mirror of
https://github.com/coaidev/coai.git
synced 2025-05-20 05:20:15 +09:00
fix: fix midjourney chunk stacking problem (#156) and stop signal cannot trigger in some channel formats issue
Co-Authored-By: Minghan Zhang <112773885+zmh-program@users.noreply.github.com>
This commit is contained in:
parent
6615cd92a5
commit
ae063e943a
@ -6,8 +6,11 @@ import (
|
|||||||
"chat/utils"
|
"chat/utils"
|
||||||
"fmt"
|
"fmt"
|
||||||
"strings"
|
"strings"
|
||||||
|
"time"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
const maxTimeout = 30 * time.Minute // 30 min timeout
|
||||||
|
|
||||||
func getStatusCode(action string, response *CommonResponse) error {
|
func getStatusCode(action string, response *CommonResponse) error {
|
||||||
code := response.Code
|
code := response.Code
|
||||||
switch code {
|
switch code {
|
||||||
@ -94,15 +97,18 @@ func (c *ChatInstance) CreateStreamTask(props *adaptercommon.ChatProps, action s
|
|||||||
task := res.Result
|
task := res.Result
|
||||||
progress := -1
|
progress := -1
|
||||||
|
|
||||||
|
ticker := time.NewTicker(50 * time.Millisecond)
|
||||||
|
defer ticker.Stop()
|
||||||
|
|
||||||
for {
|
for {
|
||||||
utils.Sleep(50)
|
select {
|
||||||
form := getStorage(task)
|
case <-ticker.C:
|
||||||
|
form := getNotifyStorage(task)
|
||||||
if form == nil {
|
if form == nil {
|
||||||
// hook for ping
|
// hook for ping (in order to catch the stop signal)
|
||||||
if err := hook(nil, -1); err != nil {
|
if err := hook(nil, -1); err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -122,6 +128,14 @@ func (c *ChatInstance) CreateStreamTask(props *adaptercommon.ChatProps, action s
|
|||||||
}
|
}
|
||||||
progress = current
|
progress = current
|
||||||
}
|
}
|
||||||
|
default:
|
||||||
|
// ping
|
||||||
|
if err := hook(form, -1); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
case <-time.After(maxTimeout):
|
||||||
|
return nil, fmt.Errorf("task timeout")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
@ -14,6 +14,6 @@ func setStorage(task string, form StorageForm) error {
|
|||||||
return utils.SetJson(connection.Cache, getTaskName(task), form, 60*60)
|
return utils.SetJson(connection.Cache, getTaskName(task), form, 60*60)
|
||||||
}
|
}
|
||||||
|
|
||||||
func getStorage(task string) *StorageForm {
|
func getNotifyStorage(task string) *StorageForm {
|
||||||
return utils.GetJson[StorageForm](connection.Cache, getTaskName(task))
|
return utils.GetCacheStore[StorageForm](connection.Cache, getTaskName(task))
|
||||||
}
|
}
|
||||||
|
@ -1,7 +1,7 @@
|
|||||||
package adapter
|
package adapter
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"chat/adapter/common"
|
adaptercommon "chat/adapter/common"
|
||||||
"chat/globals"
|
"chat/globals"
|
||||||
"chat/utils"
|
"chat/utils"
|
||||||
"fmt"
|
"fmt"
|
||||||
@ -10,7 +10,11 @@ import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
func IsAvailableError(err error) bool {
|
func IsAvailableError(err error) bool {
|
||||||
return err != nil && err.Error() != "signal"
|
return err != nil && (err.Error() != "signal" && !strings.Contains(err.Error(), "signal"))
|
||||||
|
}
|
||||||
|
|
||||||
|
func IsSkipError(err error) bool {
|
||||||
|
return err == nil || (err.Error() == "signal" || strings.Contains(err.Error(), "signal"))
|
||||||
}
|
}
|
||||||
|
|
||||||
func isQPSOverLimit(model string, err error) bool {
|
func isQPSOverLimit(model string, err error) bool {
|
||||||
@ -26,6 +30,7 @@ func NewChatRequest(conf globals.ChannelConfig, props *adaptercommon.ChatProps,
|
|||||||
retries := conf.GetRetry()
|
retries := conf.GetRetry()
|
||||||
props.Current++
|
props.Current++
|
||||||
|
|
||||||
|
fmt.Println(IsAvailableError(err))
|
||||||
if IsAvailableError(err) {
|
if IsAvailableError(err) {
|
||||||
if isQPSOverLimit(props.OriginalModel, err) {
|
if isQPSOverLimit(props.OriginalModel, err) {
|
||||||
// sleep for 0.5s to avoid qps limit
|
// sleep for 0.5s to avoid qps limit
|
||||||
|
@ -1,10 +1,12 @@
|
|||||||
package admin
|
package admin
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"chat/adapter"
|
||||||
"chat/connection"
|
"chat/connection"
|
||||||
"chat/utils"
|
"chat/utils"
|
||||||
"github.com/go-redis/redis/v8"
|
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
|
"github.com/go-redis/redis/v8"
|
||||||
)
|
)
|
||||||
|
|
||||||
func IncrErrorRequest(cache *redis.Client) {
|
func IncrErrorRequest(cache *redis.Client) {
|
||||||
@ -27,7 +29,7 @@ func IncrModelRequest(cache *redis.Client, model string, tokens int64) {
|
|||||||
func AnalysisRequest(model string, buffer *utils.Buffer, err error) {
|
func AnalysisRequest(model string, buffer *utils.Buffer, err error) {
|
||||||
instance := connection.Cache
|
instance := connection.Cache
|
||||||
|
|
||||||
if err != nil && err.Error() != "signal" {
|
if adapter.IsAvailableError(err) {
|
||||||
IncrErrorRequest(instance)
|
IncrErrorRequest(instance)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
@ -2,15 +2,20 @@ package channel
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"chat/adapter"
|
"chat/adapter"
|
||||||
"chat/adapter/common"
|
adaptercommon "chat/adapter/common"
|
||||||
"chat/globals"
|
"chat/globals"
|
||||||
"chat/utils"
|
"chat/utils"
|
||||||
"fmt"
|
"fmt"
|
||||||
"github.com/go-redis/redis/v8"
|
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
|
"github.com/go-redis/redis/v8"
|
||||||
)
|
)
|
||||||
|
|
||||||
func NewChatRequest(group string, props *adaptercommon.ChatProps, hook globals.Hook) error {
|
func NewChatRequest(group string, props *adaptercommon.ChatProps, hook globals.Hook) error {
|
||||||
|
if err := AuditContent(props); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
ticker := ConduitInstance.GetTicker(props.OriginalModel, group)
|
ticker := ConduitInstance.GetTicker(props.OriginalModel, group)
|
||||||
if ticker == nil || ticker.IsEmpty() {
|
if ticker == nil || ticker.IsEmpty() {
|
||||||
return fmt.Errorf("cannot find channel for model %s", props.OriginalModel)
|
return fmt.Errorf("cannot find channel for model %s", props.OriginalModel)
|
||||||
@ -20,8 +25,8 @@ func NewChatRequest(group string, props *adaptercommon.ChatProps, hook globals.H
|
|||||||
for !ticker.IsDone() {
|
for !ticker.IsDone() {
|
||||||
if channel := ticker.Next(); channel != nil {
|
if channel := ticker.Next(); channel != nil {
|
||||||
props.MaxRetries = utils.ToPtr(channel.GetRetry())
|
props.MaxRetries = utils.ToPtr(channel.GetRetry())
|
||||||
if err = adapter.NewChatRequest(channel, props, hook); err == nil || err.Error() == "signal" {
|
if err = adapter.NewChatRequest(channel, props, hook); adapter.IsSkipError(err) {
|
||||||
return nil
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
globals.Warn(fmt.Sprintf("[channel] caught error %s for model %s at channel %s", err.Error(), props.OriginalModel, channel.GetName()))
|
globals.Warn(fmt.Sprintf("[channel] caught error %s for model %s at channel %s", err.Error(), props.OriginalModel, channel.GetName()))
|
||||||
|
@ -63,7 +63,7 @@ func OriginIsAllowed(uri string) bool {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func OriginIsOpen(c *gin.Context) bool {
|
func OriginIsOpen(c *gin.Context) bool {
|
||||||
return strings.HasPrefix(c.Request.URL.Path, "/v1") || strings.HasPrefix(c.Request.URL.Path, "/dashboard")
|
return strings.HasPrefix(c.Request.URL.Path, "/v1") || strings.HasPrefix(c.Request.URL.Path, "/dashboard") || strings.HasPrefix(c.Request.URL.Path, "/mj")
|
||||||
}
|
}
|
||||||
|
|
||||||
const (
|
const (
|
||||||
|
@ -11,8 +11,9 @@ import (
|
|||||||
"chat/manager/conversation"
|
"chat/manager/conversation"
|
||||||
"chat/utils"
|
"chat/utils"
|
||||||
"fmt"
|
"fmt"
|
||||||
"github.com/gin-gonic/gin"
|
|
||||||
"runtime/debug"
|
"runtime/debug"
|
||||||
|
|
||||||
|
"github.com/gin-gonic/gin"
|
||||||
)
|
)
|
||||||
|
|
||||||
const defaultMessage = "empty response"
|
const defaultMessage = "empty response"
|
||||||
@ -96,7 +97,7 @@ func ChatHandler(conn *Connection, user *auth.User, instance *conversation.Conve
|
|||||||
)
|
)
|
||||||
|
|
||||||
admin.AnalysisRequest(model, buffer, err)
|
admin.AnalysisRequest(model, buffer, err)
|
||||||
if err != nil && err.Error() != "signal" {
|
if adapter.IsAvailableError(err) {
|
||||||
globals.Warn(fmt.Sprintf("%s (model: %s, client: %s)", err, model, conn.GetCtx().ClientIP()))
|
globals.Warn(fmt.Sprintf("%s (model: %s, client: %s)", err, model, conn.GetCtx().ClientIP()))
|
||||||
|
|
||||||
auth.RevertSubscriptionUsage(db, cache, user, model)
|
auth.RevertSubscriptionUsage(db, cache, user, model)
|
||||||
|
@ -4,8 +4,9 @@ import (
|
|||||||
"context"
|
"context"
|
||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
"github.com/go-redis/redis/v8"
|
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
|
"github.com/go-redis/redis/v8"
|
||||||
)
|
)
|
||||||
|
|
||||||
func Incr(cache *redis.Client, key string, delta int64) (int64, error) {
|
func Incr(cache *redis.Client, key string, delta int64) (int64, error) {
|
||||||
@ -37,7 +38,7 @@ func SetJson(cache *redis.Client, key string, value interface{}, expiration int6
|
|||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
func GetJson[T any](cache *redis.Client, key string) *T {
|
func GetCacheStore[T any](cache *redis.Client, key string) *T {
|
||||||
val, err := cache.Get(context.Background(), key).Result()
|
val, err := cache.Get(context.Background(), key).Result()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil
|
return nil
|
||||||
|
Loading…
Reference in New Issue
Block a user