update quota

This commit is contained in:
Zhang Minghan 2023-09-07 10:10:13 +08:00
parent 806cf3a048
commit 8958afca68
23 changed files with 383 additions and 114 deletions

View File

@ -14,6 +14,7 @@ import (
type AnonymousRequestBody struct { type AnonymousRequestBody struct {
Message string `json:"message" required:"true"` Message string `json:"message" required:"true"`
Web bool `json:"web"`
} }
type AnonymousResponseCache struct { type AnonymousResponseCache struct {
@ -57,18 +58,22 @@ func TestKey(key string) bool {
return res.(map[string]interface{})["choices"] != nil return res.(map[string]interface{})["choices"] != nil
} }
func GetAnonymousResponse(message string) (string, string, error) { func GetAnonymousResponse(message string, web bool) (string, string, error) {
if !web {
resp, err := GetChatGPTResponse([]types.ChatGPTMessage{{Role: "user", Content: message}}, 1000)
return "", resp, err
}
keyword, source := ChatWithWeb([]types.ChatGPTMessage{{Role: "user", Content: message}}, false) keyword, source := ChatWithWeb([]types.ChatGPTMessage{{Role: "user", Content: message}}, false)
resp, err := GetChatGPTResponse(source, 1000) resp, err := GetChatGPTResponse(source, 1000)
return keyword, resp, err return keyword, resp, err
} }
func GetAnonymousResponseWithCache(c *gin.Context, message string) (string, string, error) { func GetAnonymousResponseWithCache(c *gin.Context, message string, web bool) (string, string, error) {
cache := c.MustGet("cache").(*redis.Client) cache := c.MustGet("cache").(*redis.Client)
res, err := cache.Get(c, fmt.Sprintf(":chatgpt:%s", message)).Result() res, err := cache.Get(c, fmt.Sprintf(":chatgpt-%v:%s", web, message)).Result()
form := utils.UnmarshalJson[AnonymousResponseCache](res) form := utils.UnmarshalJson[AnonymousResponseCache](res)
if err != nil || len(res) == 0 || res == "{}" || form.Message == "" { if err != nil || len(res) == 0 || res == "{}" || form.Message == "" {
key, res, err := GetAnonymousResponse(message) key, res, err := GetAnonymousResponse(message, web)
if err != nil { if err != nil {
return "", "There was something wrong...", err return "", "There was something wrong...", err
} }
@ -76,7 +81,7 @@ func GetAnonymousResponseWithCache(c *gin.Context, message string) (string, stri
cache.Set(c, fmt.Sprintf(":chatgpt:%s", message), utils.ToJson(AnonymousResponseCache{ cache.Set(c, fmt.Sprintf(":chatgpt:%s", message), utils.ToJson(AnonymousResponseCache{
Keyword: key, Keyword: key,
Message: res, Message: res,
}), time.Hour*6) }), time.Hour*48)
return key, res, nil return key, res, nil
} }
return form.Keyword, form.Message, nil return form.Keyword, form.Message, nil
@ -103,7 +108,7 @@ func AnonymousAPI(c *gin.Context) {
}) })
return return
} }
key, res, err := GetAnonymousResponseWithCache(c, message) key, res, err := GetAnonymousResponseWithCache(c, message, body.Web)
if err != nil { if err != nil {
c.JSON(http.StatusOK, gin.H{ c.JSON(http.StatusOK, gin.H{
"status": false, "status": false,

73
api/buffer.go Normal file
View File

@ -0,0 +1,73 @@
package api
import (
"chat/auth"
"chat/types"
"chat/utils"
)
type Buffer struct {
Enable bool `json:"enable"`
Quota float32 `json:"quota"`
Data string `json:"data"`
Cursor int `json:"cursor"`
Times int `json:"times"`
}
func NewBuffer(enable bool, history []types.ChatGPTMessage) *Buffer {
buffer := &Buffer{Data: "", Cursor: 0, Times: 0, Enable: enable}
if enable {
buffer.Quota = auth.CountInputToken(utils.CountTokenPrice(history))
}
return buffer
}
func (b *Buffer) GetCursor() int {
return b.Cursor
}
func (b *Buffer) GetQuota() float32 {
if !b.Enable {
return 0.
}
return b.Quota + auth.CountOutputToken(b.ReadTimes())
}
func (b *Buffer) Write(data string) string {
b.Data += data
b.Cursor += len(data)
b.Times++
return data
}
func (b *Buffer) WriteBytes(data []byte) []byte {
b.Data += string(data)
b.Cursor += len(data)
b.Times++
return data
}
func (b *Buffer) IsEmpty() bool {
return b.Cursor == 0
}
func (b *Buffer) Reset() {
b.Data = ""
b.Cursor = 0
b.Times = 0
}
func (b *Buffer) Read() string {
return b.Data
}
func (b *Buffer) ReadWithDefault(_default string) string {
if b.IsEmpty() {
return _default
}
return b.Data
}
func (b *Buffer) ReadTimes() int {
return b.Times
}

View File

@ -7,7 +7,6 @@ import (
"chat/types" "chat/types"
"chat/utils" "chat/utils"
"database/sql" "database/sql"
"fmt"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
"github.com/go-redis/redis/v8" "github.com/go-redis/redis/v8"
"github.com/gorilla/websocket" "github.com/gorilla/websocket"
@ -15,6 +14,8 @@ import (
"strings" "strings"
) )
const defaultMessage = "There was something wrong... Please try again later."
type WebsocketAuthForm struct { type WebsocketAuthForm struct {
Token string `json:"token" binding:"required"` Token string `json:"token" binding:"required"`
Id int64 `json:"id" binding:"required"` Id int64 `json:"id" binding:"required"`
@ -25,39 +26,47 @@ func SendSegmentMessage(conn *websocket.Conn, message types.ChatGPTSegmentRespon
} }
func TextChat(db *sql.DB, user *auth.User, conn *websocket.Conn, instance *conversation.Conversation) string { func TextChat(db *sql.DB, user *auth.User, conn *websocket.Conn, instance *conversation.Conversation) string {
keyword, segment := ChatWithWeb(conversation.CopyMessage(instance.GetMessageSegment(12)), true) var keyword string
SendSegmentMessage(conn, types.ChatGPTSegmentResponse{Keyword: keyword, End: false}) var segment []types.ChatGPTMessage
msg := "" if instance.IsEnableWeb() {
keyword, segment = ChatWithWeb(conversation.CopyMessage(instance.GetMessageSegment(12)), true)
} else {
segment = conversation.CopyMessage(instance.GetMessageSegment(12))
}
SendSegmentMessage(conn, types.ChatGPTSegmentResponse{Keyword: keyword, End: false})
if instance.IsEnableGPT4() && !auth.ReduceGPT4(db, user) { if instance.IsEnableGPT4() && !auth.ReduceGPT4(db, user) {
SendSegmentMessage(conn, types.ChatGPTSegmentResponse{ SendSegmentMessage(conn, types.ChatGPTSegmentResponse{
Message: "You have run out of GPT-4 usage. Please buy more.", Message: "You have run out of GPT-4 usage. Please buy more.",
Quota: 0,
End: true, End: true,
}) })
return "You have run out of GPT-4 usage. Please buy more." return "You have run out of GPT-4 usage. Please buy more."
} }
buffer := NewBuffer(instance.IsEnableGPT4(), segment)
StreamRequest(instance.IsEnableGPT4(), segment, 2000, func(resp string) { StreamRequest(instance.IsEnableGPT4(), segment, 2000, func(resp string) {
msg += resp
SendSegmentMessage(conn, types.ChatGPTSegmentResponse{ SendSegmentMessage(conn, types.ChatGPTSegmentResponse{
Message: resp, Message: buffer.Write(resp),
Quota: buffer.GetQuota(),
End: false, End: false,
}) })
}) })
if msg == "" { if buffer.IsEmpty() {
msg = "There was something wrong... Please try again later."
if instance.IsEnableGPT4() { if instance.IsEnableGPT4() {
auth.IncreaseGPT4(db, user, 1) auth.IncreaseGPT4(db, user, 1)
} }
SendSegmentMessage(conn, types.ChatGPTSegmentResponse{ SendSegmentMessage(conn, types.ChatGPTSegmentResponse{
Message: msg, Message: defaultMessage,
Quota: buffer.GetQuota(),
End: false, End: false,
}) })
} }
SendSegmentMessage(conn, types.ChatGPTSegmentResponse{End: true}) SendSegmentMessage(conn, types.ChatGPTSegmentResponse{End: true, Quota: buffer.GetQuota()})
return msg return buffer.ReadWithDefault(defaultMessage)
} }
func ImageChat(conn *websocket.Conn, instance *conversation.Conversation, user *auth.User, db *sql.DB, cache *redis.Client) string { func ImageChat(conn *websocket.Conn, instance *conversation.Conversation, user *auth.User, db *sql.DB, cache *redis.Client) string {
@ -84,10 +93,9 @@ func ImageChat(conn *websocket.Conn, instance *conversation.Conversation, user *
return err.Error() return err.Error()
} }
markdown := fmt.Sprintln("![image](", url, ")") markdown := GetImageMarkdown(url)
SendSegmentMessage(conn, types.ChatGPTSegmentResponse{ SendSegmentMessage(conn, types.ChatGPTSegmentResponse{
Message: markdown, Message: markdown,
Keyword: "image",
End: true, End: true,
}) })
return markdown return markdown

View File

@ -70,3 +70,7 @@ func GetImageWithUserLimit(user *auth.User, prompt string, db *sql.DB, cache *re
return GetImageWithCache(context.Background(), prompt, cache) return GetImageWithCache(context.Background(), prompt, cache)
} }
} }
func GetImageMarkdown(url string) string {
return fmt.Sprintln("![image](", url, ")")
}

View File

@ -25,7 +25,7 @@ import {
DropdownMenuTrigger, DropdownMenuTrigger,
} from "./components/ui/dropdown-menu.tsx"; } from "./components/ui/dropdown-menu.tsx";
import { Toaster } from "./components/ui/toaster.tsx"; import { Toaster } from "./components/ui/toaster.tsx";
import { login } from "./conf.ts"; import {login, tokenField} from "./conf.ts";
import { useTranslation } from "react-i18next"; import { useTranslation } from "react-i18next";
function Settings() { function Settings() {
@ -67,7 +67,7 @@ function NavBar() {
const { t } = useTranslation(); const { t } = useTranslation();
const dispatch = useDispatch(); const dispatch = useDispatch();
useEffect(() => { useEffect(() => {
validateToken(dispatch, localStorage.getItem("token") ?? ""); validateToken(dispatch, localStorage.getItem(tokenField) ?? "");
}, []); }, []);
const auth = useSelector(selectAuthenticated); const auth = useSelector(selectAuthenticated);

View File

@ -13,8 +13,8 @@
.message { .message {
display: flex; display: flex;
gap: 6px;
flex-direction: column; flex-direction: column;
width: 100%;
&:last-child { &:last-child {
animation: FlexInAnimationFromBottom 0.2s cubic-bezier(0.175, 0.885, 0.32, 1.275) 0s 1 normal forwards running; animation: FlexInAnimationFromBottom 0.2s cubic-bezier(0.175, 0.885, 0.32, 1.275) 0s 1 normal forwards running;
@ -29,6 +29,36 @@
} }
} }
.message-quota {
display: flex;
flex-direction: row;
align-items: center;
user-select: none;
gap: 4px;
cursor: pointer;
border: 1px solid hsl(var(--input));
border-radius: var(--radius);
transition: 0.2s linear;
padding: 4px 8px;
width: max-content;
height: max-content;
white-space: nowrap;
.quota {
font-size: 14px;
color: hsl(var(--text-secondary));
}
.icon {
transform: translateY(1px);
color: hsl(var(--text-secondary));
}
&:hover {
border-color: hsl(var(--border-hover));
}
}
.message-content { .message-content {
padding: 8px 16px; padding: 8px 16px;
border-radius: var(--radius); border-radius: var(--radius);

View File

@ -64,3 +64,9 @@ strong {
color: hsl(var(--text)); color: hsl(var(--text));
border: 1px solid hsl(var(--border)); border: 1px solid hsl(var(--border));
} }
.icon-tooltip {
display: flex;
flex-direction: row;
align-items: center;
}

View File

@ -1,6 +1,6 @@
import { Message } from "../conversation/types.ts"; import { Message } from "../conversation/types.ts";
import Markdown from "./Markdown.tsx"; import Markdown from "./Markdown.tsx";
import {Copy, File, Loader2, MousePointerSquare} from "lucide-react"; import {Cloud, CloudFog, Copy, File, Loader2, MousePointerSquare} from "lucide-react";
import { import {
ContextMenu, ContextMenu,
ContextMenuContent, ContextMenuContent,
@ -9,6 +9,7 @@ import {
} from "./ui/context-menu.tsx"; } from "./ui/context-menu.tsx";
import {copyClipboard, saveAsFile, useInputValue} from "../utils.ts"; import {copyClipboard, saveAsFile, useInputValue} from "../utils.ts";
import {useTranslation} from "react-i18next"; import {useTranslation} from "react-i18next";
import {Tooltip, TooltipContent, TooltipProvider, TooltipTrigger} from "./ui/tooltip.tsx";
type MessageProps = { type MessageProps = {
message: Message; message: Message;
@ -22,6 +23,24 @@ function MessageSegment({ message }: MessageProps) {
<ContextMenuTrigger asChild> <ContextMenuTrigger asChild>
<div className={`message ${message.role}`}> <div className={`message ${message.role}`}>
<MessageContent message={message} /> <MessageContent message={message} />
{
(message.quota && message.quota > 0) ?
<TooltipProvider>
<Tooltip>
<TooltipTrigger asChild>
<div className={`message-quota`}>
<Cloud className={`h-4 w-4 icon`} />
<span className={`quota`}>{message.quota.toFixed(2)}</span>
</div>
</TooltipTrigger>
<TooltipContent className={`icon-tooltip`}>
<CloudFog className={`h-4 w-4 mr-2`} />
<p>{ t('quota-description') }</p>
</TooltipContent>
</Tooltip>
</TooltipProvider>
: null
}
</div> </div>
</ContextMenuTrigger> </ContextMenuTrigger>
<ContextMenuContent> <ContextMenuContent>
@ -41,6 +60,7 @@ function MessageSegment({ message }: MessageProps) {
function MessageContent({ message }: MessageProps) { function MessageContent({ message }: MessageProps) {
return ( return (
<>
<div className={`message-content`}> <div className={`message-content`}>
{message.keyword && message.keyword.length ? ( {message.keyword && message.keyword.length ? (
<div className={`bing`}> <div className={`bing`}>
@ -114,6 +134,7 @@ function MessageContent({ message }: MessageProps) {
<Loader2 className={`h-5 w-5 m-1 animate-spin`} /> <Loader2 className={`h-5 w-5 m-1 animate-spin`} />
)} )}
</div> </div>
</>
) )
} }

View File

@ -5,10 +5,12 @@ export let rest_api: string = "http://localhost:8094";
export let ws_api: string = "ws://localhost:8094"; export let ws_api: string = "ws://localhost:8094";
if (deploy) { if (deploy) {
rest_api = "https://nioapi.fystart.cn"; rest_api = "https://api.chatnio.net";
ws_api = "wss://nioapi.fystart.cn"; ws_api = "wss://api.chatnio.net";
} }
export const tokenField = deploy ? "token" : "token-dev";
export function login() { export function login() {
location.href = "https://deeptrain.lightxi.com/login?app=chatnio"; location.href = "https://deeptrain.lightxi.com/login?app=chatnio";
} }

View File

@ -1,9 +1,10 @@
import { ws_api } from "../conf.ts"; import {tokenField, ws_api} from "../conf.ts";
export const endpoint = `${ws_api}/chat`; export const endpoint = `${ws_api}/chat`;
export type StreamMessage = { export type StreamMessage = {
keyword?: string; keyword?: string;
quota?: number;
message: string; message: string;
end: boolean; end: boolean;
}; };
@ -35,7 +36,7 @@ export class Connection {
this.connection.onopen = () => { this.connection.onopen = () => {
this.state = true; this.state = true;
this.send({ this.send({
token: localStorage.getItem("token") || "", token: localStorage.getItem(tokenField) || "",
id: this.id, id: this.id,
}); });
}; };

View File

@ -79,9 +79,10 @@ export class Conversation {
this.triggerCallback(); this.triggerCallback();
} }
public updateMessage(idx: number, message: string, keyword?: string) { public updateMessage(idx: number, message: string, keyword?: string, quota?: number) {
this.data[idx].content += message; this.data[idx].content += message;
if (keyword) this.data[idx].keyword = keyword; if (keyword) this.data[idx].keyword = keyword;
if (quota) this.data[idx].quota = quota;
this.triggerCallback(); this.triggerCallback();
} }
@ -92,7 +93,7 @@ export class Conversation {
}); });
return (message: StreamMessage) => { return (message: StreamMessage) => {
this.updateMessage(cursor, message.message, message.keyword); this.updateMessage(cursor, message.message, message.keyword, message.quota);
if (message.end) { if (message.end) {
this.end = true; this.end = true;
} }

View File

@ -3,6 +3,7 @@ import { Conversation } from "./conversation.ts";
export type Message = { export type Message = {
content: string; content: string;
keyword?: string; keyword?: string;
quota?: number;
role: string; role: string;
}; };

View File

@ -51,7 +51,8 @@ const resources = {
"copy": "Copy", "copy": "Copy",
"save": "Save as File", "save": "Save as File",
"use": "Use Message", "use": "Use Message",
} },
"quota-description": "spending quota for the message",
}, },
}, },
cn: { cn: {
@ -94,7 +95,8 @@ const resources = {
"copy": "复制", "copy": "复制",
"save": "保存为文件", "save": "保存为文件",
"use": "使用消息", "use": "使用消息",
} },
"quota-description": "消息的配额支出",
}, },
}, },
}; };

View File

@ -1,7 +1,7 @@
import { useToast } from "../components/ui/use-toast.ts"; import { useToast } from "../components/ui/use-toast.ts";
import { useLocation } from "react-router-dom"; import { useLocation } from "react-router-dom";
import { ToastAction } from "../components/ui/toast.tsx"; import { ToastAction } from "../components/ui/toast.tsx";
import { login } from "../conf.ts"; import {login, tokenField} from "../conf.ts";
import { useEffect } from "react"; import { useEffect } from "react";
import Loader from "../components/Loader.tsx"; import Loader from "../components/Loader.tsx";
import "../assets/auth.less"; import "../assets/auth.less";
@ -16,7 +16,7 @@ function Auth() {
const { t } = useTranslation(); const { t } = useTranslation();
const dispatch = useDispatch(); const dispatch = useDispatch();
const search = new URLSearchParams(useLocation().search); const search = new URLSearchParams(useLocation().search);
const token = (search.get("token") || "").trim(); const token = (search.get(tokenField) || "").trim();
if (!token.length) { if (!token.length) {
toast({ toast({

View File

@ -137,7 +137,8 @@ function SideBar() {
{t("conversation.cancel")} {t("conversation.cancel")}
</AlertDialogCancel> </AlertDialogCancel>
<AlertDialogAction <AlertDialogAction
onClick={async () => { onClick={async (e) => {
e.preventDefault();
if ( if (
await deleteConversation(dispatch, conversation.id) await deleteConversation(dispatch, conversation.id)
) )

View File

@ -1,5 +1,6 @@
import { createSlice } from "@reduxjs/toolkit"; import { createSlice } from "@reduxjs/toolkit";
import axios from "axios"; import axios from "axios";
import {tokenField} from "../conf.ts";
export const authSlice = createSlice({ export const authSlice = createSlice({
name: "auth", name: "auth",
@ -12,7 +13,7 @@ export const authSlice = createSlice({
setToken: (state, action) => { setToken: (state, action) => {
state.token = action.payload as string; state.token = action.payload as string;
axios.defaults.headers.common["Authorization"] = state.token; axios.defaults.headers.common["Authorization"] = state.token;
localStorage.setItem("token", state.token); localStorage.setItem(tokenField, state.token);
}, },
setAuthenticated: (state, action) => { setAuthenticated: (state, action) => {
state.authenticated = action.payload as boolean; state.authenticated = action.payload as boolean;
@ -25,7 +26,7 @@ export const authSlice = createSlice({
state.authenticated = false; state.authenticated = false;
state.username = ""; state.username = "";
axios.defaults.headers.common["Authorization"] = ""; axios.defaults.headers.common["Authorization"] = "";
localStorage.removeItem("token"); localStorage.removeItem(tokenField);
location.reload(); location.reload();
}, },

View File

@ -4,6 +4,29 @@ import (
"database/sql" "database/sql"
) )
// Price Calculation
// 10 nio points = ¥1
// from 2023-9-6, 1 USD = 7.3124 CNY
//
// GPT-4 price (8k-context)
// Input Output
// $0.03 / 1K tokens $0.06 / 1K tokens
// ¥0.21 / 1K tokens ¥0.43 / 1K tokens
// 2.1 nio / 1K tokens 4.3 nio / 1K tokens
// Dalle price (512x512)
// $0.018 / per image
// ¥0.13 / per image
// 1 nio / per image
func CountInputToken(n int) float32 {
return float32(n) / 1000 * 2.1
}
func CountOutputToken(n int) float32 {
return float32(n) / 1000 * 4.3
}
func ReduceUsage(db *sql.DB, user *User, _t string) bool { func ReduceUsage(db *sql.DB, user *User, _t string) bool {
id := user.GetID(db) id := user.GetID(db)
var count int var count int
@ -82,7 +105,7 @@ func BuyDalle(db *sql.DB, user *User, value int) bool {
return true return true
} }
func CountGPT4Prize(value int) float32 { func CountGPT4price(value int) float32 {
if value <= 20 { if value <= 20 {
return float32(value) * 0.5 return float32(value) * 0.5
} }
@ -91,7 +114,7 @@ func CountGPT4Prize(value int) float32 {
} }
func BuyGPT4(db *sql.DB, user *User, value int) bool { func BuyGPT4(db *sql.DB, user *User, value int) bool {
if !Pay(user.Username, CountGPT4Prize(value)) { if !Pay(user.Username, CountGPT4price(value)) {
return false return false
} }

View File

@ -14,11 +14,13 @@ type Conversation struct {
Name string `json:"name"` Name string `json:"name"`
Message []types.ChatGPTMessage `json:"message"` Message []types.ChatGPTMessage `json:"message"`
EnableGPT4 bool `json:"enable_gpt4"` EnableGPT4 bool `json:"enable_gpt4"`
EnableWeb bool `json:"enable_web"`
} }
type FormMessage struct { type FormMessage struct {
Type string `json:"type"` // ping Type string `json:"type"` // ping
Message string `json:"message"` Message string `json:"message"`
Web bool `json:"web"`
GPT4 bool `json:"gpt4"` GPT4 bool `json:"gpt4"`
} }
@ -29,6 +31,7 @@ func NewConversation(db *sql.DB, id int64) *Conversation {
Name: "new chat", Name: "new chat",
Message: []types.ChatGPTMessage{}, Message: []types.ChatGPTMessage{},
EnableGPT4: false, EnableGPT4: false,
EnableWeb: false,
} }
} }
@ -36,10 +39,18 @@ func (c *Conversation) IsEnableGPT4() bool {
return c.EnableGPT4 return c.EnableGPT4
} }
func (c *Conversation) IsEnableWeb() bool {
return c.EnableWeb
}
func (c *Conversation) SetEnableGPT4(enable bool) { func (c *Conversation) SetEnableGPT4(enable bool) {
c.EnableGPT4 = enable c.EnableGPT4 = enable
} }
func (c *Conversation) SetEnableWeb(enable bool) {
c.EnableWeb = enable
}
func (c *Conversation) GetName() string { func (c *Conversation) GetName() string {
return c.Name return c.Name
} }
@ -131,6 +142,7 @@ func (c *Conversation) AddMessageFromUserForm(data []byte) (string, error) {
c.AddMessageFromUser(form.Message) c.AddMessageFromUser(form.Message)
c.SetEnableGPT4(form.GPT4) c.SetEnableGPT4(form.GPT4)
c.SetEnableWeb(form.Web)
return form.Message, nil return form.Message, nil
} }

3
go.mod
View File

@ -15,6 +15,7 @@ require (
github.com/cespare/xxhash/v2 v2.1.2 // indirect github.com/cespare/xxhash/v2 v2.1.2 // indirect
github.com/chenzhuoyu/base64x v0.0.0-20221115062448-fe3a3abad311 // indirect github.com/chenzhuoyu/base64x v0.0.0-20221115062448-fe3a3abad311 // indirect
github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f // indirect github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f // indirect
github.com/dlclark/regexp2 v1.10.0 // indirect
github.com/fsnotify/fsnotify v1.6.0 // indirect github.com/fsnotify/fsnotify v1.6.0 // indirect
github.com/gabriel-vasile/mimetype v1.4.2 // indirect github.com/gabriel-vasile/mimetype v1.4.2 // indirect
github.com/gin-contrib/sse v0.1.0 // indirect github.com/gin-contrib/sse v0.1.0 // indirect
@ -22,6 +23,7 @@ require (
github.com/go-playground/universal-translator v0.18.1 // indirect github.com/go-playground/universal-translator v0.18.1 // indirect
github.com/go-playground/validator/v10 v10.14.0 // indirect github.com/go-playground/validator/v10 v10.14.0 // indirect
github.com/goccy/go-json v0.10.2 // indirect github.com/goccy/go-json v0.10.2 // indirect
github.com/google/uuid v1.3.0 // indirect
github.com/gorilla/websocket v1.5.0 // indirect github.com/gorilla/websocket v1.5.0 // indirect
github.com/hashicorp/hcl v1.0.0 // indirect github.com/hashicorp/hcl v1.0.0 // indirect
github.com/json-iterator/go v1.1.12 // indirect github.com/json-iterator/go v1.1.12 // indirect
@ -33,6 +35,7 @@ require (
github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd // indirect github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd // indirect
github.com/modern-go/reflect2 v1.0.2 // indirect github.com/modern-go/reflect2 v1.0.2 // indirect
github.com/pelletier/go-toml/v2 v2.0.8 // indirect github.com/pelletier/go-toml/v2 v2.0.8 // indirect
github.com/pkoukk/tiktoken-go v0.1.5 // indirect
github.com/spf13/afero v1.9.5 // indirect github.com/spf13/afero v1.9.5 // indirect
github.com/spf13/cast v1.5.1 // indirect github.com/spf13/cast v1.5.1 // indirect
github.com/spf13/jwalterweatherman v1.1.0 // indirect github.com/spf13/jwalterweatherman v1.1.0 // indirect

6
go.sum
View File

@ -61,6 +61,8 @@ github.com/dgrijalva/jwt-go v3.2.0+incompatible h1:7qlOGliEKZXTDg6OTjfoBKDXWrumC
github.com/dgrijalva/jwt-go v3.2.0+incompatible/go.mod h1:E3ru+11k8xSBh+hMPgOLZmtrrCbhqsmaPHjLKYnJCaQ= github.com/dgrijalva/jwt-go v3.2.0+incompatible/go.mod h1:E3ru+11k8xSBh+hMPgOLZmtrrCbhqsmaPHjLKYnJCaQ=
github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f h1:lO4WD4F/rVNCu3HqELle0jiPLLBs70cWOduZpkS1E78= github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f h1:lO4WD4F/rVNCu3HqELle0jiPLLBs70cWOduZpkS1E78=
github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f/go.mod h1:cuUVRXasLTGF7a8hSLbxyZXjz+1KgoB3wDUb6vlszIc= github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f/go.mod h1:cuUVRXasLTGF7a8hSLbxyZXjz+1KgoB3wDUb6vlszIc=
github.com/dlclark/regexp2 v1.10.0 h1:+/GIL799phkJqYW+3YbOd8LCcbHzT0Pbo8zl70MHsq0=
github.com/dlclark/regexp2 v1.10.0/go.mod h1:DHkYz0B9wPfa6wondMfaivmHpzrQ3v9q8cnmRbL6yW8=
github.com/envoyproxy/go-control-plane v0.9.0/go.mod h1:YTl/9mNaCwkRvm6d1a2C3ymFceY/DCBVvsKhRF0iEA4= github.com/envoyproxy/go-control-plane v0.9.0/go.mod h1:YTl/9mNaCwkRvm6d1a2C3ymFceY/DCBVvsKhRF0iEA4=
github.com/envoyproxy/go-control-plane v0.9.1-0.20191026205805-5f8ba28d4473/go.mod h1:YTl/9mNaCwkRvm6d1a2C3ymFceY/DCBVvsKhRF0iEA4= github.com/envoyproxy/go-control-plane v0.9.1-0.20191026205805-5f8ba28d4473/go.mod h1:YTl/9mNaCwkRvm6d1a2C3ymFceY/DCBVvsKhRF0iEA4=
github.com/envoyproxy/go-control-plane v0.9.4/go.mod h1:6rpuAdCZL397s3pYoYcLgu1mIlRU8Am5FuJP05cCM98= github.com/envoyproxy/go-control-plane v0.9.4/go.mod h1:6rpuAdCZL397s3pYoYcLgu1mIlRU8Am5FuJP05cCM98=
@ -147,6 +149,8 @@ github.com/google/pprof v0.0.0-20201203190320-1bf35d6f28c2/go.mod h1:kpwsk12EmLe
github.com/google/pprof v0.0.0-20201218002935-b9804c9f04c2/go.mod h1:kpwsk12EmLew5upagYY7GY0pfYCcupk39gWOCRROcvE= github.com/google/pprof v0.0.0-20201218002935-b9804c9f04c2/go.mod h1:kpwsk12EmLew5upagYY7GY0pfYCcupk39gWOCRROcvE=
github.com/google/renameio v0.1.0/go.mod h1:KWCgfxg9yswjAJkECMjeO8J8rahYeXnNhOm40UhjYkI= github.com/google/renameio v0.1.0/go.mod h1:KWCgfxg9yswjAJkECMjeO8J8rahYeXnNhOm40UhjYkI=
github.com/google/uuid v1.1.2/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= github.com/google/uuid v1.1.2/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo=
github.com/google/uuid v1.3.0 h1:t6JiXgmwXMjEs8VusXIJk2BXHsn+wx8BZdTaoZ5fu7I=
github.com/google/uuid v1.3.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo=
github.com/googleapis/gax-go/v2 v2.0.4/go.mod h1:0Wqv26UfaUD9n4G6kQubkQ+KchISgw+vpHVxEJEs9eg= github.com/googleapis/gax-go/v2 v2.0.4/go.mod h1:0Wqv26UfaUD9n4G6kQubkQ+KchISgw+vpHVxEJEs9eg=
github.com/googleapis/gax-go/v2 v2.0.5/go.mod h1:DWXyrwAJ9X0FpwwEdw+IPEYBICEFu5mhpdKc/us6bOk= github.com/googleapis/gax-go/v2 v2.0.5/go.mod h1:DWXyrwAJ9X0FpwwEdw+IPEYBICEFu5mhpdKc/us6bOk=
github.com/googleapis/google-cloud-go-testing v0.0.0-20200911160855-bcd43fbb19e8/go.mod h1:dvDLG8qkwmyD9a/MJJN3XJcT3xFxOKAvTZGvuZmac9g= github.com/googleapis/google-cloud-go-testing v0.0.0-20200911160855-bcd43fbb19e8/go.mod h1:dvDLG8qkwmyD9a/MJJN3XJcT3xFxOKAvTZGvuZmac9g=
@ -192,6 +196,8 @@ github.com/pelletier/go-toml/v2 v2.0.8 h1:0ctb6s9mE31h0/lhu+J6OPmVeDxJn+kYnJc2jZ
github.com/pelletier/go-toml/v2 v2.0.8/go.mod h1:vuYfssBdrU2XDZ9bYydBu6t+6a6PYNcZljzZR9VXg+4= github.com/pelletier/go-toml/v2 v2.0.8/go.mod h1:vuYfssBdrU2XDZ9bYydBu6t+6a6PYNcZljzZR9VXg+4=
github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0=
github.com/pkg/sftp v1.13.1/go.mod h1:3HaPG6Dq1ILlpPZRO0HVMrsydcdLt6HRDccSgb87qRg= github.com/pkg/sftp v1.13.1/go.mod h1:3HaPG6Dq1ILlpPZRO0HVMrsydcdLt6HRDccSgb87qRg=
github.com/pkoukk/tiktoken-go v0.1.5 h1:hAlT4dCf6Uk50x8E7HQrddhH3EWMKUN+LArExQQsQx4=
github.com/pkoukk/tiktoken-go v0.1.5/go.mod h1:9NiV+i9mJKGj1rYOT+njbv+ZwA/zJxYdewGl6qVatpg=
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
github.com/prometheus/client_model v0.0.0-20190812154241-14fe0d1b01d4/go.mod h1:xMI15A0UPsDsEKsMN9yxemIoYk6Tm2C1GtYGdfGttqA= github.com/prometheus/client_model v0.0.0-20190812154241-14fe0d1b01d4/go.mod h1:xMI15A0UPsDsEKsMN9yxemIoYk6Tm2C1GtYGdfGttqA=

View File

@ -34,6 +34,7 @@ type ChatGPTStreamResponse struct {
} }
type ChatGPTSegmentResponse struct { type ChatGPTSegmentResponse struct {
Quota float32 `json:"quota"`
Keyword string `json:"keyword"` Keyword string `json:"keyword"`
Message string `json:"message"` Message string `json:"message"`
End bool `json:"end"` End bool `json:"end"`

View File

@ -37,6 +37,14 @@ func Unmarshal[T interface{}](data []byte) (form T, err error) {
return form, err return form, err
} }
func Marshal[T interface{}](data T) string {
res, err := json.Marshal(data)
if err != nil {
return ""
}
return string(res)
}
func ToInt(value string) int { func ToInt(value string) int {
if res, err := strconv.Atoi(value); err == nil { if res, err := strconv.Atoi(value); err == nil {
return res return res

60
utils/tokenizer.go Normal file
View File

@ -0,0 +1,60 @@
package utils
import (
"chat/types"
"fmt"
"github.com/pkoukk/tiktoken-go"
"strings"
)
// Using https://github.com/pkoukk/tiktoken-go
// To count number of tokens of openai chat messages
// OpenAI Cookbook: https://github.com/openai/openai-cookbook/blob/main/examples/How_to_count_tokens_with_tiktoken.ipynb
func GetWeightByModel(model string) int {
switch model {
case "gpt-3.5-turbo-0613",
"gpt-3.5-turbo-16k-0613",
"gpt-4-0314",
"gpt-4-32k-0314",
"gpt-4-0613",
"gpt-4-32k-0613":
return 3
case "gpt-3.5-turbo-0301":
return 4 // every message follows <|start|>{role/name}\n{content}<|end|>\n
default:
if strings.Contains(model, "gpt-3.5-turbo") {
// warning: gpt-3.5-turbo may update over time. Returning num tokens assuming gpt-3.5-turbo-0613.
return GetWeightByModel("gpt-3.5-turbo-0613")
} else if strings.Contains(model, "gpt-4") {
// warning: gpt-4 may update over time. Returning num tokens assuming gpt-4-0613.
return GetWeightByModel("gpt-4-0613")
} else {
// not implemented: See https://github.com/openai/openai-python/blob/main/chatml.md for information on how messages are converted to tokens
panic(fmt.Errorf("not implemented for model %s", model))
}
}
}
func NumTokensFromMessages(messages []types.ChatGPTMessage, model string) (tokens int) {
weight := GetWeightByModel(model)
tkm, err := tiktoken.EncodingForModel(model)
if err != nil {
// can not encode messages, use length of messages as a proxy for number of tokens
// using rune instead of byte to account for unicode characters (e.g. emojis, non-english characters)
data := Marshal(messages)
return len([]rune(data)) * weight
}
for _, message := range messages {
tokens += weight
tokens += len(tkm.Encode(message.Content, nil, nil))
tokens += len(tkm.Encode(message.Role, nil, nil))
}
tokens += 3 // every reply is primed with <|start|>assistant<|message|>
return tokens
}
func CountTokenPrice(messages []types.ChatGPTMessage) int {
return NumTokensFromMessages(messages, "gpt-4")
}