fix: fix midjourney chunk stacking problem (#156) and stop signal cannot trigger in some channel formats issue

Co-Authored-By: Minghan Zhang <112773885+zmh-program@users.noreply.github.com>
This commit is contained in:
Deng Junhai 2024-03-31 17:26:22 +08:00
parent 6615cd92a5
commit ae063e943a
8 changed files with 68 additions and 40 deletions

View File

@ -6,8 +6,11 @@ import (
"chat/utils"
"fmt"
"strings"
"time"
)
const maxTimeout = 30 * time.Minute // 30 min timeout
func getStatusCode(action string, response *CommonResponse) error {
code := response.Code
switch code {
@ -94,34 +97,45 @@ func (c *ChatInstance) CreateStreamTask(props *adaptercommon.ChatProps, action s
task := res.Result
progress := -1
ticker := time.NewTicker(50 * time.Millisecond)
defer ticker.Stop()
for {
utils.Sleep(50)
form := getStorage(task)
if form == nil {
// hook for ping
if err := hook(nil, -1); err != nil {
return nil, err
}
continue
}
switch form.Status {
case Success:
if err := hook(form, 100); err != nil {
return nil, err
}
return form, nil
case Failure:
return nil, fmt.Errorf("task failed: %s", form.FailReason)
case InProgress:
current := getProgress(form.Progress)
if progress != current {
if err := hook(form, current); err != nil {
select {
case <-ticker.C:
form := getNotifyStorage(task)
if form == nil {
// hook for ping (in order to catch the stop signal)
if err := hook(nil, -1); err != nil {
return nil, err
}
progress = current
continue
}
switch form.Status {
case Success:
if err := hook(form, 100); err != nil {
return nil, err
}
return form, nil
case Failure:
return nil, fmt.Errorf("task failed: %s", form.FailReason)
case InProgress:
current := getProgress(form.Progress)
if progress != current {
if err := hook(form, current); err != nil {
return nil, err
}
progress = current
}
default:
// ping
if err := hook(form, -1); err != nil {
return nil, err
}
}
case <-time.After(maxTimeout):
return nil, fmt.Errorf("task timeout")
}
}
}
}

View File

@ -14,6 +14,6 @@ func setStorage(task string, form StorageForm) error {
return utils.SetJson(connection.Cache, getTaskName(task), form, 60*60)
}
func getStorage(task string) *StorageForm {
return utils.GetJson[StorageForm](connection.Cache, getTaskName(task))
func getNotifyStorage(task string) *StorageForm {
return utils.GetCacheStore[StorageForm](connection.Cache, getTaskName(task))
}

View File

@ -1,7 +1,7 @@
package adapter
import (
"chat/adapter/common"
adaptercommon "chat/adapter/common"
"chat/globals"
"chat/utils"
"fmt"
@ -10,7 +10,11 @@ import (
)
func IsAvailableError(err error) bool {
return err != nil && err.Error() != "signal"
return err != nil && (err.Error() != "signal" && !strings.Contains(err.Error(), "signal"))
}
func IsSkipError(err error) bool {
return err == nil || (err.Error() == "signal" || strings.Contains(err.Error(), "signal"))
}
func isQPSOverLimit(model string, err error) bool {
@ -26,6 +30,7 @@ func NewChatRequest(conf globals.ChannelConfig, props *adaptercommon.ChatProps,
retries := conf.GetRetry()
props.Current++
fmt.Println(IsAvailableError(err))
if IsAvailableError(err) {
if isQPSOverLimit(props.OriginalModel, err) {
// sleep for 0.5s to avoid qps limit

View File

@ -1,10 +1,12 @@
package admin
import (
"chat/adapter"
"chat/connection"
"chat/utils"
"github.com/go-redis/redis/v8"
"time"
"github.com/go-redis/redis/v8"
)
func IncrErrorRequest(cache *redis.Client) {
@ -27,7 +29,7 @@ func IncrModelRequest(cache *redis.Client, model string, tokens int64) {
func AnalysisRequest(model string, buffer *utils.Buffer, err error) {
instance := connection.Cache
if err != nil && err.Error() != "signal" {
if adapter.IsAvailableError(err) {
IncrErrorRequest(instance)
return
}

View File

@ -2,15 +2,20 @@ package channel
import (
"chat/adapter"
"chat/adapter/common"
adaptercommon "chat/adapter/common"
"chat/globals"
"chat/utils"
"fmt"
"github.com/go-redis/redis/v8"
"time"
"github.com/go-redis/redis/v8"
)
func NewChatRequest(group string, props *adaptercommon.ChatProps, hook globals.Hook) error {
if err := AuditContent(props); err != nil {
return err
}
ticker := ConduitInstance.GetTicker(props.OriginalModel, group)
if ticker == nil || ticker.IsEmpty() {
return fmt.Errorf("cannot find channel for model %s", props.OriginalModel)
@ -20,8 +25,8 @@ func NewChatRequest(group string, props *adaptercommon.ChatProps, hook globals.H
for !ticker.IsDone() {
if channel := ticker.Next(); channel != nil {
props.MaxRetries = utils.ToPtr(channel.GetRetry())
if err = adapter.NewChatRequest(channel, props, hook); err == nil || err.Error() == "signal" {
return nil
if err = adapter.NewChatRequest(channel, props, hook); adapter.IsSkipError(err) {
return err
}
globals.Warn(fmt.Sprintf("[channel] caught error %s for model %s at channel %s", err.Error(), props.OriginalModel, channel.GetName()))

View File

@ -63,7 +63,7 @@ func OriginIsAllowed(uri string) bool {
}
func OriginIsOpen(c *gin.Context) bool {
return strings.HasPrefix(c.Request.URL.Path, "/v1") || strings.HasPrefix(c.Request.URL.Path, "/dashboard")
return strings.HasPrefix(c.Request.URL.Path, "/v1") || strings.HasPrefix(c.Request.URL.Path, "/dashboard") || strings.HasPrefix(c.Request.URL.Path, "/mj")
}
const (

View File

@ -11,8 +11,9 @@ import (
"chat/manager/conversation"
"chat/utils"
"fmt"
"github.com/gin-gonic/gin"
"runtime/debug"
"github.com/gin-gonic/gin"
)
const defaultMessage = "empty response"
@ -96,7 +97,7 @@ func ChatHandler(conn *Connection, user *auth.User, instance *conversation.Conve
)
admin.AnalysisRequest(model, buffer, err)
if err != nil && err.Error() != "signal" {
if adapter.IsAvailableError(err) {
globals.Warn(fmt.Sprintf("%s (model: %s, client: %s)", err, model, conn.GetCtx().ClientIP()))
auth.RevertSubscriptionUsage(db, cache, user, model)

View File

@ -4,8 +4,9 @@ import (
"context"
"errors"
"fmt"
"github.com/go-redis/redis/v8"
"time"
"github.com/go-redis/redis/v8"
)
func Incr(cache *redis.Client, key string, delta int64) (int64, error) {
@ -37,7 +38,7 @@ func SetJson(cache *redis.Client, key string, value interface{}, expiration int6
return err
}
func GetJson[T any](cache *redis.Client, key string) *T {
func GetCacheStore[T any](cache *redis.Client, key string) *T {
val, err := cache.Get(context.Background(), key).Result()
if err != nil {
return nil