From ae063e943ab29018261e3771218e3acff5ae9815 Mon Sep 17 00:00:00 2001 From: Deng Junhai Date: Sun, 31 Mar 2024 17:26:22 +0800 Subject: [PATCH] fix: fix midjourney chunk stacking problem (#156) and stop signal cannot trigger in some channel formats issue Co-Authored-By: Minghan Zhang <112773885+zmh-program@users.noreply.github.com> --- adapter/midjourney/handler.go | 64 +++++++++++++++++++++-------------- adapter/midjourney/storage.go | 4 +-- adapter/request.go | 9 +++-- admin/statistic.go | 6 ++-- channel/worker.go | 13 ++++--- globals/variables.go | 2 +- manager/chat.go | 5 +-- utils/cache.go | 5 +-- 8 files changed, 68 insertions(+), 40 deletions(-) diff --git a/adapter/midjourney/handler.go b/adapter/midjourney/handler.go index c2a5d79..85ba9a2 100644 --- a/adapter/midjourney/handler.go +++ b/adapter/midjourney/handler.go @@ -6,8 +6,11 @@ import ( "chat/utils" "fmt" "strings" + "time" ) +const maxTimeout = 30 * time.Minute // 30 min timeout + func getStatusCode(action string, response *CommonResponse) error { code := response.Code switch code { @@ -94,34 +97,45 @@ func (c *ChatInstance) CreateStreamTask(props *adaptercommon.ChatProps, action s task := res.Result progress := -1 + ticker := time.NewTicker(50 * time.Millisecond) + defer ticker.Stop() + for { - utils.Sleep(50) - form := getStorage(task) - if form == nil { - // hook for ping - if err := hook(nil, -1); err != nil { - return nil, err - } - - continue - } - - switch form.Status { - case Success: - if err := hook(form, 100); err != nil { - return nil, err - } - return form, nil - case Failure: - return nil, fmt.Errorf("task failed: %s", form.FailReason) - case InProgress: - current := getProgress(form.Progress) - if progress != current { - if err := hook(form, current); err != nil { + select { + case <-ticker.C: + form := getNotifyStorage(task) + if form == nil { + // hook for ping (in order to catch the stop signal) + if err := hook(nil, -1); err != nil { return nil, err } - progress = current + continue } + + switch form.Status { + case Success: + if err := hook(form, 100); err != nil { + return nil, err + } + return form, nil + case Failure: + return nil, fmt.Errorf("task failed: %s", form.FailReason) + case InProgress: + current := getProgress(form.Progress) + if progress != current { + if err := hook(form, current); err != nil { + return nil, err + } + progress = current + } + default: + // ping + if err := hook(form, -1); err != nil { + return nil, err + } + } + case <-time.After(maxTimeout): + return nil, fmt.Errorf("task timeout") } } -} +} \ No newline at end of file diff --git a/adapter/midjourney/storage.go b/adapter/midjourney/storage.go index 18b5c7b..3b13482 100644 --- a/adapter/midjourney/storage.go +++ b/adapter/midjourney/storage.go @@ -14,6 +14,6 @@ func setStorage(task string, form StorageForm) error { return utils.SetJson(connection.Cache, getTaskName(task), form, 60*60) } -func getStorage(task string) *StorageForm { - return utils.GetJson[StorageForm](connection.Cache, getTaskName(task)) +func getNotifyStorage(task string) *StorageForm { + return utils.GetCacheStore[StorageForm](connection.Cache, getTaskName(task)) } diff --git a/adapter/request.go b/adapter/request.go index 074d48a..298c00b 100644 --- a/adapter/request.go +++ b/adapter/request.go @@ -1,7 +1,7 @@ package adapter import ( - "chat/adapter/common" + adaptercommon "chat/adapter/common" "chat/globals" "chat/utils" "fmt" @@ -10,7 +10,11 @@ import ( ) func IsAvailableError(err error) bool { - return err != nil && err.Error() != "signal" + return err != nil && (err.Error() != "signal" && !strings.Contains(err.Error(), "signal")) +} + +func IsSkipError(err error) bool { + return err == nil || (err.Error() == "signal" || strings.Contains(err.Error(), "signal")) } func isQPSOverLimit(model string, err error) bool { @@ -26,6 +30,7 @@ func NewChatRequest(conf globals.ChannelConfig, props *adaptercommon.ChatProps, retries := conf.GetRetry() props.Current++ + fmt.Println(IsAvailableError(err)) if IsAvailableError(err) { if isQPSOverLimit(props.OriginalModel, err) { // sleep for 0.5s to avoid qps limit diff --git a/admin/statistic.go b/admin/statistic.go index 0838571..c846a2e 100644 --- a/admin/statistic.go +++ b/admin/statistic.go @@ -1,10 +1,12 @@ package admin import ( + "chat/adapter" "chat/connection" "chat/utils" - "github.com/go-redis/redis/v8" "time" + + "github.com/go-redis/redis/v8" ) func IncrErrorRequest(cache *redis.Client) { @@ -27,7 +29,7 @@ func IncrModelRequest(cache *redis.Client, model string, tokens int64) { func AnalysisRequest(model string, buffer *utils.Buffer, err error) { instance := connection.Cache - if err != nil && err.Error() != "signal" { + if adapter.IsAvailableError(err) { IncrErrorRequest(instance) return } diff --git a/channel/worker.go b/channel/worker.go index abc90b4..a69244f 100644 --- a/channel/worker.go +++ b/channel/worker.go @@ -2,15 +2,20 @@ package channel import ( "chat/adapter" - "chat/adapter/common" + adaptercommon "chat/adapter/common" "chat/globals" "chat/utils" "fmt" - "github.com/go-redis/redis/v8" "time" + + "github.com/go-redis/redis/v8" ) func NewChatRequest(group string, props *adaptercommon.ChatProps, hook globals.Hook) error { + if err := AuditContent(props); err != nil { + return err + } + ticker := ConduitInstance.GetTicker(props.OriginalModel, group) if ticker == nil || ticker.IsEmpty() { return fmt.Errorf("cannot find channel for model %s", props.OriginalModel) @@ -20,8 +25,8 @@ func NewChatRequest(group string, props *adaptercommon.ChatProps, hook globals.H for !ticker.IsDone() { if channel := ticker.Next(); channel != nil { props.MaxRetries = utils.ToPtr(channel.GetRetry()) - if err = adapter.NewChatRequest(channel, props, hook); err == nil || err.Error() == "signal" { - return nil + if err = adapter.NewChatRequest(channel, props, hook); adapter.IsSkipError(err) { + return err } globals.Warn(fmt.Sprintf("[channel] caught error %s for model %s at channel %s", err.Error(), props.OriginalModel, channel.GetName())) diff --git a/globals/variables.go b/globals/variables.go index 0eee1ac..802728d 100644 --- a/globals/variables.go +++ b/globals/variables.go @@ -63,7 +63,7 @@ func OriginIsAllowed(uri string) bool { } func OriginIsOpen(c *gin.Context) bool { - return strings.HasPrefix(c.Request.URL.Path, "/v1") || strings.HasPrefix(c.Request.URL.Path, "/dashboard") + return strings.HasPrefix(c.Request.URL.Path, "/v1") || strings.HasPrefix(c.Request.URL.Path, "/dashboard") || strings.HasPrefix(c.Request.URL.Path, "/mj") } const ( diff --git a/manager/chat.go b/manager/chat.go index 418031d..ed1817a 100644 --- a/manager/chat.go +++ b/manager/chat.go @@ -11,8 +11,9 @@ import ( "chat/manager/conversation" "chat/utils" "fmt" - "github.com/gin-gonic/gin" "runtime/debug" + + "github.com/gin-gonic/gin" ) const defaultMessage = "empty response" @@ -96,7 +97,7 @@ func ChatHandler(conn *Connection, user *auth.User, instance *conversation.Conve ) admin.AnalysisRequest(model, buffer, err) - if err != nil && err.Error() != "signal" { + if adapter.IsAvailableError(err) { globals.Warn(fmt.Sprintf("%s (model: %s, client: %s)", err, model, conn.GetCtx().ClientIP())) auth.RevertSubscriptionUsage(db, cache, user, model) diff --git a/utils/cache.go b/utils/cache.go index 561df96..6fb95c5 100644 --- a/utils/cache.go +++ b/utils/cache.go @@ -4,8 +4,9 @@ import ( "context" "errors" "fmt" - "github.com/go-redis/redis/v8" "time" + + "github.com/go-redis/redis/v8" ) func Incr(cache *redis.Client, key string, delta int64) (int64, error) { @@ -37,7 +38,7 @@ func SetJson(cache *redis.Client, key string, value interface{}, expiration int6 return err } -func GetJson[T any](cache *redis.Client, key string) *T { +func GetCacheStore[T any](cache *redis.Client, key string) *T { val, err := cache.Get(context.Background(), key).Result() if err != nil { return nil