mirror of
https://github.com/coaidev/coai.git
synced 2025-05-21 22:10:12 +09:00
add zhipuai models: chatglm_pro, chatglm_std, chatglm_lite
This commit is contained in:
parent
0e47b9e4f9
commit
fff9fd8b06
@ -7,6 +7,7 @@ import (
|
|||||||
"chat/adapter/palm2"
|
"chat/adapter/palm2"
|
||||||
"chat/adapter/slack"
|
"chat/adapter/slack"
|
||||||
"chat/adapter/sparkdesk"
|
"chat/adapter/sparkdesk"
|
||||||
|
"chat/adapter/zhipuai"
|
||||||
"chat/globals"
|
"chat/globals"
|
||||||
"chat/utils"
|
"chat/utils"
|
||||||
"github.com/spf13/viper"
|
"github.com/spf13/viper"
|
||||||
@ -67,6 +68,11 @@ func NewChatRequest(props *ChatProps, hook globals.Hook) error {
|
|||||||
Model: props.Model,
|
Model: props.Model,
|
||||||
Message: props.Message,
|
Message: props.Message,
|
||||||
}, hook)
|
}, hook)
|
||||||
|
} else if globals.IsZhiPuModel(props.Model) {
|
||||||
|
return zhipuai.NewChatInstanceFromConfig().CreateStreamChatRequest(&zhipuai.ChatProps{
|
||||||
|
Model: props.Model,
|
||||||
|
Message: props.Message,
|
||||||
|
}, hook)
|
||||||
}
|
}
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
|
@ -19,7 +19,7 @@ func (c *ChatInstance) CreateStreamChatRequest(props *ChatProps, hook globals.Ho
|
|||||||
}
|
}
|
||||||
defer conn.DeferClose()
|
defer conn.DeferClose()
|
||||||
|
|
||||||
model, _ := strings.CutPrefix(props.Model, "bing-")
|
model := strings.TrimPrefix(props.Model, "bing-")
|
||||||
prompt := props.Message[len(props.Message)-1].Content
|
prompt := props.Message[len(props.Message)-1].Content
|
||||||
if err := conn.SendJSON(&ChatRequest{
|
if err := conn.SendJSON(&ChatRequest{
|
||||||
Prompt: prompt,
|
Prompt: prompt,
|
||||||
|
63
adapter/zhipuai/chat.go
Normal file
63
adapter/zhipuai/chat.go
Normal file
@ -0,0 +1,63 @@
|
|||||||
|
package zhipuai
|
||||||
|
|
||||||
|
import (
|
||||||
|
"chat/globals"
|
||||||
|
"chat/utils"
|
||||||
|
"fmt"
|
||||||
|
"strings"
|
||||||
|
)
|
||||||
|
|
||||||
|
type ChatProps struct {
|
||||||
|
Model string
|
||||||
|
Message []globals.Message
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *ChatInstance) GetChatEndpoint(model string) string {
|
||||||
|
return fmt.Sprintf("%s/api/paas/v3/model-api/%s/sse-invoke", c.GetEndpoint(), c.GetModel(model))
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *ChatInstance) GetModel(model string) string {
|
||||||
|
switch model {
|
||||||
|
case globals.ZhiPuChatGLMPro:
|
||||||
|
return ChatGLMPro
|
||||||
|
case globals.ZhiPuChatGLMStd:
|
||||||
|
return ChatGLMStd
|
||||||
|
case globals.ZhiPuChatGLMLite:
|
||||||
|
return ChatGLMLite
|
||||||
|
default:
|
||||||
|
return ChatGLMStd
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *ChatInstance) FormatMessages(messages []globals.Message) []globals.Message {
|
||||||
|
messages = utils.DeepCopy[[]globals.Message](messages)
|
||||||
|
for i := range messages {
|
||||||
|
if messages[i].Role == "system" {
|
||||||
|
messages[i].Role = "user"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return messages
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *ChatInstance) CreateStreamChatRequest(props *ChatProps, hook globals.Hook) error {
|
||||||
|
return utils.EventSource(
|
||||||
|
"POST",
|
||||||
|
c.GetChatEndpoint(props.Model),
|
||||||
|
map[string]string{
|
||||||
|
"Content-Type": "application/json",
|
||||||
|
"Accept": "text/event-stream",
|
||||||
|
"Authorization": c.GetToken(),
|
||||||
|
},
|
||||||
|
ChatRequest{
|
||||||
|
Prompt: c.FormatMessages(props.Message),
|
||||||
|
},
|
||||||
|
func(data string) error {
|
||||||
|
if !strings.HasPrefix(data, "data:") {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
data = strings.TrimPrefix(data, "data:")
|
||||||
|
return hook(data)
|
||||||
|
},
|
||||||
|
)
|
||||||
|
}
|
52
adapter/zhipuai/struct.go
Normal file
52
adapter/zhipuai/struct.go
Normal file
@ -0,0 +1,52 @@
|
|||||||
|
package zhipuai
|
||||||
|
|
||||||
|
import (
|
||||||
|
"chat/utils"
|
||||||
|
"github.com/dgrijalva/jwt-go"
|
||||||
|
"github.com/spf13/viper"
|
||||||
|
"strings"
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
type ChatInstance struct {
|
||||||
|
Endpoint string
|
||||||
|
ApiKey string
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *ChatInstance) GetToken() string {
|
||||||
|
// get jwt token for zhipuai api
|
||||||
|
segment := strings.Split(c.ApiKey, ".")
|
||||||
|
if len(segment) != 2 {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
id, secret := segment[0], segment[1]
|
||||||
|
|
||||||
|
payload := utils.MapToStruct[jwt.MapClaims](Payload{
|
||||||
|
ApiKey: id,
|
||||||
|
Exp: time.Now().Add(time.Minute*5).Unix() * 1000,
|
||||||
|
TimeStamp: time.Now().Unix() * 1000,
|
||||||
|
})
|
||||||
|
|
||||||
|
instance := jwt.NewWithClaims(jwt.SigningMethodHS256, payload)
|
||||||
|
instance.Header = map[string]interface{}{
|
||||||
|
"alg": "HS256",
|
||||||
|
"sign_type": "SIGN",
|
||||||
|
}
|
||||||
|
token, _ := instance.SignedString([]byte(secret))
|
||||||
|
return token
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *ChatInstance) GetEndpoint() string {
|
||||||
|
return c.Endpoint
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewChatInstance(endpoint, apikey string) *ChatInstance {
|
||||||
|
return &ChatInstance{
|
||||||
|
Endpoint: endpoint,
|
||||||
|
ApiKey: apikey,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewChatInstanceFromConfig() *ChatInstance {
|
||||||
|
return NewChatInstance(viper.GetString("zhipuai.endpoint"), viper.GetString("zhipuai.apikey"))
|
||||||
|
}
|
25
adapter/zhipuai/types.go
Normal file
25
adapter/zhipuai/types.go
Normal file
@ -0,0 +1,25 @@
|
|||||||
|
package zhipuai
|
||||||
|
|
||||||
|
import "chat/globals"
|
||||||
|
|
||||||
|
const (
|
||||||
|
ChatGLMPro = "chatglm_pro"
|
||||||
|
ChatGLMStd = "chatglm_std"
|
||||||
|
ChatGLMLite = "chatglm_lite"
|
||||||
|
)
|
||||||
|
|
||||||
|
type Payload struct {
|
||||||
|
ApiKey string `json:"api_key"`
|
||||||
|
Exp int64 `json:"exp"`
|
||||||
|
TimeStamp int64 `json:"timestamp"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type ChatRequest struct {
|
||||||
|
Prompt []globals.Message `json:"prompt"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type Occurrence struct {
|
||||||
|
Code int `json:"code"`
|
||||||
|
Msg string `json:"msg"`
|
||||||
|
Success bool `json:"success"`
|
||||||
|
}
|
@ -1,6 +1,6 @@
|
|||||||
import axios from "axios";
|
import axios from "axios";
|
||||||
|
|
||||||
export const version = "3.3.4";
|
export const version = "3.4.0";
|
||||||
export const deploy: boolean = true;
|
export const deploy: boolean = true;
|
||||||
export let rest_api: string = "http://localhost:8094";
|
export let rest_api: string = "http://localhost:8094";
|
||||||
export let ws_api: string = "ws://localhost:8094";
|
export let ws_api: string = "ws://localhost:8094";
|
||||||
@ -21,8 +21,9 @@ export const supportModels: string[] = [
|
|||||||
"SparkDesk 讯飞星火",
|
"SparkDesk 讯飞星火",
|
||||||
"Palm2",
|
"Palm2",
|
||||||
"New Bing",
|
"New Bing",
|
||||||
// "Claude-2",
|
"智谱 ChatGLM Pro",
|
||||||
// "Claude-2-100k",
|
"智谱 ChatGLM Std",
|
||||||
|
"智谱 ChatGLM Lite",
|
||||||
];
|
];
|
||||||
|
|
||||||
export const supportModelConvertor: Record<string, string> = {
|
export const supportModelConvertor: Record<string, string> = {
|
||||||
@ -35,6 +36,9 @@ export const supportModelConvertor: Record<string, string> = {
|
|||||||
"SparkDesk 讯飞星火": "spark-desk",
|
"SparkDesk 讯飞星火": "spark-desk",
|
||||||
Palm2: "chat-bison-001",
|
Palm2: "chat-bison-001",
|
||||||
"New Bing": "bing-creative",
|
"New Bing": "bing-creative",
|
||||||
|
"智谱 ChatGLM Pro": "zhipu-chatglm-pro",
|
||||||
|
"智谱 ChatGLM Std": "zhipu-chatglm-std",
|
||||||
|
"智谱 ChatGLM Lite": "zhipu-chatglm-lite",
|
||||||
};
|
};
|
||||||
|
|
||||||
export function login() {
|
export function login() {
|
||||||
|
@ -73,15 +73,18 @@ func ReduceDalle(db *sql.DB, user *User) bool {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func CanEnableModel(db *sql.DB, user *User, model string) bool {
|
func CanEnableModel(db *sql.DB, user *User, model string) bool {
|
||||||
|
auth := user != nil
|
||||||
switch model {
|
switch model {
|
||||||
case globals.GPT4, globals.GPT40613, globals.GPT40314:
|
case globals.GPT4, globals.GPT40613, globals.GPT40314:
|
||||||
return user != nil && user.GetQuota(db) >= 5
|
return auth && user.GetQuota(db) >= 5
|
||||||
case globals.GPT432k, globals.GPT432k0613, globals.GPT432k0314:
|
case globals.GPT432k, globals.GPT432k0613, globals.GPT432k0314:
|
||||||
return user != nil && user.GetQuota(db) >= 50
|
return auth && user.GetQuota(db) >= 50
|
||||||
case globals.SparkDesk:
|
case globals.SparkDesk:
|
||||||
return user != nil && user.GetQuota(db) >= 1
|
return auth && user.GetQuota(db) >= 1
|
||||||
case globals.Claude2100k:
|
case globals.Claude2100k:
|
||||||
return user != nil && user.GetQuota(db) >= 1
|
return auth && user.GetQuota(db) >= 1
|
||||||
|
case globals.ZhiPuChatGLMPro, globals.ZhiPuChatGLMStd:
|
||||||
|
return auth && user.GetQuota(db) >= 1
|
||||||
default:
|
default:
|
||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
|
@ -37,6 +37,9 @@ const (
|
|||||||
BingCreative = "bing-creative"
|
BingCreative = "bing-creative"
|
||||||
BingBalanced = "bing-balanced"
|
BingBalanced = "bing-balanced"
|
||||||
BingPrecise = "bing-precise"
|
BingPrecise = "bing-precise"
|
||||||
|
ZhiPuChatGLMPro = "zhipu-chatglm-pro"
|
||||||
|
ZhiPuChatGLMStd = "zhipu-chatglm-std"
|
||||||
|
ZhiPuChatGLMLite = "zhipu-chatglm-lite"
|
||||||
)
|
)
|
||||||
|
|
||||||
var GPT3TurboArray = []string{
|
var GPT3TurboArray = []string{
|
||||||
@ -74,6 +77,12 @@ var BingModelArray = []string{
|
|||||||
BingPrecise,
|
BingPrecise,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
var ZhiPuModelArray = []string{
|
||||||
|
ZhiPuChatGLMPro,
|
||||||
|
ZhiPuChatGLMStd,
|
||||||
|
ZhiPuChatGLMLite,
|
||||||
|
}
|
||||||
|
|
||||||
var LongContextModelArray = []string{
|
var LongContextModelArray = []string{
|
||||||
GPT3Turbo16k,
|
GPT3Turbo16k,
|
||||||
GPT3Turbo16k0613,
|
GPT3Turbo16k0613,
|
||||||
@ -134,6 +143,10 @@ func IsBingModel(model string) bool {
|
|||||||
return in(model, BingModelArray)
|
return in(model, BingModelArray)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func IsZhiPuModel(model string) bool {
|
||||||
|
return in(model, ZhiPuModelArray)
|
||||||
|
}
|
||||||
|
|
||||||
func IsLongContextModel(model string) bool {
|
func IsLongContextModel(model string) bool {
|
||||||
return in(model, LongContextModelArray)
|
return in(model, LongContextModelArray)
|
||||||
}
|
}
|
||||||
|
@ -121,7 +121,7 @@ func (c *Conversation) GetMessageSegment(length int) []globals.Message {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func CopyMessage(message []globals.Message) []globals.Message {
|
func CopyMessage(message []globals.Message) []globals.Message {
|
||||||
return utils.UnmarshalJson[[]globals.Message](utils.ToJson(message)) // deep copy
|
return utils.DeepCopy[[]globals.Message](message) // deep copy
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *Conversation) GetLastMessage() globals.Message {
|
func (c *Conversation) GetLastMessage() globals.Message {
|
||||||
|
@ -52,7 +52,7 @@ func AuthMiddleware() gin.HandlerFunc {
|
|||||||
k := strings.TrimSpace(c.GetHeader("Authorization"))
|
k := strings.TrimSpace(c.GetHeader("Authorization"))
|
||||||
if k != "" {
|
if k != "" {
|
||||||
if strings.HasPrefix(k, "Bearer ") {
|
if strings.HasPrefix(k, "Bearer ") {
|
||||||
k, _ = strings.CutPrefix(k, "Bearer ")
|
k = strings.TrimPrefix(k, "Bearer ")
|
||||||
}
|
}
|
||||||
|
|
||||||
if strings.HasPrefix(k, "sk-") { // api agent
|
if strings.HasPrefix(k, "sk-") { // api agent
|
||||||
|
@ -65,6 +65,10 @@ func UnmarshalJson[T any](value string) T {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func DeepCopy[T any](value T) T {
|
||||||
|
return UnmarshalJson[T](ToJson(value))
|
||||||
|
}
|
||||||
|
|
||||||
func GetSegment[T any](arr []T, length int) []T {
|
func GetSegment[T any](arr []T, length int) []T {
|
||||||
if length > len(arr) {
|
if length > len(arr) {
|
||||||
return arr
|
return arr
|
||||||
|
@ -1,6 +1,7 @@
|
|||||||
package utils
|
package utils
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"chat/adapter/zhipuai"
|
||||||
"chat/globals"
|
"chat/globals"
|
||||||
"fmt"
|
"fmt"
|
||||||
"github.com/pkoukk/tiktoken-go"
|
"github.com/pkoukk/tiktoken-go"
|
||||||
@ -52,7 +53,11 @@ func GetWeightByModel(model string) int {
|
|||||||
globals.GPT40613,
|
globals.GPT40613,
|
||||||
globals.SparkDesk:
|
globals.SparkDesk:
|
||||||
return 3
|
return 3
|
||||||
case globals.GPT3Turbo0301, globals.GPT3Turbo16k0301:
|
case globals.GPT3Turbo0301,
|
||||||
|
globals.GPT3Turbo16k0301,
|
||||||
|
globals.ZhiPuChatGLMLite,
|
||||||
|
globals.ZhiPuChatGLMStd,
|
||||||
|
globals.ZhiPuChatGLMPro:
|
||||||
return 4 // every message follows <|start|>{role/name}\n{content}<|end|>\n
|
return 4 // every message follows <|start|>{role/name}\n{content}<|end|>\n
|
||||||
default:
|
default:
|
||||||
if strings.Contains(model, globals.GPT3Turbo) {
|
if strings.Contains(model, globals.GPT3Turbo) {
|
||||||
@ -110,6 +115,10 @@ func CountInputToken(model string, v []globals.Message) float32 {
|
|||||||
return 0
|
return 0
|
||||||
case globals.Claude2100k:
|
case globals.Claude2100k:
|
||||||
return float32(CountTokenPrice(v, model)) / 1000 * 0.008
|
return float32(CountTokenPrice(v, model)) / 1000 * 0.008
|
||||||
|
case zhipuai.ChatGLMPro:
|
||||||
|
return float32(CountTokenPrice(v, model)) / 1000 * 0.1
|
||||||
|
case zhipuai.ChatGLMStd:
|
||||||
|
return float32(CountTokenPrice(v, model)) / 1000 * 0.05
|
||||||
default:
|
default:
|
||||||
return 0
|
return 0
|
||||||
}
|
}
|
||||||
@ -131,6 +140,10 @@ func CountOutputToken(model string, t int) float32 {
|
|||||||
return 0
|
return 0
|
||||||
case globals.Claude2100k:
|
case globals.Claude2100k:
|
||||||
return float32(t*GetWeightByModel(model)) / 1000 * 0.008
|
return float32(t*GetWeightByModel(model)) / 1000 * 0.008
|
||||||
|
case zhipuai.ChatGLMPro:
|
||||||
|
return float32(t*GetWeightByModel(model)) / 1000 * 0.1
|
||||||
|
case zhipuai.ChatGLMStd:
|
||||||
|
return float32(t*GetWeightByModel(model)) / 1000 * 0.05
|
||||||
default:
|
default:
|
||||||
return 0
|
return 0
|
||||||
}
|
}
|
||||||
|
Loading…
Reference in New Issue
Block a user