mirror of
https://github.com/coaidev/coai.git
synced 2025-05-19 04:50:14 +09:00
195 lines
3.3 KiB
Go
195 lines
3.3 KiB
Go
package manager
|
|
|
|
import (
|
|
"chat/globals"
|
|
"chat/manager/conversation"
|
|
"chat/utils"
|
|
"database/sql"
|
|
"fmt"
|
|
"github.com/gin-gonic/gin"
|
|
"github.com/go-redis/redis/v8"
|
|
)
|
|
|
|
const (
|
|
ChatType = "chat"
|
|
StopType = "stop"
|
|
RestartType = "restart"
|
|
ShareType = "share"
|
|
MaskType = "mask"
|
|
EditType = "edit"
|
|
RemoveType = "remove"
|
|
)
|
|
|
|
type Stack chan *conversation.FormMessage
|
|
|
|
type Connection struct {
|
|
conn *utils.WebSocket
|
|
stack Stack
|
|
auth bool
|
|
hash string
|
|
}
|
|
|
|
func NewConnection(conn *utils.WebSocket, auth bool, hash string, bufferSize int) *Connection {
|
|
return &Connection{
|
|
conn: conn,
|
|
auth: auth,
|
|
hash: hash,
|
|
stack: make(Stack, bufferSize),
|
|
}
|
|
}
|
|
|
|
func (c *Connection) GetConn() *utils.WebSocket {
|
|
return c.conn
|
|
}
|
|
|
|
func (c *Connection) GetCtx() *gin.Context {
|
|
return c.conn.GetCtx()
|
|
}
|
|
|
|
func (c *Connection) GetStack() Stack {
|
|
return c.stack
|
|
}
|
|
|
|
func (c *Connection) ReadWorker() {
|
|
for {
|
|
if c.IsClosed() {
|
|
break
|
|
}
|
|
|
|
form, err := utils.ReadForm[conversation.FormMessage](c.conn)
|
|
if err != nil {
|
|
break
|
|
}
|
|
|
|
if form.Type == "" {
|
|
form.Type = ChatType
|
|
}
|
|
|
|
c.Write(form)
|
|
}
|
|
|
|
c.Stop()
|
|
}
|
|
|
|
func (c *Connection) Write(data *conversation.FormMessage) {
|
|
if len(c.stack) == cap(c.stack) {
|
|
c.Skip()
|
|
}
|
|
c.stack <- data
|
|
}
|
|
|
|
func (c *Connection) IsClosed() bool {
|
|
return c.conn.IsClosed()
|
|
}
|
|
|
|
func (c *Connection) Stop() {
|
|
c.Write(nil)
|
|
}
|
|
|
|
func (c *Connection) Read() *conversation.FormMessage {
|
|
form := <-c.stack
|
|
return form
|
|
}
|
|
|
|
func (c *Connection) Peek() *conversation.FormMessage {
|
|
select {
|
|
case form := <-c.stack:
|
|
utils.InsertChannel(c.stack, form, 0)
|
|
return form
|
|
default:
|
|
return nil
|
|
}
|
|
}
|
|
|
|
func (c *Connection) PeekWithType(t string) *conversation.FormMessage {
|
|
// skip if type is matched
|
|
|
|
if form := c.Peek(); form != nil {
|
|
if form.Type == t {
|
|
c.Skip()
|
|
return form
|
|
}
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
func (c *Connection) PeekWithTypes(types ...string) *conversation.FormMessage {
|
|
// skip if type is matched
|
|
|
|
if form := c.Peek(); form != nil {
|
|
for _, t := range types {
|
|
if form.Type == t {
|
|
c.Skip()
|
|
return form
|
|
}
|
|
}
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
func (c *Connection) PeekStop() *conversation.FormMessage {
|
|
return c.PeekWithTypes(StopType, RemoveType)
|
|
}
|
|
|
|
func (c *Connection) Skip() {
|
|
<-c.stack
|
|
}
|
|
|
|
func (c *Connection) GetDB() *sql.DB {
|
|
return c.conn.GetDB()
|
|
}
|
|
|
|
func (c *Connection) GetCache() *redis.Client {
|
|
return c.conn.GetCache()
|
|
}
|
|
|
|
func (c *Connection) Send(message globals.ChatSegmentResponse) {
|
|
c.conn.Send(message)
|
|
}
|
|
|
|
func (c *Connection) SendClient(message globals.ChatSegmentResponse) error {
|
|
return c.conn.SendJSON(message)
|
|
}
|
|
|
|
func (c *Connection) Process(handler func(*conversation.FormMessage) error) {
|
|
for {
|
|
if form := c.Read(); form != nil {
|
|
if err := handler(form); err != nil {
|
|
return
|
|
}
|
|
} else {
|
|
return
|
|
}
|
|
}
|
|
}
|
|
|
|
func (c *Connection) Handle(handler func(*conversation.FormMessage) error) {
|
|
go c.Process(handler)
|
|
c.ReadWorker()
|
|
}
|
|
|
|
func (c *Connection) Lock() bool {
|
|
state := c.conn.IncrRateWithLimit(
|
|
c.hash,
|
|
utils.Multi[int64](c.auth, globals.ChatMaxThread, globals.AnonymousMaxThread),
|
|
60,
|
|
)
|
|
|
|
if !state {
|
|
c.conn.Send(globals.ChatSegmentResponse{
|
|
Message: fmt.Sprintf("You have reached the maximum number of threads (%d) the same time. Please wait for a while.", globals.ChatMaxThread),
|
|
End: true,
|
|
})
|
|
|
|
return false
|
|
}
|
|
|
|
return true
|
|
}
|
|
|
|
func (c *Connection) Release() {
|
|
c.conn.DecrRate(c.hash)
|
|
}
|