mirror of
https://github.com/coaidev/coai.git
synced 2025-05-19 04:50:14 +09:00
255 lines
6.0 KiB
Go
255 lines
6.0 KiB
Go
package manager
|
|
|
|
import (
|
|
"chat/adapter"
|
|
adaptercommon "chat/adapter/common"
|
|
"chat/addition/web"
|
|
"chat/admin"
|
|
"chat/auth"
|
|
"chat/channel"
|
|
"chat/globals"
|
|
"chat/manager/conversation"
|
|
"chat/utils"
|
|
"time"
|
|
|
|
"database/sql"
|
|
"errors"
|
|
"fmt"
|
|
"runtime/debug"
|
|
"strings"
|
|
|
|
"github.com/gin-gonic/gin"
|
|
"github.com/go-redis/redis/v8"
|
|
)
|
|
|
|
const defaultMessage = "empty response"
|
|
const interruptMessage = "interrupted"
|
|
|
|
func CollectQuota(c *gin.Context, user *auth.User, buffer *utils.Buffer, uncountable bool, err error) {
|
|
db := utils.GetDBFromContext(c)
|
|
quota := buffer.GetQuota()
|
|
|
|
if user == nil || quota <= 0 {
|
|
return
|
|
}
|
|
|
|
if buffer.IsEmpty() || err != nil {
|
|
return
|
|
}
|
|
|
|
if !uncountable {
|
|
user.UseQuota(db, quota)
|
|
}
|
|
}
|
|
|
|
type partialChunk struct {
|
|
Chunk *globals.Chunk
|
|
End bool
|
|
Hit bool
|
|
Error error
|
|
}
|
|
|
|
func createStopSignal(conn *Connection) chan bool {
|
|
stopSignal := make(chan bool, 1)
|
|
go func(conn *Connection, stopSignal chan bool) {
|
|
ticker := time.NewTicker(100 * time.Millisecond)
|
|
defer func() {
|
|
ticker.Stop()
|
|
if r := recover(); r != nil && !strings.Contains(fmt.Sprintf("%s", r), "closed channel") {
|
|
stack := debug.Stack()
|
|
globals.Warn(fmt.Sprintf("caught panic from stop signal: %s\n%s", r, stack))
|
|
}
|
|
}()
|
|
|
|
for {
|
|
select {
|
|
case <-ticker.C:
|
|
state := conn.PeekStop() != nil // check the stop state
|
|
stopSignal <- state
|
|
|
|
if state {
|
|
break
|
|
}
|
|
}
|
|
}
|
|
}(conn, stopSignal)
|
|
|
|
return stopSignal
|
|
}
|
|
|
|
func createChatTask(
|
|
conn *Connection, user *auth.User, buffer *utils.Buffer, db *sql.DB, cache *redis.Client,
|
|
model string, instance *conversation.Conversation, segment []globals.Message, plan bool,
|
|
) (hit bool, err error) {
|
|
chunkChan := make(chan partialChunk, 24) // the channel to send the chunk data
|
|
interruptSignal := make(chan error, 1) // the signal to interrupt the chat task routine
|
|
stopSignal := createStopSignal(conn) // the signal to stop from the client
|
|
|
|
defer func() {
|
|
// close all channels
|
|
close(interruptSignal)
|
|
close(stopSignal)
|
|
close(chunkChan)
|
|
}()
|
|
|
|
// create a new chat request routine
|
|
go func() {
|
|
defer func() {
|
|
if r := recover(); r != nil && !strings.Contains(fmt.Sprintf("%s", r), "closed channel") {
|
|
stack := debug.Stack()
|
|
globals.Warn(fmt.Sprintf("caught panic from chat request: %s\n%s", r, stack))
|
|
}
|
|
}()
|
|
|
|
hit, err := channel.NewChatRequestWithCache(
|
|
cache, buffer,
|
|
auth.GetGroup(db, user),
|
|
adaptercommon.CreateChatProps(&adaptercommon.ChatProps{
|
|
Model: model,
|
|
Message: segment,
|
|
MaxTokens: instance.GetMaxTokens(),
|
|
Temperature: instance.GetTemperature(),
|
|
TopP: instance.GetTopP(),
|
|
TopK: instance.GetTopK(),
|
|
PresencePenalty: instance.GetPresencePenalty(),
|
|
FrequencyPenalty: instance.GetFrequencyPenalty(),
|
|
RepetitionPenalty: instance.GetRepetitionPenalty(),
|
|
}, buffer),
|
|
|
|
// the function to handle the chunk data
|
|
func(data *globals.Chunk) error {
|
|
// if interrupt signal is received
|
|
if len(interruptSignal) > 0 {
|
|
return errors.New(interruptMessage)
|
|
}
|
|
|
|
// send the chunk data to the channel
|
|
chunkChan <- partialChunk{
|
|
Chunk: data,
|
|
End: false,
|
|
Hit: false,
|
|
Error: nil,
|
|
}
|
|
return nil
|
|
},
|
|
)
|
|
|
|
// chat request routine is done
|
|
chunkChan <- partialChunk{
|
|
Chunk: nil,
|
|
End: true,
|
|
Hit: hit,
|
|
Error: err,
|
|
}
|
|
}()
|
|
|
|
for {
|
|
select {
|
|
case data := <-chunkChan:
|
|
if data.Error != nil && data.Error.Error() == interruptMessage {
|
|
// skip the interrupt message
|
|
continue
|
|
}
|
|
|
|
hit = data.Hit
|
|
err = data.Error
|
|
|
|
if data.End {
|
|
return
|
|
}
|
|
|
|
if err := conn.SendClient(globals.ChatSegmentResponse{
|
|
Message: buffer.WriteChunk(data.Chunk),
|
|
Quota: buffer.GetQuota(),
|
|
End: false,
|
|
Plan: plan,
|
|
}); err != nil {
|
|
globals.Warn(fmt.Sprintf("failed to send message to client: %s", err.Error()))
|
|
interruptSignal <- err
|
|
return hit, nil
|
|
}
|
|
|
|
case signal := <-stopSignal:
|
|
// if stop signal is received
|
|
if signal {
|
|
globals.Info(fmt.Sprintf("client stopped the chat request (model: %s, client: %s)", model, conn.GetCtx().ClientIP()))
|
|
_ = conn.SendClient(globals.ChatSegmentResponse{
|
|
Quota: buffer.GetQuota(),
|
|
End: true,
|
|
Plan: plan,
|
|
})
|
|
interruptSignal <- errors.New("signal")
|
|
|
|
return
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
func ChatHandler(conn *Connection, user *auth.User, instance *conversation.Conversation, restart bool) string {
|
|
defer func() {
|
|
if err := recover(); err != nil {
|
|
stack := debug.Stack()
|
|
globals.Warn(fmt.Sprintf("caught panic from chat handler: %s (instance: %s, client: %s)\n%s",
|
|
err, instance.GetModel(), conn.GetCtx().ClientIP(), stack,
|
|
))
|
|
}
|
|
}()
|
|
|
|
db := conn.GetDB()
|
|
cache := conn.GetCache()
|
|
|
|
model := instance.GetModel()
|
|
segment := adapter.ClearMessages(model, web.ToChatSearched(instance, restart))
|
|
|
|
check, plan := auth.CanEnableModelWithSubscription(db, cache, user, model, segment)
|
|
conn.Send(globals.ChatSegmentResponse{
|
|
Conversation: instance.GetId(),
|
|
})
|
|
|
|
if check != nil {
|
|
message := check.Error()
|
|
conn.Send(globals.ChatSegmentResponse{
|
|
Message: message,
|
|
Quota: 0,
|
|
End: true,
|
|
})
|
|
return message
|
|
}
|
|
|
|
buffer := utils.NewBuffer(model, segment, channel.ChargeInstance.GetCharge(model))
|
|
hit, err := createChatTask(conn, user, buffer, db, cache, model, instance, segment, plan)
|
|
|
|
admin.AnalyseRequest(model, buffer, err)
|
|
if adapter.IsAvailableError(err) {
|
|
globals.Warn(fmt.Sprintf("%s (model: %s, client: %s)", err, model, conn.GetCtx().ClientIP()))
|
|
|
|
auth.RevertSubscriptionUsage(db, cache, user, model)
|
|
conn.Send(globals.ChatSegmentResponse{
|
|
Message: err.Error(),
|
|
End: true,
|
|
})
|
|
return err.Error()
|
|
}
|
|
|
|
if !hit {
|
|
CollectQuota(conn.GetCtx(), user, buffer, plan, err)
|
|
}
|
|
|
|
if buffer.IsEmpty() {
|
|
conn.Send(globals.ChatSegmentResponse{
|
|
Message: defaultMessage,
|
|
End: true,
|
|
})
|
|
return defaultMessage
|
|
}
|
|
|
|
conn.Send(globals.ChatSegmentResponse{
|
|
End: true,
|
|
Quota: buffer.GetQuota(),
|
|
Plan: plan,
|
|
})
|
|
|
|
return buffer.ReadWithDefault(defaultMessage)
|
|
}
|