add thread limit

This commit is contained in:
Zhang Minghan 2023-09-08 12:08:57 +08:00
parent 6b37bc13b9
commit e34df4146b
6 changed files with 86 additions and 11 deletions

View File

@ -7,6 +7,7 @@ import (
"chat/types" "chat/types"
"chat/utils" "chat/utils"
"database/sql" "database/sql"
"fmt"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
"github.com/go-redis/redis/v8" "github.com/go-redis/redis/v8"
"github.com/gorilla/websocket" "github.com/gorilla/websocket"
@ -17,6 +18,7 @@ import (
const defaultErrorMessage = "There was something wrong... Please try again later." const defaultErrorMessage = "There was something wrong... Please try again later."
const defaultQuotaMessage = "You have run out of GPT-4 usage. Please keep your nio points above **5**." const defaultQuotaMessage = "You have run out of GPT-4 usage. Please keep your nio points above **5**."
const defaultImageMessage = "Please provide description for the image (e.g. /image an apple)." const defaultImageMessage = "Please provide description for the image (e.g. /image an apple)."
const maxThread = 3
type WebsocketAuthForm struct { type WebsocketAuthForm struct {
Token string `json:"token" binding:"required"` Token string `json:"token" binding:"required"`
@ -105,6 +107,14 @@ func ImageChat(conn *websocket.Conn, instance *conversation.Conversation, user *
return markdown return markdown
} }
func ChatHandler(conn *websocket.Conn, instance *conversation.Conversation, user *auth.User, db *sql.DB, cache *redis.Client) string {
if strings.HasPrefix(instance.GetLatestMessage(), "/image") {
return ImageChat(conn, instance, user, db, cache)
} else {
return TextChat(db, user, conn, instance)
}
}
func ChatAPI(c *gin.Context) { func ChatAPI(c *gin.Context) {
// websocket connection // websocket connection
upgrader := websocket.Upgrader{ upgrader := websocket.Upgrader{
@ -147,6 +157,7 @@ func ChatAPI(c *gin.Context) {
} }
db := c.MustGet("db").(*sql.DB) db := c.MustGet("db").(*sql.DB)
cache := c.MustGet("cache").(*redis.Client)
var instance *conversation.Conversation var instance *conversation.Conversation
if form.Id == -1 { if form.Id == -1 {
// create new conversation // create new conversation
@ -159,19 +170,23 @@ func ChatAPI(c *gin.Context) {
} }
} }
id := user.GetID(db)
for { for {
_, message, err = conn.ReadMessage() _, message, err = conn.ReadMessage()
if err != nil { if err != nil {
return return
} }
if instance.HandleMessage(db, message) { if instance.HandleMessage(db, message) {
var msg string if !utils.IncrWithLimit(cache, fmt.Sprintf(":chatthread:%d", id), 1, maxThread, 60) {
if strings.HasPrefix(instance.GetLatestMessage(), "/image") { SendSegmentMessage(conn, types.ChatGPTSegmentResponse{
cache := c.MustGet("cache").(*redis.Client) Message: fmt.Sprintf("You have reached the maximum number of threads (%d) the same time. Please wait for a while.", maxThread),
msg = ImageChat(conn, instance, user, db, cache) End: true,
} else { })
msg = TextChat(db, user, conn, instance) return
} }
msg := ChatHandler(conn, instance, user, db, cache)
utils.DecrInt(cache, fmt.Sprintf(":chatthread:%d", id), 1)
instance.SaveResponse(db, msg) instance.SaveResponse(db, msg)
} }
} }

1
app/.gitignore vendored
View File

@ -10,6 +10,7 @@ lerna-debug.log*
node_modules node_modules
dist dist
dist-ssr dist-ssr
dev-dist
*.local *.local
# Editor directories and files # Editor directories and files

View File

@ -79,7 +79,10 @@ function SideBar() {
<Button <Button
variant={`ghost`} variant={`ghost`}
size={`icon`} size={`icon`}
onClick={() => toggleConversation(dispatch, -1)} onClick={async () => {
await toggleConversation(dispatch, -1);
if (mobile) dispatch(setMenu(false));
}}
> >
<Plus className={`h-4 w-4`} /> <Plus className={`h-4 w-4`} />
</Button> </Button>
@ -112,8 +115,8 @@ function SideBar() {
current === conversation.id ? "active" : "" current === conversation.id ? "active" : ""
}`} }`}
key={i} key={i}
onClick={() => { onClick={async () => {
toggleConversation(dispatch, conversation.id); await toggleConversation(dispatch, conversation.id);
if (mobile) dispatch(setMenu(false)); if (mobile) dispatch(setMenu(false));
}} }}
> >

View File

@ -52,6 +52,6 @@ const refreshPackage = async (dispatch: any) => {
}; };
export const refreshPackageTask = (dispatch: any) => { export const refreshPackageTask = (dispatch: any) => {
setInterval(() => refreshPackage(dispatch), 5000); setInterval(() => refreshPackage(dispatch), 20000);
refreshPackage(dispatch).then(); refreshPackage(dispatch).then();
}; };

View File

@ -2,7 +2,6 @@ import { defineConfig } from 'vite'
import react from '@vitejs/plugin-react-swc' import react from '@vitejs/plugin-react-swc'
import path from "path" import path from "path"
import { createHtmlPlugin } from 'vite-plugin-html' import { createHtmlPlugin } from 'vite-plugin-html'
import htmlMinifierTerser from 'vite-plugin-html-minifier-terser'
import {VitePWA} from "vite-plugin-pwa"; import {VitePWA} from "vite-plugin-pwa";
// https://vitejs.dev/config/ // https://vitejs.dev/config/
@ -48,6 +47,16 @@ export default defineConfig({
} }
} }
}, },
{
urlPattern: new RegExp('^https://fonts.gstatic.googlefonts.cn/(.*)'),
handler: 'CacheFirst',
options: {
cacheName: 'google-fonts-gstatic',
expiration: {
maxEntries: 3600,
}
}
},
{ {
urlPattern: new RegExp('https://cdn.zmh-program.site/(.*)'), urlPattern: new RegExp('https://cdn.zmh-program.site/(.*)'),
handler: 'CacheFirst', handler: 'CacheFirst',

View File

@ -1 +1,48 @@
package utils package utils
import (
"context"
"github.com/go-redis/redis/v8"
"time"
)
func Incr(cache *redis.Client, key string, delta int64) (int64, error) {
return cache.IncrBy(context.Background(), key, delta).Result()
}
func Decr(cache *redis.Client, key string, delta int64) (int64, error) {
return cache.DecrBy(context.Background(), key, delta).Result()
}
func GetInt(cache *redis.Client, key string) (int64, error) {
return cache.Get(context.Background(), key).Int64()
}
func SetInt(cache *redis.Client, key string, value int64, expiration int64) error {
return cache.Set(context.Background(), key, value, time.Duration(expiration)*time.Second).Err()
}
func IncrWithLimit(cache *redis.Client, key string, delta int64, limit int64, expiration int64) bool {
// not exist
if _, err := cache.Get(context.Background(), key).Result(); err != nil {
if err == redis.Nil {
cache.Set(context.Background(), key, delta, time.Duration(expiration)*time.Second)
return true
}
return false
}
res, err := Incr(cache, key, delta)
if err != nil {
return false
}
if res > limit {
cache.Set(context.Background(), key, limit, time.Duration(expiration)*time.Second)
return false
}
return true
}
func DecrInt(cache *redis.Client, key string, delta int64) bool {
_, err := Decr(cache, key, delta)
return err == nil
}