diff --git a/adapter/azure/processor.go b/adapter/azure/processor.go index 3b0e30e..221e537 100644 --- a/adapter/azure/processor.go +++ b/adapter/azure/processor.go @@ -16,11 +16,11 @@ func formatMessages(props *ChatProps) interface{} { images := utils.EachNotNil[string, MessageContent](urls, func(url string) *MessageContent { obj, err := utils.NewImage(url) if err != nil { - globals.Info(fmt.Sprintf("cannot process image: %s (source: %s)", err.Error(), url)) + globals.Info(fmt.Sprintf("cannot process image: %s (source: %s)", err.Error(), utils.Extract(url, 24, "..."))) return nil } - props.Buffer.AddImage(obj, url) + props.Buffer.AddImage(obj) return &MessageContent{ Type: "image_url", diff --git a/adapter/chatgpt/processor.go b/adapter/chatgpt/processor.go index 668571a..e52436c 100644 --- a/adapter/chatgpt/processor.go +++ b/adapter/chatgpt/processor.go @@ -15,10 +15,9 @@ func formatMessages(props *ChatProps) interface{} { content, urls := utils.ExtractImages(message.Content, true) images := utils.EachNotNil[string, MessageContent](urls, func(url string) *MessageContent { obj, err := utils.NewImage(url) + props.Buffer.AddImage(obj) if err != nil { - globals.Info(fmt.Sprintf("cannot process image: %s (source: %s)", err.Error(), url)) - } else { - props.Buffer.AddImage(obj, url) + globals.Info(fmt.Sprintf("cannot process image: %s (source: %s)", err.Error(), utils.Extract(url, 24, "..."))) } return &MessageContent{ diff --git a/utils/buffer.go b/utils/buffer.go index 16d0f3b..db509ce 100644 --- a/utils/buffer.go +++ b/utils/buffer.go @@ -29,10 +29,37 @@ type Buffer struct { Charge Charge `json:"-"` } +func initInputToken(charge Charge, model string, history []globals.Message) float32 { + if globals.IsOpenAIVisionModels(model) { + for _, message := range history { + if message.Role == globals.User { + content, _ := ExtractImages(message.Content, true) + message.Content = content + } + } + + history = Each(history, func(message globals.Message) globals.Message { + if message.Role == globals.User { + raw, _ := ExtractImages(message.Content, true) + return globals.Message{ + Role: message.Role, + Content: raw, + ToolCalls: message.ToolCalls, + ToolCallId: message.ToolCallId, + } + } + + return message + }) + } + + return CountInputToken(charge, model, history) +} + func NewBuffer(model string, history []globals.Message, charge Charge) *Buffer { return &Buffer{ Model: model, - Quota: CountInputToken(charge, model, history), + Quota: initInputToken(charge, model, history), History: history, Charge: charge, } @@ -58,14 +85,15 @@ func (b *Buffer) GetChunk() string { return b.Latest } -func (b *Buffer) AddImage(image *Image, source string) { - b.Images = append(b.Images, *image) +func (b *Buffer) AddImage(image *Image) { + if image != nil { + b.Images = append(b.Images, *image) + } if b.Charge.IsBillingType(globals.TokenBilling) { - b.Quota += float32(image.CountTokens(b.Model)) * b.Charge.GetInput() - - // remove tokens from image source - b.Quota -= CountInputToken(b.Charge, b.Model, []globals.Message{{Content: source, Role: globals.User}}) + if image != nil { + b.Quota += float32(image.CountTokens(b.Model)) * b.Charge.GetInput() + } } } diff --git a/utils/image.go b/utils/image.go index 6d192d3..e829bd5 100644 --- a/utils/image.go +++ b/utils/image.go @@ -20,10 +20,11 @@ type Images []Image func NewImage(url string) (*Image, error) { if strings.HasPrefix(url, "data:image/") { - data := strings.Split(url, ",") - if len(data) != 2 { + data := SafeSplit(url, ",", 2) + if data[1] == "" { return nil, nil } + decoded, err := Base64Decode(data[1]) if err != nil { return nil, err @@ -78,7 +79,7 @@ func ConvertToBase64(url string) (string, error) { } return data[1], nil } - + res, err := http.Get(url) if err != nil { return "", err