update quota

This commit is contained in:
Zhang Minghan 2023-09-07 10:10:13 +08:00
parent 806cf3a048
commit 8958afca68
23 changed files with 383 additions and 114 deletions

View File

@ -14,6 +14,7 @@ import (
type AnonymousRequestBody struct {
Message string `json:"message" required:"true"`
Web bool `json:"web"`
}
type AnonymousResponseCache struct {
@ -57,18 +58,22 @@ func TestKey(key string) bool {
return res.(map[string]interface{})["choices"] != nil
}
func GetAnonymousResponse(message string) (string, string, error) {
func GetAnonymousResponse(message string, web bool) (string, string, error) {
if !web {
resp, err := GetChatGPTResponse([]types.ChatGPTMessage{{Role: "user", Content: message}}, 1000)
return "", resp, err
}
keyword, source := ChatWithWeb([]types.ChatGPTMessage{{Role: "user", Content: message}}, false)
resp, err := GetChatGPTResponse(source, 1000)
return keyword, resp, err
}
func GetAnonymousResponseWithCache(c *gin.Context, message string) (string, string, error) {
func GetAnonymousResponseWithCache(c *gin.Context, message string, web bool) (string, string, error) {
cache := c.MustGet("cache").(*redis.Client)
res, err := cache.Get(c, fmt.Sprintf(":chatgpt:%s", message)).Result()
res, err := cache.Get(c, fmt.Sprintf(":chatgpt-%v:%s", web, message)).Result()
form := utils.UnmarshalJson[AnonymousResponseCache](res)
if err != nil || len(res) == 0 || res == "{}" || form.Message == "" {
key, res, err := GetAnonymousResponse(message)
key, res, err := GetAnonymousResponse(message, web)
if err != nil {
return "", "There was something wrong...", err
}
@ -76,7 +81,7 @@ func GetAnonymousResponseWithCache(c *gin.Context, message string) (string, stri
cache.Set(c, fmt.Sprintf(":chatgpt:%s", message), utils.ToJson(AnonymousResponseCache{
Keyword: key,
Message: res,
}), time.Hour*6)
}), time.Hour*48)
return key, res, nil
}
return form.Keyword, form.Message, nil
@ -103,7 +108,7 @@ func AnonymousAPI(c *gin.Context) {
})
return
}
key, res, err := GetAnonymousResponseWithCache(c, message)
key, res, err := GetAnonymousResponseWithCache(c, message, body.Web)
if err != nil {
c.JSON(http.StatusOK, gin.H{
"status": false,

73
api/buffer.go Normal file
View File

@ -0,0 +1,73 @@
package api
import (
"chat/auth"
"chat/types"
"chat/utils"
)
type Buffer struct {
Enable bool `json:"enable"`
Quota float32 `json:"quota"`
Data string `json:"data"`
Cursor int `json:"cursor"`
Times int `json:"times"`
}
func NewBuffer(enable bool, history []types.ChatGPTMessage) *Buffer {
buffer := &Buffer{Data: "", Cursor: 0, Times: 0, Enable: enable}
if enable {
buffer.Quota = auth.CountInputToken(utils.CountTokenPrice(history))
}
return buffer
}
func (b *Buffer) GetCursor() int {
return b.Cursor
}
func (b *Buffer) GetQuota() float32 {
if !b.Enable {
return 0.
}
return b.Quota + auth.CountOutputToken(b.ReadTimes())
}
func (b *Buffer) Write(data string) string {
b.Data += data
b.Cursor += len(data)
b.Times++
return data
}
func (b *Buffer) WriteBytes(data []byte) []byte {
b.Data += string(data)
b.Cursor += len(data)
b.Times++
return data
}
func (b *Buffer) IsEmpty() bool {
return b.Cursor == 0
}
func (b *Buffer) Reset() {
b.Data = ""
b.Cursor = 0
b.Times = 0
}
func (b *Buffer) Read() string {
return b.Data
}
func (b *Buffer) ReadWithDefault(_default string) string {
if b.IsEmpty() {
return _default
}
return b.Data
}
func (b *Buffer) ReadTimes() int {
return b.Times
}

View File

@ -7,7 +7,6 @@ import (
"chat/types"
"chat/utils"
"database/sql"
"fmt"
"github.com/gin-gonic/gin"
"github.com/go-redis/redis/v8"
"github.com/gorilla/websocket"
@ -15,6 +14,8 @@ import (
"strings"
)
const defaultMessage = "There was something wrong... Please try again later."
type WebsocketAuthForm struct {
Token string `json:"token" binding:"required"`
Id int64 `json:"id" binding:"required"`
@ -25,39 +26,47 @@ func SendSegmentMessage(conn *websocket.Conn, message types.ChatGPTSegmentRespon
}
func TextChat(db *sql.DB, user *auth.User, conn *websocket.Conn, instance *conversation.Conversation) string {
keyword, segment := ChatWithWeb(conversation.CopyMessage(instance.GetMessageSegment(12)), true)
SendSegmentMessage(conn, types.ChatGPTSegmentResponse{Keyword: keyword, End: false})
var keyword string
var segment []types.ChatGPTMessage
msg := ""
if instance.IsEnableWeb() {
keyword, segment = ChatWithWeb(conversation.CopyMessage(instance.GetMessageSegment(12)), true)
} else {
segment = conversation.CopyMessage(instance.GetMessageSegment(12))
}
SendSegmentMessage(conn, types.ChatGPTSegmentResponse{Keyword: keyword, End: false})
if instance.IsEnableGPT4() && !auth.ReduceGPT4(db, user) {
SendSegmentMessage(conn, types.ChatGPTSegmentResponse{
Message: "You have run out of GPT-4 usage. Please buy more.",
Quota: 0,
End: true,
})
return "You have run out of GPT-4 usage. Please buy more."
}
buffer := NewBuffer(instance.IsEnableGPT4(), segment)
StreamRequest(instance.IsEnableGPT4(), segment, 2000, func(resp string) {
msg += resp
SendSegmentMessage(conn, types.ChatGPTSegmentResponse{
Message: resp,
Message: buffer.Write(resp),
Quota: buffer.GetQuota(),
End: false,
})
})
if msg == "" {
msg = "There was something wrong... Please try again later."
if buffer.IsEmpty() {
if instance.IsEnableGPT4() {
auth.IncreaseGPT4(db, user, 1)
}
SendSegmentMessage(conn, types.ChatGPTSegmentResponse{
Message: msg,
Message: defaultMessage,
Quota: buffer.GetQuota(),
End: false,
})
}
SendSegmentMessage(conn, types.ChatGPTSegmentResponse{End: true})
SendSegmentMessage(conn, types.ChatGPTSegmentResponse{End: true, Quota: buffer.GetQuota()})
return msg
return buffer.ReadWithDefault(defaultMessage)
}
func ImageChat(conn *websocket.Conn, instance *conversation.Conversation, user *auth.User, db *sql.DB, cache *redis.Client) string {
@ -84,10 +93,9 @@ func ImageChat(conn *websocket.Conn, instance *conversation.Conversation, user *
return err.Error()
}
markdown := fmt.Sprintln("![image](", url, ")")
markdown := GetImageMarkdown(url)
SendSegmentMessage(conn, types.ChatGPTSegmentResponse{
Message: markdown,
Keyword: "image",
End: true,
})
return markdown

View File

@ -70,3 +70,7 @@ func GetImageWithUserLimit(user *auth.User, prompt string, db *sql.DB, cache *re
return GetImageWithCache(context.Background(), prompt, cache)
}
}
func GetImageMarkdown(url string) string {
return fmt.Sprintln("![image](", url, ")")
}

View File

@ -25,7 +25,7 @@ import {
DropdownMenuTrigger,
} from "./components/ui/dropdown-menu.tsx";
import { Toaster } from "./components/ui/toaster.tsx";
import { login } from "./conf.ts";
import {login, tokenField} from "./conf.ts";
import { useTranslation } from "react-i18next";
function Settings() {
@ -67,7 +67,7 @@ function NavBar() {
const { t } = useTranslation();
const dispatch = useDispatch();
useEffect(() => {
validateToken(dispatch, localStorage.getItem("token") ?? "");
validateToken(dispatch, localStorage.getItem(tokenField) ?? "");
}, []);
const auth = useSelector(selectAuthenticated);

View File

@ -13,8 +13,8 @@
.message {
display: flex;
gap: 6px;
flex-direction: column;
width: 100%;
&:last-child {
animation: FlexInAnimationFromBottom 0.2s cubic-bezier(0.175, 0.885, 0.32, 1.275) 0s 1 normal forwards running;
@ -29,6 +29,36 @@
}
}
.message-quota {
display: flex;
flex-direction: row;
align-items: center;
user-select: none;
gap: 4px;
cursor: pointer;
border: 1px solid hsl(var(--input));
border-radius: var(--radius);
transition: 0.2s linear;
padding: 4px 8px;
width: max-content;
height: max-content;
white-space: nowrap;
.quota {
font-size: 14px;
color: hsl(var(--text-secondary));
}
.icon {
transform: translateY(1px);
color: hsl(var(--text-secondary));
}
&:hover {
border-color: hsl(var(--border-hover));
}
}
.message-content {
padding: 8px 16px;
border-radius: var(--radius);

View File

@ -64,3 +64,9 @@ strong {
color: hsl(var(--text));
border: 1px solid hsl(var(--border));
}
.icon-tooltip {
display: flex;
flex-direction: row;
align-items: center;
}

View File

@ -1,6 +1,6 @@
import { Message } from "../conversation/types.ts";
import Markdown from "./Markdown.tsx";
import {Copy, File, Loader2, MousePointerSquare} from "lucide-react";
import {Cloud, CloudFog, Copy, File, Loader2, MousePointerSquare} from "lucide-react";
import {
ContextMenu,
ContextMenuContent,
@ -9,6 +9,7 @@ import {
} from "./ui/context-menu.tsx";
import {copyClipboard, saveAsFile, useInputValue} from "../utils.ts";
import {useTranslation} from "react-i18next";
import {Tooltip, TooltipContent, TooltipProvider, TooltipTrigger} from "./ui/tooltip.tsx";
type MessageProps = {
message: Message;
@ -22,6 +23,24 @@ function MessageSegment({ message }: MessageProps) {
<ContextMenuTrigger asChild>
<div className={`message ${message.role}`}>
<MessageContent message={message} />
{
(message.quota && message.quota > 0) ?
<TooltipProvider>
<Tooltip>
<TooltipTrigger asChild>
<div className={`message-quota`}>
<Cloud className={`h-4 w-4 icon`} />
<span className={`quota`}>{message.quota.toFixed(2)}</span>
</div>
</TooltipTrigger>
<TooltipContent className={`icon-tooltip`}>
<CloudFog className={`h-4 w-4 mr-2`} />
<p>{ t('quota-description') }</p>
</TooltipContent>
</Tooltip>
</TooltipProvider>
: null
}
</div>
</ContextMenuTrigger>
<ContextMenuContent>
@ -41,6 +60,7 @@ function MessageSegment({ message }: MessageProps) {
function MessageContent({ message }: MessageProps) {
return (
<>
<div className={`message-content`}>
{message.keyword && message.keyword.length ? (
<div className={`bing`}>
@ -114,6 +134,7 @@ function MessageContent({ message }: MessageProps) {
<Loader2 className={`h-5 w-5 m-1 animate-spin`} />
)}
</div>
</>
)
}

View File

@ -5,10 +5,12 @@ export let rest_api: string = "http://localhost:8094";
export let ws_api: string = "ws://localhost:8094";
if (deploy) {
rest_api = "https://nioapi.fystart.cn";
ws_api = "wss://nioapi.fystart.cn";
rest_api = "https://api.chatnio.net";
ws_api = "wss://api.chatnio.net";
}
export const tokenField = deploy ? "token" : "token-dev";
export function login() {
location.href = "https://deeptrain.lightxi.com/login?app=chatnio";
}

View File

@ -1,9 +1,10 @@
import { ws_api } from "../conf.ts";
import {tokenField, ws_api} from "../conf.ts";
export const endpoint = `${ws_api}/chat`;
export type StreamMessage = {
keyword?: string;
quota?: number;
message: string;
end: boolean;
};
@ -35,7 +36,7 @@ export class Connection {
this.connection.onopen = () => {
this.state = true;
this.send({
token: localStorage.getItem("token") || "",
token: localStorage.getItem(tokenField) || "",
id: this.id,
});
};

View File

@ -79,9 +79,10 @@ export class Conversation {
this.triggerCallback();
}
public updateMessage(idx: number, message: string, keyword?: string) {
public updateMessage(idx: number, message: string, keyword?: string, quota?: number) {
this.data[idx].content += message;
if (keyword) this.data[idx].keyword = keyword;
if (quota) this.data[idx].quota = quota;
this.triggerCallback();
}
@ -92,7 +93,7 @@ export class Conversation {
});
return (message: StreamMessage) => {
this.updateMessage(cursor, message.message, message.keyword);
this.updateMessage(cursor, message.message, message.keyword, message.quota);
if (message.end) {
this.end = true;
}

View File

@ -3,6 +3,7 @@ import { Conversation } from "./conversation.ts";
export type Message = {
content: string;
keyword?: string;
quota?: number;
role: string;
};

View File

@ -51,7 +51,8 @@ const resources = {
"copy": "Copy",
"save": "Save as File",
"use": "Use Message",
}
},
"quota-description": "spending quota for the message",
},
},
cn: {
@ -94,7 +95,8 @@ const resources = {
"copy": "复制",
"save": "保存为文件",
"use": "使用消息",
}
},
"quota-description": "消息的配额支出",
},
},
};

View File

@ -1,7 +1,7 @@
import { useToast } from "../components/ui/use-toast.ts";
import { useLocation } from "react-router-dom";
import { ToastAction } from "../components/ui/toast.tsx";
import { login } from "../conf.ts";
import {login, tokenField} from "../conf.ts";
import { useEffect } from "react";
import Loader from "../components/Loader.tsx";
import "../assets/auth.less";
@ -16,7 +16,7 @@ function Auth() {
const { t } = useTranslation();
const dispatch = useDispatch();
const search = new URLSearchParams(useLocation().search);
const token = (search.get("token") || "").trim();
const token = (search.get(tokenField) || "").trim();
if (!token.length) {
toast({

View File

@ -137,7 +137,8 @@ function SideBar() {
{t("conversation.cancel")}
</AlertDialogCancel>
<AlertDialogAction
onClick={async () => {
onClick={async (e) => {
e.preventDefault();
if (
await deleteConversation(dispatch, conversation.id)
)

View File

@ -1,5 +1,6 @@
import { createSlice } from "@reduxjs/toolkit";
import axios from "axios";
import {tokenField} from "../conf.ts";
export const authSlice = createSlice({
name: "auth",
@ -12,7 +13,7 @@ export const authSlice = createSlice({
setToken: (state, action) => {
state.token = action.payload as string;
axios.defaults.headers.common["Authorization"] = state.token;
localStorage.setItem("token", state.token);
localStorage.setItem(tokenField, state.token);
},
setAuthenticated: (state, action) => {
state.authenticated = action.payload as boolean;
@ -25,7 +26,7 @@ export const authSlice = createSlice({
state.authenticated = false;
state.username = "";
axios.defaults.headers.common["Authorization"] = "";
localStorage.removeItem("token");
localStorage.removeItem(tokenField);
location.reload();
},

View File

@ -4,6 +4,29 @@ import (
"database/sql"
)
// Price Calculation
// 10 nio points = ¥1
// from 2023-9-6, 1 USD = 7.3124 CNY
//
// GPT-4 price (8k-context)
// Input Output
// $0.03 / 1K tokens $0.06 / 1K tokens
// ¥0.21 / 1K tokens ¥0.43 / 1K tokens
// 2.1 nio / 1K tokens 4.3 nio / 1K tokens
// Dalle price (512x512)
// $0.018 / per image
// ¥0.13 / per image
// 1 nio / per image
func CountInputToken(n int) float32 {
return float32(n) / 1000 * 2.1
}
func CountOutputToken(n int) float32 {
return float32(n) / 1000 * 4.3
}
func ReduceUsage(db *sql.DB, user *User, _t string) bool {
id := user.GetID(db)
var count int
@ -82,7 +105,7 @@ func BuyDalle(db *sql.DB, user *User, value int) bool {
return true
}
func CountGPT4Prize(value int) float32 {
func CountGPT4price(value int) float32 {
if value <= 20 {
return float32(value) * 0.5
}
@ -91,7 +114,7 @@ func CountGPT4Prize(value int) float32 {
}
func BuyGPT4(db *sql.DB, user *User, value int) bool {
if !Pay(user.Username, CountGPT4Prize(value)) {
if !Pay(user.Username, CountGPT4price(value)) {
return false
}

View File

@ -14,11 +14,13 @@ type Conversation struct {
Name string `json:"name"`
Message []types.ChatGPTMessage `json:"message"`
EnableGPT4 bool `json:"enable_gpt4"`
EnableWeb bool `json:"enable_web"`
}
type FormMessage struct {
Type string `json:"type"` // ping
Message string `json:"message"`
Web bool `json:"web"`
GPT4 bool `json:"gpt4"`
}
@ -29,6 +31,7 @@ func NewConversation(db *sql.DB, id int64) *Conversation {
Name: "new chat",
Message: []types.ChatGPTMessage{},
EnableGPT4: false,
EnableWeb: false,
}
}
@ -36,10 +39,18 @@ func (c *Conversation) IsEnableGPT4() bool {
return c.EnableGPT4
}
func (c *Conversation) IsEnableWeb() bool {
return c.EnableWeb
}
func (c *Conversation) SetEnableGPT4(enable bool) {
c.EnableGPT4 = enable
}
func (c *Conversation) SetEnableWeb(enable bool) {
c.EnableWeb = enable
}
func (c *Conversation) GetName() string {
return c.Name
}
@ -131,6 +142,7 @@ func (c *Conversation) AddMessageFromUserForm(data []byte) (string, error) {
c.AddMessageFromUser(form.Message)
c.SetEnableGPT4(form.GPT4)
c.SetEnableWeb(form.Web)
return form.Message, nil
}

3
go.mod
View File

@ -15,6 +15,7 @@ require (
github.com/cespare/xxhash/v2 v2.1.2 // indirect
github.com/chenzhuoyu/base64x v0.0.0-20221115062448-fe3a3abad311 // indirect
github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f // indirect
github.com/dlclark/regexp2 v1.10.0 // indirect
github.com/fsnotify/fsnotify v1.6.0 // indirect
github.com/gabriel-vasile/mimetype v1.4.2 // indirect
github.com/gin-contrib/sse v0.1.0 // indirect
@ -22,6 +23,7 @@ require (
github.com/go-playground/universal-translator v0.18.1 // indirect
github.com/go-playground/validator/v10 v10.14.0 // indirect
github.com/goccy/go-json v0.10.2 // indirect
github.com/google/uuid v1.3.0 // indirect
github.com/gorilla/websocket v1.5.0 // indirect
github.com/hashicorp/hcl v1.0.0 // indirect
github.com/json-iterator/go v1.1.12 // indirect
@ -33,6 +35,7 @@ require (
github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd // indirect
github.com/modern-go/reflect2 v1.0.2 // indirect
github.com/pelletier/go-toml/v2 v2.0.8 // indirect
github.com/pkoukk/tiktoken-go v0.1.5 // indirect
github.com/spf13/afero v1.9.5 // indirect
github.com/spf13/cast v1.5.1 // indirect
github.com/spf13/jwalterweatherman v1.1.0 // indirect

6
go.sum
View File

@ -61,6 +61,8 @@ github.com/dgrijalva/jwt-go v3.2.0+incompatible h1:7qlOGliEKZXTDg6OTjfoBKDXWrumC
github.com/dgrijalva/jwt-go v3.2.0+incompatible/go.mod h1:E3ru+11k8xSBh+hMPgOLZmtrrCbhqsmaPHjLKYnJCaQ=
github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f h1:lO4WD4F/rVNCu3HqELle0jiPLLBs70cWOduZpkS1E78=
github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f/go.mod h1:cuUVRXasLTGF7a8hSLbxyZXjz+1KgoB3wDUb6vlszIc=
github.com/dlclark/regexp2 v1.10.0 h1:+/GIL799phkJqYW+3YbOd8LCcbHzT0Pbo8zl70MHsq0=
github.com/dlclark/regexp2 v1.10.0/go.mod h1:DHkYz0B9wPfa6wondMfaivmHpzrQ3v9q8cnmRbL6yW8=
github.com/envoyproxy/go-control-plane v0.9.0/go.mod h1:YTl/9mNaCwkRvm6d1a2C3ymFceY/DCBVvsKhRF0iEA4=
github.com/envoyproxy/go-control-plane v0.9.1-0.20191026205805-5f8ba28d4473/go.mod h1:YTl/9mNaCwkRvm6d1a2C3ymFceY/DCBVvsKhRF0iEA4=
github.com/envoyproxy/go-control-plane v0.9.4/go.mod h1:6rpuAdCZL397s3pYoYcLgu1mIlRU8Am5FuJP05cCM98=
@ -147,6 +149,8 @@ github.com/google/pprof v0.0.0-20201203190320-1bf35d6f28c2/go.mod h1:kpwsk12EmLe
github.com/google/pprof v0.0.0-20201218002935-b9804c9f04c2/go.mod h1:kpwsk12EmLew5upagYY7GY0pfYCcupk39gWOCRROcvE=
github.com/google/renameio v0.1.0/go.mod h1:KWCgfxg9yswjAJkECMjeO8J8rahYeXnNhOm40UhjYkI=
github.com/google/uuid v1.1.2/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo=
github.com/google/uuid v1.3.0 h1:t6JiXgmwXMjEs8VusXIJk2BXHsn+wx8BZdTaoZ5fu7I=
github.com/google/uuid v1.3.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo=
github.com/googleapis/gax-go/v2 v2.0.4/go.mod h1:0Wqv26UfaUD9n4G6kQubkQ+KchISgw+vpHVxEJEs9eg=
github.com/googleapis/gax-go/v2 v2.0.5/go.mod h1:DWXyrwAJ9X0FpwwEdw+IPEYBICEFu5mhpdKc/us6bOk=
github.com/googleapis/google-cloud-go-testing v0.0.0-20200911160855-bcd43fbb19e8/go.mod h1:dvDLG8qkwmyD9a/MJJN3XJcT3xFxOKAvTZGvuZmac9g=
@ -192,6 +196,8 @@ github.com/pelletier/go-toml/v2 v2.0.8 h1:0ctb6s9mE31h0/lhu+J6OPmVeDxJn+kYnJc2jZ
github.com/pelletier/go-toml/v2 v2.0.8/go.mod h1:vuYfssBdrU2XDZ9bYydBu6t+6a6PYNcZljzZR9VXg+4=
github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0=
github.com/pkg/sftp v1.13.1/go.mod h1:3HaPG6Dq1ILlpPZRO0HVMrsydcdLt6HRDccSgb87qRg=
github.com/pkoukk/tiktoken-go v0.1.5 h1:hAlT4dCf6Uk50x8E7HQrddhH3EWMKUN+LArExQQsQx4=
github.com/pkoukk/tiktoken-go v0.1.5/go.mod h1:9NiV+i9mJKGj1rYOT+njbv+ZwA/zJxYdewGl6qVatpg=
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
github.com/prometheus/client_model v0.0.0-20190812154241-14fe0d1b01d4/go.mod h1:xMI15A0UPsDsEKsMN9yxemIoYk6Tm2C1GtYGdfGttqA=

View File

@ -34,6 +34,7 @@ type ChatGPTStreamResponse struct {
}
type ChatGPTSegmentResponse struct {
Quota float32 `json:"quota"`
Keyword string `json:"keyword"`
Message string `json:"message"`
End bool `json:"end"`

View File

@ -37,6 +37,14 @@ func Unmarshal[T interface{}](data []byte) (form T, err error) {
return form, err
}
func Marshal[T interface{}](data T) string {
res, err := json.Marshal(data)
if err != nil {
return ""
}
return string(res)
}
func ToInt(value string) int {
if res, err := strconv.Atoi(value); err == nil {
return res

60
utils/tokenizer.go Normal file
View File

@ -0,0 +1,60 @@
package utils
import (
"chat/types"
"fmt"
"github.com/pkoukk/tiktoken-go"
"strings"
)
// Using https://github.com/pkoukk/tiktoken-go
// To count number of tokens of openai chat messages
// OpenAI Cookbook: https://github.com/openai/openai-cookbook/blob/main/examples/How_to_count_tokens_with_tiktoken.ipynb
func GetWeightByModel(model string) int {
switch model {
case "gpt-3.5-turbo-0613",
"gpt-3.5-turbo-16k-0613",
"gpt-4-0314",
"gpt-4-32k-0314",
"gpt-4-0613",
"gpt-4-32k-0613":
return 3
case "gpt-3.5-turbo-0301":
return 4 // every message follows <|start|>{role/name}\n{content}<|end|>\n
default:
if strings.Contains(model, "gpt-3.5-turbo") {
// warning: gpt-3.5-turbo may update over time. Returning num tokens assuming gpt-3.5-turbo-0613.
return GetWeightByModel("gpt-3.5-turbo-0613")
} else if strings.Contains(model, "gpt-4") {
// warning: gpt-4 may update over time. Returning num tokens assuming gpt-4-0613.
return GetWeightByModel("gpt-4-0613")
} else {
// not implemented: See https://github.com/openai/openai-python/blob/main/chatml.md for information on how messages are converted to tokens
panic(fmt.Errorf("not implemented for model %s", model))
}
}
}
func NumTokensFromMessages(messages []types.ChatGPTMessage, model string) (tokens int) {
weight := GetWeightByModel(model)
tkm, err := tiktoken.EncodingForModel(model)
if err != nil {
// can not encode messages, use length of messages as a proxy for number of tokens
// using rune instead of byte to account for unicode characters (e.g. emojis, non-english characters)
data := Marshal(messages)
return len([]rune(data)) * weight
}
for _, message := range messages {
tokens += weight
tokens += len(tkm.Encode(message.Content, nil, nil))
tokens += len(tkm.Encode(message.Role, nil, nil))
}
tokens += 3 // every reply is primed with <|start|>assistant<|message|>
return tokens
}
func CountTokenPrice(messages []types.ChatGPTMessage) int {
return NumTokensFromMessages(messages, "gpt-4")
}