mirror of
https://github.com/coaidev/coai.git
synced 2025-05-19 21:10:18 +09:00
fix: fix midjourney chunk stacking problem (#156) and stop signal cannot trigger in some channel formats issue
Co-Authored-By: Minghan Zhang <112773885+zmh-program@users.noreply.github.com>
This commit is contained in:
parent
6615cd92a5
commit
ae063e943a
@ -6,8 +6,11 @@ import (
|
||||
"chat/utils"
|
||||
"fmt"
|
||||
"strings"
|
||||
"time"
|
||||
)
|
||||
|
||||
const maxTimeout = 30 * time.Minute // 30 min timeout
|
||||
|
||||
func getStatusCode(action string, response *CommonResponse) error {
|
||||
code := response.Code
|
||||
switch code {
|
||||
@ -94,15 +97,18 @@ func (c *ChatInstance) CreateStreamTask(props *adaptercommon.ChatProps, action s
|
||||
task := res.Result
|
||||
progress := -1
|
||||
|
||||
ticker := time.NewTicker(50 * time.Millisecond)
|
||||
defer ticker.Stop()
|
||||
|
||||
for {
|
||||
utils.Sleep(50)
|
||||
form := getStorage(task)
|
||||
select {
|
||||
case <-ticker.C:
|
||||
form := getNotifyStorage(task)
|
||||
if form == nil {
|
||||
// hook for ping
|
||||
// hook for ping (in order to catch the stop signal)
|
||||
if err := hook(nil, -1); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
continue
|
||||
}
|
||||
|
||||
@ -122,6 +128,14 @@ func (c *ChatInstance) CreateStreamTask(props *adaptercommon.ChatProps, action s
|
||||
}
|
||||
progress = current
|
||||
}
|
||||
default:
|
||||
// ping
|
||||
if err := hook(form, -1); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
case <-time.After(maxTimeout):
|
||||
return nil, fmt.Errorf("task timeout")
|
||||
}
|
||||
}
|
||||
}
|
@ -14,6 +14,6 @@ func setStorage(task string, form StorageForm) error {
|
||||
return utils.SetJson(connection.Cache, getTaskName(task), form, 60*60)
|
||||
}
|
||||
|
||||
func getStorage(task string) *StorageForm {
|
||||
return utils.GetJson[StorageForm](connection.Cache, getTaskName(task))
|
||||
func getNotifyStorage(task string) *StorageForm {
|
||||
return utils.GetCacheStore[StorageForm](connection.Cache, getTaskName(task))
|
||||
}
|
||||
|
@ -1,7 +1,7 @@
|
||||
package adapter
|
||||
|
||||
import (
|
||||
"chat/adapter/common"
|
||||
adaptercommon "chat/adapter/common"
|
||||
"chat/globals"
|
||||
"chat/utils"
|
||||
"fmt"
|
||||
@ -10,7 +10,11 @@ import (
|
||||
)
|
||||
|
||||
func IsAvailableError(err error) bool {
|
||||
return err != nil && err.Error() != "signal"
|
||||
return err != nil && (err.Error() != "signal" && !strings.Contains(err.Error(), "signal"))
|
||||
}
|
||||
|
||||
func IsSkipError(err error) bool {
|
||||
return err == nil || (err.Error() == "signal" || strings.Contains(err.Error(), "signal"))
|
||||
}
|
||||
|
||||
func isQPSOverLimit(model string, err error) bool {
|
||||
@ -26,6 +30,7 @@ func NewChatRequest(conf globals.ChannelConfig, props *adaptercommon.ChatProps,
|
||||
retries := conf.GetRetry()
|
||||
props.Current++
|
||||
|
||||
fmt.Println(IsAvailableError(err))
|
||||
if IsAvailableError(err) {
|
||||
if isQPSOverLimit(props.OriginalModel, err) {
|
||||
// sleep for 0.5s to avoid qps limit
|
||||
|
@ -1,10 +1,12 @@
|
||||
package admin
|
||||
|
||||
import (
|
||||
"chat/adapter"
|
||||
"chat/connection"
|
||||
"chat/utils"
|
||||
"github.com/go-redis/redis/v8"
|
||||
"time"
|
||||
|
||||
"github.com/go-redis/redis/v8"
|
||||
)
|
||||
|
||||
func IncrErrorRequest(cache *redis.Client) {
|
||||
@ -27,7 +29,7 @@ func IncrModelRequest(cache *redis.Client, model string, tokens int64) {
|
||||
func AnalysisRequest(model string, buffer *utils.Buffer, err error) {
|
||||
instance := connection.Cache
|
||||
|
||||
if err != nil && err.Error() != "signal" {
|
||||
if adapter.IsAvailableError(err) {
|
||||
IncrErrorRequest(instance)
|
||||
return
|
||||
}
|
||||
|
@ -2,15 +2,20 @@ package channel
|
||||
|
||||
import (
|
||||
"chat/adapter"
|
||||
"chat/adapter/common"
|
||||
adaptercommon "chat/adapter/common"
|
||||
"chat/globals"
|
||||
"chat/utils"
|
||||
"fmt"
|
||||
"github.com/go-redis/redis/v8"
|
||||
"time"
|
||||
|
||||
"github.com/go-redis/redis/v8"
|
||||
)
|
||||
|
||||
func NewChatRequest(group string, props *adaptercommon.ChatProps, hook globals.Hook) error {
|
||||
if err := AuditContent(props); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
ticker := ConduitInstance.GetTicker(props.OriginalModel, group)
|
||||
if ticker == nil || ticker.IsEmpty() {
|
||||
return fmt.Errorf("cannot find channel for model %s", props.OriginalModel)
|
||||
@ -20,8 +25,8 @@ func NewChatRequest(group string, props *adaptercommon.ChatProps, hook globals.H
|
||||
for !ticker.IsDone() {
|
||||
if channel := ticker.Next(); channel != nil {
|
||||
props.MaxRetries = utils.ToPtr(channel.GetRetry())
|
||||
if err = adapter.NewChatRequest(channel, props, hook); err == nil || err.Error() == "signal" {
|
||||
return nil
|
||||
if err = adapter.NewChatRequest(channel, props, hook); adapter.IsSkipError(err) {
|
||||
return err
|
||||
}
|
||||
|
||||
globals.Warn(fmt.Sprintf("[channel] caught error %s for model %s at channel %s", err.Error(), props.OriginalModel, channel.GetName()))
|
||||
|
@ -63,7 +63,7 @@ func OriginIsAllowed(uri string) bool {
|
||||
}
|
||||
|
||||
func OriginIsOpen(c *gin.Context) bool {
|
||||
return strings.HasPrefix(c.Request.URL.Path, "/v1") || strings.HasPrefix(c.Request.URL.Path, "/dashboard")
|
||||
return strings.HasPrefix(c.Request.URL.Path, "/v1") || strings.HasPrefix(c.Request.URL.Path, "/dashboard") || strings.HasPrefix(c.Request.URL.Path, "/mj")
|
||||
}
|
||||
|
||||
const (
|
||||
|
@ -11,8 +11,9 @@ import (
|
||||
"chat/manager/conversation"
|
||||
"chat/utils"
|
||||
"fmt"
|
||||
"github.com/gin-gonic/gin"
|
||||
"runtime/debug"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
const defaultMessage = "empty response"
|
||||
@ -96,7 +97,7 @@ func ChatHandler(conn *Connection, user *auth.User, instance *conversation.Conve
|
||||
)
|
||||
|
||||
admin.AnalysisRequest(model, buffer, err)
|
||||
if err != nil && err.Error() != "signal" {
|
||||
if adapter.IsAvailableError(err) {
|
||||
globals.Warn(fmt.Sprintf("%s (model: %s, client: %s)", err, model, conn.GetCtx().ClientIP()))
|
||||
|
||||
auth.RevertSubscriptionUsage(db, cache, user, model)
|
||||
|
@ -4,8 +4,9 @@ import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"github.com/go-redis/redis/v8"
|
||||
"time"
|
||||
|
||||
"github.com/go-redis/redis/v8"
|
||||
)
|
||||
|
||||
func Incr(cache *redis.Client, key string, delta int64) (int64, error) {
|
||||
@ -37,7 +38,7 @@ func SetJson(cache *redis.Client, key string, value interface{}, expiration int6
|
||||
return err
|
||||
}
|
||||
|
||||
func GetJson[T any](cache *redis.Client, key string) *T {
|
||||
func GetCacheStore[T any](cache *redis.Client, key string) *T {
|
||||
val, err := cache.Get(context.Background(), key).Result()
|
||||
if err != nil {
|
||||
return nil
|
||||
|
Loading…
Reference in New Issue
Block a user