mirror of
https://github.com/coaidev/coai.git
synced 2025-05-19 21:10:18 +09:00
feat: support relay midjourney-proxy format
This commit is contained in:
parent
fe6c9b0d14
commit
52419e1b2b
@ -30,7 +30,7 @@ func (c *ChatInstance) GetChangeRequest(action string, task string, index *int)
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (c *ChatInstance) CreateImagineRequest(prompt string) (*CommonResponse, error) {
|
func (c *ChatInstance) CreateImagineRequest(prompt string) (*CommonResponse, error) {
|
||||||
res, err := utils.Post(
|
content, err := utils.PostRaw(
|
||||||
c.GetImagineEndpoint(),
|
c.GetImagineEndpoint(),
|
||||||
c.GetMidjourneyHeaders(),
|
c.GetMidjourneyHeaders(),
|
||||||
c.GetImagineRequest(prompt),
|
c.GetImagineRequest(prompt),
|
||||||
@ -40,7 +40,11 @@ func (c *ChatInstance) CreateImagineRequest(prompt string) (*CommonResponse, err
|
|||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
return utils.MapToStruct[CommonResponse](res), nil
|
if data, err := utils.UnmarshalString[CommonResponse](content); err == nil {
|
||||||
|
return &data, nil
|
||||||
|
} else {
|
||||||
|
return nil, utils.ToMarkdownError(err, content)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *ChatInstance) CreateChangeRequest(action string, task string, index *int) (*CommonResponse, error) {
|
func (c *ChatInstance) CreateChangeRequest(action string, task string, index *int) (*CommonResponse, error) {
|
||||||
@ -54,6 +58,5 @@ func (c *ChatInstance) CreateChangeRequest(action string, task string, index *in
|
|||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
fmt.Println(res)
|
|
||||||
return utils.MapToStruct[CommonResponse](res), nil
|
return utils.MapToStruct[CommonResponse](res), nil
|
||||||
}
|
}
|
||||||
|
@ -40,7 +40,17 @@ func getMode(model string) string {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (c *ChatInstance) IsIgnoreMode() bool {
|
||||||
|
return strings.HasSuffix(c.Endpoint, "/mj-relax") ||
|
||||||
|
strings.HasSuffix(c.Endpoint, "/mj-fast") ||
|
||||||
|
strings.HasSuffix(c.Endpoint, "/mj-turbo")
|
||||||
|
}
|
||||||
|
|
||||||
func (c *ChatInstance) GetCleanPrompt(model string, prompt string) string {
|
func (c *ChatInstance) GetCleanPrompt(model string, prompt string) string {
|
||||||
|
if c.IsIgnoreMode() {
|
||||||
|
return prompt
|
||||||
|
}
|
||||||
|
|
||||||
arr := strings.Split(strings.TrimSpace(prompt), " ")
|
arr := strings.Split(strings.TrimSpace(prompt), " ")
|
||||||
var res []string
|
var res []string
|
||||||
|
|
||||||
@ -57,7 +67,12 @@ func (c *ChatInstance) GetCleanPrompt(model string, prompt string) string {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (c *ChatInstance) GetPrompt(props *ChatProps) string {
|
func (c *ChatInstance) GetPrompt(props *ChatProps) string {
|
||||||
return c.GetCleanPrompt(props.Model, props.Messages[len(props.Messages)-1].Content)
|
if len(props.Messages) == 0 {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
|
||||||
|
content := props.Messages[len(props.Messages)-1].Content
|
||||||
|
return c.GetCleanPrompt(props.Model, content)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *ChatInstance) CreateStreamChatRequest(props *ChatProps, callback globals.Hook) error {
|
func (c *ChatInstance) CreateStreamChatRequest(props *ChatProps, callback globals.Hook) error {
|
||||||
@ -72,7 +87,6 @@ func (c *ChatInstance) CreateStreamChatRequest(props *ChatProps, callback global
|
|||||||
if len(globals.NotifyUrl) == 0 {
|
if len(globals.NotifyUrl) == 0 {
|
||||||
return fmt.Errorf("format error: please provide available notify url")
|
return fmt.Errorf("format error: please provide available notify url")
|
||||||
}
|
}
|
||||||
|
|
||||||
action, prompt := c.ExtractPrompt(c.GetPrompt(props))
|
action, prompt := c.ExtractPrompt(c.GetPrompt(props))
|
||||||
if len(prompt) == 0 {
|
if len(prompt) == 0 {
|
||||||
return fmt.Errorf("format error: please provide available prompt")
|
return fmt.Errorf("format error: please provide available prompt")
|
||||||
|
@ -93,7 +93,7 @@ func (c *ChatInstance) CreateStreamTask(action string, prompt string, hook func(
|
|||||||
progress := -1
|
progress := -1
|
||||||
|
|
||||||
for {
|
for {
|
||||||
utils.Sleep(100)
|
utils.Sleep(50)
|
||||||
form := getStorage(task)
|
form := getStorage(task)
|
||||||
if form == nil {
|
if form == nil {
|
||||||
continue
|
continue
|
||||||
|
@ -11,6 +11,7 @@ import (
|
|||||||
"chat/utils"
|
"chat/utils"
|
||||||
"fmt"
|
"fmt"
|
||||||
"github.com/gin-gonic/gin"
|
"github.com/gin-gonic/gin"
|
||||||
|
"runtime/debug"
|
||||||
"time"
|
"time"
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -57,8 +58,9 @@ func MockStreamSender(conn *Connection, message string) {
|
|||||||
func ChatHandler(conn *Connection, user *auth.User, instance *conversation.Conversation) string {
|
func ChatHandler(conn *Connection, user *auth.User, instance *conversation.Conversation) string {
|
||||||
defer func() {
|
defer func() {
|
||||||
if err := recover(); err != nil {
|
if err := recover(); err != nil {
|
||||||
globals.Warn(fmt.Sprintf("caught panic from chat handler: %s (instance: %s, client: %s)",
|
stack := debug.Stack()
|
||||||
err, instance.GetModel(), conn.GetCtx().ClientIP(),
|
globals.Warn(fmt.Sprintf("caught panic from chat handler: %s (instance: %s, client: %s)\n%s",
|
||||||
|
err, instance.GetModel(), conn.GetCtx().ClientIP(), stack,
|
||||||
))
|
))
|
||||||
}
|
}
|
||||||
}()
|
}()
|
||||||
@ -116,7 +118,7 @@ func ChatHandler(conn *Connection, user *auth.User, instance *conversation.Conve
|
|||||||
|
|
||||||
admin.AnalysisRequest(model, buffer, err)
|
admin.AnalysisRequest(model, buffer, err)
|
||||||
if err != nil && err.Error() != "signal" {
|
if err != nil && err.Error() != "signal" {
|
||||||
globals.Warn(fmt.Sprintf("caught error from chat handler: %s (instance: %s, client: %s)", err, model, conn.GetCtx().ClientIP()))
|
globals.Warn(fmt.Sprintf("%s (model: %s, client: %s)", err, model, conn.GetCtx().ClientIP()))
|
||||||
|
|
||||||
auth.RevertSubscriptionUsage(db, cache, user, model)
|
auth.RevertSubscriptionUsage(db, cache, user, model)
|
||||||
conn.Send(globals.ChatSegmentResponse{
|
conn.Send(globals.ChatSegmentResponse{
|
||||||
|
@ -10,13 +10,15 @@ import (
|
|||||||
"chat/utils"
|
"chat/utils"
|
||||||
"fmt"
|
"fmt"
|
||||||
"github.com/gin-gonic/gin"
|
"github.com/gin-gonic/gin"
|
||||||
|
"runtime/debug"
|
||||||
)
|
)
|
||||||
|
|
||||||
func NativeChatHandler(c *gin.Context, user *auth.User, model string, message []globals.Message, enableWeb bool) (string, float32) {
|
func NativeChatHandler(c *gin.Context, user *auth.User, model string, message []globals.Message, enableWeb bool) (string, float32) {
|
||||||
defer func() {
|
defer func() {
|
||||||
if err := recover(); err != nil {
|
if err := recover(); err != nil {
|
||||||
globals.Warn(fmt.Sprintf("caught panic from chat handler: %s (instance: %s, client: %s)",
|
stack := debug.Stack()
|
||||||
err, model, c.ClientIP(),
|
globals.Warn(fmt.Sprintf("caught panic from chat handler: %s (instance: %s, client: %s)\n%s",
|
||||||
|
err, model, c.ClientIP(), stack,
|
||||||
))
|
))
|
||||||
}
|
}
|
||||||
}()
|
}()
|
||||||
|
@ -92,6 +92,20 @@ func MapToStruct[T any](data interface{}) *T {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func MapToRawStruct[T any](data interface{}) (*T, error) {
|
||||||
|
val, err := json.Marshal(data)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
form, err := Unmarshal[T](val)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
return &form, nil
|
||||||
|
}
|
||||||
|
|
||||||
func ParseInt(value string) int {
|
func ParseInt(value string) int {
|
||||||
if res, err := strconv.Atoi(value); err == nil {
|
if res, err := strconv.Atoi(value); err == nil {
|
||||||
return res
|
return res
|
||||||
@ -332,3 +346,7 @@ func ToSecret(raw string) string {
|
|||||||
func ToMarkdownCode(lang string, code string) string {
|
func ToMarkdownCode(lang string, code string) string {
|
||||||
return fmt.Sprintf("```%s\n%s\n```", lang, code)
|
return fmt.Sprintf("```%s\n%s\n```", lang, code)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func ToMarkdownError(err error, body string) error {
|
||||||
|
return fmt.Errorf("%s\n%s", err.Error(), ToMarkdownCode("html", body))
|
||||||
|
}
|
||||||
|
12
utils/net.go
12
utils/net.go
@ -8,6 +8,7 @@ import (
|
|||||||
"github.com/goccy/go-json"
|
"github.com/goccy/go-json"
|
||||||
"io"
|
"io"
|
||||||
"net/http"
|
"net/http"
|
||||||
|
"runtime/debug"
|
||||||
"strings"
|
"strings"
|
||||||
"time"
|
"time"
|
||||||
)
|
)
|
||||||
@ -89,6 +90,14 @@ func Post(uri string, headers map[string]string, body interface{}) (data interfa
|
|||||||
return data, err
|
return data, err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func PostRaw(uri string, headers map[string]string, body interface{}) (data string, err error) {
|
||||||
|
buffer, err := HttpRaw(uri, http.MethodPost, headers, ConvertBody(body))
|
||||||
|
if err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
return string(buffer), nil
|
||||||
|
}
|
||||||
|
|
||||||
func ConvertBody(body interface{}) (form io.Reader) {
|
func ConvertBody(body interface{}) (form io.Reader) {
|
||||||
if buffer, err := json.Marshal(body); err == nil {
|
if buffer, err := json.Marshal(body); err == nil {
|
||||||
form = bytes.NewBuffer(buffer)
|
form = bytes.NewBuffer(buffer)
|
||||||
@ -100,7 +109,8 @@ func EventSource(method string, uri string, headers map[string]string, body inte
|
|||||||
// panic recovery
|
// panic recovery
|
||||||
defer func() {
|
defer func() {
|
||||||
if err := recover(); err != nil {
|
if err := recover(); err != nil {
|
||||||
globals.Warn(fmt.Sprintf("event source panic: %s (uri: %s, method: %s)", err, uri, method))
|
stack := debug.Stack()
|
||||||
|
globals.Warn(fmt.Sprintf("event source panic: %s (uri: %s, method: %s)\n%s", err, uri, method, stack))
|
||||||
}
|
}
|
||||||
}()
|
}()
|
||||||
|
|
||||||
|
@ -7,6 +7,7 @@ import (
|
|||||||
"fmt"
|
"fmt"
|
||||||
"io"
|
"io"
|
||||||
"net/http"
|
"net/http"
|
||||||
|
"runtime/debug"
|
||||||
"strings"
|
"strings"
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -39,7 +40,8 @@ func EventScanner(props *EventScannerProps) *EventScannerError {
|
|||||||
// panic recovery
|
// panic recovery
|
||||||
defer func() {
|
defer func() {
|
||||||
if r := recover(); r != nil {
|
if r := recover(); r != nil {
|
||||||
globals.Warn(fmt.Sprintf("event source panic: %s (uri: %s, method: %s)", r, props.Uri, props.Method))
|
stack := debug.Stack()
|
||||||
|
globals.Warn(fmt.Sprintf("event source panic: %s (uri: %s, method: %s)\n%s", r, props.Uri, props.Method, stack))
|
||||||
}
|
}
|
||||||
}()
|
}()
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user