fix: fix midjourney chunk stacking problem (#156) and stop signal cannot trigger in some channel formats issue

Co-Authored-By: Minghan Zhang <112773885+zmh-program@users.noreply.github.com>
This commit is contained in:
Deng Junhai 2024-03-31 17:26:22 +08:00
parent 6615cd92a5
commit ae063e943a
8 changed files with 68 additions and 40 deletions

View File

@ -6,8 +6,11 @@ import (
"chat/utils" "chat/utils"
"fmt" "fmt"
"strings" "strings"
"time"
) )
const maxTimeout = 30 * time.Minute // 30 min timeout
func getStatusCode(action string, response *CommonResponse) error { func getStatusCode(action string, response *CommonResponse) error {
code := response.Code code := response.Code
switch code { switch code {
@ -94,15 +97,18 @@ func (c *ChatInstance) CreateStreamTask(props *adaptercommon.ChatProps, action s
task := res.Result task := res.Result
progress := -1 progress := -1
ticker := time.NewTicker(50 * time.Millisecond)
defer ticker.Stop()
for { for {
utils.Sleep(50) select {
form := getStorage(task) case <-ticker.C:
form := getNotifyStorage(task)
if form == nil { if form == nil {
// hook for ping // hook for ping (in order to catch the stop signal)
if err := hook(nil, -1); err != nil { if err := hook(nil, -1); err != nil {
return nil, err return nil, err
} }
continue continue
} }
@ -122,6 +128,14 @@ func (c *ChatInstance) CreateStreamTask(props *adaptercommon.ChatProps, action s
} }
progress = current progress = current
} }
default:
// ping
if err := hook(form, -1); err != nil {
return nil, err
}
}
case <-time.After(maxTimeout):
return nil, fmt.Errorf("task timeout")
} }
} }
} }

View File

@ -14,6 +14,6 @@ func setStorage(task string, form StorageForm) error {
return utils.SetJson(connection.Cache, getTaskName(task), form, 60*60) return utils.SetJson(connection.Cache, getTaskName(task), form, 60*60)
} }
func getStorage(task string) *StorageForm { func getNotifyStorage(task string) *StorageForm {
return utils.GetJson[StorageForm](connection.Cache, getTaskName(task)) return utils.GetCacheStore[StorageForm](connection.Cache, getTaskName(task))
} }

View File

@ -1,7 +1,7 @@
package adapter package adapter
import ( import (
"chat/adapter/common" adaptercommon "chat/adapter/common"
"chat/globals" "chat/globals"
"chat/utils" "chat/utils"
"fmt" "fmt"
@ -10,7 +10,11 @@ import (
) )
func IsAvailableError(err error) bool { func IsAvailableError(err error) bool {
return err != nil && err.Error() != "signal" return err != nil && (err.Error() != "signal" && !strings.Contains(err.Error(), "signal"))
}
func IsSkipError(err error) bool {
return err == nil || (err.Error() == "signal" || strings.Contains(err.Error(), "signal"))
} }
func isQPSOverLimit(model string, err error) bool { func isQPSOverLimit(model string, err error) bool {
@ -26,6 +30,7 @@ func NewChatRequest(conf globals.ChannelConfig, props *adaptercommon.ChatProps,
retries := conf.GetRetry() retries := conf.GetRetry()
props.Current++ props.Current++
fmt.Println(IsAvailableError(err))
if IsAvailableError(err) { if IsAvailableError(err) {
if isQPSOverLimit(props.OriginalModel, err) { if isQPSOverLimit(props.OriginalModel, err) {
// sleep for 0.5s to avoid qps limit // sleep for 0.5s to avoid qps limit

View File

@ -1,10 +1,12 @@
package admin package admin
import ( import (
"chat/adapter"
"chat/connection" "chat/connection"
"chat/utils" "chat/utils"
"github.com/go-redis/redis/v8"
"time" "time"
"github.com/go-redis/redis/v8"
) )
func IncrErrorRequest(cache *redis.Client) { func IncrErrorRequest(cache *redis.Client) {
@ -27,7 +29,7 @@ func IncrModelRequest(cache *redis.Client, model string, tokens int64) {
func AnalysisRequest(model string, buffer *utils.Buffer, err error) { func AnalysisRequest(model string, buffer *utils.Buffer, err error) {
instance := connection.Cache instance := connection.Cache
if err != nil && err.Error() != "signal" { if adapter.IsAvailableError(err) {
IncrErrorRequest(instance) IncrErrorRequest(instance)
return return
} }

View File

@ -2,15 +2,20 @@ package channel
import ( import (
"chat/adapter" "chat/adapter"
"chat/adapter/common" adaptercommon "chat/adapter/common"
"chat/globals" "chat/globals"
"chat/utils" "chat/utils"
"fmt" "fmt"
"github.com/go-redis/redis/v8"
"time" "time"
"github.com/go-redis/redis/v8"
) )
func NewChatRequest(group string, props *adaptercommon.ChatProps, hook globals.Hook) error { func NewChatRequest(group string, props *adaptercommon.ChatProps, hook globals.Hook) error {
if err := AuditContent(props); err != nil {
return err
}
ticker := ConduitInstance.GetTicker(props.OriginalModel, group) ticker := ConduitInstance.GetTicker(props.OriginalModel, group)
if ticker == nil || ticker.IsEmpty() { if ticker == nil || ticker.IsEmpty() {
return fmt.Errorf("cannot find channel for model %s", props.OriginalModel) return fmt.Errorf("cannot find channel for model %s", props.OriginalModel)
@ -20,8 +25,8 @@ func NewChatRequest(group string, props *adaptercommon.ChatProps, hook globals.H
for !ticker.IsDone() { for !ticker.IsDone() {
if channel := ticker.Next(); channel != nil { if channel := ticker.Next(); channel != nil {
props.MaxRetries = utils.ToPtr(channel.GetRetry()) props.MaxRetries = utils.ToPtr(channel.GetRetry())
if err = adapter.NewChatRequest(channel, props, hook); err == nil || err.Error() == "signal" { if err = adapter.NewChatRequest(channel, props, hook); adapter.IsSkipError(err) {
return nil return err
} }
globals.Warn(fmt.Sprintf("[channel] caught error %s for model %s at channel %s", err.Error(), props.OriginalModel, channel.GetName())) globals.Warn(fmt.Sprintf("[channel] caught error %s for model %s at channel %s", err.Error(), props.OriginalModel, channel.GetName()))

View File

@ -63,7 +63,7 @@ func OriginIsAllowed(uri string) bool {
} }
func OriginIsOpen(c *gin.Context) bool { func OriginIsOpen(c *gin.Context) bool {
return strings.HasPrefix(c.Request.URL.Path, "/v1") || strings.HasPrefix(c.Request.URL.Path, "/dashboard") return strings.HasPrefix(c.Request.URL.Path, "/v1") || strings.HasPrefix(c.Request.URL.Path, "/dashboard") || strings.HasPrefix(c.Request.URL.Path, "/mj")
} }
const ( const (

View File

@ -11,8 +11,9 @@ import (
"chat/manager/conversation" "chat/manager/conversation"
"chat/utils" "chat/utils"
"fmt" "fmt"
"github.com/gin-gonic/gin"
"runtime/debug" "runtime/debug"
"github.com/gin-gonic/gin"
) )
const defaultMessage = "empty response" const defaultMessage = "empty response"
@ -96,7 +97,7 @@ func ChatHandler(conn *Connection, user *auth.User, instance *conversation.Conve
) )
admin.AnalysisRequest(model, buffer, err) admin.AnalysisRequest(model, buffer, err)
if err != nil && err.Error() != "signal" { if adapter.IsAvailableError(err) {
globals.Warn(fmt.Sprintf("%s (model: %s, client: %s)", err, model, conn.GetCtx().ClientIP())) globals.Warn(fmt.Sprintf("%s (model: %s, client: %s)", err, model, conn.GetCtx().ClientIP()))
auth.RevertSubscriptionUsage(db, cache, user, model) auth.RevertSubscriptionUsage(db, cache, user, model)

View File

@ -4,8 +4,9 @@ import (
"context" "context"
"errors" "errors"
"fmt" "fmt"
"github.com/go-redis/redis/v8"
"time" "time"
"github.com/go-redis/redis/v8"
) )
func Incr(cache *redis.Client, key string, delta int64) (int64, error) { func Incr(cache *redis.Client, key string, delta int64) (int64, error) {
@ -37,7 +38,7 @@ func SetJson(cache *redis.Client, key string, value interface{}, expiration int6
return err return err
} }
func GetJson[T any](cache *redis.Client, key string) *T { func GetCacheStore[T any](cache *redis.Client, key string) *T {
val, err := cache.Get(context.Background(), key).Result() val, err := cache.Get(context.Background(), key).Result()
if err != nil { if err != nil {
return nil return nil