fix: fix frontend base64 image matcher, input tokens and b64images content hidden in un-vision models

This commit is contained in:
Zhang Minghan 2024-02-17 23:39:57 +08:00
parent 357d22a940
commit ccd9f13eaa
10 changed files with 60 additions and 97 deletions

View File

@ -2,6 +2,7 @@ package adapter
import (
"chat/globals"
"chat/utils"
"fmt"
"strings"
"time"
@ -44,3 +45,25 @@ func NewChatRequest(conf globals.ChannelConfig, props *ChatProps, hook globals.H
return conf.ProcessError(err)
}
func ClearMessages(model string, messages []globals.Message) []globals.Message {
if globals.IsVisionModel(model) {
return messages
}
return utils.Each[globals.Message](messages, func(message globals.Message) globals.Message {
if message.Role != globals.User {
return message
}
images := utils.ExtractBase64Images(message.Content)
for _, image := range images {
if len(image) <= 46 {
continue
}
message.Content = strings.Replace(message.Content, image, utils.Extract(image, 46, " ..."), -1)
}
return message
})
}

View File

@ -24,7 +24,7 @@ export function parseFile(data: string, acceptDownload?: boolean) {
const b64image = useMemo(() => {
// get base64 image from content (like: data:image/png;base64,xxxxx)
const match = content.match(/(data:image\/.*;base64,.*=)/);
const match = content.match(/data:image\/([^;]+);base64,([a-zA-Z0-9+/=]+)/g);
return match ? match[0] : "";
}, [filename, content]);

View File

@ -50,6 +50,7 @@ func PreflightCache(cache *redis.Client, hash string, buffer *utils.Buffer, hook
return idx, false, nil
}
buffer.SetInputTokens(buf.CountInputToken())
buffer.SetToolCalls(buf.GetToolCalls())
buffer.SetFunctionCall(buf.GetFunctionCall())
return idx, true, hook(data)

View File

@ -115,6 +115,11 @@ var OpenAIVisionModels = []string{
GPT4VisionPreview, GPT41106VisionPreview,
}
var VisionModels = []string{
GPT4VisionPreview, GPT41106VisionPreview,
GeminiProVision,
}
func in(value string, slice []string) bool {
for _, item := range slice {
if item == value || strings.Contains(value, item) {
@ -133,3 +138,7 @@ func IsOpenAIVisionModels(model string) bool {
// enable openai image format for gpt-4-vision-preview models
return in(model, OpenAIVisionModels)
}
func IsVisionModel(model string) bool {
return in(model, VisionModels)
}

View File

@ -1,38 +0,0 @@
package manager
import (
"chat/channel"
"chat/globals"
"chat/utils"
"fmt"
"github.com/gin-gonic/gin"
"time"
)
type CacheProps struct {
Message []globals.Message `json:"message" required:"true"`
Model string `json:"model" required:"true"`
Reversible bool `json:"reversible"`
}
type CacheData struct {
Message string `json:"message"`
}
func ExtractCacheData(c *gin.Context, props *CacheProps) *CacheData {
hash := utils.Md5Encrypt(utils.Marshal(props))
data, err := utils.GetCacheFromContext(c).Get(c, fmt.Sprintf(":niodata:%s", hash)).Result()
if err == nil && data != "" {
return utils.UnmarshalForm[CacheData](data)
}
return nil
}
func SaveCacheData(c *gin.Context, props *CacheProps, data *CacheData) {
if channel.ChargeInstance.IsBilling(props.Model) {
return
}
hash := utils.Md5Encrypt(utils.Marshal(props))
utils.GetCacheFromContext(c).Set(c, fmt.Sprintf(":niodata:%s", hash), utils.Marshal(data), time.Hour*12)
}

View File

@ -63,11 +63,12 @@ func ChatHandler(conn *Connection, user *auth.User, instance *conversation.Conve
}
}()
segment := web.UsingWebSegment(instance)
model := instance.GetModel()
db := conn.GetDB()
cache := conn.GetCache()
model := instance.GetModel()
segment := adapter.ClearMessages(model, web.UsingWebSegment(instance))
check, plan := auth.CanEnableModelWithSubscription(db, cache, user, model)
conn.Send(globals.ChatSegmentResponse{
Conversation: instance.GetId(),
@ -83,15 +84,6 @@ func ChatHandler(conn *Connection, user *auth.User, instance *conversation.Conve
return message
}
if form := ExtractCacheData(conn.GetCtx(), &CacheProps{
Message: segment,
Model: model,
Reversible: plan,
}); form != nil {
MockStreamSender(conn, form.Message)
return form.Message
}
buffer := utils.NewBuffer(model, segment, channel.ChargeInstance.GetCharge(model))
hit, err := channel.NewChatRequestWithCache(
cache, buffer,
@ -152,17 +144,5 @@ func ChatHandler(conn *Connection, user *auth.User, instance *conversation.Conve
Plan: plan,
})
result := buffer.ReadWithDefault(defaultMessage)
if err == nil && result != defaultMessage {
SaveCacheData(conn.GetCtx(), &CacheProps{
Message: segment,
Model: model,
Reversible: plan,
}, &CacheData{
Message: result,
})
}
return result
return buffer.ReadWithDefault(defaultMessage)
}

View File

@ -121,7 +121,7 @@ func sendTranshipmentResponse(c *gin.Context, form RelayForm, messages []globals
Index: 0,
Message: globals.Message{
Role: globals.Assistant,
Content: buffer.ReadWithDefault(defaultMessage),
Content: buffer.Read(),
ToolCalls: buffer.GetToolCalls(),
FunctionCall: buffer.GetFunctionCall(),
},

View File

@ -31,14 +31,6 @@ func NativeChatHandler(c *gin.Context, user *auth.User, model string, message []
return check.Error(), 0
}
if form := ExtractCacheData(c, &CacheProps{
Message: segment,
Model: model,
Reversible: plan,
}); form != nil {
return form.Message, 0
}
buffer := utils.NewBuffer(model, segment, channel.ChargeInstance.GetCharge(model))
hit, err := channel.NewChatRequestWithCache(
cache, buffer,
@ -64,13 +56,5 @@ func NativeChatHandler(c *gin.Context, user *auth.User, model string, message []
CollectQuota(c, user, buffer, plan, err)
}
SaveCacheData(c, &CacheProps{
Message: segment,
Model: model,
Reversible: plan,
}, &CacheData{
Message: buffer.ReadWithDefault(defaultMessage),
})
return buffer.ReadWithDefault(defaultMessage), buffer.GetQuota()
}

View File

@ -23,7 +23,7 @@ type Buffer struct {
Latest string `json:"latest"`
Cursor int `json:"cursor"`
Times int `json:"times"`
History []globals.Message `json:"history"`
InputTokens int `json:"input_tokens"`
Images Images `json:"images"`
ToolCalls *globals.ToolCalls `json:"tool_calls"`
ToolCallsCursor int `json:"tool_calls_cursor"`
@ -31,8 +31,8 @@ type Buffer struct {
Charge Charge `json:"-"`
}
func initInputToken(charge Charge, model string, history []globals.Message) float32 {
if globals.IsOpenAIVisionModels(model) {
func initInputToken(model string, history []globals.Message) int {
if globals.IsVisionModel(model) {
for _, message := range history {
if message.Role == globals.User {
content, _ := ExtractImages(message.Content, true)
@ -57,15 +57,20 @@ func initInputToken(charge Charge, model string, history []globals.Message) floa
})
}
return CountInputToken(charge, model, history)
return CountTokenPrice(history, model)
}
func NewBuffer(model string, history []globals.Message, charge Charge) *Buffer {
token := initInputToken(model, history)
return &Buffer{
Model: model,
Quota: initInputToken(charge, model, history),
History: history,
Charge: charge,
Model: model,
Quota: CountInputQuota(charge, token),
InputTokens: token,
Charge: charge,
FunctionCall: nil,
ToolCalls: nil,
ToolCallsCursor: 0,
}
}
@ -115,7 +120,6 @@ func (b *Buffer) AddToolCalls(toolCalls *globals.ToolCalls) {
}
b.ToolCalls = toolCalls
b.ToolCallsCursor += 1
}
func (b *Buffer) SetFunctionCall(functionCall *globals.FunctionCall) {
@ -177,12 +181,12 @@ func (b *Buffer) ReadTimes() int {
return b.Times
}
func (b *Buffer) ReadHistory() []globals.Message {
return b.History
func (b *Buffer) SetInputTokens(tokens int) {
b.InputTokens = tokens
}
func (b *Buffer) CountInputToken() int {
return GetWeightByModel(b.Model) * NumTokensFromMessages(b.History, b.Model)
return b.InputTokens
}
func (b *Buffer) CountOutputToken() int {

View File

@ -69,12 +69,12 @@ func NumTokensFromMessages(messages []globals.Message, model string) (tokens int
}
func CountTokenPrice(messages []globals.Message, model string) int {
return NumTokensFromMessages(messages, model)
return NumTokensFromMessages(messages, model) * GetWeightByModel(model)
}
func CountInputToken(charge Charge, model string, message []globals.Message) float32 {
if charge.IsBillingType(globals.TokenBilling) {
return float32(CountTokenPrice(message, model)) / 1000 * charge.GetInput()
func CountInputQuota(charge Charge, token int) float32 {
if charge.GetType() == globals.TokenBilling {
return float32(token) / 1000 * charge.GetInput()
}
return 0