diff --git a/api/buffer.go b/api/buffer.go
index 8e4872b..42c634c 100644
--- a/api/buffer.go
+++ b/api/buffer.go
@@ -3,23 +3,24 @@ package api
import (
"chat/auth"
"chat/types"
- "chat/utils"
)
type Buffer struct {
- Enable bool `json:"enable"`
+ Model string `json:"model"`
Quota float32 `json:"quota"`
Data string `json:"data"`
Cursor int `json:"cursor"`
Times int `json:"times"`
}
-func NewBuffer(enable bool, history []types.ChatGPTMessage) *Buffer {
- buffer := &Buffer{Data: "", Cursor: 0, Times: 0, Enable: enable}
- if enable {
- buffer.Quota = auth.CountInputToken(utils.CountTokenPrice(history))
+func NewBuffer(model string, history []types.ChatGPTMessage) *Buffer {
+ return &Buffer{
+ Data: "",
+ Cursor: 0,
+ Times: 0,
+ Model: model,
+ Quota: auth.CountInputToken(model, history),
}
- return buffer
}
func (b *Buffer) GetCursor() int {
@@ -27,10 +28,7 @@ func (b *Buffer) GetCursor() int {
}
func (b *Buffer) GetQuota() float32 {
- if !b.Enable {
- return 0.
- }
- return b.Quota + auth.CountOutputToken(b.ReadTimes())
+ return b.Quota + auth.CountOutputToken(b.Model, b.ReadTimes())
}
func (b *Buffer) Write(data string) string {
diff --git a/api/chat.go b/api/chat.go
index 0b66fa5..da084ad 100644
--- a/api/chat.go
+++ b/api/chat.go
@@ -29,15 +29,15 @@ func SendSegmentMessage(conn *websocket.Conn, message interface{}) {
_ = conn.WriteMessage(websocket.TextMessage, []byte(utils.ToJson(message)))
}
-func GetErrorQuota(isGPT4 bool) float32 {
- if isGPT4 {
+func GetErrorQuota(model string) float32 {
+ if types.IsGPT4Model(model) {
return -0xe // special value for error
} else {
return 0
}
}
-func TextChat(db *sql.DB, cache *redis.Client, user *auth.User, conn *websocket.Conn, instance *conversation.Conversation) string {
+func GetTextSegment(instance *conversation.Conversation) (string, []types.ChatGPTMessage) {
var keyword string
var segment []types.ChatGPTMessage
@@ -46,11 +46,16 @@ func TextChat(db *sql.DB, cache *redis.Client, user *auth.User, conn *websocket.
} else {
segment = conversation.CopyMessage(instance.GetMessageSegment(12))
}
+ return keyword, segment
+}
+func TextChat(db *sql.DB, cache *redis.Client, user *auth.User, conn *websocket.Conn, instance *conversation.Conversation) string {
+ keyword, segment := GetTextSegment(instance)
SendSegmentMessage(conn, types.ChatSegmentResponse{Keyword: keyword, End: false})
- isProPlan := auth.CanEnableSubscription(db, cache, user)
- if instance.IsEnableGPT4() && (!isProPlan) && (!auth.CanEnableGPT4(db, user)) {
+ model := instance.GetModel()
+ useReverse := auth.CanEnableSubscription(db, cache, user)
+ if !auth.CanEnableModelWithSubscription(db, user, model, useReverse) {
SendSegmentMessage(conn, types.ChatSegmentResponse{
Message: defaultQuotaMessage,
Quota: 0,
@@ -59,9 +64,9 @@ func TextChat(db *sql.DB, cache *redis.Client, user *auth.User, conn *websocket.
return defaultQuotaMessage
}
- buffer := NewBuffer(instance.IsEnableGPT4(), segment)
- StreamRequest(instance.IsEnableGPT4(), isProPlan, segment,
- utils.Multi(instance.IsEnableGPT4() || isProPlan, -1, 2000),
+ buffer := NewBuffer(model, segment)
+ StreamRequest(model, useReverse, segment,
+ utils.Multi(types.IsGPT4Model(model) || useReverse, -1, 2000),
func(resp string) {
SendSegmentMessage(conn, types.ChatSegmentResponse{
Message: buffer.Write(resp),
@@ -70,19 +75,19 @@ func TextChat(db *sql.DB, cache *redis.Client, user *auth.User, conn *websocket.
})
})
if buffer.IsEmpty() {
- if isProPlan {
+ if useReverse {
auth.DecreaseSubscriptionUsage(cache, user)
}
SendSegmentMessage(conn, types.ChatSegmentResponse{
Message: defaultErrorMessage,
- Quota: GetErrorQuota(instance.IsEnableGPT4()),
+ Quota: GetErrorQuota(model),
End: true,
})
return defaultErrorMessage
}
// collect quota
- if !isProPlan {
+ if !useReverse {
user.UseQuota(db, buffer.GetQuota())
}
SendSegmentMessage(conn, types.ChatSegmentResponse{End: true, Quota: buffer.GetQuota()})
diff --git a/api/stream.go b/api/stream.go
index f46db29..7013f58 100644
--- a/api/stream.go
+++ b/api/stream.go
@@ -105,14 +105,25 @@ func NativeStreamRequest(model string, endpoint string, apikeys string, messages
}
}
-func StreamRequest(enableGPT4 bool, isProPlan bool, messages []types.ChatGPTMessage, token int, callback func(string)) {
- if enableGPT4 {
- if isProPlan {
+func StreamRequest(model string, enableReverse bool, messages []types.ChatGPTMessage, token int, callback func(string)) {
+ switch model {
+ case types.GPT4,
+ types.GPT40314,
+ types.GPT40613:
+ if enableReverse {
NativeStreamRequest(viper.GetString("openai.reverse"), viper.GetString("openai.pro_endpoint"), viper.GetString("openai.pro"), messages, token, callback)
} else {
- NativeStreamRequest("gpt-4", viper.GetString("openai.gpt4_endpoint"), viper.GetString("openai.gpt4"), messages, token, callback)
+ NativeStreamRequest(model, viper.GetString("openai.gpt4_endpoint"), viper.GetString("openai.gpt4"), messages, token, callback)
}
- } else {
- NativeStreamRequest("gpt-3.5-turbo-16k-0613", viper.GetString("openai.user_endpoint"), viper.GetString("openai.user"), messages, token, callback)
+ case types.GPT432k,
+ types.GPT432k0613,
+ types.GPT432k0314:
+ NativeStreamRequest(model, viper.GetString("openai.gpt4_endpoint"), viper.GetString("openai.gpt4"), messages, token, callback)
+ case types.GPT3Turbo16k,
+ types.GPT3Turbo16k0301,
+ types.GPT3Turbo16k0613:
+ NativeStreamRequest(types.GPT3Turbo16k, viper.GetString("openai.user_endpoint"), viper.GetString("openai.user"), messages, token, callback)
+ default:
+ NativeStreamRequest(types.GPT3Turbo, viper.GetString("openai.anonymous_endpoint"), viper.GetString("openai.anonymous"), messages, token, callback)
}
}
diff --git a/app/src/i18n.ts b/app/src/i18n.ts
index 688a449..026a363 100644
--- a/app/src/i18n.ts
+++ b/app/src/i18n.ts
@@ -62,7 +62,6 @@ const resources = {
buy: "Buy {{amount}} points",
dalle: "DALL·E Image Generator",
"dalle-free": "5 free quotas per day",
- gpt4: "GPT-4",
flex: "Flexible Billing",
input: "Input",
output: "Output",
@@ -106,7 +105,7 @@ const resources = {
"free-web": "web searching feature",
"free-conversation": "conversation storage",
"free-api": "API calls",
- "pro-gpt4": "GPT-4 10 requests per day",
+ "pro-gpt4": "GPT-4 50 requests per day",
"pro-dalle": "50 quotas per day",
"pro-service": "Priority Service Support",
"pro-thread": "Concurrency Increase",
@@ -207,7 +206,6 @@ const resources = {
buy: "购买 {{amount}} 点数",
dalle: "DALL·E AI 绘图",
"dalle-free": "每天 5 次免费绘图配额",
- gpt4: "GPT-4",
flex: "灵活计费",
input: "输入",
output: "输出",
@@ -250,7 +248,7 @@ const resources = {
"free-web": "联网搜索功能",
"free-conversation": "对话存储记录",
"free-api": "API 调用",
- "pro-gpt4": "GPT-4 每日请求 10 次",
+ "pro-gpt4": "GPT-4 每日请求 50 次",
"pro-dalle": "每日 50 次绘图",
"pro-service": "优先服务支持",
"pro-thread": "并发数提升",
@@ -354,7 +352,6 @@ const resources = {
buy: "Купить {{amount}} очков",
dalle: "Генератор изображений DALL·E",
"dalle-free": "5 бесплатных квот в день",
- gpt4: "GPT-4",
flex: "Гибкая тарификация",
input: "Вход",
output: "Выход",
@@ -399,7 +396,7 @@ const resources = {
"free-web": "веб-поиск",
"free-conversation": "хранение разговоров",
"free-api": "API вызовы",
- "pro-gpt4": "GPT-4 10 запросов в день",
+ "pro-gpt4": "GPT-4 50 запросов в день",
"pro-dalle": "50 квот в день",
"pro-service": "Приоритетная служба поддержки",
"pro-thread": "Увеличение параллелизма",
diff --git a/app/src/routes/Quota.tsx b/app/src/routes/Quota.tsx
index 287f211..ca971e7 100644
--- a/app/src/routes/Quota.tsx
+++ b/app/src/routes/Quota.tsx
@@ -248,7 +248,7 @@ function Quota() {
-
{t("buy.gpt4")}
+
GPT-4
{t("buy.flex")}
@@ -276,6 +276,37 @@ function Quota() {
4.3 / 1k token
+
+
+
+
GPT-4-32K
+
+
+ {t("buy.flex")}
+
+
+
+
+
+ {t("buy.input")}
+
+
+
+
+ 4.2 / 1k token
+
+
+
+
+
+ {t("buy.output")}
+
+
+
+
+ 8.6 / 1k token
+
+
diff --git a/app/src/routes/Subscription.tsx b/app/src/routes/Subscription.tsx
index cb2e984..e5afe29 100644
--- a/app/src/routes/Subscription.tsx
+++ b/app/src/routes/Subscription.tsx
@@ -46,6 +46,8 @@ import { buySubscription } from "../conversation/addition.ts";
function calc_prize(month: number): number {
if (month >= 12) {
+ return 8 * month * 0.8;
+ } else if (month >= 6) {
return 8 * month * 0.9;
}
return 8 * month;
@@ -101,17 +103,22 @@ function Upgrade({ children }: UpgradeProps) {
{t(`sub.time.1`)}
{t(`sub.time.3`)}
- {t(`sub.time.6`)}
+
+ {t(`sub.time.6`)}
+
+ {t(`percent`, { cent: 9 })}
+
+
{t(`sub.time.12`)}
- {t(`percent`, { cent: 9 })}
+ {t(`percent`, { cent: 8 })}
- {t("sub.price", { price: calc_prize(month) })}
+ {t("sub.price", { price: calc_prize(month).toFixed(2) })}
diff --git a/auth/subscription.go b/auth/subscription.go
index 4bc3db6..c749aca 100644
--- a/auth/subscription.go
+++ b/auth/subscription.go
@@ -29,7 +29,7 @@ func BuySubscription(db *sql.DB, user *User, month int) bool {
func IncreaseSubscriptionUsage(cache *redis.Client, user *User) bool {
today := time.Now().Format("2006-01-02")
- return utils.IncrWithLimit(cache, fmt.Sprintf(":subscription-usage:%s:%d", today, user.ID), 1, 999, 60*60*24) // 1 day
+ return utils.IncrWithLimit(cache, fmt.Sprintf(":subscription-usage:%s:%d", today, user.ID), 1, 50, 60*60*24) // 1 day
}
func DecreaseSubscriptionUsage(cache *redis.Client, user *User) bool {
diff --git a/auth/usage.go b/auth/usage.go
index 3f66f6b..567fcfb 100644
--- a/auth/usage.go
+++ b/auth/usage.go
@@ -1,6 +1,8 @@
package auth
import (
+ "chat/types"
+ "chat/utils"
"database/sql"
)
@@ -13,18 +15,46 @@ import (
// $0.03 / 1K tokens $0.06 / 1K tokens
// ¥0.21 / 1K tokens ¥0.43 / 1K tokens
// 2.1 nio / 1K tokens 4.3 nio / 1K tokens
+//
+// GPT-4 price (32k-context)
+// Input Output
+// $0.06 / 1K tokens $0.12 / 1K tokens
+// ¥0.43 / 1K tokens ¥0.86 / 1K tokens
+// 4.3 nio / 1K tokens 8.6 nio / 1K tokens
// Dalle price (512x512)
// $0.018 / per image
// ¥0.13 / per image
// 1 nio / per image
-func CountInputToken(n int) float32 {
- return float32(n) / 1000 * 2.1
+func CountInputToken(model string, v []types.ChatGPTMessage) float32 {
+ switch model {
+ case types.GPT3Turbo:
+ return 0
+ case types.GPT3Turbo16k:
+ return 0
+ case types.GPT4:
+ return float32(utils.CountTokenPrice(v, model)) / 1000 * 2.1
+ case types.GPT432k:
+ return float32(utils.CountTokenPrice(v, model)) / 1000 * 2.1 * 2
+ default:
+ return 0
+ }
}
-func CountOutputToken(n int) float32 {
- return float32(n) / 1000 * 4.3
+func CountOutputToken(model string, t int) float32 {
+ switch model {
+ case types.GPT3Turbo:
+ return 0
+ case types.GPT3Turbo16k:
+ return 0
+ case types.GPT4:
+ return float32(t*utils.GetWeightByModel(model)) / 1000 * 4.3
+ case types.GPT432k:
+ return float32(t*utils.GetWeightByModel(model)) / 1000 * 8.6
+ default:
+ return 0
+ }
}
func ReduceDalle(db *sql.DB, user *User) bool {
@@ -34,8 +64,24 @@ func ReduceDalle(db *sql.DB, user *User) bool {
return user.UseQuota(db, 1)
}
-func CanEnableGPT4(db *sql.DB, user *User) bool {
- return user.GetQuota(db) >= 5
+func CanEnableModel(db *sql.DB, user *User, model string) bool {
+ switch model {
+ case types.GPT4, types.GPT40613, types.GPT40314:
+ return user.GetQuota(db) >= 5
+ case types.GPT432k, types.GPT432k0613, types.GPT432k0314:
+ return user.GetQuota(db) >= 50
+ default:
+ return true
+ }
+}
+
+func CanEnableModelWithSubscription(db *sql.DB, user *User, model string, useReverse bool) bool {
+ if utils.Contains(model, types.GPT4Array) {
+ if useReverse {
+ return true
+ }
+ }
+ return CanEnableModel(db, user, model)
}
func BuyQuota(db *sql.DB, user *User, quota int) bool {
diff --git a/conversation/conversation.go b/conversation/conversation.go
index e681a28..1509d22 100644
--- a/conversation/conversation.go
+++ b/conversation/conversation.go
@@ -9,42 +9,42 @@ import (
)
type Conversation struct {
- UserID int64 `json:"user_id"`
- Id int64 `json:"id"`
- Name string `json:"name"`
- Message []types.ChatGPTMessage `json:"message"`
- EnableGPT4 bool `json:"enable_gpt4"`
- EnableWeb bool `json:"enable_web"`
+ UserID int64 `json:"user_id"`
+ Id int64 `json:"id"`
+ Name string `json:"name"`
+ Message []types.ChatGPTMessage `json:"message"`
+ Model string `json:"model"`
+ EnableWeb bool `json:"enable_web"`
}
type FormMessage struct {
Type string `json:"type"` // ping
Message string `json:"message"`
Web bool `json:"web"`
- GPT4 bool `json:"gpt4"`
+ Model string `json:"model"`
}
func NewConversation(db *sql.DB, id int64) *Conversation {
return &Conversation{
- UserID: id,
- Id: GetConversationLengthByUserID(db, id) + 1,
- Name: "new chat",
- Message: []types.ChatGPTMessage{},
- EnableGPT4: false,
- EnableWeb: false,
+ UserID: id,
+ Id: GetConversationLengthByUserID(db, id) + 1,
+ Name: "new chat",
+ Message: []types.ChatGPTMessage{},
+ Model: types.GPT3Turbo,
+ EnableWeb: false,
}
}
-func (c *Conversation) IsEnableGPT4() bool {
- return c.EnableGPT4
+func (c *Conversation) GetModel() string {
+ return c.Model
}
func (c *Conversation) IsEnableWeb() bool {
return c.EnableWeb
}
-func (c *Conversation) SetEnableGPT4(enable bool) {
- c.EnableGPT4 = enable
+func (c *Conversation) SetModel(model string) {
+ c.Model = model
}
func (c *Conversation) SetEnableWeb(enable bool) {
@@ -141,7 +141,7 @@ func (c *Conversation) AddMessageFromUserForm(data []byte) (string, error) {
}
c.AddMessageFromUser(form.Message)
- c.SetEnableGPT4(form.GPT4)
+ c.SetModel(form.Model)
c.SetEnableWeb(form.Web)
return form.Message, nil
}
diff --git a/generation/api.go b/generation/api.go
index c7cf81c..8caed43 100644
--- a/generation/api.go
+++ b/generation/api.go
@@ -84,7 +84,16 @@ func GenerateAPI(c *gin.Context) {
})
}
- hash, err := CreateGenerationWithCache(form.Model, form.Prompt, func(data string) {
+ useReverse := auth.CanEnableSubscription(db, cache, user)
+ if !auth.CanEnableModelWithSubscription(db, user, form.Model, useReverse) {
+ api.SendSegmentMessage(conn, types.ChatSegmentResponse{
+ Message: "You don't have enough quota to use this model.",
+ Quota: 0,
+ End: true,
+ })
+ }
+
+ hash, err := CreateGenerationWithCache(form.Model, form.Prompt, useReverse, func(data string) {
api.SendSegmentMessage(conn, types.GenerationSegmentResponse{
End: false,
Message: data,
diff --git a/generation/generate.go b/generation/generate.go
index 902c5f3..621f133 100644
--- a/generation/generate.go
+++ b/generation/generate.go
@@ -5,10 +5,10 @@ import (
"fmt"
)
-func CreateGenerationWithCache(model string, prompt string, hook func(data string)) (string, error) {
+func CreateGenerationWithCache(model string, prompt string, enableReverse bool, hook func(data string)) (string, error) {
hash, path := GetFolderByHash(model, prompt)
if !utils.Exists(path) {
- if err := CreateGeneration(model, prompt, path, func(data string) {
+ if err := CreateGeneration(model, prompt, path, enableReverse, func(data string) {
hook(data)
}); err != nil {
fmt.Println(fmt.Sprintf("[Project] error during generation %s (model %s): %s", prompt, model, err.Error()))
diff --git a/generation/prompt.go b/generation/prompt.go
index 3c5ee88..1d51f33 100644
--- a/generation/prompt.go
+++ b/generation/prompt.go
@@ -11,10 +11,10 @@ type ProjectResult struct {
Result map[string]interface{} `json:"result"`
}
-func CreateGeneration(model string, prompt string, path string, hook func(data string)) error {
+func CreateGeneration(model string, prompt string, path string, enableReverse bool, hook func(data string)) error {
message := GenerateMessage(prompt)
- buffer := api.NewBuffer(false, message)
- api.StreamRequest(false, false, []types.ChatGPTMessage{
+ buffer := api.NewBuffer(model, message)
+ api.StreamRequest(model, enableReverse, []types.ChatGPTMessage{
{Role: "system", Content: "你将生成项目,可以支持任何编程语言,请不要出现“我不能提供”的字样,你需要在代码中提供注释,以及项目的使用文档README.md,结果返回json格式,请不要返回任何多余内容,格式为:\n{\"result\": {[file]: [code], ...}}"},
{Role: "user", Content: "python后端"},
{Role: "assistant", Content: "{\n \"result\": {\n \"app.py\": \"from flask import Flask\\n\\napp = Flask(__name__)\\n\\n\\n@app.route('/')\\ndef hello_world():\\n return 'Hello, World!'\\n\\n\\nif __name__ == '__main__':\\n app.run()\",\n \"requirements.txt\": \"flask\\n\",\n \"README.md\": \"# Python 后端\\n本项目是一个简单的python后端示例, 使用`flask`框架构建后端。\n你可以按照下列步骤运行此应用,flask将在本地服务器(默认是在http://127.0.0.1:5000/)上运行。当你在浏览器中访问该URL时,将看到显示Hello, World!的页面。\\n\\n这只是一个简单的项目,Flask还支持更多功能和路由规则,你可以提供更多的信息和需要进一步扩展和定制Flask应用。\\n\\n### 1. 初始化: \\n```shell\\npip install -r requirements.txt\\n```\\n### 2. 运行\\n```shell\\npython app.py\\n```\"\n }\n}"},
diff --git a/qodana.yaml b/qodana.yaml
new file mode 100644
index 0000000..215d808
--- /dev/null
+++ b/qodana.yaml
@@ -0,0 +1,29 @@
+#-------------------------------------------------------------------------------#
+# Qodana analysis is configured by qodana.yaml file #
+# https://www.jetbrains.com/help/qodana/qodana-yaml.html #
+#-------------------------------------------------------------------------------#
+version: "1.0"
+
+#Specify inspection profile for code analysis
+profile:
+ name: qodana.starter
+
+#Enable inspections
+#include:
+# - name:
+
+#Disable inspections
+#exclude:
+# - name:
+# paths:
+# -
+
+#Execute shell command before Qodana execution (Applied in CI/CD pipeline)
+#bootstrap: sh ./prepare-qodana.sh
+
+#Install IDE plugins before Qodana execution (Applied in CI/CD pipeline)
+#plugins:
+# - id: #(plugin id can be found at https://plugins.jetbrains.com)
+
+#Specify Qodana linter for analysis (Applied in CI/CD pipeline)
+linter: jetbrains/qodana-go:latest
diff --git a/types/globals.go b/types/globals.go
new file mode 100644
index 0000000..3570a6a
--- /dev/null
+++ b/types/globals.go
@@ -0,0 +1,51 @@
+package types
+
+import "chat/utils"
+
+const (
+ GPT3Turbo = "gpt-3.5-turbo"
+ GPT3Turbo0613 = "gpt-3.5-turbo-0613"
+ GPT3Turbo0301 = "gpt-3.5-turbo-0301"
+ GPT3Turbo16k = "gpt-3.5-turbo-16k"
+ GPT3Turbo16k0613 = "gpt-3.5-turbo-16k-0613"
+ GPT3Turbo16k0301 = "gpt-3.5-turbo-16k-0301"
+ GPT4 = "gpt-4"
+ GPT40314 = "gpt-4-0314"
+ GPT40613 = "gpt-4-0613"
+ GPT432k = "gpt-4-32k"
+ GPT432k0314 = "gpt-4-32k-0314"
+ GPT432k0613 = "gpt-4-32k-0613"
+ Dalle = "dalle"
+)
+
+var GPT3TurboArray = []string{
+ GPT3Turbo,
+ GPT3Turbo0613,
+ GPT3Turbo0301,
+}
+
+var GPT3Turbo16kArray = []string{
+ GPT3Turbo16k,
+ GPT3Turbo16k0613,
+ GPT3Turbo16k0301,
+}
+
+var GPT4Array = []string{
+ GPT4,
+ GPT40314,
+ GPT40613,
+}
+
+var GPT432kArray = []string{
+ GPT432k,
+ GPT432k0314,
+ GPT432k0613,
+}
+
+func IsGPT4Model(model string) bool {
+ return utils.Contains(model, GPT4Array) || utils.Contains(model, GPT432kArray)
+}
+
+func IsGPT3TurboModel(model string) bool {
+ return utils.Contains(model, GPT3TurboArray) || utils.Contains(model, GPT3Turbo16kArray)
+}
diff --git a/utils/tokenizer.go b/utils/tokenizer.go
index 5bac7d7..44d8e5c 100644
--- a/utils/tokenizer.go
+++ b/utils/tokenizer.go
@@ -13,22 +13,29 @@ import (
func GetWeightByModel(model string) int {
switch model {
- case "gpt-3.5-turbo-0613",
- "gpt-3.5-turbo-16k-0613",
- "gpt-4-0314",
- "gpt-4-32k-0314",
- "gpt-4-0613",
- "gpt-4-32k-0613":
+ case types.GPT432k,
+ types.GPT432k0613,
+ types.GPT432k0314:
+ return 3 * 10
+ case types.GPT3Turbo,
+ types.GPT3Turbo0613,
+
+ types.GPT3Turbo16k,
+ types.GPT3Turbo16k0613,
+
+ types.GPT4,
+ types.GPT40314,
+ types.GPT40613:
return 3
- case "gpt-3.5-turbo-0301":
+ case types.GPT3Turbo0301, types.GPT3Turbo16k0301:
return 4 // every message follows <|start|>{role/name}\n{content}<|end|>\n
default:
- if strings.Contains(model, "gpt-3.5-turbo") {
+ if strings.Contains(model, types.GPT3Turbo) {
// warning: gpt-3.5-turbo may update over time. Returning num tokens assuming gpt-3.5-turbo-0613.
- return GetWeightByModel("gpt-3.5-turbo-0613")
- } else if strings.Contains(model, "gpt-4") {
+ return GetWeightByModel(types.GPT3Turbo0613)
+ } else if strings.Contains(model, types.GPT4) {
// warning: gpt-4 may update over time. Returning num tokens assuming gpt-4-0613.
- return GetWeightByModel("gpt-4-0613")
+ return GetWeightByModel(types.GPT40613)
} else {
// not implemented: See https://github.com/openai/openai-python/blob/main/chatml.md for information on how messages are converted to tokens
panic(fmt.Errorf("not implemented for model %s", model))
@@ -55,6 +62,6 @@ func NumTokensFromMessages(messages []types.ChatGPTMessage, model string) (token
return tokens
}
-func CountTokenPrice(messages []types.ChatGPTMessage) int {
- return NumTokensFromMessages(messages, "gpt-4")
+func CountTokenPrice(messages []types.ChatGPTMessage, model string) int {
+ return NumTokensFromMessages(messages, model)
}