mirror of
https://github.com/coaidev/coai.git
synced 2025-05-19 21:10:18 +09:00
update quota
This commit is contained in:
parent
806cf3a048
commit
8958afca68
@ -14,6 +14,7 @@ import (
|
||||
|
||||
type AnonymousRequestBody struct {
|
||||
Message string `json:"message" required:"true"`
|
||||
Web bool `json:"web"`
|
||||
}
|
||||
|
||||
type AnonymousResponseCache struct {
|
||||
@ -57,18 +58,22 @@ func TestKey(key string) bool {
|
||||
return res.(map[string]interface{})["choices"] != nil
|
||||
}
|
||||
|
||||
func GetAnonymousResponse(message string) (string, string, error) {
|
||||
func GetAnonymousResponse(message string, web bool) (string, string, error) {
|
||||
if !web {
|
||||
resp, err := GetChatGPTResponse([]types.ChatGPTMessage{{Role: "user", Content: message}}, 1000)
|
||||
return "", resp, err
|
||||
}
|
||||
keyword, source := ChatWithWeb([]types.ChatGPTMessage{{Role: "user", Content: message}}, false)
|
||||
resp, err := GetChatGPTResponse(source, 1000)
|
||||
return keyword, resp, err
|
||||
}
|
||||
|
||||
func GetAnonymousResponseWithCache(c *gin.Context, message string) (string, string, error) {
|
||||
func GetAnonymousResponseWithCache(c *gin.Context, message string, web bool) (string, string, error) {
|
||||
cache := c.MustGet("cache").(*redis.Client)
|
||||
res, err := cache.Get(c, fmt.Sprintf(":chatgpt:%s", message)).Result()
|
||||
res, err := cache.Get(c, fmt.Sprintf(":chatgpt-%v:%s", web, message)).Result()
|
||||
form := utils.UnmarshalJson[AnonymousResponseCache](res)
|
||||
if err != nil || len(res) == 0 || res == "{}" || form.Message == "" {
|
||||
key, res, err := GetAnonymousResponse(message)
|
||||
key, res, err := GetAnonymousResponse(message, web)
|
||||
if err != nil {
|
||||
return "", "There was something wrong...", err
|
||||
}
|
||||
@ -76,7 +81,7 @@ func GetAnonymousResponseWithCache(c *gin.Context, message string) (string, stri
|
||||
cache.Set(c, fmt.Sprintf(":chatgpt:%s", message), utils.ToJson(AnonymousResponseCache{
|
||||
Keyword: key,
|
||||
Message: res,
|
||||
}), time.Hour*6)
|
||||
}), time.Hour*48)
|
||||
return key, res, nil
|
||||
}
|
||||
return form.Keyword, form.Message, nil
|
||||
@ -103,7 +108,7 @@ func AnonymousAPI(c *gin.Context) {
|
||||
})
|
||||
return
|
||||
}
|
||||
key, res, err := GetAnonymousResponseWithCache(c, message)
|
||||
key, res, err := GetAnonymousResponseWithCache(c, message, body.Web)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"status": false,
|
||||
|
73
api/buffer.go
Normal file
73
api/buffer.go
Normal file
@ -0,0 +1,73 @@
|
||||
package api
|
||||
|
||||
import (
|
||||
"chat/auth"
|
||||
"chat/types"
|
||||
"chat/utils"
|
||||
)
|
||||
|
||||
type Buffer struct {
|
||||
Enable bool `json:"enable"`
|
||||
Quota float32 `json:"quota"`
|
||||
Data string `json:"data"`
|
||||
Cursor int `json:"cursor"`
|
||||
Times int `json:"times"`
|
||||
}
|
||||
|
||||
func NewBuffer(enable bool, history []types.ChatGPTMessage) *Buffer {
|
||||
buffer := &Buffer{Data: "", Cursor: 0, Times: 0, Enable: enable}
|
||||
if enable {
|
||||
buffer.Quota = auth.CountInputToken(utils.CountTokenPrice(history))
|
||||
}
|
||||
return buffer
|
||||
}
|
||||
|
||||
func (b *Buffer) GetCursor() int {
|
||||
return b.Cursor
|
||||
}
|
||||
|
||||
func (b *Buffer) GetQuota() float32 {
|
||||
if !b.Enable {
|
||||
return 0.
|
||||
}
|
||||
return b.Quota + auth.CountOutputToken(b.ReadTimes())
|
||||
}
|
||||
|
||||
func (b *Buffer) Write(data string) string {
|
||||
b.Data += data
|
||||
b.Cursor += len(data)
|
||||
b.Times++
|
||||
return data
|
||||
}
|
||||
|
||||
func (b *Buffer) WriteBytes(data []byte) []byte {
|
||||
b.Data += string(data)
|
||||
b.Cursor += len(data)
|
||||
b.Times++
|
||||
return data
|
||||
}
|
||||
|
||||
func (b *Buffer) IsEmpty() bool {
|
||||
return b.Cursor == 0
|
||||
}
|
||||
|
||||
func (b *Buffer) Reset() {
|
||||
b.Data = ""
|
||||
b.Cursor = 0
|
||||
b.Times = 0
|
||||
}
|
||||
|
||||
func (b *Buffer) Read() string {
|
||||
return b.Data
|
||||
}
|
||||
|
||||
func (b *Buffer) ReadWithDefault(_default string) string {
|
||||
if b.IsEmpty() {
|
||||
return _default
|
||||
}
|
||||
return b.Data
|
||||
}
|
||||
|
||||
func (b *Buffer) ReadTimes() int {
|
||||
return b.Times
|
||||
}
|
34
api/chat.go
34
api/chat.go
@ -7,7 +7,6 @@ import (
|
||||
"chat/types"
|
||||
"chat/utils"
|
||||
"database/sql"
|
||||
"fmt"
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/go-redis/redis/v8"
|
||||
"github.com/gorilla/websocket"
|
||||
@ -15,6 +14,8 @@ import (
|
||||
"strings"
|
||||
)
|
||||
|
||||
const defaultMessage = "There was something wrong... Please try again later."
|
||||
|
||||
type WebsocketAuthForm struct {
|
||||
Token string `json:"token" binding:"required"`
|
||||
Id int64 `json:"id" binding:"required"`
|
||||
@ -25,39 +26,47 @@ func SendSegmentMessage(conn *websocket.Conn, message types.ChatGPTSegmentRespon
|
||||
}
|
||||
|
||||
func TextChat(db *sql.DB, user *auth.User, conn *websocket.Conn, instance *conversation.Conversation) string {
|
||||
keyword, segment := ChatWithWeb(conversation.CopyMessage(instance.GetMessageSegment(12)), true)
|
||||
SendSegmentMessage(conn, types.ChatGPTSegmentResponse{Keyword: keyword, End: false})
|
||||
var keyword string
|
||||
var segment []types.ChatGPTMessage
|
||||
|
||||
msg := ""
|
||||
if instance.IsEnableWeb() {
|
||||
keyword, segment = ChatWithWeb(conversation.CopyMessage(instance.GetMessageSegment(12)), true)
|
||||
} else {
|
||||
segment = conversation.CopyMessage(instance.GetMessageSegment(12))
|
||||
}
|
||||
|
||||
SendSegmentMessage(conn, types.ChatGPTSegmentResponse{Keyword: keyword, End: false})
|
||||
|
||||
if instance.IsEnableGPT4() && !auth.ReduceGPT4(db, user) {
|
||||
SendSegmentMessage(conn, types.ChatGPTSegmentResponse{
|
||||
Message: "You have run out of GPT-4 usage. Please buy more.",
|
||||
Quota: 0,
|
||||
End: true,
|
||||
})
|
||||
return "You have run out of GPT-4 usage. Please buy more."
|
||||
}
|
||||
|
||||
buffer := NewBuffer(instance.IsEnableGPT4(), segment)
|
||||
StreamRequest(instance.IsEnableGPT4(), segment, 2000, func(resp string) {
|
||||
msg += resp
|
||||
SendSegmentMessage(conn, types.ChatGPTSegmentResponse{
|
||||
Message: resp,
|
||||
Message: buffer.Write(resp),
|
||||
Quota: buffer.GetQuota(),
|
||||
End: false,
|
||||
})
|
||||
})
|
||||
if msg == "" {
|
||||
msg = "There was something wrong... Please try again later."
|
||||
if buffer.IsEmpty() {
|
||||
if instance.IsEnableGPT4() {
|
||||
auth.IncreaseGPT4(db, user, 1)
|
||||
}
|
||||
SendSegmentMessage(conn, types.ChatGPTSegmentResponse{
|
||||
Message: msg,
|
||||
Message: defaultMessage,
|
||||
Quota: buffer.GetQuota(),
|
||||
End: false,
|
||||
})
|
||||
}
|
||||
SendSegmentMessage(conn, types.ChatGPTSegmentResponse{End: true})
|
||||
SendSegmentMessage(conn, types.ChatGPTSegmentResponse{End: true, Quota: buffer.GetQuota()})
|
||||
|
||||
return msg
|
||||
return buffer.ReadWithDefault(defaultMessage)
|
||||
}
|
||||
|
||||
func ImageChat(conn *websocket.Conn, instance *conversation.Conversation, user *auth.User, db *sql.DB, cache *redis.Client) string {
|
||||
@ -84,10 +93,9 @@ func ImageChat(conn *websocket.Conn, instance *conversation.Conversation, user *
|
||||
return err.Error()
|
||||
}
|
||||
|
||||
markdown := fmt.Sprintln("")
|
||||
markdown := GetImageMarkdown(url)
|
||||
SendSegmentMessage(conn, types.ChatGPTSegmentResponse{
|
||||
Message: markdown,
|
||||
Keyword: "image",
|
||||
End: true,
|
||||
})
|
||||
return markdown
|
||||
|
@ -70,3 +70,7 @@ func GetImageWithUserLimit(user *auth.User, prompt string, db *sql.DB, cache *re
|
||||
return GetImageWithCache(context.Background(), prompt, cache)
|
||||
}
|
||||
}
|
||||
|
||||
func GetImageMarkdown(url string) string {
|
||||
return fmt.Sprintln("")
|
||||
}
|
||||
|
@ -25,7 +25,7 @@ import {
|
||||
DropdownMenuTrigger,
|
||||
} from "./components/ui/dropdown-menu.tsx";
|
||||
import { Toaster } from "./components/ui/toaster.tsx";
|
||||
import { login } from "./conf.ts";
|
||||
import {login, tokenField} from "./conf.ts";
|
||||
import { useTranslation } from "react-i18next";
|
||||
|
||||
function Settings() {
|
||||
@ -67,7 +67,7 @@ function NavBar() {
|
||||
const { t } = useTranslation();
|
||||
const dispatch = useDispatch();
|
||||
useEffect(() => {
|
||||
validateToken(dispatch, localStorage.getItem("token") ?? "");
|
||||
validateToken(dispatch, localStorage.getItem(tokenField) ?? "");
|
||||
}, []);
|
||||
const auth = useSelector(selectAuthenticated);
|
||||
|
||||
|
@ -13,8 +13,8 @@
|
||||
|
||||
.message {
|
||||
display: flex;
|
||||
gap: 6px;
|
||||
flex-direction: column;
|
||||
width: 100%;
|
||||
|
||||
&:last-child {
|
||||
animation: FlexInAnimationFromBottom 0.2s cubic-bezier(0.175, 0.885, 0.32, 1.275) 0s 1 normal forwards running;
|
||||
@ -29,6 +29,36 @@
|
||||
}
|
||||
}
|
||||
|
||||
.message-quota {
|
||||
display: flex;
|
||||
flex-direction: row;
|
||||
align-items: center;
|
||||
user-select: none;
|
||||
gap: 4px;
|
||||
cursor: pointer;
|
||||
border: 1px solid hsl(var(--input));
|
||||
border-radius: var(--radius);
|
||||
transition: 0.2s linear;
|
||||
padding: 4px 8px;
|
||||
width: max-content;
|
||||
height: max-content;
|
||||
white-space: nowrap;
|
||||
|
||||
.quota {
|
||||
font-size: 14px;
|
||||
color: hsl(var(--text-secondary));
|
||||
}
|
||||
|
||||
.icon {
|
||||
transform: translateY(1px);
|
||||
color: hsl(var(--text-secondary));
|
||||
}
|
||||
|
||||
&:hover {
|
||||
border-color: hsl(var(--border-hover));
|
||||
}
|
||||
}
|
||||
|
||||
.message-content {
|
||||
padding: 8px 16px;
|
||||
border-radius: var(--radius);
|
||||
|
@ -64,3 +64,9 @@ strong {
|
||||
color: hsl(var(--text));
|
||||
border: 1px solid hsl(var(--border));
|
||||
}
|
||||
|
||||
.icon-tooltip {
|
||||
display: flex;
|
||||
flex-direction: row;
|
||||
align-items: center;
|
||||
}
|
||||
|
@ -1,6 +1,6 @@
|
||||
import { Message } from "../conversation/types.ts";
|
||||
import Markdown from "./Markdown.tsx";
|
||||
import {Copy, File, Loader2, MousePointerSquare} from "lucide-react";
|
||||
import {Cloud, CloudFog, Copy, File, Loader2, MousePointerSquare} from "lucide-react";
|
||||
import {
|
||||
ContextMenu,
|
||||
ContextMenuContent,
|
||||
@ -9,6 +9,7 @@ import {
|
||||
} from "./ui/context-menu.tsx";
|
||||
import {copyClipboard, saveAsFile, useInputValue} from "../utils.ts";
|
||||
import {useTranslation} from "react-i18next";
|
||||
import {Tooltip, TooltipContent, TooltipProvider, TooltipTrigger} from "./ui/tooltip.tsx";
|
||||
|
||||
type MessageProps = {
|
||||
message: Message;
|
||||
@ -22,6 +23,24 @@ function MessageSegment({ message }: MessageProps) {
|
||||
<ContextMenuTrigger asChild>
|
||||
<div className={`message ${message.role}`}>
|
||||
<MessageContent message={message} />
|
||||
{
|
||||
(message.quota && message.quota > 0) ?
|
||||
<TooltipProvider>
|
||||
<Tooltip>
|
||||
<TooltipTrigger asChild>
|
||||
<div className={`message-quota`}>
|
||||
<Cloud className={`h-4 w-4 icon`} />
|
||||
<span className={`quota`}>{message.quota.toFixed(2)}</span>
|
||||
</div>
|
||||
</TooltipTrigger>
|
||||
<TooltipContent className={`icon-tooltip`}>
|
||||
<CloudFog className={`h-4 w-4 mr-2`} />
|
||||
<p>{ t('quota-description') }</p>
|
||||
</TooltipContent>
|
||||
</Tooltip>
|
||||
</TooltipProvider>
|
||||
: null
|
||||
}
|
||||
</div>
|
||||
</ContextMenuTrigger>
|
||||
<ContextMenuContent>
|
||||
@ -41,6 +60,7 @@ function MessageSegment({ message }: MessageProps) {
|
||||
|
||||
function MessageContent({ message }: MessageProps) {
|
||||
return (
|
||||
<>
|
||||
<div className={`message-content`}>
|
||||
{message.keyword && message.keyword.length ? (
|
||||
<div className={`bing`}>
|
||||
@ -114,6 +134,7 @@ function MessageContent({ message }: MessageProps) {
|
||||
<Loader2 className={`h-5 w-5 m-1 animate-spin`} />
|
||||
)}
|
||||
</div>
|
||||
</>
|
||||
)
|
||||
}
|
||||
|
||||
|
@ -5,10 +5,12 @@ export let rest_api: string = "http://localhost:8094";
|
||||
export let ws_api: string = "ws://localhost:8094";
|
||||
|
||||
if (deploy) {
|
||||
rest_api = "https://nioapi.fystart.cn";
|
||||
ws_api = "wss://nioapi.fystart.cn";
|
||||
rest_api = "https://api.chatnio.net";
|
||||
ws_api = "wss://api.chatnio.net";
|
||||
}
|
||||
|
||||
export const tokenField = deploy ? "token" : "token-dev";
|
||||
|
||||
export function login() {
|
||||
location.href = "https://deeptrain.lightxi.com/login?app=chatnio";
|
||||
}
|
||||
|
@ -1,9 +1,10 @@
|
||||
import { ws_api } from "../conf.ts";
|
||||
import {tokenField, ws_api} from "../conf.ts";
|
||||
|
||||
export const endpoint = `${ws_api}/chat`;
|
||||
|
||||
export type StreamMessage = {
|
||||
keyword?: string;
|
||||
quota?: number;
|
||||
message: string;
|
||||
end: boolean;
|
||||
};
|
||||
@ -35,7 +36,7 @@ export class Connection {
|
||||
this.connection.onopen = () => {
|
||||
this.state = true;
|
||||
this.send({
|
||||
token: localStorage.getItem("token") || "",
|
||||
token: localStorage.getItem(tokenField) || "",
|
||||
id: this.id,
|
||||
});
|
||||
};
|
||||
|
@ -79,9 +79,10 @@ export class Conversation {
|
||||
this.triggerCallback();
|
||||
}
|
||||
|
||||
public updateMessage(idx: number, message: string, keyword?: string) {
|
||||
public updateMessage(idx: number, message: string, keyword?: string, quota?: number) {
|
||||
this.data[idx].content += message;
|
||||
if (keyword) this.data[idx].keyword = keyword;
|
||||
if (quota) this.data[idx].quota = quota;
|
||||
this.triggerCallback();
|
||||
}
|
||||
|
||||
@ -92,7 +93,7 @@ export class Conversation {
|
||||
});
|
||||
|
||||
return (message: StreamMessage) => {
|
||||
this.updateMessage(cursor, message.message, message.keyword);
|
||||
this.updateMessage(cursor, message.message, message.keyword, message.quota);
|
||||
if (message.end) {
|
||||
this.end = true;
|
||||
}
|
||||
|
@ -3,6 +3,7 @@ import { Conversation } from "./conversation.ts";
|
||||
export type Message = {
|
||||
content: string;
|
||||
keyword?: string;
|
||||
quota?: number;
|
||||
role: string;
|
||||
};
|
||||
|
||||
|
@ -51,7 +51,8 @@ const resources = {
|
||||
"copy": "Copy",
|
||||
"save": "Save as File",
|
||||
"use": "Use Message",
|
||||
}
|
||||
},
|
||||
"quota-description": "spending quota for the message",
|
||||
},
|
||||
},
|
||||
cn: {
|
||||
@ -94,7 +95,8 @@ const resources = {
|
||||
"copy": "复制",
|
||||
"save": "保存为文件",
|
||||
"use": "使用消息",
|
||||
}
|
||||
},
|
||||
"quota-description": "消息的配额支出",
|
||||
},
|
||||
},
|
||||
};
|
||||
|
@ -1,7 +1,7 @@
|
||||
import { useToast } from "../components/ui/use-toast.ts";
|
||||
import { useLocation } from "react-router-dom";
|
||||
import { ToastAction } from "../components/ui/toast.tsx";
|
||||
import { login } from "../conf.ts";
|
||||
import {login, tokenField} from "../conf.ts";
|
||||
import { useEffect } from "react";
|
||||
import Loader from "../components/Loader.tsx";
|
||||
import "../assets/auth.less";
|
||||
@ -16,7 +16,7 @@ function Auth() {
|
||||
const { t } = useTranslation();
|
||||
const dispatch = useDispatch();
|
||||
const search = new URLSearchParams(useLocation().search);
|
||||
const token = (search.get("token") || "").trim();
|
||||
const token = (search.get(tokenField) || "").trim();
|
||||
|
||||
if (!token.length) {
|
||||
toast({
|
||||
|
@ -137,7 +137,8 @@ function SideBar() {
|
||||
{t("conversation.cancel")}
|
||||
</AlertDialogCancel>
|
||||
<AlertDialogAction
|
||||
onClick={async () => {
|
||||
onClick={async (e) => {
|
||||
e.preventDefault();
|
||||
if (
|
||||
await deleteConversation(dispatch, conversation.id)
|
||||
)
|
||||
|
@ -1,5 +1,6 @@
|
||||
import { createSlice } from "@reduxjs/toolkit";
|
||||
import axios from "axios";
|
||||
import {tokenField} from "../conf.ts";
|
||||
|
||||
export const authSlice = createSlice({
|
||||
name: "auth",
|
||||
@ -12,7 +13,7 @@ export const authSlice = createSlice({
|
||||
setToken: (state, action) => {
|
||||
state.token = action.payload as string;
|
||||
axios.defaults.headers.common["Authorization"] = state.token;
|
||||
localStorage.setItem("token", state.token);
|
||||
localStorage.setItem(tokenField, state.token);
|
||||
},
|
||||
setAuthenticated: (state, action) => {
|
||||
state.authenticated = action.payload as boolean;
|
||||
@ -25,7 +26,7 @@ export const authSlice = createSlice({
|
||||
state.authenticated = false;
|
||||
state.username = "";
|
||||
axios.defaults.headers.common["Authorization"] = "";
|
||||
localStorage.removeItem("token");
|
||||
localStorage.removeItem(tokenField);
|
||||
|
||||
location.reload();
|
||||
},
|
||||
|
@ -4,6 +4,29 @@ import (
|
||||
"database/sql"
|
||||
)
|
||||
|
||||
// Price Calculation
|
||||
// 10 nio points = ¥1
|
||||
// from 2023-9-6, 1 USD = 7.3124 CNY
|
||||
//
|
||||
// GPT-4 price (8k-context)
|
||||
// Input Output
|
||||
// $0.03 / 1K tokens $0.06 / 1K tokens
|
||||
// ¥0.21 / 1K tokens ¥0.43 / 1K tokens
|
||||
// 2.1 nio / 1K tokens 4.3 nio / 1K tokens
|
||||
|
||||
// Dalle price (512x512)
|
||||
// $0.018 / per image
|
||||
// ¥0.13 / per image
|
||||
// 1 nio / per image
|
||||
|
||||
func CountInputToken(n int) float32 {
|
||||
return float32(n) / 1000 * 2.1
|
||||
}
|
||||
|
||||
func CountOutputToken(n int) float32 {
|
||||
return float32(n) / 1000 * 4.3
|
||||
}
|
||||
|
||||
func ReduceUsage(db *sql.DB, user *User, _t string) bool {
|
||||
id := user.GetID(db)
|
||||
var count int
|
||||
@ -82,7 +105,7 @@ func BuyDalle(db *sql.DB, user *User, value int) bool {
|
||||
return true
|
||||
}
|
||||
|
||||
func CountGPT4Prize(value int) float32 {
|
||||
func CountGPT4price(value int) float32 {
|
||||
if value <= 20 {
|
||||
return float32(value) * 0.5
|
||||
}
|
||||
@ -91,7 +114,7 @@ func CountGPT4Prize(value int) float32 {
|
||||
}
|
||||
|
||||
func BuyGPT4(db *sql.DB, user *User, value int) bool {
|
||||
if !Pay(user.Username, CountGPT4Prize(value)) {
|
||||
if !Pay(user.Username, CountGPT4price(value)) {
|
||||
return false
|
||||
}
|
||||
|
||||
|
@ -14,11 +14,13 @@ type Conversation struct {
|
||||
Name string `json:"name"`
|
||||
Message []types.ChatGPTMessage `json:"message"`
|
||||
EnableGPT4 bool `json:"enable_gpt4"`
|
||||
EnableWeb bool `json:"enable_web"`
|
||||
}
|
||||
|
||||
type FormMessage struct {
|
||||
Type string `json:"type"` // ping
|
||||
Message string `json:"message"`
|
||||
Web bool `json:"web"`
|
||||
GPT4 bool `json:"gpt4"`
|
||||
}
|
||||
|
||||
@ -29,6 +31,7 @@ func NewConversation(db *sql.DB, id int64) *Conversation {
|
||||
Name: "new chat",
|
||||
Message: []types.ChatGPTMessage{},
|
||||
EnableGPT4: false,
|
||||
EnableWeb: false,
|
||||
}
|
||||
}
|
||||
|
||||
@ -36,10 +39,18 @@ func (c *Conversation) IsEnableGPT4() bool {
|
||||
return c.EnableGPT4
|
||||
}
|
||||
|
||||
func (c *Conversation) IsEnableWeb() bool {
|
||||
return c.EnableWeb
|
||||
}
|
||||
|
||||
func (c *Conversation) SetEnableGPT4(enable bool) {
|
||||
c.EnableGPT4 = enable
|
||||
}
|
||||
|
||||
func (c *Conversation) SetEnableWeb(enable bool) {
|
||||
c.EnableWeb = enable
|
||||
}
|
||||
|
||||
func (c *Conversation) GetName() string {
|
||||
return c.Name
|
||||
}
|
||||
@ -131,6 +142,7 @@ func (c *Conversation) AddMessageFromUserForm(data []byte) (string, error) {
|
||||
|
||||
c.AddMessageFromUser(form.Message)
|
||||
c.SetEnableGPT4(form.GPT4)
|
||||
c.SetEnableWeb(form.Web)
|
||||
return form.Message, nil
|
||||
}
|
||||
|
||||
|
3
go.mod
3
go.mod
@ -15,6 +15,7 @@ require (
|
||||
github.com/cespare/xxhash/v2 v2.1.2 // indirect
|
||||
github.com/chenzhuoyu/base64x v0.0.0-20221115062448-fe3a3abad311 // indirect
|
||||
github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f // indirect
|
||||
github.com/dlclark/regexp2 v1.10.0 // indirect
|
||||
github.com/fsnotify/fsnotify v1.6.0 // indirect
|
||||
github.com/gabriel-vasile/mimetype v1.4.2 // indirect
|
||||
github.com/gin-contrib/sse v0.1.0 // indirect
|
||||
@ -22,6 +23,7 @@ require (
|
||||
github.com/go-playground/universal-translator v0.18.1 // indirect
|
||||
github.com/go-playground/validator/v10 v10.14.0 // indirect
|
||||
github.com/goccy/go-json v0.10.2 // indirect
|
||||
github.com/google/uuid v1.3.0 // indirect
|
||||
github.com/gorilla/websocket v1.5.0 // indirect
|
||||
github.com/hashicorp/hcl v1.0.0 // indirect
|
||||
github.com/json-iterator/go v1.1.12 // indirect
|
||||
@ -33,6 +35,7 @@ require (
|
||||
github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd // indirect
|
||||
github.com/modern-go/reflect2 v1.0.2 // indirect
|
||||
github.com/pelletier/go-toml/v2 v2.0.8 // indirect
|
||||
github.com/pkoukk/tiktoken-go v0.1.5 // indirect
|
||||
github.com/spf13/afero v1.9.5 // indirect
|
||||
github.com/spf13/cast v1.5.1 // indirect
|
||||
github.com/spf13/jwalterweatherman v1.1.0 // indirect
|
||||
|
6
go.sum
6
go.sum
@ -61,6 +61,8 @@ github.com/dgrijalva/jwt-go v3.2.0+incompatible h1:7qlOGliEKZXTDg6OTjfoBKDXWrumC
|
||||
github.com/dgrijalva/jwt-go v3.2.0+incompatible/go.mod h1:E3ru+11k8xSBh+hMPgOLZmtrrCbhqsmaPHjLKYnJCaQ=
|
||||
github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f h1:lO4WD4F/rVNCu3HqELle0jiPLLBs70cWOduZpkS1E78=
|
||||
github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f/go.mod h1:cuUVRXasLTGF7a8hSLbxyZXjz+1KgoB3wDUb6vlszIc=
|
||||
github.com/dlclark/regexp2 v1.10.0 h1:+/GIL799phkJqYW+3YbOd8LCcbHzT0Pbo8zl70MHsq0=
|
||||
github.com/dlclark/regexp2 v1.10.0/go.mod h1:DHkYz0B9wPfa6wondMfaivmHpzrQ3v9q8cnmRbL6yW8=
|
||||
github.com/envoyproxy/go-control-plane v0.9.0/go.mod h1:YTl/9mNaCwkRvm6d1a2C3ymFceY/DCBVvsKhRF0iEA4=
|
||||
github.com/envoyproxy/go-control-plane v0.9.1-0.20191026205805-5f8ba28d4473/go.mod h1:YTl/9mNaCwkRvm6d1a2C3ymFceY/DCBVvsKhRF0iEA4=
|
||||
github.com/envoyproxy/go-control-plane v0.9.4/go.mod h1:6rpuAdCZL397s3pYoYcLgu1mIlRU8Am5FuJP05cCM98=
|
||||
@ -147,6 +149,8 @@ github.com/google/pprof v0.0.0-20201203190320-1bf35d6f28c2/go.mod h1:kpwsk12EmLe
|
||||
github.com/google/pprof v0.0.0-20201218002935-b9804c9f04c2/go.mod h1:kpwsk12EmLew5upagYY7GY0pfYCcupk39gWOCRROcvE=
|
||||
github.com/google/renameio v0.1.0/go.mod h1:KWCgfxg9yswjAJkECMjeO8J8rahYeXnNhOm40UhjYkI=
|
||||
github.com/google/uuid v1.1.2/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo=
|
||||
github.com/google/uuid v1.3.0 h1:t6JiXgmwXMjEs8VusXIJk2BXHsn+wx8BZdTaoZ5fu7I=
|
||||
github.com/google/uuid v1.3.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo=
|
||||
github.com/googleapis/gax-go/v2 v2.0.4/go.mod h1:0Wqv26UfaUD9n4G6kQubkQ+KchISgw+vpHVxEJEs9eg=
|
||||
github.com/googleapis/gax-go/v2 v2.0.5/go.mod h1:DWXyrwAJ9X0FpwwEdw+IPEYBICEFu5mhpdKc/us6bOk=
|
||||
github.com/googleapis/google-cloud-go-testing v0.0.0-20200911160855-bcd43fbb19e8/go.mod h1:dvDLG8qkwmyD9a/MJJN3XJcT3xFxOKAvTZGvuZmac9g=
|
||||
@ -192,6 +196,8 @@ github.com/pelletier/go-toml/v2 v2.0.8 h1:0ctb6s9mE31h0/lhu+J6OPmVeDxJn+kYnJc2jZ
|
||||
github.com/pelletier/go-toml/v2 v2.0.8/go.mod h1:vuYfssBdrU2XDZ9bYydBu6t+6a6PYNcZljzZR9VXg+4=
|
||||
github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0=
|
||||
github.com/pkg/sftp v1.13.1/go.mod h1:3HaPG6Dq1ILlpPZRO0HVMrsydcdLt6HRDccSgb87qRg=
|
||||
github.com/pkoukk/tiktoken-go v0.1.5 h1:hAlT4dCf6Uk50x8E7HQrddhH3EWMKUN+LArExQQsQx4=
|
||||
github.com/pkoukk/tiktoken-go v0.1.5/go.mod h1:9NiV+i9mJKGj1rYOT+njbv+ZwA/zJxYdewGl6qVatpg=
|
||||
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
|
||||
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
|
||||
github.com/prometheus/client_model v0.0.0-20190812154241-14fe0d1b01d4/go.mod h1:xMI15A0UPsDsEKsMN9yxemIoYk6Tm2C1GtYGdfGttqA=
|
||||
|
@ -34,6 +34,7 @@ type ChatGPTStreamResponse struct {
|
||||
}
|
||||
|
||||
type ChatGPTSegmentResponse struct {
|
||||
Quota float32 `json:"quota"`
|
||||
Keyword string `json:"keyword"`
|
||||
Message string `json:"message"`
|
||||
End bool `json:"end"`
|
||||
|
@ -37,6 +37,14 @@ func Unmarshal[T interface{}](data []byte) (form T, err error) {
|
||||
return form, err
|
||||
}
|
||||
|
||||
func Marshal[T interface{}](data T) string {
|
||||
res, err := json.Marshal(data)
|
||||
if err != nil {
|
||||
return ""
|
||||
}
|
||||
return string(res)
|
||||
}
|
||||
|
||||
func ToInt(value string) int {
|
||||
if res, err := strconv.Atoi(value); err == nil {
|
||||
return res
|
||||
|
60
utils/tokenizer.go
Normal file
60
utils/tokenizer.go
Normal file
@ -0,0 +1,60 @@
|
||||
package utils
|
||||
|
||||
import (
|
||||
"chat/types"
|
||||
"fmt"
|
||||
"github.com/pkoukk/tiktoken-go"
|
||||
"strings"
|
||||
)
|
||||
|
||||
// Using https://github.com/pkoukk/tiktoken-go
|
||||
// To count number of tokens of openai chat messages
|
||||
// OpenAI Cookbook: https://github.com/openai/openai-cookbook/blob/main/examples/How_to_count_tokens_with_tiktoken.ipynb
|
||||
|
||||
func GetWeightByModel(model string) int {
|
||||
switch model {
|
||||
case "gpt-3.5-turbo-0613",
|
||||
"gpt-3.5-turbo-16k-0613",
|
||||
"gpt-4-0314",
|
||||
"gpt-4-32k-0314",
|
||||
"gpt-4-0613",
|
||||
"gpt-4-32k-0613":
|
||||
return 3
|
||||
case "gpt-3.5-turbo-0301":
|
||||
return 4 // every message follows <|start|>{role/name}\n{content}<|end|>\n
|
||||
default:
|
||||
if strings.Contains(model, "gpt-3.5-turbo") {
|
||||
// warning: gpt-3.5-turbo may update over time. Returning num tokens assuming gpt-3.5-turbo-0613.
|
||||
return GetWeightByModel("gpt-3.5-turbo-0613")
|
||||
} else if strings.Contains(model, "gpt-4") {
|
||||
// warning: gpt-4 may update over time. Returning num tokens assuming gpt-4-0613.
|
||||
return GetWeightByModel("gpt-4-0613")
|
||||
} else {
|
||||
// not implemented: See https://github.com/openai/openai-python/blob/main/chatml.md for information on how messages are converted to tokens
|
||||
panic(fmt.Errorf("not implemented for model %s", model))
|
||||
}
|
||||
}
|
||||
}
|
||||
func NumTokensFromMessages(messages []types.ChatGPTMessage, model string) (tokens int) {
|
||||
weight := GetWeightByModel(model)
|
||||
tkm, err := tiktoken.EncodingForModel(model)
|
||||
if err != nil {
|
||||
// can not encode messages, use length of messages as a proxy for number of tokens
|
||||
// using rune instead of byte to account for unicode characters (e.g. emojis, non-english characters)
|
||||
|
||||
data := Marshal(messages)
|
||||
return len([]rune(data)) * weight
|
||||
}
|
||||
|
||||
for _, message := range messages {
|
||||
tokens += weight
|
||||
tokens += len(tkm.Encode(message.Content, nil, nil))
|
||||
tokens += len(tkm.Encode(message.Role, nil, nil))
|
||||
}
|
||||
tokens += 3 // every reply is primed with <|start|>assistant<|message|>
|
||||
return tokens
|
||||
}
|
||||
|
||||
func CountTokenPrice(messages []types.ChatGPTMessage) int {
|
||||
return NumTokensFromMessages(messages, "gpt-4")
|
||||
}
|
Loading…
Reference in New Issue
Block a user