From 422da58b2432e209f3fc4ba86db5f64dbe6e2462 Mon Sep 17 00:00:00 2001 From: Zhang Minghan Date: Tue, 12 Mar 2024 14:36:18 +0800 Subject: [PATCH] restruct: restruct channel factories --- adapter/adapter.go | 185 +++-------------------- adapter/azure/chat.go | 38 ++--- adapter/azure/image.go | 9 +- adapter/azure/processor.go | 9 +- adapter/azure/struct.go | 3 +- adapter/azure/types.go | 12 +- adapter/baichuan/chat.go | 15 +- adapter/baichuan/struct.go | 3 +- adapter/bing/chat.go | 8 +- adapter/bing/struct.go | 3 +- adapter/chatgpt/test.go | 47 ------ adapter/claude/chat.go | 23 +-- adapter/claude/struct.go | 3 +- adapter/common/interface.go | 11 ++ adapter/common/types.go | 31 ++++ adapter/dashscope/chat.go | 25 +-- adapter/dashscope/struct.go | 3 +- adapter/hunyuan/chat.go | 10 +- adapter/hunyuan/struct.go | 3 +- adapter/midjourney/chat.go | 14 +- adapter/midjourney/handler.go | 5 + adapter/midjourney/struct.go | 3 +- adapter/{chatgpt => openai}/chat.go | 40 ++--- adapter/{chatgpt => openai}/image.go | 7 +- adapter/{chatgpt => openai}/processor.go | 11 +- adapter/{chatgpt => openai}/struct.go | 5 +- adapter/{chatgpt => openai}/types.go | 14 +- adapter/palm2/chat.go | 20 +-- adapter/palm2/struct.go | 3 +- adapter/palm2/types.go | 4 +- adapter/request.go | 9 +- adapter/skylark/chat.go | 24 +-- adapter/skylark/struct.go | 3 +- adapter/slack/chat.go | 7 +- adapter/slack/struct.go | 3 +- adapter/sparkdesk/chat.go | 33 ++-- adapter/sparkdesk/struct.go | 17 +-- adapter/zhinao/chat.go | 23 +-- adapter/zhinao/struct.go | 3 +- adapter/zhipuai/chat.go | 12 +- adapter/zhipuai/struct.go | 3 +- addition/generation/prompt.go | 10 +- app/src/admin/channel.ts | 22 +-- app/src/store/chat.ts | 2 +- channel/worker.go | 22 ++- cli/exec.go | 2 - cli/filter.go | 18 --- globals/interface.go | 1 + manager/chat.go | 3 +- manager/chat_completions.go | 6 +- manager/completions.go | 4 +- manager/images.go | 6 +- 52 files changed, 284 insertions(+), 516 deletions(-) delete mode 100644 adapter/chatgpt/test.go create mode 100644 adapter/common/interface.go create mode 100644 adapter/common/types.go rename adapter/{chatgpt => openai}/chat.go (71%) rename adapter/{chatgpt => openai}/image.go (86%) rename adapter/{chatgpt => openai}/processor.go (88%) rename adapter/{chatgpt => openai}/struct.go (84%) rename adapter/{chatgpt => openai}/types.go (90%) delete mode 100644 cli/filter.go diff --git a/adapter/adapter.go b/adapter/adapter.go index 92be5ee..445424b 100644 --- a/adapter/adapter.go +++ b/adapter/adapter.go @@ -4,11 +4,12 @@ import ( "chat/adapter/azure" "chat/adapter/baichuan" "chat/adapter/bing" - "chat/adapter/chatgpt" "chat/adapter/claude" + "chat/adapter/common" "chat/adapter/dashscope" "chat/adapter/hunyuan" "chat/adapter/midjourney" + "chat/adapter/openai" "chat/adapter/palm2" "chat/adapter/skylark" "chat/adapter/slack" @@ -16,171 +17,33 @@ import ( "chat/adapter/zhinao" "chat/adapter/zhipuai" "chat/globals" - "chat/utils" "fmt" ) -type RequestProps struct { - MaxRetries *int - Current int - Group string +var channelFactories = map[string]adaptercommon.FactoryCreator{ + globals.OpenAIChannelType: openai.NewChatInstanceFromConfig, + globals.AzureOpenAIChannelType: azure.NewChatInstanceFromConfig, + globals.ClaudeChannelType: claude.NewChatInstanceFromConfig, + globals.SlackChannelType: slack.NewChatInstanceFromConfig, + globals.BingChannelType: bing.NewChatInstanceFromConfig, + globals.PalmChannelType: palm2.NewChatInstanceFromConfig, + globals.SparkdeskChannelType: sparkdesk.NewChatInstanceFromConfig, + globals.ChatGLMChannelType: zhipuai.NewChatInstanceFromConfig, + globals.QwenChannelType: dashscope.NewChatInstanceFromConfig, + globals.HunyuanChannelType: hunyuan.NewChatInstanceFromConfig, + globals.BaichuanChannelType: baichuan.NewChatInstanceFromConfig, + globals.SkylarkChannelType: skylark.NewChatInstanceFromConfig, + globals.ZhinaoChannelType: zhinao.NewChatInstanceFromConfig, + globals.MidjourneyChannelType: midjourney.NewChatInstanceFromConfig, } -type ChatProps struct { - RequestProps +func createChatRequest(conf globals.ChannelConfig, props *adaptercommon.ChatProps, hook globals.Hook) error { + props.Model = conf.GetModelReflect(props.OriginalModel) - Model string - Message []globals.Message - MaxTokens *int - PresencePenalty *float32 - FrequencyPenalty *float32 - RepetitionPenalty *float32 - Temperature *float32 - TopP *float32 - TopK *int - Tools *globals.FunctionTools - ToolChoice *interface{} - Buffer utils.Buffer -} - -func createChatRequest(conf globals.ChannelConfig, props *ChatProps, hook globals.Hook) error { - model := conf.GetModelReflect(props.Model) - - switch conf.GetType() { - case globals.OpenAIChannelType: - return chatgpt.NewChatInstanceFromConfig(conf).CreateStreamChatRequest(&chatgpt.ChatProps{ - Model: model, - Message: props.Message, - Token: props.MaxTokens, - PresencePenalty: props.PresencePenalty, - FrequencyPenalty: props.FrequencyPenalty, - Temperature: props.Temperature, - TopP: props.TopP, - Tools: props.Tools, - ToolChoice: props.ToolChoice, - Buffer: props.Buffer, - }, hook) - - case globals.AzureOpenAIChannelType: - return azure.NewChatInstanceFromConfig(conf).CreateStreamChatRequest(&azure.ChatProps{ - Model: model, - Message: props.Message, - Token: props.MaxTokens, - PresencePenalty: props.PresencePenalty, - FrequencyPenalty: props.FrequencyPenalty, - Temperature: props.Temperature, - TopP: props.TopP, - Tools: props.Tools, - ToolChoice: props.ToolChoice, - Buffer: props.Buffer, - }, hook) - - case globals.ClaudeChannelType: - return claude.NewChatInstanceFromConfig(conf).CreateStreamChatRequest(&claude.ChatProps{ - Model: model, - Message: props.Message, - Token: props.MaxTokens, - TopP: props.TopP, - TopK: props.TopK, - Temperature: props.Temperature, - }, hook) - - case globals.SlackChannelType: - return slack.NewChatInstanceFromConfig(conf).CreateStreamChatRequest(&slack.ChatProps{ - Message: props.Message, - }, hook) - - case globals.BingChannelType: - return bing.NewChatInstanceFromConfig(conf).CreateStreamChatRequest(&bing.ChatProps{ - Model: model, - Message: props.Message, - }, hook) - - case globals.PalmChannelType: - return palm2.NewChatInstanceFromConfig(conf).CreateStreamChatRequest(&palm2.ChatProps{ - Model: model, - Message: props.Message, - }, hook) - - case globals.SparkdeskChannelType: - return sparkdesk.NewChatInstance(conf, model).CreateStreamChatRequest(&sparkdesk.ChatProps{ - Model: model, - Message: props.Message, - Token: props.MaxTokens, - Temperature: props.Temperature, - TopK: props.TopK, - Tools: props.Tools, - Buffer: props.Buffer, - }, hook) - - case globals.ChatGLMChannelType: - return zhipuai.NewChatInstanceFromConfig(conf).CreateStreamChatRequest(&zhipuai.ChatProps{ - Model: model, - Message: props.Message, - Temperature: props.Temperature, - TopP: props.TopP, - }, hook) - - case globals.QwenChannelType: - return dashscope.NewChatInstanceFromConfig(conf).CreateStreamChatRequest(&dashscope.ChatProps{ - Model: model, - Message: props.Message, - Token: props.MaxTokens, - Temperature: props.Temperature, - TopP: props.TopP, - TopK: props.TopK, - RepetitionPenalty: props.RepetitionPenalty, - }, hook) - - case globals.HunyuanChannelType: - return hunyuan.NewChatInstanceFromConfig(conf).CreateStreamChatRequest(&hunyuan.ChatProps{ - Model: model, - Message: props.Message, - Temperature: props.Temperature, - TopP: props.TopP, - }, hook) - - case globals.BaichuanChannelType: - return baichuan.NewChatInstanceFromConfig(conf).CreateStreamChatRequest(&baichuan.ChatProps{ - Model: model, - Message: props.Message, - TopP: props.TopP, - TopK: props.TopK, - Temperature: props.Temperature, - }, hook) - - case globals.SkylarkChannelType: - return skylark.NewChatInstanceFromConfig(conf).CreateStreamChatRequest(&skylark.ChatProps{ - Model: model, - Message: props.Message, - Token: props.MaxTokens, - TopP: props.TopP, - TopK: props.TopK, - Temperature: props.Temperature, - FrequencyPenalty: props.FrequencyPenalty, - PresencePenalty: props.PresencePenalty, - RepeatPenalty: props.RepetitionPenalty, - Tools: props.Tools, - }, hook) - - case globals.ZhinaoChannelType: - return zhinao.NewChatInstanceFromConfig(conf).CreateStreamChatRequest(&zhinao.ChatProps{ - Model: model, - Message: props.Message, - Token: props.MaxTokens, - TopP: props.TopP, - TopK: props.TopK, - Temperature: props.Temperature, - RepetitionPenalty: props.RepetitionPenalty, - }, hook) - - case globals.MidjourneyChannelType: - return midjourney.NewChatInstanceFromConfig(conf).CreateStreamChatRequest(&midjourney.ChatProps{ - Model: model, - Messages: props.Message, - }, hook) - - default: - return fmt.Errorf("unknown channel type %s (model: %s)", conf.GetType(), props.Model) + factoryType := conf.GetType() + if factory, ok := channelFactories[factoryType]; ok { + return factory(conf).CreateStreamChatRequest(props, hook) } + + return fmt.Errorf("unknown channel type %s (channel #%d)", conf.GetType(), conf.GetId()) } diff --git a/adapter/azure/chat.go b/adapter/azure/chat.go index 6c70fa1..86082ab 100644 --- a/adapter/azure/chat.go +++ b/adapter/azure/chat.go @@ -1,6 +1,7 @@ package azure import ( + adaptercommon "chat/adapter/common" "chat/globals" "chat/utils" "errors" @@ -8,20 +9,7 @@ import ( "strings" ) -type ChatProps struct { - Model string - Message []globals.Message - Token *int - PresencePenalty *float32 - FrequencyPenalty *float32 - Temperature *float32 - TopP *float32 - Tools *globals.FunctionTools - ToolChoice *interface{} - Buffer utils.Buffer -} - -func (c *ChatInstance) GetChatEndpoint(props *ChatProps) string { +func (c *ChatInstance) GetChatEndpoint(props *adaptercommon.ChatProps) string { model := strings.ReplaceAll(props.Model, ".", "") if props.Model == globals.GPT3TurboInstruct { return fmt.Sprintf("%s/openai/deployments/%s/completions?api-version=%s", c.GetResource(), model, c.GetEndpoint()) @@ -37,7 +25,7 @@ func (c *ChatInstance) GetCompletionPrompt(messages []globals.Message) string { return result } -func (c *ChatInstance) GetLatestPrompt(props *ChatProps) string { +func (c *ChatInstance) GetLatestPrompt(props *adaptercommon.ChatProps) string { if len(props.Message) == 0 { return "" } @@ -45,19 +33,19 @@ func (c *ChatInstance) GetLatestPrompt(props *ChatProps) string { return props.Message[len(props.Message)-1].Content } -func (c *ChatInstance) GetChatBody(props *ChatProps, stream bool) interface{} { +func (c *ChatInstance) GetChatBody(props *adaptercommon.ChatProps, stream bool) interface{} { if props.Model == globals.GPT3TurboInstruct { // for completions return CompletionRequest{ Prompt: c.GetCompletionPrompt(props.Message), - MaxToken: props.Token, + MaxToken: props.MaxTokens, Stream: stream, } } return ChatRequest{ Messages: formatMessages(props), - MaxToken: props.Token, + MaxToken: props.MaxTokens, Stream: stream, PresencePenalty: props.PresencePenalty, FrequencyPenalty: props.FrequencyPenalty, @@ -68,8 +56,8 @@ func (c *ChatInstance) GetChatBody(props *ChatProps, stream bool) interface{} { } } -// CreateChatRequest is the native http request body for chatgpt -func (c *ChatInstance) CreateChatRequest(props *ChatProps) (string, error) { +// CreateChatRequest is the native http request body for openai +func (c *ChatInstance) CreateChatRequest(props *adaptercommon.ChatProps) (string, error) { if globals.IsOpenAIDalleModel(props.Model) { return c.CreateImage(props) } @@ -81,20 +69,20 @@ func (c *ChatInstance) CreateChatRequest(props *ChatProps) (string, error) { ) if err != nil || res == nil { - return "", fmt.Errorf("chatgpt error: %s", err.Error()) + return "", fmt.Errorf("openai error: %s", err.Error()) } data := utils.MapToStruct[ChatResponse](res) if data == nil { - return "", fmt.Errorf("chatgpt error: cannot parse response") + return "", fmt.Errorf("openai error: cannot parse response") } else if data.Error.Message != "" { - return "", fmt.Errorf("chatgpt error: %s", data.Error.Message) + return "", fmt.Errorf("openai error: %s", data.Error.Message) } return data.Choices[0].Message.Content, nil } -// CreateStreamChatRequest is the stream response body for chatgpt -func (c *ChatInstance) CreateStreamChatRequest(props *ChatProps, callback globals.Hook) error { +// CreateStreamChatRequest is the stream response body for openai +func (c *ChatInstance) CreateStreamChatRequest(props *adaptercommon.ChatProps, callback globals.Hook) error { if globals.IsOpenAIDalleModel(props.Model) { if url, err := c.CreateImage(props); err != nil { return err diff --git a/adapter/azure/image.go b/adapter/azure/image.go index 5a7e62a..74689a0 100644 --- a/adapter/azure/image.go +++ b/adapter/azure/image.go @@ -1,6 +1,7 @@ package azure import ( + adaptercommon "chat/adapter/common" "chat/globals" "chat/utils" "fmt" @@ -32,21 +33,21 @@ func (c *ChatInstance) CreateImageRequest(props ImageProps) (string, error) { N: 1, }) if err != nil || res == nil { - return "", fmt.Errorf("chatgpt error: %s", err.Error()) + return "", fmt.Errorf("openai error: %s", err.Error()) } data := utils.MapToStruct[ImageResponse](res) if data == nil { - return "", fmt.Errorf("chatgpt error: cannot parse response") + return "", fmt.Errorf("openai error: cannot parse response") } else if data.Error.Message != "" { - return "", fmt.Errorf("chatgpt error: %s", data.Error.Message) + return "", fmt.Errorf("openai error: %s", data.Error.Message) } return data.Data[0].Url, nil } // CreateImage will create a dalle image from prompt, return markdown of image -func (c *ChatInstance) CreateImage(props *ChatProps) (string, error) { +func (c *ChatInstance) CreateImage(props *adaptercommon.ChatProps) (string, error) { url, err := c.CreateImageRequest(ImageProps{ Model: props.Model, Prompt: c.GetLatestPrompt(props), diff --git a/adapter/azure/processor.go b/adapter/azure/processor.go index 1f14234..27a1785 100644 --- a/adapter/azure/processor.go +++ b/adapter/azure/processor.go @@ -1,6 +1,7 @@ package azure import ( + adaptercommon "chat/adapter/common" "chat/globals" "chat/utils" "errors" @@ -8,7 +9,7 @@ import ( "regexp" ) -func formatMessages(props *ChatProps) interface{} { +func formatMessages(props *adaptercommon.ChatProps) interface{} { if globals.IsVisionModel(props.Model) { return utils.Each[globals.Message, Message](props.Message, func(message globals.Message) Message { if message.Role == globals.User { @@ -120,7 +121,7 @@ func (c *ChatInstance) ProcessLine(data string, isCompletionType bool) (*globals }, nil } - globals.Warn(fmt.Sprintf("chatgpt error: cannot parse completion response: %s", data)) + globals.Warn(fmt.Sprintf("openai error: cannot parse completion response: %s", data)) return &globals.Chunk{Content: ""}, errors.New("parser error: cannot parse completion response") } @@ -129,9 +130,9 @@ func (c *ChatInstance) ProcessLine(data string, isCompletionType bool) (*globals } if form := processChatErrorResponse(data); form != nil { - return &globals.Chunk{Content: ""}, errors.New(fmt.Sprintf("chatgpt error: %s (type: %s)", form.Error.Message, form.Error.Type)) + return &globals.Chunk{Content: ""}, errors.New(fmt.Sprintf("openai error: %s (type: %s)", form.Error.Message, form.Error.Type)) } - globals.Warn(fmt.Sprintf("chatgpt error: cannot parse chat completion response: %s", data)) + globals.Warn(fmt.Sprintf("openai error: cannot parse chat completion response: %s", data)) return &globals.Chunk{Content: ""}, errors.New("parser error: cannot parse chat completion response") } diff --git a/adapter/azure/struct.go b/adapter/azure/struct.go index 3f5b6b5..6b66f6a 100644 --- a/adapter/azure/struct.go +++ b/adapter/azure/struct.go @@ -1,6 +1,7 @@ package azure import ( + factory "chat/adapter/common" "chat/globals" ) @@ -42,7 +43,7 @@ func NewChatInstance(endpoint, apiKey string, resource string) *ChatInstance { } } -func NewChatInstanceFromConfig(conf globals.ChannelConfig) *ChatInstance { +func NewChatInstanceFromConfig(conf globals.ChannelConfig) factory.Factory { param := conf.SplitRandomSecret(2) return NewChatInstance( conf.GetEndpoint(), diff --git a/adapter/azure/types.go b/adapter/azure/types.go index 89458bb..96ed1e4 100644 --- a/adapter/azure/types.go +++ b/adapter/azure/types.go @@ -24,7 +24,7 @@ type Message struct { ToolCalls *globals.ToolCalls `json:"tool_calls,omitempty"` // only `assistant` role } -// ChatRequest is the request body for chatgpt +// ChatRequest is the request body for openai type ChatRequest struct { Model string `json:"model"` Messages interface{} `json:"messages"` @@ -38,7 +38,7 @@ type ChatRequest struct { ToolChoice *interface{} `json:"tool_choice,omitempty"` // string or object } -// CompletionRequest is the request body for chatgpt completion +// CompletionRequest is the request body for openai completion type CompletionRequest struct { Model string `json:"model"` Prompt string `json:"prompt"` @@ -46,7 +46,7 @@ type CompletionRequest struct { Stream bool `json:"stream"` } -// ChatResponse is the native http request body for chatgpt +// ChatResponse is the native http request body for openai type ChatResponse struct { ID string `json:"id"` Object string `json:"object"` @@ -62,7 +62,7 @@ type ChatResponse struct { } `json:"error"` } -// ChatStreamResponse is the stream response body for chatgpt +// ChatStreamResponse is the stream response body for openai type ChatStreamResponse struct { ID string `json:"id"` Object string `json:"object"` @@ -75,7 +75,7 @@ type ChatStreamResponse struct { } `json:"choices"` } -// CompletionResponse is the native http request body / stream response body for chatgpt completion +// CompletionResponse is the native http request body / stream response body for openai completion type CompletionResponse struct { ID string `json:"id"` Object string `json:"object"` @@ -96,7 +96,7 @@ type ChatStreamErrorResponse struct { type ImageSize string -// ImageRequest is the request body for chatgpt dalle image generation +// ImageRequest is the request body for openai dalle image generation type ImageRequest struct { Model string `json:"model"` Prompt string `json:"prompt"` diff --git a/adapter/baichuan/chat.go b/adapter/baichuan/chat.go index d59c4e5..7c78db3 100644 --- a/adapter/baichuan/chat.go +++ b/adapter/baichuan/chat.go @@ -1,20 +1,13 @@ package baichuan import ( + adaptercommon "chat/adapter/common" "chat/globals" "chat/utils" "errors" "fmt" ) -type ChatProps struct { - Model string - Message []globals.Message - TopP *float32 - TopK *int - Temperature *float32 -} - func (c *ChatInstance) GetChatEndpoint() string { return fmt.Sprintf("%s/v1/chat/completions", c.GetEndpoint()) } @@ -38,7 +31,7 @@ func (c *ChatInstance) GetMessages(messages []globals.Message) []globals.Message return messages } -func (c *ChatInstance) GetChatBody(props *ChatProps, stream bool) ChatRequest { +func (c *ChatInstance) GetChatBody(props *adaptercommon.ChatProps, stream bool) ChatRequest { return ChatRequest{ Model: c.GetModel(props.Model), Messages: c.GetMessages(props.Message), @@ -50,7 +43,7 @@ func (c *ChatInstance) GetChatBody(props *ChatProps, stream bool) ChatRequest { } // CreateChatRequest is the native http request body for baichuan -func (c *ChatInstance) CreateChatRequest(props *ChatProps) (string, error) { +func (c *ChatInstance) CreateChatRequest(props *adaptercommon.ChatProps) (string, error) { res, err := utils.Post( c.GetChatEndpoint(), c.GetHeader(), @@ -71,7 +64,7 @@ func (c *ChatInstance) CreateChatRequest(props *ChatProps) (string, error) { } // CreateStreamChatRequest is the stream response body for baichuan -func (c *ChatInstance) CreateStreamChatRequest(props *ChatProps, callback globals.Hook) error { +func (c *ChatInstance) CreateStreamChatRequest(props *adaptercommon.ChatProps, callback globals.Hook) error { err := utils.EventScanner(&utils.EventScannerProps{ Method: "POST", Uri: c.GetChatEndpoint(), diff --git a/adapter/baichuan/struct.go b/adapter/baichuan/struct.go index cf74675..f137bb0 100644 --- a/adapter/baichuan/struct.go +++ b/adapter/baichuan/struct.go @@ -1,6 +1,7 @@ package baichuan import ( + factory "chat/adapter/common" "chat/globals" "fmt" ) @@ -32,7 +33,7 @@ func NewChatInstance(endpoint, apiKey string) *ChatInstance { } } -func NewChatInstanceFromConfig(conf globals.ChannelConfig) *ChatInstance { +func NewChatInstanceFromConfig(conf globals.ChannelConfig) factory.Factory { return NewChatInstance( conf.GetEndpoint(), conf.GetRandomSecret(), diff --git a/adapter/bing/chat.go b/adapter/bing/chat.go index 2dfac87..f371f8e 100644 --- a/adapter/bing/chat.go +++ b/adapter/bing/chat.go @@ -1,18 +1,14 @@ package bing import ( + adaptercommon "chat/adapter/common" "chat/globals" "chat/utils" "fmt" "strings" ) -type ChatProps struct { - Message []globals.Message - Model string -} - -func (c *ChatInstance) CreateStreamChatRequest(props *ChatProps, hook globals.Hook) error { +func (c *ChatInstance) CreateStreamChatRequest(props *adaptercommon.ChatProps, hook globals.Hook) error { var conn *utils.WebSocket if conn = utils.NewWebsocketClient(c.GetEndpoint()); conn == nil { return fmt.Errorf("bing error: websocket connection failed") diff --git a/adapter/bing/struct.go b/adapter/bing/struct.go index e17361d..25be0ba 100644 --- a/adapter/bing/struct.go +++ b/adapter/bing/struct.go @@ -1,6 +1,7 @@ package bing import ( + factory "chat/adapter/common" "chat/globals" "fmt" ) @@ -21,7 +22,7 @@ func NewChatInstance(endpoint, secret string) *ChatInstance { } } -func NewChatInstanceFromConfig(conf globals.ChannelConfig) *ChatInstance { +func NewChatInstanceFromConfig(conf globals.ChannelConfig) factory.Factory { return NewChatInstance( conf.GetEndpoint(), conf.GetRandomSecret(), diff --git a/adapter/chatgpt/test.go b/adapter/chatgpt/test.go deleted file mode 100644 index 221ecdc..0000000 --- a/adapter/chatgpt/test.go +++ /dev/null @@ -1,47 +0,0 @@ -package chatgpt - -import ( - "chat/globals" - "chat/utils" - "fmt" - "github.com/spf13/viper" - "strings" -) - -func (c *ChatInstance) Test() bool { - result, err := c.CreateChatRequest(&ChatProps{ - Model: globals.GPT3Turbo, - Message: []globals.Message{{Role: globals.User, Content: "hi"}}, - Token: utils.ToPtr(1), - }) - if err != nil { - fmt.Println(fmt.Sprintf("%s: test failed (%s)", c.GetApiKey(), err.Error())) - } - - return err == nil && len(result) > 0 -} - -func FilterKeys(v string) []string { - endpoint := viper.GetString(fmt.Sprintf("openai.%s.endpoint", v)) - keys := strings.Split(viper.GetString(fmt.Sprintf("openai.%s.apikey", v)), "|") - - return FilterKeysNative(endpoint, keys) -} - -func FilterKeysNative(endpoint string, keys []string) []string { - stack := make(chan string, len(keys)) - for _, key := range keys { - go func(key string) { - instance := NewChatInstance(endpoint, key) - stack <- utils.Multi[string](instance.Test(), key, "") - }(key) - } - - var result []string - for i := 0; i < len(keys); i++ { - if res := <-stack; res != "" { - result = append(result, res) - } - } - return result -} diff --git a/adapter/claude/chat.go b/adapter/claude/chat.go index e23ef82..10a5a3a 100644 --- a/adapter/claude/chat.go +++ b/adapter/claude/chat.go @@ -1,6 +1,7 @@ package claude import ( + adaptercommon "chat/adapter/common" "chat/globals" "chat/utils" "fmt" @@ -9,15 +10,6 @@ import ( const defaultTokens = 2500 -type ChatProps struct { - Model string - Message []globals.Message - Token *int - Temperature *float32 - TopP *float32 - TopK *int -} - func (c *ChatInstance) GetChatEndpoint() string { return fmt.Sprintf("%s/v1/complete", c.GetEndpoint()) } @@ -52,15 +44,15 @@ func (c *ChatInstance) ConvertMessage(message []globals.Message) string { return fmt.Sprintf("%s\n\nAssistant:", result) } -func (c *ChatInstance) GetTokens(props *ChatProps) int { - if props.Token == nil || *props.Token <= 0 { +func (c *ChatInstance) GetTokens(props *adaptercommon.ChatProps) int { + if props.MaxTokens == nil || *props.MaxTokens <= 0 { return defaultTokens } - return *props.Token + return *props.MaxTokens } -func (c *ChatInstance) GetChatBody(props *ChatProps, stream bool) *ChatBody { +func (c *ChatInstance) GetChatBody(props *adaptercommon.ChatProps, stream bool) *ChatBody { return &ChatBody{ Prompt: c.ConvertMessage(props.Message), MaxTokensToSample: c.GetTokens(props), @@ -73,7 +65,7 @@ func (c *ChatInstance) GetChatBody(props *ChatProps, stream bool) *ChatBody { } // CreateChatRequest is the request for anthropic claude -func (c *ChatInstance) CreateChatRequest(props *ChatProps) (string, error) { +func (c *ChatInstance) CreateChatRequest(props *adaptercommon.ChatProps) (string, error) { data, err := utils.Post(c.GetChatEndpoint(), c.GetChatHeaders(), c.GetChatBody(props, false)) if err != nil { return "", fmt.Errorf("claude error: %s", err.Error()) @@ -115,7 +107,7 @@ func (c *ChatInstance) ProcessLine(buf, data string) (string, error) { } // CreateStreamChatRequest is the stream request for anthropic claude -func (c *ChatInstance) CreateStreamChatRequest(props *ChatProps, hook globals.Hook) error { +func (c *ChatInstance) CreateStreamChatRequest(props *adaptercommon.ChatProps, hook globals.Hook) error { buf := "" return utils.EventSource( @@ -124,7 +116,6 @@ func (c *ChatInstance) CreateStreamChatRequest(props *ChatProps, hook globals.Ho c.GetChatHeaders(), c.GetChatBody(props, true), func(data string) error { - if resp, err := c.ProcessLine(buf, data); err == nil && len(resp) > 0 { buf = "" if err := hook(&globals.Chunk{Content: resp}); err != nil { diff --git a/adapter/claude/struct.go b/adapter/claude/struct.go index 3a4acdc..2997c7c 100644 --- a/adapter/claude/struct.go +++ b/adapter/claude/struct.go @@ -1,6 +1,7 @@ package claude import ( + factory "chat/adapter/common" "chat/globals" ) @@ -16,7 +17,7 @@ func NewChatInstance(endpoint, apiKey string) *ChatInstance { } } -func NewChatInstanceFromConfig(conf globals.ChannelConfig) *ChatInstance { +func NewChatInstanceFromConfig(conf globals.ChannelConfig) factory.Factory { return NewChatInstance( conf.GetEndpoint(), conf.GetRandomSecret(), diff --git a/adapter/common/interface.go b/adapter/common/interface.go new file mode 100644 index 0000000..f2cec99 --- /dev/null +++ b/adapter/common/interface.go @@ -0,0 +1,11 @@ +package adaptercommon + +import ( + "chat/globals" +) + +type Factory interface { + CreateStreamChatRequest(props *ChatProps, hook globals.Hook) error +} + +type FactoryCreator func(globals.ChannelConfig) Factory diff --git a/adapter/common/types.go b/adapter/common/types.go new file mode 100644 index 0000000..a0f8014 --- /dev/null +++ b/adapter/common/types.go @@ -0,0 +1,31 @@ +package adaptercommon + +import ( + "chat/globals" + "chat/utils" +) + +type RequestProps struct { + MaxRetries *int + Current int + Group string +} + +type ChatProps struct { + RequestProps + + Model string + OriginalModel string + + Message []globals.Message + MaxTokens *int + PresencePenalty *float32 + FrequencyPenalty *float32 + RepetitionPenalty *float32 + Temperature *float32 + TopP *float32 + TopK *int + Tools *globals.FunctionTools + ToolChoice *interface{} + Buffer utils.Buffer +} diff --git a/adapter/dashscope/chat.go b/adapter/dashscope/chat.go index 2bd6326..c47f94d 100644 --- a/adapter/dashscope/chat.go +++ b/adapter/dashscope/chat.go @@ -1,6 +1,7 @@ package dashscope import ( + adaptercommon "chat/adapter/common" "chat/globals" "chat/utils" "fmt" @@ -9,16 +10,6 @@ import ( const defaultMaxTokens = 1500 -type ChatProps struct { - Model string - Token *int - Temperature *float32 - TopP *float32 - TopK *int - RepetitionPenalty *float32 - Message []globals.Message -} - func (c *ChatInstance) GetHeader() map[string]string { return map[string]string{ "Content-Type": "application/json", @@ -43,16 +34,16 @@ func (c *ChatInstance) FormatMessages(message []globals.Message) []Message { return messages } -func (c *ChatInstance) GetMaxTokens(props *ChatProps) int { +func (c *ChatInstance) GetMaxTokens(props *adaptercommon.ChatProps) int { // dashscope has a restriction of 1500 tokens in completion - if props.Token == nil || *props.Token <= 0 || *props.Token > 1500 { + if props.MaxTokens == nil || *props.MaxTokens <= 0 || *props.MaxTokens > 1500 { return defaultMaxTokens } - return *props.Token + return *props.MaxTokens } -func (c *ChatInstance) GetTopP(props *ChatProps) *float32 { +func (c *ChatInstance) GetTopP(props *adaptercommon.ChatProps) *float32 { // range of top_p should be (0.0, 1.0) if props.TopP == nil { return nil @@ -67,7 +58,7 @@ func (c *ChatInstance) GetTopP(props *ChatProps) *float32 { return props.TopP } -func (c *ChatInstance) GetRepeatPenalty(props *ChatProps) *float32 { +func (c *ChatInstance) GetRepeatPenalty(props *adaptercommon.ChatProps) *float32 { // range of repetition_penalty should greater than 0.0 if props.RepetitionPenalty == nil { return nil @@ -80,7 +71,7 @@ func (c *ChatInstance) GetRepeatPenalty(props *ChatProps) *float32 { return props.RepetitionPenalty } -func (c *ChatInstance) GetChatBody(props *ChatProps) ChatRequest { +func (c *ChatInstance) GetChatBody(props *adaptercommon.ChatProps) ChatRequest { return ChatRequest{ Model: strings.TrimSuffix(props.Model, "-net"), Input: ChatInput{ @@ -102,7 +93,7 @@ func (c *ChatInstance) GetChatEndpoint() string { return fmt.Sprintf("%s/api/v1/services/aigc/text-generation/generation", c.Endpoint) } -func (c *ChatInstance) CreateStreamChatRequest(props *ChatProps, callback globals.Hook) error { +func (c *ChatInstance) CreateStreamChatRequest(props *adaptercommon.ChatProps, callback globals.Hook) error { return utils.EventSource( "POST", c.GetChatEndpoint(), diff --git a/adapter/dashscope/struct.go b/adapter/dashscope/struct.go index 858b7b8..c3f12ae 100644 --- a/adapter/dashscope/struct.go +++ b/adapter/dashscope/struct.go @@ -1,6 +1,7 @@ package dashscope import ( + factory "chat/adapter/common" "chat/globals" ) @@ -24,7 +25,7 @@ func NewChatInstance(endpoint string, apiKey string) *ChatInstance { } } -func NewChatInstanceFromConfig(conf globals.ChannelConfig) *ChatInstance { +func NewChatInstanceFromConfig(conf globals.ChannelConfig) factory.Factory { return NewChatInstance( conf.GetEndpoint(), conf.GetRandomSecret(), diff --git a/adapter/hunyuan/chat.go b/adapter/hunyuan/chat.go index 2e9363f..fa12fc4 100644 --- a/adapter/hunyuan/chat.go +++ b/adapter/hunyuan/chat.go @@ -1,18 +1,12 @@ package hunyuan import ( + adaptercommon "chat/adapter/common" "chat/globals" "context" "fmt" ) -type ChatProps struct { - Model string - Message []globals.Message - Temperature *float32 - TopP *float32 -} - func (c *ChatInstance) FormatMessages(messages []globals.Message) []globals.Message { var result []globals.Message for _, message := range messages { @@ -36,7 +30,7 @@ func (c *ChatInstance) FormatMessages(messages []globals.Message) []globals.Mess return result } -func (c *ChatInstance) CreateStreamChatRequest(props *ChatProps, callback globals.Hook) error { +func (c *ChatInstance) CreateStreamChatRequest(props *adaptercommon.ChatProps, callback globals.Hook) error { credential := NewCredential(c.GetSecretId(), c.GetSecretKey()) client := NewInstance(c.GetAppId(), c.GetEndpoint(), credential) channel, err := client.Chat(context.Background(), NewRequest(Stream, c.FormatMessages(props.Message), props.Temperature, props.TopP)) diff --git a/adapter/hunyuan/struct.go b/adapter/hunyuan/struct.go index fbe0b3d..2e2684d 100644 --- a/adapter/hunyuan/struct.go +++ b/adapter/hunyuan/struct.go @@ -1,6 +1,7 @@ package hunyuan import ( + factory "chat/adapter/common" "chat/globals" "chat/utils" ) @@ -37,7 +38,7 @@ func NewChatInstance(endpoint, appId, secretId, secretKey string) *ChatInstance } } -func NewChatInstanceFromConfig(conf globals.ChannelConfig) *ChatInstance { +func NewChatInstanceFromConfig(conf globals.ChannelConfig) factory.Factory { params := conf.SplitRandomSecret(3) return NewChatInstance( conf.GetEndpoint(), diff --git a/adapter/midjourney/chat.go b/adapter/midjourney/chat.go index de156a2..ead51a8 100644 --- a/adapter/midjourney/chat.go +++ b/adapter/midjourney/chat.go @@ -1,6 +1,7 @@ package midjourney import ( + adaptercommon "chat/adapter/common" "chat/globals" "chat/utils" "fmt" @@ -66,16 +67,16 @@ func (c *ChatInstance) GetCleanPrompt(model string, prompt string) string { return target } -func (c *ChatInstance) GetPrompt(props *ChatProps) string { - if len(props.Messages) == 0 { +func (c *ChatInstance) GetPrompt(props *adaptercommon.ChatProps) string { + if len(props.Message) == 0 { return "" } - content := props.Messages[len(props.Messages)-1].Content + content := props.Message[len(props.Message)-1].Content return c.GetCleanPrompt(props.Model, content) } -func (c *ChatInstance) CreateStreamChatRequest(props *ChatProps, callback globals.Hook) error { +func (c *ChatInstance) CreateStreamChatRequest(props *adaptercommon.ChatProps, callback globals.Hook) error { // partial response like: // ```progress // 0 @@ -95,6 +96,11 @@ func (c *ChatInstance) CreateStreamChatRequest(props *ChatProps, callback global var begin bool form, err := c.CreateStreamTask(action, prompt, func(form *StorageForm, progress int) error { + if progress == -1 { + // ping event + return callback(&globals.Chunk{Content: ""}) + } + if progress == 0 { begin = true if err := callback(&globals.Chunk{Content: "```progress\n"}); err != nil { diff --git a/adapter/midjourney/handler.go b/adapter/midjourney/handler.go index dff210a..26d6aec 100644 --- a/adapter/midjourney/handler.go +++ b/adapter/midjourney/handler.go @@ -96,6 +96,11 @@ func (c *ChatInstance) CreateStreamTask(action string, prompt string, hook func( utils.Sleep(50) form := getStorage(task) if form == nil { + // hook for ping + if err := hook(nil, -1); err != nil { + return nil, err + } + continue } diff --git a/adapter/midjourney/struct.go b/adapter/midjourney/struct.go index 7656e79..6c665a3 100644 --- a/adapter/midjourney/struct.go +++ b/adapter/midjourney/struct.go @@ -1,6 +1,7 @@ package midjourney import ( + factory "chat/adapter/common" "chat/globals" "fmt" ) @@ -47,7 +48,7 @@ func NewChatInstance(endpoint, apiSecret, whiteList string) *ChatInstance { } } -func NewChatInstanceFromConfig(conf globals.ChannelConfig) *ChatInstance { +func NewChatInstanceFromConfig(conf globals.ChannelConfig) factory.Factory { params := conf.SplitRandomSecret(2) return NewChatInstance( diff --git a/adapter/chatgpt/chat.go b/adapter/openai/chat.go similarity index 71% rename from adapter/chatgpt/chat.go rename to adapter/openai/chat.go index d3183cd..bcc0f38 100644 --- a/adapter/chatgpt/chat.go +++ b/adapter/openai/chat.go @@ -1,6 +1,7 @@ -package chatgpt +package openai import ( + adaptercommon "chat/adapter/common" "chat/globals" "chat/utils" "errors" @@ -8,20 +9,7 @@ import ( "regexp" ) -type ChatProps struct { - Model string - Message []globals.Message - Token *int - PresencePenalty *float32 - FrequencyPenalty *float32 - Temperature *float32 - TopP *float32 - Tools *globals.FunctionTools - ToolChoice *interface{} - Buffer utils.Buffer -} - -func (c *ChatInstance) GetChatEndpoint(props *ChatProps) string { +func (c *ChatInstance) GetChatEndpoint(props *adaptercommon.ChatProps) string { if props.Model == globals.GPT3TurboInstruct { return fmt.Sprintf("%s/v1/completions", c.GetEndpoint()) } @@ -36,7 +24,7 @@ func (c *ChatInstance) GetCompletionPrompt(messages []globals.Message) string { return result } -func (c *ChatInstance) GetLatestPrompt(props *ChatProps) string { +func (c *ChatInstance) GetLatestPrompt(props *adaptercommon.ChatProps) string { if len(props.Message) == 0 { return "" } @@ -44,13 +32,13 @@ func (c *ChatInstance) GetLatestPrompt(props *ChatProps) string { return props.Message[len(props.Message)-1].Content } -func (c *ChatInstance) GetChatBody(props *ChatProps, stream bool) interface{} { +func (c *ChatInstance) GetChatBody(props *adaptercommon.ChatProps, stream bool) interface{} { if props.Model == globals.GPT3TurboInstruct { // for completions return CompletionRequest{ Model: props.Model, Prompt: c.GetCompletionPrompt(props.Message), - MaxToken: props.Token, + MaxToken: props.MaxTokens, Stream: stream, } } @@ -60,7 +48,7 @@ func (c *ChatInstance) GetChatBody(props *ChatProps, stream bool) interface{} { return ChatRequest{ Model: props.Model, Messages: messages, - MaxToken: props.Token, + MaxToken: props.MaxTokens, Stream: stream, PresencePenalty: props.PresencePenalty, FrequencyPenalty: props.FrequencyPenalty, @@ -71,8 +59,8 @@ func (c *ChatInstance) GetChatBody(props *ChatProps, stream bool) interface{} { } } -// CreateChatRequest is the native http request body for chatgpt -func (c *ChatInstance) CreateChatRequest(props *ChatProps) (string, error) { +// CreateChatRequest is the native http request body for openai +func (c *ChatInstance) CreateChatRequest(props *adaptercommon.ChatProps) (string, error) { if globals.IsOpenAIDalleModel(props.Model) { return c.CreateImage(props) } @@ -84,14 +72,14 @@ func (c *ChatInstance) CreateChatRequest(props *ChatProps) (string, error) { ) if err != nil || res == nil { - return "", fmt.Errorf("chatgpt error: %s", err.Error()) + return "", fmt.Errorf("openai error: %s", err.Error()) } data := utils.MapToStruct[ChatResponse](res) if data == nil { - return "", fmt.Errorf("chatgpt error: cannot parse response") + return "", fmt.Errorf("openai error: cannot parse response") } else if data.Error.Message != "" { - return "", fmt.Errorf("chatgpt error: %s", data.Error.Message) + return "", fmt.Errorf("openai error: %s", data.Error.Message) } return data.Choices[0].Message.Content, nil } @@ -103,8 +91,8 @@ func hideRequestId(message string) string { return exp.ReplaceAllString(message, "") } -// CreateStreamChatRequest is the stream response body for chatgpt -func (c *ChatInstance) CreateStreamChatRequest(props *ChatProps, callback globals.Hook) error { +// CreateStreamChatRequest is the stream response body for openai +func (c *ChatInstance) CreateStreamChatRequest(props *adaptercommon.ChatProps, callback globals.Hook) error { if globals.IsOpenAIDalleModel(props.Model) { if url, err := c.CreateImage(props); err != nil { return err diff --git a/adapter/chatgpt/image.go b/adapter/openai/image.go similarity index 86% rename from adapter/chatgpt/image.go rename to adapter/openai/image.go index ea4ad81..91dce0b 100644 --- a/adapter/chatgpt/image.go +++ b/adapter/openai/image.go @@ -1,6 +1,7 @@ -package chatgpt +package openai import ( + adaptercommon "chat/adapter/common" "chat/globals" "chat/utils" "fmt" @@ -37,7 +38,7 @@ func (c *ChatInstance) CreateImageRequest(props ImageProps) (string, error) { data := utils.MapToStruct[ImageResponse](res) if data == nil { - return "", fmt.Errorf("chatgpt error: cannot parse response") + return "", fmt.Errorf("openai error: cannot parse response") } else if data.Error.Message != "" { return "", fmt.Errorf(data.Error.Message) } @@ -46,7 +47,7 @@ func (c *ChatInstance) CreateImageRequest(props ImageProps) (string, error) { } // CreateImage will create a dalle image from prompt, return markdown of image -func (c *ChatInstance) CreateImage(props *ChatProps) (string, error) { +func (c *ChatInstance) CreateImage(props *adaptercommon.ChatProps) (string, error) { url, err := c.CreateImageRequest(ImageProps{ Model: props.Model, Prompt: c.GetLatestPrompt(props), diff --git a/adapter/chatgpt/processor.go b/adapter/openai/processor.go similarity index 88% rename from adapter/chatgpt/processor.go rename to adapter/openai/processor.go index 0f6f729..8e36c74 100644 --- a/adapter/chatgpt/processor.go +++ b/adapter/openai/processor.go @@ -1,6 +1,7 @@ -package chatgpt +package openai import ( + adaptercommon "chat/adapter/common" "chat/globals" "chat/utils" "errors" @@ -8,7 +9,7 @@ import ( "regexp" ) -func formatMessages(props *ChatProps) interface{} { +func formatMessages(props *adaptercommon.ChatProps) interface{} { if globals.IsVisionModel(props.Model) { return utils.Each[globals.Message, Message](props.Message, func(message globals.Message) Message { if message.Role == globals.User { @@ -118,7 +119,7 @@ func (c *ChatInstance) ProcessLine(data string, isCompletionType bool) (*globals }, nil } - globals.Warn(fmt.Sprintf("chatgpt error: cannot parse completion response: %s", data)) + globals.Warn(fmt.Sprintf("openai error: cannot parse completion response: %s", data)) return &globals.Chunk{Content: ""}, errors.New("parser error: cannot parse completion response") } @@ -127,9 +128,9 @@ func (c *ChatInstance) ProcessLine(data string, isCompletionType bool) (*globals } if form := processChatErrorResponse(data); form != nil { - return &globals.Chunk{Content: ""}, errors.New(fmt.Sprintf("chatgpt error: %s (type: %s)", form.Error.Message, form.Error.Type)) + return &globals.Chunk{Content: ""}, errors.New(fmt.Sprintf("openai error: %s (type: %s)", form.Error.Message, form.Error.Type)) } - globals.Warn(fmt.Sprintf("chatgpt error: cannot parse chat completion response: %s", data)) + globals.Warn(fmt.Sprintf("openai error: cannot parse chat completion response: %s", data)) return &globals.Chunk{Content: ""}, errors.New("parser error: cannot parse chat completion response") } diff --git a/adapter/chatgpt/struct.go b/adapter/openai/struct.go similarity index 84% rename from adapter/chatgpt/struct.go rename to adapter/openai/struct.go index fd2512e..b3dcb20 100644 --- a/adapter/chatgpt/struct.go +++ b/adapter/openai/struct.go @@ -1,6 +1,7 @@ -package chatgpt +package openai import ( + factory "chat/adapter/common" "chat/globals" "fmt" ) @@ -37,7 +38,7 @@ func NewChatInstance(endpoint, apiKey string) *ChatInstance { } } -func NewChatInstanceFromConfig(conf globals.ChannelConfig) *ChatInstance { +func NewChatInstanceFromConfig(conf globals.ChannelConfig) factory.Factory { return NewChatInstance( conf.GetEndpoint(), conf.GetRandomSecret(), diff --git a/adapter/chatgpt/types.go b/adapter/openai/types.go similarity index 90% rename from adapter/chatgpt/types.go rename to adapter/openai/types.go index 45a472f..4e31e00 100644 --- a/adapter/chatgpt/types.go +++ b/adapter/openai/types.go @@ -1,4 +1,4 @@ -package chatgpt +package openai import "chat/globals" @@ -24,7 +24,7 @@ type Message struct { ToolCalls *globals.ToolCalls `json:"tool_calls,omitempty"` // only `assistant` role } -// ChatRequest is the request body for chatgpt +// ChatRequest is the request body for openai type ChatRequest struct { Model string `json:"model"` Messages interface{} `json:"messages"` @@ -38,7 +38,7 @@ type ChatRequest struct { ToolChoice *interface{} `json:"tool_choice,omitempty"` // string or object } -// CompletionRequest is the request body for chatgpt completion +// CompletionRequest is the request body for openai completion type CompletionRequest struct { Model string `json:"model"` Prompt string `json:"prompt"` @@ -46,7 +46,7 @@ type CompletionRequest struct { Stream bool `json:"stream"` } -// ChatResponse is the native http request body for chatgpt +// ChatResponse is the native http request body for openai type ChatResponse struct { ID string `json:"id"` Object string `json:"object"` @@ -62,7 +62,7 @@ type ChatResponse struct { } `json:"error"` } -// ChatStreamResponse is the stream response body for chatgpt +// ChatStreamResponse is the stream response body for openai type ChatStreamResponse struct { ID string `json:"id"` Object string `json:"object"` @@ -75,7 +75,7 @@ type ChatStreamResponse struct { } `json:"choices"` } -// CompletionResponse is the native http request body / stream response body for chatgpt completion +// CompletionResponse is the native http request body / stream response body for openai completion type CompletionResponse struct { ID string `json:"id"` Object string `json:"object"` @@ -96,7 +96,7 @@ type ChatStreamErrorResponse struct { type ImageSize string -// ImageRequest is the request body for chatgpt dalle image generation +// ImageRequest is the request body for openai dalle image generation type ImageRequest struct { Model string `json:"model"` Prompt string `json:"prompt"` diff --git a/adapter/palm2/chat.go b/adapter/palm2/chat.go index a55b56e..3e839df 100644 --- a/adapter/palm2/chat.go +++ b/adapter/palm2/chat.go @@ -1,6 +1,7 @@ package palm2 import ( + adaptercommon "chat/adapter/common" "chat/globals" "chat/utils" "fmt" @@ -8,15 +9,6 @@ import ( var geminiMaxImages = 16 -type ChatProps struct { - Model string - Message []globals.Message - Temperature *float64 - TopP *float64 - TopK *int - MaxOutputTokens *int -} - func (c *ChatInstance) GetChatEndpoint(model string) string { if model == globals.ChatBison001 { return fmt.Sprintf("%s/v1beta2/models/%s:generateMessage?key=%s", c.Endpoint, model, c.ApiKey) @@ -51,7 +43,7 @@ func (c *ChatInstance) ConvertMessage(message []globals.Message) []PalmMessage { return result } -func (c *ChatInstance) GetPalm2ChatBody(props *ChatProps) *PalmChatBody { +func (c *ChatInstance) GetPalm2ChatBody(props *adaptercommon.ChatProps) *PalmChatBody { return &PalmChatBody{ Prompt: PalmPrompt{ Messages: c.ConvertMessage(props.Message), @@ -59,12 +51,12 @@ func (c *ChatInstance) GetPalm2ChatBody(props *ChatProps) *PalmChatBody { } } -func (c *ChatInstance) GetGeminiChatBody(props *ChatProps) *GeminiChatBody { +func (c *ChatInstance) GetGeminiChatBody(props *adaptercommon.ChatProps) *GeminiChatBody { return &GeminiChatBody{ Contents: c.GetGeminiContents(props.Model, props.Message), GenerationConfig: GeminiConfig{ Temperature: props.Temperature, - MaxOutputTokens: props.MaxOutputTokens, + MaxOutputTokens: props.MaxTokens, TopP: props.TopP, TopK: props.TopK, }, @@ -95,7 +87,7 @@ func (c *ChatInstance) GetGeminiChatResponse(data interface{}) (string, error) { return "", fmt.Errorf("gemini: cannot parse response") } -func (c *ChatInstance) CreateChatRequest(props *ChatProps) (string, error) { +func (c *ChatInstance) CreateChatRequest(props *adaptercommon.ChatProps) (string, error) { uri := c.GetChatEndpoint(props.Model) if props.Model == globals.ChatBison001 { @@ -122,7 +114,7 @@ func (c *ChatInstance) CreateChatRequest(props *ChatProps) (string, error) { // CreateStreamChatRequest is the mock stream request for palm2 // tips: palm2 does not support stream request -func (c *ChatInstance) CreateStreamChatRequest(props *ChatProps, callback globals.Hook) error { +func (c *ChatInstance) CreateStreamChatRequest(props *adaptercommon.ChatProps, callback globals.Hook) error { response, err := c.CreateChatRequest(props) if err != nil { return err diff --git a/adapter/palm2/struct.go b/adapter/palm2/struct.go index 5d41cf5..ab265ce 100644 --- a/adapter/palm2/struct.go +++ b/adapter/palm2/struct.go @@ -1,6 +1,7 @@ package palm2 import ( + factory "chat/adapter/common" "chat/globals" ) @@ -24,7 +25,7 @@ func NewChatInstance(endpoint string, apiKey string) *ChatInstance { } } -func NewChatInstanceFromConfig(conf globals.ChannelConfig) *ChatInstance { +func NewChatInstanceFromConfig(conf globals.ChannelConfig) factory.Factory { return NewChatInstance( conf.GetEndpoint(), conf.GetRandomSecret(), diff --git a/adapter/palm2/types.go b/adapter/palm2/types.go index c3e1ee8..7a41551 100644 --- a/adapter/palm2/types.go +++ b/adapter/palm2/types.go @@ -31,9 +31,9 @@ type GeminiChatBody struct { } type GeminiConfig struct { - Temperature *float64 `json:"temperature,omitempty"` + Temperature *float32 `json:"temperature,omitempty"` MaxOutputTokens *int `json:"maxOutputTokens,omitempty"` - TopP *float64 `json:"topP,omitempty"` + TopP *float32 `json:"topP,omitempty"` TopK *int `json:"topK,omitempty"` } diff --git a/adapter/request.go b/adapter/request.go index 03e2461..52345b4 100644 --- a/adapter/request.go +++ b/adapter/request.go @@ -1,6 +1,7 @@ package adapter import ( + "chat/adapter/common" "chat/globals" "chat/utils" "fmt" @@ -21,24 +22,24 @@ func isQPSOverLimit(model string, err error) bool { } } -func NewChatRequest(conf globals.ChannelConfig, props *ChatProps, hook globals.Hook) error { +func NewChatRequest(conf globals.ChannelConfig, props *adaptercommon.ChatProps, hook globals.Hook) error { err := createChatRequest(conf, props, hook) retries := conf.GetRetry() props.Current++ if IsAvailableError(err) { - if isQPSOverLimit(props.Model, err) { + if isQPSOverLimit(props.OriginalModel, err) { // sleep for 0.5s to avoid qps limit - globals.Info(fmt.Sprintf("qps limit for %s, sleep and retry (times: %d)", props.Model, props.Current)) + globals.Info(fmt.Sprintf("qps limit for %s, sleep and retry (times: %d)", props.OriginalModel, props.Current)) time.Sleep(500 * time.Millisecond) return NewChatRequest(conf, props, hook) } if props.Current < retries { content := strings.Replace(err.Error(), "\n", "", -1) - globals.Warn(fmt.Sprintf("retrying chat request for %s (attempt %d/%d, error: %s)", props.Model, props.Current+1, retries, content)) + globals.Warn(fmt.Sprintf("retrying chat request for %s (attempt %d/%d, error: %s)", props.OriginalModel, props.Current+1, retries, content)) return NewChatRequest(conf, props, hook) } } diff --git a/adapter/skylark/chat.go b/adapter/skylark/chat.go index e034bab..d889d6c 100644 --- a/adapter/skylark/chat.go +++ b/adapter/skylark/chat.go @@ -1,6 +1,7 @@ package skylark import ( + adaptercommon "chat/adapter/common" "chat/globals" "chat/utils" "fmt" @@ -10,21 +11,6 @@ import ( const defaultMaxTokens int64 = 1500 -type ChatProps struct { - Model string - Message []globals.Message - Token *int - - PresencePenalty *float32 - FrequencyPenalty *float32 - RepeatPenalty *float32 - Temperature *float32 - TopP *float32 - TopK *int - Tools *globals.FunctionTools - Buffer utils.Buffer -} - func getMessages(messages []globals.Message) []*api.Message { return utils.Each[globals.Message, *api.Message](messages, func(message globals.Message) *api.Message { if message.Role == globals.Tool { @@ -47,7 +33,7 @@ func (c *ChatInstance) GetMaxTokens(token *int) int64 { return int64(*token) } -func (c *ChatInstance) CreateRequest(props *ChatProps) *api.ChatReq { +func (c *ChatInstance) CreateRequest(props *adaptercommon.ChatProps) *api.ChatReq { return &api.ChatReq{ Model: &api.Model{ Name: props.Model, @@ -59,8 +45,8 @@ func (c *ChatInstance) CreateRequest(props *ChatProps) *api.ChatReq { Temperature: utils.GetPtrVal(props.Temperature, 0.), PresencePenalty: utils.GetPtrVal(props.PresencePenalty, 0.), FrequencyPenalty: utils.GetPtrVal(props.FrequencyPenalty, 0.), - RepetitionPenalty: utils.GetPtrVal(props.RepeatPenalty, 0.), - MaxTokens: c.GetMaxTokens(props.Token), + RepetitionPenalty: utils.GetPtrVal(props.RepetitionPenalty, 0.), + MaxTokens: c.GetMaxTokens(props.MaxTokens), }, Functions: getFunctions(props.Tools), } @@ -96,7 +82,7 @@ func getChoice(choice *api.ChatResp) *globals.Chunk { } } -func (c *ChatInstance) CreateStreamChatRequest(props *ChatProps, callback globals.Hook) error { +func (c *ChatInstance) CreateStreamChatRequest(props *adaptercommon.ChatProps, callback globals.Hook) error { req := c.CreateRequest(props) channel, err := c.Instance.StreamChat(req) if err != nil { diff --git a/adapter/skylark/struct.go b/adapter/skylark/struct.go index 9e5d2fe..ee5d34c 100644 --- a/adapter/skylark/struct.go +++ b/adapter/skylark/struct.go @@ -1,6 +1,7 @@ package skylark import ( + factory "chat/adapter/common" "chat/globals" "github.com/volcengine/volc-sdk-golang/service/maas" "strings" @@ -43,7 +44,7 @@ func NewChatInstance(endpoint, accessKey, secretKey string) *ChatInstance { } } -func NewChatInstanceFromConfig(conf globals.ChannelConfig) *ChatInstance { +func NewChatInstanceFromConfig(conf globals.ChannelConfig) factory.Factory { params := conf.SplitRandomSecret(2) return NewChatInstance( diff --git a/adapter/slack/chat.go b/adapter/slack/chat.go index 4dbb4b3..7d0746d 100644 --- a/adapter/slack/chat.go +++ b/adapter/slack/chat.go @@ -1,15 +1,12 @@ package slack import ( + adaptercommon "chat/adapter/common" "chat/globals" "context" ) -type ChatProps struct { - Message []globals.Message -} - -func (c *ChatInstance) CreateStreamChatRequest(props *ChatProps, hook globals.Hook) error { +func (c *ChatInstance) CreateStreamChatRequest(props *adaptercommon.ChatProps, hook globals.Hook) error { if err := c.Instance.NewChannel(c.GetChannel()); err != nil { return err } diff --git a/adapter/slack/struct.go b/adapter/slack/struct.go index cc89de7..3abc2ec 100644 --- a/adapter/slack/struct.go +++ b/adapter/slack/struct.go @@ -1,6 +1,7 @@ package slack import ( + factory "chat/adapter/common" "chat/globals" "fmt" "github.com/bincooo/claude-api" @@ -46,7 +47,7 @@ func NewChatInstance(botId, token, channel string) *ChatInstance { } } -func NewChatInstanceFromConfig(conf globals.ChannelConfig) *ChatInstance { +func NewChatInstanceFromConfig(conf globals.ChannelConfig) factory.Factory { params := conf.SplitRandomSecret(2) return NewChatInstance( params[0], params[1], diff --git a/adapter/sparkdesk/chat.go b/adapter/sparkdesk/chat.go index a6df841..f1271d7 100644 --- a/adapter/sparkdesk/chat.go +++ b/adapter/sparkdesk/chat.go @@ -1,42 +1,33 @@ package sparkdesk import ( + adaptercommon "chat/adapter/common" "chat/globals" "chat/utils" "fmt" "strings" ) -type ChatProps struct { - Model string - Message []globals.Message - Token *int - Temperature *float32 - TopK *int - Tools *globals.FunctionTools - Buffer utils.Buffer -} - -func GetToken(props *ChatProps) *int { - if props.Token == nil { +func GetToken(props *adaptercommon.ChatProps) *int { + if props.MaxTokens == nil { return nil } switch props.Model { case globals.SparkDeskV2, globals.SparkDeskV3: - if *props.Token > 8192 { + if *props.MaxTokens > 8192 { return utils.ToPtr(8192) } case globals.SparkDesk: - if *props.Token > 4096 { + if *props.MaxTokens > 4096 { return utils.ToPtr(4096) } } - return props.Token + return props.MaxTokens } -func (c *ChatInstance) GetMessages(props *ChatProps) []Message { +func (c *ChatInstance) GetMessages(props *adaptercommon.ChatProps) []Message { var messages []Message for _, message := range props.Message { if message.Role == globals.Tool { @@ -54,7 +45,7 @@ func (c *ChatInstance) GetMessages(props *ChatProps) []Message { return messages } -func (c *ChatInstance) GetFunctionCalling(props *ChatProps) *FunctionsPayload { +func (c *ChatInstance) GetFunctionCalling(props *adaptercommon.ChatProps) *FunctionsPayload { if props.Model != globals.SparkDeskV3 || props.Tools == nil { return nil } @@ -102,9 +93,11 @@ func getChoice(form *ChatResponse) *globals.Chunk { } } -func (c *ChatInstance) CreateStreamChatRequest(props *ChatProps, hook globals.Hook) error { +func (c *ChatInstance) CreateStreamChatRequest(props *adaptercommon.ChatProps, hook globals.Hook) error { + endpoint := fmt.Sprintf("%s/%s/chat", c.Endpoint, TransformAddr(props.Model)) + var conn *utils.WebSocket - if conn = utils.NewWebsocketClient(c.GenerateUrl()); conn == nil { + if conn = utils.NewWebsocketClient(c.GenerateUrl(endpoint)); conn == nil { return fmt.Errorf("sparkdesk error: websocket connection failed") } defer conn.DeferClose() @@ -121,7 +114,7 @@ func (c *ChatInstance) CreateStreamChatRequest(props *ChatProps, hook globals.Ho }, Parameter: RequestParameter{ Chat: ChatParameter{ - Domain: c.Model, + Domain: TransformModel(props.Model), MaxToken: GetToken(props), }, }, diff --git a/adapter/sparkdesk/struct.go b/adapter/sparkdesk/struct.go index 98910a8..c1d51ca 100644 --- a/adapter/sparkdesk/struct.go +++ b/adapter/sparkdesk/struct.go @@ -1,6 +1,7 @@ package sparkdesk import ( + factory "chat/adapter/common" "chat/globals" "crypto/hmac" "crypto/sha256" @@ -15,7 +16,6 @@ type ChatInstance struct { AppId string ApiSecret string ApiKey string - Model string Endpoint string } @@ -45,24 +45,23 @@ func TransformModel(model string) string { } } -func NewChatInstance(conf globals.ChannelConfig, model string) *ChatInstance { +func NewChatInstanceFromConfig(conf globals.ChannelConfig) factory.Factory { params := conf.SplitRandomSecret(3) return &ChatInstance{ AppId: params[0], ApiSecret: params[1], ApiKey: params[2], - Model: TransformModel(model), - Endpoint: fmt.Sprintf("%s/%s/chat", conf.GetEndpoint(), TransformAddr(model)), + Endpoint: conf.GetEndpoint(), } } -func (c *ChatInstance) CreateUrl(host, date, auth string) string { +func (c *ChatInstance) CreateUrl(endpoint, host, date, auth string) string { v := make(url.Values) v.Add("host", host) v.Add("date", date) v.Add("authorization", auth) - return fmt.Sprintf("%s?%s", c.Endpoint, v.Encode()) + return fmt.Sprintf("%s?%s", endpoint, v.Encode()) } func (c *ChatInstance) Sign(data, key string) string { @@ -72,8 +71,8 @@ func (c *ChatInstance) Sign(data, key string) string { } // GenerateUrl will generate the signed url for sparkdesk api -func (c *ChatInstance) GenerateUrl() string { - uri, err := url.Parse(c.Endpoint) +func (c *ChatInstance) GenerateUrl(endpoint string) string { + uri, err := url.Parse(endpoint) if err != nil { return "" } @@ -96,5 +95,5 @@ func (c *ChatInstance) GenerateUrl() string { ), )) - return c.CreateUrl(uri.Host, date, authorization) + return c.CreateUrl(endpoint, uri.Host, date, authorization) } diff --git a/adapter/zhinao/chat.go b/adapter/zhinao/chat.go index 2736a11..0ca297c 100644 --- a/adapter/zhinao/chat.go +++ b/adapter/zhinao/chat.go @@ -1,22 +1,13 @@ package zhinao import ( + adaptercommon "chat/adapter/common" "chat/globals" "chat/utils" "fmt" "strings" ) -type ChatProps struct { - Model string - Message []globals.Message - Token *int - TopP *float32 - TopK *int - Temperature *float32 - RepetitionPenalty *float32 -} - func (c *ChatInstance) GetChatEndpoint() string { return fmt.Sprintf("%s/v1/chat/completions", c.GetEndpoint()) } @@ -30,10 +21,10 @@ func (c *ChatInstance) GetModel(model string) string { } } -func (c *ChatInstance) GetChatBody(props *ChatProps, stream bool) interface{} { +func (c *ChatInstance) GetChatBody(props *adaptercommon.ChatProps, stream bool) interface{} { // 2048 is the max token for 360GPT - if props.Token != nil && *props.Token > 2048 { - props.Token = utils.ToPtr(2048) + if props.MaxTokens != nil && *props.MaxTokens > 2048 { + props.MaxTokens = utils.ToPtr(2048) } return ChatRequest{ @@ -45,7 +36,7 @@ func (c *ChatInstance) GetChatBody(props *ChatProps, stream bool) interface{} { return &message }), - MaxToken: props.Token, + MaxToken: props.MaxTokens, Stream: stream, Temperature: props.Temperature, TopP: props.TopP, @@ -55,7 +46,7 @@ func (c *ChatInstance) GetChatBody(props *ChatProps, stream bool) interface{} { } // CreateChatRequest is the native http request body for zhinao -func (c *ChatInstance) CreateChatRequest(props *ChatProps) (string, error) { +func (c *ChatInstance) CreateChatRequest(props *adaptercommon.ChatProps) (string, error) { res, err := utils.Post( c.GetChatEndpoint(), c.GetHeader(), @@ -76,7 +67,7 @@ func (c *ChatInstance) CreateChatRequest(props *ChatProps) (string, error) { } // CreateStreamChatRequest is the stream response body for zhinao -func (c *ChatInstance) CreateStreamChatRequest(props *ChatProps, callback globals.Hook) error { +func (c *ChatInstance) CreateStreamChatRequest(props *adaptercommon.ChatProps, callback globals.Hook) error { buf := "" cursor := 0 chunk := "" diff --git a/adapter/zhinao/struct.go b/adapter/zhinao/struct.go index 4037f90..d15e406 100644 --- a/adapter/zhinao/struct.go +++ b/adapter/zhinao/struct.go @@ -1,6 +1,7 @@ package zhinao import ( + factory "chat/adapter/common" "chat/globals" "fmt" ) @@ -32,7 +33,7 @@ func NewChatInstance(endpoint, apiKey string) *ChatInstance { } } -func NewChatInstanceFromConfig(conf globals.ChannelConfig) *ChatInstance { +func NewChatInstanceFromConfig(conf globals.ChannelConfig) factory.Factory { return NewChatInstance( conf.GetEndpoint(), conf.GetRandomSecret(), diff --git a/adapter/zhipuai/chat.go b/adapter/zhipuai/chat.go index 3de4d92..62747b4 100644 --- a/adapter/zhipuai/chat.go +++ b/adapter/zhipuai/chat.go @@ -1,19 +1,13 @@ package zhipuai import ( + adaptercommon "chat/adapter/common" "chat/globals" "chat/utils" "fmt" "strings" ) -type ChatProps struct { - Model string - Message []globals.Message - Temperature *float32 `json:"temperature,omitempty"` - TopP *float32 `json:"top_p,omitempty"` -} - func (c *ChatInstance) GetChatEndpoint(model string) string { return fmt.Sprintf("%s/api/paas/v3/model-api/%s/sse-invoke", c.GetEndpoint(), c.GetModel(model)) } @@ -47,7 +41,7 @@ func (c *ChatInstance) FormatMessages(messages []globals.Message) []globals.Mess return messages } -func (c *ChatInstance) GetBody(props *ChatProps) ChatRequest { +func (c *ChatInstance) GetBody(props *adaptercommon.ChatProps) ChatRequest { return ChatRequest{ Prompt: c.FormatMessages(props.Message), TopP: props.TopP, @@ -55,7 +49,7 @@ func (c *ChatInstance) GetBody(props *ChatProps) ChatRequest { } } -func (c *ChatInstance) CreateStreamChatRequest(props *ChatProps, hook globals.Hook) error { +func (c *ChatInstance) CreateStreamChatRequest(props *adaptercommon.ChatProps, hook globals.Hook) error { return utils.EventSource( "POST", c.GetChatEndpoint(props.Model), diff --git a/adapter/zhipuai/struct.go b/adapter/zhipuai/struct.go index d41c54b..6c3c5fd 100644 --- a/adapter/zhipuai/struct.go +++ b/adapter/zhipuai/struct.go @@ -1,6 +1,7 @@ package zhipuai import ( + factory "chat/adapter/common" "chat/globals" "chat/utils" "github.com/dgrijalva/jwt-go" @@ -47,6 +48,6 @@ func NewChatInstance(endpoint, apikey string) *ChatInstance { } } -func NewChatInstanceFromConfig(conf globals.ChannelConfig) *ChatInstance { +func NewChatInstanceFromConfig(conf globals.ChannelConfig) factory.Factory { return NewChatInstance(conf.GetEndpoint(), conf.GetRandomSecret()) } diff --git a/addition/generation/prompt.go b/addition/generation/prompt.go index ecaaf09..4312421 100644 --- a/addition/generation/prompt.go +++ b/addition/generation/prompt.go @@ -1,7 +1,7 @@ package generation import ( - "chat/adapter" + "chat/adapter/common" "chat/admin" "chat/channel" "chat/globals" @@ -17,10 +17,10 @@ func CreateGeneration(group, model, prompt, path string, hook func(buffer *utils message := GenerateMessage(prompt) buffer := utils.NewBuffer(model, message, channel.ChargeInstance.GetCharge(model)) - err := channel.NewChatRequest(group, &adapter.ChatProps{ - Model: model, - Message: message, - Buffer: *buffer, + err := channel.NewChatRequest(group, &adaptercommon.ChatProps{ + OriginalModel: model, + Message: message, + Buffer: *buffer, }, func(data *globals.Chunk) error { buffer.WriteChunk(data) hook(buffer, data.Content) diff --git a/app/src/admin/channel.ts b/app/src/admin/channel.ts index 2e04126..7b9283f 100644 --- a/app/src/admin/channel.ts +++ b/app/src/admin/channel.ts @@ -32,18 +32,18 @@ export type ChannelInfo = { export const ChannelTypes: Record = { openai: "OpenAI", azure: "Azure OpenAI", - claude: "Claude", - slack: "Slack", - sparkdesk: "讯飞星火", - chatglm: "智谱 ChatGLM", - qwen: "通义千问", - hunyuan: "腾讯混元", - zhinao: "360 智脑", - baichuan: "百川 AI", - skylark: "火山方舟", - bing: "New Bing", + claude: "Anthropic Claude", palm: "Google Gemini", - midjourney: "Midjourney", + midjourney: "Midjourney Proxy", + sparkdesk: "讯飞星火 SparkDesk", + chatglm: "智谱 ChatGLM", + qwen: "通义千问 TongYi", + hunyuan: "腾讯混元 Hunyuan", + zhinao: "360智脑 360GLM", + baichuan: "百川大模型 BaichuanAI", + skylark: "云雀大模型 SkylarkLLM", + bing: "New Bing", + slack: "Slack Claude", }; export const ChannelInfos: Record = { diff --git a/app/src/store/chat.ts b/app/src/store/chat.ts index 155c511..aaf1719 100644 --- a/app/src/store/chat.ts +++ b/app/src/store/chat.ts @@ -169,7 +169,7 @@ const chatSlice = createSlice({ }); const instance = conversation.messages[conversation.messages.length - 1]; - instance.content += message.message; + if (message.message.length > 0) instance.content += message.message; if (message.keyword) instance.keyword = message.keyword; if (message.quota) instance.quota = message.quota; if (message.end) instance.end = message.end; diff --git a/channel/worker.go b/channel/worker.go index 6e95562..abc90b4 100644 --- a/channel/worker.go +++ b/channel/worker.go @@ -2,6 +2,7 @@ package channel import ( "chat/adapter" + "chat/adapter/common" "chat/globals" "chat/utils" "fmt" @@ -9,10 +10,10 @@ import ( "time" ) -func NewChatRequest(group string, props *adapter.ChatProps, hook globals.Hook) error { - ticker := ConduitInstance.GetTicker(props.Model, group) +func NewChatRequest(group string, props *adaptercommon.ChatProps, hook globals.Hook) error { + ticker := ConduitInstance.GetTicker(props.OriginalModel, group) if ticker == nil || ticker.IsEmpty() { - return fmt.Errorf("cannot find channel for model %s", props.Model) + return fmt.Errorf("cannot find channel for model %s", props.OriginalModel) } var err error @@ -23,14 +24,14 @@ func NewChatRequest(group string, props *adapter.ChatProps, hook globals.Hook) e return nil } - globals.Warn(fmt.Sprintf("[channel] caught error %s for model %s at channel %s", err.Error(), props.Model, channel.GetName())) + globals.Warn(fmt.Sprintf("[channel] caught error %s for model %s at channel %s", err.Error(), props.OriginalModel, channel.GetName())) } } - globals.Info(fmt.Sprintf("[channel] channels are exhausted for model %s", props.Model)) + globals.Info(fmt.Sprintf("[channel] channels are exhausted for model %s", props.OriginalModel)) if err == nil { - err = fmt.Errorf("channels are exhausted for model %s", props.Model) + err = fmt.Errorf("channels are exhausted for model %s", props.OriginalModel) } return err @@ -78,9 +79,14 @@ func StoreCache(cache *redis.Client, hash string, index int64, buffer *utils.Buf cache.Set(cache.Context(), key, raw, expire) } -func NewChatRequestWithCache(cache *redis.Client, buffer *utils.Buffer, group string, props *adapter.ChatProps, hook globals.Hook) (bool, error) { +func NewChatRequestWithCache(cache *redis.Client, buffer *utils.Buffer, group string, props *adaptercommon.ChatProps, hook globals.Hook) (bool, error) { hash := utils.Md5Encrypt(utils.Marshal(props)) - idx, hit, err := PreflightCache(cache, props.Model, hash, buffer, hook) + + if len(props.OriginalModel) == 0 { + props.OriginalModel = props.Model + } + + idx, hit, err := PreflightCache(cache, props.OriginalModel, hash, buffer, hook) if hit { return true, err } diff --git a/cli/exec.go b/cli/exec.go index 1b4d6b5..8752379 100644 --- a/cli/exec.go +++ b/cli/exec.go @@ -12,8 +12,6 @@ func Run() bool { Help() case "invite": CreateInvitationCommand(param) - case "filter": - FilterApiKeyCommand(param) case "token": CreateTokenCommand(param) case "root": diff --git a/cli/filter.go b/cli/filter.go deleted file mode 100644 index 7b6850c..0000000 --- a/cli/filter.go +++ /dev/null @@ -1,18 +0,0 @@ -package cli - -import ( - "chat/adapter/chatgpt" - "fmt" - "strings" -) - -func FilterApiKeyCommand(args []string) { - data := strings.Trim(strings.TrimSpace(GetArgString(args, 0)), "\"") - endpoint := "https://api.openai.com" - keys := strings.Split(data, "|") - - available := chatgpt.FilterKeysNative(endpoint, keys) - - outputInfo("filter", fmt.Sprintf("filtered %d keys, %d available, %d unavailable", len(keys), len(available), len(keys)-len(available))) - fmt.Println(strings.Join(available, "|")) -} diff --git a/globals/interface.go b/globals/interface.go index 7360d4b..9f0be6c 100644 --- a/globals/interface.go +++ b/globals/interface.go @@ -10,6 +10,7 @@ type ChannelConfig interface { SplitRandomSecret(num int) []string GetEndpoint() string ProcessError(err error) error + GetId() int } type AuthLike interface { diff --git a/manager/chat.go b/manager/chat.go index 20ae633..fd0d85b 100644 --- a/manager/chat.go +++ b/manager/chat.go @@ -2,6 +2,7 @@ package manager import ( "chat/adapter" + "chat/adapter/common" "chat/addition/web" "chat/admin" "chat/auth" @@ -90,7 +91,7 @@ func ChatHandler(conn *Connection, user *auth.User, instance *conversation.Conve hit, err := channel.NewChatRequestWithCache( cache, buffer, auth.GetGroup(db, user), - &adapter.ChatProps{ + &adaptercommon.ChatProps{ Model: model, Message: segment, Buffer: *buffer, diff --git a/manager/chat_completions.go b/manager/chat_completions.go index dd46641..03cfefd 100644 --- a/manager/chat_completions.go +++ b/manager/chat_completions.go @@ -1,7 +1,7 @@ package manager import ( - "chat/adapter" + "chat/adapter/common" "chat/addition/web" "chat/admin" "chat/auth" @@ -76,8 +76,8 @@ func ChatRelayAPI(c *gin.Context) { } } -func getChatProps(form RelayForm, messages []globals.Message, buffer *utils.Buffer, plan bool) *adapter.ChatProps { - return &adapter.ChatProps{ +func getChatProps(form RelayForm, messages []globals.Message, buffer *utils.Buffer, plan bool) *adaptercommon.ChatProps { + return &adaptercommon.ChatProps{ Model: form.Model, Message: messages, MaxTokens: form.MaxTokens, diff --git a/manager/completions.go b/manager/completions.go index cd8e9aa..47ad6b4 100644 --- a/manager/completions.go +++ b/manager/completions.go @@ -1,7 +1,7 @@ package manager import ( - "chat/adapter" + "chat/adapter/common" "chat/addition/web" "chat/admin" "chat/auth" @@ -37,7 +37,7 @@ func NativeChatHandler(c *gin.Context, user *auth.User, model string, message [] hit, err := channel.NewChatRequestWithCache( cache, buffer, auth.GetGroup(db, user), - &adapter.ChatProps{ + &adaptercommon.ChatProps{ Model: model, Message: segment, Buffer: *buffer, diff --git a/manager/images.go b/manager/images.go index 838ca99..bf1003f 100644 --- a/manager/images.go +++ b/manager/images.go @@ -1,7 +1,7 @@ package manager import ( - "chat/adapter" + "chat/adapter/common" "chat/admin" "chat/auth" "chat/channel" @@ -57,8 +57,8 @@ func ImagesRelayAPI(c *gin.Context) { createRelayImageObject(c, form, prompt, created, user, supportRelayPlan()) } -func getImageProps(form RelayImageForm, messages []globals.Message, buffer *utils.Buffer) *adapter.ChatProps { - return &adapter.ChatProps{ +func getImageProps(form RelayImageForm, messages []globals.Message, buffer *utils.Buffer) *adaptercommon.ChatProps { + return &adaptercommon.ChatProps{ Model: form.Model, Message: messages, MaxTokens: utils.ToPtr(-1),