diff --git a/api/chat.go b/api/chat.go index 8b91056..907e28b 100644 --- a/api/chat.go +++ b/api/chat.go @@ -7,6 +7,7 @@ import ( "chat/types" "chat/utils" "database/sql" + "fmt" "github.com/gin-gonic/gin" "github.com/go-redis/redis/v8" "github.com/gorilla/websocket" @@ -17,6 +18,7 @@ import ( const defaultErrorMessage = "There was something wrong... Please try again later." const defaultQuotaMessage = "You have run out of GPT-4 usage. Please keep your nio points above **5**." const defaultImageMessage = "Please provide description for the image (e.g. /image an apple)." +const maxThread = 3 type WebsocketAuthForm struct { Token string `json:"token" binding:"required"` @@ -105,6 +107,14 @@ func ImageChat(conn *websocket.Conn, instance *conversation.Conversation, user * return markdown } +func ChatHandler(conn *websocket.Conn, instance *conversation.Conversation, user *auth.User, db *sql.DB, cache *redis.Client) string { + if strings.HasPrefix(instance.GetLatestMessage(), "/image") { + return ImageChat(conn, instance, user, db, cache) + } else { + return TextChat(db, user, conn, instance) + } +} + func ChatAPI(c *gin.Context) { // websocket connection upgrader := websocket.Upgrader{ @@ -147,6 +157,7 @@ func ChatAPI(c *gin.Context) { } db := c.MustGet("db").(*sql.DB) + cache := c.MustGet("cache").(*redis.Client) var instance *conversation.Conversation if form.Id == -1 { // create new conversation @@ -159,19 +170,23 @@ func ChatAPI(c *gin.Context) { } } + id := user.GetID(db) + for { _, message, err = conn.ReadMessage() if err != nil { return } if instance.HandleMessage(db, message) { - var msg string - if strings.HasPrefix(instance.GetLatestMessage(), "/image") { - cache := c.MustGet("cache").(*redis.Client) - msg = ImageChat(conn, instance, user, db, cache) - } else { - msg = TextChat(db, user, conn, instance) + if !utils.IncrWithLimit(cache, fmt.Sprintf(":chatthread:%d", id), 1, maxThread, 60) { + SendSegmentMessage(conn, types.ChatGPTSegmentResponse{ + Message: fmt.Sprintf("You have reached the maximum number of threads (%d) the same time. Please wait for a while.", maxThread), + End: true, + }) + return } + msg := ChatHandler(conn, instance, user, db, cache) + utils.DecrInt(cache, fmt.Sprintf(":chatthread:%d", id), 1) instance.SaveResponse(db, msg) } } diff --git a/app/.gitignore b/app/.gitignore index a547bf3..6d6ae5a 100644 --- a/app/.gitignore +++ b/app/.gitignore @@ -10,6 +10,7 @@ lerna-debug.log* node_modules dist dist-ssr +dev-dist *.local # Editor directories and files diff --git a/app/src/routes/Home.tsx b/app/src/routes/Home.tsx index 51166c1..2a151b6 100644 --- a/app/src/routes/Home.tsx +++ b/app/src/routes/Home.tsx @@ -79,7 +79,10 @@ function SideBar() { @@ -112,8 +115,8 @@ function SideBar() { current === conversation.id ? "active" : "" }`} key={i} - onClick={() => { - toggleConversation(dispatch, conversation.id); + onClick={async () => { + await toggleConversation(dispatch, conversation.id); if (mobile) dispatch(setMenu(false)); }} > diff --git a/app/src/store/package.ts b/app/src/store/package.ts index 51ff59f..6ce9713 100644 --- a/app/src/store/package.ts +++ b/app/src/store/package.ts @@ -52,6 +52,6 @@ const refreshPackage = async (dispatch: any) => { }; export const refreshPackageTask = (dispatch: any) => { - setInterval(() => refreshPackage(dispatch), 5000); + setInterval(() => refreshPackage(dispatch), 20000); refreshPackage(dispatch).then(); }; diff --git a/app/vite.config.ts b/app/vite.config.ts index 5ad69f7..7a104da 100644 --- a/app/vite.config.ts +++ b/app/vite.config.ts @@ -2,7 +2,6 @@ import { defineConfig } from 'vite' import react from '@vitejs/plugin-react-swc' import path from "path" import { createHtmlPlugin } from 'vite-plugin-html' -import htmlMinifierTerser from 'vite-plugin-html-minifier-terser' import {VitePWA} from "vite-plugin-pwa"; // https://vitejs.dev/config/ @@ -48,6 +47,16 @@ export default defineConfig({ } } }, + { + urlPattern: new RegExp('^https://fonts.gstatic.googlefonts.cn/(.*)'), + handler: 'CacheFirst', + options: { + cacheName: 'google-fonts-gstatic', + expiration: { + maxEntries: 3600, + } + } + }, { urlPattern: new RegExp('https://cdn.zmh-program.site/(.*)'), handler: 'CacheFirst', diff --git a/utils/cache.go b/utils/cache.go index d4b585b..e342fed 100644 --- a/utils/cache.go +++ b/utils/cache.go @@ -1 +1,48 @@ package utils + +import ( + "context" + "github.com/go-redis/redis/v8" + "time" +) + +func Incr(cache *redis.Client, key string, delta int64) (int64, error) { + return cache.IncrBy(context.Background(), key, delta).Result() +} + +func Decr(cache *redis.Client, key string, delta int64) (int64, error) { + return cache.DecrBy(context.Background(), key, delta).Result() +} + +func GetInt(cache *redis.Client, key string) (int64, error) { + return cache.Get(context.Background(), key).Int64() +} + +func SetInt(cache *redis.Client, key string, value int64, expiration int64) error { + return cache.Set(context.Background(), key, value, time.Duration(expiration)*time.Second).Err() +} + +func IncrWithLimit(cache *redis.Client, key string, delta int64, limit int64, expiration int64) bool { + // not exist + if _, err := cache.Get(context.Background(), key).Result(); err != nil { + if err == redis.Nil { + cache.Set(context.Background(), key, delta, time.Duration(expiration)*time.Second) + return true + } + return false + } + res, err := Incr(cache, key, delta) + if err != nil { + return false + } + if res > limit { + cache.Set(context.Background(), key, limit, time.Duration(expiration)*time.Second) + return false + } + return true +} + +func DecrInt(cache *redis.Client, key string, delta int64) bool { + _, err := Decr(cache, key, delta) + return err == nil +}