add midjourney models

This commit is contained in:
Zhang Minghan 2023-11-13 11:36:58 +08:00
parent d084e544e6
commit 042f67fd74
22 changed files with 441 additions and 24 deletions

View File

@ -4,6 +4,7 @@ import (
"chat/adapter/bing"
"chat/adapter/claude"
"chat/adapter/dashscope"
"chat/adapter/midjourney"
"chat/adapter/oneapi"
"chat/adapter/palm2"
"chat/adapter/slack"
@ -61,6 +62,11 @@ func NewChatRequest(props *ChatProps, hook globals.Hook) error {
Model: props.Model,
Message: props.Message,
}, hook)
} else if globals.IsMidjourneyModel(props.Model) {
return midjourney.NewChatInstanceFromConfig().CreateStreamChatRequest(&midjourney.ChatProps{
Model: props.Model,
Messages: props.Message,
}, hook)
}
return hook("Sorry, we cannot find the model you are looking for. Please try another model.")

View File

@ -70,7 +70,7 @@ func (c *ChatInstance) GetChatBody(props *ChatProps, stream bool) interface{} {
// CreateChatRequest is the native http request body for chatgpt
func (c *ChatInstance) CreateChatRequest(props *ChatProps) (string, error) {
if props.Model == globals.Dalle2 {
if globals.IsDalleModel(props.Model) {
return c.CreateImage(props)
}
@ -95,7 +95,7 @@ func (c *ChatInstance) CreateChatRequest(props *ChatProps) (string, error) {
// CreateStreamChatRequest is the stream response body for chatgpt
func (c *ChatInstance) CreateStreamChatRequest(props *ChatProps, callback globals.Hook) error {
if props.Model == globals.Dalle2 {
if globals.IsDalleModel(props.Model) {
if url, err := c.CreateImage(props); err != nil {
return err
} else {

View File

@ -1,12 +1,14 @@
package chatgpt
import (
"chat/globals"
"chat/utils"
"fmt"
"strings"
)
type ImageProps struct {
Model string
Prompt string
Size ImageSize
}
@ -20,9 +22,14 @@ func (c *ChatInstance) CreateImageRequest(props ImageProps) (string, error) {
res, err := utils.Post(
c.GetImageEndpoint(),
c.GetHeader(), ImageRequest{
Model: props.Model,
Prompt: props.Prompt,
Size: utils.Multi[ImageSize](len(props.Size) == 0, ImageSize512, props.Size),
N: 1,
Size: utils.Multi[ImageSize](
props.Model == globals.Dalle3,
ImageSize1024,
ImageSize512,
),
N: 1,
})
if err != nil || res == nil {
return "", fmt.Errorf("chatgpt error: %s", err.Error())
@ -41,6 +48,7 @@ func (c *ChatInstance) CreateImageRequest(props ImageProps) (string, error) {
// CreateImage will create a dalle image from prompt, return markdown of image
func (c *ChatInstance) CreateImage(props *ChatProps) (string, error) {
url, err := c.CreateImageRequest(ImageProps{
Model: props.Model,
Prompt: c.GetLatestPrompt(props),
})
if err != nil {

View File

@ -52,7 +52,7 @@ func NewChatInstanceFromModel(props *InstanceProps) *ChatInstance {
globals.GPT432k, globals.GPT432k0613, globals.GPT432k0314:
return NewChatInstanceFromConfig("gpt4")
case globals.GPT4Vision, globals.GPT4Dalle, globals.GPT4All:
case globals.GPT4Vision, globals.GPT4Dalle, globals.Dalle3, globals.GPT4All:
return NewChatInstanceFromConfig("reverse")
case globals.GPT3Turbo, globals.GPT3TurboInstruct, globals.GPT3Turbo0613, globals.GPT3Turbo0301,

View File

@ -89,6 +89,7 @@ type ImageSize string
// ImageRequest is the request body for chatgpt dalle image generation
type ImageRequest struct {
Model string `json:"model"`
Prompt string `json:"prompt"`
Size ImageSize `json:"size"`
N int `json:"n"`

101
adapter/midjourney/api.go Normal file
View File

@ -0,0 +1,101 @@
package midjourney
import (
"chat/utils"
"fmt"
"github.com/spf13/viper"
"strings"
)
func (c *ChatInstance) GetImagineUrl() string {
return fmt.Sprintf("%s/mj/submit/imagine", c.GetEndpoint())
}
func (c *ChatInstance) CreateImagineRequest(prompt string) (*ImagineResponse, error) {
res, err := utils.Post(
c.GetImagineUrl(),
map[string]string{
"Content-Type": "application/json",
"mj-api-secret": c.GetApiSecret(),
},
ImagineRequest{
NotifyHook: fmt.Sprintf("%s/mj/notify", viper.GetString("midjourney.expose")),
Prompt: prompt,
},
)
if err != nil {
return nil, err
}
return utils.MapToStruct[ImagineResponse](res), nil
}
func getStatusCode(code int) error {
switch code {
case SuccessCode, QueueCode:
return nil
case ExistedCode:
return fmt.Errorf("task is existed, please try again later with another prompt")
case MaxQueueCode:
return fmt.Errorf("task queue is full, please try again later")
case NudeCode:
return fmt.Errorf("prompt violates the content policy of midjourney, the request is rejected")
default:
return fmt.Errorf("unknown error from midjourney")
}
}
func getProgress(value string) int {
progress := strings.TrimSuffix(value, "%")
return utils.ParseInt(progress)
}
func (c *ChatInstance) CreateStreamImagineTask(prompt string, hook func(progress int) error) (string, error) {
res, err := c.CreateImagineRequest(prompt)
if err != nil {
return "", err
}
if err := getStatusCode(res.Code); err != nil {
return "", err
}
task := res.Result
progress := -1
for {
utils.Sleep(100)
form := getStorage(task)
if form == nil {
continue
}
switch form.Status {
case Success:
if err := hook(100); err != nil {
return "", err
}
return form.Url, nil
case Failure:
if err := hook(100); err != nil {
return "", err
}
return "", fmt.Errorf("task failed: %s", form.FailReason)
case InProgress:
current := getProgress(form.Progress)
if progress != current {
if err := hook(current); err != nil {
return "", err
}
progress = current
}
}
}
}
func (c *ChatInstance) CreateImagineTask(prompt string) (string, error) {
return c.CreateStreamImagineTask(prompt, func(progress int) error {
return nil
})
}

View File

@ -0,0 +1,79 @@
package midjourney
import (
"chat/globals"
"chat/utils"
"fmt"
"strings"
)
type ChatProps struct {
Messages []globals.Message
Model string
}
func getMode(model string) string {
switch model {
case globals.Midjourney: // relax
return RelaxMode
case globals.MidjourneyFast: // fast
return FastMode
case globals.MidjourneyTurbo: // turbo
return TurboMode
default:
return RelaxMode
}
}
func (c *ChatInstance) GetCleanPrompt(model string, prompt string) string {
arr := strings.Split(strings.TrimSpace(prompt), " ")
var res []string
for _, word := range arr {
if utils.Contains[string](word, ModeArr) {
continue
}
res = append(res, word)
}
res = append(res, getMode(model))
target := strings.Join(res, " ")
return target
}
func (c *ChatInstance) GetPrompt(props *ChatProps) string {
return c.GetCleanPrompt(props.Model, props.Messages[len(props.Messages)-1].Content)
}
func (c *ChatInstance) CreateStreamChatRequest(props *ChatProps, callback globals.Hook) error {
// partial response like:
// ```progress
// 0
// ...
// 100
// ```
// ![image](...)
prompt := c.GetPrompt(props)
if prompt == "" {
return fmt.Errorf("format error: please provide available prompt")
}
if err := callback("```progress\n"); err != nil {
return err
}
url, err := c.CreateStreamImagineTask(prompt, func(progress int) error {
return callback(fmt.Sprintf("%d\n", progress))
})
if err := callback("```\n"); err != nil {
return err
}
if err != nil {
return fmt.Errorf("error from midjourney: %s", err.Error())
}
return callback(utils.GetImageMarkdown(url))
}

View File

@ -0,0 +1,51 @@
package midjourney
import (
"chat/utils"
"fmt"
"github.com/gin-gonic/gin"
"github.com/spf13/viper"
"net/http"
"strings"
)
func InWhiteList(ip string) bool {
arr := strings.Split(viper.GetString("midjourney.white_list"), ",")
return utils.Contains[string](ip, arr)
}
func NotifyAPI(c *gin.Context) {
if !InWhiteList(c.ClientIP()) {
fmt.Println(fmt.Sprintf("[midjourney] notify api: banned request from %s", c.ClientIP()))
c.AbortWithStatus(http.StatusForbidden)
return
}
var form NotifyForm
if err := c.ShouldBindJSON(&form); err != nil {
c.AbortWithStatus(http.StatusBadRequest)
return
}
// fmt.Println(fmt.Sprintf("[midjourney] notify api: get notify: %s (from: %s)", utils.Marshal(form), c.ClientIP()))
if !utils.Contains(form.Status, []string{InProgress, Success, Failure}) {
// ignore
return
}
reason, ok := form.FailReason.(string)
if !ok {
reason = "unknown"
}
err := setStorage(form.Id, StorageForm{
Url: form.ImageUrl,
FailReason: reason,
Progress: form.Progress,
Status: form.Status,
})
c.JSON(http.StatusOK, gin.H{
"status": err == nil,
})
}

View File

@ -0,0 +1,19 @@
package midjourney
import (
"chat/connection"
"chat/utils"
"fmt"
)
func getTaskName(task string) string {
return fmt.Sprintf("nio:mj-task:%s", task)
}
func setStorage(task string, form StorageForm) error {
return utils.SetJson(connection.Cache, getTaskName(task), form, 60*60)
}
func getStorage(task string) *StorageForm {
return utils.GetJson[StorageForm](connection.Cache, getTaskName(task))
}

View File

@ -0,0 +1,32 @@
package midjourney
import (
"github.com/spf13/viper"
)
type ChatInstance struct {
Endpoint string
ApiSecret string
}
func (c *ChatInstance) GetApiSecret() string {
return c.ApiSecret
}
func (c *ChatInstance) GetEndpoint() string {
return c.Endpoint
}
func NewChatInstance(endpoint string, apiSecret string) *ChatInstance {
return &ChatInstance{
Endpoint: endpoint,
ApiSecret: apiSecret,
}
}
func NewChatInstanceFromConfig() *ChatInstance {
return NewChatInstance(
viper.GetString("midjourney.endpoint"),
viper.GetString("midjourney.api_secret"),
)
}

View File

@ -0,0 +1,58 @@
package midjourney
const (
SuccessCode = 1
ExistedCode = 21
QueueCode = 22
MaxQueueCode = 23
NudeCode = 24
)
const (
NotStartStatus = "NOT_START"
Submitted = "SUBMITTED"
InProgress = "IN_PROGRESS"
Failure = "FAILURE"
Success = "SUCCESS"
)
const (
TurboMode = "--turbo"
FastMode = "--fast"
RelaxMode = "--relax"
)
var ModeArr = []string{TurboMode, FastMode, RelaxMode}
type ImagineRequest struct {
NotifyHook string `json:"notifyHook"`
Prompt string `json:"prompt"`
}
type ImagineResponse struct {
Code int `json:"code"`
Description string `json:"description"`
Result string `json:"result"`
}
type NotifyForm struct {
Id string `json:"id"`
Action string `json:"action"`
Status string `json:"status"`
Prompt string `json:"prompt"`
PromptEn string `json:"promptEn"`
Description string `json:"description"`
SubmitTime int64 `json:"submitTime"`
StartTime int64 `json:"startTime"`
FinishTime int64 `json:"finishTime"`
Progress string `json:"progress"`
ImageUrl string `json:"imageUrl"`
FailReason interface{} `json:"failReason"`
}
type StorageForm struct {
Url string `json:"url"`
FailReason string `json:"failReason"`
Progress string `json:"progress"`
Status string `json:"status"`
}

10
adapter/router.go Normal file
View File

@ -0,0 +1,10 @@
package adapter
import (
"chat/adapter/midjourney"
"github.com/gin-gonic/gin"
)
func Register(app *gin.Engine) {
app.POST("/mj/notify", midjourney.NotifyAPI)
}

View File

@ -31,12 +31,7 @@ function ChatInterface({ setTarget }: ChatInterfaceProps) {
if (!ref.current) return;
const el = ref.current as HTMLDivElement;
const event = () => {
setScrollable(
el.scrollTop + el.clientHeight + 20 >= el.scrollHeight, // at bottom
);
}
const event = () => setScrollable(el.scrollTop + el.clientHeight + 20 >= el.scrollHeight);
return addEventListeners(el, [
"scroll", "scrollend",
"resize", "touchend",

View File

@ -8,7 +8,7 @@ import {
} from "@/utils/env.ts";
import { getMemory } from "@/utils/memory.ts";
export const version = "3.6.19";
export const version = "3.6.20";
export const dev: boolean = getDev();
export const deploy: boolean = true;
export let rest_api: string = getRestApi(deploy);
@ -61,7 +61,12 @@ export const supportModels: Model[] = [
{ id: "chat-bison-001", name: "Palm2", free: true, auth: true },
// dalle models
{ id: "dalle", name: "DALLE2", free: true, auth: true },
{ id: "dall-e-3", name: "DALLE 3", free: false, auth: true },
{ id: "dall-e-2", name: "DALLE 2", free: true, auth: true },
{ id: "midjourney", name: "Midjourney", free: false, auth: true },
{ id: "midjourney-fast", name: "Midjourney Fast", free: false, auth: true },
{ id: "midjourney-turbo", name: "Midjourney Turbo", free: false, auth: true },
// reverse models
{ id: "gpt-4-v", name: "GPT-4 Vision", free: false, auth: true },
@ -96,7 +101,11 @@ export const planModels = [
"claude-2-100k",
];
export const expensiveModels = ["gpt-4-32k-0613"];
export const expensiveModels = [
"dall-e-3",
"midjourney-turbo",
"gpt-4-32k-0613",
];
export function login() {
location.href = `https://deeptrain.net/login?app=${dev ? "dev" : "chatnio"}`;

View File

@ -22,7 +22,6 @@ export const apiSlice = createSlice({
state.dialog = false;
},
setKey: (state, action) => {
if (!action.payload.length) return;
state.key = action.payload as string;
},
},
@ -35,7 +34,13 @@ export default apiSlice.reducer;
export const dialogSelector = (state: RootState): boolean => state.api.dialog;
export const keySelector = (state: RootState): string => state.api.key;
export const getApiKey = async (dispatch: AppDispatch) => {
export const getApiKey = async (dispatch: AppDispatch, retries?: boolean) => {
const response = await getKey();
if (response.status) dispatch(setKey(response.key));
if (response.status) {
if (response.key.length === 0 && retries !== false) {
await getApiKey(dispatch, false);
return;
}
dispatch(setKey(response.key));
};
};

View File

@ -11,7 +11,7 @@ func CanEnableModel(db *sql.DB, user *User, model string) bool {
switch model {
case globals.GPT3Turbo, globals.GPT3TurboInstruct, globals.GPT3Turbo0301, globals.GPT3Turbo0613:
return true
case globals.GPT4, globals.GPT40613, globals.GPT40314, globals.GPT41106Preview:
case globals.GPT4, globals.GPT40613, globals.GPT40314, globals.GPT41106Preview, globals.Dalle3:
return user != nil && user.GetQuota(db) >= 5
case globals.GPT432k, globals.GPT432k0613, globals.GPT432k0314:
return user != nil && user.GetQuota(db) >= 50
@ -23,7 +23,7 @@ func CanEnableModel(db *sql.DB, user *User, model string) bool {
return user != nil && user.GetQuota(db) >= 1
case globals.QwenTurbo, globals.QwenPlus, globals.QwenPlusNet, globals.QwenTurboNet:
return user != nil && user.GetQuota(db) >= 1
case globals.Midjourney, globals.StableDiffusion:
case globals.StableDiffusion, globals.Midjourney, globals.MidjourneyFast, globals.MidjourneyTurbo:
return user != nil && user.GetQuota(db) >= 1
case globals.LLaMa27B, globals.LLaMa213B, globals.LLaMa270B,
globals.CodeLLaMa34B, globals.CodeLLaMa13B, globals.CodeLLaMa7B:

View File

@ -56,7 +56,8 @@ const (
GPT432k = "gpt-4-32k"
GPT432k0314 = "gpt-4-32k-0314"
GPT432k0613 = "gpt-4-32k-0613"
Dalle2 = "dalle"
Dalle2 = "dall-e-2"
Dalle3 = "dall-e-3"
Claude1 = "claude-1"
Claude1100k = "claude-1.3"
Claude2 = "claude-1-100k"
@ -78,6 +79,8 @@ const (
QwenTurboNet = "qwen-turbo-net"
QwenPlusNet = "qwen-plus-net"
Midjourney = "midjourney"
MidjourneyFast = "midjourney-fast"
MidjourneyTurbo = "midjourney-turbo"
StableDiffusion = "stable-diffusion"
LLaMa270B = "llama-2-70b"
LLaMa213B = "llama-2-13b"
@ -148,6 +151,12 @@ var QwenModelArray = []string{
QwenPlusNet,
}
var MidjourneyModelArray = []string{
Midjourney,
MidjourneyFast,
MidjourneyTurbo,
}
var LongContextModelArray = []string{
GPT3Turbo16k, GPT3Turbo16k0613, GPT3Turbo16k0301,
GPT41106Preview, GPT432k, GPT432k0314, GPT432k0613,
@ -179,14 +188,14 @@ var AllModels = []string{
GPT3Turbo16k, GPT3Turbo16k0613, GPT3Turbo16k0301,
GPT4, GPT40314, GPT40613, GPT4Vision, GPT4All, GPT41106Preview, GPT4Dalle,
GPT432k, GPT432k0314, GPT432k0613,
Dalle2,
Dalle2, Dalle3,
Claude1, Claude1100k, Claude2, Claude2100k, ClaudeSlack,
SparkDesk, SparkDeskV2, SparkDeskV3,
ChatBison001,
BingCreative, BingBalanced, BingPrecise,
ZhiPuChatGLMTurbo, ZhiPuChatGLMPro, ZhiPuChatGLMStd, ZhiPuChatGLMLite,
QwenTurbo, QwenPlus, QwenTurboNet, QwenPlusNet,
Midjourney, StableDiffusion,
StableDiffusion, Midjourney, MidjourneyFast, MidjourneyTurbo,
LLaMa270B, LLaMa213B, LLaMa27B,
CodeLLaMa34B, CodeLLaMa13B, CodeLLaMa7B,
}
@ -213,7 +222,7 @@ func IsGPT3TurboModel(model string) bool {
}
func IsChatGPTModel(model string) bool {
return IsGPT3TurboModel(model) || IsGPT4Model(model)
return IsGPT3TurboModel(model) || IsGPT4Model(model) || IsDalleModel(model)
}
func IsClaudeModel(model string) bool {
@ -224,6 +233,10 @@ func IsLLaMaModel(model string) bool {
return in(model, LLaMaModelArray)
}
func IsDalleModel(model string) bool {
return model == Dalle2 || model == Dalle3
}
func IsClaude100KModel(model string) bool {
return model == Claude1100k || model == Claude2100k
}
@ -252,6 +265,10 @@ func IsQwenModel(model string) bool {
return in(model, QwenModelArray)
}
func IsMidjourneyModel(model string) bool {
return in(model, MidjourneyModelArray)
}
func IsLongContextModel(model string) bool {
return in(model, LongContextModelArray)
}

View File

@ -1,6 +1,7 @@
package main
import (
"chat/adapter"
"chat/addition"
"chat/admin"
"chat/auth"
@ -29,6 +30,7 @@ func main() {
{
auth.Register(app)
admin.Register(app)
adapter.Register(app)
manager.Register(app)
addition.Register(app)
conversation.Register(app)

View File

@ -3,6 +3,7 @@ package utils
import (
"fmt"
"github.com/goccy/go-json"
"time"
)
func Sum[T int | int64 | float32 | float64](arr []T) T {
@ -149,3 +150,7 @@ func EachNotNil[T any, U any](arr []T, f func(T) *U) []U {
}
return res
}
func Sleep(ms int) {
time.Sleep(time.Duration(ms) * time.Millisecond)
}

View File

@ -31,6 +31,19 @@ func SetInt(cache *redis.Client, key string, value int64, expiration int64) erro
return cache.Set(context.Background(), key, value, time.Duration(expiration)*time.Second).Err()
}
func SetJson(cache *redis.Client, key string, value interface{}, expiration int64) error {
err := cache.Set(context.Background(), key, Marshal(value), time.Duration(expiration)*time.Second).Err()
return err
}
func GetJson[T any](cache *redis.Client, key string) *T {
val, err := cache.Get(context.Background(), key).Result()
if err != nil {
return nil
}
return UnmarshalForm[T](val)
}
func IncrWithLimit(cache *redis.Client, key string, delta int64, limit int64, expiration int64) bool {
// not exist
if _, err := cache.Get(context.Background(), key).Result(); err != nil {

View File

@ -69,7 +69,7 @@ func MapToStruct[T any](data interface{}) *T {
}
}
func ToInt(value string) int {
func ParseInt(value string) int {
if res, err := strconv.Atoi(value); err == nil {
return res
} else {

View File

@ -145,6 +145,12 @@ func CountOutputToken(model string, t int) float32 {
return 0.25
case globals.Midjourney:
return 0.5
case globals.MidjourneyFast:
return 2
case globals.MidjourneyTurbo:
return 5
case globals.Dalle3:
return 5.6
default:
return 0
}