mirror of
https://github.com/coaidev/coai.git
synced 2025-05-19 04:50:14 +09:00
fix: nil pointer dereference
error when carrying an image to a conversation (#221)
This commit is contained in:
parent
9cb958080b
commit
576213d21f
@ -79,6 +79,7 @@ type commonState struct {
|
||||
Expire int64 `json:"expire" mapstructure:"expire"`
|
||||
Size int64 `json:"size" mapstructure:"size"`
|
||||
ImageStore bool `json:"image_store" mapstructure:"imagestore"`
|
||||
PromptStore bool `json:"prompt_store" mapstructure:"promptstore"`
|
||||
}
|
||||
|
||||
type SystemConfig struct {
|
||||
@ -114,6 +115,8 @@ func (c *SystemConfig) Load() {
|
||||
globals.CacheAcceptedSize = c.GetCacheAcceptedSize()
|
||||
globals.AcceptImageStore = c.AcceptImageStore()
|
||||
|
||||
globals.AcceptPromptStore = c.Common.PromptStore
|
||||
|
||||
if c.General.PWAManifest == "" {
|
||||
c.General.PWAManifest = utils.ReadPWAManifest()
|
||||
}
|
||||
|
@ -48,3 +48,9 @@ const (
|
||||
HttpsProxyType
|
||||
Socks5ProxyType
|
||||
)
|
||||
|
||||
const (
|
||||
WebTokenType = "web"
|
||||
ApiTokenType = "api"
|
||||
SystemToken = "system"
|
||||
)
|
||||
|
@ -23,6 +23,7 @@ var CacheAcceptedModels []string
|
||||
var CacheAcceptedExpire int64
|
||||
var CacheAcceptedSize int64
|
||||
var AcceptImageStore bool
|
||||
var AcceptPromptStore bool
|
||||
var CloseRegistration bool
|
||||
var CloseRelay bool
|
||||
|
||||
|
108
utils/buffer.go
108
utils/buffer.go
@ -2,7 +2,9 @@ package utils
|
||||
|
||||
import (
|
||||
"chat/globals"
|
||||
"fmt"
|
||||
"strings"
|
||||
"time"
|
||||
)
|
||||
|
||||
type Charge interface {
|
||||
@ -28,7 +30,11 @@ type Buffer struct {
|
||||
ToolCalls *globals.ToolCalls `json:"tool_calls"`
|
||||
ToolCallsCursor int `json:"tool_calls_cursor"`
|
||||
FunctionCall *globals.FunctionCall `json:"function_call"`
|
||||
StartTime *time.Time `json:"-"`
|
||||
Prompts string `json:"prompts"`
|
||||
TokenName string `json:"-"`
|
||||
Charge Charge `json:"-"`
|
||||
VisionRecall bool `json:"-"`
|
||||
}
|
||||
|
||||
func initInputToken(model string, history []globals.Message) int {
|
||||
@ -71,6 +77,7 @@ func NewBuffer(model string, history []globals.Message, charge Charge) *Buffer {
|
||||
FunctionCall: nil,
|
||||
ToolCalls: nil,
|
||||
ToolCallsCursor: 0,
|
||||
StartTime: ToPtr(time.Now()),
|
||||
}
|
||||
}
|
||||
|
||||
@ -79,6 +86,11 @@ func (b *Buffer) GetCursor() int {
|
||||
}
|
||||
|
||||
func (b *Buffer) GetQuota() float32 {
|
||||
return b.Quota + CountOutputToken(b.Charge, b.CountOutputToken(true))
|
||||
}
|
||||
|
||||
func (b *Buffer) GetRecordQuota() float32 {
|
||||
// end of the buffer, the output token is counted using the times
|
||||
return b.Quota + CountOutputToken(b.Charge, b.CountOutputToken(false))
|
||||
}
|
||||
|
||||
@ -106,15 +118,23 @@ func (b *Buffer) GetChunk() string {
|
||||
return b.Latest
|
||||
}
|
||||
|
||||
func (b *Buffer) AddImage(image *Image) {
|
||||
if image != nil {
|
||||
b.Images = append(b.Images, *image)
|
||||
func (b *Buffer) InitVisionRecall() {
|
||||
// set the vision recall flag to true to prevent the buffer from adding more images of retrying the channel
|
||||
b.VisionRecall = true
|
||||
}
|
||||
|
||||
if b.Charge.IsBillingType(globals.TokenBilling) {
|
||||
if image != nil {
|
||||
b.Quota += float32(image.CountTokens(b.Model)) * b.Charge.GetInput()
|
||||
func (b *Buffer) AddImage(image *Image) {
|
||||
if image == nil || b.VisionRecall {
|
||||
return
|
||||
}
|
||||
|
||||
b.Images = append(b.Images, *image)
|
||||
|
||||
tokens := image.CountTokens(b.Model)
|
||||
b.InputTokens += tokens
|
||||
|
||||
if b.Charge.IsBillingType(globals.TokenBilling) {
|
||||
b.Quota += float32(tokens) / 1000 * b.Charge.GetInput()
|
||||
}
|
||||
}
|
||||
|
||||
@ -145,6 +165,13 @@ func hitTool(tool globals.ToolCall, tools globals.ToolCalls) (int, *globals.Tool
|
||||
return 0, nil
|
||||
}
|
||||
|
||||
func appendTool(tool globals.ToolCall, chunk globals.ToolCall) string {
|
||||
from := ToString(tool.Function.Arguments)
|
||||
to := ToString(chunk.Function.Arguments)
|
||||
|
||||
return from + to
|
||||
}
|
||||
|
||||
func mixTools(source *globals.ToolCalls, target *globals.ToolCalls) *globals.ToolCalls {
|
||||
if source == nil {
|
||||
return target
|
||||
@ -157,7 +184,7 @@ func mixTools(source *globals.ToolCalls, target *globals.ToolCalls) *globals.Too
|
||||
idx, hit := hitTool(tool, tools)
|
||||
|
||||
if hit != nil {
|
||||
tools[idx].Function.Arguments += tool.Function.Arguments
|
||||
tools[idx].Function.Arguments = appendTool(tools[idx], tool)
|
||||
} else {
|
||||
tools = append(tools, tool)
|
||||
}
|
||||
@ -209,6 +236,27 @@ func (b *Buffer) GetCharge() Charge {
|
||||
return b.Charge
|
||||
}
|
||||
|
||||
func (b *Buffer) ToChargeInfo() string {
|
||||
switch b.Charge.GetType() {
|
||||
case globals.TokenBilling:
|
||||
return fmt.Sprintf(
|
||||
"input tokens: %0.4f quota / 1k tokens\n"+
|
||||
"output tokens: %0.4f quota / 1k tokens\n",
|
||||
b.Charge.GetInput(), b.Charge.GetOutput(),
|
||||
)
|
||||
case globals.TimesBilling:
|
||||
return fmt.Sprintf("%f quota per request\n", b.Charge.GetLimit())
|
||||
case globals.NonBilling:
|
||||
return "no cost"
|
||||
}
|
||||
|
||||
return ""
|
||||
}
|
||||
|
||||
func (b *Buffer) SetPrompts(prompts interface{}) {
|
||||
b.Prompts = ToString(prompts)
|
||||
}
|
||||
|
||||
func (b *Buffer) Read() string {
|
||||
return b.Data
|
||||
}
|
||||
@ -247,5 +295,49 @@ func (b *Buffer) CountOutputToken(running bool) int {
|
||||
}
|
||||
|
||||
func (b *Buffer) CountToken() int {
|
||||
return b.CountInputToken() + b.CountOutputToken(false)
|
||||
return b.CountInputToken() + b.CountOutputToken(true)
|
||||
}
|
||||
|
||||
func (b *Buffer) GetDuration() float32 {
|
||||
if b.StartTime == nil {
|
||||
return 0
|
||||
}
|
||||
|
||||
return float32(time.Since(*b.StartTime).Seconds())
|
||||
}
|
||||
|
||||
func (b *Buffer) GetStartTime() *time.Time {
|
||||
return b.StartTime
|
||||
}
|
||||
|
||||
func (b *Buffer) GetPrompts() string {
|
||||
return b.Prompts
|
||||
}
|
||||
|
||||
func (b *Buffer) GetTokenName() string {
|
||||
if len(b.TokenName) == 0 {
|
||||
return globals.WebTokenType
|
||||
}
|
||||
|
||||
return b.TokenName
|
||||
}
|
||||
|
||||
func (b *Buffer) SetTokenName(tokenName string) {
|
||||
b.TokenName = tokenName
|
||||
}
|
||||
|
||||
func (b *Buffer) GetRecordPrompts() string {
|
||||
if !globals.AcceptPromptStore {
|
||||
return ""
|
||||
}
|
||||
|
||||
return b.GetPrompts()
|
||||
}
|
||||
|
||||
func (b *Buffer) GetRecordResponsePrompts() string {
|
||||
if !globals.AcceptPromptStore {
|
||||
return ""
|
||||
}
|
||||
|
||||
return b.Read()
|
||||
}
|
||||
|
22
utils/net.go
22
utils/net.go
@ -183,6 +183,28 @@ func Post(uri string, headers map[string]string, body interface{}, config ...glo
|
||||
return data, err
|
||||
}
|
||||
|
||||
func ToString(data interface{}) string {
|
||||
switch v := data.(type) {
|
||||
case string:
|
||||
return v
|
||||
case int, int8, int16, int32, int64:
|
||||
return fmt.Sprintf("%d", v)
|
||||
case uint, uint8, uint16, uint32, uint64:
|
||||
return fmt.Sprintf("%d", v)
|
||||
case float32, float64:
|
||||
return fmt.Sprintf("%f", v)
|
||||
case bool:
|
||||
return fmt.Sprintf("%t", v)
|
||||
default:
|
||||
data := Marshal(data)
|
||||
if len(data) > 0 {
|
||||
return data
|
||||
}
|
||||
|
||||
return fmt.Sprintf("%v", data)
|
||||
}
|
||||
}
|
||||
|
||||
func PostRaw(uri string, headers map[string]string, body interface{}, config ...globals.ProxyConfig) (data string, err error) {
|
||||
buffer, err := HttpRaw(uri, http.MethodPost, headers, ConvertBody(body), config)
|
||||
if err != nil {
|
||||
|
Loading…
Reference in New Issue
Block a user