mirror of
https://github.com/coaidev/coai.git
synced 2025-05-19 04:50:14 +09:00
fix: nil pointer dereference
error when carrying an image to a conversation (#221)
This commit is contained in:
parent
9cb958080b
commit
576213d21f
@ -73,12 +73,13 @@ type SearchState struct {
|
|||||||
}
|
}
|
||||||
|
|
||||||
type commonState struct {
|
type commonState struct {
|
||||||
Article []string `json:"article" mapstructure:"article"`
|
Article []string `json:"article" mapstructure:"article"`
|
||||||
Generation []string `json:"generation" mapstructure:"generation"`
|
Generation []string `json:"generation" mapstructure:"generation"`
|
||||||
Cache []string `json:"cache" mapstructure:"cache"`
|
Cache []string `json:"cache" mapstructure:"cache"`
|
||||||
Expire int64 `json:"expire" mapstructure:"expire"`
|
Expire int64 `json:"expire" mapstructure:"expire"`
|
||||||
Size int64 `json:"size" mapstructure:"size"`
|
Size int64 `json:"size" mapstructure:"size"`
|
||||||
ImageStore bool `json:"image_store" mapstructure:"imagestore"`
|
ImageStore bool `json:"image_store" mapstructure:"imagestore"`
|
||||||
|
PromptStore bool `json:"prompt_store" mapstructure:"promptstore"`
|
||||||
}
|
}
|
||||||
|
|
||||||
type SystemConfig struct {
|
type SystemConfig struct {
|
||||||
@ -114,6 +115,8 @@ func (c *SystemConfig) Load() {
|
|||||||
globals.CacheAcceptedSize = c.GetCacheAcceptedSize()
|
globals.CacheAcceptedSize = c.GetCacheAcceptedSize()
|
||||||
globals.AcceptImageStore = c.AcceptImageStore()
|
globals.AcceptImageStore = c.AcceptImageStore()
|
||||||
|
|
||||||
|
globals.AcceptPromptStore = c.Common.PromptStore
|
||||||
|
|
||||||
if c.General.PWAManifest == "" {
|
if c.General.PWAManifest == "" {
|
||||||
c.General.PWAManifest = utils.ReadPWAManifest()
|
c.General.PWAManifest = utils.ReadPWAManifest()
|
||||||
}
|
}
|
||||||
|
@ -48,3 +48,9 @@ const (
|
|||||||
HttpsProxyType
|
HttpsProxyType
|
||||||
Socks5ProxyType
|
Socks5ProxyType
|
||||||
)
|
)
|
||||||
|
|
||||||
|
const (
|
||||||
|
WebTokenType = "web"
|
||||||
|
ApiTokenType = "api"
|
||||||
|
SystemToken = "system"
|
||||||
|
)
|
||||||
|
@ -23,6 +23,7 @@ var CacheAcceptedModels []string
|
|||||||
var CacheAcceptedExpire int64
|
var CacheAcceptedExpire int64
|
||||||
var CacheAcceptedSize int64
|
var CacheAcceptedSize int64
|
||||||
var AcceptImageStore bool
|
var AcceptImageStore bool
|
||||||
|
var AcceptPromptStore bool
|
||||||
var CloseRegistration bool
|
var CloseRegistration bool
|
||||||
var CloseRelay bool
|
var CloseRelay bool
|
||||||
|
|
||||||
|
106
utils/buffer.go
106
utils/buffer.go
@ -2,7 +2,9 @@ package utils
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"chat/globals"
|
"chat/globals"
|
||||||
|
"fmt"
|
||||||
"strings"
|
"strings"
|
||||||
|
"time"
|
||||||
)
|
)
|
||||||
|
|
||||||
type Charge interface {
|
type Charge interface {
|
||||||
@ -28,7 +30,11 @@ type Buffer struct {
|
|||||||
ToolCalls *globals.ToolCalls `json:"tool_calls"`
|
ToolCalls *globals.ToolCalls `json:"tool_calls"`
|
||||||
ToolCallsCursor int `json:"tool_calls_cursor"`
|
ToolCallsCursor int `json:"tool_calls_cursor"`
|
||||||
FunctionCall *globals.FunctionCall `json:"function_call"`
|
FunctionCall *globals.FunctionCall `json:"function_call"`
|
||||||
|
StartTime *time.Time `json:"-"`
|
||||||
|
Prompts string `json:"prompts"`
|
||||||
|
TokenName string `json:"-"`
|
||||||
Charge Charge `json:"-"`
|
Charge Charge `json:"-"`
|
||||||
|
VisionRecall bool `json:"-"`
|
||||||
}
|
}
|
||||||
|
|
||||||
func initInputToken(model string, history []globals.Message) int {
|
func initInputToken(model string, history []globals.Message) int {
|
||||||
@ -71,6 +77,7 @@ func NewBuffer(model string, history []globals.Message, charge Charge) *Buffer {
|
|||||||
FunctionCall: nil,
|
FunctionCall: nil,
|
||||||
ToolCalls: nil,
|
ToolCalls: nil,
|
||||||
ToolCallsCursor: 0,
|
ToolCallsCursor: 0,
|
||||||
|
StartTime: ToPtr(time.Now()),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -79,6 +86,11 @@ func (b *Buffer) GetCursor() int {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (b *Buffer) GetQuota() float32 {
|
func (b *Buffer) GetQuota() float32 {
|
||||||
|
return b.Quota + CountOutputToken(b.Charge, b.CountOutputToken(true))
|
||||||
|
}
|
||||||
|
|
||||||
|
func (b *Buffer) GetRecordQuota() float32 {
|
||||||
|
// end of the buffer, the output token is counted using the times
|
||||||
return b.Quota + CountOutputToken(b.Charge, b.CountOutputToken(false))
|
return b.Quota + CountOutputToken(b.Charge, b.CountOutputToken(false))
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -106,15 +118,23 @@ func (b *Buffer) GetChunk() string {
|
|||||||
return b.Latest
|
return b.Latest
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (b *Buffer) InitVisionRecall() {
|
||||||
|
// set the vision recall flag to true to prevent the buffer from adding more images of retrying the channel
|
||||||
|
b.VisionRecall = true
|
||||||
|
}
|
||||||
|
|
||||||
func (b *Buffer) AddImage(image *Image) {
|
func (b *Buffer) AddImage(image *Image) {
|
||||||
if image != nil {
|
if image == nil || b.VisionRecall {
|
||||||
b.Images = append(b.Images, *image)
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
b.Images = append(b.Images, *image)
|
||||||
|
|
||||||
|
tokens := image.CountTokens(b.Model)
|
||||||
|
b.InputTokens += tokens
|
||||||
|
|
||||||
if b.Charge.IsBillingType(globals.TokenBilling) {
|
if b.Charge.IsBillingType(globals.TokenBilling) {
|
||||||
if image != nil {
|
b.Quota += float32(tokens) / 1000 * b.Charge.GetInput()
|
||||||
b.Quota += float32(image.CountTokens(b.Model)) * b.Charge.GetInput()
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -145,6 +165,13 @@ func hitTool(tool globals.ToolCall, tools globals.ToolCalls) (int, *globals.Tool
|
|||||||
return 0, nil
|
return 0, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func appendTool(tool globals.ToolCall, chunk globals.ToolCall) string {
|
||||||
|
from := ToString(tool.Function.Arguments)
|
||||||
|
to := ToString(chunk.Function.Arguments)
|
||||||
|
|
||||||
|
return from + to
|
||||||
|
}
|
||||||
|
|
||||||
func mixTools(source *globals.ToolCalls, target *globals.ToolCalls) *globals.ToolCalls {
|
func mixTools(source *globals.ToolCalls, target *globals.ToolCalls) *globals.ToolCalls {
|
||||||
if source == nil {
|
if source == nil {
|
||||||
return target
|
return target
|
||||||
@ -157,7 +184,7 @@ func mixTools(source *globals.ToolCalls, target *globals.ToolCalls) *globals.Too
|
|||||||
idx, hit := hitTool(tool, tools)
|
idx, hit := hitTool(tool, tools)
|
||||||
|
|
||||||
if hit != nil {
|
if hit != nil {
|
||||||
tools[idx].Function.Arguments += tool.Function.Arguments
|
tools[idx].Function.Arguments = appendTool(tools[idx], tool)
|
||||||
} else {
|
} else {
|
||||||
tools = append(tools, tool)
|
tools = append(tools, tool)
|
||||||
}
|
}
|
||||||
@ -209,6 +236,27 @@ func (b *Buffer) GetCharge() Charge {
|
|||||||
return b.Charge
|
return b.Charge
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (b *Buffer) ToChargeInfo() string {
|
||||||
|
switch b.Charge.GetType() {
|
||||||
|
case globals.TokenBilling:
|
||||||
|
return fmt.Sprintf(
|
||||||
|
"input tokens: %0.4f quota / 1k tokens\n"+
|
||||||
|
"output tokens: %0.4f quota / 1k tokens\n",
|
||||||
|
b.Charge.GetInput(), b.Charge.GetOutput(),
|
||||||
|
)
|
||||||
|
case globals.TimesBilling:
|
||||||
|
return fmt.Sprintf("%f quota per request\n", b.Charge.GetLimit())
|
||||||
|
case globals.NonBilling:
|
||||||
|
return "no cost"
|
||||||
|
}
|
||||||
|
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
|
||||||
|
func (b *Buffer) SetPrompts(prompts interface{}) {
|
||||||
|
b.Prompts = ToString(prompts)
|
||||||
|
}
|
||||||
|
|
||||||
func (b *Buffer) Read() string {
|
func (b *Buffer) Read() string {
|
||||||
return b.Data
|
return b.Data
|
||||||
}
|
}
|
||||||
@ -247,5 +295,49 @@ func (b *Buffer) CountOutputToken(running bool) int {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (b *Buffer) CountToken() int {
|
func (b *Buffer) CountToken() int {
|
||||||
return b.CountInputToken() + b.CountOutputToken(false)
|
return b.CountInputToken() + b.CountOutputToken(true)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (b *Buffer) GetDuration() float32 {
|
||||||
|
if b.StartTime == nil {
|
||||||
|
return 0
|
||||||
|
}
|
||||||
|
|
||||||
|
return float32(time.Since(*b.StartTime).Seconds())
|
||||||
|
}
|
||||||
|
|
||||||
|
func (b *Buffer) GetStartTime() *time.Time {
|
||||||
|
return b.StartTime
|
||||||
|
}
|
||||||
|
|
||||||
|
func (b *Buffer) GetPrompts() string {
|
||||||
|
return b.Prompts
|
||||||
|
}
|
||||||
|
|
||||||
|
func (b *Buffer) GetTokenName() string {
|
||||||
|
if len(b.TokenName) == 0 {
|
||||||
|
return globals.WebTokenType
|
||||||
|
}
|
||||||
|
|
||||||
|
return b.TokenName
|
||||||
|
}
|
||||||
|
|
||||||
|
func (b *Buffer) SetTokenName(tokenName string) {
|
||||||
|
b.TokenName = tokenName
|
||||||
|
}
|
||||||
|
|
||||||
|
func (b *Buffer) GetRecordPrompts() string {
|
||||||
|
if !globals.AcceptPromptStore {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
|
||||||
|
return b.GetPrompts()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (b *Buffer) GetRecordResponsePrompts() string {
|
||||||
|
if !globals.AcceptPromptStore {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
|
||||||
|
return b.Read()
|
||||||
}
|
}
|
||||||
|
22
utils/net.go
22
utils/net.go
@ -183,6 +183,28 @@ func Post(uri string, headers map[string]string, body interface{}, config ...glo
|
|||||||
return data, err
|
return data, err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func ToString(data interface{}) string {
|
||||||
|
switch v := data.(type) {
|
||||||
|
case string:
|
||||||
|
return v
|
||||||
|
case int, int8, int16, int32, int64:
|
||||||
|
return fmt.Sprintf("%d", v)
|
||||||
|
case uint, uint8, uint16, uint32, uint64:
|
||||||
|
return fmt.Sprintf("%d", v)
|
||||||
|
case float32, float64:
|
||||||
|
return fmt.Sprintf("%f", v)
|
||||||
|
case bool:
|
||||||
|
return fmt.Sprintf("%t", v)
|
||||||
|
default:
|
||||||
|
data := Marshal(data)
|
||||||
|
if len(data) > 0 {
|
||||||
|
return data
|
||||||
|
}
|
||||||
|
|
||||||
|
return fmt.Sprintf("%v", data)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func PostRaw(uri string, headers map[string]string, body interface{}, config ...globals.ProxyConfig) (data string, err error) {
|
func PostRaw(uri string, headers map[string]string, body interface{}, config ...globals.ProxyConfig) (data string, err error) {
|
||||||
buffer, err := HttpRaw(uri, http.MethodPost, headers, ConvertBody(body), config)
|
buffer, err := HttpRaw(uri, http.MethodPost, headers, ConvertBody(body), config)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
Loading…
Reference in New Issue
Block a user