update token counter and close channel when buffer end

This commit is contained in:
Zhang Minghan 2023-10-02 21:56:55 +08:00
parent a8b1082261
commit 83133b526f
2 changed files with 16 additions and 12 deletions

View File

@ -139,9 +139,9 @@ func sendTranshipmentResponse(c *gin.Context, form TranshipmentForm, id string,
}, },
}, },
Usage: Usage{ Usage: Usage{
PromptTokens: int(buffer.CountInputToken()), PromptTokens: buffer.CountInputToken(),
CompletionTokens: int(buffer.CountOutputToken()), CompletionTokens: buffer.CountOutputToken(),
TotalTokens: int(buffer.CountToken()), TotalTokens: buffer.CountToken(),
}, },
Quota: buffer.GetQuota(), Quota: buffer.GetQuota(),
}) })
@ -164,9 +164,9 @@ func getStreamTranshipmentForm(id string, created int64, form TranshipmentForm,
}, },
}, },
Usage: Usage{ Usage: Usage{
PromptTokens: int(buffer.CountInputToken()), PromptTokens: buffer.CountInputToken(),
CompletionTokens: int(buffer.CountOutputToken()), CompletionTokens: buffer.CountOutputToken(),
TotalTokens: int(buffer.CountToken()), TotalTokens: buffer.CountToken(),
}, },
Quota: buffer.GetQuota(), Quota: buffer.GetQuota(),
} }
@ -183,16 +183,18 @@ func sendStreamTranshipmentResponse(c *gin.Context, form TranshipmentForm, id st
Reversible: reversible && globals.IsGPT4Model(form.Model), Reversible: reversible && globals.IsGPT4Model(form.Model),
Token: form.MaxTokens, Token: form.MaxTokens,
}, func(data string) error { }, func(data string) error {
channel <- getStreamTranshipmentForm(id, created, form, data, buffer, false) channel <- getStreamTranshipmentForm(id, created, form, buffer.Write(data), buffer, false)
return nil return nil
}); err != nil { }); err != nil {
channel <- getStreamTranshipmentForm(id, created, form, fmt.Sprintf("Error: %s", err.Error()), buffer, true) channel <- getStreamTranshipmentForm(id, created, form, fmt.Sprintf("Error: %s", err.Error()), buffer, true)
CollectQuota(c, user, buffer.GetQuota(), reversible) CollectQuota(c, user, buffer.GetQuota(), reversible)
close(channel)
return return
} }
channel <- getStreamTranshipmentForm(id, created, form, "", buffer, true) channel <- getStreamTranshipmentForm(id, created, form, "", buffer, true)
CollectQuota(c, user, buffer.GetQuota(), reversible) CollectQuota(c, user, buffer.GetQuota(), reversible)
close(channel)
return return
}() }()
@ -201,6 +203,8 @@ func sendStreamTranshipmentResponse(c *gin.Context, form TranshipmentForm, id st
c.SSEvent("message", resp) c.SSEvent("message", resp)
return true return true
} }
w.Write([]byte("[DATA: DONE]"))
return false return false
}) })
} }

View File

@ -80,14 +80,14 @@ func (b *Buffer) ReadHistory() []globals.Message {
return b.History return b.History
} }
func (b *Buffer) CountInputToken() float32 { func (b *Buffer) CountInputToken() int {
return CountInputToken(b.Model, b.ReadHistory()) return GetWeightByModel(b.Model) * NumTokensFromMessages(b.History, b.Model)
} }
func (b *Buffer) CountOutputToken() float32 { func (b *Buffer) CountOutputToken() int {
return CountOutputToken(b.Model, b.ReadTimes()) return b.ReadTimes() * GetWeightByModel(b.Model)
} }
func (b *Buffer) CountToken() float32 { func (b *Buffer) CountToken() int {
return b.CountInputToken() + b.CountOutputToken() return b.CountInputToken() + b.CountOutputToken()
} }