add zhipuai models: chatglm_pro, chatglm_std, chatglm_lite

This commit is contained in:
Zhang Minghan 2023-10-10 06:53:44 +08:00
parent 0e47b9e4f9
commit fff9fd8b06
12 changed files with 194 additions and 11 deletions

View File

@ -7,6 +7,7 @@ import (
"chat/adapter/palm2"
"chat/adapter/slack"
"chat/adapter/sparkdesk"
"chat/adapter/zhipuai"
"chat/globals"
"chat/utils"
"github.com/spf13/viper"
@ -67,6 +68,11 @@ func NewChatRequest(props *ChatProps, hook globals.Hook) error {
Model: props.Model,
Message: props.Message,
}, hook)
} else if globals.IsZhiPuModel(props.Model) {
return zhipuai.NewChatInstanceFromConfig().CreateStreamChatRequest(&zhipuai.ChatProps{
Model: props.Model,
Message: props.Message,
}, hook)
}
return nil

View File

@ -19,7 +19,7 @@ func (c *ChatInstance) CreateStreamChatRequest(props *ChatProps, hook globals.Ho
}
defer conn.DeferClose()
model, _ := strings.CutPrefix(props.Model, "bing-")
model := strings.TrimPrefix(props.Model, "bing-")
prompt := props.Message[len(props.Message)-1].Content
if err := conn.SendJSON(&ChatRequest{
Prompt: prompt,

63
adapter/zhipuai/chat.go Normal file
View File

@ -0,0 +1,63 @@
package zhipuai
import (
"chat/globals"
"chat/utils"
"fmt"
"strings"
)
type ChatProps struct {
Model string
Message []globals.Message
}
func (c *ChatInstance) GetChatEndpoint(model string) string {
return fmt.Sprintf("%s/api/paas/v3/model-api/%s/sse-invoke", c.GetEndpoint(), c.GetModel(model))
}
func (c *ChatInstance) GetModel(model string) string {
switch model {
case globals.ZhiPuChatGLMPro:
return ChatGLMPro
case globals.ZhiPuChatGLMStd:
return ChatGLMStd
case globals.ZhiPuChatGLMLite:
return ChatGLMLite
default:
return ChatGLMStd
}
}
func (c *ChatInstance) FormatMessages(messages []globals.Message) []globals.Message {
messages = utils.DeepCopy[[]globals.Message](messages)
for i := range messages {
if messages[i].Role == "system" {
messages[i].Role = "user"
}
}
return messages
}
func (c *ChatInstance) CreateStreamChatRequest(props *ChatProps, hook globals.Hook) error {
return utils.EventSource(
"POST",
c.GetChatEndpoint(props.Model),
map[string]string{
"Content-Type": "application/json",
"Accept": "text/event-stream",
"Authorization": c.GetToken(),
},
ChatRequest{
Prompt: c.FormatMessages(props.Message),
},
func(data string) error {
if !strings.HasPrefix(data, "data:") {
return nil
}
data = strings.TrimPrefix(data, "data:")
return hook(data)
},
)
}

52
adapter/zhipuai/struct.go Normal file
View File

@ -0,0 +1,52 @@
package zhipuai
import (
"chat/utils"
"github.com/dgrijalva/jwt-go"
"github.com/spf13/viper"
"strings"
"time"
)
type ChatInstance struct {
Endpoint string
ApiKey string
}
func (c *ChatInstance) GetToken() string {
// get jwt token for zhipuai api
segment := strings.Split(c.ApiKey, ".")
if len(segment) != 2 {
return ""
}
id, secret := segment[0], segment[1]
payload := utils.MapToStruct[jwt.MapClaims](Payload{
ApiKey: id,
Exp: time.Now().Add(time.Minute*5).Unix() * 1000,
TimeStamp: time.Now().Unix() * 1000,
})
instance := jwt.NewWithClaims(jwt.SigningMethodHS256, payload)
instance.Header = map[string]interface{}{
"alg": "HS256",
"sign_type": "SIGN",
}
token, _ := instance.SignedString([]byte(secret))
return token
}
func (c *ChatInstance) GetEndpoint() string {
return c.Endpoint
}
func NewChatInstance(endpoint, apikey string) *ChatInstance {
return &ChatInstance{
Endpoint: endpoint,
ApiKey: apikey,
}
}
func NewChatInstanceFromConfig() *ChatInstance {
return NewChatInstance(viper.GetString("zhipuai.endpoint"), viper.GetString("zhipuai.apikey"))
}

25
adapter/zhipuai/types.go Normal file
View File

@ -0,0 +1,25 @@
package zhipuai
import "chat/globals"
const (
ChatGLMPro = "chatglm_pro"
ChatGLMStd = "chatglm_std"
ChatGLMLite = "chatglm_lite"
)
type Payload struct {
ApiKey string `json:"api_key"`
Exp int64 `json:"exp"`
TimeStamp int64 `json:"timestamp"`
}
type ChatRequest struct {
Prompt []globals.Message `json:"prompt"`
}
type Occurrence struct {
Code int `json:"code"`
Msg string `json:"msg"`
Success bool `json:"success"`
}

View File

@ -1,6 +1,6 @@
import axios from "axios";
export const version = "3.3.4";
export const version = "3.4.0";
export const deploy: boolean = true;
export let rest_api: string = "http://localhost:8094";
export let ws_api: string = "ws://localhost:8094";
@ -21,8 +21,9 @@ export const supportModels: string[] = [
"SparkDesk 讯飞星火",
"Palm2",
"New Bing",
// "Claude-2",
// "Claude-2-100k",
"智谱 ChatGLM Pro",
"智谱 ChatGLM Std",
"智谱 ChatGLM Lite",
];
export const supportModelConvertor: Record<string, string> = {
@ -35,6 +36,9 @@ export const supportModelConvertor: Record<string, string> = {
"SparkDesk 讯飞星火": "spark-desk",
Palm2: "chat-bison-001",
"New Bing": "bing-creative",
"智谱 ChatGLM Pro": "zhipu-chatglm-pro",
"智谱 ChatGLM Std": "zhipu-chatglm-std",
"智谱 ChatGLM Lite": "zhipu-chatglm-lite",
};
export function login() {

View File

@ -73,15 +73,18 @@ func ReduceDalle(db *sql.DB, user *User) bool {
}
func CanEnableModel(db *sql.DB, user *User, model string) bool {
auth := user != nil
switch model {
case globals.GPT4, globals.GPT40613, globals.GPT40314:
return user != nil && user.GetQuota(db) >= 5
return auth && user.GetQuota(db) >= 5
case globals.GPT432k, globals.GPT432k0613, globals.GPT432k0314:
return user != nil && user.GetQuota(db) >= 50
return auth && user.GetQuota(db) >= 50
case globals.SparkDesk:
return user != nil && user.GetQuota(db) >= 1
return auth && user.GetQuota(db) >= 1
case globals.Claude2100k:
return user != nil && user.GetQuota(db) >= 1
return auth && user.GetQuota(db) >= 1
case globals.ZhiPuChatGLMPro, globals.ZhiPuChatGLMStd:
return auth && user.GetQuota(db) >= 1
default:
return true
}

View File

@ -37,6 +37,9 @@ const (
BingCreative = "bing-creative"
BingBalanced = "bing-balanced"
BingPrecise = "bing-precise"
ZhiPuChatGLMPro = "zhipu-chatglm-pro"
ZhiPuChatGLMStd = "zhipu-chatglm-std"
ZhiPuChatGLMLite = "zhipu-chatglm-lite"
)
var GPT3TurboArray = []string{
@ -74,6 +77,12 @@ var BingModelArray = []string{
BingPrecise,
}
var ZhiPuModelArray = []string{
ZhiPuChatGLMPro,
ZhiPuChatGLMStd,
ZhiPuChatGLMLite,
}
var LongContextModelArray = []string{
GPT3Turbo16k,
GPT3Turbo16k0613,
@ -134,6 +143,10 @@ func IsBingModel(model string) bool {
return in(model, BingModelArray)
}
func IsZhiPuModel(model string) bool {
return in(model, ZhiPuModelArray)
}
func IsLongContextModel(model string) bool {
return in(model, LongContextModelArray)
}

View File

@ -121,7 +121,7 @@ func (c *Conversation) GetMessageSegment(length int) []globals.Message {
}
func CopyMessage(message []globals.Message) []globals.Message {
return utils.UnmarshalJson[[]globals.Message](utils.ToJson(message)) // deep copy
return utils.DeepCopy[[]globals.Message](message) // deep copy
}
func (c *Conversation) GetLastMessage() globals.Message {

View File

@ -52,7 +52,7 @@ func AuthMiddleware() gin.HandlerFunc {
k := strings.TrimSpace(c.GetHeader("Authorization"))
if k != "" {
if strings.HasPrefix(k, "Bearer ") {
k, _ = strings.CutPrefix(k, "Bearer ")
k = strings.TrimPrefix(k, "Bearer ")
}
if strings.HasPrefix(k, "sk-") { // api agent

View File

@ -65,6 +65,10 @@ func UnmarshalJson[T any](value string) T {
}
}
func DeepCopy[T any](value T) T {
return UnmarshalJson[T](ToJson(value))
}
func GetSegment[T any](arr []T, length int) []T {
if length > len(arr) {
return arr

View File

@ -1,6 +1,7 @@
package utils
import (
"chat/adapter/zhipuai"
"chat/globals"
"fmt"
"github.com/pkoukk/tiktoken-go"
@ -52,7 +53,11 @@ func GetWeightByModel(model string) int {
globals.GPT40613,
globals.SparkDesk:
return 3
case globals.GPT3Turbo0301, globals.GPT3Turbo16k0301:
case globals.GPT3Turbo0301,
globals.GPT3Turbo16k0301,
globals.ZhiPuChatGLMLite,
globals.ZhiPuChatGLMStd,
globals.ZhiPuChatGLMPro:
return 4 // every message follows <|start|>{role/name}\n{content}<|end|>\n
default:
if strings.Contains(model, globals.GPT3Turbo) {
@ -110,6 +115,10 @@ func CountInputToken(model string, v []globals.Message) float32 {
return 0
case globals.Claude2100k:
return float32(CountTokenPrice(v, model)) / 1000 * 0.008
case zhipuai.ChatGLMPro:
return float32(CountTokenPrice(v, model)) / 1000 * 0.1
case zhipuai.ChatGLMStd:
return float32(CountTokenPrice(v, model)) / 1000 * 0.05
default:
return 0
}
@ -131,6 +140,10 @@ func CountOutputToken(model string, t int) float32 {
return 0
case globals.Claude2100k:
return float32(t*GetWeightByModel(model)) / 1000 * 0.008
case zhipuai.ChatGLMPro:
return float32(t*GetWeightByModel(model)) / 1000 * 0.1
case zhipuai.ChatGLMStd:
return float32(t*GetWeightByModel(model)) / 1000 * 0.05
default:
return 0
}