diff --git a/adapter/adapter.go b/adapter/adapter.go index 78ee57b..576ccac 100644 --- a/adapter/adapter.go +++ b/adapter/adapter.go @@ -6,6 +6,7 @@ import ( "chat/adapter/bing" "chat/adapter/claude" adaptercommon "chat/adapter/common" + "chat/adapter/coze" "chat/adapter/dashscope" "chat/adapter/deepseek" "chat/adapter/dify" @@ -39,6 +40,7 @@ var channelFactories = map[string]adaptercommon.FactoryCreator{ globals.MidjourneyChannelType: midjourney.NewChatInstanceFromConfig, globals.DeepseekChannelType: deepseek.NewChatInstanceFromConfig, globals.DifyChannelType: dify.NewChatInstanceFromConfig, + globals.CozeChannelType: coze.NewChatInstanceFromConfig, globals.MoonshotChannelType: openai.NewChatInstanceFromConfig, // openai format globals.GroqChannelType: openai.NewChatInstanceFromConfig, // openai format diff --git a/adapter/coze/chat.go b/adapter/coze/chat.go new file mode 100644 index 0000000..d97ca36 --- /dev/null +++ b/adapter/coze/chat.go @@ -0,0 +1,204 @@ +package coze + +import ( + adaptercommon "chat/adapter/common" + "chat/globals" + "chat/utils" + "encoding/json" + "errors" + "fmt" + "strings" + "sync" + "time" +) + +type ChatInstance struct { + Endpoint string + ApiKey string + AutoSaveHistory bool + responseComplete bool +} + +func (c *ChatInstance) GetEndpoint() string { + return c.Endpoint +} + +func (c *ChatInstance) GetApiKey() string { + return c.ApiKey +} + +func (c *ChatInstance) GetHeader() map[string]string { + return map[string]string{ + "Content-Type": "application/json", + "Authorization": fmt.Sprintf("Bearer %s", c.GetApiKey()), + } +} + +func NewChatInstance(endpoint, apiKey string) *ChatInstance { + return &ChatInstance{ + Endpoint: endpoint, + ApiKey: apiKey, + AutoSaveHistory: false, + } +} + +func NewChatInstanceFromConfig(conf globals.ChannelConfig) adaptercommon.Factory { + return NewChatInstance( + conf.GetEndpoint(), + conf.GetRandomSecret(), + ) +} + +func (c *ChatInstance) GetChatEndpoint() string { + return fmt.Sprintf("%s/v3/chat", c.GetEndpoint()) +} + +func (c *ChatInstance) GetChatBody(props *adaptercommon.ChatProps, stream bool) interface{} { + additionalMessages := []EnterMessage{} + + for _, msg := range props.Message { + enterMsg := EnterMessage{ + Role: msg.Role, + Content: msg.Content, + ContentType: "text", + } + + if msg.Role == "user" { + enterMsg.Type = "question" + } else if msg.Role == "assistant" { + enterMsg.Type = "answer" + } + + additionalMessages = append(additionalMessages, enterMsg) + } + + // `user_id` is required in coze + timestamp := time.Now().UnixNano() + userID := fmt.Sprintf("user_%d", timestamp) + + return ChatRequest{ + BotID: props.Model, + UserID: userID, + AdditionalMessages: additionalMessages, + Stream: stream, + AutoSaveHistory: c.AutoSaveHistory, + } +} + +func (c *ChatInstance) ProcessLine(data string) (string, error) { + if c.responseComplete { + return "", nil + } + + if data == "" { + return "", nil + } + + chunk, complete, err := processStreamResponse(data) + if err != nil { + return "", err + } + + if complete { + c.responseComplete = true + } + + return chunk.Content, nil +} + +func (c *ChatInstance) CreateChatRequest(props *adaptercommon.ChatProps) (string, error) { + // TODO: use standard non-stream request + c.AutoSaveHistory = true + + res, err := utils.Post( + c.GetChatEndpoint(), + c.GetHeader(), + c.GetChatBody(props, false), + props.Proxy, + ) + + if err != nil || res == nil { + return "", fmt.Errorf("coze error: %s", err.Error()) + } + + responseBody := utils.Marshal(res) + response := processChatResponse(responseBody) + if response == nil { + return "", fmt.Errorf("coze error: cannot parse response") + } + + if response.Code != 0 { + return "", fmt.Errorf("coze error: %s (code: %d)", response.Msg, response.Code) + } + + var responseContent string + var responseMutex sync.Mutex + + err = c.CreateStreamChatRequest(props, func(chunk *globals.Chunk) error { + responseMutex.Lock() + defer responseMutex.Unlock() + responseContent += chunk.Content + return nil + }) + + if err != nil { + return "", err + } + + if responseContent == "" { + return "", fmt.Errorf("coze error: empty response from API") + } + + return responseContent, nil +} + +func (c *ChatInstance) CreateStreamChatRequest(props *adaptercommon.ChatProps, callback globals.Hook) error { + c.responseComplete = false + c.AutoSaveHistory = false + + err := utils.EventScanner(&utils.EventScannerProps{ + Method: "POST", + Uri: c.GetChatEndpoint(), + Headers: c.GetHeader(), + Body: c.GetChatBody(props, true), + FullSSE: true, + Callback: func(data string) error { + partial, err := c.ProcessLine(data) + if err != nil { + return err + } + + if partial != "" { + err = callback(&globals.Chunk{Content: partial}) + if err != nil { + return err + } + } + return nil + }, + }, props.Proxy) + + c.responseComplete = true + + if err != nil { + if strings.Contains(err.Body, "\"code\":") { + errorResp := processChatErrorResponse(err.Body) + if errorResp != nil && errorResp.Data.Code != 0 { + return errors.New(fmt.Sprintf("coze error: %s (code: %d)", errorResp.Data.Msg, errorResp.Data.Code)) + } + + var genericResp map[string]interface{} + if jsonErr := json.Unmarshal([]byte(err.Body), &genericResp); jsonErr == nil { + errMsg, _ := json.Marshal(genericResp) + return errors.New(fmt.Sprintf("coze error: %s", string(errMsg))) + } + } + + if err.Error != nil { + return err.Error + } + return errors.New(fmt.Sprintf("coze error: unexpected error in stream request")) + } + + return nil +} diff --git a/adapter/coze/processor.go b/adapter/coze/processor.go new file mode 100644 index 0000000..caa0833 --- /dev/null +++ b/adapter/coze/processor.go @@ -0,0 +1,147 @@ +package coze + +import ( + "chat/globals" + "chat/utils" + "errors" + "fmt" + "strconv" + "strings" +) + +func processChatResponse(data string) *ChatResponse { + if form := utils.UnmarshalForm[ChatResponse](data); form != nil { + return form + } + return nil +} + +func processChatStreamResponse(data string) *ChatStreamResponse { + if form := utils.UnmarshalForm[ChatStreamResponse](data); form != nil { + return form + } + return nil +} + +func processChatStreamData(data string) *ChatStreamData { + if form := utils.UnmarshalForm[ChatStreamData](data); form != nil { + return form + } + return nil +} + +func processChatErrorResponse(data string) *ChatStreamErrorResponse { + if form := utils.UnmarshalForm[ChatStreamErrorResponse](data); form != nil { + return form + } + return nil +} + +func processSSEData(data string) (event string, eventData string, err error) { + if data == "" { + return "", "", nil + } + + sseLines := strings.Split(data, "\n") + for _, line := range sseLines { + line = strings.TrimSpace(line) + if strings.HasPrefix(line, "event:") { + event = strings.TrimSpace(strings.TrimPrefix(line, "event:")) + } else if strings.HasPrefix(line, "data:") { + eventData = strings.TrimSpace(strings.TrimPrefix(line, "data:")) + } + } + + if eventData == "" { + return "", "", nil + } + + if strings.HasPrefix(eventData, "\"") && strings.HasSuffix(eventData, "\"") && len(eventData) > 2 { + unquoted, err := strconv.Unquote(eventData) + if err == nil { + eventData = unquoted + } + } + + return event, eventData, nil +} + +func processEventContent(event string, eventData string) (content string, complete bool, err error) { + switch event { + case "conversation.message.delta": + content, _ := parseEventContent(event, eventData) + if content != "" { + return content, false, nil + } + + streamData := processChatStreamData(eventData) + if streamData != nil && streamData.Type == "answer" && streamData.Role == "assistant" && streamData.Content != "" { + return streamData.Content, false, nil + } + case "conversation.message.completed": + return "", false, nil + case "conversation.chat.completed": + return "", true, nil + case "conversation.chat.failed": + streamData := processChatStreamData(eventData) + if streamData != nil { + if streamData.Code != 0 && streamData.Msg != "" { + return "", false, errors.New(fmt.Sprintf("coze error: %s (code: %d)", streamData.Msg, streamData.Code)) + } + } + return "", false, errors.New("coze error: conversation failed") + case "done": + return "", true, nil + } + + errorResp := processChatErrorResponse(eventData) + if errorResp != nil && errorResp.Data.Code != 0 { + return "", false, errors.New(fmt.Sprintf("coze error: %s (code: %d)", errorResp.Data.Msg, errorResp.Data.Code)) + } + + streamData := processChatStreamData(eventData) + if streamData != nil { + if streamData.Code != 0 && streamData.Msg != "" { + return "", false, errors.New(fmt.Sprintf("coze error: %s (code: %d)", streamData.Msg, streamData.Code)) + } + + if streamData.LastError.Code != 0 && streamData.LastError.Msg != "" { + return "", false, errors.New(fmt.Sprintf("coze error: %s (code: %d)", streamData.LastError.Msg, streamData.LastError.Code)) + } + } + + return "", false, nil +} + +func parseEventContent(eventType string, eventData string) (string, error) { + if eventType == "conversation.message.delta" { + streamResp := processChatStreamResponse(fmt.Sprintf(`{"event":"%s","data":%s}`, eventType, eventData)) + if streamResp != nil { + streamData := processChatStreamData(streamResp.Data) + if streamData != nil && streamData.Type == "answer" && streamData.Role == "assistant" && streamData.Content != "" { + return streamData.Content, nil + } + } + } + return "", nil +} + +func processStreamResponse(data string) (*globals.Chunk, bool, error) { + event, eventData, err := processSSEData(data) + if err != nil { + return nil, false, err + } + + if event == "" || eventData == "" { + return &globals.Chunk{Content: ""}, false, nil + } + + content, complete, err := processEventContent(event, eventData) + if err != nil { + return nil, false, err + } + + return &globals.Chunk{ + Content: content, + }, complete, nil +} diff --git a/adapter/coze/struct.go b/adapter/coze/struct.go new file mode 100644 index 0000000..2ce2738 --- /dev/null +++ b/adapter/coze/struct.go @@ -0,0 +1,98 @@ +package coze + +type ChatRequest struct { + BotID string `json:"bot_id"` + UserID string `json:"user_id"` + AdditionalMessages []EnterMessage `json:"additional_messages,omitempty"` + Stream bool `json:"stream"` + CustomVariables map[string]string `json:"custom_variables,omitempty"` + AutoSaveHistory bool `json:"auto_save_history"` + MetaData map[string]string `json:"meta_data,omitempty"` + ExtraParams map[string]string `json:"extra_params,omitempty"` + ShortcutCommand *ShortcutCommand `json:"shortcut_command,omitempty"` +} + +type EnterMessage struct { + Role string `json:"role"` + Type string `json:"type,omitempty"` + Content string `json:"content,omitempty"` + ContentType string `json:"content_type,omitempty"` + MetaData map[string]string `json:"meta_data,omitempty"` +} + +type ShortcutCommand struct { + // TODO: support for adding this on demand +} + +type ObjectString struct { + Type string `json:"type"` + Text string `json:"text,omitempty"` + FileID string `json:"file_id,omitempty"` + FileURL string `json:"file_url,omitempty"` +} + +type ChatResponse struct { + Data struct { + ID string `json:"id"` + ConversationID string `json:"conversation_id"` + BotID string `json:"bot_id"` + CreatedAt int64 `json:"created_at"` + CompletedAt int64 `json:"completed_at"` + LastError interface{} `json:"last_error"` + MetaData map[string]string `json:"meta_data"` + Status string `json:"status"` + Usage *Usage `json:"usage"` + } `json:"data"` + Code int `json:"code"` + Msg string `json:"msg"` +} + +type Usage struct { + TokenCount int `json:"token_count"` + OutputTokens int `json:"output_tokens"` + InputTokens int `json:"input_tokens"` +} + +type ChatStreamResponse struct { + Event string `json:"event"` + Data string `json:"data"` +} + +type ChatStreamData struct { + ID string `json:"id,omitempty"` + Role string `json:"role,omitempty"` + Type string `json:"type,omitempty"` + Content string `json:"content,omitempty"` + ContentType string `json:"content_type,omitempty"` + + ChatID string `json:"chat_id,omitempty"` + ConversationID string `json:"conversation_id,omitempty"` + BotID string `json:"bot_id,omitempty"` + SectionID string `json:"section_id,omitempty"` + + CreatedAt int64 `json:"created_at,omitempty"` + CompletedAt int64 `json:"completed_at,omitempty"` + UpdatedAt int64 `json:"updated_at,omitempty"` + + Status string `json:"status,omitempty"` + LastError struct { + Code int `json:"code"` + Msg string `json:"msg"` + } `json:"last_error,omitempty"` + Code int `json:"code"` + Msg string `json:"msg"` + + Usage *Usage `json:"usage,omitempty"` + + MetaData map[string]string `json:"meta_data,omitempty"` + FromModule interface{} `json:"from_module,omitempty"` + FromUnit interface{} `json:"from_unit,omitempty"` +} + +type ChatStreamErrorResponse struct { + Event string `json:"event"` + Data struct { + Code int `json:"code"` + Msg string `json:"msg"` + } `json:"data"` +} diff --git a/app/src/admin/channel.ts b/app/src/admin/channel.ts index 93c3e82..9ab95d1 100644 --- a/app/src/admin/channel.ts +++ b/app/src/admin/channel.ts @@ -68,6 +68,7 @@ export const ChannelTypes: Record = { slack: "Slack Claude", deepseek: "深度求索 DeepSeek", dify: "Dify", + coze: "扣子 Coze", }; export const ShortChannelTypes: Record = { @@ -89,6 +90,7 @@ export const ShortChannelTypes: Record = { slack: "Slack", deepseek: "深度求索", dify: "Dify", + coze: "Coze", }; export const ChannelInfos: Record = { @@ -297,6 +299,17 @@ export const ChannelInfos: Record = { "> 因此,您需要为每一个 Dify 平台的 CHATFLOW 分别创建渠道 \n" + "> 如果需要让系统自动适配 Dify 平台的图标(商业版 / Pro),请将模型名称填写为 **dify** 开头的模型,如 **dify-chat** \n", }, + coze: { + endpoint: "https://api.coze.cn", + format: "", + models: [""], + description: + "> 扣子 Coze 的模型名称即为 Coze 平台的 **bot_id** \n" + + "> 进入智能体的开发页面,开发页面 URL 中 bot 参数后的数字就是智能体 ID \n" + + "> 例如 [https://www.coze.cn/space/341****/bot/73428668*****](https://www.coze.cn/space/341****/bot/73428668*****),智能体 ID 为 73428668***** \n" + + "> 确保当前使用的访问密钥已被授予智能体所属空间的 chat 权限 \n" + + "> 如果需要让系统自动适配扣子 Coze 平台的图标(商业版 / Pro),请在 **模型映射** 中将 **bot_id** 映射为 **coze** 开头的模型,如 coze-chat>73428668***** \n", + }, }; export const defaultChannelModels: string[] = getUniqueList( diff --git a/app/src/admin/colors.ts b/app/src/admin/colors.ts index 24e0dea..ae78308 100644 --- a/app/src/admin/colors.ts +++ b/app/src/admin/colors.ts @@ -76,7 +76,7 @@ export const modelColorMapper: Record = { // ByteDance Skylark / Doubao / Coze skylark: "sky-300", doubao: "sky-300", - coze: "sky-300", + coze: "indigo-400", // Dify dify: "gray-300", diff --git a/globals/constant.go b/globals/constant.go index fb380ea..38e7915 100644 --- a/globals/constant.go +++ b/globals/constant.go @@ -27,6 +27,7 @@ const ( GroqChannelType = "groq" DeepseekChannelType = "deepseek" DifyChannelType = "dify" + CozeChannelType = "coze" ) const ( diff --git a/utils/scanner.go b/utils/scanner.go index c3d217a..54f6ffa 100644 --- a/utils/scanner.go +++ b/utils/scanner.go @@ -17,6 +17,7 @@ type EventScannerProps struct { Headers map[string]string Body interface{} Callback func(string) error + FullSSE bool } type EventScannerError struct { @@ -85,7 +86,89 @@ func EventScanner(props *EventScannerProps, config ...globals.ProxyConfig) *Even } } - scanner := bufio.NewScanner(resp.Body) + if props.FullSSE { + return processFullSSE(resp.Body, props.Callback) + } + + return processLegacySSE(resp.Body, props.Callback) +} + +func processFullSSE(body io.ReadCloser, callback func(string) error) *EventScannerError { + scanner := bufio.NewScanner(body) + var eventType, eventData string + var buffer strings.Builder + + for scanner.Scan() { + line := scanner.Text() + + if len(strings.TrimSpace(line)) == 0 { + if eventData != "" { + if eventType != "" { + buffer.WriteString("event: ") + buffer.WriteString(eventType) + buffer.WriteString("\n") + } + buffer.WriteString("data: ") + buffer.WriteString(eventData) + + eventStr := buffer.String() + if globals.DebugMode { + globals.Debug(fmt.Sprintf("[sse-full] event: %s", eventStr)) + } + + if err := callback(eventStr); err != nil { + err := body.Close() + if err != nil { + globals.Debug(fmt.Sprintf("[sse] event source close error: %s", err.Error())) + } + return &EventScannerError{Error: err} + } + + eventType = "" + eventData = "" + buffer.Reset() + } + continue + } + + if strings.HasPrefix(line, "event:") { + eventType = strings.TrimSpace(strings.TrimPrefix(line, "event:")) + continue + } + + if strings.HasPrefix(line, "data:") { + eventData = strings.TrimSpace(strings.TrimPrefix(line, "data:")) + + if eventData == "[DONE]" || strings.HasPrefix(eventData, "[DONE]") { + continue + } + } + } + + if eventData != "" { + if eventType != "" { + buffer.WriteString("event: ") + buffer.WriteString(eventType) + buffer.WriteString("\n") + } + buffer.WriteString("data: ") + buffer.WriteString(eventData) + + eventStr := buffer.String() + if globals.DebugMode { + globals.Debug(fmt.Sprintf("[sse-full] last event: %s", eventStr)) + } + + if err := callback(eventStr); err != nil { + return &EventScannerError{Error: err} + } + } + + return nil +} + +func processLegacySSE(body io.ReadCloser, callback func(string) error) *EventScannerError { + scanner := bufio.NewScanner(body) scanner.Split(func(data []byte, atEOF bool) (advance int, token []byte, err error) { if atEOF && len(data) == 0 { // when EOF and empty data @@ -125,9 +208,9 @@ func EventScanner(props *EventScannerProps, config ...globals.ProxyConfig) *Even } // callback chunk - if err := props.Callback(chunk); err != nil { + if err := callback(chunk); err != nil { // break connection on callback error - err := resp.Body.Close() + err := body.Close() if err != nil { globals.Debug(fmt.Sprintf("[sse] event source close error: %s", err.Error())) }