mirror of
https://github.com/coaidev/coai.git
synced 2025-05-20 05:20:15 +09:00
update quota
This commit is contained in:
parent
806cf3a048
commit
8958afca68
@ -14,6 +14,7 @@ import (
|
|||||||
|
|
||||||
type AnonymousRequestBody struct {
|
type AnonymousRequestBody struct {
|
||||||
Message string `json:"message" required:"true"`
|
Message string `json:"message" required:"true"`
|
||||||
|
Web bool `json:"web"`
|
||||||
}
|
}
|
||||||
|
|
||||||
type AnonymousResponseCache struct {
|
type AnonymousResponseCache struct {
|
||||||
@ -57,18 +58,22 @@ func TestKey(key string) bool {
|
|||||||
return res.(map[string]interface{})["choices"] != nil
|
return res.(map[string]interface{})["choices"] != nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func GetAnonymousResponse(message string) (string, string, error) {
|
func GetAnonymousResponse(message string, web bool) (string, string, error) {
|
||||||
|
if !web {
|
||||||
|
resp, err := GetChatGPTResponse([]types.ChatGPTMessage{{Role: "user", Content: message}}, 1000)
|
||||||
|
return "", resp, err
|
||||||
|
}
|
||||||
keyword, source := ChatWithWeb([]types.ChatGPTMessage{{Role: "user", Content: message}}, false)
|
keyword, source := ChatWithWeb([]types.ChatGPTMessage{{Role: "user", Content: message}}, false)
|
||||||
resp, err := GetChatGPTResponse(source, 1000)
|
resp, err := GetChatGPTResponse(source, 1000)
|
||||||
return keyword, resp, err
|
return keyword, resp, err
|
||||||
}
|
}
|
||||||
|
|
||||||
func GetAnonymousResponseWithCache(c *gin.Context, message string) (string, string, error) {
|
func GetAnonymousResponseWithCache(c *gin.Context, message string, web bool) (string, string, error) {
|
||||||
cache := c.MustGet("cache").(*redis.Client)
|
cache := c.MustGet("cache").(*redis.Client)
|
||||||
res, err := cache.Get(c, fmt.Sprintf(":chatgpt:%s", message)).Result()
|
res, err := cache.Get(c, fmt.Sprintf(":chatgpt-%v:%s", web, message)).Result()
|
||||||
form := utils.UnmarshalJson[AnonymousResponseCache](res)
|
form := utils.UnmarshalJson[AnonymousResponseCache](res)
|
||||||
if err != nil || len(res) == 0 || res == "{}" || form.Message == "" {
|
if err != nil || len(res) == 0 || res == "{}" || form.Message == "" {
|
||||||
key, res, err := GetAnonymousResponse(message)
|
key, res, err := GetAnonymousResponse(message, web)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return "", "There was something wrong...", err
|
return "", "There was something wrong...", err
|
||||||
}
|
}
|
||||||
@ -76,7 +81,7 @@ func GetAnonymousResponseWithCache(c *gin.Context, message string) (string, stri
|
|||||||
cache.Set(c, fmt.Sprintf(":chatgpt:%s", message), utils.ToJson(AnonymousResponseCache{
|
cache.Set(c, fmt.Sprintf(":chatgpt:%s", message), utils.ToJson(AnonymousResponseCache{
|
||||||
Keyword: key,
|
Keyword: key,
|
||||||
Message: res,
|
Message: res,
|
||||||
}), time.Hour*6)
|
}), time.Hour*48)
|
||||||
return key, res, nil
|
return key, res, nil
|
||||||
}
|
}
|
||||||
return form.Keyword, form.Message, nil
|
return form.Keyword, form.Message, nil
|
||||||
@ -103,7 +108,7 @@ func AnonymousAPI(c *gin.Context) {
|
|||||||
})
|
})
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
key, res, err := GetAnonymousResponseWithCache(c, message)
|
key, res, err := GetAnonymousResponseWithCache(c, message, body.Web)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
c.JSON(http.StatusOK, gin.H{
|
c.JSON(http.StatusOK, gin.H{
|
||||||
"status": false,
|
"status": false,
|
||||||
|
73
api/buffer.go
Normal file
73
api/buffer.go
Normal file
@ -0,0 +1,73 @@
|
|||||||
|
package api
|
||||||
|
|
||||||
|
import (
|
||||||
|
"chat/auth"
|
||||||
|
"chat/types"
|
||||||
|
"chat/utils"
|
||||||
|
)
|
||||||
|
|
||||||
|
type Buffer struct {
|
||||||
|
Enable bool `json:"enable"`
|
||||||
|
Quota float32 `json:"quota"`
|
||||||
|
Data string `json:"data"`
|
||||||
|
Cursor int `json:"cursor"`
|
||||||
|
Times int `json:"times"`
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewBuffer(enable bool, history []types.ChatGPTMessage) *Buffer {
|
||||||
|
buffer := &Buffer{Data: "", Cursor: 0, Times: 0, Enable: enable}
|
||||||
|
if enable {
|
||||||
|
buffer.Quota = auth.CountInputToken(utils.CountTokenPrice(history))
|
||||||
|
}
|
||||||
|
return buffer
|
||||||
|
}
|
||||||
|
|
||||||
|
func (b *Buffer) GetCursor() int {
|
||||||
|
return b.Cursor
|
||||||
|
}
|
||||||
|
|
||||||
|
func (b *Buffer) GetQuota() float32 {
|
||||||
|
if !b.Enable {
|
||||||
|
return 0.
|
||||||
|
}
|
||||||
|
return b.Quota + auth.CountOutputToken(b.ReadTimes())
|
||||||
|
}
|
||||||
|
|
||||||
|
func (b *Buffer) Write(data string) string {
|
||||||
|
b.Data += data
|
||||||
|
b.Cursor += len(data)
|
||||||
|
b.Times++
|
||||||
|
return data
|
||||||
|
}
|
||||||
|
|
||||||
|
func (b *Buffer) WriteBytes(data []byte) []byte {
|
||||||
|
b.Data += string(data)
|
||||||
|
b.Cursor += len(data)
|
||||||
|
b.Times++
|
||||||
|
return data
|
||||||
|
}
|
||||||
|
|
||||||
|
func (b *Buffer) IsEmpty() bool {
|
||||||
|
return b.Cursor == 0
|
||||||
|
}
|
||||||
|
|
||||||
|
func (b *Buffer) Reset() {
|
||||||
|
b.Data = ""
|
||||||
|
b.Cursor = 0
|
||||||
|
b.Times = 0
|
||||||
|
}
|
||||||
|
|
||||||
|
func (b *Buffer) Read() string {
|
||||||
|
return b.Data
|
||||||
|
}
|
||||||
|
|
||||||
|
func (b *Buffer) ReadWithDefault(_default string) string {
|
||||||
|
if b.IsEmpty() {
|
||||||
|
return _default
|
||||||
|
}
|
||||||
|
return b.Data
|
||||||
|
}
|
||||||
|
|
||||||
|
func (b *Buffer) ReadTimes() int {
|
||||||
|
return b.Times
|
||||||
|
}
|
34
api/chat.go
34
api/chat.go
@ -7,7 +7,6 @@ import (
|
|||||||
"chat/types"
|
"chat/types"
|
||||||
"chat/utils"
|
"chat/utils"
|
||||||
"database/sql"
|
"database/sql"
|
||||||
"fmt"
|
|
||||||
"github.com/gin-gonic/gin"
|
"github.com/gin-gonic/gin"
|
||||||
"github.com/go-redis/redis/v8"
|
"github.com/go-redis/redis/v8"
|
||||||
"github.com/gorilla/websocket"
|
"github.com/gorilla/websocket"
|
||||||
@ -15,6 +14,8 @@ import (
|
|||||||
"strings"
|
"strings"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
const defaultMessage = "There was something wrong... Please try again later."
|
||||||
|
|
||||||
type WebsocketAuthForm struct {
|
type WebsocketAuthForm struct {
|
||||||
Token string `json:"token" binding:"required"`
|
Token string `json:"token" binding:"required"`
|
||||||
Id int64 `json:"id" binding:"required"`
|
Id int64 `json:"id" binding:"required"`
|
||||||
@ -25,39 +26,47 @@ func SendSegmentMessage(conn *websocket.Conn, message types.ChatGPTSegmentRespon
|
|||||||
}
|
}
|
||||||
|
|
||||||
func TextChat(db *sql.DB, user *auth.User, conn *websocket.Conn, instance *conversation.Conversation) string {
|
func TextChat(db *sql.DB, user *auth.User, conn *websocket.Conn, instance *conversation.Conversation) string {
|
||||||
keyword, segment := ChatWithWeb(conversation.CopyMessage(instance.GetMessageSegment(12)), true)
|
var keyword string
|
||||||
SendSegmentMessage(conn, types.ChatGPTSegmentResponse{Keyword: keyword, End: false})
|
var segment []types.ChatGPTMessage
|
||||||
|
|
||||||
msg := ""
|
if instance.IsEnableWeb() {
|
||||||
|
keyword, segment = ChatWithWeb(conversation.CopyMessage(instance.GetMessageSegment(12)), true)
|
||||||
|
} else {
|
||||||
|
segment = conversation.CopyMessage(instance.GetMessageSegment(12))
|
||||||
|
}
|
||||||
|
|
||||||
|
SendSegmentMessage(conn, types.ChatGPTSegmentResponse{Keyword: keyword, End: false})
|
||||||
|
|
||||||
if instance.IsEnableGPT4() && !auth.ReduceGPT4(db, user) {
|
if instance.IsEnableGPT4() && !auth.ReduceGPT4(db, user) {
|
||||||
SendSegmentMessage(conn, types.ChatGPTSegmentResponse{
|
SendSegmentMessage(conn, types.ChatGPTSegmentResponse{
|
||||||
Message: "You have run out of GPT-4 usage. Please buy more.",
|
Message: "You have run out of GPT-4 usage. Please buy more.",
|
||||||
|
Quota: 0,
|
||||||
End: true,
|
End: true,
|
||||||
})
|
})
|
||||||
return "You have run out of GPT-4 usage. Please buy more."
|
return "You have run out of GPT-4 usage. Please buy more."
|
||||||
}
|
}
|
||||||
|
|
||||||
|
buffer := NewBuffer(instance.IsEnableGPT4(), segment)
|
||||||
StreamRequest(instance.IsEnableGPT4(), segment, 2000, func(resp string) {
|
StreamRequest(instance.IsEnableGPT4(), segment, 2000, func(resp string) {
|
||||||
msg += resp
|
|
||||||
SendSegmentMessage(conn, types.ChatGPTSegmentResponse{
|
SendSegmentMessage(conn, types.ChatGPTSegmentResponse{
|
||||||
Message: resp,
|
Message: buffer.Write(resp),
|
||||||
|
Quota: buffer.GetQuota(),
|
||||||
End: false,
|
End: false,
|
||||||
})
|
})
|
||||||
})
|
})
|
||||||
if msg == "" {
|
if buffer.IsEmpty() {
|
||||||
msg = "There was something wrong... Please try again later."
|
|
||||||
if instance.IsEnableGPT4() {
|
if instance.IsEnableGPT4() {
|
||||||
auth.IncreaseGPT4(db, user, 1)
|
auth.IncreaseGPT4(db, user, 1)
|
||||||
}
|
}
|
||||||
SendSegmentMessage(conn, types.ChatGPTSegmentResponse{
|
SendSegmentMessage(conn, types.ChatGPTSegmentResponse{
|
||||||
Message: msg,
|
Message: defaultMessage,
|
||||||
|
Quota: buffer.GetQuota(),
|
||||||
End: false,
|
End: false,
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
SendSegmentMessage(conn, types.ChatGPTSegmentResponse{End: true})
|
SendSegmentMessage(conn, types.ChatGPTSegmentResponse{End: true, Quota: buffer.GetQuota()})
|
||||||
|
|
||||||
return msg
|
return buffer.ReadWithDefault(defaultMessage)
|
||||||
}
|
}
|
||||||
|
|
||||||
func ImageChat(conn *websocket.Conn, instance *conversation.Conversation, user *auth.User, db *sql.DB, cache *redis.Client) string {
|
func ImageChat(conn *websocket.Conn, instance *conversation.Conversation, user *auth.User, db *sql.DB, cache *redis.Client) string {
|
||||||
@ -84,10 +93,9 @@ func ImageChat(conn *websocket.Conn, instance *conversation.Conversation, user *
|
|||||||
return err.Error()
|
return err.Error()
|
||||||
}
|
}
|
||||||
|
|
||||||
markdown := fmt.Sprintln("")
|
markdown := GetImageMarkdown(url)
|
||||||
SendSegmentMessage(conn, types.ChatGPTSegmentResponse{
|
SendSegmentMessage(conn, types.ChatGPTSegmentResponse{
|
||||||
Message: markdown,
|
Message: markdown,
|
||||||
Keyword: "image",
|
|
||||||
End: true,
|
End: true,
|
||||||
})
|
})
|
||||||
return markdown
|
return markdown
|
||||||
|
@ -70,3 +70,7 @@ func GetImageWithUserLimit(user *auth.User, prompt string, db *sql.DB, cache *re
|
|||||||
return GetImageWithCache(context.Background(), prompt, cache)
|
return GetImageWithCache(context.Background(), prompt, cache)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func GetImageMarkdown(url string) string {
|
||||||
|
return fmt.Sprintln("")
|
||||||
|
}
|
||||||
|
@ -25,7 +25,7 @@ import {
|
|||||||
DropdownMenuTrigger,
|
DropdownMenuTrigger,
|
||||||
} from "./components/ui/dropdown-menu.tsx";
|
} from "./components/ui/dropdown-menu.tsx";
|
||||||
import { Toaster } from "./components/ui/toaster.tsx";
|
import { Toaster } from "./components/ui/toaster.tsx";
|
||||||
import { login } from "./conf.ts";
|
import {login, tokenField} from "./conf.ts";
|
||||||
import { useTranslation } from "react-i18next";
|
import { useTranslation } from "react-i18next";
|
||||||
|
|
||||||
function Settings() {
|
function Settings() {
|
||||||
@ -67,7 +67,7 @@ function NavBar() {
|
|||||||
const { t } = useTranslation();
|
const { t } = useTranslation();
|
||||||
const dispatch = useDispatch();
|
const dispatch = useDispatch();
|
||||||
useEffect(() => {
|
useEffect(() => {
|
||||||
validateToken(dispatch, localStorage.getItem("token") ?? "");
|
validateToken(dispatch, localStorage.getItem(tokenField) ?? "");
|
||||||
}, []);
|
}, []);
|
||||||
const auth = useSelector(selectAuthenticated);
|
const auth = useSelector(selectAuthenticated);
|
||||||
|
|
||||||
|
@ -13,8 +13,8 @@
|
|||||||
|
|
||||||
.message {
|
.message {
|
||||||
display: flex;
|
display: flex;
|
||||||
|
gap: 6px;
|
||||||
flex-direction: column;
|
flex-direction: column;
|
||||||
width: 100%;
|
|
||||||
|
|
||||||
&:last-child {
|
&:last-child {
|
||||||
animation: FlexInAnimationFromBottom 0.2s cubic-bezier(0.175, 0.885, 0.32, 1.275) 0s 1 normal forwards running;
|
animation: FlexInAnimationFromBottom 0.2s cubic-bezier(0.175, 0.885, 0.32, 1.275) 0s 1 normal forwards running;
|
||||||
@ -29,6 +29,36 @@
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
.message-quota {
|
||||||
|
display: flex;
|
||||||
|
flex-direction: row;
|
||||||
|
align-items: center;
|
||||||
|
user-select: none;
|
||||||
|
gap: 4px;
|
||||||
|
cursor: pointer;
|
||||||
|
border: 1px solid hsl(var(--input));
|
||||||
|
border-radius: var(--radius);
|
||||||
|
transition: 0.2s linear;
|
||||||
|
padding: 4px 8px;
|
||||||
|
width: max-content;
|
||||||
|
height: max-content;
|
||||||
|
white-space: nowrap;
|
||||||
|
|
||||||
|
.quota {
|
||||||
|
font-size: 14px;
|
||||||
|
color: hsl(var(--text-secondary));
|
||||||
|
}
|
||||||
|
|
||||||
|
.icon {
|
||||||
|
transform: translateY(1px);
|
||||||
|
color: hsl(var(--text-secondary));
|
||||||
|
}
|
||||||
|
|
||||||
|
&:hover {
|
||||||
|
border-color: hsl(var(--border-hover));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
.message-content {
|
.message-content {
|
||||||
padding: 8px 16px;
|
padding: 8px 16px;
|
||||||
border-radius: var(--radius);
|
border-radius: var(--radius);
|
||||||
|
@ -64,3 +64,9 @@ strong {
|
|||||||
color: hsl(var(--text));
|
color: hsl(var(--text));
|
||||||
border: 1px solid hsl(var(--border));
|
border: 1px solid hsl(var(--border));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
.icon-tooltip {
|
||||||
|
display: flex;
|
||||||
|
flex-direction: row;
|
||||||
|
align-items: center;
|
||||||
|
}
|
||||||
|
@ -1,6 +1,6 @@
|
|||||||
import { Message } from "../conversation/types.ts";
|
import { Message } from "../conversation/types.ts";
|
||||||
import Markdown from "./Markdown.tsx";
|
import Markdown from "./Markdown.tsx";
|
||||||
import {Copy, File, Loader2, MousePointerSquare} from "lucide-react";
|
import {Cloud, CloudFog, Copy, File, Loader2, MousePointerSquare} from "lucide-react";
|
||||||
import {
|
import {
|
||||||
ContextMenu,
|
ContextMenu,
|
||||||
ContextMenuContent,
|
ContextMenuContent,
|
||||||
@ -9,6 +9,7 @@ import {
|
|||||||
} from "./ui/context-menu.tsx";
|
} from "./ui/context-menu.tsx";
|
||||||
import {copyClipboard, saveAsFile, useInputValue} from "../utils.ts";
|
import {copyClipboard, saveAsFile, useInputValue} from "../utils.ts";
|
||||||
import {useTranslation} from "react-i18next";
|
import {useTranslation} from "react-i18next";
|
||||||
|
import {Tooltip, TooltipContent, TooltipProvider, TooltipTrigger} from "./ui/tooltip.tsx";
|
||||||
|
|
||||||
type MessageProps = {
|
type MessageProps = {
|
||||||
message: Message;
|
message: Message;
|
||||||
@ -22,6 +23,24 @@ function MessageSegment({ message }: MessageProps) {
|
|||||||
<ContextMenuTrigger asChild>
|
<ContextMenuTrigger asChild>
|
||||||
<div className={`message ${message.role}`}>
|
<div className={`message ${message.role}`}>
|
||||||
<MessageContent message={message} />
|
<MessageContent message={message} />
|
||||||
|
{
|
||||||
|
(message.quota && message.quota > 0) ?
|
||||||
|
<TooltipProvider>
|
||||||
|
<Tooltip>
|
||||||
|
<TooltipTrigger asChild>
|
||||||
|
<div className={`message-quota`}>
|
||||||
|
<Cloud className={`h-4 w-4 icon`} />
|
||||||
|
<span className={`quota`}>{message.quota.toFixed(2)}</span>
|
||||||
|
</div>
|
||||||
|
</TooltipTrigger>
|
||||||
|
<TooltipContent className={`icon-tooltip`}>
|
||||||
|
<CloudFog className={`h-4 w-4 mr-2`} />
|
||||||
|
<p>{ t('quota-description') }</p>
|
||||||
|
</TooltipContent>
|
||||||
|
</Tooltip>
|
||||||
|
</TooltipProvider>
|
||||||
|
: null
|
||||||
|
}
|
||||||
</div>
|
</div>
|
||||||
</ContextMenuTrigger>
|
</ContextMenuTrigger>
|
||||||
<ContextMenuContent>
|
<ContextMenuContent>
|
||||||
@ -41,6 +60,7 @@ function MessageSegment({ message }: MessageProps) {
|
|||||||
|
|
||||||
function MessageContent({ message }: MessageProps) {
|
function MessageContent({ message }: MessageProps) {
|
||||||
return (
|
return (
|
||||||
|
<>
|
||||||
<div className={`message-content`}>
|
<div className={`message-content`}>
|
||||||
{message.keyword && message.keyword.length ? (
|
{message.keyword && message.keyword.length ? (
|
||||||
<div className={`bing`}>
|
<div className={`bing`}>
|
||||||
@ -114,6 +134,7 @@ function MessageContent({ message }: MessageProps) {
|
|||||||
<Loader2 className={`h-5 w-5 m-1 animate-spin`} />
|
<Loader2 className={`h-5 w-5 m-1 animate-spin`} />
|
||||||
)}
|
)}
|
||||||
</div>
|
</div>
|
||||||
|
</>
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -5,10 +5,12 @@ export let rest_api: string = "http://localhost:8094";
|
|||||||
export let ws_api: string = "ws://localhost:8094";
|
export let ws_api: string = "ws://localhost:8094";
|
||||||
|
|
||||||
if (deploy) {
|
if (deploy) {
|
||||||
rest_api = "https://nioapi.fystart.cn";
|
rest_api = "https://api.chatnio.net";
|
||||||
ws_api = "wss://nioapi.fystart.cn";
|
ws_api = "wss://api.chatnio.net";
|
||||||
}
|
}
|
||||||
|
|
||||||
|
export const tokenField = deploy ? "token" : "token-dev";
|
||||||
|
|
||||||
export function login() {
|
export function login() {
|
||||||
location.href = "https://deeptrain.lightxi.com/login?app=chatnio";
|
location.href = "https://deeptrain.lightxi.com/login?app=chatnio";
|
||||||
}
|
}
|
||||||
|
@ -1,9 +1,10 @@
|
|||||||
import { ws_api } from "../conf.ts";
|
import {tokenField, ws_api} from "../conf.ts";
|
||||||
|
|
||||||
export const endpoint = `${ws_api}/chat`;
|
export const endpoint = `${ws_api}/chat`;
|
||||||
|
|
||||||
export type StreamMessage = {
|
export type StreamMessage = {
|
||||||
keyword?: string;
|
keyword?: string;
|
||||||
|
quota?: number;
|
||||||
message: string;
|
message: string;
|
||||||
end: boolean;
|
end: boolean;
|
||||||
};
|
};
|
||||||
@ -35,7 +36,7 @@ export class Connection {
|
|||||||
this.connection.onopen = () => {
|
this.connection.onopen = () => {
|
||||||
this.state = true;
|
this.state = true;
|
||||||
this.send({
|
this.send({
|
||||||
token: localStorage.getItem("token") || "",
|
token: localStorage.getItem(tokenField) || "",
|
||||||
id: this.id,
|
id: this.id,
|
||||||
});
|
});
|
||||||
};
|
};
|
||||||
|
@ -79,9 +79,10 @@ export class Conversation {
|
|||||||
this.triggerCallback();
|
this.triggerCallback();
|
||||||
}
|
}
|
||||||
|
|
||||||
public updateMessage(idx: number, message: string, keyword?: string) {
|
public updateMessage(idx: number, message: string, keyword?: string, quota?: number) {
|
||||||
this.data[idx].content += message;
|
this.data[idx].content += message;
|
||||||
if (keyword) this.data[idx].keyword = keyword;
|
if (keyword) this.data[idx].keyword = keyword;
|
||||||
|
if (quota) this.data[idx].quota = quota;
|
||||||
this.triggerCallback();
|
this.triggerCallback();
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -92,7 +93,7 @@ export class Conversation {
|
|||||||
});
|
});
|
||||||
|
|
||||||
return (message: StreamMessage) => {
|
return (message: StreamMessage) => {
|
||||||
this.updateMessage(cursor, message.message, message.keyword);
|
this.updateMessage(cursor, message.message, message.keyword, message.quota);
|
||||||
if (message.end) {
|
if (message.end) {
|
||||||
this.end = true;
|
this.end = true;
|
||||||
}
|
}
|
||||||
|
@ -3,6 +3,7 @@ import { Conversation } from "./conversation.ts";
|
|||||||
export type Message = {
|
export type Message = {
|
||||||
content: string;
|
content: string;
|
||||||
keyword?: string;
|
keyword?: string;
|
||||||
|
quota?: number;
|
||||||
role: string;
|
role: string;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
@ -51,7 +51,8 @@ const resources = {
|
|||||||
"copy": "Copy",
|
"copy": "Copy",
|
||||||
"save": "Save as File",
|
"save": "Save as File",
|
||||||
"use": "Use Message",
|
"use": "Use Message",
|
||||||
}
|
},
|
||||||
|
"quota-description": "spending quota for the message",
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
cn: {
|
cn: {
|
||||||
@ -94,7 +95,8 @@ const resources = {
|
|||||||
"copy": "复制",
|
"copy": "复制",
|
||||||
"save": "保存为文件",
|
"save": "保存为文件",
|
||||||
"use": "使用消息",
|
"use": "使用消息",
|
||||||
}
|
},
|
||||||
|
"quota-description": "消息的配额支出",
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
};
|
};
|
||||||
|
@ -1,7 +1,7 @@
|
|||||||
import { useToast } from "../components/ui/use-toast.ts";
|
import { useToast } from "../components/ui/use-toast.ts";
|
||||||
import { useLocation } from "react-router-dom";
|
import { useLocation } from "react-router-dom";
|
||||||
import { ToastAction } from "../components/ui/toast.tsx";
|
import { ToastAction } from "../components/ui/toast.tsx";
|
||||||
import { login } from "../conf.ts";
|
import {login, tokenField} from "../conf.ts";
|
||||||
import { useEffect } from "react";
|
import { useEffect } from "react";
|
||||||
import Loader from "../components/Loader.tsx";
|
import Loader from "../components/Loader.tsx";
|
||||||
import "../assets/auth.less";
|
import "../assets/auth.less";
|
||||||
@ -16,7 +16,7 @@ function Auth() {
|
|||||||
const { t } = useTranslation();
|
const { t } = useTranslation();
|
||||||
const dispatch = useDispatch();
|
const dispatch = useDispatch();
|
||||||
const search = new URLSearchParams(useLocation().search);
|
const search = new URLSearchParams(useLocation().search);
|
||||||
const token = (search.get("token") || "").trim();
|
const token = (search.get(tokenField) || "").trim();
|
||||||
|
|
||||||
if (!token.length) {
|
if (!token.length) {
|
||||||
toast({
|
toast({
|
||||||
|
@ -137,7 +137,8 @@ function SideBar() {
|
|||||||
{t("conversation.cancel")}
|
{t("conversation.cancel")}
|
||||||
</AlertDialogCancel>
|
</AlertDialogCancel>
|
||||||
<AlertDialogAction
|
<AlertDialogAction
|
||||||
onClick={async () => {
|
onClick={async (e) => {
|
||||||
|
e.preventDefault();
|
||||||
if (
|
if (
|
||||||
await deleteConversation(dispatch, conversation.id)
|
await deleteConversation(dispatch, conversation.id)
|
||||||
)
|
)
|
||||||
|
@ -1,5 +1,6 @@
|
|||||||
import { createSlice } from "@reduxjs/toolkit";
|
import { createSlice } from "@reduxjs/toolkit";
|
||||||
import axios from "axios";
|
import axios from "axios";
|
||||||
|
import {tokenField} from "../conf.ts";
|
||||||
|
|
||||||
export const authSlice = createSlice({
|
export const authSlice = createSlice({
|
||||||
name: "auth",
|
name: "auth",
|
||||||
@ -12,7 +13,7 @@ export const authSlice = createSlice({
|
|||||||
setToken: (state, action) => {
|
setToken: (state, action) => {
|
||||||
state.token = action.payload as string;
|
state.token = action.payload as string;
|
||||||
axios.defaults.headers.common["Authorization"] = state.token;
|
axios.defaults.headers.common["Authorization"] = state.token;
|
||||||
localStorage.setItem("token", state.token);
|
localStorage.setItem(tokenField, state.token);
|
||||||
},
|
},
|
||||||
setAuthenticated: (state, action) => {
|
setAuthenticated: (state, action) => {
|
||||||
state.authenticated = action.payload as boolean;
|
state.authenticated = action.payload as boolean;
|
||||||
@ -25,7 +26,7 @@ export const authSlice = createSlice({
|
|||||||
state.authenticated = false;
|
state.authenticated = false;
|
||||||
state.username = "";
|
state.username = "";
|
||||||
axios.defaults.headers.common["Authorization"] = "";
|
axios.defaults.headers.common["Authorization"] = "";
|
||||||
localStorage.removeItem("token");
|
localStorage.removeItem(tokenField);
|
||||||
|
|
||||||
location.reload();
|
location.reload();
|
||||||
},
|
},
|
||||||
|
@ -4,6 +4,29 @@ import (
|
|||||||
"database/sql"
|
"database/sql"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
// Price Calculation
|
||||||
|
// 10 nio points = ¥1
|
||||||
|
// from 2023-9-6, 1 USD = 7.3124 CNY
|
||||||
|
//
|
||||||
|
// GPT-4 price (8k-context)
|
||||||
|
// Input Output
|
||||||
|
// $0.03 / 1K tokens $0.06 / 1K tokens
|
||||||
|
// ¥0.21 / 1K tokens ¥0.43 / 1K tokens
|
||||||
|
// 2.1 nio / 1K tokens 4.3 nio / 1K tokens
|
||||||
|
|
||||||
|
// Dalle price (512x512)
|
||||||
|
// $0.018 / per image
|
||||||
|
// ¥0.13 / per image
|
||||||
|
// 1 nio / per image
|
||||||
|
|
||||||
|
func CountInputToken(n int) float32 {
|
||||||
|
return float32(n) / 1000 * 2.1
|
||||||
|
}
|
||||||
|
|
||||||
|
func CountOutputToken(n int) float32 {
|
||||||
|
return float32(n) / 1000 * 4.3
|
||||||
|
}
|
||||||
|
|
||||||
func ReduceUsage(db *sql.DB, user *User, _t string) bool {
|
func ReduceUsage(db *sql.DB, user *User, _t string) bool {
|
||||||
id := user.GetID(db)
|
id := user.GetID(db)
|
||||||
var count int
|
var count int
|
||||||
@ -82,7 +105,7 @@ func BuyDalle(db *sql.DB, user *User, value int) bool {
|
|||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
|
|
||||||
func CountGPT4Prize(value int) float32 {
|
func CountGPT4price(value int) float32 {
|
||||||
if value <= 20 {
|
if value <= 20 {
|
||||||
return float32(value) * 0.5
|
return float32(value) * 0.5
|
||||||
}
|
}
|
||||||
@ -91,7 +114,7 @@ func CountGPT4Prize(value int) float32 {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func BuyGPT4(db *sql.DB, user *User, value int) bool {
|
func BuyGPT4(db *sql.DB, user *User, value int) bool {
|
||||||
if !Pay(user.Username, CountGPT4Prize(value)) {
|
if !Pay(user.Username, CountGPT4price(value)) {
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -14,11 +14,13 @@ type Conversation struct {
|
|||||||
Name string `json:"name"`
|
Name string `json:"name"`
|
||||||
Message []types.ChatGPTMessage `json:"message"`
|
Message []types.ChatGPTMessage `json:"message"`
|
||||||
EnableGPT4 bool `json:"enable_gpt4"`
|
EnableGPT4 bool `json:"enable_gpt4"`
|
||||||
|
EnableWeb bool `json:"enable_web"`
|
||||||
}
|
}
|
||||||
|
|
||||||
type FormMessage struct {
|
type FormMessage struct {
|
||||||
Type string `json:"type"` // ping
|
Type string `json:"type"` // ping
|
||||||
Message string `json:"message"`
|
Message string `json:"message"`
|
||||||
|
Web bool `json:"web"`
|
||||||
GPT4 bool `json:"gpt4"`
|
GPT4 bool `json:"gpt4"`
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -29,6 +31,7 @@ func NewConversation(db *sql.DB, id int64) *Conversation {
|
|||||||
Name: "new chat",
|
Name: "new chat",
|
||||||
Message: []types.ChatGPTMessage{},
|
Message: []types.ChatGPTMessage{},
|
||||||
EnableGPT4: false,
|
EnableGPT4: false,
|
||||||
|
EnableWeb: false,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -36,10 +39,18 @@ func (c *Conversation) IsEnableGPT4() bool {
|
|||||||
return c.EnableGPT4
|
return c.EnableGPT4
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (c *Conversation) IsEnableWeb() bool {
|
||||||
|
return c.EnableWeb
|
||||||
|
}
|
||||||
|
|
||||||
func (c *Conversation) SetEnableGPT4(enable bool) {
|
func (c *Conversation) SetEnableGPT4(enable bool) {
|
||||||
c.EnableGPT4 = enable
|
c.EnableGPT4 = enable
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (c *Conversation) SetEnableWeb(enable bool) {
|
||||||
|
c.EnableWeb = enable
|
||||||
|
}
|
||||||
|
|
||||||
func (c *Conversation) GetName() string {
|
func (c *Conversation) GetName() string {
|
||||||
return c.Name
|
return c.Name
|
||||||
}
|
}
|
||||||
@ -131,6 +142,7 @@ func (c *Conversation) AddMessageFromUserForm(data []byte) (string, error) {
|
|||||||
|
|
||||||
c.AddMessageFromUser(form.Message)
|
c.AddMessageFromUser(form.Message)
|
||||||
c.SetEnableGPT4(form.GPT4)
|
c.SetEnableGPT4(form.GPT4)
|
||||||
|
c.SetEnableWeb(form.Web)
|
||||||
return form.Message, nil
|
return form.Message, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
3
go.mod
3
go.mod
@ -15,6 +15,7 @@ require (
|
|||||||
github.com/cespare/xxhash/v2 v2.1.2 // indirect
|
github.com/cespare/xxhash/v2 v2.1.2 // indirect
|
||||||
github.com/chenzhuoyu/base64x v0.0.0-20221115062448-fe3a3abad311 // indirect
|
github.com/chenzhuoyu/base64x v0.0.0-20221115062448-fe3a3abad311 // indirect
|
||||||
github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f // indirect
|
github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f // indirect
|
||||||
|
github.com/dlclark/regexp2 v1.10.0 // indirect
|
||||||
github.com/fsnotify/fsnotify v1.6.0 // indirect
|
github.com/fsnotify/fsnotify v1.6.0 // indirect
|
||||||
github.com/gabriel-vasile/mimetype v1.4.2 // indirect
|
github.com/gabriel-vasile/mimetype v1.4.2 // indirect
|
||||||
github.com/gin-contrib/sse v0.1.0 // indirect
|
github.com/gin-contrib/sse v0.1.0 // indirect
|
||||||
@ -22,6 +23,7 @@ require (
|
|||||||
github.com/go-playground/universal-translator v0.18.1 // indirect
|
github.com/go-playground/universal-translator v0.18.1 // indirect
|
||||||
github.com/go-playground/validator/v10 v10.14.0 // indirect
|
github.com/go-playground/validator/v10 v10.14.0 // indirect
|
||||||
github.com/goccy/go-json v0.10.2 // indirect
|
github.com/goccy/go-json v0.10.2 // indirect
|
||||||
|
github.com/google/uuid v1.3.0 // indirect
|
||||||
github.com/gorilla/websocket v1.5.0 // indirect
|
github.com/gorilla/websocket v1.5.0 // indirect
|
||||||
github.com/hashicorp/hcl v1.0.0 // indirect
|
github.com/hashicorp/hcl v1.0.0 // indirect
|
||||||
github.com/json-iterator/go v1.1.12 // indirect
|
github.com/json-iterator/go v1.1.12 // indirect
|
||||||
@ -33,6 +35,7 @@ require (
|
|||||||
github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd // indirect
|
github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd // indirect
|
||||||
github.com/modern-go/reflect2 v1.0.2 // indirect
|
github.com/modern-go/reflect2 v1.0.2 // indirect
|
||||||
github.com/pelletier/go-toml/v2 v2.0.8 // indirect
|
github.com/pelletier/go-toml/v2 v2.0.8 // indirect
|
||||||
|
github.com/pkoukk/tiktoken-go v0.1.5 // indirect
|
||||||
github.com/spf13/afero v1.9.5 // indirect
|
github.com/spf13/afero v1.9.5 // indirect
|
||||||
github.com/spf13/cast v1.5.1 // indirect
|
github.com/spf13/cast v1.5.1 // indirect
|
||||||
github.com/spf13/jwalterweatherman v1.1.0 // indirect
|
github.com/spf13/jwalterweatherman v1.1.0 // indirect
|
||||||
|
6
go.sum
6
go.sum
@ -61,6 +61,8 @@ github.com/dgrijalva/jwt-go v3.2.0+incompatible h1:7qlOGliEKZXTDg6OTjfoBKDXWrumC
|
|||||||
github.com/dgrijalva/jwt-go v3.2.0+incompatible/go.mod h1:E3ru+11k8xSBh+hMPgOLZmtrrCbhqsmaPHjLKYnJCaQ=
|
github.com/dgrijalva/jwt-go v3.2.0+incompatible/go.mod h1:E3ru+11k8xSBh+hMPgOLZmtrrCbhqsmaPHjLKYnJCaQ=
|
||||||
github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f h1:lO4WD4F/rVNCu3HqELle0jiPLLBs70cWOduZpkS1E78=
|
github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f h1:lO4WD4F/rVNCu3HqELle0jiPLLBs70cWOduZpkS1E78=
|
||||||
github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f/go.mod h1:cuUVRXasLTGF7a8hSLbxyZXjz+1KgoB3wDUb6vlszIc=
|
github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f/go.mod h1:cuUVRXasLTGF7a8hSLbxyZXjz+1KgoB3wDUb6vlszIc=
|
||||||
|
github.com/dlclark/regexp2 v1.10.0 h1:+/GIL799phkJqYW+3YbOd8LCcbHzT0Pbo8zl70MHsq0=
|
||||||
|
github.com/dlclark/regexp2 v1.10.0/go.mod h1:DHkYz0B9wPfa6wondMfaivmHpzrQ3v9q8cnmRbL6yW8=
|
||||||
github.com/envoyproxy/go-control-plane v0.9.0/go.mod h1:YTl/9mNaCwkRvm6d1a2C3ymFceY/DCBVvsKhRF0iEA4=
|
github.com/envoyproxy/go-control-plane v0.9.0/go.mod h1:YTl/9mNaCwkRvm6d1a2C3ymFceY/DCBVvsKhRF0iEA4=
|
||||||
github.com/envoyproxy/go-control-plane v0.9.1-0.20191026205805-5f8ba28d4473/go.mod h1:YTl/9mNaCwkRvm6d1a2C3ymFceY/DCBVvsKhRF0iEA4=
|
github.com/envoyproxy/go-control-plane v0.9.1-0.20191026205805-5f8ba28d4473/go.mod h1:YTl/9mNaCwkRvm6d1a2C3ymFceY/DCBVvsKhRF0iEA4=
|
||||||
github.com/envoyproxy/go-control-plane v0.9.4/go.mod h1:6rpuAdCZL397s3pYoYcLgu1mIlRU8Am5FuJP05cCM98=
|
github.com/envoyproxy/go-control-plane v0.9.4/go.mod h1:6rpuAdCZL397s3pYoYcLgu1mIlRU8Am5FuJP05cCM98=
|
||||||
@ -147,6 +149,8 @@ github.com/google/pprof v0.0.0-20201203190320-1bf35d6f28c2/go.mod h1:kpwsk12EmLe
|
|||||||
github.com/google/pprof v0.0.0-20201218002935-b9804c9f04c2/go.mod h1:kpwsk12EmLew5upagYY7GY0pfYCcupk39gWOCRROcvE=
|
github.com/google/pprof v0.0.0-20201218002935-b9804c9f04c2/go.mod h1:kpwsk12EmLew5upagYY7GY0pfYCcupk39gWOCRROcvE=
|
||||||
github.com/google/renameio v0.1.0/go.mod h1:KWCgfxg9yswjAJkECMjeO8J8rahYeXnNhOm40UhjYkI=
|
github.com/google/renameio v0.1.0/go.mod h1:KWCgfxg9yswjAJkECMjeO8J8rahYeXnNhOm40UhjYkI=
|
||||||
github.com/google/uuid v1.1.2/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo=
|
github.com/google/uuid v1.1.2/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo=
|
||||||
|
github.com/google/uuid v1.3.0 h1:t6JiXgmwXMjEs8VusXIJk2BXHsn+wx8BZdTaoZ5fu7I=
|
||||||
|
github.com/google/uuid v1.3.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo=
|
||||||
github.com/googleapis/gax-go/v2 v2.0.4/go.mod h1:0Wqv26UfaUD9n4G6kQubkQ+KchISgw+vpHVxEJEs9eg=
|
github.com/googleapis/gax-go/v2 v2.0.4/go.mod h1:0Wqv26UfaUD9n4G6kQubkQ+KchISgw+vpHVxEJEs9eg=
|
||||||
github.com/googleapis/gax-go/v2 v2.0.5/go.mod h1:DWXyrwAJ9X0FpwwEdw+IPEYBICEFu5mhpdKc/us6bOk=
|
github.com/googleapis/gax-go/v2 v2.0.5/go.mod h1:DWXyrwAJ9X0FpwwEdw+IPEYBICEFu5mhpdKc/us6bOk=
|
||||||
github.com/googleapis/google-cloud-go-testing v0.0.0-20200911160855-bcd43fbb19e8/go.mod h1:dvDLG8qkwmyD9a/MJJN3XJcT3xFxOKAvTZGvuZmac9g=
|
github.com/googleapis/google-cloud-go-testing v0.0.0-20200911160855-bcd43fbb19e8/go.mod h1:dvDLG8qkwmyD9a/MJJN3XJcT3xFxOKAvTZGvuZmac9g=
|
||||||
@ -192,6 +196,8 @@ github.com/pelletier/go-toml/v2 v2.0.8 h1:0ctb6s9mE31h0/lhu+J6OPmVeDxJn+kYnJc2jZ
|
|||||||
github.com/pelletier/go-toml/v2 v2.0.8/go.mod h1:vuYfssBdrU2XDZ9bYydBu6t+6a6PYNcZljzZR9VXg+4=
|
github.com/pelletier/go-toml/v2 v2.0.8/go.mod h1:vuYfssBdrU2XDZ9bYydBu6t+6a6PYNcZljzZR9VXg+4=
|
||||||
github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0=
|
github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0=
|
||||||
github.com/pkg/sftp v1.13.1/go.mod h1:3HaPG6Dq1ILlpPZRO0HVMrsydcdLt6HRDccSgb87qRg=
|
github.com/pkg/sftp v1.13.1/go.mod h1:3HaPG6Dq1ILlpPZRO0HVMrsydcdLt6HRDccSgb87qRg=
|
||||||
|
github.com/pkoukk/tiktoken-go v0.1.5 h1:hAlT4dCf6Uk50x8E7HQrddhH3EWMKUN+LArExQQsQx4=
|
||||||
|
github.com/pkoukk/tiktoken-go v0.1.5/go.mod h1:9NiV+i9mJKGj1rYOT+njbv+ZwA/zJxYdewGl6qVatpg=
|
||||||
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
|
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
|
||||||
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
|
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
|
||||||
github.com/prometheus/client_model v0.0.0-20190812154241-14fe0d1b01d4/go.mod h1:xMI15A0UPsDsEKsMN9yxemIoYk6Tm2C1GtYGdfGttqA=
|
github.com/prometheus/client_model v0.0.0-20190812154241-14fe0d1b01d4/go.mod h1:xMI15A0UPsDsEKsMN9yxemIoYk6Tm2C1GtYGdfGttqA=
|
||||||
|
@ -34,6 +34,7 @@ type ChatGPTStreamResponse struct {
|
|||||||
}
|
}
|
||||||
|
|
||||||
type ChatGPTSegmentResponse struct {
|
type ChatGPTSegmentResponse struct {
|
||||||
|
Quota float32 `json:"quota"`
|
||||||
Keyword string `json:"keyword"`
|
Keyword string `json:"keyword"`
|
||||||
Message string `json:"message"`
|
Message string `json:"message"`
|
||||||
End bool `json:"end"`
|
End bool `json:"end"`
|
||||||
|
@ -37,6 +37,14 @@ func Unmarshal[T interface{}](data []byte) (form T, err error) {
|
|||||||
return form, err
|
return form, err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func Marshal[T interface{}](data T) string {
|
||||||
|
res, err := json.Marshal(data)
|
||||||
|
if err != nil {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
return string(res)
|
||||||
|
}
|
||||||
|
|
||||||
func ToInt(value string) int {
|
func ToInt(value string) int {
|
||||||
if res, err := strconv.Atoi(value); err == nil {
|
if res, err := strconv.Atoi(value); err == nil {
|
||||||
return res
|
return res
|
||||||
|
60
utils/tokenizer.go
Normal file
60
utils/tokenizer.go
Normal file
@ -0,0 +1,60 @@
|
|||||||
|
package utils
|
||||||
|
|
||||||
|
import (
|
||||||
|
"chat/types"
|
||||||
|
"fmt"
|
||||||
|
"github.com/pkoukk/tiktoken-go"
|
||||||
|
"strings"
|
||||||
|
)
|
||||||
|
|
||||||
|
// Using https://github.com/pkoukk/tiktoken-go
|
||||||
|
// To count number of tokens of openai chat messages
|
||||||
|
// OpenAI Cookbook: https://github.com/openai/openai-cookbook/blob/main/examples/How_to_count_tokens_with_tiktoken.ipynb
|
||||||
|
|
||||||
|
func GetWeightByModel(model string) int {
|
||||||
|
switch model {
|
||||||
|
case "gpt-3.5-turbo-0613",
|
||||||
|
"gpt-3.5-turbo-16k-0613",
|
||||||
|
"gpt-4-0314",
|
||||||
|
"gpt-4-32k-0314",
|
||||||
|
"gpt-4-0613",
|
||||||
|
"gpt-4-32k-0613":
|
||||||
|
return 3
|
||||||
|
case "gpt-3.5-turbo-0301":
|
||||||
|
return 4 // every message follows <|start|>{role/name}\n{content}<|end|>\n
|
||||||
|
default:
|
||||||
|
if strings.Contains(model, "gpt-3.5-turbo") {
|
||||||
|
// warning: gpt-3.5-turbo may update over time. Returning num tokens assuming gpt-3.5-turbo-0613.
|
||||||
|
return GetWeightByModel("gpt-3.5-turbo-0613")
|
||||||
|
} else if strings.Contains(model, "gpt-4") {
|
||||||
|
// warning: gpt-4 may update over time. Returning num tokens assuming gpt-4-0613.
|
||||||
|
return GetWeightByModel("gpt-4-0613")
|
||||||
|
} else {
|
||||||
|
// not implemented: See https://github.com/openai/openai-python/blob/main/chatml.md for information on how messages are converted to tokens
|
||||||
|
panic(fmt.Errorf("not implemented for model %s", model))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
func NumTokensFromMessages(messages []types.ChatGPTMessage, model string) (tokens int) {
|
||||||
|
weight := GetWeightByModel(model)
|
||||||
|
tkm, err := tiktoken.EncodingForModel(model)
|
||||||
|
if err != nil {
|
||||||
|
// can not encode messages, use length of messages as a proxy for number of tokens
|
||||||
|
// using rune instead of byte to account for unicode characters (e.g. emojis, non-english characters)
|
||||||
|
|
||||||
|
data := Marshal(messages)
|
||||||
|
return len([]rune(data)) * weight
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, message := range messages {
|
||||||
|
tokens += weight
|
||||||
|
tokens += len(tkm.Encode(message.Content, nil, nil))
|
||||||
|
tokens += len(tkm.Encode(message.Role, nil, nil))
|
||||||
|
}
|
||||||
|
tokens += 3 // every reply is primed with <|start|>assistant<|message|>
|
||||||
|
return tokens
|
||||||
|
}
|
||||||
|
|
||||||
|
func CountTokenPrice(messages []types.ChatGPTMessage) int {
|
||||||
|
return NumTokensFromMessages(messages, "gpt-4")
|
||||||
|
}
|
Loading…
Reference in New Issue
Block a user