From ccd9f13eaa89734336744bda0a58ac33695a83b9 Mon Sep 17 00:00:00 2001 From: Zhang Minghan Date: Sat, 17 Feb 2024 23:39:57 +0800 Subject: [PATCH] fix: fix frontend base64 image matcher, input tokens and b64images content hidden in un-vision models --- adapter/request.go | 23 +++++++++++++++++ app/src/components/plugins/file.tsx | 2 +- channel/worker.go | 1 + globals/variables.go | 9 +++++++ manager/cache.go | 38 ----------------------------- manager/chat.go | 30 ++++------------------- manager/chat_completions.go | 2 +- manager/completions.go | 16 ------------ utils/buffer.go | 28 ++++++++++++--------- utils/tokenizer.go | 8 +++--- 10 files changed, 60 insertions(+), 97 deletions(-) delete mode 100644 manager/cache.go diff --git a/adapter/request.go b/adapter/request.go index 4f37e91..03e2461 100644 --- a/adapter/request.go +++ b/adapter/request.go @@ -2,6 +2,7 @@ package adapter import ( "chat/globals" + "chat/utils" "fmt" "strings" "time" @@ -44,3 +45,25 @@ func NewChatRequest(conf globals.ChannelConfig, props *ChatProps, hook globals.H return conf.ProcessError(err) } + +func ClearMessages(model string, messages []globals.Message) []globals.Message { + if globals.IsVisionModel(model) { + return messages + } + + return utils.Each[globals.Message](messages, func(message globals.Message) globals.Message { + if message.Role != globals.User { + return message + } + + images := utils.ExtractBase64Images(message.Content) + for _, image := range images { + if len(image) <= 46 { + continue + } + + message.Content = strings.Replace(message.Content, image, utils.Extract(image, 46, " ..."), -1) + } + return message + }) +} diff --git a/app/src/components/plugins/file.tsx b/app/src/components/plugins/file.tsx index 4b82c69..7621e22 100644 --- a/app/src/components/plugins/file.tsx +++ b/app/src/components/plugins/file.tsx @@ -24,7 +24,7 @@ export function parseFile(data: string, acceptDownload?: boolean) { const b64image = useMemo(() => { // get base64 image from content (like: data:image/png;base64,xxxxx) - const match = content.match(/(data:image\/.*;base64,.*=)/); + const match = content.match(/data:image\/([^;]+);base64,([a-zA-Z0-9+/=]+)/g); return match ? match[0] : ""; }, [filename, content]); diff --git a/channel/worker.go b/channel/worker.go index acc0fab..ea5fc37 100644 --- a/channel/worker.go +++ b/channel/worker.go @@ -50,6 +50,7 @@ func PreflightCache(cache *redis.Client, hash string, buffer *utils.Buffer, hook return idx, false, nil } + buffer.SetInputTokens(buf.CountInputToken()) buffer.SetToolCalls(buf.GetToolCalls()) buffer.SetFunctionCall(buf.GetFunctionCall()) return idx, true, hook(data) diff --git a/globals/variables.go b/globals/variables.go index 62261f1..ccdd152 100644 --- a/globals/variables.go +++ b/globals/variables.go @@ -115,6 +115,11 @@ var OpenAIVisionModels = []string{ GPT4VisionPreview, GPT41106VisionPreview, } +var VisionModels = []string{ + GPT4VisionPreview, GPT41106VisionPreview, + GeminiProVision, +} + func in(value string, slice []string) bool { for _, item := range slice { if item == value || strings.Contains(value, item) { @@ -133,3 +138,7 @@ func IsOpenAIVisionModels(model string) bool { // enable openai image format for gpt-4-vision-preview models return in(model, OpenAIVisionModels) } + +func IsVisionModel(model string) bool { + return in(model, VisionModels) +} diff --git a/manager/cache.go b/manager/cache.go deleted file mode 100644 index a5361f9..0000000 --- a/manager/cache.go +++ /dev/null @@ -1,38 +0,0 @@ -package manager - -import ( - "chat/channel" - "chat/globals" - "chat/utils" - "fmt" - "github.com/gin-gonic/gin" - "time" -) - -type CacheProps struct { - Message []globals.Message `json:"message" required:"true"` - Model string `json:"model" required:"true"` - Reversible bool `json:"reversible"` -} - -type CacheData struct { - Message string `json:"message"` -} - -func ExtractCacheData(c *gin.Context, props *CacheProps) *CacheData { - hash := utils.Md5Encrypt(utils.Marshal(props)) - data, err := utils.GetCacheFromContext(c).Get(c, fmt.Sprintf(":niodata:%s", hash)).Result() - if err == nil && data != "" { - return utils.UnmarshalForm[CacheData](data) - } - return nil -} - -func SaveCacheData(c *gin.Context, props *CacheProps, data *CacheData) { - if channel.ChargeInstance.IsBilling(props.Model) { - return - } - - hash := utils.Md5Encrypt(utils.Marshal(props)) - utils.GetCacheFromContext(c).Set(c, fmt.Sprintf(":niodata:%s", hash), utils.Marshal(data), time.Hour*12) -} diff --git a/manager/chat.go b/manager/chat.go index d9e83b3..a645cbd 100644 --- a/manager/chat.go +++ b/manager/chat.go @@ -63,11 +63,12 @@ func ChatHandler(conn *Connection, user *auth.User, instance *conversation.Conve } }() - segment := web.UsingWebSegment(instance) - - model := instance.GetModel() db := conn.GetDB() cache := conn.GetCache() + + model := instance.GetModel() + segment := adapter.ClearMessages(model, web.UsingWebSegment(instance)) + check, plan := auth.CanEnableModelWithSubscription(db, cache, user, model) conn.Send(globals.ChatSegmentResponse{ Conversation: instance.GetId(), @@ -83,15 +84,6 @@ func ChatHandler(conn *Connection, user *auth.User, instance *conversation.Conve return message } - if form := ExtractCacheData(conn.GetCtx(), &CacheProps{ - Message: segment, - Model: model, - Reversible: plan, - }); form != nil { - MockStreamSender(conn, form.Message) - return form.Message - } - buffer := utils.NewBuffer(model, segment, channel.ChargeInstance.GetCharge(model)) hit, err := channel.NewChatRequestWithCache( cache, buffer, @@ -152,17 +144,5 @@ func ChatHandler(conn *Connection, user *auth.User, instance *conversation.Conve Plan: plan, }) - result := buffer.ReadWithDefault(defaultMessage) - - if err == nil && result != defaultMessage { - SaveCacheData(conn.GetCtx(), &CacheProps{ - Message: segment, - Model: model, - Reversible: plan, - }, &CacheData{ - Message: result, - }) - } - - return result + return buffer.ReadWithDefault(defaultMessage) } diff --git a/manager/chat_completions.go b/manager/chat_completions.go index d76ed10..cd438ac 100644 --- a/manager/chat_completions.go +++ b/manager/chat_completions.go @@ -121,7 +121,7 @@ func sendTranshipmentResponse(c *gin.Context, form RelayForm, messages []globals Index: 0, Message: globals.Message{ Role: globals.Assistant, - Content: buffer.ReadWithDefault(defaultMessage), + Content: buffer.Read(), ToolCalls: buffer.GetToolCalls(), FunctionCall: buffer.GetFunctionCall(), }, diff --git a/manager/completions.go b/manager/completions.go index 473b18b..3ff298e 100644 --- a/manager/completions.go +++ b/manager/completions.go @@ -31,14 +31,6 @@ func NativeChatHandler(c *gin.Context, user *auth.User, model string, message [] return check.Error(), 0 } - if form := ExtractCacheData(c, &CacheProps{ - Message: segment, - Model: model, - Reversible: plan, - }); form != nil { - return form.Message, 0 - } - buffer := utils.NewBuffer(model, segment, channel.ChargeInstance.GetCharge(model)) hit, err := channel.NewChatRequestWithCache( cache, buffer, @@ -64,13 +56,5 @@ func NativeChatHandler(c *gin.Context, user *auth.User, model string, message [] CollectQuota(c, user, buffer, plan, err) } - SaveCacheData(c, &CacheProps{ - Message: segment, - Model: model, - Reversible: plan, - }, &CacheData{ - Message: buffer.ReadWithDefault(defaultMessage), - }) - return buffer.ReadWithDefault(defaultMessage), buffer.GetQuota() } diff --git a/utils/buffer.go b/utils/buffer.go index 5ea6fb1..3996ed8 100644 --- a/utils/buffer.go +++ b/utils/buffer.go @@ -23,7 +23,7 @@ type Buffer struct { Latest string `json:"latest"` Cursor int `json:"cursor"` Times int `json:"times"` - History []globals.Message `json:"history"` + InputTokens int `json:"input_tokens"` Images Images `json:"images"` ToolCalls *globals.ToolCalls `json:"tool_calls"` ToolCallsCursor int `json:"tool_calls_cursor"` @@ -31,8 +31,8 @@ type Buffer struct { Charge Charge `json:"-"` } -func initInputToken(charge Charge, model string, history []globals.Message) float32 { - if globals.IsOpenAIVisionModels(model) { +func initInputToken(model string, history []globals.Message) int { + if globals.IsVisionModel(model) { for _, message := range history { if message.Role == globals.User { content, _ := ExtractImages(message.Content, true) @@ -57,15 +57,20 @@ func initInputToken(charge Charge, model string, history []globals.Message) floa }) } - return CountInputToken(charge, model, history) + return CountTokenPrice(history, model) } func NewBuffer(model string, history []globals.Message, charge Charge) *Buffer { + token := initInputToken(model, history) + return &Buffer{ - Model: model, - Quota: initInputToken(charge, model, history), - History: history, - Charge: charge, + Model: model, + Quota: CountInputQuota(charge, token), + InputTokens: token, + Charge: charge, + FunctionCall: nil, + ToolCalls: nil, + ToolCallsCursor: 0, } } @@ -115,7 +120,6 @@ func (b *Buffer) AddToolCalls(toolCalls *globals.ToolCalls) { } b.ToolCalls = toolCalls - b.ToolCallsCursor += 1 } func (b *Buffer) SetFunctionCall(functionCall *globals.FunctionCall) { @@ -177,12 +181,12 @@ func (b *Buffer) ReadTimes() int { return b.Times } -func (b *Buffer) ReadHistory() []globals.Message { - return b.History +func (b *Buffer) SetInputTokens(tokens int) { + b.InputTokens = tokens } func (b *Buffer) CountInputToken() int { - return GetWeightByModel(b.Model) * NumTokensFromMessages(b.History, b.Model) + return b.InputTokens } func (b *Buffer) CountOutputToken() int { diff --git a/utils/tokenizer.go b/utils/tokenizer.go index 1cf290f..728b152 100644 --- a/utils/tokenizer.go +++ b/utils/tokenizer.go @@ -69,12 +69,12 @@ func NumTokensFromMessages(messages []globals.Message, model string) (tokens int } func CountTokenPrice(messages []globals.Message, model string) int { - return NumTokensFromMessages(messages, model) + return NumTokensFromMessages(messages, model) * GetWeightByModel(model) } -func CountInputToken(charge Charge, model string, message []globals.Message) float32 { - if charge.IsBillingType(globals.TokenBilling) { - return float32(CountTokenPrice(message, model)) / 1000 * charge.GetInput() +func CountInputQuota(charge Charge, token int) float32 { + if charge.GetType() == globals.TokenBilling { + return float32(token) / 1000 * charge.GetInput() } return 0