mirror of
https://github.com/coaidev/coai.git
synced 2025-05-20 05:20:15 +09:00
feat: support base64 images
This commit is contained in:
parent
4cd5c422c0
commit
1128f0014f
@ -6,26 +6,17 @@ import (
|
|||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
"regexp"
|
"regexp"
|
||||||
"strings"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
func formatMessages(props *ChatProps) interface{} {
|
func formatMessages(props *ChatProps) interface{} {
|
||||||
if props.Model == globals.GPT4Vision {
|
if globals.IsOpenAIVisionModels(props.Model) {
|
||||||
base := props.Message[len(props.Message)-1].Content
|
|
||||||
urls := utils.ExtractImageUrls(base)
|
|
||||||
|
|
||||||
if len(urls) > 0 {
|
|
||||||
base = fmt.Sprintf("%s %s", strings.Join(urls, " "), base)
|
|
||||||
}
|
|
||||||
props.Message[len(props.Message)-1].Content = base
|
|
||||||
return props.Message
|
|
||||||
} else if globals.IsOpenAIVisionModels(props.Model) {
|
|
||||||
return utils.Each[globals.Message, Message](props.Message, func(message globals.Message) Message {
|
return utils.Each[globals.Message, Message](props.Message, func(message globals.Message) Message {
|
||||||
if message.Role == globals.User {
|
if message.Role == globals.User {
|
||||||
urls := utils.ExtractImageUrls(message.Content)
|
raw, urls := utils.ExtractImages(message.Content, true)
|
||||||
images := utils.EachNotNil[string, MessageContent](urls, func(url string) *MessageContent {
|
images := utils.EachNotNil[string, MessageContent](urls, func(url string) *MessageContent {
|
||||||
obj, err := utils.NewImage(url)
|
obj, err := utils.NewImage(url)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
globals.Info(fmt.Sprintf("cannot process image: %s (source: %s)", err.Error(), url))
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -43,7 +34,7 @@ func formatMessages(props *ChatProps) interface{} {
|
|||||||
Role: message.Role,
|
Role: message.Role,
|
||||||
Content: utils.Prepend(images, MessageContent{
|
Content: utils.Prepend(images, MessageContent{
|
||||||
Type: "text",
|
Type: "text",
|
||||||
Text: &message.Content,
|
Text: &raw,
|
||||||
}),
|
}),
|
||||||
ToolCalls: message.ToolCalls,
|
ToolCalls: message.ToolCalls,
|
||||||
ToolCallId: message.ToolCallId,
|
ToolCallId: message.ToolCallId,
|
||||||
|
@ -55,9 +55,11 @@ func (c *ChatInstance) GetChatBody(props *ChatProps, stream bool) interface{} {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
messages := formatMessages(props)
|
||||||
|
|
||||||
return ChatRequest{
|
return ChatRequest{
|
||||||
Model: props.Model,
|
Model: props.Model,
|
||||||
Messages: formatMessages(props),
|
Messages: messages,
|
||||||
MaxToken: props.Token,
|
MaxToken: props.Token,
|
||||||
Stream: stream,
|
Stream: stream,
|
||||||
PresencePenalty: props.PresencePenalty,
|
PresencePenalty: props.PresencePenalty,
|
||||||
|
@ -6,31 +6,21 @@ import (
|
|||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
"regexp"
|
"regexp"
|
||||||
"strings"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
func formatMessages(props *ChatProps) interface{} {
|
func formatMessages(props *ChatProps) interface{} {
|
||||||
if props.Model == globals.GPT4Vision {
|
if globals.IsOpenAIVisionModels(props.Model) {
|
||||||
base := props.Message[len(props.Message)-1].Content
|
|
||||||
urls := utils.ExtractImageUrls(base)
|
|
||||||
|
|
||||||
if len(urls) > 0 {
|
|
||||||
base = fmt.Sprintf("%s %s", strings.Join(urls, " "), base)
|
|
||||||
}
|
|
||||||
props.Message[len(props.Message)-1].Content = base
|
|
||||||
return props.Message
|
|
||||||
} else if globals.IsOpenAIVisionModels(props.Model) {
|
|
||||||
return utils.Each[globals.Message, Message](props.Message, func(message globals.Message) Message {
|
return utils.Each[globals.Message, Message](props.Message, func(message globals.Message) Message {
|
||||||
if message.Role == globals.User {
|
if message.Role == globals.User {
|
||||||
urls := utils.ExtractImageUrls(message.Content)
|
content, urls := utils.ExtractImages(message.Content, true)
|
||||||
images := utils.EachNotNil[string, MessageContent](urls, func(url string) *MessageContent {
|
images := utils.EachNotNil[string, MessageContent](urls, func(url string) *MessageContent {
|
||||||
obj, err := utils.NewImage(url)
|
obj, err := utils.NewImage(url)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil
|
globals.Info(fmt.Sprintf("cannot process image: %s (source: %s)", err.Error(), url))
|
||||||
|
} else {
|
||||||
|
props.Buffer.AddImage(obj)
|
||||||
}
|
}
|
||||||
|
|
||||||
props.Buffer.AddImage(obj)
|
|
||||||
|
|
||||||
return &MessageContent{
|
return &MessageContent{
|
||||||
Type: "image_url",
|
Type: "image_url",
|
||||||
ImageUrl: &ImageUrl{
|
ImageUrl: &ImageUrl{
|
||||||
@ -43,7 +33,7 @@ func formatMessages(props *ChatProps) interface{} {
|
|||||||
Role: message.Role,
|
Role: message.Role,
|
||||||
Content: utils.Prepend(images, MessageContent{
|
Content: utils.Prepend(images, MessageContent{
|
||||||
Type: "text",
|
Type: "text",
|
||||||
Text: &message.Content,
|
Text: &content,
|
||||||
}),
|
}),
|
||||||
ToolCalls: message.ToolCalls,
|
ToolCalls: message.ToolCalls,
|
||||||
ToolCallId: message.ToolCallId,
|
ToolCallId: message.ToolCallId,
|
||||||
|
@ -44,19 +44,21 @@ func getMimeType(content string) string {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func getGeminiContent(parts []GeminiChatPart, content string, model string) []GeminiChatPart {
|
func getGeminiContent(parts []GeminiChatPart, content string, model string) []GeminiChatPart {
|
||||||
parts = append(parts, GeminiChatPart{
|
|
||||||
Text: &content,
|
|
||||||
})
|
|
||||||
|
|
||||||
if model == globals.GeminiPro {
|
if model == globals.GeminiPro {
|
||||||
return parts
|
return append(parts, GeminiChatPart{
|
||||||
|
Text: &content,
|
||||||
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
urls := utils.ExtractImageUrls(content)
|
raw, urls := utils.ExtractImages(content, true)
|
||||||
if len(urls) > geminiMaxImages {
|
if len(urls) > geminiMaxImages {
|
||||||
urls = urls[:geminiMaxImages]
|
urls = urls[:geminiMaxImages]
|
||||||
}
|
}
|
||||||
|
|
||||||
|
parts = append(parts, GeminiChatPart{
|
||||||
|
Text: &raw,
|
||||||
|
})
|
||||||
|
|
||||||
for _, url := range urls {
|
for _, url := range urls {
|
||||||
data, err := utils.ConvertToBase64(url)
|
data, err := utils.ConvertToBase64(url)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
@ -111,6 +111,7 @@ var OpenAIDalleModels = []string{
|
|||||||
}
|
}
|
||||||
|
|
||||||
var OpenAIVisionModels = []string{
|
var OpenAIVisionModels = []string{
|
||||||
|
//GPT4Vision, GPT4All, GPT4Dalle,
|
||||||
GPT4VisionPreview, GPT41106VisionPreview,
|
GPT4VisionPreview, GPT41106VisionPreview,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -69,7 +69,7 @@ func getImageProps(form RelayImageForm, messages []globals.Message, buffer *util
|
|||||||
func getUrlFromBuffer(buffer *utils.Buffer) string {
|
func getUrlFromBuffer(buffer *utils.Buffer) string {
|
||||||
content := buffer.Read()
|
content := buffer.Read()
|
||||||
|
|
||||||
urls := utils.ExtractImageUrls(content)
|
_, urls := utils.ExtractImages(content, true)
|
||||||
if len(urls) > 0 {
|
if len(urls) > 0 {
|
||||||
return urls[len(urls)-1]
|
return urls[len(urls)-1]
|
||||||
}
|
}
|
||||||
|
@ -169,11 +169,33 @@ func ExtractUrls(data string) []string {
|
|||||||
return re.FindAllString(data, -1)
|
return re.FindAllString(data, -1)
|
||||||
}
|
}
|
||||||
|
|
||||||
func ExtractImageUrls(data string) []string {
|
func ExtractImages(data string, includeBase64 bool) (content string, images []string) {
|
||||||
|
ext := ExtractExternalImages(data)
|
||||||
|
if includeBase64 {
|
||||||
|
images = append(ext, ExtractBase64Images(data)...)
|
||||||
|
} else {
|
||||||
|
images = ext
|
||||||
|
}
|
||||||
|
|
||||||
|
content = data
|
||||||
|
for _, image := range images {
|
||||||
|
content = strings.ReplaceAll(content, image, "")
|
||||||
|
}
|
||||||
|
|
||||||
|
return content, images
|
||||||
|
}
|
||||||
|
|
||||||
|
func ExtractBase64Images(data string) []string {
|
||||||
|
// get base64 images from data (data:image/png;base64,xxxxxx) (\n \\n [space] \\t \\r \\v \\f break the base64 string)
|
||||||
|
re := regexp.MustCompile(`(data:image/\w+;base64,[\w+/=]+)`)
|
||||||
|
return re.FindAllString(data, -1)
|
||||||
|
}
|
||||||
|
|
||||||
|
func ExtractExternalImages(data string) []string {
|
||||||
// https://platform.openai.com/docs/guides/vision/what-type-of-files-can-i-upload
|
// https://platform.openai.com/docs/guides/vision/what-type-of-files-can-i-upload
|
||||||
|
|
||||||
re := regexp.MustCompile(`(https?://\S+\.(?:png|jpg|jpeg|gif|webp|heif|heic)(?:\s\S+)?)`)
|
re := regexp.MustCompile(`(https?://\S+\.(?:png|jpg|jpeg|gif|webp|heif|heic)(?:\s\S+)?)`)
|
||||||
return re.FindAllString(strings.ToLower(data), -1)
|
return re.FindAllString(data, -1)
|
||||||
}
|
}
|
||||||
|
|
||||||
func ContainUnicode(data string) bool {
|
func ContainUnicode(data string) bool {
|
||||||
|
@ -31,12 +31,8 @@ func Base64EncodeBytes(raw []byte) string {
|
|||||||
return base64.StdEncoding.EncodeToString(raw)
|
return base64.StdEncoding.EncodeToString(raw)
|
||||||
}
|
}
|
||||||
|
|
||||||
func Base64Decode(raw string) string {
|
func Base64Decode(raw string) ([]byte, error) {
|
||||||
if data, err := base64.StdEncoding.DecodeString(raw); err == nil {
|
return base64.StdEncoding.DecodeString(raw)
|
||||||
return string(data)
|
|
||||||
} else {
|
|
||||||
return ""
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func Base64DecodeBytes(raw string) []byte {
|
func Base64DecodeBytes(raw string) []byte {
|
||||||
|
@ -19,6 +19,24 @@ type Image struct {
|
|||||||
type Images []Image
|
type Images []Image
|
||||||
|
|
||||||
func NewImage(url string) (*Image, error) {
|
func NewImage(url string) (*Image, error) {
|
||||||
|
if strings.HasPrefix(url, "data:image/") {
|
||||||
|
data := strings.Split(url, ",")
|
||||||
|
if len(data) != 2 {
|
||||||
|
return nil, nil
|
||||||
|
}
|
||||||
|
decoded, err := Base64Decode(data[1])
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
img, _, err := image.Decode(strings.NewReader(string(decoded)))
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
return &Image{Object: img}, nil
|
||||||
|
}
|
||||||
|
|
||||||
res, err := http.Get(url)
|
res, err := http.Get(url)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
@ -53,6 +71,14 @@ func NewImage(url string) (*Image, error) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func ConvertToBase64(url string) (string, error) {
|
func ConvertToBase64(url string) (string, error) {
|
||||||
|
if strings.HasPrefix(url, "data:image/") {
|
||||||
|
data := strings.Split(url, ",")
|
||||||
|
if len(data) != 2 {
|
||||||
|
return "", nil
|
||||||
|
}
|
||||||
|
return data[1], nil
|
||||||
|
}
|
||||||
|
|
||||||
res, err := http.Get(url)
|
res, err := http.Get(url)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return "", err
|
return "", err
|
||||||
|
@ -17,6 +17,9 @@ var maxTimeout = 30 * time.Minute
|
|||||||
func newClient() *http.Client {
|
func newClient() *http.Client {
|
||||||
return &http.Client{
|
return &http.Client{
|
||||||
Timeout: maxTimeout,
|
Timeout: maxTimeout,
|
||||||
|
Transport: &http.Transport{
|
||||||
|
TLSClientConfig: &tls.Config{InsecureSkipVerify: true},
|
||||||
|
},
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user