mirror of
https://github.com/coaidev/coai.git
synced 2025-05-21 14:00:13 +09:00
add zhipuai models: chatglm_pro, chatglm_std, chatglm_lite
This commit is contained in:
parent
0e47b9e4f9
commit
fff9fd8b06
@ -7,6 +7,7 @@ import (
|
||||
"chat/adapter/palm2"
|
||||
"chat/adapter/slack"
|
||||
"chat/adapter/sparkdesk"
|
||||
"chat/adapter/zhipuai"
|
||||
"chat/globals"
|
||||
"chat/utils"
|
||||
"github.com/spf13/viper"
|
||||
@ -67,6 +68,11 @@ func NewChatRequest(props *ChatProps, hook globals.Hook) error {
|
||||
Model: props.Model,
|
||||
Message: props.Message,
|
||||
}, hook)
|
||||
} else if globals.IsZhiPuModel(props.Model) {
|
||||
return zhipuai.NewChatInstanceFromConfig().CreateStreamChatRequest(&zhipuai.ChatProps{
|
||||
Model: props.Model,
|
||||
Message: props.Message,
|
||||
}, hook)
|
||||
}
|
||||
|
||||
return nil
|
||||
|
@ -19,7 +19,7 @@ func (c *ChatInstance) CreateStreamChatRequest(props *ChatProps, hook globals.Ho
|
||||
}
|
||||
defer conn.DeferClose()
|
||||
|
||||
model, _ := strings.CutPrefix(props.Model, "bing-")
|
||||
model := strings.TrimPrefix(props.Model, "bing-")
|
||||
prompt := props.Message[len(props.Message)-1].Content
|
||||
if err := conn.SendJSON(&ChatRequest{
|
||||
Prompt: prompt,
|
||||
|
63
adapter/zhipuai/chat.go
Normal file
63
adapter/zhipuai/chat.go
Normal file
@ -0,0 +1,63 @@
|
||||
package zhipuai
|
||||
|
||||
import (
|
||||
"chat/globals"
|
||||
"chat/utils"
|
||||
"fmt"
|
||||
"strings"
|
||||
)
|
||||
|
||||
type ChatProps struct {
|
||||
Model string
|
||||
Message []globals.Message
|
||||
}
|
||||
|
||||
func (c *ChatInstance) GetChatEndpoint(model string) string {
|
||||
return fmt.Sprintf("%s/api/paas/v3/model-api/%s/sse-invoke", c.GetEndpoint(), c.GetModel(model))
|
||||
}
|
||||
|
||||
func (c *ChatInstance) GetModel(model string) string {
|
||||
switch model {
|
||||
case globals.ZhiPuChatGLMPro:
|
||||
return ChatGLMPro
|
||||
case globals.ZhiPuChatGLMStd:
|
||||
return ChatGLMStd
|
||||
case globals.ZhiPuChatGLMLite:
|
||||
return ChatGLMLite
|
||||
default:
|
||||
return ChatGLMStd
|
||||
}
|
||||
}
|
||||
|
||||
func (c *ChatInstance) FormatMessages(messages []globals.Message) []globals.Message {
|
||||
messages = utils.DeepCopy[[]globals.Message](messages)
|
||||
for i := range messages {
|
||||
if messages[i].Role == "system" {
|
||||
messages[i].Role = "user"
|
||||
}
|
||||
}
|
||||
return messages
|
||||
}
|
||||
|
||||
func (c *ChatInstance) CreateStreamChatRequest(props *ChatProps, hook globals.Hook) error {
|
||||
return utils.EventSource(
|
||||
"POST",
|
||||
c.GetChatEndpoint(props.Model),
|
||||
map[string]string{
|
||||
"Content-Type": "application/json",
|
||||
"Accept": "text/event-stream",
|
||||
"Authorization": c.GetToken(),
|
||||
},
|
||||
ChatRequest{
|
||||
Prompt: c.FormatMessages(props.Message),
|
||||
},
|
||||
func(data string) error {
|
||||
if !strings.HasPrefix(data, "data:") {
|
||||
return nil
|
||||
}
|
||||
|
||||
data = strings.TrimPrefix(data, "data:")
|
||||
return hook(data)
|
||||
},
|
||||
)
|
||||
}
|
52
adapter/zhipuai/struct.go
Normal file
52
adapter/zhipuai/struct.go
Normal file
@ -0,0 +1,52 @@
|
||||
package zhipuai
|
||||
|
||||
import (
|
||||
"chat/utils"
|
||||
"github.com/dgrijalva/jwt-go"
|
||||
"github.com/spf13/viper"
|
||||
"strings"
|
||||
"time"
|
||||
)
|
||||
|
||||
type ChatInstance struct {
|
||||
Endpoint string
|
||||
ApiKey string
|
||||
}
|
||||
|
||||
func (c *ChatInstance) GetToken() string {
|
||||
// get jwt token for zhipuai api
|
||||
segment := strings.Split(c.ApiKey, ".")
|
||||
if len(segment) != 2 {
|
||||
return ""
|
||||
}
|
||||
id, secret := segment[0], segment[1]
|
||||
|
||||
payload := utils.MapToStruct[jwt.MapClaims](Payload{
|
||||
ApiKey: id,
|
||||
Exp: time.Now().Add(time.Minute*5).Unix() * 1000,
|
||||
TimeStamp: time.Now().Unix() * 1000,
|
||||
})
|
||||
|
||||
instance := jwt.NewWithClaims(jwt.SigningMethodHS256, payload)
|
||||
instance.Header = map[string]interface{}{
|
||||
"alg": "HS256",
|
||||
"sign_type": "SIGN",
|
||||
}
|
||||
token, _ := instance.SignedString([]byte(secret))
|
||||
return token
|
||||
}
|
||||
|
||||
func (c *ChatInstance) GetEndpoint() string {
|
||||
return c.Endpoint
|
||||
}
|
||||
|
||||
func NewChatInstance(endpoint, apikey string) *ChatInstance {
|
||||
return &ChatInstance{
|
||||
Endpoint: endpoint,
|
||||
ApiKey: apikey,
|
||||
}
|
||||
}
|
||||
|
||||
func NewChatInstanceFromConfig() *ChatInstance {
|
||||
return NewChatInstance(viper.GetString("zhipuai.endpoint"), viper.GetString("zhipuai.apikey"))
|
||||
}
|
25
adapter/zhipuai/types.go
Normal file
25
adapter/zhipuai/types.go
Normal file
@ -0,0 +1,25 @@
|
||||
package zhipuai
|
||||
|
||||
import "chat/globals"
|
||||
|
||||
const (
|
||||
ChatGLMPro = "chatglm_pro"
|
||||
ChatGLMStd = "chatglm_std"
|
||||
ChatGLMLite = "chatglm_lite"
|
||||
)
|
||||
|
||||
type Payload struct {
|
||||
ApiKey string `json:"api_key"`
|
||||
Exp int64 `json:"exp"`
|
||||
TimeStamp int64 `json:"timestamp"`
|
||||
}
|
||||
|
||||
type ChatRequest struct {
|
||||
Prompt []globals.Message `json:"prompt"`
|
||||
}
|
||||
|
||||
type Occurrence struct {
|
||||
Code int `json:"code"`
|
||||
Msg string `json:"msg"`
|
||||
Success bool `json:"success"`
|
||||
}
|
@ -1,6 +1,6 @@
|
||||
import axios from "axios";
|
||||
|
||||
export const version = "3.3.4";
|
||||
export const version = "3.4.0";
|
||||
export const deploy: boolean = true;
|
||||
export let rest_api: string = "http://localhost:8094";
|
||||
export let ws_api: string = "ws://localhost:8094";
|
||||
@ -21,8 +21,9 @@ export const supportModels: string[] = [
|
||||
"SparkDesk 讯飞星火",
|
||||
"Palm2",
|
||||
"New Bing",
|
||||
// "Claude-2",
|
||||
// "Claude-2-100k",
|
||||
"智谱 ChatGLM Pro",
|
||||
"智谱 ChatGLM Std",
|
||||
"智谱 ChatGLM Lite",
|
||||
];
|
||||
|
||||
export const supportModelConvertor: Record<string, string> = {
|
||||
@ -35,6 +36,9 @@ export const supportModelConvertor: Record<string, string> = {
|
||||
"SparkDesk 讯飞星火": "spark-desk",
|
||||
Palm2: "chat-bison-001",
|
||||
"New Bing": "bing-creative",
|
||||
"智谱 ChatGLM Pro": "zhipu-chatglm-pro",
|
||||
"智谱 ChatGLM Std": "zhipu-chatglm-std",
|
||||
"智谱 ChatGLM Lite": "zhipu-chatglm-lite",
|
||||
};
|
||||
|
||||
export function login() {
|
||||
|
@ -73,15 +73,18 @@ func ReduceDalle(db *sql.DB, user *User) bool {
|
||||
}
|
||||
|
||||
func CanEnableModel(db *sql.DB, user *User, model string) bool {
|
||||
auth := user != nil
|
||||
switch model {
|
||||
case globals.GPT4, globals.GPT40613, globals.GPT40314:
|
||||
return user != nil && user.GetQuota(db) >= 5
|
||||
return auth && user.GetQuota(db) >= 5
|
||||
case globals.GPT432k, globals.GPT432k0613, globals.GPT432k0314:
|
||||
return user != nil && user.GetQuota(db) >= 50
|
||||
return auth && user.GetQuota(db) >= 50
|
||||
case globals.SparkDesk:
|
||||
return user != nil && user.GetQuota(db) >= 1
|
||||
return auth && user.GetQuota(db) >= 1
|
||||
case globals.Claude2100k:
|
||||
return user != nil && user.GetQuota(db) >= 1
|
||||
return auth && user.GetQuota(db) >= 1
|
||||
case globals.ZhiPuChatGLMPro, globals.ZhiPuChatGLMStd:
|
||||
return auth && user.GetQuota(db) >= 1
|
||||
default:
|
||||
return true
|
||||
}
|
||||
|
@ -37,6 +37,9 @@ const (
|
||||
BingCreative = "bing-creative"
|
||||
BingBalanced = "bing-balanced"
|
||||
BingPrecise = "bing-precise"
|
||||
ZhiPuChatGLMPro = "zhipu-chatglm-pro"
|
||||
ZhiPuChatGLMStd = "zhipu-chatglm-std"
|
||||
ZhiPuChatGLMLite = "zhipu-chatglm-lite"
|
||||
)
|
||||
|
||||
var GPT3TurboArray = []string{
|
||||
@ -74,6 +77,12 @@ var BingModelArray = []string{
|
||||
BingPrecise,
|
||||
}
|
||||
|
||||
var ZhiPuModelArray = []string{
|
||||
ZhiPuChatGLMPro,
|
||||
ZhiPuChatGLMStd,
|
||||
ZhiPuChatGLMLite,
|
||||
}
|
||||
|
||||
var LongContextModelArray = []string{
|
||||
GPT3Turbo16k,
|
||||
GPT3Turbo16k0613,
|
||||
@ -134,6 +143,10 @@ func IsBingModel(model string) bool {
|
||||
return in(model, BingModelArray)
|
||||
}
|
||||
|
||||
func IsZhiPuModel(model string) bool {
|
||||
return in(model, ZhiPuModelArray)
|
||||
}
|
||||
|
||||
func IsLongContextModel(model string) bool {
|
||||
return in(model, LongContextModelArray)
|
||||
}
|
||||
|
@ -121,7 +121,7 @@ func (c *Conversation) GetMessageSegment(length int) []globals.Message {
|
||||
}
|
||||
|
||||
func CopyMessage(message []globals.Message) []globals.Message {
|
||||
return utils.UnmarshalJson[[]globals.Message](utils.ToJson(message)) // deep copy
|
||||
return utils.DeepCopy[[]globals.Message](message) // deep copy
|
||||
}
|
||||
|
||||
func (c *Conversation) GetLastMessage() globals.Message {
|
||||
|
@ -52,7 +52,7 @@ func AuthMiddleware() gin.HandlerFunc {
|
||||
k := strings.TrimSpace(c.GetHeader("Authorization"))
|
||||
if k != "" {
|
||||
if strings.HasPrefix(k, "Bearer ") {
|
||||
k, _ = strings.CutPrefix(k, "Bearer ")
|
||||
k = strings.TrimPrefix(k, "Bearer ")
|
||||
}
|
||||
|
||||
if strings.HasPrefix(k, "sk-") { // api agent
|
||||
|
@ -65,6 +65,10 @@ func UnmarshalJson[T any](value string) T {
|
||||
}
|
||||
}
|
||||
|
||||
func DeepCopy[T any](value T) T {
|
||||
return UnmarshalJson[T](ToJson(value))
|
||||
}
|
||||
|
||||
func GetSegment[T any](arr []T, length int) []T {
|
||||
if length > len(arr) {
|
||||
return arr
|
||||
|
@ -1,6 +1,7 @@
|
||||
package utils
|
||||
|
||||
import (
|
||||
"chat/adapter/zhipuai"
|
||||
"chat/globals"
|
||||
"fmt"
|
||||
"github.com/pkoukk/tiktoken-go"
|
||||
@ -52,7 +53,11 @@ func GetWeightByModel(model string) int {
|
||||
globals.GPT40613,
|
||||
globals.SparkDesk:
|
||||
return 3
|
||||
case globals.GPT3Turbo0301, globals.GPT3Turbo16k0301:
|
||||
case globals.GPT3Turbo0301,
|
||||
globals.GPT3Turbo16k0301,
|
||||
globals.ZhiPuChatGLMLite,
|
||||
globals.ZhiPuChatGLMStd,
|
||||
globals.ZhiPuChatGLMPro:
|
||||
return 4 // every message follows <|start|>{role/name}\n{content}<|end|>\n
|
||||
default:
|
||||
if strings.Contains(model, globals.GPT3Turbo) {
|
||||
@ -110,6 +115,10 @@ func CountInputToken(model string, v []globals.Message) float32 {
|
||||
return 0
|
||||
case globals.Claude2100k:
|
||||
return float32(CountTokenPrice(v, model)) / 1000 * 0.008
|
||||
case zhipuai.ChatGLMPro:
|
||||
return float32(CountTokenPrice(v, model)) / 1000 * 0.1
|
||||
case zhipuai.ChatGLMStd:
|
||||
return float32(CountTokenPrice(v, model)) / 1000 * 0.05
|
||||
default:
|
||||
return 0
|
||||
}
|
||||
@ -131,6 +140,10 @@ func CountOutputToken(model string, t int) float32 {
|
||||
return 0
|
||||
case globals.Claude2100k:
|
||||
return float32(t*GetWeightByModel(model)) / 1000 * 0.008
|
||||
case zhipuai.ChatGLMPro:
|
||||
return float32(t*GetWeightByModel(model)) / 1000 * 0.1
|
||||
case zhipuai.ChatGLMStd:
|
||||
return float32(t*GetWeightByModel(model)) / 1000 * 0.05
|
||||
default:
|
||||
return 0
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user