feat: support relay midjourney-proxy format

This commit is contained in:
Zhang Minghan 2024-03-08 22:18:45 +08:00
parent fe6c9b0d14
commit 52419e1b2b
8 changed files with 64 additions and 13 deletions

View File

@ -30,7 +30,7 @@ func (c *ChatInstance) GetChangeRequest(action string, task string, index *int)
} }
func (c *ChatInstance) CreateImagineRequest(prompt string) (*CommonResponse, error) { func (c *ChatInstance) CreateImagineRequest(prompt string) (*CommonResponse, error) {
res, err := utils.Post( content, err := utils.PostRaw(
c.GetImagineEndpoint(), c.GetImagineEndpoint(),
c.GetMidjourneyHeaders(), c.GetMidjourneyHeaders(),
c.GetImagineRequest(prompt), c.GetImagineRequest(prompt),
@ -40,7 +40,11 @@ func (c *ChatInstance) CreateImagineRequest(prompt string) (*CommonResponse, err
return nil, err return nil, err
} }
return utils.MapToStruct[CommonResponse](res), nil if data, err := utils.UnmarshalString[CommonResponse](content); err == nil {
return &data, nil
} else {
return nil, utils.ToMarkdownError(err, content)
}
} }
func (c *ChatInstance) CreateChangeRequest(action string, task string, index *int) (*CommonResponse, error) { func (c *ChatInstance) CreateChangeRequest(action string, task string, index *int) (*CommonResponse, error) {
@ -54,6 +58,5 @@ func (c *ChatInstance) CreateChangeRequest(action string, task string, index *in
return nil, err return nil, err
} }
fmt.Println(res)
return utils.MapToStruct[CommonResponse](res), nil return utils.MapToStruct[CommonResponse](res), nil
} }

View File

@ -40,7 +40,17 @@ func getMode(model string) string {
} }
} }
func (c *ChatInstance) IsIgnoreMode() bool {
return strings.HasSuffix(c.Endpoint, "/mj-relax") ||
strings.HasSuffix(c.Endpoint, "/mj-fast") ||
strings.HasSuffix(c.Endpoint, "/mj-turbo")
}
func (c *ChatInstance) GetCleanPrompt(model string, prompt string) string { func (c *ChatInstance) GetCleanPrompt(model string, prompt string) string {
if c.IsIgnoreMode() {
return prompt
}
arr := strings.Split(strings.TrimSpace(prompt), " ") arr := strings.Split(strings.TrimSpace(prompt), " ")
var res []string var res []string
@ -57,7 +67,12 @@ func (c *ChatInstance) GetCleanPrompt(model string, prompt string) string {
} }
func (c *ChatInstance) GetPrompt(props *ChatProps) string { func (c *ChatInstance) GetPrompt(props *ChatProps) string {
return c.GetCleanPrompt(props.Model, props.Messages[len(props.Messages)-1].Content) if len(props.Messages) == 0 {
return ""
}
content := props.Messages[len(props.Messages)-1].Content
return c.GetCleanPrompt(props.Model, content)
} }
func (c *ChatInstance) CreateStreamChatRequest(props *ChatProps, callback globals.Hook) error { func (c *ChatInstance) CreateStreamChatRequest(props *ChatProps, callback globals.Hook) error {
@ -72,7 +87,6 @@ func (c *ChatInstance) CreateStreamChatRequest(props *ChatProps, callback global
if len(globals.NotifyUrl) == 0 { if len(globals.NotifyUrl) == 0 {
return fmt.Errorf("format error: please provide available notify url") return fmt.Errorf("format error: please provide available notify url")
} }
action, prompt := c.ExtractPrompt(c.GetPrompt(props)) action, prompt := c.ExtractPrompt(c.GetPrompt(props))
if len(prompt) == 0 { if len(prompt) == 0 {
return fmt.Errorf("format error: please provide available prompt") return fmt.Errorf("format error: please provide available prompt")

View File

@ -93,7 +93,7 @@ func (c *ChatInstance) CreateStreamTask(action string, prompt string, hook func(
progress := -1 progress := -1
for { for {
utils.Sleep(100) utils.Sleep(50)
form := getStorage(task) form := getStorage(task)
if form == nil { if form == nil {
continue continue

View File

@ -11,6 +11,7 @@ import (
"chat/utils" "chat/utils"
"fmt" "fmt"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
"runtime/debug"
"time" "time"
) )
@ -57,8 +58,9 @@ func MockStreamSender(conn *Connection, message string) {
func ChatHandler(conn *Connection, user *auth.User, instance *conversation.Conversation) string { func ChatHandler(conn *Connection, user *auth.User, instance *conversation.Conversation) string {
defer func() { defer func() {
if err := recover(); err != nil { if err := recover(); err != nil {
globals.Warn(fmt.Sprintf("caught panic from chat handler: %s (instance: %s, client: %s)", stack := debug.Stack()
err, instance.GetModel(), conn.GetCtx().ClientIP(), globals.Warn(fmt.Sprintf("caught panic from chat handler: %s (instance: %s, client: %s)\n%s",
err, instance.GetModel(), conn.GetCtx().ClientIP(), stack,
)) ))
} }
}() }()
@ -116,7 +118,7 @@ func ChatHandler(conn *Connection, user *auth.User, instance *conversation.Conve
admin.AnalysisRequest(model, buffer, err) admin.AnalysisRequest(model, buffer, err)
if err != nil && err.Error() != "signal" { if err != nil && err.Error() != "signal" {
globals.Warn(fmt.Sprintf("caught error from chat handler: %s (instance: %s, client: %s)", err, model, conn.GetCtx().ClientIP())) globals.Warn(fmt.Sprintf("%s (model: %s, client: %s)", err, model, conn.GetCtx().ClientIP()))
auth.RevertSubscriptionUsage(db, cache, user, model) auth.RevertSubscriptionUsage(db, cache, user, model)
conn.Send(globals.ChatSegmentResponse{ conn.Send(globals.ChatSegmentResponse{

View File

@ -10,13 +10,15 @@ import (
"chat/utils" "chat/utils"
"fmt" "fmt"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
"runtime/debug"
) )
func NativeChatHandler(c *gin.Context, user *auth.User, model string, message []globals.Message, enableWeb bool) (string, float32) { func NativeChatHandler(c *gin.Context, user *auth.User, model string, message []globals.Message, enableWeb bool) (string, float32) {
defer func() { defer func() {
if err := recover(); err != nil { if err := recover(); err != nil {
globals.Warn(fmt.Sprintf("caught panic from chat handler: %s (instance: %s, client: %s)", stack := debug.Stack()
err, model, c.ClientIP(), globals.Warn(fmt.Sprintf("caught panic from chat handler: %s (instance: %s, client: %s)\n%s",
err, model, c.ClientIP(), stack,
)) ))
} }
}() }()

View File

@ -92,6 +92,20 @@ func MapToStruct[T any](data interface{}) *T {
} }
} }
func MapToRawStruct[T any](data interface{}) (*T, error) {
val, err := json.Marshal(data)
if err != nil {
return nil, err
}
form, err := Unmarshal[T](val)
if err != nil {
return nil, err
}
return &form, nil
}
func ParseInt(value string) int { func ParseInt(value string) int {
if res, err := strconv.Atoi(value); err == nil { if res, err := strconv.Atoi(value); err == nil {
return res return res
@ -332,3 +346,7 @@ func ToSecret(raw string) string {
func ToMarkdownCode(lang string, code string) string { func ToMarkdownCode(lang string, code string) string {
return fmt.Sprintf("```%s\n%s\n```", lang, code) return fmt.Sprintf("```%s\n%s\n```", lang, code)
} }
func ToMarkdownError(err error, body string) error {
return fmt.Errorf("%s\n%s", err.Error(), ToMarkdownCode("html", body))
}

View File

@ -8,6 +8,7 @@ import (
"github.com/goccy/go-json" "github.com/goccy/go-json"
"io" "io"
"net/http" "net/http"
"runtime/debug"
"strings" "strings"
"time" "time"
) )
@ -89,6 +90,14 @@ func Post(uri string, headers map[string]string, body interface{}) (data interfa
return data, err return data, err
} }
func PostRaw(uri string, headers map[string]string, body interface{}) (data string, err error) {
buffer, err := HttpRaw(uri, http.MethodPost, headers, ConvertBody(body))
if err != nil {
return "", err
}
return string(buffer), nil
}
func ConvertBody(body interface{}) (form io.Reader) { func ConvertBody(body interface{}) (form io.Reader) {
if buffer, err := json.Marshal(body); err == nil { if buffer, err := json.Marshal(body); err == nil {
form = bytes.NewBuffer(buffer) form = bytes.NewBuffer(buffer)
@ -100,7 +109,8 @@ func EventSource(method string, uri string, headers map[string]string, body inte
// panic recovery // panic recovery
defer func() { defer func() {
if err := recover(); err != nil { if err := recover(); err != nil {
globals.Warn(fmt.Sprintf("event source panic: %s (uri: %s, method: %s)", err, uri, method)) stack := debug.Stack()
globals.Warn(fmt.Sprintf("event source panic: %s (uri: %s, method: %s)\n%s", err, uri, method, stack))
} }
}() }()

View File

@ -7,6 +7,7 @@ import (
"fmt" "fmt"
"io" "io"
"net/http" "net/http"
"runtime/debug"
"strings" "strings"
) )
@ -39,7 +40,8 @@ func EventScanner(props *EventScannerProps) *EventScannerError {
// panic recovery // panic recovery
defer func() { defer func() {
if r := recover(); r != nil { if r := recover(); r != nil {
globals.Warn(fmt.Sprintf("event source panic: %s (uri: %s, method: %s)", r, props.Uri, props.Method)) stack := debug.Stack()
globals.Warn(fmt.Sprintf("event source panic: %s (uri: %s, method: %s)\n%s", r, props.Uri, props.Method, stack))
} }
}() }()