diff --git a/manager/transhipment.go b/manager/transhipment.go index f4e21da..c5e5bc9 100644 --- a/manager/transhipment.go +++ b/manager/transhipment.go @@ -139,9 +139,9 @@ func sendTranshipmentResponse(c *gin.Context, form TranshipmentForm, id string, }, }, Usage: Usage{ - PromptTokens: int(buffer.CountInputToken()), - CompletionTokens: int(buffer.CountOutputToken()), - TotalTokens: int(buffer.CountToken()), + PromptTokens: buffer.CountInputToken(), + CompletionTokens: buffer.CountOutputToken(), + TotalTokens: buffer.CountToken(), }, Quota: buffer.GetQuota(), }) @@ -164,9 +164,9 @@ func getStreamTranshipmentForm(id string, created int64, form TranshipmentForm, }, }, Usage: Usage{ - PromptTokens: int(buffer.CountInputToken()), - CompletionTokens: int(buffer.CountOutputToken()), - TotalTokens: int(buffer.CountToken()), + PromptTokens: buffer.CountInputToken(), + CompletionTokens: buffer.CountOutputToken(), + TotalTokens: buffer.CountToken(), }, Quota: buffer.GetQuota(), } @@ -183,16 +183,18 @@ func sendStreamTranshipmentResponse(c *gin.Context, form TranshipmentForm, id st Reversible: reversible && globals.IsGPT4Model(form.Model), Token: form.MaxTokens, }, func(data string) error { - channel <- getStreamTranshipmentForm(id, created, form, data, buffer, false) + channel <- getStreamTranshipmentForm(id, created, form, buffer.Write(data), buffer, false) return nil }); err != nil { channel <- getStreamTranshipmentForm(id, created, form, fmt.Sprintf("Error: %s", err.Error()), buffer, true) CollectQuota(c, user, buffer.GetQuota(), reversible) + close(channel) return } channel <- getStreamTranshipmentForm(id, created, form, "", buffer, true) CollectQuota(c, user, buffer.GetQuota(), reversible) + close(channel) return }() @@ -201,6 +203,8 @@ func sendStreamTranshipmentResponse(c *gin.Context, form TranshipmentForm, id st c.SSEvent("message", resp) return true } + + w.Write([]byte("[DATA: DONE]")) return false }) } diff --git a/utils/buffer.go b/utils/buffer.go index 89908b2..184e136 100644 --- a/utils/buffer.go +++ b/utils/buffer.go @@ -80,14 +80,14 @@ func (b *Buffer) ReadHistory() []globals.Message { return b.History } -func (b *Buffer) CountInputToken() float32 { - return CountInputToken(b.Model, b.ReadHistory()) +func (b *Buffer) CountInputToken() int { + return GetWeightByModel(b.Model) * NumTokensFromMessages(b.History, b.Model) } -func (b *Buffer) CountOutputToken() float32 { - return CountOutputToken(b.Model, b.ReadTimes()) +func (b *Buffer) CountOutputToken() int { + return b.ReadTimes() * GetWeightByModel(b.Model) } -func (b *Buffer) CountToken() float32 { +func (b *Buffer) CountToken() int { return b.CountInputToken() + b.CountOutputToken() }