diff --git a/channel/system.go b/channel/system.go index 9454f23..9bce7bd 100644 --- a/channel/system.go +++ b/channel/system.go @@ -73,12 +73,13 @@ type SearchState struct { } type commonState struct { - Article []string `json:"article" mapstructure:"article"` - Generation []string `json:"generation" mapstructure:"generation"` - Cache []string `json:"cache" mapstructure:"cache"` - Expire int64 `json:"expire" mapstructure:"expire"` - Size int64 `json:"size" mapstructure:"size"` - ImageStore bool `json:"image_store" mapstructure:"imagestore"` + Article []string `json:"article" mapstructure:"article"` + Generation []string `json:"generation" mapstructure:"generation"` + Cache []string `json:"cache" mapstructure:"cache"` + Expire int64 `json:"expire" mapstructure:"expire"` + Size int64 `json:"size" mapstructure:"size"` + ImageStore bool `json:"image_store" mapstructure:"imagestore"` + PromptStore bool `json:"prompt_store" mapstructure:"promptstore"` } type SystemConfig struct { @@ -114,6 +115,8 @@ func (c *SystemConfig) Load() { globals.CacheAcceptedSize = c.GetCacheAcceptedSize() globals.AcceptImageStore = c.AcceptImageStore() + globals.AcceptPromptStore = c.Common.PromptStore + if c.General.PWAManifest == "" { c.General.PWAManifest = utils.ReadPWAManifest() } diff --git a/globals/constant.go b/globals/constant.go index 497aba0..1f9dd5b 100644 --- a/globals/constant.go +++ b/globals/constant.go @@ -48,3 +48,9 @@ const ( HttpsProxyType Socks5ProxyType ) + +const ( + WebTokenType = "web" + ApiTokenType = "api" + SystemToken = "system" +) diff --git a/globals/variables.go b/globals/variables.go index 6460cba..8032b94 100644 --- a/globals/variables.go +++ b/globals/variables.go @@ -23,6 +23,7 @@ var CacheAcceptedModels []string var CacheAcceptedExpire int64 var CacheAcceptedSize int64 var AcceptImageStore bool +var AcceptPromptStore bool var CloseRegistration bool var CloseRelay bool diff --git a/utils/buffer.go b/utils/buffer.go index fe0af4b..601c636 100644 --- a/utils/buffer.go +++ b/utils/buffer.go @@ -2,7 +2,9 @@ package utils import ( "chat/globals" + "fmt" "strings" + "time" ) type Charge interface { @@ -28,7 +30,11 @@ type Buffer struct { ToolCalls *globals.ToolCalls `json:"tool_calls"` ToolCallsCursor int `json:"tool_calls_cursor"` FunctionCall *globals.FunctionCall `json:"function_call"` + StartTime *time.Time `json:"-"` + Prompts string `json:"prompts"` + TokenName string `json:"-"` Charge Charge `json:"-"` + VisionRecall bool `json:"-"` } func initInputToken(model string, history []globals.Message) int { @@ -71,6 +77,7 @@ func NewBuffer(model string, history []globals.Message, charge Charge) *Buffer { FunctionCall: nil, ToolCalls: nil, ToolCallsCursor: 0, + StartTime: ToPtr(time.Now()), } } @@ -79,6 +86,11 @@ func (b *Buffer) GetCursor() int { } func (b *Buffer) GetQuota() float32 { + return b.Quota + CountOutputToken(b.Charge, b.CountOutputToken(true)) +} + +func (b *Buffer) GetRecordQuota() float32 { + // end of the buffer, the output token is counted using the times return b.Quota + CountOutputToken(b.Charge, b.CountOutputToken(false)) } @@ -106,15 +118,23 @@ func (b *Buffer) GetChunk() string { return b.Latest } +func (b *Buffer) InitVisionRecall() { + // set the vision recall flag to true to prevent the buffer from adding more images of retrying the channel + b.VisionRecall = true +} + func (b *Buffer) AddImage(image *Image) { - if image != nil { - b.Images = append(b.Images, *image) + if image == nil || b.VisionRecall { + return } + b.Images = append(b.Images, *image) + + tokens := image.CountTokens(b.Model) + b.InputTokens += tokens + if b.Charge.IsBillingType(globals.TokenBilling) { - if image != nil { - b.Quota += float32(image.CountTokens(b.Model)) * b.Charge.GetInput() - } + b.Quota += float32(tokens) / 1000 * b.Charge.GetInput() } } @@ -145,6 +165,13 @@ func hitTool(tool globals.ToolCall, tools globals.ToolCalls) (int, *globals.Tool return 0, nil } +func appendTool(tool globals.ToolCall, chunk globals.ToolCall) string { + from := ToString(tool.Function.Arguments) + to := ToString(chunk.Function.Arguments) + + return from + to +} + func mixTools(source *globals.ToolCalls, target *globals.ToolCalls) *globals.ToolCalls { if source == nil { return target @@ -157,7 +184,7 @@ func mixTools(source *globals.ToolCalls, target *globals.ToolCalls) *globals.Too idx, hit := hitTool(tool, tools) if hit != nil { - tools[idx].Function.Arguments += tool.Function.Arguments + tools[idx].Function.Arguments = appendTool(tools[idx], tool) } else { tools = append(tools, tool) } @@ -209,6 +236,27 @@ func (b *Buffer) GetCharge() Charge { return b.Charge } +func (b *Buffer) ToChargeInfo() string { + switch b.Charge.GetType() { + case globals.TokenBilling: + return fmt.Sprintf( + "input tokens: %0.4f quota / 1k tokens\n"+ + "output tokens: %0.4f quota / 1k tokens\n", + b.Charge.GetInput(), b.Charge.GetOutput(), + ) + case globals.TimesBilling: + return fmt.Sprintf("%f quota per request\n", b.Charge.GetLimit()) + case globals.NonBilling: + return "no cost" + } + + return "" +} + +func (b *Buffer) SetPrompts(prompts interface{}) { + b.Prompts = ToString(prompts) +} + func (b *Buffer) Read() string { return b.Data } @@ -247,5 +295,49 @@ func (b *Buffer) CountOutputToken(running bool) int { } func (b *Buffer) CountToken() int { - return b.CountInputToken() + b.CountOutputToken(false) + return b.CountInputToken() + b.CountOutputToken(true) +} + +func (b *Buffer) GetDuration() float32 { + if b.StartTime == nil { + return 0 + } + + return float32(time.Since(*b.StartTime).Seconds()) +} + +func (b *Buffer) GetStartTime() *time.Time { + return b.StartTime +} + +func (b *Buffer) GetPrompts() string { + return b.Prompts +} + +func (b *Buffer) GetTokenName() string { + if len(b.TokenName) == 0 { + return globals.WebTokenType + } + + return b.TokenName +} + +func (b *Buffer) SetTokenName(tokenName string) { + b.TokenName = tokenName +} + +func (b *Buffer) GetRecordPrompts() string { + if !globals.AcceptPromptStore { + return "" + } + + return b.GetPrompts() +} + +func (b *Buffer) GetRecordResponsePrompts() string { + if !globals.AcceptPromptStore { + return "" + } + + return b.Read() } diff --git a/utils/net.go b/utils/net.go index 015d95c..350049c 100644 --- a/utils/net.go +++ b/utils/net.go @@ -183,6 +183,28 @@ func Post(uri string, headers map[string]string, body interface{}, config ...glo return data, err } +func ToString(data interface{}) string { + switch v := data.(type) { + case string: + return v + case int, int8, int16, int32, int64: + return fmt.Sprintf("%d", v) + case uint, uint8, uint16, uint32, uint64: + return fmt.Sprintf("%d", v) + case float32, float64: + return fmt.Sprintf("%f", v) + case bool: + return fmt.Sprintf("%t", v) + default: + data := Marshal(data) + if len(data) > 0 { + return data + } + + return fmt.Sprintf("%v", data) + } +} + func PostRaw(uri string, headers map[string]string, body interface{}, config ...globals.ProxyConfig) (data string, err error) { buffer, err := HttpRaw(uri, http.MethodPost, headers, ConvertBody(body), config) if err != nil {