mirror of
https://github.com/coaidev/coai.git
synced 2025-05-20 21:40:15 +09:00
add midjourney models
This commit is contained in:
parent
d084e544e6
commit
042f67fd74
@ -4,6 +4,7 @@ import (
|
|||||||
"chat/adapter/bing"
|
"chat/adapter/bing"
|
||||||
"chat/adapter/claude"
|
"chat/adapter/claude"
|
||||||
"chat/adapter/dashscope"
|
"chat/adapter/dashscope"
|
||||||
|
"chat/adapter/midjourney"
|
||||||
"chat/adapter/oneapi"
|
"chat/adapter/oneapi"
|
||||||
"chat/adapter/palm2"
|
"chat/adapter/palm2"
|
||||||
"chat/adapter/slack"
|
"chat/adapter/slack"
|
||||||
@ -61,6 +62,11 @@ func NewChatRequest(props *ChatProps, hook globals.Hook) error {
|
|||||||
Model: props.Model,
|
Model: props.Model,
|
||||||
Message: props.Message,
|
Message: props.Message,
|
||||||
}, hook)
|
}, hook)
|
||||||
|
} else if globals.IsMidjourneyModel(props.Model) {
|
||||||
|
return midjourney.NewChatInstanceFromConfig().CreateStreamChatRequest(&midjourney.ChatProps{
|
||||||
|
Model: props.Model,
|
||||||
|
Messages: props.Message,
|
||||||
|
}, hook)
|
||||||
}
|
}
|
||||||
|
|
||||||
return hook("Sorry, we cannot find the model you are looking for. Please try another model.")
|
return hook("Sorry, we cannot find the model you are looking for. Please try another model.")
|
||||||
|
@ -70,7 +70,7 @@ func (c *ChatInstance) GetChatBody(props *ChatProps, stream bool) interface{} {
|
|||||||
|
|
||||||
// CreateChatRequest is the native http request body for chatgpt
|
// CreateChatRequest is the native http request body for chatgpt
|
||||||
func (c *ChatInstance) CreateChatRequest(props *ChatProps) (string, error) {
|
func (c *ChatInstance) CreateChatRequest(props *ChatProps) (string, error) {
|
||||||
if props.Model == globals.Dalle2 {
|
if globals.IsDalleModel(props.Model) {
|
||||||
return c.CreateImage(props)
|
return c.CreateImage(props)
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -95,7 +95,7 @@ func (c *ChatInstance) CreateChatRequest(props *ChatProps) (string, error) {
|
|||||||
|
|
||||||
// CreateStreamChatRequest is the stream response body for chatgpt
|
// CreateStreamChatRequest is the stream response body for chatgpt
|
||||||
func (c *ChatInstance) CreateStreamChatRequest(props *ChatProps, callback globals.Hook) error {
|
func (c *ChatInstance) CreateStreamChatRequest(props *ChatProps, callback globals.Hook) error {
|
||||||
if props.Model == globals.Dalle2 {
|
if globals.IsDalleModel(props.Model) {
|
||||||
if url, err := c.CreateImage(props); err != nil {
|
if url, err := c.CreateImage(props); err != nil {
|
||||||
return err
|
return err
|
||||||
} else {
|
} else {
|
||||||
|
@ -1,12 +1,14 @@
|
|||||||
package chatgpt
|
package chatgpt
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"chat/globals"
|
||||||
"chat/utils"
|
"chat/utils"
|
||||||
"fmt"
|
"fmt"
|
||||||
"strings"
|
"strings"
|
||||||
)
|
)
|
||||||
|
|
||||||
type ImageProps struct {
|
type ImageProps struct {
|
||||||
|
Model string
|
||||||
Prompt string
|
Prompt string
|
||||||
Size ImageSize
|
Size ImageSize
|
||||||
}
|
}
|
||||||
@ -20,8 +22,13 @@ func (c *ChatInstance) CreateImageRequest(props ImageProps) (string, error) {
|
|||||||
res, err := utils.Post(
|
res, err := utils.Post(
|
||||||
c.GetImageEndpoint(),
|
c.GetImageEndpoint(),
|
||||||
c.GetHeader(), ImageRequest{
|
c.GetHeader(), ImageRequest{
|
||||||
|
Model: props.Model,
|
||||||
Prompt: props.Prompt,
|
Prompt: props.Prompt,
|
||||||
Size: utils.Multi[ImageSize](len(props.Size) == 0, ImageSize512, props.Size),
|
Size: utils.Multi[ImageSize](
|
||||||
|
props.Model == globals.Dalle3,
|
||||||
|
ImageSize1024,
|
||||||
|
ImageSize512,
|
||||||
|
),
|
||||||
N: 1,
|
N: 1,
|
||||||
})
|
})
|
||||||
if err != nil || res == nil {
|
if err != nil || res == nil {
|
||||||
@ -41,6 +48,7 @@ func (c *ChatInstance) CreateImageRequest(props ImageProps) (string, error) {
|
|||||||
// CreateImage will create a dalle image from prompt, return markdown of image
|
// CreateImage will create a dalle image from prompt, return markdown of image
|
||||||
func (c *ChatInstance) CreateImage(props *ChatProps) (string, error) {
|
func (c *ChatInstance) CreateImage(props *ChatProps) (string, error) {
|
||||||
url, err := c.CreateImageRequest(ImageProps{
|
url, err := c.CreateImageRequest(ImageProps{
|
||||||
|
Model: props.Model,
|
||||||
Prompt: c.GetLatestPrompt(props),
|
Prompt: c.GetLatestPrompt(props),
|
||||||
})
|
})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
@ -52,7 +52,7 @@ func NewChatInstanceFromModel(props *InstanceProps) *ChatInstance {
|
|||||||
globals.GPT432k, globals.GPT432k0613, globals.GPT432k0314:
|
globals.GPT432k, globals.GPT432k0613, globals.GPT432k0314:
|
||||||
return NewChatInstanceFromConfig("gpt4")
|
return NewChatInstanceFromConfig("gpt4")
|
||||||
|
|
||||||
case globals.GPT4Vision, globals.GPT4Dalle, globals.GPT4All:
|
case globals.GPT4Vision, globals.GPT4Dalle, globals.Dalle3, globals.GPT4All:
|
||||||
return NewChatInstanceFromConfig("reverse")
|
return NewChatInstanceFromConfig("reverse")
|
||||||
|
|
||||||
case globals.GPT3Turbo, globals.GPT3TurboInstruct, globals.GPT3Turbo0613, globals.GPT3Turbo0301,
|
case globals.GPT3Turbo, globals.GPT3TurboInstruct, globals.GPT3Turbo0613, globals.GPT3Turbo0301,
|
||||||
|
@ -89,6 +89,7 @@ type ImageSize string
|
|||||||
|
|
||||||
// ImageRequest is the request body for chatgpt dalle image generation
|
// ImageRequest is the request body for chatgpt dalle image generation
|
||||||
type ImageRequest struct {
|
type ImageRequest struct {
|
||||||
|
Model string `json:"model"`
|
||||||
Prompt string `json:"prompt"`
|
Prompt string `json:"prompt"`
|
||||||
Size ImageSize `json:"size"`
|
Size ImageSize `json:"size"`
|
||||||
N int `json:"n"`
|
N int `json:"n"`
|
||||||
|
101
adapter/midjourney/api.go
Normal file
101
adapter/midjourney/api.go
Normal file
@ -0,0 +1,101 @@
|
|||||||
|
package midjourney
|
||||||
|
|
||||||
|
import (
|
||||||
|
"chat/utils"
|
||||||
|
"fmt"
|
||||||
|
"github.com/spf13/viper"
|
||||||
|
"strings"
|
||||||
|
)
|
||||||
|
|
||||||
|
func (c *ChatInstance) GetImagineUrl() string {
|
||||||
|
return fmt.Sprintf("%s/mj/submit/imagine", c.GetEndpoint())
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *ChatInstance) CreateImagineRequest(prompt string) (*ImagineResponse, error) {
|
||||||
|
res, err := utils.Post(
|
||||||
|
c.GetImagineUrl(),
|
||||||
|
map[string]string{
|
||||||
|
"Content-Type": "application/json",
|
||||||
|
"mj-api-secret": c.GetApiSecret(),
|
||||||
|
},
|
||||||
|
ImagineRequest{
|
||||||
|
NotifyHook: fmt.Sprintf("%s/mj/notify", viper.GetString("midjourney.expose")),
|
||||||
|
Prompt: prompt,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
return utils.MapToStruct[ImagineResponse](res), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func getStatusCode(code int) error {
|
||||||
|
switch code {
|
||||||
|
case SuccessCode, QueueCode:
|
||||||
|
return nil
|
||||||
|
case ExistedCode:
|
||||||
|
return fmt.Errorf("task is existed, please try again later with another prompt")
|
||||||
|
case MaxQueueCode:
|
||||||
|
return fmt.Errorf("task queue is full, please try again later")
|
||||||
|
case NudeCode:
|
||||||
|
return fmt.Errorf("prompt violates the content policy of midjourney, the request is rejected")
|
||||||
|
default:
|
||||||
|
return fmt.Errorf("unknown error from midjourney")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func getProgress(value string) int {
|
||||||
|
progress := strings.TrimSuffix(value, "%")
|
||||||
|
return utils.ParseInt(progress)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *ChatInstance) CreateStreamImagineTask(prompt string, hook func(progress int) error) (string, error) {
|
||||||
|
res, err := c.CreateImagineRequest(prompt)
|
||||||
|
if err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := getStatusCode(res.Code); err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
|
||||||
|
task := res.Result
|
||||||
|
progress := -1
|
||||||
|
|
||||||
|
for {
|
||||||
|
utils.Sleep(100)
|
||||||
|
form := getStorage(task)
|
||||||
|
if form == nil {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
switch form.Status {
|
||||||
|
case Success:
|
||||||
|
if err := hook(100); err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
return form.Url, nil
|
||||||
|
case Failure:
|
||||||
|
if err := hook(100); err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
return "", fmt.Errorf("task failed: %s", form.FailReason)
|
||||||
|
case InProgress:
|
||||||
|
current := getProgress(form.Progress)
|
||||||
|
if progress != current {
|
||||||
|
if err := hook(current); err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
progress = current
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *ChatInstance) CreateImagineTask(prompt string) (string, error) {
|
||||||
|
return c.CreateStreamImagineTask(prompt, func(progress int) error {
|
||||||
|
return nil
|
||||||
|
})
|
||||||
|
}
|
79
adapter/midjourney/chat.go
Normal file
79
adapter/midjourney/chat.go
Normal file
@ -0,0 +1,79 @@
|
|||||||
|
package midjourney
|
||||||
|
|
||||||
|
import (
|
||||||
|
"chat/globals"
|
||||||
|
"chat/utils"
|
||||||
|
"fmt"
|
||||||
|
"strings"
|
||||||
|
)
|
||||||
|
|
||||||
|
type ChatProps struct {
|
||||||
|
Messages []globals.Message
|
||||||
|
Model string
|
||||||
|
}
|
||||||
|
|
||||||
|
func getMode(model string) string {
|
||||||
|
switch model {
|
||||||
|
case globals.Midjourney: // relax
|
||||||
|
return RelaxMode
|
||||||
|
case globals.MidjourneyFast: // fast
|
||||||
|
return FastMode
|
||||||
|
case globals.MidjourneyTurbo: // turbo
|
||||||
|
return TurboMode
|
||||||
|
default:
|
||||||
|
return RelaxMode
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *ChatInstance) GetCleanPrompt(model string, prompt string) string {
|
||||||
|
arr := strings.Split(strings.TrimSpace(prompt), " ")
|
||||||
|
var res []string
|
||||||
|
|
||||||
|
for _, word := range arr {
|
||||||
|
if utils.Contains[string](word, ModeArr) {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
res = append(res, word)
|
||||||
|
}
|
||||||
|
|
||||||
|
res = append(res, getMode(model))
|
||||||
|
target := strings.Join(res, " ")
|
||||||
|
return target
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *ChatInstance) GetPrompt(props *ChatProps) string {
|
||||||
|
return c.GetCleanPrompt(props.Model, props.Messages[len(props.Messages)-1].Content)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *ChatInstance) CreateStreamChatRequest(props *ChatProps, callback globals.Hook) error {
|
||||||
|
// partial response like:
|
||||||
|
// ```progress
|
||||||
|
// 0
|
||||||
|
// ...
|
||||||
|
// 100
|
||||||
|
// ```
|
||||||
|
// 
|
||||||
|
|
||||||
|
prompt := c.GetPrompt(props)
|
||||||
|
if prompt == "" {
|
||||||
|
return fmt.Errorf("format error: please provide available prompt")
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := callback("```progress\n"); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
url, err := c.CreateStreamImagineTask(prompt, func(progress int) error {
|
||||||
|
return callback(fmt.Sprintf("%d\n", progress))
|
||||||
|
})
|
||||||
|
|
||||||
|
if err := callback("```\n"); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("error from midjourney: %s", err.Error())
|
||||||
|
}
|
||||||
|
|
||||||
|
return callback(utils.GetImageMarkdown(url))
|
||||||
|
}
|
51
adapter/midjourney/expose.go
Normal file
51
adapter/midjourney/expose.go
Normal file
@ -0,0 +1,51 @@
|
|||||||
|
package midjourney
|
||||||
|
|
||||||
|
import (
|
||||||
|
"chat/utils"
|
||||||
|
"fmt"
|
||||||
|
"github.com/gin-gonic/gin"
|
||||||
|
"github.com/spf13/viper"
|
||||||
|
"net/http"
|
||||||
|
"strings"
|
||||||
|
)
|
||||||
|
|
||||||
|
func InWhiteList(ip string) bool {
|
||||||
|
arr := strings.Split(viper.GetString("midjourney.white_list"), ",")
|
||||||
|
return utils.Contains[string](ip, arr)
|
||||||
|
}
|
||||||
|
|
||||||
|
func NotifyAPI(c *gin.Context) {
|
||||||
|
if !InWhiteList(c.ClientIP()) {
|
||||||
|
fmt.Println(fmt.Sprintf("[midjourney] notify api: banned request from %s", c.ClientIP()))
|
||||||
|
c.AbortWithStatus(http.StatusForbidden)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
var form NotifyForm
|
||||||
|
if err := c.ShouldBindJSON(&form); err != nil {
|
||||||
|
c.AbortWithStatus(http.StatusBadRequest)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
// fmt.Println(fmt.Sprintf("[midjourney] notify api: get notify: %s (from: %s)", utils.Marshal(form), c.ClientIP()))
|
||||||
|
|
||||||
|
if !utils.Contains(form.Status, []string{InProgress, Success, Failure}) {
|
||||||
|
// ignore
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
reason, ok := form.FailReason.(string)
|
||||||
|
if !ok {
|
||||||
|
reason = "unknown"
|
||||||
|
}
|
||||||
|
|
||||||
|
err := setStorage(form.Id, StorageForm{
|
||||||
|
Url: form.ImageUrl,
|
||||||
|
FailReason: reason,
|
||||||
|
Progress: form.Progress,
|
||||||
|
Status: form.Status,
|
||||||
|
})
|
||||||
|
|
||||||
|
c.JSON(http.StatusOK, gin.H{
|
||||||
|
"status": err == nil,
|
||||||
|
})
|
||||||
|
}
|
19
adapter/midjourney/storage.go
Normal file
19
adapter/midjourney/storage.go
Normal file
@ -0,0 +1,19 @@
|
|||||||
|
package midjourney
|
||||||
|
|
||||||
|
import (
|
||||||
|
"chat/connection"
|
||||||
|
"chat/utils"
|
||||||
|
"fmt"
|
||||||
|
)
|
||||||
|
|
||||||
|
func getTaskName(task string) string {
|
||||||
|
return fmt.Sprintf("nio:mj-task:%s", task)
|
||||||
|
}
|
||||||
|
|
||||||
|
func setStorage(task string, form StorageForm) error {
|
||||||
|
return utils.SetJson(connection.Cache, getTaskName(task), form, 60*60)
|
||||||
|
}
|
||||||
|
|
||||||
|
func getStorage(task string) *StorageForm {
|
||||||
|
return utils.GetJson[StorageForm](connection.Cache, getTaskName(task))
|
||||||
|
}
|
32
adapter/midjourney/struct.go
Normal file
32
adapter/midjourney/struct.go
Normal file
@ -0,0 +1,32 @@
|
|||||||
|
package midjourney
|
||||||
|
|
||||||
|
import (
|
||||||
|
"github.com/spf13/viper"
|
||||||
|
)
|
||||||
|
|
||||||
|
type ChatInstance struct {
|
||||||
|
Endpoint string
|
||||||
|
ApiSecret string
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *ChatInstance) GetApiSecret() string {
|
||||||
|
return c.ApiSecret
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *ChatInstance) GetEndpoint() string {
|
||||||
|
return c.Endpoint
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewChatInstance(endpoint string, apiSecret string) *ChatInstance {
|
||||||
|
return &ChatInstance{
|
||||||
|
Endpoint: endpoint,
|
||||||
|
ApiSecret: apiSecret,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewChatInstanceFromConfig() *ChatInstance {
|
||||||
|
return NewChatInstance(
|
||||||
|
viper.GetString("midjourney.endpoint"),
|
||||||
|
viper.GetString("midjourney.api_secret"),
|
||||||
|
)
|
||||||
|
}
|
58
adapter/midjourney/types.go
Normal file
58
adapter/midjourney/types.go
Normal file
@ -0,0 +1,58 @@
|
|||||||
|
package midjourney
|
||||||
|
|
||||||
|
const (
|
||||||
|
SuccessCode = 1
|
||||||
|
ExistedCode = 21
|
||||||
|
QueueCode = 22
|
||||||
|
MaxQueueCode = 23
|
||||||
|
NudeCode = 24
|
||||||
|
)
|
||||||
|
|
||||||
|
const (
|
||||||
|
NotStartStatus = "NOT_START"
|
||||||
|
Submitted = "SUBMITTED"
|
||||||
|
InProgress = "IN_PROGRESS"
|
||||||
|
Failure = "FAILURE"
|
||||||
|
Success = "SUCCESS"
|
||||||
|
)
|
||||||
|
|
||||||
|
const (
|
||||||
|
TurboMode = "--turbo"
|
||||||
|
FastMode = "--fast"
|
||||||
|
RelaxMode = "--relax"
|
||||||
|
)
|
||||||
|
|
||||||
|
var ModeArr = []string{TurboMode, FastMode, RelaxMode}
|
||||||
|
|
||||||
|
type ImagineRequest struct {
|
||||||
|
NotifyHook string `json:"notifyHook"`
|
||||||
|
Prompt string `json:"prompt"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type ImagineResponse struct {
|
||||||
|
Code int `json:"code"`
|
||||||
|
Description string `json:"description"`
|
||||||
|
Result string `json:"result"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type NotifyForm struct {
|
||||||
|
Id string `json:"id"`
|
||||||
|
Action string `json:"action"`
|
||||||
|
Status string `json:"status"`
|
||||||
|
Prompt string `json:"prompt"`
|
||||||
|
PromptEn string `json:"promptEn"`
|
||||||
|
Description string `json:"description"`
|
||||||
|
SubmitTime int64 `json:"submitTime"`
|
||||||
|
StartTime int64 `json:"startTime"`
|
||||||
|
FinishTime int64 `json:"finishTime"`
|
||||||
|
Progress string `json:"progress"`
|
||||||
|
ImageUrl string `json:"imageUrl"`
|
||||||
|
FailReason interface{} `json:"failReason"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type StorageForm struct {
|
||||||
|
Url string `json:"url"`
|
||||||
|
FailReason string `json:"failReason"`
|
||||||
|
Progress string `json:"progress"`
|
||||||
|
Status string `json:"status"`
|
||||||
|
}
|
10
adapter/router.go
Normal file
10
adapter/router.go
Normal file
@ -0,0 +1,10 @@
|
|||||||
|
package adapter
|
||||||
|
|
||||||
|
import (
|
||||||
|
"chat/adapter/midjourney"
|
||||||
|
"github.com/gin-gonic/gin"
|
||||||
|
)
|
||||||
|
|
||||||
|
func Register(app *gin.Engine) {
|
||||||
|
app.POST("/mj/notify", midjourney.NotifyAPI)
|
||||||
|
}
|
@ -31,12 +31,7 @@ function ChatInterface({ setTarget }: ChatInterfaceProps) {
|
|||||||
if (!ref.current) return;
|
if (!ref.current) return;
|
||||||
const el = ref.current as HTMLDivElement;
|
const el = ref.current as HTMLDivElement;
|
||||||
|
|
||||||
const event = () => {
|
const event = () => setScrollable(el.scrollTop + el.clientHeight + 20 >= el.scrollHeight);
|
||||||
setScrollable(
|
|
||||||
el.scrollTop + el.clientHeight + 20 >= el.scrollHeight, // at bottom
|
|
||||||
);
|
|
||||||
}
|
|
||||||
|
|
||||||
return addEventListeners(el, [
|
return addEventListeners(el, [
|
||||||
"scroll", "scrollend",
|
"scroll", "scrollend",
|
||||||
"resize", "touchend",
|
"resize", "touchend",
|
||||||
|
@ -8,7 +8,7 @@ import {
|
|||||||
} from "@/utils/env.ts";
|
} from "@/utils/env.ts";
|
||||||
import { getMemory } from "@/utils/memory.ts";
|
import { getMemory } from "@/utils/memory.ts";
|
||||||
|
|
||||||
export const version = "3.6.19";
|
export const version = "3.6.20";
|
||||||
export const dev: boolean = getDev();
|
export const dev: boolean = getDev();
|
||||||
export const deploy: boolean = true;
|
export const deploy: boolean = true;
|
||||||
export let rest_api: string = getRestApi(deploy);
|
export let rest_api: string = getRestApi(deploy);
|
||||||
@ -61,7 +61,12 @@ export const supportModels: Model[] = [
|
|||||||
{ id: "chat-bison-001", name: "Palm2", free: true, auth: true },
|
{ id: "chat-bison-001", name: "Palm2", free: true, auth: true },
|
||||||
|
|
||||||
// dalle models
|
// dalle models
|
||||||
{ id: "dalle", name: "DALLE2", free: true, auth: true },
|
{ id: "dall-e-3", name: "DALLE 3", free: false, auth: true },
|
||||||
|
{ id: "dall-e-2", name: "DALLE 2", free: true, auth: true },
|
||||||
|
|
||||||
|
{ id: "midjourney", name: "Midjourney", free: false, auth: true },
|
||||||
|
{ id: "midjourney-fast", name: "Midjourney Fast", free: false, auth: true },
|
||||||
|
{ id: "midjourney-turbo", name: "Midjourney Turbo", free: false, auth: true },
|
||||||
|
|
||||||
// reverse models
|
// reverse models
|
||||||
{ id: "gpt-4-v", name: "GPT-4 Vision", free: false, auth: true },
|
{ id: "gpt-4-v", name: "GPT-4 Vision", free: false, auth: true },
|
||||||
@ -96,7 +101,11 @@ export const planModels = [
|
|||||||
"claude-2-100k",
|
"claude-2-100k",
|
||||||
];
|
];
|
||||||
|
|
||||||
export const expensiveModels = ["gpt-4-32k-0613"];
|
export const expensiveModels = [
|
||||||
|
"dall-e-3",
|
||||||
|
"midjourney-turbo",
|
||||||
|
"gpt-4-32k-0613",
|
||||||
|
];
|
||||||
|
|
||||||
export function login() {
|
export function login() {
|
||||||
location.href = `https://deeptrain.net/login?app=${dev ? "dev" : "chatnio"}`;
|
location.href = `https://deeptrain.net/login?app=${dev ? "dev" : "chatnio"}`;
|
||||||
|
@ -22,7 +22,6 @@ export const apiSlice = createSlice({
|
|||||||
state.dialog = false;
|
state.dialog = false;
|
||||||
},
|
},
|
||||||
setKey: (state, action) => {
|
setKey: (state, action) => {
|
||||||
if (!action.payload.length) return;
|
|
||||||
state.key = action.payload as string;
|
state.key = action.payload as string;
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
@ -35,7 +34,13 @@ export default apiSlice.reducer;
|
|||||||
export const dialogSelector = (state: RootState): boolean => state.api.dialog;
|
export const dialogSelector = (state: RootState): boolean => state.api.dialog;
|
||||||
export const keySelector = (state: RootState): string => state.api.key;
|
export const keySelector = (state: RootState): string => state.api.key;
|
||||||
|
|
||||||
export const getApiKey = async (dispatch: AppDispatch) => {
|
export const getApiKey = async (dispatch: AppDispatch, retries?: boolean) => {
|
||||||
const response = await getKey();
|
const response = await getKey();
|
||||||
if (response.status) dispatch(setKey(response.key));
|
if (response.status) {
|
||||||
|
if (response.key.length === 0 && retries !== false) {
|
||||||
|
await getApiKey(dispatch, false);
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
dispatch(setKey(response.key));
|
||||||
|
};
|
||||||
};
|
};
|
||||||
|
@ -11,7 +11,7 @@ func CanEnableModel(db *sql.DB, user *User, model string) bool {
|
|||||||
switch model {
|
switch model {
|
||||||
case globals.GPT3Turbo, globals.GPT3TurboInstruct, globals.GPT3Turbo0301, globals.GPT3Turbo0613:
|
case globals.GPT3Turbo, globals.GPT3TurboInstruct, globals.GPT3Turbo0301, globals.GPT3Turbo0613:
|
||||||
return true
|
return true
|
||||||
case globals.GPT4, globals.GPT40613, globals.GPT40314, globals.GPT41106Preview:
|
case globals.GPT4, globals.GPT40613, globals.GPT40314, globals.GPT41106Preview, globals.Dalle3:
|
||||||
return user != nil && user.GetQuota(db) >= 5
|
return user != nil && user.GetQuota(db) >= 5
|
||||||
case globals.GPT432k, globals.GPT432k0613, globals.GPT432k0314:
|
case globals.GPT432k, globals.GPT432k0613, globals.GPT432k0314:
|
||||||
return user != nil && user.GetQuota(db) >= 50
|
return user != nil && user.GetQuota(db) >= 50
|
||||||
@ -23,7 +23,7 @@ func CanEnableModel(db *sql.DB, user *User, model string) bool {
|
|||||||
return user != nil && user.GetQuota(db) >= 1
|
return user != nil && user.GetQuota(db) >= 1
|
||||||
case globals.QwenTurbo, globals.QwenPlus, globals.QwenPlusNet, globals.QwenTurboNet:
|
case globals.QwenTurbo, globals.QwenPlus, globals.QwenPlusNet, globals.QwenTurboNet:
|
||||||
return user != nil && user.GetQuota(db) >= 1
|
return user != nil && user.GetQuota(db) >= 1
|
||||||
case globals.Midjourney, globals.StableDiffusion:
|
case globals.StableDiffusion, globals.Midjourney, globals.MidjourneyFast, globals.MidjourneyTurbo:
|
||||||
return user != nil && user.GetQuota(db) >= 1
|
return user != nil && user.GetQuota(db) >= 1
|
||||||
case globals.LLaMa27B, globals.LLaMa213B, globals.LLaMa270B,
|
case globals.LLaMa27B, globals.LLaMa213B, globals.LLaMa270B,
|
||||||
globals.CodeLLaMa34B, globals.CodeLLaMa13B, globals.CodeLLaMa7B:
|
globals.CodeLLaMa34B, globals.CodeLLaMa13B, globals.CodeLLaMa7B:
|
||||||
|
@ -56,7 +56,8 @@ const (
|
|||||||
GPT432k = "gpt-4-32k"
|
GPT432k = "gpt-4-32k"
|
||||||
GPT432k0314 = "gpt-4-32k-0314"
|
GPT432k0314 = "gpt-4-32k-0314"
|
||||||
GPT432k0613 = "gpt-4-32k-0613"
|
GPT432k0613 = "gpt-4-32k-0613"
|
||||||
Dalle2 = "dalle"
|
Dalle2 = "dall-e-2"
|
||||||
|
Dalle3 = "dall-e-3"
|
||||||
Claude1 = "claude-1"
|
Claude1 = "claude-1"
|
||||||
Claude1100k = "claude-1.3"
|
Claude1100k = "claude-1.3"
|
||||||
Claude2 = "claude-1-100k"
|
Claude2 = "claude-1-100k"
|
||||||
@ -78,6 +79,8 @@ const (
|
|||||||
QwenTurboNet = "qwen-turbo-net"
|
QwenTurboNet = "qwen-turbo-net"
|
||||||
QwenPlusNet = "qwen-plus-net"
|
QwenPlusNet = "qwen-plus-net"
|
||||||
Midjourney = "midjourney"
|
Midjourney = "midjourney"
|
||||||
|
MidjourneyFast = "midjourney-fast"
|
||||||
|
MidjourneyTurbo = "midjourney-turbo"
|
||||||
StableDiffusion = "stable-diffusion"
|
StableDiffusion = "stable-diffusion"
|
||||||
LLaMa270B = "llama-2-70b"
|
LLaMa270B = "llama-2-70b"
|
||||||
LLaMa213B = "llama-2-13b"
|
LLaMa213B = "llama-2-13b"
|
||||||
@ -148,6 +151,12 @@ var QwenModelArray = []string{
|
|||||||
QwenPlusNet,
|
QwenPlusNet,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
var MidjourneyModelArray = []string{
|
||||||
|
Midjourney,
|
||||||
|
MidjourneyFast,
|
||||||
|
MidjourneyTurbo,
|
||||||
|
}
|
||||||
|
|
||||||
var LongContextModelArray = []string{
|
var LongContextModelArray = []string{
|
||||||
GPT3Turbo16k, GPT3Turbo16k0613, GPT3Turbo16k0301,
|
GPT3Turbo16k, GPT3Turbo16k0613, GPT3Turbo16k0301,
|
||||||
GPT41106Preview, GPT432k, GPT432k0314, GPT432k0613,
|
GPT41106Preview, GPT432k, GPT432k0314, GPT432k0613,
|
||||||
@ -179,14 +188,14 @@ var AllModels = []string{
|
|||||||
GPT3Turbo16k, GPT3Turbo16k0613, GPT3Turbo16k0301,
|
GPT3Turbo16k, GPT3Turbo16k0613, GPT3Turbo16k0301,
|
||||||
GPT4, GPT40314, GPT40613, GPT4Vision, GPT4All, GPT41106Preview, GPT4Dalle,
|
GPT4, GPT40314, GPT40613, GPT4Vision, GPT4All, GPT41106Preview, GPT4Dalle,
|
||||||
GPT432k, GPT432k0314, GPT432k0613,
|
GPT432k, GPT432k0314, GPT432k0613,
|
||||||
Dalle2,
|
Dalle2, Dalle3,
|
||||||
Claude1, Claude1100k, Claude2, Claude2100k, ClaudeSlack,
|
Claude1, Claude1100k, Claude2, Claude2100k, ClaudeSlack,
|
||||||
SparkDesk, SparkDeskV2, SparkDeskV3,
|
SparkDesk, SparkDeskV2, SparkDeskV3,
|
||||||
ChatBison001,
|
ChatBison001,
|
||||||
BingCreative, BingBalanced, BingPrecise,
|
BingCreative, BingBalanced, BingPrecise,
|
||||||
ZhiPuChatGLMTurbo, ZhiPuChatGLMPro, ZhiPuChatGLMStd, ZhiPuChatGLMLite,
|
ZhiPuChatGLMTurbo, ZhiPuChatGLMPro, ZhiPuChatGLMStd, ZhiPuChatGLMLite,
|
||||||
QwenTurbo, QwenPlus, QwenTurboNet, QwenPlusNet,
|
QwenTurbo, QwenPlus, QwenTurboNet, QwenPlusNet,
|
||||||
Midjourney, StableDiffusion,
|
StableDiffusion, Midjourney, MidjourneyFast, MidjourneyTurbo,
|
||||||
LLaMa270B, LLaMa213B, LLaMa27B,
|
LLaMa270B, LLaMa213B, LLaMa27B,
|
||||||
CodeLLaMa34B, CodeLLaMa13B, CodeLLaMa7B,
|
CodeLLaMa34B, CodeLLaMa13B, CodeLLaMa7B,
|
||||||
}
|
}
|
||||||
@ -213,7 +222,7 @@ func IsGPT3TurboModel(model string) bool {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func IsChatGPTModel(model string) bool {
|
func IsChatGPTModel(model string) bool {
|
||||||
return IsGPT3TurboModel(model) || IsGPT4Model(model)
|
return IsGPT3TurboModel(model) || IsGPT4Model(model) || IsDalleModel(model)
|
||||||
}
|
}
|
||||||
|
|
||||||
func IsClaudeModel(model string) bool {
|
func IsClaudeModel(model string) bool {
|
||||||
@ -224,6 +233,10 @@ func IsLLaMaModel(model string) bool {
|
|||||||
return in(model, LLaMaModelArray)
|
return in(model, LLaMaModelArray)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func IsDalleModel(model string) bool {
|
||||||
|
return model == Dalle2 || model == Dalle3
|
||||||
|
}
|
||||||
|
|
||||||
func IsClaude100KModel(model string) bool {
|
func IsClaude100KModel(model string) bool {
|
||||||
return model == Claude1100k || model == Claude2100k
|
return model == Claude1100k || model == Claude2100k
|
||||||
}
|
}
|
||||||
@ -252,6 +265,10 @@ func IsQwenModel(model string) bool {
|
|||||||
return in(model, QwenModelArray)
|
return in(model, QwenModelArray)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func IsMidjourneyModel(model string) bool {
|
||||||
|
return in(model, MidjourneyModelArray)
|
||||||
|
}
|
||||||
|
|
||||||
func IsLongContextModel(model string) bool {
|
func IsLongContextModel(model string) bool {
|
||||||
return in(model, LongContextModelArray)
|
return in(model, LongContextModelArray)
|
||||||
}
|
}
|
||||||
|
2
main.go
2
main.go
@ -1,6 +1,7 @@
|
|||||||
package main
|
package main
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"chat/adapter"
|
||||||
"chat/addition"
|
"chat/addition"
|
||||||
"chat/admin"
|
"chat/admin"
|
||||||
"chat/auth"
|
"chat/auth"
|
||||||
@ -29,6 +30,7 @@ func main() {
|
|||||||
{
|
{
|
||||||
auth.Register(app)
|
auth.Register(app)
|
||||||
admin.Register(app)
|
admin.Register(app)
|
||||||
|
adapter.Register(app)
|
||||||
manager.Register(app)
|
manager.Register(app)
|
||||||
addition.Register(app)
|
addition.Register(app)
|
||||||
conversation.Register(app)
|
conversation.Register(app)
|
||||||
|
@ -3,6 +3,7 @@ package utils
|
|||||||
import (
|
import (
|
||||||
"fmt"
|
"fmt"
|
||||||
"github.com/goccy/go-json"
|
"github.com/goccy/go-json"
|
||||||
|
"time"
|
||||||
)
|
)
|
||||||
|
|
||||||
func Sum[T int | int64 | float32 | float64](arr []T) T {
|
func Sum[T int | int64 | float32 | float64](arr []T) T {
|
||||||
@ -149,3 +150,7 @@ func EachNotNil[T any, U any](arr []T, f func(T) *U) []U {
|
|||||||
}
|
}
|
||||||
return res
|
return res
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func Sleep(ms int) {
|
||||||
|
time.Sleep(time.Duration(ms) * time.Millisecond)
|
||||||
|
}
|
||||||
|
@ -31,6 +31,19 @@ func SetInt(cache *redis.Client, key string, value int64, expiration int64) erro
|
|||||||
return cache.Set(context.Background(), key, value, time.Duration(expiration)*time.Second).Err()
|
return cache.Set(context.Background(), key, value, time.Duration(expiration)*time.Second).Err()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func SetJson(cache *redis.Client, key string, value interface{}, expiration int64) error {
|
||||||
|
err := cache.Set(context.Background(), key, Marshal(value), time.Duration(expiration)*time.Second).Err()
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
func GetJson[T any](cache *redis.Client, key string) *T {
|
||||||
|
val, err := cache.Get(context.Background(), key).Result()
|
||||||
|
if err != nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
return UnmarshalForm[T](val)
|
||||||
|
}
|
||||||
|
|
||||||
func IncrWithLimit(cache *redis.Client, key string, delta int64, limit int64, expiration int64) bool {
|
func IncrWithLimit(cache *redis.Client, key string, delta int64, limit int64, expiration int64) bool {
|
||||||
// not exist
|
// not exist
|
||||||
if _, err := cache.Get(context.Background(), key).Result(); err != nil {
|
if _, err := cache.Get(context.Background(), key).Result(); err != nil {
|
||||||
|
@ -69,7 +69,7 @@ func MapToStruct[T any](data interface{}) *T {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func ToInt(value string) int {
|
func ParseInt(value string) int {
|
||||||
if res, err := strconv.Atoi(value); err == nil {
|
if res, err := strconv.Atoi(value); err == nil {
|
||||||
return res
|
return res
|
||||||
} else {
|
} else {
|
||||||
|
@ -145,6 +145,12 @@ func CountOutputToken(model string, t int) float32 {
|
|||||||
return 0.25
|
return 0.25
|
||||||
case globals.Midjourney:
|
case globals.Midjourney:
|
||||||
return 0.5
|
return 0.5
|
||||||
|
case globals.MidjourneyFast:
|
||||||
|
return 2
|
||||||
|
case globals.MidjourneyTurbo:
|
||||||
|
return 5
|
||||||
|
case globals.Dalle3:
|
||||||
|
return 5.6
|
||||||
default:
|
default:
|
||||||
return 0
|
return 0
|
||||||
}
|
}
|
||||||
|
Loading…
Reference in New Issue
Block a user