diff --git a/adapter/common/types.go b/adapter/common/types.go index 3f8723a..47da4e5 100644 --- a/adapter/common/types.go +++ b/adapter/common/types.go @@ -6,28 +6,37 @@ import ( ) type RequestProps struct { - MaxRetries *int - Current int - Group string - - Proxy globals.ProxyConfig + MaxRetries *int `json:"-"` + Current int `json:"-"` + Group string `json:"-"` + Proxy globals.ProxyConfig `json:"-"` } type ChatProps struct { RequestProps - Model string - OriginalModel string + Model string `json:"model,omitempty"` + OriginalModel string `json:"-"` - Message []globals.Message - MaxTokens *int - PresencePenalty *float32 - FrequencyPenalty *float32 - RepetitionPenalty *float32 - Temperature *float32 - TopP *float32 - TopK *int - Tools *globals.FunctionTools - ToolChoice *interface{} - Buffer utils.Buffer + Message []globals.Message `json:"messages,omitempty"` + MaxTokens *int `json:"max_tokens,omitempty"` + PresencePenalty *float32 `json:"presence_penalty,omitempty"` + FrequencyPenalty *float32 `json:"frequency_penalty,omitempty"` + RepetitionPenalty *float32 `json:"repetition_penalty,omitempty"` + Temperature *float32 `json:"temperature,omitempty"` + TopP *float32 `json:"top_p,omitempty"` + TopK *int `json:"top_k,omitempty"` + Tools *globals.FunctionTools `json:"tools,omitempty"` + ToolChoice *interface{} `json:"tool_choice,omitempty"` + Buffer *utils.Buffer `json:"-"` +} + +func (c *ChatProps) SetupBuffer(buf *utils.Buffer) { + buf.SetPrompts(c) + c.Buffer = buf +} + +func CreateChatProps(props *ChatProps, buffer *utils.Buffer) *ChatProps { + props.SetupBuffer(buffer) + return props } diff --git a/addition/generation/prompt.go b/addition/generation/prompt.go index 4312421..07dfc72 100644 --- a/addition/generation/prompt.go +++ b/addition/generation/prompt.go @@ -1,7 +1,7 @@ package generation import ( - "chat/adapter/common" + adaptercommon "chat/adapter/common" "chat/admin" "chat/channel" "chat/globals" @@ -17,17 +17,16 @@ func CreateGeneration(group, model, prompt, path string, hook func(buffer *utils message := GenerateMessage(prompt) buffer := utils.NewBuffer(model, message, channel.ChargeInstance.GetCharge(model)) - err := channel.NewChatRequest(group, &adaptercommon.ChatProps{ + err := channel.NewChatRequest(group, adaptercommon.CreateChatProps(&adaptercommon.ChatProps{ OriginalModel: model, Message: message, - Buffer: *buffer, - }, func(data *globals.Chunk) error { + }, buffer), func(data *globals.Chunk) error { buffer.WriteChunk(data) hook(buffer, data.Content) return nil }) - admin.AnalysisRequest(model, buffer, err) + admin.AnalyseRequest(model, buffer, err) if err != nil { return err } diff --git a/admin/statistic.go b/admin/statistic.go index 94710ef..1f086af 100644 --- a/admin/statistic.go +++ b/admin/statistic.go @@ -4,8 +4,9 @@ import ( "chat/adapter" "chat/connection" "chat/utils" - "github.com/go-redis/redis/v8" "time" + + "github.com/go-redis/redis/v8" ) func IncrErrorRequest(cache *redis.Client) { @@ -25,7 +26,7 @@ func IncrModelRequest(cache *redis.Client, model string, tokens int64) { utils.IncrWithExpire(cache, getModelFormat(getDay(), model), tokens, time.Hour*24*7*2) } -func AnalysisRequest(model string, buffer *utils.Buffer, err error) { +func AnalyseRequest(model string, buffer *utils.Buffer, err error) { instance := connection.Cache if adapter.IsAvailableError(err) { diff --git a/manager/chat.go b/manager/chat.go index 2c2e51f..7010169 100644 --- a/manager/chat.go +++ b/manager/chat.go @@ -104,7 +104,7 @@ func createChatTask( hit, err := channel.NewChatRequestWithCache( cache, buffer, auth.GetGroup(db, user), - &adaptercommon.ChatProps{ + adaptercommon.CreateChatProps(&adaptercommon.ChatProps{ Model: model, Message: segment, MaxTokens: instance.GetMaxTokens(), @@ -114,7 +114,7 @@ func createChatTask( PresencePenalty: instance.GetPresencePenalty(), FrequencyPenalty: instance.GetFrequencyPenalty(), RepetitionPenalty: instance.GetRepetitionPenalty(), - }, + }, buffer), // the function to handle the chunk data func(data *globals.Chunk) error { @@ -168,6 +168,7 @@ func createChatTask( interruptSignal <- err return hit, nil } + case signal := <-stopSignal: // if stop signal is received if signal { @@ -219,7 +220,7 @@ func ChatHandler(conn *Connection, user *auth.User, instance *conversation.Conve buffer := utils.NewBuffer(model, segment, channel.ChargeInstance.GetCharge(model)) hit, err := createChatTask(conn, user, buffer, db, cache, model, instance, segment, plan) - admin.AnalysisRequest(model, buffer, err) + admin.AnalyseRequest(model, buffer, err) if adapter.IsAvailableError(err) { globals.Warn(fmt.Sprintf("%s (model: %s, client: %s)", err, model, conn.GetCtx().ClientIP())) diff --git a/manager/chat_completions.go b/manager/chat_completions.go index 2ac8a7a..36ce60b 100644 --- a/manager/chat_completions.go +++ b/manager/chat_completions.go @@ -94,7 +94,7 @@ func ChatRelayAPI(c *gin.Context) { } func getChatProps(form RelayForm, messages []globals.Message, buffer *utils.Buffer) *adaptercommon.ChatProps { - return &adaptercommon.ChatProps{ + return adaptercommon.CreateChatProps(&adaptercommon.ChatProps{ Model: form.Model, Message: messages, MaxTokens: form.MaxTokens, @@ -106,8 +106,7 @@ func getChatProps(form RelayForm, messages []globals.Message, buffer *utils.Buff TopK: form.TopK, Tools: form.Tools, ToolChoice: form.ToolChoice, - Buffer: *buffer, - } + }, buffer) } func sendTranshipmentResponse(c *gin.Context, form RelayForm, messages []globals.Message, id string, created int64, user *auth.User, plan bool) { @@ -120,7 +119,7 @@ func sendTranshipmentResponse(c *gin.Context, form RelayForm, messages []globals return nil }) - admin.AnalysisRequest(form.Model, buffer, err) + admin.AnalyseRequest(form.Model, buffer, err) if err != nil { auth.RevertSubscriptionUsage(db, cache, user, form.Model) globals.Warn(fmt.Sprintf("error from chat request api: %s (instance: %s, client: %s)", err, form.Model, c.ClientIP())) @@ -235,7 +234,7 @@ func sendStreamTranshipmentResponse(c *gin.Context, form RelayForm, messages []g }, ) - admin.AnalysisRequest(form.Model, buffer, err) + admin.AnalyseRequest(form.Model, buffer, err) if err != nil { auth.RevertSubscriptionUsage(db, cache, user, form.Model) globals.Warn(fmt.Sprintf("error from chat request api: %s (instance: %s, client: %s)", err.Error(), form.Model, c.ClientIP())) diff --git a/manager/completions.go b/manager/completions.go index 567038d..f0762df 100644 --- a/manager/completions.go +++ b/manager/completions.go @@ -38,18 +38,17 @@ func NativeChatHandler(c *gin.Context, user *auth.User, model string, message [] hit, err := channel.NewChatRequestWithCache( cache, buffer, auth.GetGroup(db, user), - &adaptercommon.ChatProps{ + adaptercommon.CreateChatProps(&adaptercommon.ChatProps{ Model: model, Message: segment, - Buffer: *buffer, - }, + }, buffer), func(resp *globals.Chunk) error { buffer.WriteChunk(resp) return nil }, ) - admin.AnalysisRequest(model, buffer, err) + admin.AnalyseRequest(model, buffer, err) if err != nil { auth.RevertSubscriptionUsage(db, cache, user, model) return err.Error(), 0 diff --git a/manager/images.go b/manager/images.go index 1f036dc..5e80011 100644 --- a/manager/images.go +++ b/manager/images.go @@ -64,12 +64,11 @@ func ImagesRelayAPI(c *gin.Context) { } func getImageProps(form RelayImageForm, messages []globals.Message, buffer *utils.Buffer) *adaptercommon.ChatProps { - return &adaptercommon.ChatProps{ + return adaptercommon.CreateChatProps(&adaptercommon.ChatProps{ Model: form.Model, Message: messages, MaxTokens: utils.ToPtr(-1), - Buffer: *buffer, - } + }, buffer) } func getUrlFromBuffer(buffer *utils.Buffer) string { @@ -100,7 +99,7 @@ func createRelayImageObject(c *gin.Context, form RelayImageForm, prompt string, return nil }) - admin.AnalysisRequest(form.Model, buffer, err) + admin.AnalyseRequest(form.Model, buffer, err) if err != nil { auth.RevertSubscriptionUsage(db, cache, user, form.Model) globals.Warn(fmt.Sprintf("error from chat request api: %s (instance: %s, client: %s)", err, form.Model, c.ClientIP()))