mirror of
https://github.com/coaidev/coai.git
synced 2025-05-19 13:00:14 +09:00
61 lines
2.0 KiB
Go
61 lines
2.0 KiB
Go
package utils
|
|
|
|
import (
|
|
"chat/types"
|
|
"fmt"
|
|
"github.com/pkoukk/tiktoken-go"
|
|
"strings"
|
|
)
|
|
|
|
// Using https://github.com/pkoukk/tiktoken-go
|
|
// To count number of tokens of openai chat messages
|
|
// OpenAI Cookbook: https://github.com/openai/openai-cookbook/blob/main/examples/How_to_count_tokens_with_tiktoken.ipynb
|
|
|
|
func GetWeightByModel(model string) int {
|
|
switch model {
|
|
case "gpt-3.5-turbo-0613",
|
|
"gpt-3.5-turbo-16k-0613",
|
|
"gpt-4-0314",
|
|
"gpt-4-32k-0314",
|
|
"gpt-4-0613",
|
|
"gpt-4-32k-0613":
|
|
return 3
|
|
case "gpt-3.5-turbo-0301":
|
|
return 4 // every message follows <|start|>{role/name}\n{content}<|end|>\n
|
|
default:
|
|
if strings.Contains(model, "gpt-3.5-turbo") {
|
|
// warning: gpt-3.5-turbo may update over time. Returning num tokens assuming gpt-3.5-turbo-0613.
|
|
return GetWeightByModel("gpt-3.5-turbo-0613")
|
|
} else if strings.Contains(model, "gpt-4") {
|
|
// warning: gpt-4 may update over time. Returning num tokens assuming gpt-4-0613.
|
|
return GetWeightByModel("gpt-4-0613")
|
|
} else {
|
|
// not implemented: See https://github.com/openai/openai-python/blob/main/chatml.md for information on how messages are converted to tokens
|
|
panic(fmt.Errorf("not implemented for model %s", model))
|
|
}
|
|
}
|
|
}
|
|
func NumTokensFromMessages(messages []types.ChatGPTMessage, model string) (tokens int) {
|
|
weight := GetWeightByModel(model)
|
|
tkm, err := tiktoken.EncodingForModel(model)
|
|
if err != nil {
|
|
// can not encode messages, use length of messages as a proxy for number of tokens
|
|
// using rune instead of byte to account for unicode characters (e.g. emojis, non-english characters)
|
|
|
|
data := Marshal(messages)
|
|
return len([]rune(data)) * weight
|
|
}
|
|
|
|
for _, message := range messages {
|
|
tokens += weight
|
|
tokens += len(tkm.Encode(message.Content, nil, nil))
|
|
tokens += len(tkm.Encode(message.Role, nil, nil))
|
|
}
|
|
tokens += 3 // every reply is primed with <|start|>assistant<|message|>
|
|
return tokens
|
|
}
|
|
|
|
func CountTokenPrice(messages []types.ChatGPTMessage) int {
|
|
return NumTokensFromMessages(messages, "gpt-4")
|
|
}
|