update token counter and close channel when buffer end

This commit is contained in:
Zhang Minghan 2023-10-02 21:56:55 +08:00
parent a8b1082261
commit 83133b526f
2 changed files with 16 additions and 12 deletions

View File

@ -139,9 +139,9 @@ func sendTranshipmentResponse(c *gin.Context, form TranshipmentForm, id string,
},
},
Usage: Usage{
PromptTokens: int(buffer.CountInputToken()),
CompletionTokens: int(buffer.CountOutputToken()),
TotalTokens: int(buffer.CountToken()),
PromptTokens: buffer.CountInputToken(),
CompletionTokens: buffer.CountOutputToken(),
TotalTokens: buffer.CountToken(),
},
Quota: buffer.GetQuota(),
})
@ -164,9 +164,9 @@ func getStreamTranshipmentForm(id string, created int64, form TranshipmentForm,
},
},
Usage: Usage{
PromptTokens: int(buffer.CountInputToken()),
CompletionTokens: int(buffer.CountOutputToken()),
TotalTokens: int(buffer.CountToken()),
PromptTokens: buffer.CountInputToken(),
CompletionTokens: buffer.CountOutputToken(),
TotalTokens: buffer.CountToken(),
},
Quota: buffer.GetQuota(),
}
@ -183,16 +183,18 @@ func sendStreamTranshipmentResponse(c *gin.Context, form TranshipmentForm, id st
Reversible: reversible && globals.IsGPT4Model(form.Model),
Token: form.MaxTokens,
}, func(data string) error {
channel <- getStreamTranshipmentForm(id, created, form, data, buffer, false)
channel <- getStreamTranshipmentForm(id, created, form, buffer.Write(data), buffer, false)
return nil
}); err != nil {
channel <- getStreamTranshipmentForm(id, created, form, fmt.Sprintf("Error: %s", err.Error()), buffer, true)
CollectQuota(c, user, buffer.GetQuota(), reversible)
close(channel)
return
}
channel <- getStreamTranshipmentForm(id, created, form, "", buffer, true)
CollectQuota(c, user, buffer.GetQuota(), reversible)
close(channel)
return
}()
@ -201,6 +203,8 @@ func sendStreamTranshipmentResponse(c *gin.Context, form TranshipmentForm, id st
c.SSEvent("message", resp)
return true
}
w.Write([]byte("[DATA: DONE]"))
return false
})
}

View File

@ -80,14 +80,14 @@ func (b *Buffer) ReadHistory() []globals.Message {
return b.History
}
func (b *Buffer) CountInputToken() float32 {
return CountInputToken(b.Model, b.ReadHistory())
func (b *Buffer) CountInputToken() int {
return GetWeightByModel(b.Model) * NumTokensFromMessages(b.History, b.Model)
}
func (b *Buffer) CountOutputToken() float32 {
return CountOutputToken(b.Model, b.ReadTimes())
func (b *Buffer) CountOutputToken() int {
return b.ReadTimes() * GetWeightByModel(b.Model)
}
func (b *Buffer) CountToken() float32 {
func (b *Buffer) CountToken() int {
return b.CountInputToken() + b.CountOutputToken()
}