mirror of
https://github.com/coaidev/coai.git
synced 2025-05-19 21:10:18 +09:00
feat: support relay midjourney-proxy format
This commit is contained in:
parent
fe6c9b0d14
commit
52419e1b2b
@ -30,7 +30,7 @@ func (c *ChatInstance) GetChangeRequest(action string, task string, index *int)
|
||||
}
|
||||
|
||||
func (c *ChatInstance) CreateImagineRequest(prompt string) (*CommonResponse, error) {
|
||||
res, err := utils.Post(
|
||||
content, err := utils.PostRaw(
|
||||
c.GetImagineEndpoint(),
|
||||
c.GetMidjourneyHeaders(),
|
||||
c.GetImagineRequest(prompt),
|
||||
@ -40,7 +40,11 @@ func (c *ChatInstance) CreateImagineRequest(prompt string) (*CommonResponse, err
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return utils.MapToStruct[CommonResponse](res), nil
|
||||
if data, err := utils.UnmarshalString[CommonResponse](content); err == nil {
|
||||
return &data, nil
|
||||
} else {
|
||||
return nil, utils.ToMarkdownError(err, content)
|
||||
}
|
||||
}
|
||||
|
||||
func (c *ChatInstance) CreateChangeRequest(action string, task string, index *int) (*CommonResponse, error) {
|
||||
@ -54,6 +58,5 @@ func (c *ChatInstance) CreateChangeRequest(action string, task string, index *in
|
||||
return nil, err
|
||||
}
|
||||
|
||||
fmt.Println(res)
|
||||
return utils.MapToStruct[CommonResponse](res), nil
|
||||
}
|
||||
|
@ -40,7 +40,17 @@ func getMode(model string) string {
|
||||
}
|
||||
}
|
||||
|
||||
func (c *ChatInstance) IsIgnoreMode() bool {
|
||||
return strings.HasSuffix(c.Endpoint, "/mj-relax") ||
|
||||
strings.HasSuffix(c.Endpoint, "/mj-fast") ||
|
||||
strings.HasSuffix(c.Endpoint, "/mj-turbo")
|
||||
}
|
||||
|
||||
func (c *ChatInstance) GetCleanPrompt(model string, prompt string) string {
|
||||
if c.IsIgnoreMode() {
|
||||
return prompt
|
||||
}
|
||||
|
||||
arr := strings.Split(strings.TrimSpace(prompt), " ")
|
||||
var res []string
|
||||
|
||||
@ -57,7 +67,12 @@ func (c *ChatInstance) GetCleanPrompt(model string, prompt string) string {
|
||||
}
|
||||
|
||||
func (c *ChatInstance) GetPrompt(props *ChatProps) string {
|
||||
return c.GetCleanPrompt(props.Model, props.Messages[len(props.Messages)-1].Content)
|
||||
if len(props.Messages) == 0 {
|
||||
return ""
|
||||
}
|
||||
|
||||
content := props.Messages[len(props.Messages)-1].Content
|
||||
return c.GetCleanPrompt(props.Model, content)
|
||||
}
|
||||
|
||||
func (c *ChatInstance) CreateStreamChatRequest(props *ChatProps, callback globals.Hook) error {
|
||||
@ -72,7 +87,6 @@ func (c *ChatInstance) CreateStreamChatRequest(props *ChatProps, callback global
|
||||
if len(globals.NotifyUrl) == 0 {
|
||||
return fmt.Errorf("format error: please provide available notify url")
|
||||
}
|
||||
|
||||
action, prompt := c.ExtractPrompt(c.GetPrompt(props))
|
||||
if len(prompt) == 0 {
|
||||
return fmt.Errorf("format error: please provide available prompt")
|
||||
|
@ -93,7 +93,7 @@ func (c *ChatInstance) CreateStreamTask(action string, prompt string, hook func(
|
||||
progress := -1
|
||||
|
||||
for {
|
||||
utils.Sleep(100)
|
||||
utils.Sleep(50)
|
||||
form := getStorage(task)
|
||||
if form == nil {
|
||||
continue
|
||||
|
@ -11,6 +11,7 @@ import (
|
||||
"chat/utils"
|
||||
"fmt"
|
||||
"github.com/gin-gonic/gin"
|
||||
"runtime/debug"
|
||||
"time"
|
||||
)
|
||||
|
||||
@ -57,8 +58,9 @@ func MockStreamSender(conn *Connection, message string) {
|
||||
func ChatHandler(conn *Connection, user *auth.User, instance *conversation.Conversation) string {
|
||||
defer func() {
|
||||
if err := recover(); err != nil {
|
||||
globals.Warn(fmt.Sprintf("caught panic from chat handler: %s (instance: %s, client: %s)",
|
||||
err, instance.GetModel(), conn.GetCtx().ClientIP(),
|
||||
stack := debug.Stack()
|
||||
globals.Warn(fmt.Sprintf("caught panic from chat handler: %s (instance: %s, client: %s)\n%s",
|
||||
err, instance.GetModel(), conn.GetCtx().ClientIP(), stack,
|
||||
))
|
||||
}
|
||||
}()
|
||||
@ -116,7 +118,7 @@ func ChatHandler(conn *Connection, user *auth.User, instance *conversation.Conve
|
||||
|
||||
admin.AnalysisRequest(model, buffer, err)
|
||||
if err != nil && err.Error() != "signal" {
|
||||
globals.Warn(fmt.Sprintf("caught error from chat handler: %s (instance: %s, client: %s)", err, model, conn.GetCtx().ClientIP()))
|
||||
globals.Warn(fmt.Sprintf("%s (model: %s, client: %s)", err, model, conn.GetCtx().ClientIP()))
|
||||
|
||||
auth.RevertSubscriptionUsage(db, cache, user, model)
|
||||
conn.Send(globals.ChatSegmentResponse{
|
||||
|
@ -10,13 +10,15 @@ import (
|
||||
"chat/utils"
|
||||
"fmt"
|
||||
"github.com/gin-gonic/gin"
|
||||
"runtime/debug"
|
||||
)
|
||||
|
||||
func NativeChatHandler(c *gin.Context, user *auth.User, model string, message []globals.Message, enableWeb bool) (string, float32) {
|
||||
defer func() {
|
||||
if err := recover(); err != nil {
|
||||
globals.Warn(fmt.Sprintf("caught panic from chat handler: %s (instance: %s, client: %s)",
|
||||
err, model, c.ClientIP(),
|
||||
stack := debug.Stack()
|
||||
globals.Warn(fmt.Sprintf("caught panic from chat handler: %s (instance: %s, client: %s)\n%s",
|
||||
err, model, c.ClientIP(), stack,
|
||||
))
|
||||
}
|
||||
}()
|
||||
|
@ -92,6 +92,20 @@ func MapToStruct[T any](data interface{}) *T {
|
||||
}
|
||||
}
|
||||
|
||||
func MapToRawStruct[T any](data interface{}) (*T, error) {
|
||||
val, err := json.Marshal(data)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
form, err := Unmarshal[T](val)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &form, nil
|
||||
}
|
||||
|
||||
func ParseInt(value string) int {
|
||||
if res, err := strconv.Atoi(value); err == nil {
|
||||
return res
|
||||
@ -332,3 +346,7 @@ func ToSecret(raw string) string {
|
||||
func ToMarkdownCode(lang string, code string) string {
|
||||
return fmt.Sprintf("```%s\n%s\n```", lang, code)
|
||||
}
|
||||
|
||||
func ToMarkdownError(err error, body string) error {
|
||||
return fmt.Errorf("%s\n%s", err.Error(), ToMarkdownCode("html", body))
|
||||
}
|
||||
|
12
utils/net.go
12
utils/net.go
@ -8,6 +8,7 @@ import (
|
||||
"github.com/goccy/go-json"
|
||||
"io"
|
||||
"net/http"
|
||||
"runtime/debug"
|
||||
"strings"
|
||||
"time"
|
||||
)
|
||||
@ -89,6 +90,14 @@ func Post(uri string, headers map[string]string, body interface{}) (data interfa
|
||||
return data, err
|
||||
}
|
||||
|
||||
func PostRaw(uri string, headers map[string]string, body interface{}) (data string, err error) {
|
||||
buffer, err := HttpRaw(uri, http.MethodPost, headers, ConvertBody(body))
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
return string(buffer), nil
|
||||
}
|
||||
|
||||
func ConvertBody(body interface{}) (form io.Reader) {
|
||||
if buffer, err := json.Marshal(body); err == nil {
|
||||
form = bytes.NewBuffer(buffer)
|
||||
@ -100,7 +109,8 @@ func EventSource(method string, uri string, headers map[string]string, body inte
|
||||
// panic recovery
|
||||
defer func() {
|
||||
if err := recover(); err != nil {
|
||||
globals.Warn(fmt.Sprintf("event source panic: %s (uri: %s, method: %s)", err, uri, method))
|
||||
stack := debug.Stack()
|
||||
globals.Warn(fmt.Sprintf("event source panic: %s (uri: %s, method: %s)\n%s", err, uri, method, stack))
|
||||
}
|
||||
}()
|
||||
|
||||
|
@ -7,6 +7,7 @@ import (
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"runtime/debug"
|
||||
"strings"
|
||||
)
|
||||
|
||||
@ -39,7 +40,8 @@ func EventScanner(props *EventScannerProps) *EventScannerError {
|
||||
// panic recovery
|
||||
defer func() {
|
||||
if r := recover(); r != nil {
|
||||
globals.Warn(fmt.Sprintf("event source panic: %s (uri: %s, method: %s)", r, props.Uri, props.Method))
|
||||
stack := debug.Stack()
|
||||
globals.Warn(fmt.Sprintf("event source panic: %s (uri: %s, method: %s)\n%s", r, props.Uri, props.Method, stack))
|
||||
}
|
||||
}()
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user