From fff9fd8b06cac03c55ebefb03524bda293292d5d Mon Sep 17 00:00:00 2001 From: Zhang Minghan Date: Tue, 10 Oct 2023 06:53:44 +0800 Subject: [PATCH] add zhipuai models: chatglm_pro, chatglm_std, chatglm_lite --- adapter/adapter.go | 6 +++ adapter/bing/chat.go | 2 +- adapter/zhipuai/chat.go | 63 ++++++++++++++++++++++++++++ adapter/zhipuai/struct.go | 52 +++++++++++++++++++++++ adapter/zhipuai/types.go | 25 +++++++++++ app/src/conf.ts | 10 +++-- auth/payment.go | 11 +++-- globals/variables.go | 13 ++++++ manager/conversation/conversation.go | 2 +- middleware/auth.go | 2 +- utils/base.go | 4 ++ utils/tokenizer.go | 15 ++++++- 12 files changed, 194 insertions(+), 11 deletions(-) create mode 100644 adapter/zhipuai/chat.go create mode 100644 adapter/zhipuai/struct.go create mode 100644 adapter/zhipuai/types.go diff --git a/adapter/adapter.go b/adapter/adapter.go index af48fb0..04169cc 100644 --- a/adapter/adapter.go +++ b/adapter/adapter.go @@ -7,6 +7,7 @@ import ( "chat/adapter/palm2" "chat/adapter/slack" "chat/adapter/sparkdesk" + "chat/adapter/zhipuai" "chat/globals" "chat/utils" "github.com/spf13/viper" @@ -67,6 +68,11 @@ func NewChatRequest(props *ChatProps, hook globals.Hook) error { Model: props.Model, Message: props.Message, }, hook) + } else if globals.IsZhiPuModel(props.Model) { + return zhipuai.NewChatInstanceFromConfig().CreateStreamChatRequest(&zhipuai.ChatProps{ + Model: props.Model, + Message: props.Message, + }, hook) } return nil diff --git a/adapter/bing/chat.go b/adapter/bing/chat.go index a8bf58e..3ec85a1 100644 --- a/adapter/bing/chat.go +++ b/adapter/bing/chat.go @@ -19,7 +19,7 @@ func (c *ChatInstance) CreateStreamChatRequest(props *ChatProps, hook globals.Ho } defer conn.DeferClose() - model, _ := strings.CutPrefix(props.Model, "bing-") + model := strings.TrimPrefix(props.Model, "bing-") prompt := props.Message[len(props.Message)-1].Content if err := conn.SendJSON(&ChatRequest{ Prompt: prompt, diff --git a/adapter/zhipuai/chat.go b/adapter/zhipuai/chat.go new file mode 100644 index 0000000..9edbe42 --- /dev/null +++ b/adapter/zhipuai/chat.go @@ -0,0 +1,63 @@ +package zhipuai + +import ( + "chat/globals" + "chat/utils" + "fmt" + "strings" +) + +type ChatProps struct { + Model string + Message []globals.Message +} + +func (c *ChatInstance) GetChatEndpoint(model string) string { + return fmt.Sprintf("%s/api/paas/v3/model-api/%s/sse-invoke", c.GetEndpoint(), c.GetModel(model)) +} + +func (c *ChatInstance) GetModel(model string) string { + switch model { + case globals.ZhiPuChatGLMPro: + return ChatGLMPro + case globals.ZhiPuChatGLMStd: + return ChatGLMStd + case globals.ZhiPuChatGLMLite: + return ChatGLMLite + default: + return ChatGLMStd + } +} + +func (c *ChatInstance) FormatMessages(messages []globals.Message) []globals.Message { + messages = utils.DeepCopy[[]globals.Message](messages) + for i := range messages { + if messages[i].Role == "system" { + messages[i].Role = "user" + } + } + return messages +} + +func (c *ChatInstance) CreateStreamChatRequest(props *ChatProps, hook globals.Hook) error { + return utils.EventSource( + "POST", + c.GetChatEndpoint(props.Model), + map[string]string{ + "Content-Type": "application/json", + "Accept": "text/event-stream", + "Authorization": c.GetToken(), + }, + ChatRequest{ + Prompt: c.FormatMessages(props.Message), + }, + func(data string) error { + if !strings.HasPrefix(data, "data:") { + return nil + } + + data = strings.TrimPrefix(data, "data:") + return hook(data) + }, + ) +} diff --git a/adapter/zhipuai/struct.go b/adapter/zhipuai/struct.go new file mode 100644 index 0000000..ac4250f --- /dev/null +++ b/adapter/zhipuai/struct.go @@ -0,0 +1,52 @@ +package zhipuai + +import ( + "chat/utils" + "github.com/dgrijalva/jwt-go" + "github.com/spf13/viper" + "strings" + "time" +) + +type ChatInstance struct { + Endpoint string + ApiKey string +} + +func (c *ChatInstance) GetToken() string { + // get jwt token for zhipuai api + segment := strings.Split(c.ApiKey, ".") + if len(segment) != 2 { + return "" + } + id, secret := segment[0], segment[1] + + payload := utils.MapToStruct[jwt.MapClaims](Payload{ + ApiKey: id, + Exp: time.Now().Add(time.Minute*5).Unix() * 1000, + TimeStamp: time.Now().Unix() * 1000, + }) + + instance := jwt.NewWithClaims(jwt.SigningMethodHS256, payload) + instance.Header = map[string]interface{}{ + "alg": "HS256", + "sign_type": "SIGN", + } + token, _ := instance.SignedString([]byte(secret)) + return token +} + +func (c *ChatInstance) GetEndpoint() string { + return c.Endpoint +} + +func NewChatInstance(endpoint, apikey string) *ChatInstance { + return &ChatInstance{ + Endpoint: endpoint, + ApiKey: apikey, + } +} + +func NewChatInstanceFromConfig() *ChatInstance { + return NewChatInstance(viper.GetString("zhipuai.endpoint"), viper.GetString("zhipuai.apikey")) +} diff --git a/adapter/zhipuai/types.go b/adapter/zhipuai/types.go new file mode 100644 index 0000000..2e52def --- /dev/null +++ b/adapter/zhipuai/types.go @@ -0,0 +1,25 @@ +package zhipuai + +import "chat/globals" + +const ( + ChatGLMPro = "chatglm_pro" + ChatGLMStd = "chatglm_std" + ChatGLMLite = "chatglm_lite" +) + +type Payload struct { + ApiKey string `json:"api_key"` + Exp int64 `json:"exp"` + TimeStamp int64 `json:"timestamp"` +} + +type ChatRequest struct { + Prompt []globals.Message `json:"prompt"` +} + +type Occurrence struct { + Code int `json:"code"` + Msg string `json:"msg"` + Success bool `json:"success"` +} diff --git a/app/src/conf.ts b/app/src/conf.ts index 4355df0..9145ed4 100644 --- a/app/src/conf.ts +++ b/app/src/conf.ts @@ -1,6 +1,6 @@ import axios from "axios"; -export const version = "3.3.4"; +export const version = "3.4.0"; export const deploy: boolean = true; export let rest_api: string = "http://localhost:8094"; export let ws_api: string = "ws://localhost:8094"; @@ -21,8 +21,9 @@ export const supportModels: string[] = [ "SparkDesk 讯飞星火", "Palm2", "New Bing", - // "Claude-2", - // "Claude-2-100k", + "智谱 ChatGLM Pro", + "智谱 ChatGLM Std", + "智谱 ChatGLM Lite", ]; export const supportModelConvertor: Record = { @@ -35,6 +36,9 @@ export const supportModelConvertor: Record = { "SparkDesk 讯飞星火": "spark-desk", Palm2: "chat-bison-001", "New Bing": "bing-creative", + "智谱 ChatGLM Pro": "zhipu-chatglm-pro", + "智谱 ChatGLM Std": "zhipu-chatglm-std", + "智谱 ChatGLM Lite": "zhipu-chatglm-lite", }; export function login() { diff --git a/auth/payment.go b/auth/payment.go index 96cfba9..d5ca693 100644 --- a/auth/payment.go +++ b/auth/payment.go @@ -73,15 +73,18 @@ func ReduceDalle(db *sql.DB, user *User) bool { } func CanEnableModel(db *sql.DB, user *User, model string) bool { + auth := user != nil switch model { case globals.GPT4, globals.GPT40613, globals.GPT40314: - return user != nil && user.GetQuota(db) >= 5 + return auth && user.GetQuota(db) >= 5 case globals.GPT432k, globals.GPT432k0613, globals.GPT432k0314: - return user != nil && user.GetQuota(db) >= 50 + return auth && user.GetQuota(db) >= 50 case globals.SparkDesk: - return user != nil && user.GetQuota(db) >= 1 + return auth && user.GetQuota(db) >= 1 case globals.Claude2100k: - return user != nil && user.GetQuota(db) >= 1 + return auth && user.GetQuota(db) >= 1 + case globals.ZhiPuChatGLMPro, globals.ZhiPuChatGLMStd: + return auth && user.GetQuota(db) >= 1 default: return true } diff --git a/globals/variables.go b/globals/variables.go index eab0789..905abf9 100644 --- a/globals/variables.go +++ b/globals/variables.go @@ -37,6 +37,9 @@ const ( BingCreative = "bing-creative" BingBalanced = "bing-balanced" BingPrecise = "bing-precise" + ZhiPuChatGLMPro = "zhipu-chatglm-pro" + ZhiPuChatGLMStd = "zhipu-chatglm-std" + ZhiPuChatGLMLite = "zhipu-chatglm-lite" ) var GPT3TurboArray = []string{ @@ -74,6 +77,12 @@ var BingModelArray = []string{ BingPrecise, } +var ZhiPuModelArray = []string{ + ZhiPuChatGLMPro, + ZhiPuChatGLMStd, + ZhiPuChatGLMLite, +} + var LongContextModelArray = []string{ GPT3Turbo16k, GPT3Turbo16k0613, @@ -134,6 +143,10 @@ func IsBingModel(model string) bool { return in(model, BingModelArray) } +func IsZhiPuModel(model string) bool { + return in(model, ZhiPuModelArray) +} + func IsLongContextModel(model string) bool { return in(model, LongContextModelArray) } diff --git a/manager/conversation/conversation.go b/manager/conversation/conversation.go index 3108507..0dabc7f 100644 --- a/manager/conversation/conversation.go +++ b/manager/conversation/conversation.go @@ -121,7 +121,7 @@ func (c *Conversation) GetMessageSegment(length int) []globals.Message { } func CopyMessage(message []globals.Message) []globals.Message { - return utils.UnmarshalJson[[]globals.Message](utils.ToJson(message)) // deep copy + return utils.DeepCopy[[]globals.Message](message) // deep copy } func (c *Conversation) GetLastMessage() globals.Message { diff --git a/middleware/auth.go b/middleware/auth.go index 0d31a8e..0677152 100644 --- a/middleware/auth.go +++ b/middleware/auth.go @@ -52,7 +52,7 @@ func AuthMiddleware() gin.HandlerFunc { k := strings.TrimSpace(c.GetHeader("Authorization")) if k != "" { if strings.HasPrefix(k, "Bearer ") { - k, _ = strings.CutPrefix(k, "Bearer ") + k = strings.TrimPrefix(k, "Bearer ") } if strings.HasPrefix(k, "sk-") { // api agent diff --git a/utils/base.go b/utils/base.go index aadb97a..fc27abb 100644 --- a/utils/base.go +++ b/utils/base.go @@ -65,6 +65,10 @@ func UnmarshalJson[T any](value string) T { } } +func DeepCopy[T any](value T) T { + return UnmarshalJson[T](ToJson(value)) +} + func GetSegment[T any](arr []T, length int) []T { if length > len(arr) { return arr diff --git a/utils/tokenizer.go b/utils/tokenizer.go index cdda6a6..99dbffd 100644 --- a/utils/tokenizer.go +++ b/utils/tokenizer.go @@ -1,6 +1,7 @@ package utils import ( + "chat/adapter/zhipuai" "chat/globals" "fmt" "github.com/pkoukk/tiktoken-go" @@ -52,7 +53,11 @@ func GetWeightByModel(model string) int { globals.GPT40613, globals.SparkDesk: return 3 - case globals.GPT3Turbo0301, globals.GPT3Turbo16k0301: + case globals.GPT3Turbo0301, + globals.GPT3Turbo16k0301, + globals.ZhiPuChatGLMLite, + globals.ZhiPuChatGLMStd, + globals.ZhiPuChatGLMPro: return 4 // every message follows <|start|>{role/name}\n{content}<|end|>\n default: if strings.Contains(model, globals.GPT3Turbo) { @@ -110,6 +115,10 @@ func CountInputToken(model string, v []globals.Message) float32 { return 0 case globals.Claude2100k: return float32(CountTokenPrice(v, model)) / 1000 * 0.008 + case zhipuai.ChatGLMPro: + return float32(CountTokenPrice(v, model)) / 1000 * 0.1 + case zhipuai.ChatGLMStd: + return float32(CountTokenPrice(v, model)) / 1000 * 0.05 default: return 0 } @@ -131,6 +140,10 @@ func CountOutputToken(model string, t int) float32 { return 0 case globals.Claude2100k: return float32(t*GetWeightByModel(model)) / 1000 * 0.008 + case zhipuai.ChatGLMPro: + return float32(t*GetWeightByModel(model)) / 1000 * 0.1 + case zhipuai.ChatGLMStd: + return float32(t*GetWeightByModel(model)) / 1000 * 0.05 default: return 0 }