From 042f67fd74095ade1405afca7b17a19cbb01e273 Mon Sep 17 00:00:00 2001 From: Zhang Minghan Date: Mon, 13 Nov 2023 11:36:58 +0800 Subject: [PATCH] add midjourney models --- adapter/adapter.go | 6 ++ adapter/chatgpt/chat.go | 4 +- adapter/chatgpt/image.go | 12 ++- adapter/chatgpt/struct.go | 2 +- adapter/chatgpt/types.go | 1 + adapter/midjourney/api.go | 101 ++++++++++++++++++++++ adapter/midjourney/chat.go | 79 +++++++++++++++++ adapter/midjourney/expose.go | 51 +++++++++++ adapter/midjourney/storage.go | 19 ++++ adapter/midjourney/struct.go | 32 +++++++ adapter/midjourney/types.go | 58 +++++++++++++ adapter/router.go | 10 +++ app/src/components/home/ChatInterface.tsx | 7 +- app/src/conf.ts | 15 +++- app/src/store/api.ts | 11 ++- auth/rule.go | 4 +- globals/variables.go | 25 +++++- main.go | 2 + utils/base.go | 5 ++ utils/cache.go | 13 +++ utils/char.go | 2 +- utils/tokenizer.go | 6 ++ 22 files changed, 441 insertions(+), 24 deletions(-) create mode 100644 adapter/midjourney/api.go create mode 100644 adapter/midjourney/chat.go create mode 100644 adapter/midjourney/expose.go create mode 100644 adapter/midjourney/storage.go create mode 100644 adapter/midjourney/struct.go create mode 100644 adapter/midjourney/types.go create mode 100644 adapter/router.go diff --git a/adapter/adapter.go b/adapter/adapter.go index b444c39..9eac39b 100644 --- a/adapter/adapter.go +++ b/adapter/adapter.go @@ -4,6 +4,7 @@ import ( "chat/adapter/bing" "chat/adapter/claude" "chat/adapter/dashscope" + "chat/adapter/midjourney" "chat/adapter/oneapi" "chat/adapter/palm2" "chat/adapter/slack" @@ -61,6 +62,11 @@ func NewChatRequest(props *ChatProps, hook globals.Hook) error { Model: props.Model, Message: props.Message, }, hook) + } else if globals.IsMidjourneyModel(props.Model) { + return midjourney.NewChatInstanceFromConfig().CreateStreamChatRequest(&midjourney.ChatProps{ + Model: props.Model, + Messages: props.Message, + }, hook) } return hook("Sorry, we cannot find the model you are looking for. Please try another model.") diff --git a/adapter/chatgpt/chat.go b/adapter/chatgpt/chat.go index c087d92..64c0c1d 100644 --- a/adapter/chatgpt/chat.go +++ b/adapter/chatgpt/chat.go @@ -70,7 +70,7 @@ func (c *ChatInstance) GetChatBody(props *ChatProps, stream bool) interface{} { // CreateChatRequest is the native http request body for chatgpt func (c *ChatInstance) CreateChatRequest(props *ChatProps) (string, error) { - if props.Model == globals.Dalle2 { + if globals.IsDalleModel(props.Model) { return c.CreateImage(props) } @@ -95,7 +95,7 @@ func (c *ChatInstance) CreateChatRequest(props *ChatProps) (string, error) { // CreateStreamChatRequest is the stream response body for chatgpt func (c *ChatInstance) CreateStreamChatRequest(props *ChatProps, callback globals.Hook) error { - if props.Model == globals.Dalle2 { + if globals.IsDalleModel(props.Model) { if url, err := c.CreateImage(props); err != nil { return err } else { diff --git a/adapter/chatgpt/image.go b/adapter/chatgpt/image.go index 70cfdd8..1cca870 100644 --- a/adapter/chatgpt/image.go +++ b/adapter/chatgpt/image.go @@ -1,12 +1,14 @@ package chatgpt import ( + "chat/globals" "chat/utils" "fmt" "strings" ) type ImageProps struct { + Model string Prompt string Size ImageSize } @@ -20,9 +22,14 @@ func (c *ChatInstance) CreateImageRequest(props ImageProps) (string, error) { res, err := utils.Post( c.GetImageEndpoint(), c.GetHeader(), ImageRequest{ + Model: props.Model, Prompt: props.Prompt, - Size: utils.Multi[ImageSize](len(props.Size) == 0, ImageSize512, props.Size), - N: 1, + Size: utils.Multi[ImageSize]( + props.Model == globals.Dalle3, + ImageSize1024, + ImageSize512, + ), + N: 1, }) if err != nil || res == nil { return "", fmt.Errorf("chatgpt error: %s", err.Error()) @@ -41,6 +48,7 @@ func (c *ChatInstance) CreateImageRequest(props ImageProps) (string, error) { // CreateImage will create a dalle image from prompt, return markdown of image func (c *ChatInstance) CreateImage(props *ChatProps) (string, error) { url, err := c.CreateImageRequest(ImageProps{ + Model: props.Model, Prompt: c.GetLatestPrompt(props), }) if err != nil { diff --git a/adapter/chatgpt/struct.go b/adapter/chatgpt/struct.go index 8a15052..8851cf5 100644 --- a/adapter/chatgpt/struct.go +++ b/adapter/chatgpt/struct.go @@ -52,7 +52,7 @@ func NewChatInstanceFromModel(props *InstanceProps) *ChatInstance { globals.GPT432k, globals.GPT432k0613, globals.GPT432k0314: return NewChatInstanceFromConfig("gpt4") - case globals.GPT4Vision, globals.GPT4Dalle, globals.GPT4All: + case globals.GPT4Vision, globals.GPT4Dalle, globals.Dalle3, globals.GPT4All: return NewChatInstanceFromConfig("reverse") case globals.GPT3Turbo, globals.GPT3TurboInstruct, globals.GPT3Turbo0613, globals.GPT3Turbo0301, diff --git a/adapter/chatgpt/types.go b/adapter/chatgpt/types.go index 09af0c4..e68e5fb 100644 --- a/adapter/chatgpt/types.go +++ b/adapter/chatgpt/types.go @@ -89,6 +89,7 @@ type ImageSize string // ImageRequest is the request body for chatgpt dalle image generation type ImageRequest struct { + Model string `json:"model"` Prompt string `json:"prompt"` Size ImageSize `json:"size"` N int `json:"n"` diff --git a/adapter/midjourney/api.go b/adapter/midjourney/api.go new file mode 100644 index 0000000..c9cd6d6 --- /dev/null +++ b/adapter/midjourney/api.go @@ -0,0 +1,101 @@ +package midjourney + +import ( + "chat/utils" + "fmt" + "github.com/spf13/viper" + "strings" +) + +func (c *ChatInstance) GetImagineUrl() string { + return fmt.Sprintf("%s/mj/submit/imagine", c.GetEndpoint()) +} + +func (c *ChatInstance) CreateImagineRequest(prompt string) (*ImagineResponse, error) { + res, err := utils.Post( + c.GetImagineUrl(), + map[string]string{ + "Content-Type": "application/json", + "mj-api-secret": c.GetApiSecret(), + }, + ImagineRequest{ + NotifyHook: fmt.Sprintf("%s/mj/notify", viper.GetString("midjourney.expose")), + Prompt: prompt, + }, + ) + + if err != nil { + return nil, err + } + + return utils.MapToStruct[ImagineResponse](res), nil +} + +func getStatusCode(code int) error { + switch code { + case SuccessCode, QueueCode: + return nil + case ExistedCode: + return fmt.Errorf("task is existed, please try again later with another prompt") + case MaxQueueCode: + return fmt.Errorf("task queue is full, please try again later") + case NudeCode: + return fmt.Errorf("prompt violates the content policy of midjourney, the request is rejected") + default: + return fmt.Errorf("unknown error from midjourney") + } +} + +func getProgress(value string) int { + progress := strings.TrimSuffix(value, "%") + return utils.ParseInt(progress) +} + +func (c *ChatInstance) CreateStreamImagineTask(prompt string, hook func(progress int) error) (string, error) { + res, err := c.CreateImagineRequest(prompt) + if err != nil { + return "", err + } + + if err := getStatusCode(res.Code); err != nil { + return "", err + } + + task := res.Result + progress := -1 + + for { + utils.Sleep(100) + form := getStorage(task) + if form == nil { + continue + } + + switch form.Status { + case Success: + if err := hook(100); err != nil { + return "", err + } + return form.Url, nil + case Failure: + if err := hook(100); err != nil { + return "", err + } + return "", fmt.Errorf("task failed: %s", form.FailReason) + case InProgress: + current := getProgress(form.Progress) + if progress != current { + if err := hook(current); err != nil { + return "", err + } + progress = current + } + } + } +} + +func (c *ChatInstance) CreateImagineTask(prompt string) (string, error) { + return c.CreateStreamImagineTask(prompt, func(progress int) error { + return nil + }) +} diff --git a/adapter/midjourney/chat.go b/adapter/midjourney/chat.go new file mode 100644 index 0000000..ee0c120 --- /dev/null +++ b/adapter/midjourney/chat.go @@ -0,0 +1,79 @@ +package midjourney + +import ( + "chat/globals" + "chat/utils" + "fmt" + "strings" +) + +type ChatProps struct { + Messages []globals.Message + Model string +} + +func getMode(model string) string { + switch model { + case globals.Midjourney: // relax + return RelaxMode + case globals.MidjourneyFast: // fast + return FastMode + case globals.MidjourneyTurbo: // turbo + return TurboMode + default: + return RelaxMode + } +} + +func (c *ChatInstance) GetCleanPrompt(model string, prompt string) string { + arr := strings.Split(strings.TrimSpace(prompt), " ") + var res []string + + for _, word := range arr { + if utils.Contains[string](word, ModeArr) { + continue + } + res = append(res, word) + } + + res = append(res, getMode(model)) + target := strings.Join(res, " ") + return target +} + +func (c *ChatInstance) GetPrompt(props *ChatProps) string { + return c.GetCleanPrompt(props.Model, props.Messages[len(props.Messages)-1].Content) +} + +func (c *ChatInstance) CreateStreamChatRequest(props *ChatProps, callback globals.Hook) error { + // partial response like: + // ```progress + // 0 + // ... + // 100 + // ``` + // ![image](...) + + prompt := c.GetPrompt(props) + if prompt == "" { + return fmt.Errorf("format error: please provide available prompt") + } + + if err := callback("```progress\n"); err != nil { + return err + } + + url, err := c.CreateStreamImagineTask(prompt, func(progress int) error { + return callback(fmt.Sprintf("%d\n", progress)) + }) + + if err := callback("```\n"); err != nil { + return err + } + + if err != nil { + return fmt.Errorf("error from midjourney: %s", err.Error()) + } + + return callback(utils.GetImageMarkdown(url)) +} diff --git a/adapter/midjourney/expose.go b/adapter/midjourney/expose.go new file mode 100644 index 0000000..588eb65 --- /dev/null +++ b/adapter/midjourney/expose.go @@ -0,0 +1,51 @@ +package midjourney + +import ( + "chat/utils" + "fmt" + "github.com/gin-gonic/gin" + "github.com/spf13/viper" + "net/http" + "strings" +) + +func InWhiteList(ip string) bool { + arr := strings.Split(viper.GetString("midjourney.white_list"), ",") + return utils.Contains[string](ip, arr) +} + +func NotifyAPI(c *gin.Context) { + if !InWhiteList(c.ClientIP()) { + fmt.Println(fmt.Sprintf("[midjourney] notify api: banned request from %s", c.ClientIP())) + c.AbortWithStatus(http.StatusForbidden) + return + } + + var form NotifyForm + if err := c.ShouldBindJSON(&form); err != nil { + c.AbortWithStatus(http.StatusBadRequest) + return + } + // fmt.Println(fmt.Sprintf("[midjourney] notify api: get notify: %s (from: %s)", utils.Marshal(form), c.ClientIP())) + + if !utils.Contains(form.Status, []string{InProgress, Success, Failure}) { + // ignore + return + } + + reason, ok := form.FailReason.(string) + if !ok { + reason = "unknown" + } + + err := setStorage(form.Id, StorageForm{ + Url: form.ImageUrl, + FailReason: reason, + Progress: form.Progress, + Status: form.Status, + }) + + c.JSON(http.StatusOK, gin.H{ + "status": err == nil, + }) +} diff --git a/adapter/midjourney/storage.go b/adapter/midjourney/storage.go new file mode 100644 index 0000000..18b5c7b --- /dev/null +++ b/adapter/midjourney/storage.go @@ -0,0 +1,19 @@ +package midjourney + +import ( + "chat/connection" + "chat/utils" + "fmt" +) + +func getTaskName(task string) string { + return fmt.Sprintf("nio:mj-task:%s", task) +} + +func setStorage(task string, form StorageForm) error { + return utils.SetJson(connection.Cache, getTaskName(task), form, 60*60) +} + +func getStorage(task string) *StorageForm { + return utils.GetJson[StorageForm](connection.Cache, getTaskName(task)) +} diff --git a/adapter/midjourney/struct.go b/adapter/midjourney/struct.go new file mode 100644 index 0000000..6e32c9a --- /dev/null +++ b/adapter/midjourney/struct.go @@ -0,0 +1,32 @@ +package midjourney + +import ( + "github.com/spf13/viper" +) + +type ChatInstance struct { + Endpoint string + ApiSecret string +} + +func (c *ChatInstance) GetApiSecret() string { + return c.ApiSecret +} + +func (c *ChatInstance) GetEndpoint() string { + return c.Endpoint +} + +func NewChatInstance(endpoint string, apiSecret string) *ChatInstance { + return &ChatInstance{ + Endpoint: endpoint, + ApiSecret: apiSecret, + } +} + +func NewChatInstanceFromConfig() *ChatInstance { + return NewChatInstance( + viper.GetString("midjourney.endpoint"), + viper.GetString("midjourney.api_secret"), + ) +} diff --git a/adapter/midjourney/types.go b/adapter/midjourney/types.go new file mode 100644 index 0000000..9d7e052 --- /dev/null +++ b/adapter/midjourney/types.go @@ -0,0 +1,58 @@ +package midjourney + +const ( + SuccessCode = 1 + ExistedCode = 21 + QueueCode = 22 + MaxQueueCode = 23 + NudeCode = 24 +) + +const ( + NotStartStatus = "NOT_START" + Submitted = "SUBMITTED" + InProgress = "IN_PROGRESS" + Failure = "FAILURE" + Success = "SUCCESS" +) + +const ( + TurboMode = "--turbo" + FastMode = "--fast" + RelaxMode = "--relax" +) + +var ModeArr = []string{TurboMode, FastMode, RelaxMode} + +type ImagineRequest struct { + NotifyHook string `json:"notifyHook"` + Prompt string `json:"prompt"` +} + +type ImagineResponse struct { + Code int `json:"code"` + Description string `json:"description"` + Result string `json:"result"` +} + +type NotifyForm struct { + Id string `json:"id"` + Action string `json:"action"` + Status string `json:"status"` + Prompt string `json:"prompt"` + PromptEn string `json:"promptEn"` + Description string `json:"description"` + SubmitTime int64 `json:"submitTime"` + StartTime int64 `json:"startTime"` + FinishTime int64 `json:"finishTime"` + Progress string `json:"progress"` + ImageUrl string `json:"imageUrl"` + FailReason interface{} `json:"failReason"` +} + +type StorageForm struct { + Url string `json:"url"` + FailReason string `json:"failReason"` + Progress string `json:"progress"` + Status string `json:"status"` +} diff --git a/adapter/router.go b/adapter/router.go new file mode 100644 index 0000000..e8cfc1d --- /dev/null +++ b/adapter/router.go @@ -0,0 +1,10 @@ +package adapter + +import ( + "chat/adapter/midjourney" + "github.com/gin-gonic/gin" +) + +func Register(app *gin.Engine) { + app.POST("/mj/notify", midjourney.NotifyAPI) +} diff --git a/app/src/components/home/ChatInterface.tsx b/app/src/components/home/ChatInterface.tsx index 9737bc4..38454a9 100644 --- a/app/src/components/home/ChatInterface.tsx +++ b/app/src/components/home/ChatInterface.tsx @@ -31,12 +31,7 @@ function ChatInterface({ setTarget }: ChatInterfaceProps) { if (!ref.current) return; const el = ref.current as HTMLDivElement; - const event = () => { - setScrollable( - el.scrollTop + el.clientHeight + 20 >= el.scrollHeight, // at bottom - ); - } - + const event = () => setScrollable(el.scrollTop + el.clientHeight + 20 >= el.scrollHeight); return addEventListeners(el, [ "scroll", "scrollend", "resize", "touchend", diff --git a/app/src/conf.ts b/app/src/conf.ts index 17c9dff..c267ed0 100644 --- a/app/src/conf.ts +++ b/app/src/conf.ts @@ -8,7 +8,7 @@ import { } from "@/utils/env.ts"; import { getMemory } from "@/utils/memory.ts"; -export const version = "3.6.19"; +export const version = "3.6.20"; export const dev: boolean = getDev(); export const deploy: boolean = true; export let rest_api: string = getRestApi(deploy); @@ -61,7 +61,12 @@ export const supportModels: Model[] = [ { id: "chat-bison-001", name: "Palm2", free: true, auth: true }, // dalle models - { id: "dalle", name: "DALLE2", free: true, auth: true }, + { id: "dall-e-3", name: "DALLE 3", free: false, auth: true }, + { id: "dall-e-2", name: "DALLE 2", free: true, auth: true }, + + { id: "midjourney", name: "Midjourney", free: false, auth: true }, + { id: "midjourney-fast", name: "Midjourney Fast", free: false, auth: true }, + { id: "midjourney-turbo", name: "Midjourney Turbo", free: false, auth: true }, // reverse models { id: "gpt-4-v", name: "GPT-4 Vision", free: false, auth: true }, @@ -96,7 +101,11 @@ export const planModels = [ "claude-2-100k", ]; -export const expensiveModels = ["gpt-4-32k-0613"]; +export const expensiveModels = [ + "dall-e-3", + "midjourney-turbo", + "gpt-4-32k-0613", +]; export function login() { location.href = `https://deeptrain.net/login?app=${dev ? "dev" : "chatnio"}`; diff --git a/app/src/store/api.ts b/app/src/store/api.ts index 9fc5d83..2ffd288 100644 --- a/app/src/store/api.ts +++ b/app/src/store/api.ts @@ -22,7 +22,6 @@ export const apiSlice = createSlice({ state.dialog = false; }, setKey: (state, action) => { - if (!action.payload.length) return; state.key = action.payload as string; }, }, @@ -35,7 +34,13 @@ export default apiSlice.reducer; export const dialogSelector = (state: RootState): boolean => state.api.dialog; export const keySelector = (state: RootState): string => state.api.key; -export const getApiKey = async (dispatch: AppDispatch) => { +export const getApiKey = async (dispatch: AppDispatch, retries?: boolean) => { const response = await getKey(); - if (response.status) dispatch(setKey(response.key)); + if (response.status) { + if (response.key.length === 0 && retries !== false) { + await getApiKey(dispatch, false); + return; + } + dispatch(setKey(response.key)); + }; }; diff --git a/auth/rule.go b/auth/rule.go index d0aefd2..b4aa4d0 100644 --- a/auth/rule.go +++ b/auth/rule.go @@ -11,7 +11,7 @@ func CanEnableModel(db *sql.DB, user *User, model string) bool { switch model { case globals.GPT3Turbo, globals.GPT3TurboInstruct, globals.GPT3Turbo0301, globals.GPT3Turbo0613: return true - case globals.GPT4, globals.GPT40613, globals.GPT40314, globals.GPT41106Preview: + case globals.GPT4, globals.GPT40613, globals.GPT40314, globals.GPT41106Preview, globals.Dalle3: return user != nil && user.GetQuota(db) >= 5 case globals.GPT432k, globals.GPT432k0613, globals.GPT432k0314: return user != nil && user.GetQuota(db) >= 50 @@ -23,7 +23,7 @@ func CanEnableModel(db *sql.DB, user *User, model string) bool { return user != nil && user.GetQuota(db) >= 1 case globals.QwenTurbo, globals.QwenPlus, globals.QwenPlusNet, globals.QwenTurboNet: return user != nil && user.GetQuota(db) >= 1 - case globals.Midjourney, globals.StableDiffusion: + case globals.StableDiffusion, globals.Midjourney, globals.MidjourneyFast, globals.MidjourneyTurbo: return user != nil && user.GetQuota(db) >= 1 case globals.LLaMa27B, globals.LLaMa213B, globals.LLaMa270B, globals.CodeLLaMa34B, globals.CodeLLaMa13B, globals.CodeLLaMa7B: diff --git a/globals/variables.go b/globals/variables.go index fe66d68..248b335 100644 --- a/globals/variables.go +++ b/globals/variables.go @@ -56,7 +56,8 @@ const ( GPT432k = "gpt-4-32k" GPT432k0314 = "gpt-4-32k-0314" GPT432k0613 = "gpt-4-32k-0613" - Dalle2 = "dalle" + Dalle2 = "dall-e-2" + Dalle3 = "dall-e-3" Claude1 = "claude-1" Claude1100k = "claude-1.3" Claude2 = "claude-1-100k" @@ -78,6 +79,8 @@ const ( QwenTurboNet = "qwen-turbo-net" QwenPlusNet = "qwen-plus-net" Midjourney = "midjourney" + MidjourneyFast = "midjourney-fast" + MidjourneyTurbo = "midjourney-turbo" StableDiffusion = "stable-diffusion" LLaMa270B = "llama-2-70b" LLaMa213B = "llama-2-13b" @@ -148,6 +151,12 @@ var QwenModelArray = []string{ QwenPlusNet, } +var MidjourneyModelArray = []string{ + Midjourney, + MidjourneyFast, + MidjourneyTurbo, +} + var LongContextModelArray = []string{ GPT3Turbo16k, GPT3Turbo16k0613, GPT3Turbo16k0301, GPT41106Preview, GPT432k, GPT432k0314, GPT432k0613, @@ -179,14 +188,14 @@ var AllModels = []string{ GPT3Turbo16k, GPT3Turbo16k0613, GPT3Turbo16k0301, GPT4, GPT40314, GPT40613, GPT4Vision, GPT4All, GPT41106Preview, GPT4Dalle, GPT432k, GPT432k0314, GPT432k0613, - Dalle2, + Dalle2, Dalle3, Claude1, Claude1100k, Claude2, Claude2100k, ClaudeSlack, SparkDesk, SparkDeskV2, SparkDeskV3, ChatBison001, BingCreative, BingBalanced, BingPrecise, ZhiPuChatGLMTurbo, ZhiPuChatGLMPro, ZhiPuChatGLMStd, ZhiPuChatGLMLite, QwenTurbo, QwenPlus, QwenTurboNet, QwenPlusNet, - Midjourney, StableDiffusion, + StableDiffusion, Midjourney, MidjourneyFast, MidjourneyTurbo, LLaMa270B, LLaMa213B, LLaMa27B, CodeLLaMa34B, CodeLLaMa13B, CodeLLaMa7B, } @@ -213,7 +222,7 @@ func IsGPT3TurboModel(model string) bool { } func IsChatGPTModel(model string) bool { - return IsGPT3TurboModel(model) || IsGPT4Model(model) + return IsGPT3TurboModel(model) || IsGPT4Model(model) || IsDalleModel(model) } func IsClaudeModel(model string) bool { @@ -224,6 +233,10 @@ func IsLLaMaModel(model string) bool { return in(model, LLaMaModelArray) } +func IsDalleModel(model string) bool { + return model == Dalle2 || model == Dalle3 +} + func IsClaude100KModel(model string) bool { return model == Claude1100k || model == Claude2100k } @@ -252,6 +265,10 @@ func IsQwenModel(model string) bool { return in(model, QwenModelArray) } +func IsMidjourneyModel(model string) bool { + return in(model, MidjourneyModelArray) +} + func IsLongContextModel(model string) bool { return in(model, LongContextModelArray) } diff --git a/main.go b/main.go index b13a258..75f6856 100644 --- a/main.go +++ b/main.go @@ -1,6 +1,7 @@ package main import ( + "chat/adapter" "chat/addition" "chat/admin" "chat/auth" @@ -29,6 +30,7 @@ func main() { { auth.Register(app) admin.Register(app) + adapter.Register(app) manager.Register(app) addition.Register(app) conversation.Register(app) diff --git a/utils/base.go b/utils/base.go index d06896f..c3bb625 100644 --- a/utils/base.go +++ b/utils/base.go @@ -3,6 +3,7 @@ package utils import ( "fmt" "github.com/goccy/go-json" + "time" ) func Sum[T int | int64 | float32 | float64](arr []T) T { @@ -149,3 +150,7 @@ func EachNotNil[T any, U any](arr []T, f func(T) *U) []U { } return res } + +func Sleep(ms int) { + time.Sleep(time.Duration(ms) * time.Millisecond) +} diff --git a/utils/cache.go b/utils/cache.go index 018a873..3769a4f 100644 --- a/utils/cache.go +++ b/utils/cache.go @@ -31,6 +31,19 @@ func SetInt(cache *redis.Client, key string, value int64, expiration int64) erro return cache.Set(context.Background(), key, value, time.Duration(expiration)*time.Second).Err() } +func SetJson(cache *redis.Client, key string, value interface{}, expiration int64) error { + err := cache.Set(context.Background(), key, Marshal(value), time.Duration(expiration)*time.Second).Err() + return err +} + +func GetJson[T any](cache *redis.Client, key string) *T { + val, err := cache.Get(context.Background(), key).Result() + if err != nil { + return nil + } + return UnmarshalForm[T](val) +} + func IncrWithLimit(cache *redis.Client, key string, delta int64, limit int64, expiration int64) bool { // not exist if _, err := cache.Get(context.Background(), key).Result(); err != nil { diff --git a/utils/char.go b/utils/char.go index a5e3de1..ee3af23 100644 --- a/utils/char.go +++ b/utils/char.go @@ -69,7 +69,7 @@ func MapToStruct[T any](data interface{}) *T { } } -func ToInt(value string) int { +func ParseInt(value string) int { if res, err := strconv.Atoi(value); err == nil { return res } else { diff --git a/utils/tokenizer.go b/utils/tokenizer.go index bb5c7ed..2af9bba 100644 --- a/utils/tokenizer.go +++ b/utils/tokenizer.go @@ -145,6 +145,12 @@ func CountOutputToken(model string, t int) float32 { return 0.25 case globals.Midjourney: return 0.5 + case globals.MidjourneyFast: + return 2 + case globals.MidjourneyTurbo: + return 5 + case globals.Dalle3: + return 5.6 default: return 0 }