fix: nil pointer dereference error when carrying an image to a conversation (#221)

This commit is contained in:
Deng Junhai 2024-07-01 14:38:14 +08:00
parent 9cb958080b
commit 576213d21f
5 changed files with 137 additions and 13 deletions

View File

@ -79,6 +79,7 @@ type commonState struct {
Expire int64 `json:"expire" mapstructure:"expire"`
Size int64 `json:"size" mapstructure:"size"`
ImageStore bool `json:"image_store" mapstructure:"imagestore"`
PromptStore bool `json:"prompt_store" mapstructure:"promptstore"`
}
type SystemConfig struct {
@ -114,6 +115,8 @@ func (c *SystemConfig) Load() {
globals.CacheAcceptedSize = c.GetCacheAcceptedSize()
globals.AcceptImageStore = c.AcceptImageStore()
globals.AcceptPromptStore = c.Common.PromptStore
if c.General.PWAManifest == "" {
c.General.PWAManifest = utils.ReadPWAManifest()
}

View File

@ -48,3 +48,9 @@ const (
HttpsProxyType
Socks5ProxyType
)
const (
WebTokenType = "web"
ApiTokenType = "api"
SystemToken = "system"
)

View File

@ -23,6 +23,7 @@ var CacheAcceptedModels []string
var CacheAcceptedExpire int64
var CacheAcceptedSize int64
var AcceptImageStore bool
var AcceptPromptStore bool
var CloseRegistration bool
var CloseRelay bool

View File

@ -2,7 +2,9 @@ package utils
import (
"chat/globals"
"fmt"
"strings"
"time"
)
type Charge interface {
@ -28,7 +30,11 @@ type Buffer struct {
ToolCalls *globals.ToolCalls `json:"tool_calls"`
ToolCallsCursor int `json:"tool_calls_cursor"`
FunctionCall *globals.FunctionCall `json:"function_call"`
StartTime *time.Time `json:"-"`
Prompts string `json:"prompts"`
TokenName string `json:"-"`
Charge Charge `json:"-"`
VisionRecall bool `json:"-"`
}
func initInputToken(model string, history []globals.Message) int {
@ -71,6 +77,7 @@ func NewBuffer(model string, history []globals.Message, charge Charge) *Buffer {
FunctionCall: nil,
ToolCalls: nil,
ToolCallsCursor: 0,
StartTime: ToPtr(time.Now()),
}
}
@ -79,6 +86,11 @@ func (b *Buffer) GetCursor() int {
}
func (b *Buffer) GetQuota() float32 {
return b.Quota + CountOutputToken(b.Charge, b.CountOutputToken(true))
}
func (b *Buffer) GetRecordQuota() float32 {
// end of the buffer, the output token is counted using the times
return b.Quota + CountOutputToken(b.Charge, b.CountOutputToken(false))
}
@ -106,15 +118,23 @@ func (b *Buffer) GetChunk() string {
return b.Latest
}
func (b *Buffer) AddImage(image *Image) {
if image != nil {
b.Images = append(b.Images, *image)
func (b *Buffer) InitVisionRecall() {
// set the vision recall flag to true to prevent the buffer from adding more images of retrying the channel
b.VisionRecall = true
}
if b.Charge.IsBillingType(globals.TokenBilling) {
if image != nil {
b.Quota += float32(image.CountTokens(b.Model)) * b.Charge.GetInput()
func (b *Buffer) AddImage(image *Image) {
if image == nil || b.VisionRecall {
return
}
b.Images = append(b.Images, *image)
tokens := image.CountTokens(b.Model)
b.InputTokens += tokens
if b.Charge.IsBillingType(globals.TokenBilling) {
b.Quota += float32(tokens) / 1000 * b.Charge.GetInput()
}
}
@ -145,6 +165,13 @@ func hitTool(tool globals.ToolCall, tools globals.ToolCalls) (int, *globals.Tool
return 0, nil
}
func appendTool(tool globals.ToolCall, chunk globals.ToolCall) string {
from := ToString(tool.Function.Arguments)
to := ToString(chunk.Function.Arguments)
return from + to
}
func mixTools(source *globals.ToolCalls, target *globals.ToolCalls) *globals.ToolCalls {
if source == nil {
return target
@ -157,7 +184,7 @@ func mixTools(source *globals.ToolCalls, target *globals.ToolCalls) *globals.Too
idx, hit := hitTool(tool, tools)
if hit != nil {
tools[idx].Function.Arguments += tool.Function.Arguments
tools[idx].Function.Arguments = appendTool(tools[idx], tool)
} else {
tools = append(tools, tool)
}
@ -209,6 +236,27 @@ func (b *Buffer) GetCharge() Charge {
return b.Charge
}
func (b *Buffer) ToChargeInfo() string {
switch b.Charge.GetType() {
case globals.TokenBilling:
return fmt.Sprintf(
"input tokens: %0.4f quota / 1k tokens\n"+
"output tokens: %0.4f quota / 1k tokens\n",
b.Charge.GetInput(), b.Charge.GetOutput(),
)
case globals.TimesBilling:
return fmt.Sprintf("%f quota per request\n", b.Charge.GetLimit())
case globals.NonBilling:
return "no cost"
}
return ""
}
func (b *Buffer) SetPrompts(prompts interface{}) {
b.Prompts = ToString(prompts)
}
func (b *Buffer) Read() string {
return b.Data
}
@ -247,5 +295,49 @@ func (b *Buffer) CountOutputToken(running bool) int {
}
func (b *Buffer) CountToken() int {
return b.CountInputToken() + b.CountOutputToken(false)
return b.CountInputToken() + b.CountOutputToken(true)
}
func (b *Buffer) GetDuration() float32 {
if b.StartTime == nil {
return 0
}
return float32(time.Since(*b.StartTime).Seconds())
}
func (b *Buffer) GetStartTime() *time.Time {
return b.StartTime
}
func (b *Buffer) GetPrompts() string {
return b.Prompts
}
func (b *Buffer) GetTokenName() string {
if len(b.TokenName) == 0 {
return globals.WebTokenType
}
return b.TokenName
}
func (b *Buffer) SetTokenName(tokenName string) {
b.TokenName = tokenName
}
func (b *Buffer) GetRecordPrompts() string {
if !globals.AcceptPromptStore {
return ""
}
return b.GetPrompts()
}
func (b *Buffer) GetRecordResponsePrompts() string {
if !globals.AcceptPromptStore {
return ""
}
return b.Read()
}

View File

@ -183,6 +183,28 @@ func Post(uri string, headers map[string]string, body interface{}, config ...glo
return data, err
}
func ToString(data interface{}) string {
switch v := data.(type) {
case string:
return v
case int, int8, int16, int32, int64:
return fmt.Sprintf("%d", v)
case uint, uint8, uint16, uint32, uint64:
return fmt.Sprintf("%d", v)
case float32, float64:
return fmt.Sprintf("%f", v)
case bool:
return fmt.Sprintf("%t", v)
default:
data := Marshal(data)
if len(data) > 0 {
return data
}
return fmt.Sprintf("%v", data)
}
}
func PostRaw(uri string, headers map[string]string, body interface{}, config ...globals.ProxyConfig) (data string, err error) {
buffer, err := HttpRaw(uri, http.MethodPost, headers, ConvertBody(body), config)
if err != nil {