From 52419e1b2b83c121a2350d8ed020b6c156bd94c0 Mon Sep 17 00:00:00 2001 From: Zhang Minghan Date: Fri, 8 Mar 2024 22:18:45 +0800 Subject: [PATCH] feat: support relay midjourney-proxy format --- adapter/midjourney/api.go | 9 ++++++--- adapter/midjourney/chat.go | 18 ++++++++++++++++-- adapter/midjourney/handler.go | 2 +- manager/chat.go | 8 +++++--- manager/completions.go | 6 ++++-- utils/char.go | 18 ++++++++++++++++++ utils/net.go | 12 +++++++++++- utils/scanner.go | 4 +++- 8 files changed, 64 insertions(+), 13 deletions(-) diff --git a/adapter/midjourney/api.go b/adapter/midjourney/api.go index 9fb9bf2..bdd136c 100644 --- a/adapter/midjourney/api.go +++ b/adapter/midjourney/api.go @@ -30,7 +30,7 @@ func (c *ChatInstance) GetChangeRequest(action string, task string, index *int) } func (c *ChatInstance) CreateImagineRequest(prompt string) (*CommonResponse, error) { - res, err := utils.Post( + content, err := utils.PostRaw( c.GetImagineEndpoint(), c.GetMidjourneyHeaders(), c.GetImagineRequest(prompt), @@ -40,7 +40,11 @@ func (c *ChatInstance) CreateImagineRequest(prompt string) (*CommonResponse, err return nil, err } - return utils.MapToStruct[CommonResponse](res), nil + if data, err := utils.UnmarshalString[CommonResponse](content); err == nil { + return &data, nil + } else { + return nil, utils.ToMarkdownError(err, content) + } } func (c *ChatInstance) CreateChangeRequest(action string, task string, index *int) (*CommonResponse, error) { @@ -54,6 +58,5 @@ func (c *ChatInstance) CreateChangeRequest(action string, task string, index *in return nil, err } - fmt.Println(res) return utils.MapToStruct[CommonResponse](res), nil } diff --git a/adapter/midjourney/chat.go b/adapter/midjourney/chat.go index 33582e1..de156a2 100644 --- a/adapter/midjourney/chat.go +++ b/adapter/midjourney/chat.go @@ -40,7 +40,17 @@ func getMode(model string) string { } } +func (c *ChatInstance) IsIgnoreMode() bool { + return strings.HasSuffix(c.Endpoint, "/mj-relax") || + strings.HasSuffix(c.Endpoint, "/mj-fast") || + strings.HasSuffix(c.Endpoint, "/mj-turbo") +} + func (c *ChatInstance) GetCleanPrompt(model string, prompt string) string { + if c.IsIgnoreMode() { + return prompt + } + arr := strings.Split(strings.TrimSpace(prompt), " ") var res []string @@ -57,7 +67,12 @@ func (c *ChatInstance) GetCleanPrompt(model string, prompt string) string { } func (c *ChatInstance) GetPrompt(props *ChatProps) string { - return c.GetCleanPrompt(props.Model, props.Messages[len(props.Messages)-1].Content) + if len(props.Messages) == 0 { + return "" + } + + content := props.Messages[len(props.Messages)-1].Content + return c.GetCleanPrompt(props.Model, content) } func (c *ChatInstance) CreateStreamChatRequest(props *ChatProps, callback globals.Hook) error { @@ -72,7 +87,6 @@ func (c *ChatInstance) CreateStreamChatRequest(props *ChatProps, callback global if len(globals.NotifyUrl) == 0 { return fmt.Errorf("format error: please provide available notify url") } - action, prompt := c.ExtractPrompt(c.GetPrompt(props)) if len(prompt) == 0 { return fmt.Errorf("format error: please provide available prompt") diff --git a/adapter/midjourney/handler.go b/adapter/midjourney/handler.go index 6f9fd54..dff210a 100644 --- a/adapter/midjourney/handler.go +++ b/adapter/midjourney/handler.go @@ -93,7 +93,7 @@ func (c *ChatInstance) CreateStreamTask(action string, prompt string, hook func( progress := -1 for { - utils.Sleep(100) + utils.Sleep(50) form := getStorage(task) if form == nil { continue diff --git a/manager/chat.go b/manager/chat.go index f08f501..20ae633 100644 --- a/manager/chat.go +++ b/manager/chat.go @@ -11,6 +11,7 @@ import ( "chat/utils" "fmt" "github.com/gin-gonic/gin" + "runtime/debug" "time" ) @@ -57,8 +58,9 @@ func MockStreamSender(conn *Connection, message string) { func ChatHandler(conn *Connection, user *auth.User, instance *conversation.Conversation) string { defer func() { if err := recover(); err != nil { - globals.Warn(fmt.Sprintf("caught panic from chat handler: %s (instance: %s, client: %s)", - err, instance.GetModel(), conn.GetCtx().ClientIP(), + stack := debug.Stack() + globals.Warn(fmt.Sprintf("caught panic from chat handler: %s (instance: %s, client: %s)\n%s", + err, instance.GetModel(), conn.GetCtx().ClientIP(), stack, )) } }() @@ -116,7 +118,7 @@ func ChatHandler(conn *Connection, user *auth.User, instance *conversation.Conve admin.AnalysisRequest(model, buffer, err) if err != nil && err.Error() != "signal" { - globals.Warn(fmt.Sprintf("caught error from chat handler: %s (instance: %s, client: %s)", err, model, conn.GetCtx().ClientIP())) + globals.Warn(fmt.Sprintf("%s (model: %s, client: %s)", err, model, conn.GetCtx().ClientIP())) auth.RevertSubscriptionUsage(db, cache, user, model) conn.Send(globals.ChatSegmentResponse{ diff --git a/manager/completions.go b/manager/completions.go index ad1eef7..cd8e9aa 100644 --- a/manager/completions.go +++ b/manager/completions.go @@ -10,13 +10,15 @@ import ( "chat/utils" "fmt" "github.com/gin-gonic/gin" + "runtime/debug" ) func NativeChatHandler(c *gin.Context, user *auth.User, model string, message []globals.Message, enableWeb bool) (string, float32) { defer func() { if err := recover(); err != nil { - globals.Warn(fmt.Sprintf("caught panic from chat handler: %s (instance: %s, client: %s)", - err, model, c.ClientIP(), + stack := debug.Stack() + globals.Warn(fmt.Sprintf("caught panic from chat handler: %s (instance: %s, client: %s)\n%s", + err, model, c.ClientIP(), stack, )) } }() diff --git a/utils/char.go b/utils/char.go index 1a9927a..2d1b100 100644 --- a/utils/char.go +++ b/utils/char.go @@ -92,6 +92,20 @@ func MapToStruct[T any](data interface{}) *T { } } +func MapToRawStruct[T any](data interface{}) (*T, error) { + val, err := json.Marshal(data) + if err != nil { + return nil, err + } + + form, err := Unmarshal[T](val) + if err != nil { + return nil, err + } + + return &form, nil +} + func ParseInt(value string) int { if res, err := strconv.Atoi(value); err == nil { return res @@ -332,3 +346,7 @@ func ToSecret(raw string) string { func ToMarkdownCode(lang string, code string) string { return fmt.Sprintf("```%s\n%s\n```", lang, code) } + +func ToMarkdownError(err error, body string) error { + return fmt.Errorf("%s\n%s", err.Error(), ToMarkdownCode("html", body)) +} diff --git a/utils/net.go b/utils/net.go index be3da95..dc78d2a 100644 --- a/utils/net.go +++ b/utils/net.go @@ -8,6 +8,7 @@ import ( "github.com/goccy/go-json" "io" "net/http" + "runtime/debug" "strings" "time" ) @@ -89,6 +90,14 @@ func Post(uri string, headers map[string]string, body interface{}) (data interfa return data, err } +func PostRaw(uri string, headers map[string]string, body interface{}) (data string, err error) { + buffer, err := HttpRaw(uri, http.MethodPost, headers, ConvertBody(body)) + if err != nil { + return "", err + } + return string(buffer), nil +} + func ConvertBody(body interface{}) (form io.Reader) { if buffer, err := json.Marshal(body); err == nil { form = bytes.NewBuffer(buffer) @@ -100,7 +109,8 @@ func EventSource(method string, uri string, headers map[string]string, body inte // panic recovery defer func() { if err := recover(); err != nil { - globals.Warn(fmt.Sprintf("event source panic: %s (uri: %s, method: %s)", err, uri, method)) + stack := debug.Stack() + globals.Warn(fmt.Sprintf("event source panic: %s (uri: %s, method: %s)\n%s", err, uri, method, stack)) } }() diff --git a/utils/scanner.go b/utils/scanner.go index 23344ae..6e82959 100644 --- a/utils/scanner.go +++ b/utils/scanner.go @@ -7,6 +7,7 @@ import ( "fmt" "io" "net/http" + "runtime/debug" "strings" ) @@ -39,7 +40,8 @@ func EventScanner(props *EventScannerProps) *EventScannerError { // panic recovery defer func() { if r := recover(); r != nil { - globals.Warn(fmt.Sprintf("event source panic: %s (uri: %s, method: %s)", r, props.Uri, props.Method)) + stack := debug.Stack() + globals.Warn(fmt.Sprintf("event source panic: %s (uri: %s, method: %s)\n%s", r, props.Uri, props.Method, stack)) } }()