feat: support base64 images

This commit is contained in:
Zhang Minghan 2024-02-15 15:48:06 +08:00
parent 4cd5c422c0
commit 1128f0014f
10 changed files with 78 additions and 45 deletions

View File

@ -6,26 +6,17 @@ import (
"errors"
"fmt"
"regexp"
"strings"
)
func formatMessages(props *ChatProps) interface{} {
if props.Model == globals.GPT4Vision {
base := props.Message[len(props.Message)-1].Content
urls := utils.ExtractImageUrls(base)
if len(urls) > 0 {
base = fmt.Sprintf("%s %s", strings.Join(urls, " "), base)
}
props.Message[len(props.Message)-1].Content = base
return props.Message
} else if globals.IsOpenAIVisionModels(props.Model) {
if globals.IsOpenAIVisionModels(props.Model) {
return utils.Each[globals.Message, Message](props.Message, func(message globals.Message) Message {
if message.Role == globals.User {
urls := utils.ExtractImageUrls(message.Content)
raw, urls := utils.ExtractImages(message.Content, true)
images := utils.EachNotNil[string, MessageContent](urls, func(url string) *MessageContent {
obj, err := utils.NewImage(url)
if err != nil {
globals.Info(fmt.Sprintf("cannot process image: %s (source: %s)", err.Error(), url))
return nil
}
@ -43,7 +34,7 @@ func formatMessages(props *ChatProps) interface{} {
Role: message.Role,
Content: utils.Prepend(images, MessageContent{
Type: "text",
Text: &message.Content,
Text: &raw,
}),
ToolCalls: message.ToolCalls,
ToolCallId: message.ToolCallId,

View File

@ -55,9 +55,11 @@ func (c *ChatInstance) GetChatBody(props *ChatProps, stream bool) interface{} {
}
}
messages := formatMessages(props)
return ChatRequest{
Model: props.Model,
Messages: formatMessages(props),
Messages: messages,
MaxToken: props.Token,
Stream: stream,
PresencePenalty: props.PresencePenalty,

View File

@ -6,31 +6,21 @@ import (
"errors"
"fmt"
"regexp"
"strings"
)
func formatMessages(props *ChatProps) interface{} {
if props.Model == globals.GPT4Vision {
base := props.Message[len(props.Message)-1].Content
urls := utils.ExtractImageUrls(base)
if len(urls) > 0 {
base = fmt.Sprintf("%s %s", strings.Join(urls, " "), base)
}
props.Message[len(props.Message)-1].Content = base
return props.Message
} else if globals.IsOpenAIVisionModels(props.Model) {
if globals.IsOpenAIVisionModels(props.Model) {
return utils.Each[globals.Message, Message](props.Message, func(message globals.Message) Message {
if message.Role == globals.User {
urls := utils.ExtractImageUrls(message.Content)
content, urls := utils.ExtractImages(message.Content, true)
images := utils.EachNotNil[string, MessageContent](urls, func(url string) *MessageContent {
obj, err := utils.NewImage(url)
if err != nil {
return nil
globals.Info(fmt.Sprintf("cannot process image: %s (source: %s)", err.Error(), url))
} else {
props.Buffer.AddImage(obj)
}
props.Buffer.AddImage(obj)
return &MessageContent{
Type: "image_url",
ImageUrl: &ImageUrl{
@ -43,7 +33,7 @@ func formatMessages(props *ChatProps) interface{} {
Role: message.Role,
Content: utils.Prepend(images, MessageContent{
Type: "text",
Text: &message.Content,
Text: &content,
}),
ToolCalls: message.ToolCalls,
ToolCallId: message.ToolCallId,

View File

@ -44,19 +44,21 @@ func getMimeType(content string) string {
}
func getGeminiContent(parts []GeminiChatPart, content string, model string) []GeminiChatPart {
parts = append(parts, GeminiChatPart{
Text: &content,
})
if model == globals.GeminiPro {
return parts
return append(parts, GeminiChatPart{
Text: &content,
})
}
urls := utils.ExtractImageUrls(content)
raw, urls := utils.ExtractImages(content, true)
if len(urls) > geminiMaxImages {
urls = urls[:geminiMaxImages]
}
parts = append(parts, GeminiChatPart{
Text: &raw,
})
for _, url := range urls {
data, err := utils.ConvertToBase64(url)
if err != nil {

View File

@ -111,6 +111,7 @@ var OpenAIDalleModels = []string{
}
var OpenAIVisionModels = []string{
//GPT4Vision, GPT4All, GPT4Dalle,
GPT4VisionPreview, GPT41106VisionPreview,
}

View File

@ -69,7 +69,7 @@ func getImageProps(form RelayImageForm, messages []globals.Message, buffer *util
func getUrlFromBuffer(buffer *utils.Buffer) string {
content := buffer.Read()
urls := utils.ExtractImageUrls(content)
_, urls := utils.ExtractImages(content, true)
if len(urls) > 0 {
return urls[len(urls)-1]
}

View File

@ -169,11 +169,33 @@ func ExtractUrls(data string) []string {
return re.FindAllString(data, -1)
}
func ExtractImageUrls(data string) []string {
func ExtractImages(data string, includeBase64 bool) (content string, images []string) {
ext := ExtractExternalImages(data)
if includeBase64 {
images = append(ext, ExtractBase64Images(data)...)
} else {
images = ext
}
content = data
for _, image := range images {
content = strings.ReplaceAll(content, image, "")
}
return content, images
}
func ExtractBase64Images(data string) []string {
// get base64 images from data () (\n \\n [space] \\t \\r \\v \\f break the base64 string)
re := regexp.MustCompile(`(data:image/\w+;base64,[\w+/=]+)`)
return re.FindAllString(data, -1)
}
func ExtractExternalImages(data string) []string {
// https://platform.openai.com/docs/guides/vision/what-type-of-files-can-i-upload
re := regexp.MustCompile(`(https?://\S+\.(?:png|jpg|jpeg|gif|webp|heif|heic)(?:\s\S+)?)`)
return re.FindAllString(strings.ToLower(data), -1)
return re.FindAllString(data, -1)
}
func ContainUnicode(data string) bool {

View File

@ -31,12 +31,8 @@ func Base64EncodeBytes(raw []byte) string {
return base64.StdEncoding.EncodeToString(raw)
}
func Base64Decode(raw string) string {
if data, err := base64.StdEncoding.DecodeString(raw); err == nil {
return string(data)
} else {
return ""
}
func Base64Decode(raw string) ([]byte, error) {
return base64.StdEncoding.DecodeString(raw)
}
func Base64DecodeBytes(raw string) []byte {

View File

@ -19,6 +19,24 @@ type Image struct {
type Images []Image
func NewImage(url string) (*Image, error) {
if strings.HasPrefix(url, "data:image/") {
data := strings.Split(url, ",")
if len(data) != 2 {
return nil, nil
}
decoded, err := Base64Decode(data[1])
if err != nil {
return nil, err
}
img, _, err := image.Decode(strings.NewReader(string(decoded)))
if err != nil {
return nil, err
}
return &Image{Object: img}, nil
}
res, err := http.Get(url)
if err != nil {
return nil, err
@ -53,6 +71,14 @@ func NewImage(url string) (*Image, error) {
}
func ConvertToBase64(url string) (string, error) {
if strings.HasPrefix(url, "data:image/") {
data := strings.Split(url, ",")
if len(data) != 2 {
return "", nil
}
return data[1], nil
}
res, err := http.Get(url)
if err != nil {
return "", err

View File

@ -17,6 +17,9 @@ var maxTimeout = 30 * time.Minute
func newClient() *http.Client {
return &http.Client{
Timeout: maxTimeout,
Transport: &http.Transport{
TLSClientConfig: &tls.Config{InsecureSkipVerify: true},
},
}
}