fix: nil pointer dereference error when carrying an image to a conversation (#221)

This commit is contained in:
Deng Junhai 2024-07-01 14:38:14 +08:00
parent 9cb958080b
commit 576213d21f
5 changed files with 137 additions and 13 deletions

View File

@ -73,12 +73,13 @@ type SearchState struct {
} }
type commonState struct { type commonState struct {
Article []string `json:"article" mapstructure:"article"` Article []string `json:"article" mapstructure:"article"`
Generation []string `json:"generation" mapstructure:"generation"` Generation []string `json:"generation" mapstructure:"generation"`
Cache []string `json:"cache" mapstructure:"cache"` Cache []string `json:"cache" mapstructure:"cache"`
Expire int64 `json:"expire" mapstructure:"expire"` Expire int64 `json:"expire" mapstructure:"expire"`
Size int64 `json:"size" mapstructure:"size"` Size int64 `json:"size" mapstructure:"size"`
ImageStore bool `json:"image_store" mapstructure:"imagestore"` ImageStore bool `json:"image_store" mapstructure:"imagestore"`
PromptStore bool `json:"prompt_store" mapstructure:"promptstore"`
} }
type SystemConfig struct { type SystemConfig struct {
@ -114,6 +115,8 @@ func (c *SystemConfig) Load() {
globals.CacheAcceptedSize = c.GetCacheAcceptedSize() globals.CacheAcceptedSize = c.GetCacheAcceptedSize()
globals.AcceptImageStore = c.AcceptImageStore() globals.AcceptImageStore = c.AcceptImageStore()
globals.AcceptPromptStore = c.Common.PromptStore
if c.General.PWAManifest == "" { if c.General.PWAManifest == "" {
c.General.PWAManifest = utils.ReadPWAManifest() c.General.PWAManifest = utils.ReadPWAManifest()
} }

View File

@ -48,3 +48,9 @@ const (
HttpsProxyType HttpsProxyType
Socks5ProxyType Socks5ProxyType
) )
const (
WebTokenType = "web"
ApiTokenType = "api"
SystemToken = "system"
)

View File

@ -23,6 +23,7 @@ var CacheAcceptedModels []string
var CacheAcceptedExpire int64 var CacheAcceptedExpire int64
var CacheAcceptedSize int64 var CacheAcceptedSize int64
var AcceptImageStore bool var AcceptImageStore bool
var AcceptPromptStore bool
var CloseRegistration bool var CloseRegistration bool
var CloseRelay bool var CloseRelay bool

View File

@ -2,7 +2,9 @@ package utils
import ( import (
"chat/globals" "chat/globals"
"fmt"
"strings" "strings"
"time"
) )
type Charge interface { type Charge interface {
@ -28,7 +30,11 @@ type Buffer struct {
ToolCalls *globals.ToolCalls `json:"tool_calls"` ToolCalls *globals.ToolCalls `json:"tool_calls"`
ToolCallsCursor int `json:"tool_calls_cursor"` ToolCallsCursor int `json:"tool_calls_cursor"`
FunctionCall *globals.FunctionCall `json:"function_call"` FunctionCall *globals.FunctionCall `json:"function_call"`
StartTime *time.Time `json:"-"`
Prompts string `json:"prompts"`
TokenName string `json:"-"`
Charge Charge `json:"-"` Charge Charge `json:"-"`
VisionRecall bool `json:"-"`
} }
func initInputToken(model string, history []globals.Message) int { func initInputToken(model string, history []globals.Message) int {
@ -71,6 +77,7 @@ func NewBuffer(model string, history []globals.Message, charge Charge) *Buffer {
FunctionCall: nil, FunctionCall: nil,
ToolCalls: nil, ToolCalls: nil,
ToolCallsCursor: 0, ToolCallsCursor: 0,
StartTime: ToPtr(time.Now()),
} }
} }
@ -79,6 +86,11 @@ func (b *Buffer) GetCursor() int {
} }
func (b *Buffer) GetQuota() float32 { func (b *Buffer) GetQuota() float32 {
return b.Quota + CountOutputToken(b.Charge, b.CountOutputToken(true))
}
func (b *Buffer) GetRecordQuota() float32 {
// end of the buffer, the output token is counted using the times
return b.Quota + CountOutputToken(b.Charge, b.CountOutputToken(false)) return b.Quota + CountOutputToken(b.Charge, b.CountOutputToken(false))
} }
@ -106,15 +118,23 @@ func (b *Buffer) GetChunk() string {
return b.Latest return b.Latest
} }
func (b *Buffer) InitVisionRecall() {
// set the vision recall flag to true to prevent the buffer from adding more images of retrying the channel
b.VisionRecall = true
}
func (b *Buffer) AddImage(image *Image) { func (b *Buffer) AddImage(image *Image) {
if image != nil { if image == nil || b.VisionRecall {
b.Images = append(b.Images, *image) return
} }
b.Images = append(b.Images, *image)
tokens := image.CountTokens(b.Model)
b.InputTokens += tokens
if b.Charge.IsBillingType(globals.TokenBilling) { if b.Charge.IsBillingType(globals.TokenBilling) {
if image != nil { b.Quota += float32(tokens) / 1000 * b.Charge.GetInput()
b.Quota += float32(image.CountTokens(b.Model)) * b.Charge.GetInput()
}
} }
} }
@ -145,6 +165,13 @@ func hitTool(tool globals.ToolCall, tools globals.ToolCalls) (int, *globals.Tool
return 0, nil return 0, nil
} }
func appendTool(tool globals.ToolCall, chunk globals.ToolCall) string {
from := ToString(tool.Function.Arguments)
to := ToString(chunk.Function.Arguments)
return from + to
}
func mixTools(source *globals.ToolCalls, target *globals.ToolCalls) *globals.ToolCalls { func mixTools(source *globals.ToolCalls, target *globals.ToolCalls) *globals.ToolCalls {
if source == nil { if source == nil {
return target return target
@ -157,7 +184,7 @@ func mixTools(source *globals.ToolCalls, target *globals.ToolCalls) *globals.Too
idx, hit := hitTool(tool, tools) idx, hit := hitTool(tool, tools)
if hit != nil { if hit != nil {
tools[idx].Function.Arguments += tool.Function.Arguments tools[idx].Function.Arguments = appendTool(tools[idx], tool)
} else { } else {
tools = append(tools, tool) tools = append(tools, tool)
} }
@ -209,6 +236,27 @@ func (b *Buffer) GetCharge() Charge {
return b.Charge return b.Charge
} }
func (b *Buffer) ToChargeInfo() string {
switch b.Charge.GetType() {
case globals.TokenBilling:
return fmt.Sprintf(
"input tokens: %0.4f quota / 1k tokens\n"+
"output tokens: %0.4f quota / 1k tokens\n",
b.Charge.GetInput(), b.Charge.GetOutput(),
)
case globals.TimesBilling:
return fmt.Sprintf("%f quota per request\n", b.Charge.GetLimit())
case globals.NonBilling:
return "no cost"
}
return ""
}
func (b *Buffer) SetPrompts(prompts interface{}) {
b.Prompts = ToString(prompts)
}
func (b *Buffer) Read() string { func (b *Buffer) Read() string {
return b.Data return b.Data
} }
@ -247,5 +295,49 @@ func (b *Buffer) CountOutputToken(running bool) int {
} }
func (b *Buffer) CountToken() int { func (b *Buffer) CountToken() int {
return b.CountInputToken() + b.CountOutputToken(false) return b.CountInputToken() + b.CountOutputToken(true)
}
func (b *Buffer) GetDuration() float32 {
if b.StartTime == nil {
return 0
}
return float32(time.Since(*b.StartTime).Seconds())
}
func (b *Buffer) GetStartTime() *time.Time {
return b.StartTime
}
func (b *Buffer) GetPrompts() string {
return b.Prompts
}
func (b *Buffer) GetTokenName() string {
if len(b.TokenName) == 0 {
return globals.WebTokenType
}
return b.TokenName
}
func (b *Buffer) SetTokenName(tokenName string) {
b.TokenName = tokenName
}
func (b *Buffer) GetRecordPrompts() string {
if !globals.AcceptPromptStore {
return ""
}
return b.GetPrompts()
}
func (b *Buffer) GetRecordResponsePrompts() string {
if !globals.AcceptPromptStore {
return ""
}
return b.Read()
} }

View File

@ -183,6 +183,28 @@ func Post(uri string, headers map[string]string, body interface{}, config ...glo
return data, err return data, err
} }
func ToString(data interface{}) string {
switch v := data.(type) {
case string:
return v
case int, int8, int16, int32, int64:
return fmt.Sprintf("%d", v)
case uint, uint8, uint16, uint32, uint64:
return fmt.Sprintf("%d", v)
case float32, float64:
return fmt.Sprintf("%f", v)
case bool:
return fmt.Sprintf("%t", v)
default:
data := Marshal(data)
if len(data) > 0 {
return data
}
return fmt.Sprintf("%v", data)
}
}
func PostRaw(uri string, headers map[string]string, body interface{}, config ...globals.ProxyConfig) (data string, err error) { func PostRaw(uri string, headers map[string]string, body interface{}, config ...globals.ProxyConfig) (data string, err error) {
buffer, err := HttpRaw(uri, http.MethodPost, headers, ConvertBody(body), config) buffer, err := HttpRaw(uri, http.MethodPost, headers, ConvertBody(body), config)
if err != nil { if err != nil {