From 1128f0014f094fde481df6447991a884cd2edbe0 Mon Sep 17 00:00:00 2001 From: Zhang Minghan Date: Thu, 15 Feb 2024 15:48:06 +0800 Subject: [PATCH] feat: support base64 images --- adapter/azure/processor.go | 17 ++++------------- adapter/chatgpt/chat.go | 4 +++- adapter/chatgpt/processor.go | 22 ++++++---------------- adapter/palm2/formatter.go | 14 ++++++++------ globals/variables.go | 1 + manager/images.go | 2 +- utils/char.go | 26 ++++++++++++++++++++++++-- utils/encrypt.go | 8 ++------ utils/image.go | 26 ++++++++++++++++++++++++++ utils/net.go | 3 +++ 10 files changed, 78 insertions(+), 45 deletions(-) diff --git a/adapter/azure/processor.go b/adapter/azure/processor.go index 91bd444..797f652 100644 --- a/adapter/azure/processor.go +++ b/adapter/azure/processor.go @@ -6,26 +6,17 @@ import ( "errors" "fmt" "regexp" - "strings" ) func formatMessages(props *ChatProps) interface{} { - if props.Model == globals.GPT4Vision { - base := props.Message[len(props.Message)-1].Content - urls := utils.ExtractImageUrls(base) - - if len(urls) > 0 { - base = fmt.Sprintf("%s %s", strings.Join(urls, " "), base) - } - props.Message[len(props.Message)-1].Content = base - return props.Message - } else if globals.IsOpenAIVisionModels(props.Model) { + if globals.IsOpenAIVisionModels(props.Model) { return utils.Each[globals.Message, Message](props.Message, func(message globals.Message) Message { if message.Role == globals.User { - urls := utils.ExtractImageUrls(message.Content) + raw, urls := utils.ExtractImages(message.Content, true) images := utils.EachNotNil[string, MessageContent](urls, func(url string) *MessageContent { obj, err := utils.NewImage(url) if err != nil { + globals.Info(fmt.Sprintf("cannot process image: %s (source: %s)", err.Error(), url)) return nil } @@ -43,7 +34,7 @@ func formatMessages(props *ChatProps) interface{} { Role: message.Role, Content: utils.Prepend(images, MessageContent{ Type: "text", - Text: &message.Content, + Text: &raw, }), ToolCalls: message.ToolCalls, ToolCallId: message.ToolCallId, diff --git a/adapter/chatgpt/chat.go b/adapter/chatgpt/chat.go index 0d4cfe1..bd5a5d0 100644 --- a/adapter/chatgpt/chat.go +++ b/adapter/chatgpt/chat.go @@ -55,9 +55,11 @@ func (c *ChatInstance) GetChatBody(props *ChatProps, stream bool) interface{} { } } + messages := formatMessages(props) + return ChatRequest{ Model: props.Model, - Messages: formatMessages(props), + Messages: messages, MaxToken: props.Token, Stream: stream, PresencePenalty: props.PresencePenalty, diff --git a/adapter/chatgpt/processor.go b/adapter/chatgpt/processor.go index 1abe9e1..29baa48 100644 --- a/adapter/chatgpt/processor.go +++ b/adapter/chatgpt/processor.go @@ -6,31 +6,21 @@ import ( "errors" "fmt" "regexp" - "strings" ) func formatMessages(props *ChatProps) interface{} { - if props.Model == globals.GPT4Vision { - base := props.Message[len(props.Message)-1].Content - urls := utils.ExtractImageUrls(base) - - if len(urls) > 0 { - base = fmt.Sprintf("%s %s", strings.Join(urls, " "), base) - } - props.Message[len(props.Message)-1].Content = base - return props.Message - } else if globals.IsOpenAIVisionModels(props.Model) { + if globals.IsOpenAIVisionModels(props.Model) { return utils.Each[globals.Message, Message](props.Message, func(message globals.Message) Message { if message.Role == globals.User { - urls := utils.ExtractImageUrls(message.Content) + content, urls := utils.ExtractImages(message.Content, true) images := utils.EachNotNil[string, MessageContent](urls, func(url string) *MessageContent { obj, err := utils.NewImage(url) if err != nil { - return nil + globals.Info(fmt.Sprintf("cannot process image: %s (source: %s)", err.Error(), url)) + } else { + props.Buffer.AddImage(obj) } - props.Buffer.AddImage(obj) - return &MessageContent{ Type: "image_url", ImageUrl: &ImageUrl{ @@ -43,7 +33,7 @@ func formatMessages(props *ChatProps) interface{} { Role: message.Role, Content: utils.Prepend(images, MessageContent{ Type: "text", - Text: &message.Content, + Text: &content, }), ToolCalls: message.ToolCalls, ToolCallId: message.ToolCallId, diff --git a/adapter/palm2/formatter.go b/adapter/palm2/formatter.go index b96d20c..6a16027 100644 --- a/adapter/palm2/formatter.go +++ b/adapter/palm2/formatter.go @@ -44,19 +44,21 @@ func getMimeType(content string) string { } func getGeminiContent(parts []GeminiChatPart, content string, model string) []GeminiChatPart { - parts = append(parts, GeminiChatPart{ - Text: &content, - }) - if model == globals.GeminiPro { - return parts + return append(parts, GeminiChatPart{ + Text: &content, + }) } - urls := utils.ExtractImageUrls(content) + raw, urls := utils.ExtractImages(content, true) if len(urls) > geminiMaxImages { urls = urls[:geminiMaxImages] } + parts = append(parts, GeminiChatPart{ + Text: &raw, + }) + for _, url := range urls { data, err := utils.ConvertToBase64(url) if err != nil { diff --git a/globals/variables.go b/globals/variables.go index e97a938..62261f1 100644 --- a/globals/variables.go +++ b/globals/variables.go @@ -111,6 +111,7 @@ var OpenAIDalleModels = []string{ } var OpenAIVisionModels = []string{ + //GPT4Vision, GPT4All, GPT4Dalle, GPT4VisionPreview, GPT41106VisionPreview, } diff --git a/manager/images.go b/manager/images.go index 551c263..830070d 100644 --- a/manager/images.go +++ b/manager/images.go @@ -69,7 +69,7 @@ func getImageProps(form RelayImageForm, messages []globals.Message, buffer *util func getUrlFromBuffer(buffer *utils.Buffer) string { content := buffer.Read() - urls := utils.ExtractImageUrls(content) + _, urls := utils.ExtractImages(content, true) if len(urls) > 0 { return urls[len(urls)-1] } diff --git a/utils/char.go b/utils/char.go index 7fbed1f..6ec1905 100644 --- a/utils/char.go +++ b/utils/char.go @@ -169,11 +169,33 @@ func ExtractUrls(data string) []string { return re.FindAllString(data, -1) } -func ExtractImageUrls(data string) []string { +func ExtractImages(data string, includeBase64 bool) (content string, images []string) { + ext := ExtractExternalImages(data) + if includeBase64 { + images = append(ext, ExtractBase64Images(data)...) + } else { + images = ext + } + + content = data + for _, image := range images { + content = strings.ReplaceAll(content, image, "") + } + + return content, images +} + +func ExtractBase64Images(data string) []string { + // get base64 images from data () (\n \\n [space] \\t \\r \\v \\f break the base64 string) + re := regexp.MustCompile(`(data:image/\w+;base64,[\w+/=]+)`) + return re.FindAllString(data, -1) +} + +func ExtractExternalImages(data string) []string { // https://platform.openai.com/docs/guides/vision/what-type-of-files-can-i-upload re := regexp.MustCompile(`(https?://\S+\.(?:png|jpg|jpeg|gif|webp|heif|heic)(?:\s\S+)?)`) - return re.FindAllString(strings.ToLower(data), -1) + return re.FindAllString(data, -1) } func ContainUnicode(data string) bool { diff --git a/utils/encrypt.go b/utils/encrypt.go index b7ee38c..4b24257 100644 --- a/utils/encrypt.go +++ b/utils/encrypt.go @@ -31,12 +31,8 @@ func Base64EncodeBytes(raw []byte) string { return base64.StdEncoding.EncodeToString(raw) } -func Base64Decode(raw string) string { - if data, err := base64.StdEncoding.DecodeString(raw); err == nil { - return string(data) - } else { - return "" - } +func Base64Decode(raw string) ([]byte, error) { + return base64.StdEncoding.DecodeString(raw) } func Base64DecodeBytes(raw string) []byte { diff --git a/utils/image.go b/utils/image.go index 6254b4f..6d192d3 100644 --- a/utils/image.go +++ b/utils/image.go @@ -19,6 +19,24 @@ type Image struct { type Images []Image func NewImage(url string) (*Image, error) { + if strings.HasPrefix(url, "data:image/") { + data := strings.Split(url, ",") + if len(data) != 2 { + return nil, nil + } + decoded, err := Base64Decode(data[1]) + if err != nil { + return nil, err + } + + img, _, err := image.Decode(strings.NewReader(string(decoded))) + if err != nil { + return nil, err + } + + return &Image{Object: img}, nil + } + res, err := http.Get(url) if err != nil { return nil, err @@ -53,6 +71,14 @@ func NewImage(url string) (*Image, error) { } func ConvertToBase64(url string) (string, error) { + if strings.HasPrefix(url, "data:image/") { + data := strings.Split(url, ",") + if len(data) != 2 { + return "", nil + } + return data[1], nil + } + res, err := http.Get(url) if err != nil { return "", err diff --git a/utils/net.go b/utils/net.go index 3372554..be3da95 100644 --- a/utils/net.go +++ b/utils/net.go @@ -17,6 +17,9 @@ var maxTimeout = 30 * time.Minute func newClient() *http.Client { return &http.Client{ Timeout: maxTimeout, + Transport: &http.Transport{ + TLSClientConfig: &tls.Config{InsecureSkipVerify: true}, + }, } }