diff --git a/manager/chat.go b/manager/chat.go index e8d6f6e..0cfff6f 100644 --- a/manager/chat.go +++ b/manager/chat.go @@ -10,6 +10,7 @@ import ( "chat/globals" "chat/manager/conversation" "chat/utils" + "time" "database/sql" "errors" @@ -51,7 +52,9 @@ type partialChunk struct { func createStopSignal(conn *Connection) chan bool { stopSignal := make(chan bool, 1) go func(conn *Connection, stopSignal chan bool) { + ticker := time.NewTicker(100 * time.Millisecond) defer func() { + ticker.Stop() if r := recover(); r != nil && !strings.Contains(fmt.Sprintf("%s", r), "closed channel") { stack := debug.Stack() globals.Warn(fmt.Sprintf("caught panic from stop signal: %s\n%s", r, stack)) @@ -59,9 +62,14 @@ func createStopSignal(conn *Connection) chan bool { }() for { - if conn.PeekStop() != nil { - stopSignal <- true - break + select { + case <-ticker.C: + state := conn.PeekStop() != nil // check the stop state + stopSignal <- state + + if state { + break + } } } }(conn, stopSignal) @@ -169,16 +177,19 @@ func createChatTask( return hit, sendPackError } - case <-stopSignal: - globals.Info(fmt.Sprintf("client stopped the chat request (model: %s, client: %s)", model, conn.GetCtx().ClientIP())) - _ = conn.SendClient(globals.ChatSegmentResponse{ - Quota: buffer.GetQuota(), - End: true, - Plan: plan, - }) - interruptSignal <- errors.New("signal") + case signal := <-stopSignal: + // if stop signal is received + if signal { + globals.Info(fmt.Sprintf("client stopped the chat request (model: %s, client: %s)", model, conn.GetCtx().ClientIP())) + _ = conn.SendClient(globals.ChatSegmentResponse{ + Quota: buffer.GetQuota(), + End: true, + Plan: plan, + }) + interruptSignal <- errors.New("signal") - return + return + } } } }