From efc76c36424abccf2756f98a6ebf1be74f1c05fe Mon Sep 17 00:00:00 2001 From: Zhang Minghan Date: Sun, 18 Feb 2024 12:08:25 +0800 Subject: [PATCH] feat: support function calling and tools calling in relay chat_completions! (#20) (#30) --- adapter/azure/chat.go | 15 +++-- adapter/azure/processor.go | 28 +++++---- adapter/baichuan/chat.go | 50 +++++---------- adapter/baichuan/processor.go | 96 ++++++----------------------- adapter/baichuan/types.go | 23 ++----- adapter/bing/chat.go | 4 +- adapter/chatgpt/chat.go | 6 +- adapter/chatgpt/processor.go | 28 +++++---- adapter/claude/chat.go | 2 +- adapter/dashscope/chat.go | 2 +- adapter/hunyuan/chat.go | 7 ++- adapter/midjourney/chat.go | 14 +++-- adapter/palm2/chat.go | 2 +- adapter/skylark/chat.go | 18 +++--- adapter/slack/struct.go | 2 +- adapter/sparkdesk/chat.go | 26 ++++---- adapter/zhinao/chat.go | 2 +- adapter/zhipuai/chat.go | 2 +- addition/generation/prompt.go | 6 +- app/src/components/plugins/file.tsx | 4 +- app/src/resources/i18n/cn.json | 2 +- app/src/resources/i18n/en.json | 2 +- app/src/resources/i18n/ja.json | 2 +- app/src/resources/i18n/ru.json | 2 +- app/src/translator/adapter.ts | 2 + app/src/translator/io.ts | 1 + channel/worker.go | 7 ++- globals/types.go | 9 ++- manager/chat.go | 4 +- manager/chat_completions.go | 40 ++++++------ manager/completions.go | 4 +- manager/images.go | 4 +- utils/base.go | 9 +++ utils/buffer.go | 54 +++++++++++++++- 34 files changed, 247 insertions(+), 232 deletions(-) diff --git a/adapter/azure/chat.go b/adapter/azure/chat.go index 1f5cd6b..6c70fa1 100644 --- a/adapter/azure/chat.go +++ b/adapter/azure/chat.go @@ -99,33 +99,40 @@ func (c *ChatInstance) CreateStreamChatRequest(props *ChatProps, callback global if url, err := c.CreateImage(props); err != nil { return err } else { - return callback(url) + return callback(&globals.Chunk{Content: url}) } } isCompletionType := props.Model == globals.GPT3TurboInstruct + ticks := 0 err := utils.EventScanner(&utils.EventScannerProps{ Method: "POST", Uri: c.GetChatEndpoint(props), Headers: c.GetHeader(), Body: c.GetChatBody(props, true), Callback: func(data string) error { - partial, err := c.ProcessLine(props.Buffer, data, isCompletionType) + ticks += 1 + + partial, err := c.ProcessLine(data, isCompletionType) if err != nil { return err } - return callback(partial) }, }) if err != nil { if form := processChatErrorResponse(err.Body); form != nil { - return errors.New(fmt.Sprintf("%s (type: %s)", form.Error.Message, form.Error.Type)) + msg := fmt.Sprintf("%s (type: %s)", form.Error.Message, form.Error.Type) + return errors.New(msg) } return err.Error } + if ticks == 0 { + return errors.New("no response") + } + return nil } diff --git a/adapter/azure/processor.go b/adapter/azure/processor.go index 2549b3b..d5c1c17 100644 --- a/adapter/azure/processor.go +++ b/adapter/azure/processor.go @@ -74,16 +74,18 @@ func processChatErrorResponse(data string) *ChatStreamErrorResponse { return utils.UnmarshalForm[ChatStreamErrorResponse](data) } -func getChoices(buffer utils.Buffer, form *ChatStreamResponse) string { +func getChoices(form *ChatStreamResponse) *globals.Chunk { if len(form.Choices) == 0 { - return "" + return &globals.Chunk{Content: ""} } choice := form.Choices[0].Delta - buffer.AddToolCalls(choice.ToolCalls) - buffer.SetFunctionCall(choice.FunctionCall) - return choice.Content + return &globals.Chunk{ + Content: choice.Content, + ToolCall: choice.ToolCalls, + FunctionCall: choice.FunctionCall, + } } func getCompletionChoices(form *CompletionResponse) string { @@ -109,25 +111,27 @@ func getRobustnessResult(chunk string) string { } } -func (c *ChatInstance) ProcessLine(obj utils.Buffer, data string, isCompletionType bool) (string, error) { +func (c *ChatInstance) ProcessLine(data string, isCompletionType bool) (*globals.Chunk, error) { if isCompletionType { - // legacy support + // openai legacy support if completion := processCompletionResponse(data); completion != nil { - return getCompletionChoices(completion), nil + return &globals.Chunk{ + Content: getCompletionChoices(completion), + }, nil } globals.Warn(fmt.Sprintf("chatgpt error: cannot parse completion response: %s", data)) - return "", errors.New("parser error: cannot parse completion response") + return &globals.Chunk{Content: ""}, errors.New("parser error: cannot parse completion response") } if form := processChatResponse(data); form != nil { - return getChoices(obj, form), nil + return getChoices(form), nil } if form := processChatErrorResponse(data); form != nil { - return "", errors.New(fmt.Sprintf("chatgpt error: %s (type: %s)", form.Error.Message, form.Error.Type)) + return &globals.Chunk{Content: ""}, errors.New(fmt.Sprintf("chatgpt error: %s (type: %s)", form.Error.Message, form.Error.Type)) } globals.Warn(fmt.Sprintf("chatgpt error: cannot parse chat completion response: %s", data)) - return "", errors.New("parser error: cannot parse chat completion response") + return &globals.Chunk{Content: ""}, errors.New("parser error: cannot parse chat completion response") } diff --git a/adapter/baichuan/chat.go b/adapter/baichuan/chat.go index a7798f5..d59c4e5 100644 --- a/adapter/baichuan/chat.go +++ b/adapter/baichuan/chat.go @@ -3,8 +3,8 @@ package baichuan import ( "chat/globals" "chat/utils" + "errors" "fmt" - "strings" ) type ChatProps struct { @@ -72,44 +72,26 @@ func (c *ChatInstance) CreateChatRequest(props *ChatProps) (string, error) { // CreateStreamChatRequest is the stream response body for baichuan func (c *ChatInstance) CreateStreamChatRequest(props *ChatProps, callback globals.Hook) error { - buf := "" - cursor := 0 - chunk := "" - - err := utils.EventSource( - "POST", - c.GetChatEndpoint(), - c.GetHeader(), - c.GetChatBody(props, true), - func(data string) error { - data, err := c.ProcessLine(buf, data) - chunk += data - + err := utils.EventScanner(&utils.EventScannerProps{ + Method: "POST", + Uri: c.GetChatEndpoint(), + Headers: c.GetHeader(), + Body: c.GetChatBody(props, true), + Callback: func(data string) error { + partial, err := c.ProcessLine(data) if err != nil { - if strings.HasPrefix(err.Error(), "baichuan error") { - return err - } - - // error when break line - buf = buf + data - return nil + return err } - - buf = "" - if data != "" { - cursor += 1 - if err := callback(data); err != nil { - return err - } - } - return nil + return callback(partial) }, - ) + }) if err != nil { - return err - } else if len(chunk) == 0 { - return fmt.Errorf("empty response") + if form := processChatErrorResponse(err.Body); form != nil { + msg := fmt.Sprintf("%s (type: %s)", form.Error.Message, form.Error.Type) + return errors.New(msg) + } + return err.Error } return nil diff --git a/adapter/baichuan/processor.go b/adapter/baichuan/processor.go index f7d6935..c495476 100644 --- a/adapter/baichuan/processor.go +++ b/adapter/baichuan/processor.go @@ -5,95 +5,37 @@ import ( "chat/utils" "errors" "fmt" - "strings" ) -func processFormat(data string) string { - rep := strings.NewReplacer( - "data: {", - "\"data\": {", - ) - item := rep.Replace(data) - if !strings.HasPrefix(item, "{") { - item = "{" + item - } - if !strings.HasSuffix(item, "}}") { - item = item + "}" - } - - return item -} - func processChatResponse(data string) *ChatStreamResponse { - if strings.HasPrefix(data, "{") { - var form *ChatStreamResponse - if form = utils.UnmarshalForm[ChatStreamResponse](data); form != nil { - return form - } - - if form = utils.UnmarshalForm[ChatStreamResponse](data[:len(data)-1]); form != nil { - return form - } - - if form = utils.UnmarshalForm[ChatStreamResponse](data + "}"); form != nil { - return form - } - } - - return nil + return utils.UnmarshalForm[ChatStreamResponse](data) } func processChatErrorResponse(data string) *ChatStreamErrorResponse { - if strings.HasPrefix(data, "{") { - var form *ChatStreamErrorResponse - if form = utils.UnmarshalForm[ChatStreamErrorResponse](data); form != nil { - return form - } - if form = utils.UnmarshalForm[ChatStreamErrorResponse](data + "}"); form != nil { - return form - } - } - - return nil + return utils.UnmarshalForm[ChatStreamErrorResponse](data) } -func isDone(data string) bool { - return utils.Contains[string](data, []string{ - "{data: [DONE]}", "{data: [DONE]}}", "null}}", "{null}", - "{[DONE]}", "{data:}", "{data:}}", "data: [DONE]}}", - }) -} - -func getChoices(form *ChatStreamResponse) string { - if len(form.Data.Choices) == 0 { - if len(form.Choices) > 0 { - return form.Choices[0].Delta.Content - } +func getChoices(form *ChatStreamResponse) *globals.Chunk { + if len(form.Choices) == 0 { + return &globals.Chunk{Content: ""} } - return form.Data.Choices[0].Delta.Content + choice := form.Choices[0].Delta + + return &globals.Chunk{ + Content: choice.Content, + } } -func (c *ChatInstance) ProcessLine(buf, data string) (string, error) { - item := processFormat(buf + data) - if isDone(item) { - return "", nil - } - - if form := processChatResponse(item); form == nil { - // recursive call - if len(buf) > 0 { - return c.ProcessLine("", buf+item) - } - - if err := processChatErrorResponse(item); err == nil || err.Data.Error.Message == "" { - globals.Warn(fmt.Sprintf("baichuan error: cannot parse response: %s", item)) - return data, errors.New("parser error: cannot parse response") - } else { - return "", fmt.Errorf("baichuan error: %s (type: %s)", err.Data.Error.Message, err.Data.Error.Type) - } - - } else { +func (c *ChatInstance) ProcessLine(data string) (*globals.Chunk, error) { + if form := processChatResponse(data); form != nil { return getChoices(form), nil } + + if form := processChatErrorResponse(data); form != nil { + return &globals.Chunk{Content: ""}, errors.New(fmt.Sprintf("baichuan error: %s (type: %s)", form.Error.Message, form.Error.Type)) + } + + globals.Warn(fmt.Sprintf("baichuan error: cannot parse chat completion response: %s", data)) + return &globals.Chunk{Content: ""}, errors.New("parser error: cannot parse chat completion response") } diff --git a/adapter/baichuan/types.go b/adapter/baichuan/types.go index 65aa6e3..c7b67f3 100644 --- a/adapter/baichuan/types.go +++ b/adapter/baichuan/types.go @@ -32,19 +32,6 @@ type ChatResponse struct { // ChatStreamResponse is the stream response body for baichuan type ChatStreamResponse struct { - Data struct { - ID string `json:"id"` - Object string `json:"object"` - Created int64 `json:"created"` - Model string `json:"model"` - Choices []struct { - Delta struct { - Content string `json:"content"` - } - Index int `json:"index"` - } `json:"choices"` - } `json:"data"` - ID string `json:"id"` Object string `json:"object"` Created int64 `json:"created"` @@ -58,10 +45,8 @@ type ChatStreamResponse struct { } type ChatStreamErrorResponse struct { - Data struct { - Error struct { - Message string `json:"message"` - Type string `json:"type"` - } `json:"error"` - } `json:"data"` + Error struct { + Message string `json:"message"` + Type string `json:"type"` + } `json:"error"` } diff --git a/adapter/bing/chat.go b/adapter/bing/chat.go index dc2bdec..2dfac87 100644 --- a/adapter/bing/chat.go +++ b/adapter/bing/chat.go @@ -39,7 +39,9 @@ func (c *ChatInstance) CreateStreamChatRequest(props *ChatProps, hook globals.Ho return nil } - if err := hook(form.Response); err != nil { + if err := hook(&globals.Chunk{ + Content: form.Response, + }); err != nil { return err } } diff --git a/adapter/chatgpt/chat.go b/adapter/chatgpt/chat.go index bd5a5d0..823c6c7 100644 --- a/adapter/chatgpt/chat.go +++ b/adapter/chatgpt/chat.go @@ -109,7 +109,9 @@ func (c *ChatInstance) CreateStreamChatRequest(props *ChatProps, callback global if url, err := c.CreateImage(props); err != nil { return err } else { - return callback(url) + return callback(&globals.Chunk{ + Content: url, + }) } } @@ -124,7 +126,7 @@ func (c *ChatInstance) CreateStreamChatRequest(props *ChatProps, callback global Callback: func(data string) error { ticks += 1 - partial, err := c.ProcessLine(props.Buffer, data, isCompletionType) + partial, err := c.ProcessLine(data, isCompletionType) if err != nil { return err } diff --git a/adapter/chatgpt/processor.go b/adapter/chatgpt/processor.go index 55a8930..5b433d4 100644 --- a/adapter/chatgpt/processor.go +++ b/adapter/chatgpt/processor.go @@ -72,16 +72,18 @@ func processChatErrorResponse(data string) *ChatStreamErrorResponse { return utils.UnmarshalForm[ChatStreamErrorResponse](data) } -func getChoices(buffer utils.Buffer, form *ChatStreamResponse) string { +func getChoices(form *ChatStreamResponse) *globals.Chunk { if len(form.Choices) == 0 { - return "" + return &globals.Chunk{Content: ""} } choice := form.Choices[0].Delta - buffer.AddToolCalls(choice.ToolCalls) - buffer.SetFunctionCall(choice.FunctionCall) - return choice.Content + return &globals.Chunk{ + Content: choice.Content, + ToolCall: choice.ToolCalls, + FunctionCall: choice.FunctionCall, + } } func getCompletionChoices(form *CompletionResponse) string { @@ -107,25 +109,27 @@ func getRobustnessResult(chunk string) string { } } -func (c *ChatInstance) ProcessLine(obj utils.Buffer, data string, isCompletionType bool) (string, error) { +func (c *ChatInstance) ProcessLine(data string, isCompletionType bool) (*globals.Chunk, error) { if isCompletionType { - // legacy support + // openai legacy support if completion := processCompletionResponse(data); completion != nil { - return getCompletionChoices(completion), nil + return &globals.Chunk{ + Content: getCompletionChoices(completion), + }, nil } globals.Warn(fmt.Sprintf("chatgpt error: cannot parse completion response: %s", data)) - return "", errors.New("parser error: cannot parse completion response") + return &globals.Chunk{Content: ""}, errors.New("parser error: cannot parse completion response") } if form := processChatResponse(data); form != nil { - return getChoices(obj, form), nil + return getChoices(form), nil } if form := processChatErrorResponse(data); form != nil { - return "", errors.New(fmt.Sprintf("chatgpt error: %s (type: %s)", form.Error.Message, form.Error.Type)) + return &globals.Chunk{Content: ""}, errors.New(fmt.Sprintf("chatgpt error: %s (type: %s)", form.Error.Message, form.Error.Type)) } globals.Warn(fmt.Sprintf("chatgpt error: cannot parse chat completion response: %s", data)) - return "", errors.New("parser error: cannot parse chat completion response") + return &globals.Chunk{Content: ""}, errors.New("parser error: cannot parse chat completion response") } diff --git a/adapter/claude/chat.go b/adapter/claude/chat.go index 1309bf5..e23ef82 100644 --- a/adapter/claude/chat.go +++ b/adapter/claude/chat.go @@ -127,7 +127,7 @@ func (c *ChatInstance) CreateStreamChatRequest(props *ChatProps, hook globals.Ho if resp, err := c.ProcessLine(buf, data); err == nil && len(resp) > 0 { buf = "" - if err := hook(resp); err != nil { + if err := hook(&globals.Chunk{Content: resp}); err != nil { return err } } else { diff --git a/adapter/dashscope/chat.go b/adapter/dashscope/chat.go index 18cd079..2bd6326 100644 --- a/adapter/dashscope/chat.go +++ b/adapter/dashscope/chat.go @@ -126,7 +126,7 @@ func (c *ChatInstance) CreateStreamChatRequest(props *ChatProps, callback global return fmt.Errorf("dashscope error: %s", form.Message) } - if err := callback(form.Output.Text); err != nil { + if err := callback(&globals.Chunk{Content: form.Output.Text}); err != nil { return err } return nil diff --git a/adapter/hunyuan/chat.go b/adapter/hunyuan/chat.go index eb57709..2e9363f 100644 --- a/adapter/hunyuan/chat.go +++ b/adapter/hunyuan/chat.go @@ -50,7 +50,12 @@ func (c *ChatInstance) CreateStreamChatRequest(props *ChatProps, callback global break } - if err := callback(chunk.Choices[0].Delta.Content); err != nil { + if len(chunk.Choices) == 0 { + continue + } + + choice := chunk.Choices[0].Delta + if err := callback(&globals.Chunk{Content: choice.Content}); err != nil { return err } } diff --git a/adapter/midjourney/chat.go b/adapter/midjourney/chat.go index c242f9c..d84c222 100644 --- a/adapter/midjourney/chat.go +++ b/adapter/midjourney/chat.go @@ -79,21 +79,21 @@ func (c *ChatInstance) CreateStreamChatRequest(props *ChatProps, callback global form, err := c.CreateStreamTask(action, prompt, func(form *StorageForm, progress int) error { if progress == 0 { begin = true - if err := callback("```progress\n"); err != nil { + if err := callback(&globals.Chunk{Content: "```progress\n"}); err != nil { return err } } else if progress == 100 && !begin { - if err := callback("```progress\n"); err != nil { + if err := callback(&globals.Chunk{Content: "```progress\n"}); err != nil { return err } } - if err := callback(fmt.Sprintf("%d\n", progress)); err != nil { + if err := callback(&globals.Chunk{Content: fmt.Sprintf("%d\n", progress)}); err != nil { return err } if progress == 100 { - if err := callback("```\n"); err != nil { + if err := callback(&globals.Chunk{Content: "```\n"}); err != nil { return err } } @@ -105,7 +105,7 @@ func (c *ChatInstance) CreateStreamChatRequest(props *ChatProps, callback global return fmt.Errorf("error from midjourney: %s", err.Error()) } - if err := callback(utils.GetImageMarkdown(form.Url)); err != nil { + if err := callback(&globals.Chunk{Content: utils.GetImageMarkdown(form.Url)}); err != nil { return err } @@ -133,5 +133,7 @@ func (c *ChatInstance) CallbackActions(form *StorageForm, callback globals.Hook) reroll := fmt.Sprintf("[REROLL](%s)", toVirtualMessage(fmt.Sprintf("/REROLL %s", form.Task))) - return callback(fmt.Sprintf("\n\n%s\n\n%s\n\n%s\n", upscale, variation, reroll)) + return callback(&globals.Chunk{ + Content: fmt.Sprintf("\n\n%s\n\n%s\n\n%s\n", upscale, variation, reroll), + }) } diff --git a/adapter/palm2/chat.go b/adapter/palm2/chat.go index 703db09..a55b56e 100644 --- a/adapter/palm2/chat.go +++ b/adapter/palm2/chat.go @@ -129,7 +129,7 @@ func (c *ChatInstance) CreateStreamChatRequest(props *ChatProps, callback global } for _, item := range utils.SplitItem(response, " ") { - if err := callback(item); err != nil { + if err := callback(&globals.Chunk{Content: item}); err != nil { return err } } diff --git a/adapter/skylark/chat.go b/adapter/skylark/chat.go index 2b65bcb..3cd973b 100644 --- a/adapter/skylark/chat.go +++ b/adapter/skylark/chat.go @@ -66,14 +66,17 @@ func (c *ChatInstance) CreateRequest(props *ChatProps) *api.ChatReq { } } -func getChoice(choice *api.ChatResp, buffer utils.Buffer) string { +func getChoice(choice *api.ChatResp) *globals.Chunk { if choice == nil { - return "" + return &globals.Chunk{Content: ""} } - calls := choice.Choice.Message.FunctionCall - if calls != nil { - buffer.AddToolCalls(&globals.ToolCalls{ + message := choice.Choice.Message + + calls := message.FunctionCall + return &globals.Chunk{ + Content: message.Content, + ToolCall: utils.Multi(calls != nil, &globals.ToolCalls{ globals.ToolCall{ Type: "function", Id: globals.ToolCallId(fmt.Sprintf("%s-%s", calls.Name, choice.ReqId)), @@ -82,9 +85,8 @@ func getChoice(choice *api.ChatResp, buffer utils.Buffer) string { Arguments: calls.Arguments, }, }, - }) + }, nil), } - return choice.Choice.Message.Content } func (c *ChatInstance) CreateStreamChatRequest(props *ChatProps, callback globals.Hook) error { @@ -99,7 +101,7 @@ func (c *ChatInstance) CreateStreamChatRequest(props *ChatProps, callback global return partial.Error } - if err := callback(getChoice(partial, props.Buffer)); err != nil { + if err := callback(getChoice(partial)); err != nil { return err } } diff --git a/adapter/slack/struct.go b/adapter/slack/struct.go index 318af9f..cc89de7 100644 --- a/adapter/slack/struct.go +++ b/adapter/slack/struct.go @@ -77,7 +77,7 @@ func (c *ChatInstance) ProcessPartialResponse(res chan types.PartialResponse, ho if data.Error != nil { return data.Error } else if data.Text != "" { - if err := hook(data.Text); err != nil { + if err := hook(&globals.Chunk{Content: data.Text}); err != nil { return err } } diff --git a/adapter/sparkdesk/chat.go b/adapter/sparkdesk/chat.go index 5df8d21..25b7d69 100644 --- a/adapter/sparkdesk/chat.go +++ b/adapter/sparkdesk/chat.go @@ -67,26 +67,26 @@ func (c *ChatInstance) GetFunctionCalling(props *ChatProps) *FunctionsPayload { } } -func getChoice(form *ChatResponse, buffer utils.Buffer) string { - resp := form.Payload.Choices.Text - if len(resp) == 0 { - return "" +func getChoice(form *ChatResponse) *globals.Chunk { + if len(form.Payload.Choices.Text) == 0 { + return &globals.Chunk{Content: ""} } - if resp[0].FunctionCall != nil { - buffer.AddToolCalls(&globals.ToolCalls{ + choice := form.Payload.Choices.Text[0] + + return &globals.Chunk{ + Content: choice.Content, + ToolCall: utils.Multi(choice.FunctionCall != nil, &globals.ToolCalls{ globals.ToolCall{ Type: "function", - Id: globals.ToolCallId(fmt.Sprintf("%s-%s", resp[0].FunctionCall.Name, resp[0].FunctionCall.Arguments)), + Id: globals.ToolCallId(fmt.Sprintf("%s-%s", choice.FunctionCall.Name, choice.FunctionCall.Arguments)), Function: globals.ToolCallFunction{ - Name: resp[0].FunctionCall.Name, - Arguments: resp[0].FunctionCall.Arguments, + Name: choice.FunctionCall.Name, + Arguments: choice.FunctionCall.Arguments, }, }, - }) + }, nil), } - - return resp[0].Content } func (c *ChatInstance) CreateStreamChatRequest(props *ChatProps, hook globals.Hook) error { @@ -130,7 +130,7 @@ func (c *ChatInstance) CreateStreamChatRequest(props *ChatProps, hook globals.Ho return fmt.Errorf("sparkdesk error: %s (sid: %s)", form.Header.Message, form.Header.Sid) } - if err := hook(getChoice(form, props.Buffer)); err != nil { + if err := hook(getChoice(form)); err != nil { return err } } diff --git a/adapter/zhinao/chat.go b/adapter/zhinao/chat.go index b27d1e1..2736a11 100644 --- a/adapter/zhinao/chat.go +++ b/adapter/zhinao/chat.go @@ -103,7 +103,7 @@ func (c *ChatInstance) CreateStreamChatRequest(props *ChatProps, callback global buf = "" if data != "" { cursor += 1 - if err := callback(data); err != nil { + if err := callback(&globals.Chunk{Content: data}); err != nil { return err } } diff --git a/adapter/zhipuai/chat.go b/adapter/zhipuai/chat.go index 478fab7..3de4d92 100644 --- a/adapter/zhipuai/chat.go +++ b/adapter/zhipuai/chat.go @@ -73,7 +73,7 @@ func (c *ChatInstance) CreateStreamChatRequest(props *ChatProps, hook globals.Ho } data = strings.TrimPrefix(data, "data:") - return hook(data) + return hook(&globals.Chunk{Content: data}) }, ) } diff --git a/addition/generation/prompt.go b/addition/generation/prompt.go index e3decdc..ecaaf09 100644 --- a/addition/generation/prompt.go +++ b/addition/generation/prompt.go @@ -21,9 +21,9 @@ func CreateGeneration(group, model, prompt, path string, hook func(buffer *utils Model: model, Message: message, Buffer: *buffer, - }, func(data string) error { - buffer.Write(data) - hook(buffer, data) + }, func(data *globals.Chunk) error { + buffer.WriteChunk(data) + hook(buffer, data.Content) return nil }) diff --git a/app/src/components/plugins/file.tsx b/app/src/components/plugins/file.tsx index 7621e22..ad6f795 100644 --- a/app/src/components/plugins/file.tsx +++ b/app/src/components/plugins/file.tsx @@ -24,7 +24,9 @@ export function parseFile(data: string, acceptDownload?: boolean) { const b64image = useMemo(() => { // get base64 image from content (like: ) - const match = content.match(/data:image\/([^;]+);base64,([a-zA-Z0-9+/=]+)/g); + const match = content.match( + /data:image\/([^;]+);base64,([a-zA-Z0-9+/=]+)/g, + ); return match ? match[0] : ""; }, [filename, content]); diff --git a/app/src/resources/i18n/cn.json b/app/src/resources/i18n/cn.json index a759ce5..5ff96e1 100644 --- a/app/src/resources/i18n/cn.json +++ b/app/src/resources/i18n/cn.json @@ -356,7 +356,7 @@ "align": "聊天框居中", "memory": "内存占用", "max-tokens": "最大回复 Token 数", - "max-tokens-tip": "最大回复 Token 数,超过此数值将会被截断", + "max-tokens-tip": "最大回复 Token 数,超过此数值将会被截断(过高的数值可能会导致超过模型的最大 Token 导致请求失败)", "temperature": "温度", "temperature-tip": "随机采样的比例,高温度会产生更多的随机性,低温度会产生较集中和确定性的文本", "top-p": "核采样概率阈值", diff --git a/app/src/resources/i18n/en.json b/app/src/resources/i18n/en.json index 69363f2..5ffb523 100644 --- a/app/src/resources/i18n/en.json +++ b/app/src/resources/i18n/en.json @@ -306,7 +306,7 @@ "temperature": "temperature", "temperature-tip": "Random sampling ratio, high temperature produces more randomness, low temperature produces more concentrated and deterministic text", "max-tokens": "Maximum number of response tokens", - "max-tokens-tip": "Maximum number of reply tokens, exceeding this value will be truncated", + "max-tokens-tip": "Maximum number of reply tokens, exceeding this value will be truncated (too high value may cause the request to fail due to exceeding the model's maximum token)", "top-p": "Kernel Sampling Probability Threshold", "top-p-tip": "(TopP) The higher the probability value, the higher the randomness generated; the lower the value, the higher the certainty generated", "top-k": "Sample Candidate Set Size", diff --git a/app/src/resources/i18n/ja.json b/app/src/resources/i18n/ja.json index 34002ee..9cc880c 100644 --- a/app/src/resources/i18n/ja.json +++ b/app/src/resources/i18n/ja.json @@ -306,7 +306,7 @@ "temperature": "温度", "temperature-tip": "ランダムサンプリング比、高温はよりランダム性を生み、低温はより集中的で決定論的なテキストを生成します", "max-tokens": "レスポンストークンの最大数", - "max-tokens-tip": "この値を超える返信トークンの最大数は切り捨てられます", + "max-tokens-tip": "この値を超えるリプライトークンの最大数は切り捨てられます(値が高すぎると、モデルの最大トークンを超えるために要求が失敗する可能性があります)", "top-p": "カーネルサンプリング確率閾値", "top-p-tip": "( TopP )確率値が高いほど生成されるランダム性が高く、値が低いほど生成される確実性が高くなります", "top-k": "サンプル候補セットサイズ", diff --git a/app/src/resources/i18n/ru.json b/app/src/resources/i18n/ru.json index 51b0472..6c199e1 100644 --- a/app/src/resources/i18n/ru.json +++ b/app/src/resources/i18n/ru.json @@ -306,7 +306,7 @@ "temperature": "Температура", "temperature-tip": "Коэффициент случайной выборки, высокая температура создает больше случайности, низкая температура создает более концентрированный и детерминированный текст", "max-tokens": "Максимальное количество маркеров ответа", - "max-tokens-tip": "Максимальное количество маркеров ответа, превышающее это значение, будет усечено", + "max-tokens-tip": "Максимальное количество маркеров ответа, превышающее это значение, будет усечено (слишком высокое значение может привести к сбою запроса из-за превышения максимального маркера модели)", "top-p": "Порог вероятности отбора проб ядра", "top-p-tip": "(TopP) Чем выше значение вероятности, тем выше генерируемая случайность; чем ниже значение, тем выше генерируемая определенность", "top-k": "Размер набора образцов-кандидатов", diff --git a/app/src/translator/adapter.ts b/app/src/translator/adapter.ts index cc7a425..da3514a 100644 --- a/app/src/translator/adapter.ts +++ b/app/src/translator/adapter.ts @@ -46,5 +46,7 @@ export function doTranslate( from = getFormattedLanguage(from); to = getFormattedLanguage(to); + if (content.startsWith("!!")) content = content.substring(2); + return translate(content, from, to); } diff --git a/app/src/translator/io.ts b/app/src/translator/io.ts index 1125ca7..52ccb67 100644 --- a/app/src/translator/io.ts +++ b/app/src/translator/io.ts @@ -23,6 +23,7 @@ export function getMigration( switch (typeof template) { case "string": if (typeof translation !== "string") return val; + else if (template.startsWith("!!")) return val; break; case "object": return getMigration(template, translation, val[0]); diff --git a/channel/worker.go b/channel/worker.go index ea5fc37..8daa8e3 100644 --- a/channel/worker.go +++ b/channel/worker.go @@ -53,7 +53,12 @@ func PreflightCache(cache *redis.Client, hash string, buffer *utils.Buffer, hook buffer.SetInputTokens(buf.CountInputToken()) buffer.SetToolCalls(buf.GetToolCalls()) buffer.SetFunctionCall(buf.GetFunctionCall()) - return idx, true, hook(data) + + return idx, true, hook(&globals.Chunk{ + Content: data, + FunctionCall: buf.GetFunctionCall(), + ToolCall: buf.GetToolCalls(), + }) } func StoreCache(cache *redis.Client, hash string, index int64, buffer *utils.Buffer) { diff --git a/globals/types.go b/globals/types.go index 9aef559..4bc1975 100644 --- a/globals/types.go +++ b/globals/types.go @@ -1,6 +1,7 @@ package globals -type Hook func(data string) error +type Hook func(data *Chunk) error + type Message struct { Role string `json:"role"` Content string `json:"content"` @@ -10,6 +11,12 @@ type Message struct { ToolCalls *ToolCalls `json:"tool_calls,omitempty"` // only `assistant` role } +type Chunk struct { + Content string `json:"content"` + ToolCall *ToolCalls `json:"tool_call,omitempty"` + FunctionCall *FunctionCall `json:"function_call,omitempty"` +} + type ChatSegmentResponse struct { Conversation int64 `json:"conversation"` Quota float32 `json:"quota"` diff --git a/manager/chat.go b/manager/chat.go index a645cbd..f08f501 100644 --- a/manager/chat.go +++ b/manager/chat.go @@ -100,13 +100,13 @@ func ChatHandler(conn *Connection, user *auth.User, instance *conversation.Conve FrequencyPenalty: instance.GetFrequencyPenalty(), RepetitionPenalty: instance.GetRepetitionPenalty(), }, - func(data string) error { + func(data *globals.Chunk) error { if signal := conn.PeekWithType(StopType); signal != nil { // stop signal from client return fmt.Errorf("signal") } return conn.SendClient(globals.ChatSegmentResponse{ - Message: buffer.Write(data), + Message: buffer.WriteChunk(data), Quota: buffer.GetQuota(), End: false, Plan: plan, diff --git a/manager/chat_completions.go b/manager/chat_completions.go index cd438ac..6a80492 100644 --- a/manager/chat_completions.go +++ b/manager/chat_completions.go @@ -93,8 +93,8 @@ func sendTranshipmentResponse(c *gin.Context, form RelayForm, messages []globals cache := utils.GetCacheFromContext(c) buffer := utils.NewBuffer(form.Model, messages, channel.ChargeInstance.GetCharge(form.Model)) - hit, err := channel.NewChatRequestWithCache(cache, buffer, auth.GetGroup(db, user), getChatProps(form, messages, buffer, plan), func(data string) error { - buffer.Write(data) + hit, err := channel.NewChatRequestWithCache(cache, buffer, auth.GetGroup(db, user), getChatProps(form, messages, buffer, plan), func(data *globals.Chunk) error { + buffer.WriteChunk(data) return nil }) @@ -137,14 +137,7 @@ func sendTranshipmentResponse(c *gin.Context, form RelayForm, messages []globals }) } -func getStreamTranshipmentForm(id string, created int64, form RelayForm, data string, buffer *utils.Buffer, end bool, err error) RelayStreamResponse { - toolCalling := buffer.GetToolCalls() - - var functionCalling *globals.FunctionCall - if end { - functionCalling = buffer.GetFunctionCall() - } - +func getStreamTranshipmentForm(id string, created int64, form RelayForm, data *globals.Chunk, buffer *utils.Buffer, end bool, err error) RelayStreamResponse { return RelayStreamResponse{ Id: fmt.Sprintf("chatcmpl-%s", id), Object: "chat.completion.chunk", @@ -155,9 +148,9 @@ func getStreamTranshipmentForm(id string, created int64, form RelayForm, data st Index: 0, Delta: globals.Message{ Role: globals.Assistant, - Content: data, - ToolCalls: toolCalling, - FunctionCall: functionCalling, + Content: data.Content, + ToolCalls: data.ToolCall, + FunctionCall: data.FunctionCall, }, FinishReason: utils.Multi[interface{}](end, "stop", nil), }, @@ -177,23 +170,30 @@ func sendStreamTranshipmentResponse(c *gin.Context, form RelayForm, messages []g db := utils.GetDBFromContext(c) cache := utils.GetCacheFromContext(c) + group := auth.GetGroup(db, user) + charge := channel.ChargeInstance.GetCharge(form.Model) + go func() { - buffer := utils.NewBuffer(form.Model, messages, channel.ChargeInstance.GetCharge(form.Model)) - hit, err := channel.NewChatRequestWithCache(cache, buffer, auth.GetGroup(db, user), getChatProps(form, messages, buffer, plan), func(data string) error { - partial <- getStreamTranshipmentForm(id, created, form, buffer.Write(data), buffer, false, nil) - return nil - }) + buffer := utils.NewBuffer(form.Model, messages, charge) + hit, err := channel.NewChatRequestWithCache( + cache, buffer, group, getChatProps(form, messages, buffer, plan), + func(data *globals.Chunk) error { + buffer.WriteChunk(data) + partial <- getStreamTranshipmentForm(id, created, form, data, buffer, false, nil) + return nil + }, + ) admin.AnalysisRequest(form.Model, buffer, err) if err != nil { auth.RevertSubscriptionUsage(db, cache, user, form.Model) globals.Warn(fmt.Sprintf("error from chat request api: %s (instance: %s, client: %s)", err.Error(), form.Model, c.ClientIP())) - partial <- getStreamTranshipmentForm(id, created, form, err.Error(), buffer, true, err) + partial <- getStreamTranshipmentForm(id, created, form, &globals.Chunk{Content: err.Error()}, buffer, true, err) close(partial) return } - partial <- getStreamTranshipmentForm(id, created, form, "", buffer, true, nil) + partial <- getStreamTranshipmentForm(id, created, form, &globals.Chunk{Content: ""}, buffer, true, nil) if !hit { CollectQuota(c, user, buffer, plan, err) diff --git a/manager/completions.go b/manager/completions.go index 3ff298e..ad1eef7 100644 --- a/manager/completions.go +++ b/manager/completions.go @@ -40,8 +40,8 @@ func NativeChatHandler(c *gin.Context, user *auth.User, model string, message [] Message: segment, Buffer: *buffer, }, - func(resp string) error { - buffer.Write(resp) + func(resp *globals.Chunk) error { + buffer.WriteChunk(resp) return nil }, ) diff --git a/manager/images.go b/manager/images.go index 830070d..838ca99 100644 --- a/manager/images.go +++ b/manager/images.go @@ -89,8 +89,8 @@ func createRelayImageObject(c *gin.Context, form RelayImageForm, prompt string, } buffer := utils.NewBuffer(form.Model, messages, channel.ChargeInstance.GetCharge(form.Model)) - hit, err := channel.NewChatRequestWithCache(cache, buffer, auth.GetGroup(db, user), getImageProps(form, messages, buffer), func(data string) error { - buffer.Write(data) + hit, err := channel.NewChatRequestWithCache(cache, buffer, auth.GetGroup(db, user), getImageProps(form, messages, buffer), func(data *globals.Chunk) error { + buffer.WriteChunk(data) return nil }) diff --git a/utils/base.go b/utils/base.go index ff63d19..2d60b88 100644 --- a/utils/base.go +++ b/utils/base.go @@ -81,6 +81,15 @@ func InsertSlice[T any](arr []T, index int, value []T) []T { return arr } +func Collect[T any](arr ...[]T) []T { + res := make([]T, 0) + + for _, v := range arr { + res = append(res, v...) + } + return res +} + func Append[T any](arr []T, value T) []T { return append(arr, value) } diff --git a/utils/buffer.go b/utils/buffer.go index 3996ed8..00e3e82 100644 --- a/utils/buffer.go +++ b/utils/buffer.go @@ -90,6 +90,18 @@ func (b *Buffer) Write(data string) string { return data } +func (b *Buffer) WriteChunk(data *globals.Chunk) string { + if data == nil { + return "" + } + + b.Write(data.Content) + b.AddToolCalls(data.ToolCall) + b.SetFunctionCall(data.FunctionCall) + + return data.Content +} + func (b *Buffer) GetChunk() string { return b.Latest } @@ -114,12 +126,52 @@ func (b *Buffer) SetToolCalls(toolCalls *globals.ToolCalls) { b.ToolCalls = toolCalls } +func hitTool(tool globals.ToolCall, tools globals.ToolCalls) (int, *globals.ToolCall) { + for i, t := range tools { + if t.Id == tool.Id { + return i, &t + } + } + + if len(tool.Type) == 0 && len(tool.Id) == 0 { + length := len(tools) + + if length > 0 { + // if the tool is empty, return the last tool as the hit + return length - 1, &tools[length-1] + } + } + + return 0, nil +} + +func mixTools(source *globals.ToolCalls, target *globals.ToolCalls) *globals.ToolCalls { + if source == nil { + return target + } + + tools := make(globals.ToolCalls, 0) + arr := Collect[globals.ToolCall](*source, *target) + + for _, tool := range arr { + idx, hit := hitTool(tool, tools) + + if hit != nil { + tools[idx].Function.Arguments += tool.Function.Arguments + } else { + tools = append(tools, tool) + } + } + + return &tools +} + func (b *Buffer) AddToolCalls(toolCalls *globals.ToolCalls) { if toolCalls == nil { return } - b.ToolCalls = toolCalls + b.ToolCalls = mixTools(b.ToolCalls, toolCalls) } func (b *Buffer) SetFunctionCall(functionCall *globals.FunctionCall) {