feat: update and optimize tokenizer performance (#191)

Co-Authored-By: Minghan Zhang <112773885+zmh-program@users.noreply.github.com>
This commit is contained in:
Deng Junhai 2024-06-22 01:59:27 +08:00
parent 2024302316
commit c81b599e90
4 changed files with 34 additions and 26 deletions

View File

@ -2,10 +2,11 @@ package globals
import ( import (
"fmt" "fmt"
"strings"
"github.com/natefinch/lumberjack" "github.com/natefinch/lumberjack"
"github.com/sirupsen/logrus" "github.com/sirupsen/logrus"
"github.com/spf13/viper" "github.com/spf13/viper"
"strings"
) )
const DefaultLoggerFile = "chatnio.log" const DefaultLoggerFile = "chatnio.log"
@ -25,7 +26,7 @@ func (l *AppLogger) Format(entry *logrus.Entry) ([]byte, error) {
) )
if !viper.GetBool("log.ignore_console") { if !viper.GetBool("log.ignore_console") {
fmt.Println(data) fmt.Print(data)
} }
return []byte(data), nil return []byte(data), nil

View File

@ -154,7 +154,7 @@ func sendTranshipmentResponse(c *gin.Context, form RelayForm, messages []globals
}, },
Usage: Usage{ Usage: Usage{
PromptTokens: buffer.CountInputToken(), PromptTokens: buffer.CountInputToken(),
CompletionTokens: buffer.CountOutputToken(), CompletionTokens: buffer.CountOutputToken(false),
TotalTokens: buffer.CountToken(), TotalTokens: buffer.CountToken(),
}, },
Quota: utils.Multi[*float32](form.Official, nil, utils.ToPtr(buffer.GetQuota())), Quota: utils.Multi[*float32](form.Official, nil, utils.ToPtr(buffer.GetQuota())),
@ -205,7 +205,7 @@ func getStreamTranshipmentForm(id string, created int64, form RelayForm, data *g
}, },
Usage: Usage{ Usage: Usage{
PromptTokens: buffer.CountInputToken(), PromptTokens: buffer.CountInputToken(),
CompletionTokens: buffer.CountOutputToken(), CompletionTokens: buffer.CountOutputToken(true),
TotalTokens: buffer.CountToken(), TotalTokens: buffer.CountToken(),
}, },
Quota: utils.Multi[*float32](form.Official, nil, utils.ToPtr(buffer.GetQuota())), Quota: utils.Multi[*float32](form.Official, nil, utils.ToPtr(buffer.GetQuota())),

View File

@ -57,7 +57,7 @@ func initInputToken(model string, history []globals.Message) int {
}) })
} }
return CountTokenPrice(history, model) return NumTokensFromMessages(history, model, false)
} }
func NewBuffer(model string, history []globals.Message, charge Charge) *Buffer { func NewBuffer(model string, history []globals.Message, charge Charge) *Buffer {
@ -79,7 +79,7 @@ func (b *Buffer) GetCursor() int {
} }
func (b *Buffer) GetQuota() float32 { func (b *Buffer) GetQuota() float32 {
return b.Quota + CountOutputToken(b.Charge, b.Model, b.ReadTimes()) return b.Quota + CountOutputToken(b.Charge, b.CountOutputToken(false))
} }
func (b *Buffer) Write(data string) string { func (b *Buffer) Write(data string) string {
@ -197,11 +197,6 @@ func (b *Buffer) IsFunctionCalling() bool {
return b.FunctionCall != nil || b.ToolCalls != nil return b.FunctionCall != nil || b.ToolCalls != nil
} }
func (b *Buffer) WriteBytes(data []byte) []byte {
b.Write(string(data))
return data
}
func (b *Buffer) IsEmpty() bool { func (b *Buffer) IsEmpty() bool {
return b.Cursor == 0 && !b.IsFunctionCalling() return b.Cursor == 0 && !b.IsFunctionCalling()
} }
@ -241,10 +236,16 @@ func (b *Buffer) CountInputToken() int {
return b.InputTokens return b.InputTokens
} }
func (b *Buffer) CountOutputToken() int { func (b *Buffer) CountOutputToken(running bool) int {
return b.ReadTimes() * GetWeightByModel(b.Model) if running {
// performance optimization:
// if the buffer is still running, the output token counted using the times instead
return b.Times
}
return NumTokensFromResponse(b.Read(), b.Model)
} }
func (b *Buffer) CountToken() int { func (b *Buffer) CountToken() int {
return b.CountInputToken() + b.CountOutputToken() return b.CountInputToken() + b.CountOutputToken(false)
} }

View File

@ -3,8 +3,9 @@ package utils
import ( import (
"chat/globals" "chat/globals"
"fmt" "fmt"
"github.com/pkoukk/tiktoken-go"
"strings" "strings"
"github.com/pkoukk/tiktoken-go"
) )
// Using https://github.com/pkoukk/tiktoken-go // Using https://github.com/pkoukk/tiktoken-go
@ -45,9 +46,10 @@ func GetWeightByModel(model string) int {
} }
} }
} }
func NumTokensFromMessages(messages []globals.Message, model string) (tokens int) { func NumTokensFromMessages(messages []globals.Message, model string, responseType bool) (tokens int) {
tokensPerMessage := GetWeightByModel(model) tokensPerMessage := GetWeightByModel(model)
tkm, err := tiktoken.EncodingForModel(model) tkm, err := tiktoken.EncodingForModel(model)
if err != nil { if err != nil {
// the method above was deprecated, use the recall method instead // the method above was deprecated, use the recall method instead
// can not encode messages, use length of messages as a proxy for number of tokens // can not encode messages, use length of messages as a proxy for number of tokens
@ -59,16 +61,20 @@ func NumTokensFromMessages(messages []globals.Message, model string) (tokens int
if globals.DebugMode { if globals.DebugMode {
globals.Debug(fmt.Sprintf("[tiktoken] error encoding messages: %s (model: %s), using default model instead", err, model)) globals.Debug(fmt.Sprintf("[tiktoken] error encoding messages: %s (model: %s), using default model instead", err, model))
} }
return NumTokensFromMessages(messages, globals.GPT3Turbo0613) return NumTokensFromMessages(messages, globals.GPT3Turbo0613, responseType)
} }
for _, message := range messages { for _, message := range messages {
tokens += tokens += len(tkm.Encode(message.Content, nil, nil))
len(tkm.Encode(message.Content, nil, nil)) +
len(tkm.Encode(message.Role, nil, nil)) + if !responseType {
tokensPerMessage tokens += len(tkm.Encode(message.Role, nil, nil)) + tokensPerMessage
}
}
if !responseType {
tokens += 3 // every reply is primed with <|start|>assistant<|message|>
} }
tokens += 3 // every reply is primed with <|start|>assistant<|message|>
if globals.DebugMode { if globals.DebugMode {
globals.Debug(fmt.Sprintf("[tiktoken] num tokens from messages: %d (tokens per message: %d, model: %s)", tokens, tokensPerMessage, model)) globals.Debug(fmt.Sprintf("[tiktoken] num tokens from messages: %d (tokens per message: %d, model: %s)", tokens, tokensPerMessage, model))
@ -76,8 +82,8 @@ func NumTokensFromMessages(messages []globals.Message, model string) (tokens int
return tokens return tokens
} }
func CountTokenPrice(messages []globals.Message, model string) int { func NumTokensFromResponse(response string, model string) int {
return NumTokensFromMessages(messages, model) * GetWeightByModel(model) return NumTokensFromMessages([]globals.Message{{Content: response}}, model, true)
} }
func CountInputQuota(charge Charge, token int) float32 { func CountInputQuota(charge Charge, token int) float32 {
@ -88,10 +94,10 @@ func CountInputQuota(charge Charge, token int) float32 {
return 0 return 0
} }
func CountOutputToken(charge Charge, model string, token int) float32 { func CountOutputToken(charge Charge, token int) float32 {
switch charge.GetType() { switch charge.GetType() {
case globals.TokenBilling: case globals.TokenBilling:
return float32(token*GetWeightByModel(model)) / 1000 * charge.GetOutput() return float32(token) / 1000 * charge.GetOutput()
case globals.TimesBilling: case globals.TimesBilling:
return charge.GetOutput() return charge.GetOutput()
default: default: