From 8958afca685af9fdff37d30fc81dbd2a99941f74 Mon Sep 17 00:00:00 2001 From: Zhang Minghan Date: Thu, 7 Sep 2023 10:10:13 +0800 Subject: [PATCH] update quota --- api/anonymous.go | 17 ++- api/buffer.go | 73 ++++++++++++ api/chat.go | 34 +++--- api/image.go | 4 + app/src/App.tsx | 4 +- app/src/assets/chat.less | 32 ++++- app/src/assets/main.less | 6 + app/src/components/Message.tsx | 169 +++++++++++++++------------ app/src/conf.ts | 6 +- app/src/conversation/connection.ts | 5 +- app/src/conversation/conversation.ts | 5 +- app/src/conversation/types.ts | 1 + app/src/i18n.ts | 6 +- app/src/routes/Auth.tsx | 4 +- app/src/routes/Home.tsx | 3 +- app/src/store/auth.ts | 5 +- auth/usage.go | 27 ++++- conversation/conversation.go | 12 ++ go.mod | 3 + go.sum | 6 + types/types.go | 7 +- utils/char.go | 8 ++ utils/tokenizer.go | 60 ++++++++++ 23 files changed, 383 insertions(+), 114 deletions(-) create mode 100644 api/buffer.go create mode 100644 utils/tokenizer.go diff --git a/api/anonymous.go b/api/anonymous.go index 622faab..38662ad 100644 --- a/api/anonymous.go +++ b/api/anonymous.go @@ -14,6 +14,7 @@ import ( type AnonymousRequestBody struct { Message string `json:"message" required:"true"` + Web bool `json:"web"` } type AnonymousResponseCache struct { @@ -57,18 +58,22 @@ func TestKey(key string) bool { return res.(map[string]interface{})["choices"] != nil } -func GetAnonymousResponse(message string) (string, string, error) { +func GetAnonymousResponse(message string, web bool) (string, string, error) { + if !web { + resp, err := GetChatGPTResponse([]types.ChatGPTMessage{{Role: "user", Content: message}}, 1000) + return "", resp, err + } keyword, source := ChatWithWeb([]types.ChatGPTMessage{{Role: "user", Content: message}}, false) resp, err := GetChatGPTResponse(source, 1000) return keyword, resp, err } -func GetAnonymousResponseWithCache(c *gin.Context, message string) (string, string, error) { +func GetAnonymousResponseWithCache(c *gin.Context, message string, web bool) (string, string, error) { cache := c.MustGet("cache").(*redis.Client) - res, err := cache.Get(c, fmt.Sprintf(":chatgpt:%s", message)).Result() + res, err := cache.Get(c, fmt.Sprintf(":chatgpt-%v:%s", web, message)).Result() form := utils.UnmarshalJson[AnonymousResponseCache](res) if err != nil || len(res) == 0 || res == "{}" || form.Message == "" { - key, res, err := GetAnonymousResponse(message) + key, res, err := GetAnonymousResponse(message, web) if err != nil { return "", "There was something wrong...", err } @@ -76,7 +81,7 @@ func GetAnonymousResponseWithCache(c *gin.Context, message string) (string, stri cache.Set(c, fmt.Sprintf(":chatgpt:%s", message), utils.ToJson(AnonymousResponseCache{ Keyword: key, Message: res, - }), time.Hour*6) + }), time.Hour*48) return key, res, nil } return form.Keyword, form.Message, nil @@ -103,7 +108,7 @@ func AnonymousAPI(c *gin.Context) { }) return } - key, res, err := GetAnonymousResponseWithCache(c, message) + key, res, err := GetAnonymousResponseWithCache(c, message, body.Web) if err != nil { c.JSON(http.StatusOK, gin.H{ "status": false, diff --git a/api/buffer.go b/api/buffer.go new file mode 100644 index 0000000..0f81d5e --- /dev/null +++ b/api/buffer.go @@ -0,0 +1,73 @@ +package api + +import ( + "chat/auth" + "chat/types" + "chat/utils" +) + +type Buffer struct { + Enable bool `json:"enable"` + Quota float32 `json:"quota"` + Data string `json:"data"` + Cursor int `json:"cursor"` + Times int `json:"times"` +} + +func NewBuffer(enable bool, history []types.ChatGPTMessage) *Buffer { + buffer := &Buffer{Data: "", Cursor: 0, Times: 0, Enable: enable} + if enable { + buffer.Quota = auth.CountInputToken(utils.CountTokenPrice(history)) + } + return buffer +} + +func (b *Buffer) GetCursor() int { + return b.Cursor +} + +func (b *Buffer) GetQuota() float32 { + if !b.Enable { + return 0. + } + return b.Quota + auth.CountOutputToken(b.ReadTimes()) +} + +func (b *Buffer) Write(data string) string { + b.Data += data + b.Cursor += len(data) + b.Times++ + return data +} + +func (b *Buffer) WriteBytes(data []byte) []byte { + b.Data += string(data) + b.Cursor += len(data) + b.Times++ + return data +} + +func (b *Buffer) IsEmpty() bool { + return b.Cursor == 0 +} + +func (b *Buffer) Reset() { + b.Data = "" + b.Cursor = 0 + b.Times = 0 +} + +func (b *Buffer) Read() string { + return b.Data +} + +func (b *Buffer) ReadWithDefault(_default string) string { + if b.IsEmpty() { + return _default + } + return b.Data +} + +func (b *Buffer) ReadTimes() int { + return b.Times +} diff --git a/api/chat.go b/api/chat.go index 86cb817..c016093 100644 --- a/api/chat.go +++ b/api/chat.go @@ -7,7 +7,6 @@ import ( "chat/types" "chat/utils" "database/sql" - "fmt" "github.com/gin-gonic/gin" "github.com/go-redis/redis/v8" "github.com/gorilla/websocket" @@ -15,6 +14,8 @@ import ( "strings" ) +const defaultMessage = "There was something wrong... Please try again later." + type WebsocketAuthForm struct { Token string `json:"token" binding:"required"` Id int64 `json:"id" binding:"required"` @@ -25,39 +26,47 @@ func SendSegmentMessage(conn *websocket.Conn, message types.ChatGPTSegmentRespon } func TextChat(db *sql.DB, user *auth.User, conn *websocket.Conn, instance *conversation.Conversation) string { - keyword, segment := ChatWithWeb(conversation.CopyMessage(instance.GetMessageSegment(12)), true) - SendSegmentMessage(conn, types.ChatGPTSegmentResponse{Keyword: keyword, End: false}) + var keyword string + var segment []types.ChatGPTMessage - msg := "" + if instance.IsEnableWeb() { + keyword, segment = ChatWithWeb(conversation.CopyMessage(instance.GetMessageSegment(12)), true) + } else { + segment = conversation.CopyMessage(instance.GetMessageSegment(12)) + } + + SendSegmentMessage(conn, types.ChatGPTSegmentResponse{Keyword: keyword, End: false}) if instance.IsEnableGPT4() && !auth.ReduceGPT4(db, user) { SendSegmentMessage(conn, types.ChatGPTSegmentResponse{ Message: "You have run out of GPT-4 usage. Please buy more.", + Quota: 0, End: true, }) return "You have run out of GPT-4 usage. Please buy more." } + buffer := NewBuffer(instance.IsEnableGPT4(), segment) StreamRequest(instance.IsEnableGPT4(), segment, 2000, func(resp string) { - msg += resp SendSegmentMessage(conn, types.ChatGPTSegmentResponse{ - Message: resp, + Message: buffer.Write(resp), + Quota: buffer.GetQuota(), End: false, }) }) - if msg == "" { - msg = "There was something wrong... Please try again later." + if buffer.IsEmpty() { if instance.IsEnableGPT4() { auth.IncreaseGPT4(db, user, 1) } SendSegmentMessage(conn, types.ChatGPTSegmentResponse{ - Message: msg, + Message: defaultMessage, + Quota: buffer.GetQuota(), End: false, }) } - SendSegmentMessage(conn, types.ChatGPTSegmentResponse{End: true}) + SendSegmentMessage(conn, types.ChatGPTSegmentResponse{End: true, Quota: buffer.GetQuota()}) - return msg + return buffer.ReadWithDefault(defaultMessage) } func ImageChat(conn *websocket.Conn, instance *conversation.Conversation, user *auth.User, db *sql.DB, cache *redis.Client) string { @@ -84,10 +93,9 @@ func ImageChat(conn *websocket.Conn, instance *conversation.Conversation, user * return err.Error() } - markdown := fmt.Sprintln("![image](", url, ")") + markdown := GetImageMarkdown(url) SendSegmentMessage(conn, types.ChatGPTSegmentResponse{ Message: markdown, - Keyword: "image", End: true, }) return markdown diff --git a/api/image.go b/api/image.go index 6e501e0..5bfaa61 100644 --- a/api/image.go +++ b/api/image.go @@ -70,3 +70,7 @@ func GetImageWithUserLimit(user *auth.User, prompt string, db *sql.DB, cache *re return GetImageWithCache(context.Background(), prompt, cache) } } + +func GetImageMarkdown(url string) string { + return fmt.Sprintln("![image](", url, ")") +} diff --git a/app/src/App.tsx b/app/src/App.tsx index 4919070..959bc60 100644 --- a/app/src/App.tsx +++ b/app/src/App.tsx @@ -25,7 +25,7 @@ import { DropdownMenuTrigger, } from "./components/ui/dropdown-menu.tsx"; import { Toaster } from "./components/ui/toaster.tsx"; -import { login } from "./conf.ts"; +import {login, tokenField} from "./conf.ts"; import { useTranslation } from "react-i18next"; function Settings() { @@ -67,7 +67,7 @@ function NavBar() { const { t } = useTranslation(); const dispatch = useDispatch(); useEffect(() => { - validateToken(dispatch, localStorage.getItem("token") ?? ""); + validateToken(dispatch, localStorage.getItem(tokenField) ?? ""); }, []); const auth = useSelector(selectAuthenticated); diff --git a/app/src/assets/chat.less b/app/src/assets/chat.less index 906585d..db38a8a 100644 --- a/app/src/assets/chat.less +++ b/app/src/assets/chat.less @@ -13,8 +13,8 @@ .message { display: flex; + gap: 6px; flex-direction: column; - width: 100%; &:last-child { animation: FlexInAnimationFromBottom 0.2s cubic-bezier(0.175, 0.885, 0.32, 1.275) 0s 1 normal forwards running; @@ -29,6 +29,36 @@ } } + .message-quota { + display: flex; + flex-direction: row; + align-items: center; + user-select: none; + gap: 4px; + cursor: pointer; + border: 1px solid hsl(var(--input)); + border-radius: var(--radius); + transition: 0.2s linear; + padding: 4px 8px; + width: max-content; + height: max-content; + white-space: nowrap; + + .quota { + font-size: 14px; + color: hsl(var(--text-secondary)); + } + + .icon { + transform: translateY(1px); + color: hsl(var(--text-secondary)); + } + + &:hover { + border-color: hsl(var(--border-hover)); + } + } + .message-content { padding: 8px 16px; border-radius: var(--radius); diff --git a/app/src/assets/main.less b/app/src/assets/main.less index 8ae27a7..3dc2709 100644 --- a/app/src/assets/main.less +++ b/app/src/assets/main.less @@ -64,3 +64,9 @@ strong { color: hsl(var(--text)); border: 1px solid hsl(var(--border)); } + +.icon-tooltip { + display: flex; + flex-direction: row; + align-items: center; +} diff --git a/app/src/components/Message.tsx b/app/src/components/Message.tsx index cf8ffae..9cbe331 100644 --- a/app/src/components/Message.tsx +++ b/app/src/components/Message.tsx @@ -1,6 +1,6 @@ import { Message } from "../conversation/types.ts"; import Markdown from "./Markdown.tsx"; -import {Copy, File, Loader2, MousePointerSquare} from "lucide-react"; +import {Cloud, CloudFog, Copy, File, Loader2, MousePointerSquare} from "lucide-react"; import { ContextMenu, ContextMenuContent, @@ -9,6 +9,7 @@ import { } from "./ui/context-menu.tsx"; import {copyClipboard, saveAsFile, useInputValue} from "../utils.ts"; import {useTranslation} from "react-i18next"; +import {Tooltip, TooltipContent, TooltipProvider, TooltipTrigger} from "./ui/tooltip.tsx"; type MessageProps = { message: Message; @@ -22,6 +23,24 @@ function MessageSegment({ message }: MessageProps) {
+ { + (message.quota && message.quota > 0) ? + + + +
+ + {message.quota.toFixed(2)} +
+
+ + +

{ t('quota-description') }

+
+
+
+ : null + }
@@ -41,79 +60,81 @@ function MessageSegment({ message }: MessageProps) { function MessageContent({ message }: MessageProps) { return ( -
- {message.keyword && message.keyword.length ? ( -
- - bing - - - - - - - - - - - - - - - - - - - - - {message.keyword} -
- ) : null} - {message.content.length ? ( - - ) : ( - - )} -
+ <> +
+ {message.keyword && message.keyword.length ? ( +
+ + bing + + + + + + + + + + + + + + + + + + + + + {message.keyword} +
+ ) : null} + {message.content.length ? ( + + ) : ( + + )} +
+ ) } diff --git a/app/src/conf.ts b/app/src/conf.ts index b310b50..fbdad44 100644 --- a/app/src/conf.ts +++ b/app/src/conf.ts @@ -5,10 +5,12 @@ export let rest_api: string = "http://localhost:8094"; export let ws_api: string = "ws://localhost:8094"; if (deploy) { - rest_api = "https://nioapi.fystart.cn"; - ws_api = "wss://nioapi.fystart.cn"; + rest_api = "https://api.chatnio.net"; + ws_api = "wss://api.chatnio.net"; } +export const tokenField = deploy ? "token" : "token-dev"; + export function login() { location.href = "https://deeptrain.lightxi.com/login?app=chatnio"; } diff --git a/app/src/conversation/connection.ts b/app/src/conversation/connection.ts index 9e49fac..8c83586 100644 --- a/app/src/conversation/connection.ts +++ b/app/src/conversation/connection.ts @@ -1,9 +1,10 @@ -import { ws_api } from "../conf.ts"; +import {tokenField, ws_api} from "../conf.ts"; export const endpoint = `${ws_api}/chat`; export type StreamMessage = { keyword?: string; + quota?: number; message: string; end: boolean; }; @@ -35,7 +36,7 @@ export class Connection { this.connection.onopen = () => { this.state = true; this.send({ - token: localStorage.getItem("token") || "", + token: localStorage.getItem(tokenField) || "", id: this.id, }); }; diff --git a/app/src/conversation/conversation.ts b/app/src/conversation/conversation.ts index 5605b1e..2a2f9ef 100644 --- a/app/src/conversation/conversation.ts +++ b/app/src/conversation/conversation.ts @@ -79,9 +79,10 @@ export class Conversation { this.triggerCallback(); } - public updateMessage(idx: number, message: string, keyword?: string) { + public updateMessage(idx: number, message: string, keyword?: string, quota?: number) { this.data[idx].content += message; if (keyword) this.data[idx].keyword = keyword; + if (quota) this.data[idx].quota = quota; this.triggerCallback(); } @@ -92,7 +93,7 @@ export class Conversation { }); return (message: StreamMessage) => { - this.updateMessage(cursor, message.message, message.keyword); + this.updateMessage(cursor, message.message, message.keyword, message.quota); if (message.end) { this.end = true; } diff --git a/app/src/conversation/types.ts b/app/src/conversation/types.ts index 71f1146..e8b76b6 100644 --- a/app/src/conversation/types.ts +++ b/app/src/conversation/types.ts @@ -3,6 +3,7 @@ import { Conversation } from "./conversation.ts"; export type Message = { content: string; keyword?: string; + quota?: number; role: string; }; diff --git a/app/src/i18n.ts b/app/src/i18n.ts index 11d7e5c..2ba3faa 100644 --- a/app/src/i18n.ts +++ b/app/src/i18n.ts @@ -51,7 +51,8 @@ const resources = { "copy": "Copy", "save": "Save as File", "use": "Use Message", - } + }, + "quota-description": "spending quota for the message", }, }, cn: { @@ -94,7 +95,8 @@ const resources = { "copy": "复制", "save": "保存为文件", "use": "使用消息", - } + }, + "quota-description": "消息的配额支出", }, }, }; diff --git a/app/src/routes/Auth.tsx b/app/src/routes/Auth.tsx index 29a235c..ecc1714 100644 --- a/app/src/routes/Auth.tsx +++ b/app/src/routes/Auth.tsx @@ -1,7 +1,7 @@ import { useToast } from "../components/ui/use-toast.ts"; import { useLocation } from "react-router-dom"; import { ToastAction } from "../components/ui/toast.tsx"; -import { login } from "../conf.ts"; +import {login, tokenField} from "../conf.ts"; import { useEffect } from "react"; import Loader from "../components/Loader.tsx"; import "../assets/auth.less"; @@ -16,7 +16,7 @@ function Auth() { const { t } = useTranslation(); const dispatch = useDispatch(); const search = new URLSearchParams(useLocation().search); - const token = (search.get("token") || "").trim(); + const token = (search.get(tokenField) || "").trim(); if (!token.length) { toast({ diff --git a/app/src/routes/Home.tsx b/app/src/routes/Home.tsx index f5cce78..ee34a7b 100644 --- a/app/src/routes/Home.tsx +++ b/app/src/routes/Home.tsx @@ -137,7 +137,8 @@ function SideBar() { {t("conversation.cancel")} { + onClick={async (e) => { + e.preventDefault(); if ( await deleteConversation(dispatch, conversation.id) ) diff --git a/app/src/store/auth.ts b/app/src/store/auth.ts index 16c1ac8..4f2db0f 100644 --- a/app/src/store/auth.ts +++ b/app/src/store/auth.ts @@ -1,5 +1,6 @@ import { createSlice } from "@reduxjs/toolkit"; import axios from "axios"; +import {tokenField} from "../conf.ts"; export const authSlice = createSlice({ name: "auth", @@ -12,7 +13,7 @@ export const authSlice = createSlice({ setToken: (state, action) => { state.token = action.payload as string; axios.defaults.headers.common["Authorization"] = state.token; - localStorage.setItem("token", state.token); + localStorage.setItem(tokenField, state.token); }, setAuthenticated: (state, action) => { state.authenticated = action.payload as boolean; @@ -25,7 +26,7 @@ export const authSlice = createSlice({ state.authenticated = false; state.username = ""; axios.defaults.headers.common["Authorization"] = ""; - localStorage.removeItem("token"); + localStorage.removeItem(tokenField); location.reload(); }, diff --git a/auth/usage.go b/auth/usage.go index 094dd9a..6f79a4f 100644 --- a/auth/usage.go +++ b/auth/usage.go @@ -4,6 +4,29 @@ import ( "database/sql" ) +// Price Calculation +// 10 nio points = ¥1 +// from 2023-9-6, 1 USD = 7.3124 CNY +// +// GPT-4 price (8k-context) +// Input Output +// $0.03 / 1K tokens $0.06 / 1K tokens +// ¥0.21 / 1K tokens ¥0.43 / 1K tokens +// 2.1 nio / 1K tokens 4.3 nio / 1K tokens + +// Dalle price (512x512) +// $0.018 / per image +// ¥0.13 / per image +// 1 nio / per image + +func CountInputToken(n int) float32 { + return float32(n) / 1000 * 2.1 +} + +func CountOutputToken(n int) float32 { + return float32(n) / 1000 * 4.3 +} + func ReduceUsage(db *sql.DB, user *User, _t string) bool { id := user.GetID(db) var count int @@ -82,7 +105,7 @@ func BuyDalle(db *sql.DB, user *User, value int) bool { return true } -func CountGPT4Prize(value int) float32 { +func CountGPT4price(value int) float32 { if value <= 20 { return float32(value) * 0.5 } @@ -91,7 +114,7 @@ func CountGPT4Prize(value int) float32 { } func BuyGPT4(db *sql.DB, user *User, value int) bool { - if !Pay(user.Username, CountGPT4Prize(value)) { + if !Pay(user.Username, CountGPT4price(value)) { return false } diff --git a/conversation/conversation.go b/conversation/conversation.go index 21adeb5..e681a28 100644 --- a/conversation/conversation.go +++ b/conversation/conversation.go @@ -14,11 +14,13 @@ type Conversation struct { Name string `json:"name"` Message []types.ChatGPTMessage `json:"message"` EnableGPT4 bool `json:"enable_gpt4"` + EnableWeb bool `json:"enable_web"` } type FormMessage struct { Type string `json:"type"` // ping Message string `json:"message"` + Web bool `json:"web"` GPT4 bool `json:"gpt4"` } @@ -29,6 +31,7 @@ func NewConversation(db *sql.DB, id int64) *Conversation { Name: "new chat", Message: []types.ChatGPTMessage{}, EnableGPT4: false, + EnableWeb: false, } } @@ -36,10 +39,18 @@ func (c *Conversation) IsEnableGPT4() bool { return c.EnableGPT4 } +func (c *Conversation) IsEnableWeb() bool { + return c.EnableWeb +} + func (c *Conversation) SetEnableGPT4(enable bool) { c.EnableGPT4 = enable } +func (c *Conversation) SetEnableWeb(enable bool) { + c.EnableWeb = enable +} + func (c *Conversation) GetName() string { return c.Name } @@ -131,6 +142,7 @@ func (c *Conversation) AddMessageFromUserForm(data []byte) (string, error) { c.AddMessageFromUser(form.Message) c.SetEnableGPT4(form.GPT4) + c.SetEnableWeb(form.Web) return form.Message, nil } diff --git a/go.mod b/go.mod index ab6ae39..ba6cf29 100644 --- a/go.mod +++ b/go.mod @@ -15,6 +15,7 @@ require ( github.com/cespare/xxhash/v2 v2.1.2 // indirect github.com/chenzhuoyu/base64x v0.0.0-20221115062448-fe3a3abad311 // indirect github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f // indirect + github.com/dlclark/regexp2 v1.10.0 // indirect github.com/fsnotify/fsnotify v1.6.0 // indirect github.com/gabriel-vasile/mimetype v1.4.2 // indirect github.com/gin-contrib/sse v0.1.0 // indirect @@ -22,6 +23,7 @@ require ( github.com/go-playground/universal-translator v0.18.1 // indirect github.com/go-playground/validator/v10 v10.14.0 // indirect github.com/goccy/go-json v0.10.2 // indirect + github.com/google/uuid v1.3.0 // indirect github.com/gorilla/websocket v1.5.0 // indirect github.com/hashicorp/hcl v1.0.0 // indirect github.com/json-iterator/go v1.1.12 // indirect @@ -33,6 +35,7 @@ require ( github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd // indirect github.com/modern-go/reflect2 v1.0.2 // indirect github.com/pelletier/go-toml/v2 v2.0.8 // indirect + github.com/pkoukk/tiktoken-go v0.1.5 // indirect github.com/spf13/afero v1.9.5 // indirect github.com/spf13/cast v1.5.1 // indirect github.com/spf13/jwalterweatherman v1.1.0 // indirect diff --git a/go.sum b/go.sum index a9244aa..65db482 100644 --- a/go.sum +++ b/go.sum @@ -61,6 +61,8 @@ github.com/dgrijalva/jwt-go v3.2.0+incompatible h1:7qlOGliEKZXTDg6OTjfoBKDXWrumC github.com/dgrijalva/jwt-go v3.2.0+incompatible/go.mod h1:E3ru+11k8xSBh+hMPgOLZmtrrCbhqsmaPHjLKYnJCaQ= github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f h1:lO4WD4F/rVNCu3HqELle0jiPLLBs70cWOduZpkS1E78= github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f/go.mod h1:cuUVRXasLTGF7a8hSLbxyZXjz+1KgoB3wDUb6vlszIc= +github.com/dlclark/regexp2 v1.10.0 h1:+/GIL799phkJqYW+3YbOd8LCcbHzT0Pbo8zl70MHsq0= +github.com/dlclark/regexp2 v1.10.0/go.mod h1:DHkYz0B9wPfa6wondMfaivmHpzrQ3v9q8cnmRbL6yW8= github.com/envoyproxy/go-control-plane v0.9.0/go.mod h1:YTl/9mNaCwkRvm6d1a2C3ymFceY/DCBVvsKhRF0iEA4= github.com/envoyproxy/go-control-plane v0.9.1-0.20191026205805-5f8ba28d4473/go.mod h1:YTl/9mNaCwkRvm6d1a2C3ymFceY/DCBVvsKhRF0iEA4= github.com/envoyproxy/go-control-plane v0.9.4/go.mod h1:6rpuAdCZL397s3pYoYcLgu1mIlRU8Am5FuJP05cCM98= @@ -147,6 +149,8 @@ github.com/google/pprof v0.0.0-20201203190320-1bf35d6f28c2/go.mod h1:kpwsk12EmLe github.com/google/pprof v0.0.0-20201218002935-b9804c9f04c2/go.mod h1:kpwsk12EmLew5upagYY7GY0pfYCcupk39gWOCRROcvE= github.com/google/renameio v0.1.0/go.mod h1:KWCgfxg9yswjAJkECMjeO8J8rahYeXnNhOm40UhjYkI= github.com/google/uuid v1.1.2/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= +github.com/google/uuid v1.3.0 h1:t6JiXgmwXMjEs8VusXIJk2BXHsn+wx8BZdTaoZ5fu7I= +github.com/google/uuid v1.3.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= github.com/googleapis/gax-go/v2 v2.0.4/go.mod h1:0Wqv26UfaUD9n4G6kQubkQ+KchISgw+vpHVxEJEs9eg= github.com/googleapis/gax-go/v2 v2.0.5/go.mod h1:DWXyrwAJ9X0FpwwEdw+IPEYBICEFu5mhpdKc/us6bOk= github.com/googleapis/google-cloud-go-testing v0.0.0-20200911160855-bcd43fbb19e8/go.mod h1:dvDLG8qkwmyD9a/MJJN3XJcT3xFxOKAvTZGvuZmac9g= @@ -192,6 +196,8 @@ github.com/pelletier/go-toml/v2 v2.0.8 h1:0ctb6s9mE31h0/lhu+J6OPmVeDxJn+kYnJc2jZ github.com/pelletier/go-toml/v2 v2.0.8/go.mod h1:vuYfssBdrU2XDZ9bYydBu6t+6a6PYNcZljzZR9VXg+4= github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= github.com/pkg/sftp v1.13.1/go.mod h1:3HaPG6Dq1ILlpPZRO0HVMrsydcdLt6HRDccSgb87qRg= +github.com/pkoukk/tiktoken-go v0.1.5 h1:hAlT4dCf6Uk50x8E7HQrddhH3EWMKUN+LArExQQsQx4= +github.com/pkoukk/tiktoken-go v0.1.5/go.mod h1:9NiV+i9mJKGj1rYOT+njbv+ZwA/zJxYdewGl6qVatpg= github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= github.com/prometheus/client_model v0.0.0-20190812154241-14fe0d1b01d4/go.mod h1:xMI15A0UPsDsEKsMN9yxemIoYk6Tm2C1GtYGdfGttqA= diff --git a/types/types.go b/types/types.go index 0ad714c..499c41a 100644 --- a/types/types.go +++ b/types/types.go @@ -34,7 +34,8 @@ type ChatGPTStreamResponse struct { } type ChatGPTSegmentResponse struct { - Keyword string `json:"keyword"` - Message string `json:"message"` - End bool `json:"end"` + Quota float32 `json:"quota"` + Keyword string `json:"keyword"` + Message string `json:"message"` + End bool `json:"end"` } diff --git a/utils/char.go b/utils/char.go index b176be1..f5ba3dc 100644 --- a/utils/char.go +++ b/utils/char.go @@ -37,6 +37,14 @@ func Unmarshal[T interface{}](data []byte) (form T, err error) { return form, err } +func Marshal[T interface{}](data T) string { + res, err := json.Marshal(data) + if err != nil { + return "" + } + return string(res) +} + func ToInt(value string) int { if res, err := strconv.Atoi(value); err == nil { return res diff --git a/utils/tokenizer.go b/utils/tokenizer.go new file mode 100644 index 0000000..5bac7d7 --- /dev/null +++ b/utils/tokenizer.go @@ -0,0 +1,60 @@ +package utils + +import ( + "chat/types" + "fmt" + "github.com/pkoukk/tiktoken-go" + "strings" +) + +// Using https://github.com/pkoukk/tiktoken-go +// To count number of tokens of openai chat messages +// OpenAI Cookbook: https://github.com/openai/openai-cookbook/blob/main/examples/How_to_count_tokens_with_tiktoken.ipynb + +func GetWeightByModel(model string) int { + switch model { + case "gpt-3.5-turbo-0613", + "gpt-3.5-turbo-16k-0613", + "gpt-4-0314", + "gpt-4-32k-0314", + "gpt-4-0613", + "gpt-4-32k-0613": + return 3 + case "gpt-3.5-turbo-0301": + return 4 // every message follows <|start|>{role/name}\n{content}<|end|>\n + default: + if strings.Contains(model, "gpt-3.5-turbo") { + // warning: gpt-3.5-turbo may update over time. Returning num tokens assuming gpt-3.5-turbo-0613. + return GetWeightByModel("gpt-3.5-turbo-0613") + } else if strings.Contains(model, "gpt-4") { + // warning: gpt-4 may update over time. Returning num tokens assuming gpt-4-0613. + return GetWeightByModel("gpt-4-0613") + } else { + // not implemented: See https://github.com/openai/openai-python/blob/main/chatml.md for information on how messages are converted to tokens + panic(fmt.Errorf("not implemented for model %s", model)) + } + } +} +func NumTokensFromMessages(messages []types.ChatGPTMessage, model string) (tokens int) { + weight := GetWeightByModel(model) + tkm, err := tiktoken.EncodingForModel(model) + if err != nil { + // can not encode messages, use length of messages as a proxy for number of tokens + // using rune instead of byte to account for unicode characters (e.g. emojis, non-english characters) + + data := Marshal(messages) + return len([]rune(data)) * weight + } + + for _, message := range messages { + tokens += weight + tokens += len(tkm.Encode(message.Content, nil, nil)) + tokens += len(tkm.Encode(message.Role, nil, nil)) + } + tokens += 3 // every reply is primed with <|start|>assistant<|message|> + return tokens +} + +func CountTokenPrice(messages []types.ChatGPTMessage) int { + return NumTokensFromMessages(messages, "gpt-4") +}