fix: nil pointer dereference error when carrying an image to a conversation (#221)

This commit is contained in:
Deng Junhai 2024-07-02 21:41:23 +08:00
parent 980bb2c101
commit 140bed53f9
7 changed files with 48 additions and 41 deletions

View File

@ -6,28 +6,37 @@ import (
) )
type RequestProps struct { type RequestProps struct {
MaxRetries *int MaxRetries *int `json:"-"`
Current int Current int `json:"-"`
Group string Group string `json:"-"`
Proxy globals.ProxyConfig `json:"-"`
Proxy globals.ProxyConfig
} }
type ChatProps struct { type ChatProps struct {
RequestProps RequestProps
Model string Model string `json:"model,omitempty"`
OriginalModel string OriginalModel string `json:"-"`
Message []globals.Message Message []globals.Message `json:"messages,omitempty"`
MaxTokens *int MaxTokens *int `json:"max_tokens,omitempty"`
PresencePenalty *float32 PresencePenalty *float32 `json:"presence_penalty,omitempty"`
FrequencyPenalty *float32 FrequencyPenalty *float32 `json:"frequency_penalty,omitempty"`
RepetitionPenalty *float32 RepetitionPenalty *float32 `json:"repetition_penalty,omitempty"`
Temperature *float32 Temperature *float32 `json:"temperature,omitempty"`
TopP *float32 TopP *float32 `json:"top_p,omitempty"`
TopK *int TopK *int `json:"top_k,omitempty"`
Tools *globals.FunctionTools Tools *globals.FunctionTools `json:"tools,omitempty"`
ToolChoice *interface{} ToolChoice *interface{} `json:"tool_choice,omitempty"`
Buffer utils.Buffer Buffer *utils.Buffer `json:"-"`
}
func (c *ChatProps) SetupBuffer(buf *utils.Buffer) {
buf.SetPrompts(c)
c.Buffer = buf
}
func CreateChatProps(props *ChatProps, buffer *utils.Buffer) *ChatProps {
props.SetupBuffer(buffer)
return props
} }

View File

@ -1,7 +1,7 @@
package generation package generation
import ( import (
"chat/adapter/common" adaptercommon "chat/adapter/common"
"chat/admin" "chat/admin"
"chat/channel" "chat/channel"
"chat/globals" "chat/globals"
@ -17,17 +17,16 @@ func CreateGeneration(group, model, prompt, path string, hook func(buffer *utils
message := GenerateMessage(prompt) message := GenerateMessage(prompt)
buffer := utils.NewBuffer(model, message, channel.ChargeInstance.GetCharge(model)) buffer := utils.NewBuffer(model, message, channel.ChargeInstance.GetCharge(model))
err := channel.NewChatRequest(group, &adaptercommon.ChatProps{ err := channel.NewChatRequest(group, adaptercommon.CreateChatProps(&adaptercommon.ChatProps{
OriginalModel: model, OriginalModel: model,
Message: message, Message: message,
Buffer: *buffer, }, buffer), func(data *globals.Chunk) error {
}, func(data *globals.Chunk) error {
buffer.WriteChunk(data) buffer.WriteChunk(data)
hook(buffer, data.Content) hook(buffer, data.Content)
return nil return nil
}) })
admin.AnalysisRequest(model, buffer, err) admin.AnalyseRequest(model, buffer, err)
if err != nil { if err != nil {
return err return err
} }

View File

@ -4,8 +4,9 @@ import (
"chat/adapter" "chat/adapter"
"chat/connection" "chat/connection"
"chat/utils" "chat/utils"
"github.com/go-redis/redis/v8"
"time" "time"
"github.com/go-redis/redis/v8"
) )
func IncrErrorRequest(cache *redis.Client) { func IncrErrorRequest(cache *redis.Client) {
@ -25,7 +26,7 @@ func IncrModelRequest(cache *redis.Client, model string, tokens int64) {
utils.IncrWithExpire(cache, getModelFormat(getDay(), model), tokens, time.Hour*24*7*2) utils.IncrWithExpire(cache, getModelFormat(getDay(), model), tokens, time.Hour*24*7*2)
} }
func AnalysisRequest(model string, buffer *utils.Buffer, err error) { func AnalyseRequest(model string, buffer *utils.Buffer, err error) {
instance := connection.Cache instance := connection.Cache
if adapter.IsAvailableError(err) { if adapter.IsAvailableError(err) {

View File

@ -104,7 +104,7 @@ func createChatTask(
hit, err := channel.NewChatRequestWithCache( hit, err := channel.NewChatRequestWithCache(
cache, buffer, cache, buffer,
auth.GetGroup(db, user), auth.GetGroup(db, user),
&adaptercommon.ChatProps{ adaptercommon.CreateChatProps(&adaptercommon.ChatProps{
Model: model, Model: model,
Message: segment, Message: segment,
MaxTokens: instance.GetMaxTokens(), MaxTokens: instance.GetMaxTokens(),
@ -114,7 +114,7 @@ func createChatTask(
PresencePenalty: instance.GetPresencePenalty(), PresencePenalty: instance.GetPresencePenalty(),
FrequencyPenalty: instance.GetFrequencyPenalty(), FrequencyPenalty: instance.GetFrequencyPenalty(),
RepetitionPenalty: instance.GetRepetitionPenalty(), RepetitionPenalty: instance.GetRepetitionPenalty(),
}, }, buffer),
// the function to handle the chunk data // the function to handle the chunk data
func(data *globals.Chunk) error { func(data *globals.Chunk) error {
@ -168,6 +168,7 @@ func createChatTask(
interruptSignal <- err interruptSignal <- err
return hit, nil return hit, nil
} }
case signal := <-stopSignal: case signal := <-stopSignal:
// if stop signal is received // if stop signal is received
if signal { if signal {
@ -219,7 +220,7 @@ func ChatHandler(conn *Connection, user *auth.User, instance *conversation.Conve
buffer := utils.NewBuffer(model, segment, channel.ChargeInstance.GetCharge(model)) buffer := utils.NewBuffer(model, segment, channel.ChargeInstance.GetCharge(model))
hit, err := createChatTask(conn, user, buffer, db, cache, model, instance, segment, plan) hit, err := createChatTask(conn, user, buffer, db, cache, model, instance, segment, plan)
admin.AnalysisRequest(model, buffer, err) admin.AnalyseRequest(model, buffer, err)
if adapter.IsAvailableError(err) { if adapter.IsAvailableError(err) {
globals.Warn(fmt.Sprintf("%s (model: %s, client: %s)", err, model, conn.GetCtx().ClientIP())) globals.Warn(fmt.Sprintf("%s (model: %s, client: %s)", err, model, conn.GetCtx().ClientIP()))

View File

@ -94,7 +94,7 @@ func ChatRelayAPI(c *gin.Context) {
} }
func getChatProps(form RelayForm, messages []globals.Message, buffer *utils.Buffer) *adaptercommon.ChatProps { func getChatProps(form RelayForm, messages []globals.Message, buffer *utils.Buffer) *adaptercommon.ChatProps {
return &adaptercommon.ChatProps{ return adaptercommon.CreateChatProps(&adaptercommon.ChatProps{
Model: form.Model, Model: form.Model,
Message: messages, Message: messages,
MaxTokens: form.MaxTokens, MaxTokens: form.MaxTokens,
@ -106,8 +106,7 @@ func getChatProps(form RelayForm, messages []globals.Message, buffer *utils.Buff
TopK: form.TopK, TopK: form.TopK,
Tools: form.Tools, Tools: form.Tools,
ToolChoice: form.ToolChoice, ToolChoice: form.ToolChoice,
Buffer: *buffer, }, buffer)
}
} }
func sendTranshipmentResponse(c *gin.Context, form RelayForm, messages []globals.Message, id string, created int64, user *auth.User, plan bool) { func sendTranshipmentResponse(c *gin.Context, form RelayForm, messages []globals.Message, id string, created int64, user *auth.User, plan bool) {
@ -120,7 +119,7 @@ func sendTranshipmentResponse(c *gin.Context, form RelayForm, messages []globals
return nil return nil
}) })
admin.AnalysisRequest(form.Model, buffer, err) admin.AnalyseRequest(form.Model, buffer, err)
if err != nil { if err != nil {
auth.RevertSubscriptionUsage(db, cache, user, form.Model) auth.RevertSubscriptionUsage(db, cache, user, form.Model)
globals.Warn(fmt.Sprintf("error from chat request api: %s (instance: %s, client: %s)", err, form.Model, c.ClientIP())) globals.Warn(fmt.Sprintf("error from chat request api: %s (instance: %s, client: %s)", err, form.Model, c.ClientIP()))
@ -235,7 +234,7 @@ func sendStreamTranshipmentResponse(c *gin.Context, form RelayForm, messages []g
}, },
) )
admin.AnalysisRequest(form.Model, buffer, err) admin.AnalyseRequest(form.Model, buffer, err)
if err != nil { if err != nil {
auth.RevertSubscriptionUsage(db, cache, user, form.Model) auth.RevertSubscriptionUsage(db, cache, user, form.Model)
globals.Warn(fmt.Sprintf("error from chat request api: %s (instance: %s, client: %s)", err.Error(), form.Model, c.ClientIP())) globals.Warn(fmt.Sprintf("error from chat request api: %s (instance: %s, client: %s)", err.Error(), form.Model, c.ClientIP()))

View File

@ -38,18 +38,17 @@ func NativeChatHandler(c *gin.Context, user *auth.User, model string, message []
hit, err := channel.NewChatRequestWithCache( hit, err := channel.NewChatRequestWithCache(
cache, buffer, cache, buffer,
auth.GetGroup(db, user), auth.GetGroup(db, user),
&adaptercommon.ChatProps{ adaptercommon.CreateChatProps(&adaptercommon.ChatProps{
Model: model, Model: model,
Message: segment, Message: segment,
Buffer: *buffer, }, buffer),
},
func(resp *globals.Chunk) error { func(resp *globals.Chunk) error {
buffer.WriteChunk(resp) buffer.WriteChunk(resp)
return nil return nil
}, },
) )
admin.AnalysisRequest(model, buffer, err) admin.AnalyseRequest(model, buffer, err)
if err != nil { if err != nil {
auth.RevertSubscriptionUsage(db, cache, user, model) auth.RevertSubscriptionUsage(db, cache, user, model)
return err.Error(), 0 return err.Error(), 0

View File

@ -64,12 +64,11 @@ func ImagesRelayAPI(c *gin.Context) {
} }
func getImageProps(form RelayImageForm, messages []globals.Message, buffer *utils.Buffer) *adaptercommon.ChatProps { func getImageProps(form RelayImageForm, messages []globals.Message, buffer *utils.Buffer) *adaptercommon.ChatProps {
return &adaptercommon.ChatProps{ return adaptercommon.CreateChatProps(&adaptercommon.ChatProps{
Model: form.Model, Model: form.Model,
Message: messages, Message: messages,
MaxTokens: utils.ToPtr(-1), MaxTokens: utils.ToPtr(-1),
Buffer: *buffer, }, buffer)
}
} }
func getUrlFromBuffer(buffer *utils.Buffer) string { func getUrlFromBuffer(buffer *utils.Buffer) string {
@ -100,7 +99,7 @@ func createRelayImageObject(c *gin.Context, form RelayImageForm, prompt string,
return nil return nil
}) })
admin.AnalysisRequest(form.Model, buffer, err) admin.AnalyseRequest(form.Model, buffer, err)
if err != nil { if err != nil {
auth.RevertSubscriptionUsage(db, cache, user, form.Model) auth.RevertSubscriptionUsage(db, cache, user, form.Model)
globals.Warn(fmt.Sprintf("error from chat request api: %s (instance: %s, client: %s)", err, form.Model, c.ClientIP())) globals.Warn(fmt.Sprintf("error from chat request api: %s (instance: %s, client: %s)", err, form.Model, c.ClientIP()))