mirror of
https://github.com/coaidev/coai.git
synced 2025-05-28 17:30:15 +09:00
add thread limit
This commit is contained in:
parent
6b37bc13b9
commit
e34df4146b
27
api/chat.go
27
api/chat.go
@ -7,6 +7,7 @@ import (
|
||||
"chat/types"
|
||||
"chat/utils"
|
||||
"database/sql"
|
||||
"fmt"
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/go-redis/redis/v8"
|
||||
"github.com/gorilla/websocket"
|
||||
@ -17,6 +18,7 @@ import (
|
||||
const defaultErrorMessage = "There was something wrong... Please try again later."
|
||||
const defaultQuotaMessage = "You have run out of GPT-4 usage. Please keep your nio points above **5**."
|
||||
const defaultImageMessage = "Please provide description for the image (e.g. /image an apple)."
|
||||
const maxThread = 3
|
||||
|
||||
type WebsocketAuthForm struct {
|
||||
Token string `json:"token" binding:"required"`
|
||||
@ -105,6 +107,14 @@ func ImageChat(conn *websocket.Conn, instance *conversation.Conversation, user *
|
||||
return markdown
|
||||
}
|
||||
|
||||
func ChatHandler(conn *websocket.Conn, instance *conversation.Conversation, user *auth.User, db *sql.DB, cache *redis.Client) string {
|
||||
if strings.HasPrefix(instance.GetLatestMessage(), "/image") {
|
||||
return ImageChat(conn, instance, user, db, cache)
|
||||
} else {
|
||||
return TextChat(db, user, conn, instance)
|
||||
}
|
||||
}
|
||||
|
||||
func ChatAPI(c *gin.Context) {
|
||||
// websocket connection
|
||||
upgrader := websocket.Upgrader{
|
||||
@ -147,6 +157,7 @@ func ChatAPI(c *gin.Context) {
|
||||
}
|
||||
|
||||
db := c.MustGet("db").(*sql.DB)
|
||||
cache := c.MustGet("cache").(*redis.Client)
|
||||
var instance *conversation.Conversation
|
||||
if form.Id == -1 {
|
||||
// create new conversation
|
||||
@ -159,19 +170,23 @@ func ChatAPI(c *gin.Context) {
|
||||
}
|
||||
}
|
||||
|
||||
id := user.GetID(db)
|
||||
|
||||
for {
|
||||
_, message, err = conn.ReadMessage()
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
if instance.HandleMessage(db, message) {
|
||||
var msg string
|
||||
if strings.HasPrefix(instance.GetLatestMessage(), "/image") {
|
||||
cache := c.MustGet("cache").(*redis.Client)
|
||||
msg = ImageChat(conn, instance, user, db, cache)
|
||||
} else {
|
||||
msg = TextChat(db, user, conn, instance)
|
||||
if !utils.IncrWithLimit(cache, fmt.Sprintf(":chatthread:%d", id), 1, maxThread, 60) {
|
||||
SendSegmentMessage(conn, types.ChatGPTSegmentResponse{
|
||||
Message: fmt.Sprintf("You have reached the maximum number of threads (%d) the same time. Please wait for a while.", maxThread),
|
||||
End: true,
|
||||
})
|
||||
return
|
||||
}
|
||||
msg := ChatHandler(conn, instance, user, db, cache)
|
||||
utils.DecrInt(cache, fmt.Sprintf(":chatthread:%d", id), 1)
|
||||
instance.SaveResponse(db, msg)
|
||||
}
|
||||
}
|
||||
|
1
app/.gitignore
vendored
1
app/.gitignore
vendored
@ -10,6 +10,7 @@ lerna-debug.log*
|
||||
node_modules
|
||||
dist
|
||||
dist-ssr
|
||||
dev-dist
|
||||
*.local
|
||||
|
||||
# Editor directories and files
|
||||
|
@ -79,7 +79,10 @@ function SideBar() {
|
||||
<Button
|
||||
variant={`ghost`}
|
||||
size={`icon`}
|
||||
onClick={() => toggleConversation(dispatch, -1)}
|
||||
onClick={async () => {
|
||||
await toggleConversation(dispatch, -1);
|
||||
if (mobile) dispatch(setMenu(false));
|
||||
}}
|
||||
>
|
||||
<Plus className={`h-4 w-4`} />
|
||||
</Button>
|
||||
@ -112,8 +115,8 @@ function SideBar() {
|
||||
current === conversation.id ? "active" : ""
|
||||
}`}
|
||||
key={i}
|
||||
onClick={() => {
|
||||
toggleConversation(dispatch, conversation.id);
|
||||
onClick={async () => {
|
||||
await toggleConversation(dispatch, conversation.id);
|
||||
if (mobile) dispatch(setMenu(false));
|
||||
}}
|
||||
>
|
||||
|
@ -52,6 +52,6 @@ const refreshPackage = async (dispatch: any) => {
|
||||
};
|
||||
|
||||
export const refreshPackageTask = (dispatch: any) => {
|
||||
setInterval(() => refreshPackage(dispatch), 5000);
|
||||
setInterval(() => refreshPackage(dispatch), 20000);
|
||||
refreshPackage(dispatch).then();
|
||||
};
|
||||
|
@ -2,7 +2,6 @@ import { defineConfig } from 'vite'
|
||||
import react from '@vitejs/plugin-react-swc'
|
||||
import path from "path"
|
||||
import { createHtmlPlugin } from 'vite-plugin-html'
|
||||
import htmlMinifierTerser from 'vite-plugin-html-minifier-terser'
|
||||
import {VitePWA} from "vite-plugin-pwa";
|
||||
|
||||
// https://vitejs.dev/config/
|
||||
@ -48,6 +47,16 @@ export default defineConfig({
|
||||
}
|
||||
}
|
||||
},
|
||||
{
|
||||
urlPattern: new RegExp('^https://fonts.gstatic.googlefonts.cn/(.*)'),
|
||||
handler: 'CacheFirst',
|
||||
options: {
|
||||
cacheName: 'google-fonts-gstatic',
|
||||
expiration: {
|
||||
maxEntries: 3600,
|
||||
}
|
||||
}
|
||||
},
|
||||
{
|
||||
urlPattern: new RegExp('https://cdn.zmh-program.site/(.*)'),
|
||||
handler: 'CacheFirst',
|
||||
|
@ -1 +1,48 @@
|
||||
package utils
|
||||
|
||||
import (
|
||||
"context"
|
||||
"github.com/go-redis/redis/v8"
|
||||
"time"
|
||||
)
|
||||
|
||||
func Incr(cache *redis.Client, key string, delta int64) (int64, error) {
|
||||
return cache.IncrBy(context.Background(), key, delta).Result()
|
||||
}
|
||||
|
||||
func Decr(cache *redis.Client, key string, delta int64) (int64, error) {
|
||||
return cache.DecrBy(context.Background(), key, delta).Result()
|
||||
}
|
||||
|
||||
func GetInt(cache *redis.Client, key string) (int64, error) {
|
||||
return cache.Get(context.Background(), key).Int64()
|
||||
}
|
||||
|
||||
func SetInt(cache *redis.Client, key string, value int64, expiration int64) error {
|
||||
return cache.Set(context.Background(), key, value, time.Duration(expiration)*time.Second).Err()
|
||||
}
|
||||
|
||||
func IncrWithLimit(cache *redis.Client, key string, delta int64, limit int64, expiration int64) bool {
|
||||
// not exist
|
||||
if _, err := cache.Get(context.Background(), key).Result(); err != nil {
|
||||
if err == redis.Nil {
|
||||
cache.Set(context.Background(), key, delta, time.Duration(expiration)*time.Second)
|
||||
return true
|
||||
}
|
||||
return false
|
||||
}
|
||||
res, err := Incr(cache, key, delta)
|
||||
if err != nil {
|
||||
return false
|
||||
}
|
||||
if res > limit {
|
||||
cache.Set(context.Background(), key, limit, time.Duration(expiration)*time.Second)
|
||||
return false
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
func DecrInt(cache *redis.Client, key string, delta int64) bool {
|
||||
_, err := Decr(cache, key, delta)
|
||||
return err == nil
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user