mirror of
https://github.com/coaidev/coai.git
synced 2025-05-19 21:10:18 +09:00
feat: update and optimize tokenizer performance (#191)
Co-Authored-By: Minghan Zhang <112773885+zmh-program@users.noreply.github.com>
This commit is contained in:
parent
2024302316
commit
4c3843b3be
@ -2,10 +2,11 @@ package globals
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
"github.com/natefinch/lumberjack"
|
||||
"github.com/sirupsen/logrus"
|
||||
"github.com/spf13/viper"
|
||||
"strings"
|
||||
)
|
||||
|
||||
const DefaultLoggerFile = "chatnio.log"
|
||||
@ -25,7 +26,7 @@ func (l *AppLogger) Format(entry *logrus.Entry) ([]byte, error) {
|
||||
)
|
||||
|
||||
if !viper.GetBool("log.ignore_console") {
|
||||
fmt.Println(data)
|
||||
fmt.Print(data)
|
||||
}
|
||||
|
||||
return []byte(data), nil
|
||||
|
@ -197,11 +197,6 @@ func (b *Buffer) IsFunctionCalling() bool {
|
||||
return b.FunctionCall != nil || b.ToolCalls != nil
|
||||
}
|
||||
|
||||
func (b *Buffer) WriteBytes(data []byte) []byte {
|
||||
b.Write(string(data))
|
||||
return data
|
||||
}
|
||||
|
||||
func (b *Buffer) IsEmpty() bool {
|
||||
return b.Cursor == 0 && !b.IsFunctionCalling()
|
||||
}
|
||||
@ -237,12 +232,12 @@ func (b *Buffer) SetInputTokens(tokens int) {
|
||||
b.InputTokens = tokens
|
||||
}
|
||||
|
||||
func (b *Buffer) CountInputToken() int {
|
||||
return b.InputTokens
|
||||
func (b *Buffer) CountOutputToken() int {
|
||||
return b.ReadTimes() * GetWeightByModel(b.Model)
|
||||
}
|
||||
|
||||
func (b *Buffer) CountOutputToken() int {
|
||||
return b.ReadTimes() * GetWeightByModel(b.Model)
|
||||
return b.CountInputToken() + b.CountOutputToken()
|
||||
}
|
||||
|
||||
func (b *Buffer) CountToken() int {
|
||||
|
@ -3,8 +3,9 @@ package utils
|
||||
import (
|
||||
"chat/globals"
|
||||
"fmt"
|
||||
"github.com/pkoukk/tiktoken-go"
|
||||
"strings"
|
||||
|
||||
"github.com/pkoukk/tiktoken-go"
|
||||
)
|
||||
|
||||
// Using https://github.com/pkoukk/tiktoken-go
|
||||
@ -45,9 +46,10 @@ func GetWeightByModel(model string) int {
|
||||
}
|
||||
}
|
||||
}
|
||||
func NumTokensFromMessages(messages []globals.Message, model string) (tokens int) {
|
||||
func NumTokensFromMessages(messages []globals.Message, model string, responseType bool) (tokens int) {
|
||||
tokensPerMessage := GetWeightByModel(model)
|
||||
tkm, err := tiktoken.EncodingForModel(model)
|
||||
|
||||
if err != nil {
|
||||
// the method above was deprecated, use the recall method instead
|
||||
// can not encode messages, use length of messages as a proxy for number of tokens
|
||||
@ -59,16 +61,20 @@ func NumTokensFromMessages(messages []globals.Message, model string) (tokens int
|
||||
if globals.DebugMode {
|
||||
globals.Debug(fmt.Sprintf("[tiktoken] error encoding messages: %s (model: %s), using default model instead", err, model))
|
||||
}
|
||||
return NumTokensFromMessages(messages, globals.GPT3Turbo0613)
|
||||
return NumTokensFromMessages(messages, globals.GPT3Turbo0613, responseType)
|
||||
}
|
||||
|
||||
for _, message := range messages {
|
||||
tokens +=
|
||||
len(tkm.Encode(message.Content, nil, nil)) +
|
||||
len(tkm.Encode(message.Role, nil, nil)) +
|
||||
tokensPerMessage
|
||||
tokens += len(tkm.Encode(message.Content, nil, nil))
|
||||
|
||||
if !responseType {
|
||||
tokens += len(tkm.Encode(message.Role, nil, nil)) + tokensPerMessage
|
||||
}
|
||||
}
|
||||
|
||||
if !responseType {
|
||||
tokens += 3 // every reply is primed with <|start|>assistant<|message|>
|
||||
}
|
||||
|
||||
if globals.DebugMode {
|
||||
globals.Debug(fmt.Sprintf("[tiktoken] num tokens from messages: %d (tokens per message: %d, model: %s)", tokens, tokensPerMessage, model))
|
||||
@ -76,8 +82,8 @@ func NumTokensFromMessages(messages []globals.Message, model string) (tokens int
|
||||
return tokens
|
||||
}
|
||||
|
||||
func CountTokenPrice(messages []globals.Message, model string) int {
|
||||
return NumTokensFromMessages(messages, model) * GetWeightByModel(model)
|
||||
func NumTokensFromResponse(response string, model string) int {
|
||||
return NumTokensFromMessages([]globals.Message{{Content: response}}, model, true)
|
||||
}
|
||||
|
||||
func CountInputQuota(charge Charge, token int) float32 {
|
||||
@ -88,10 +94,10 @@ func CountInputQuota(charge Charge, token int) float32 {
|
||||
return 0
|
||||
}
|
||||
|
||||
func CountOutputToken(charge Charge, model string, token int) float32 {
|
||||
func CountOutputToken(charge Charge, token int) float32 {
|
||||
switch charge.GetType() {
|
||||
case globals.TokenBilling:
|
||||
return float32(token*GetWeightByModel(model)) / 1000 * charge.GetOutput()
|
||||
return float32(token) / 1000 * charge.GetOutput()
|
||||
case globals.TimesBilling:
|
||||
return charge.GetOutput()
|
||||
default:
|
||||
|
Loading…
Reference in New Issue
Block a user