mirror of
https://github.com/coaidev/coai.git
synced 2025-05-19 21:10:18 +09:00
feat: support base64 images
This commit is contained in:
parent
4cd5c422c0
commit
1128f0014f
@ -6,26 +6,17 @@ import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"regexp"
|
||||
"strings"
|
||||
)
|
||||
|
||||
func formatMessages(props *ChatProps) interface{} {
|
||||
if props.Model == globals.GPT4Vision {
|
||||
base := props.Message[len(props.Message)-1].Content
|
||||
urls := utils.ExtractImageUrls(base)
|
||||
|
||||
if len(urls) > 0 {
|
||||
base = fmt.Sprintf("%s %s", strings.Join(urls, " "), base)
|
||||
}
|
||||
props.Message[len(props.Message)-1].Content = base
|
||||
return props.Message
|
||||
} else if globals.IsOpenAIVisionModels(props.Model) {
|
||||
if globals.IsOpenAIVisionModels(props.Model) {
|
||||
return utils.Each[globals.Message, Message](props.Message, func(message globals.Message) Message {
|
||||
if message.Role == globals.User {
|
||||
urls := utils.ExtractImageUrls(message.Content)
|
||||
raw, urls := utils.ExtractImages(message.Content, true)
|
||||
images := utils.EachNotNil[string, MessageContent](urls, func(url string) *MessageContent {
|
||||
obj, err := utils.NewImage(url)
|
||||
if err != nil {
|
||||
globals.Info(fmt.Sprintf("cannot process image: %s (source: %s)", err.Error(), url))
|
||||
return nil
|
||||
}
|
||||
|
||||
@ -43,7 +34,7 @@ func formatMessages(props *ChatProps) interface{} {
|
||||
Role: message.Role,
|
||||
Content: utils.Prepend(images, MessageContent{
|
||||
Type: "text",
|
||||
Text: &message.Content,
|
||||
Text: &raw,
|
||||
}),
|
||||
ToolCalls: message.ToolCalls,
|
||||
ToolCallId: message.ToolCallId,
|
||||
|
@ -55,9 +55,11 @@ func (c *ChatInstance) GetChatBody(props *ChatProps, stream bool) interface{} {
|
||||
}
|
||||
}
|
||||
|
||||
messages := formatMessages(props)
|
||||
|
||||
return ChatRequest{
|
||||
Model: props.Model,
|
||||
Messages: formatMessages(props),
|
||||
Messages: messages,
|
||||
MaxToken: props.Token,
|
||||
Stream: stream,
|
||||
PresencePenalty: props.PresencePenalty,
|
||||
|
@ -6,30 +6,20 @@ import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"regexp"
|
||||
"strings"
|
||||
)
|
||||
|
||||
func formatMessages(props *ChatProps) interface{} {
|
||||
if props.Model == globals.GPT4Vision {
|
||||
base := props.Message[len(props.Message)-1].Content
|
||||
urls := utils.ExtractImageUrls(base)
|
||||
|
||||
if len(urls) > 0 {
|
||||
base = fmt.Sprintf("%s %s", strings.Join(urls, " "), base)
|
||||
}
|
||||
props.Message[len(props.Message)-1].Content = base
|
||||
return props.Message
|
||||
} else if globals.IsOpenAIVisionModels(props.Model) {
|
||||
if globals.IsOpenAIVisionModels(props.Model) {
|
||||
return utils.Each[globals.Message, Message](props.Message, func(message globals.Message) Message {
|
||||
if message.Role == globals.User {
|
||||
urls := utils.ExtractImageUrls(message.Content)
|
||||
content, urls := utils.ExtractImages(message.Content, true)
|
||||
images := utils.EachNotNil[string, MessageContent](urls, func(url string) *MessageContent {
|
||||
obj, err := utils.NewImage(url)
|
||||
if err != nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
globals.Info(fmt.Sprintf("cannot process image: %s (source: %s)", err.Error(), url))
|
||||
} else {
|
||||
props.Buffer.AddImage(obj)
|
||||
}
|
||||
|
||||
return &MessageContent{
|
||||
Type: "image_url",
|
||||
@ -43,7 +33,7 @@ func formatMessages(props *ChatProps) interface{} {
|
||||
Role: message.Role,
|
||||
Content: utils.Prepend(images, MessageContent{
|
||||
Type: "text",
|
||||
Text: &message.Content,
|
||||
Text: &content,
|
||||
}),
|
||||
ToolCalls: message.ToolCalls,
|
||||
ToolCallId: message.ToolCallId,
|
||||
|
@ -44,19 +44,21 @@ func getMimeType(content string) string {
|
||||
}
|
||||
|
||||
func getGeminiContent(parts []GeminiChatPart, content string, model string) []GeminiChatPart {
|
||||
parts = append(parts, GeminiChatPart{
|
||||
if model == globals.GeminiPro {
|
||||
return append(parts, GeminiChatPart{
|
||||
Text: &content,
|
||||
})
|
||||
|
||||
if model == globals.GeminiPro {
|
||||
return parts
|
||||
}
|
||||
|
||||
urls := utils.ExtractImageUrls(content)
|
||||
raw, urls := utils.ExtractImages(content, true)
|
||||
if len(urls) > geminiMaxImages {
|
||||
urls = urls[:geminiMaxImages]
|
||||
}
|
||||
|
||||
parts = append(parts, GeminiChatPart{
|
||||
Text: &raw,
|
||||
})
|
||||
|
||||
for _, url := range urls {
|
||||
data, err := utils.ConvertToBase64(url)
|
||||
if err != nil {
|
||||
|
@ -111,6 +111,7 @@ var OpenAIDalleModels = []string{
|
||||
}
|
||||
|
||||
var OpenAIVisionModels = []string{
|
||||
//GPT4Vision, GPT4All, GPT4Dalle,
|
||||
GPT4VisionPreview, GPT41106VisionPreview,
|
||||
}
|
||||
|
||||
|
@ -69,7 +69,7 @@ func getImageProps(form RelayImageForm, messages []globals.Message, buffer *util
|
||||
func getUrlFromBuffer(buffer *utils.Buffer) string {
|
||||
content := buffer.Read()
|
||||
|
||||
urls := utils.ExtractImageUrls(content)
|
||||
_, urls := utils.ExtractImages(content, true)
|
||||
if len(urls) > 0 {
|
||||
return urls[len(urls)-1]
|
||||
}
|
||||
|
@ -169,11 +169,33 @@ func ExtractUrls(data string) []string {
|
||||
return re.FindAllString(data, -1)
|
||||
}
|
||||
|
||||
func ExtractImageUrls(data string) []string {
|
||||
func ExtractImages(data string, includeBase64 bool) (content string, images []string) {
|
||||
ext := ExtractExternalImages(data)
|
||||
if includeBase64 {
|
||||
images = append(ext, ExtractBase64Images(data)...)
|
||||
} else {
|
||||
images = ext
|
||||
}
|
||||
|
||||
content = data
|
||||
for _, image := range images {
|
||||
content = strings.ReplaceAll(content, image, "")
|
||||
}
|
||||
|
||||
return content, images
|
||||
}
|
||||
|
||||
func ExtractBase64Images(data string) []string {
|
||||
// get base64 images from data (data:image/png;base64,xxxxxx) (\n \\n [space] \\t \\r \\v \\f break the base64 string)
|
||||
re := regexp.MustCompile(`(data:image/\w+;base64,[\w+/=]+)`)
|
||||
return re.FindAllString(data, -1)
|
||||
}
|
||||
|
||||
func ExtractExternalImages(data string) []string {
|
||||
// https://platform.openai.com/docs/guides/vision/what-type-of-files-can-i-upload
|
||||
|
||||
re := regexp.MustCompile(`(https?://\S+\.(?:png|jpg|jpeg|gif|webp|heif|heic)(?:\s\S+)?)`)
|
||||
return re.FindAllString(strings.ToLower(data), -1)
|
||||
return re.FindAllString(data, -1)
|
||||
}
|
||||
|
||||
func ContainUnicode(data string) bool {
|
||||
|
@ -31,12 +31,8 @@ func Base64EncodeBytes(raw []byte) string {
|
||||
return base64.StdEncoding.EncodeToString(raw)
|
||||
}
|
||||
|
||||
func Base64Decode(raw string) string {
|
||||
if data, err := base64.StdEncoding.DecodeString(raw); err == nil {
|
||||
return string(data)
|
||||
} else {
|
||||
return ""
|
||||
}
|
||||
func Base64Decode(raw string) ([]byte, error) {
|
||||
return base64.StdEncoding.DecodeString(raw)
|
||||
}
|
||||
|
||||
func Base64DecodeBytes(raw string) []byte {
|
||||
|
@ -19,6 +19,24 @@ type Image struct {
|
||||
type Images []Image
|
||||
|
||||
func NewImage(url string) (*Image, error) {
|
||||
if strings.HasPrefix(url, "data:image/") {
|
||||
data := strings.Split(url, ",")
|
||||
if len(data) != 2 {
|
||||
return nil, nil
|
||||
}
|
||||
decoded, err := Base64Decode(data[1])
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
img, _, err := image.Decode(strings.NewReader(string(decoded)))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &Image{Object: img}, nil
|
||||
}
|
||||
|
||||
res, err := http.Get(url)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
@ -53,6 +71,14 @@ func NewImage(url string) (*Image, error) {
|
||||
}
|
||||
|
||||
func ConvertToBase64(url string) (string, error) {
|
||||
if strings.HasPrefix(url, "data:image/") {
|
||||
data := strings.Split(url, ",")
|
||||
if len(data) != 2 {
|
||||
return "", nil
|
||||
}
|
||||
return data[1], nil
|
||||
}
|
||||
|
||||
res, err := http.Get(url)
|
||||
if err != nil {
|
||||
return "", err
|
||||
|
@ -17,6 +17,9 @@ var maxTimeout = 30 * time.Minute
|
||||
func newClient() *http.Client {
|
||||
return &http.Client{
|
||||
Timeout: maxTimeout,
|
||||
Transport: &http.Transport{
|
||||
TLSClientConfig: &tls.Config{InsecureSkipVerify: true},
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user