add midjourney models

This commit is contained in:
Zhang Minghan 2023-11-13 11:36:58 +08:00
parent d084e544e6
commit 042f67fd74
22 changed files with 441 additions and 24 deletions

View File

@ -4,6 +4,7 @@ import (
"chat/adapter/bing" "chat/adapter/bing"
"chat/adapter/claude" "chat/adapter/claude"
"chat/adapter/dashscope" "chat/adapter/dashscope"
"chat/adapter/midjourney"
"chat/adapter/oneapi" "chat/adapter/oneapi"
"chat/adapter/palm2" "chat/adapter/palm2"
"chat/adapter/slack" "chat/adapter/slack"
@ -61,6 +62,11 @@ func NewChatRequest(props *ChatProps, hook globals.Hook) error {
Model: props.Model, Model: props.Model,
Message: props.Message, Message: props.Message,
}, hook) }, hook)
} else if globals.IsMidjourneyModel(props.Model) {
return midjourney.NewChatInstanceFromConfig().CreateStreamChatRequest(&midjourney.ChatProps{
Model: props.Model,
Messages: props.Message,
}, hook)
} }
return hook("Sorry, we cannot find the model you are looking for. Please try another model.") return hook("Sorry, we cannot find the model you are looking for. Please try another model.")

View File

@ -70,7 +70,7 @@ func (c *ChatInstance) GetChatBody(props *ChatProps, stream bool) interface{} {
// CreateChatRequest is the native http request body for chatgpt // CreateChatRequest is the native http request body for chatgpt
func (c *ChatInstance) CreateChatRequest(props *ChatProps) (string, error) { func (c *ChatInstance) CreateChatRequest(props *ChatProps) (string, error) {
if props.Model == globals.Dalle2 { if globals.IsDalleModel(props.Model) {
return c.CreateImage(props) return c.CreateImage(props)
} }
@ -95,7 +95,7 @@ func (c *ChatInstance) CreateChatRequest(props *ChatProps) (string, error) {
// CreateStreamChatRequest is the stream response body for chatgpt // CreateStreamChatRequest is the stream response body for chatgpt
func (c *ChatInstance) CreateStreamChatRequest(props *ChatProps, callback globals.Hook) error { func (c *ChatInstance) CreateStreamChatRequest(props *ChatProps, callback globals.Hook) error {
if props.Model == globals.Dalle2 { if globals.IsDalleModel(props.Model) {
if url, err := c.CreateImage(props); err != nil { if url, err := c.CreateImage(props); err != nil {
return err return err
} else { } else {

View File

@ -1,12 +1,14 @@
package chatgpt package chatgpt
import ( import (
"chat/globals"
"chat/utils" "chat/utils"
"fmt" "fmt"
"strings" "strings"
) )
type ImageProps struct { type ImageProps struct {
Model string
Prompt string Prompt string
Size ImageSize Size ImageSize
} }
@ -20,8 +22,13 @@ func (c *ChatInstance) CreateImageRequest(props ImageProps) (string, error) {
res, err := utils.Post( res, err := utils.Post(
c.GetImageEndpoint(), c.GetImageEndpoint(),
c.GetHeader(), ImageRequest{ c.GetHeader(), ImageRequest{
Model: props.Model,
Prompt: props.Prompt, Prompt: props.Prompt,
Size: utils.Multi[ImageSize](len(props.Size) == 0, ImageSize512, props.Size), Size: utils.Multi[ImageSize](
props.Model == globals.Dalle3,
ImageSize1024,
ImageSize512,
),
N: 1, N: 1,
}) })
if err != nil || res == nil { if err != nil || res == nil {
@ -41,6 +48,7 @@ func (c *ChatInstance) CreateImageRequest(props ImageProps) (string, error) {
// CreateImage will create a dalle image from prompt, return markdown of image // CreateImage will create a dalle image from prompt, return markdown of image
func (c *ChatInstance) CreateImage(props *ChatProps) (string, error) { func (c *ChatInstance) CreateImage(props *ChatProps) (string, error) {
url, err := c.CreateImageRequest(ImageProps{ url, err := c.CreateImageRequest(ImageProps{
Model: props.Model,
Prompt: c.GetLatestPrompt(props), Prompt: c.GetLatestPrompt(props),
}) })
if err != nil { if err != nil {

View File

@ -52,7 +52,7 @@ func NewChatInstanceFromModel(props *InstanceProps) *ChatInstance {
globals.GPT432k, globals.GPT432k0613, globals.GPT432k0314: globals.GPT432k, globals.GPT432k0613, globals.GPT432k0314:
return NewChatInstanceFromConfig("gpt4") return NewChatInstanceFromConfig("gpt4")
case globals.GPT4Vision, globals.GPT4Dalle, globals.GPT4All: case globals.GPT4Vision, globals.GPT4Dalle, globals.Dalle3, globals.GPT4All:
return NewChatInstanceFromConfig("reverse") return NewChatInstanceFromConfig("reverse")
case globals.GPT3Turbo, globals.GPT3TurboInstruct, globals.GPT3Turbo0613, globals.GPT3Turbo0301, case globals.GPT3Turbo, globals.GPT3TurboInstruct, globals.GPT3Turbo0613, globals.GPT3Turbo0301,

View File

@ -89,6 +89,7 @@ type ImageSize string
// ImageRequest is the request body for chatgpt dalle image generation // ImageRequest is the request body for chatgpt dalle image generation
type ImageRequest struct { type ImageRequest struct {
Model string `json:"model"`
Prompt string `json:"prompt"` Prompt string `json:"prompt"`
Size ImageSize `json:"size"` Size ImageSize `json:"size"`
N int `json:"n"` N int `json:"n"`

101
adapter/midjourney/api.go Normal file
View File

@ -0,0 +1,101 @@
package midjourney
import (
"chat/utils"
"fmt"
"github.com/spf13/viper"
"strings"
)
func (c *ChatInstance) GetImagineUrl() string {
return fmt.Sprintf("%s/mj/submit/imagine", c.GetEndpoint())
}
func (c *ChatInstance) CreateImagineRequest(prompt string) (*ImagineResponse, error) {
res, err := utils.Post(
c.GetImagineUrl(),
map[string]string{
"Content-Type": "application/json",
"mj-api-secret": c.GetApiSecret(),
},
ImagineRequest{
NotifyHook: fmt.Sprintf("%s/mj/notify", viper.GetString("midjourney.expose")),
Prompt: prompt,
},
)
if err != nil {
return nil, err
}
return utils.MapToStruct[ImagineResponse](res), nil
}
func getStatusCode(code int) error {
switch code {
case SuccessCode, QueueCode:
return nil
case ExistedCode:
return fmt.Errorf("task is existed, please try again later with another prompt")
case MaxQueueCode:
return fmt.Errorf("task queue is full, please try again later")
case NudeCode:
return fmt.Errorf("prompt violates the content policy of midjourney, the request is rejected")
default:
return fmt.Errorf("unknown error from midjourney")
}
}
func getProgress(value string) int {
progress := strings.TrimSuffix(value, "%")
return utils.ParseInt(progress)
}
func (c *ChatInstance) CreateStreamImagineTask(prompt string, hook func(progress int) error) (string, error) {
res, err := c.CreateImagineRequest(prompt)
if err != nil {
return "", err
}
if err := getStatusCode(res.Code); err != nil {
return "", err
}
task := res.Result
progress := -1
for {
utils.Sleep(100)
form := getStorage(task)
if form == nil {
continue
}
switch form.Status {
case Success:
if err := hook(100); err != nil {
return "", err
}
return form.Url, nil
case Failure:
if err := hook(100); err != nil {
return "", err
}
return "", fmt.Errorf("task failed: %s", form.FailReason)
case InProgress:
current := getProgress(form.Progress)
if progress != current {
if err := hook(current); err != nil {
return "", err
}
progress = current
}
}
}
}
func (c *ChatInstance) CreateImagineTask(prompt string) (string, error) {
return c.CreateStreamImagineTask(prompt, func(progress int) error {
return nil
})
}

View File

@ -0,0 +1,79 @@
package midjourney
import (
"chat/globals"
"chat/utils"
"fmt"
"strings"
)
type ChatProps struct {
Messages []globals.Message
Model string
}
func getMode(model string) string {
switch model {
case globals.Midjourney: // relax
return RelaxMode
case globals.MidjourneyFast: // fast
return FastMode
case globals.MidjourneyTurbo: // turbo
return TurboMode
default:
return RelaxMode
}
}
func (c *ChatInstance) GetCleanPrompt(model string, prompt string) string {
arr := strings.Split(strings.TrimSpace(prompt), " ")
var res []string
for _, word := range arr {
if utils.Contains[string](word, ModeArr) {
continue
}
res = append(res, word)
}
res = append(res, getMode(model))
target := strings.Join(res, " ")
return target
}
func (c *ChatInstance) GetPrompt(props *ChatProps) string {
return c.GetCleanPrompt(props.Model, props.Messages[len(props.Messages)-1].Content)
}
func (c *ChatInstance) CreateStreamChatRequest(props *ChatProps, callback globals.Hook) error {
// partial response like:
// ```progress
// 0
// ...
// 100
// ```
// ![image](...)
prompt := c.GetPrompt(props)
if prompt == "" {
return fmt.Errorf("format error: please provide available prompt")
}
if err := callback("```progress\n"); err != nil {
return err
}
url, err := c.CreateStreamImagineTask(prompt, func(progress int) error {
return callback(fmt.Sprintf("%d\n", progress))
})
if err := callback("```\n"); err != nil {
return err
}
if err != nil {
return fmt.Errorf("error from midjourney: %s", err.Error())
}
return callback(utils.GetImageMarkdown(url))
}

View File

@ -0,0 +1,51 @@
package midjourney
import (
"chat/utils"
"fmt"
"github.com/gin-gonic/gin"
"github.com/spf13/viper"
"net/http"
"strings"
)
func InWhiteList(ip string) bool {
arr := strings.Split(viper.GetString("midjourney.white_list"), ",")
return utils.Contains[string](ip, arr)
}
func NotifyAPI(c *gin.Context) {
if !InWhiteList(c.ClientIP()) {
fmt.Println(fmt.Sprintf("[midjourney] notify api: banned request from %s", c.ClientIP()))
c.AbortWithStatus(http.StatusForbidden)
return
}
var form NotifyForm
if err := c.ShouldBindJSON(&form); err != nil {
c.AbortWithStatus(http.StatusBadRequest)
return
}
// fmt.Println(fmt.Sprintf("[midjourney] notify api: get notify: %s (from: %s)", utils.Marshal(form), c.ClientIP()))
if !utils.Contains(form.Status, []string{InProgress, Success, Failure}) {
// ignore
return
}
reason, ok := form.FailReason.(string)
if !ok {
reason = "unknown"
}
err := setStorage(form.Id, StorageForm{
Url: form.ImageUrl,
FailReason: reason,
Progress: form.Progress,
Status: form.Status,
})
c.JSON(http.StatusOK, gin.H{
"status": err == nil,
})
}

View File

@ -0,0 +1,19 @@
package midjourney
import (
"chat/connection"
"chat/utils"
"fmt"
)
func getTaskName(task string) string {
return fmt.Sprintf("nio:mj-task:%s", task)
}
func setStorage(task string, form StorageForm) error {
return utils.SetJson(connection.Cache, getTaskName(task), form, 60*60)
}
func getStorage(task string) *StorageForm {
return utils.GetJson[StorageForm](connection.Cache, getTaskName(task))
}

View File

@ -0,0 +1,32 @@
package midjourney
import (
"github.com/spf13/viper"
)
type ChatInstance struct {
Endpoint string
ApiSecret string
}
func (c *ChatInstance) GetApiSecret() string {
return c.ApiSecret
}
func (c *ChatInstance) GetEndpoint() string {
return c.Endpoint
}
func NewChatInstance(endpoint string, apiSecret string) *ChatInstance {
return &ChatInstance{
Endpoint: endpoint,
ApiSecret: apiSecret,
}
}
func NewChatInstanceFromConfig() *ChatInstance {
return NewChatInstance(
viper.GetString("midjourney.endpoint"),
viper.GetString("midjourney.api_secret"),
)
}

View File

@ -0,0 +1,58 @@
package midjourney
const (
SuccessCode = 1
ExistedCode = 21
QueueCode = 22
MaxQueueCode = 23
NudeCode = 24
)
const (
NotStartStatus = "NOT_START"
Submitted = "SUBMITTED"
InProgress = "IN_PROGRESS"
Failure = "FAILURE"
Success = "SUCCESS"
)
const (
TurboMode = "--turbo"
FastMode = "--fast"
RelaxMode = "--relax"
)
var ModeArr = []string{TurboMode, FastMode, RelaxMode}
type ImagineRequest struct {
NotifyHook string `json:"notifyHook"`
Prompt string `json:"prompt"`
}
type ImagineResponse struct {
Code int `json:"code"`
Description string `json:"description"`
Result string `json:"result"`
}
type NotifyForm struct {
Id string `json:"id"`
Action string `json:"action"`
Status string `json:"status"`
Prompt string `json:"prompt"`
PromptEn string `json:"promptEn"`
Description string `json:"description"`
SubmitTime int64 `json:"submitTime"`
StartTime int64 `json:"startTime"`
FinishTime int64 `json:"finishTime"`
Progress string `json:"progress"`
ImageUrl string `json:"imageUrl"`
FailReason interface{} `json:"failReason"`
}
type StorageForm struct {
Url string `json:"url"`
FailReason string `json:"failReason"`
Progress string `json:"progress"`
Status string `json:"status"`
}

10
adapter/router.go Normal file
View File

@ -0,0 +1,10 @@
package adapter
import (
"chat/adapter/midjourney"
"github.com/gin-gonic/gin"
)
func Register(app *gin.Engine) {
app.POST("/mj/notify", midjourney.NotifyAPI)
}

View File

@ -31,12 +31,7 @@ function ChatInterface({ setTarget }: ChatInterfaceProps) {
if (!ref.current) return; if (!ref.current) return;
const el = ref.current as HTMLDivElement; const el = ref.current as HTMLDivElement;
const event = () => { const event = () => setScrollable(el.scrollTop + el.clientHeight + 20 >= el.scrollHeight);
setScrollable(
el.scrollTop + el.clientHeight + 20 >= el.scrollHeight, // at bottom
);
}
return addEventListeners(el, [ return addEventListeners(el, [
"scroll", "scrollend", "scroll", "scrollend",
"resize", "touchend", "resize", "touchend",

View File

@ -8,7 +8,7 @@ import {
} from "@/utils/env.ts"; } from "@/utils/env.ts";
import { getMemory } from "@/utils/memory.ts"; import { getMemory } from "@/utils/memory.ts";
export const version = "3.6.19"; export const version = "3.6.20";
export const dev: boolean = getDev(); export const dev: boolean = getDev();
export const deploy: boolean = true; export const deploy: boolean = true;
export let rest_api: string = getRestApi(deploy); export let rest_api: string = getRestApi(deploy);
@ -61,7 +61,12 @@ export const supportModels: Model[] = [
{ id: "chat-bison-001", name: "Palm2", free: true, auth: true }, { id: "chat-bison-001", name: "Palm2", free: true, auth: true },
// dalle models // dalle models
{ id: "dalle", name: "DALLE2", free: true, auth: true }, { id: "dall-e-3", name: "DALLE 3", free: false, auth: true },
{ id: "dall-e-2", name: "DALLE 2", free: true, auth: true },
{ id: "midjourney", name: "Midjourney", free: false, auth: true },
{ id: "midjourney-fast", name: "Midjourney Fast", free: false, auth: true },
{ id: "midjourney-turbo", name: "Midjourney Turbo", free: false, auth: true },
// reverse models // reverse models
{ id: "gpt-4-v", name: "GPT-4 Vision", free: false, auth: true }, { id: "gpt-4-v", name: "GPT-4 Vision", free: false, auth: true },
@ -96,7 +101,11 @@ export const planModels = [
"claude-2-100k", "claude-2-100k",
]; ];
export const expensiveModels = ["gpt-4-32k-0613"]; export const expensiveModels = [
"dall-e-3",
"midjourney-turbo",
"gpt-4-32k-0613",
];
export function login() { export function login() {
location.href = `https://deeptrain.net/login?app=${dev ? "dev" : "chatnio"}`; location.href = `https://deeptrain.net/login?app=${dev ? "dev" : "chatnio"}`;

View File

@ -22,7 +22,6 @@ export const apiSlice = createSlice({
state.dialog = false; state.dialog = false;
}, },
setKey: (state, action) => { setKey: (state, action) => {
if (!action.payload.length) return;
state.key = action.payload as string; state.key = action.payload as string;
}, },
}, },
@ -35,7 +34,13 @@ export default apiSlice.reducer;
export const dialogSelector = (state: RootState): boolean => state.api.dialog; export const dialogSelector = (state: RootState): boolean => state.api.dialog;
export const keySelector = (state: RootState): string => state.api.key; export const keySelector = (state: RootState): string => state.api.key;
export const getApiKey = async (dispatch: AppDispatch) => { export const getApiKey = async (dispatch: AppDispatch, retries?: boolean) => {
const response = await getKey(); const response = await getKey();
if (response.status) dispatch(setKey(response.key)); if (response.status) {
if (response.key.length === 0 && retries !== false) {
await getApiKey(dispatch, false);
return;
}
dispatch(setKey(response.key));
};
}; };

View File

@ -11,7 +11,7 @@ func CanEnableModel(db *sql.DB, user *User, model string) bool {
switch model { switch model {
case globals.GPT3Turbo, globals.GPT3TurboInstruct, globals.GPT3Turbo0301, globals.GPT3Turbo0613: case globals.GPT3Turbo, globals.GPT3TurboInstruct, globals.GPT3Turbo0301, globals.GPT3Turbo0613:
return true return true
case globals.GPT4, globals.GPT40613, globals.GPT40314, globals.GPT41106Preview: case globals.GPT4, globals.GPT40613, globals.GPT40314, globals.GPT41106Preview, globals.Dalle3:
return user != nil && user.GetQuota(db) >= 5 return user != nil && user.GetQuota(db) >= 5
case globals.GPT432k, globals.GPT432k0613, globals.GPT432k0314: case globals.GPT432k, globals.GPT432k0613, globals.GPT432k0314:
return user != nil && user.GetQuota(db) >= 50 return user != nil && user.GetQuota(db) >= 50
@ -23,7 +23,7 @@ func CanEnableModel(db *sql.DB, user *User, model string) bool {
return user != nil && user.GetQuota(db) >= 1 return user != nil && user.GetQuota(db) >= 1
case globals.QwenTurbo, globals.QwenPlus, globals.QwenPlusNet, globals.QwenTurboNet: case globals.QwenTurbo, globals.QwenPlus, globals.QwenPlusNet, globals.QwenTurboNet:
return user != nil && user.GetQuota(db) >= 1 return user != nil && user.GetQuota(db) >= 1
case globals.Midjourney, globals.StableDiffusion: case globals.StableDiffusion, globals.Midjourney, globals.MidjourneyFast, globals.MidjourneyTurbo:
return user != nil && user.GetQuota(db) >= 1 return user != nil && user.GetQuota(db) >= 1
case globals.LLaMa27B, globals.LLaMa213B, globals.LLaMa270B, case globals.LLaMa27B, globals.LLaMa213B, globals.LLaMa270B,
globals.CodeLLaMa34B, globals.CodeLLaMa13B, globals.CodeLLaMa7B: globals.CodeLLaMa34B, globals.CodeLLaMa13B, globals.CodeLLaMa7B:

View File

@ -56,7 +56,8 @@ const (
GPT432k = "gpt-4-32k" GPT432k = "gpt-4-32k"
GPT432k0314 = "gpt-4-32k-0314" GPT432k0314 = "gpt-4-32k-0314"
GPT432k0613 = "gpt-4-32k-0613" GPT432k0613 = "gpt-4-32k-0613"
Dalle2 = "dalle" Dalle2 = "dall-e-2"
Dalle3 = "dall-e-3"
Claude1 = "claude-1" Claude1 = "claude-1"
Claude1100k = "claude-1.3" Claude1100k = "claude-1.3"
Claude2 = "claude-1-100k" Claude2 = "claude-1-100k"
@ -78,6 +79,8 @@ const (
QwenTurboNet = "qwen-turbo-net" QwenTurboNet = "qwen-turbo-net"
QwenPlusNet = "qwen-plus-net" QwenPlusNet = "qwen-plus-net"
Midjourney = "midjourney" Midjourney = "midjourney"
MidjourneyFast = "midjourney-fast"
MidjourneyTurbo = "midjourney-turbo"
StableDiffusion = "stable-diffusion" StableDiffusion = "stable-diffusion"
LLaMa270B = "llama-2-70b" LLaMa270B = "llama-2-70b"
LLaMa213B = "llama-2-13b" LLaMa213B = "llama-2-13b"
@ -148,6 +151,12 @@ var QwenModelArray = []string{
QwenPlusNet, QwenPlusNet,
} }
var MidjourneyModelArray = []string{
Midjourney,
MidjourneyFast,
MidjourneyTurbo,
}
var LongContextModelArray = []string{ var LongContextModelArray = []string{
GPT3Turbo16k, GPT3Turbo16k0613, GPT3Turbo16k0301, GPT3Turbo16k, GPT3Turbo16k0613, GPT3Turbo16k0301,
GPT41106Preview, GPT432k, GPT432k0314, GPT432k0613, GPT41106Preview, GPT432k, GPT432k0314, GPT432k0613,
@ -179,14 +188,14 @@ var AllModels = []string{
GPT3Turbo16k, GPT3Turbo16k0613, GPT3Turbo16k0301, GPT3Turbo16k, GPT3Turbo16k0613, GPT3Turbo16k0301,
GPT4, GPT40314, GPT40613, GPT4Vision, GPT4All, GPT41106Preview, GPT4Dalle, GPT4, GPT40314, GPT40613, GPT4Vision, GPT4All, GPT41106Preview, GPT4Dalle,
GPT432k, GPT432k0314, GPT432k0613, GPT432k, GPT432k0314, GPT432k0613,
Dalle2, Dalle2, Dalle3,
Claude1, Claude1100k, Claude2, Claude2100k, ClaudeSlack, Claude1, Claude1100k, Claude2, Claude2100k, ClaudeSlack,
SparkDesk, SparkDeskV2, SparkDeskV3, SparkDesk, SparkDeskV2, SparkDeskV3,
ChatBison001, ChatBison001,
BingCreative, BingBalanced, BingPrecise, BingCreative, BingBalanced, BingPrecise,
ZhiPuChatGLMTurbo, ZhiPuChatGLMPro, ZhiPuChatGLMStd, ZhiPuChatGLMLite, ZhiPuChatGLMTurbo, ZhiPuChatGLMPro, ZhiPuChatGLMStd, ZhiPuChatGLMLite,
QwenTurbo, QwenPlus, QwenTurboNet, QwenPlusNet, QwenTurbo, QwenPlus, QwenTurboNet, QwenPlusNet,
Midjourney, StableDiffusion, StableDiffusion, Midjourney, MidjourneyFast, MidjourneyTurbo,
LLaMa270B, LLaMa213B, LLaMa27B, LLaMa270B, LLaMa213B, LLaMa27B,
CodeLLaMa34B, CodeLLaMa13B, CodeLLaMa7B, CodeLLaMa34B, CodeLLaMa13B, CodeLLaMa7B,
} }
@ -213,7 +222,7 @@ func IsGPT3TurboModel(model string) bool {
} }
func IsChatGPTModel(model string) bool { func IsChatGPTModel(model string) bool {
return IsGPT3TurboModel(model) || IsGPT4Model(model) return IsGPT3TurboModel(model) || IsGPT4Model(model) || IsDalleModel(model)
} }
func IsClaudeModel(model string) bool { func IsClaudeModel(model string) bool {
@ -224,6 +233,10 @@ func IsLLaMaModel(model string) bool {
return in(model, LLaMaModelArray) return in(model, LLaMaModelArray)
} }
func IsDalleModel(model string) bool {
return model == Dalle2 || model == Dalle3
}
func IsClaude100KModel(model string) bool { func IsClaude100KModel(model string) bool {
return model == Claude1100k || model == Claude2100k return model == Claude1100k || model == Claude2100k
} }
@ -252,6 +265,10 @@ func IsQwenModel(model string) bool {
return in(model, QwenModelArray) return in(model, QwenModelArray)
} }
func IsMidjourneyModel(model string) bool {
return in(model, MidjourneyModelArray)
}
func IsLongContextModel(model string) bool { func IsLongContextModel(model string) bool {
return in(model, LongContextModelArray) return in(model, LongContextModelArray)
} }

View File

@ -1,6 +1,7 @@
package main package main
import ( import (
"chat/adapter"
"chat/addition" "chat/addition"
"chat/admin" "chat/admin"
"chat/auth" "chat/auth"
@ -29,6 +30,7 @@ func main() {
{ {
auth.Register(app) auth.Register(app)
admin.Register(app) admin.Register(app)
adapter.Register(app)
manager.Register(app) manager.Register(app)
addition.Register(app) addition.Register(app)
conversation.Register(app) conversation.Register(app)

View File

@ -3,6 +3,7 @@ package utils
import ( import (
"fmt" "fmt"
"github.com/goccy/go-json" "github.com/goccy/go-json"
"time"
) )
func Sum[T int | int64 | float32 | float64](arr []T) T { func Sum[T int | int64 | float32 | float64](arr []T) T {
@ -149,3 +150,7 @@ func EachNotNil[T any, U any](arr []T, f func(T) *U) []U {
} }
return res return res
} }
func Sleep(ms int) {
time.Sleep(time.Duration(ms) * time.Millisecond)
}

View File

@ -31,6 +31,19 @@ func SetInt(cache *redis.Client, key string, value int64, expiration int64) erro
return cache.Set(context.Background(), key, value, time.Duration(expiration)*time.Second).Err() return cache.Set(context.Background(), key, value, time.Duration(expiration)*time.Second).Err()
} }
func SetJson(cache *redis.Client, key string, value interface{}, expiration int64) error {
err := cache.Set(context.Background(), key, Marshal(value), time.Duration(expiration)*time.Second).Err()
return err
}
func GetJson[T any](cache *redis.Client, key string) *T {
val, err := cache.Get(context.Background(), key).Result()
if err != nil {
return nil
}
return UnmarshalForm[T](val)
}
func IncrWithLimit(cache *redis.Client, key string, delta int64, limit int64, expiration int64) bool { func IncrWithLimit(cache *redis.Client, key string, delta int64, limit int64, expiration int64) bool {
// not exist // not exist
if _, err := cache.Get(context.Background(), key).Result(); err != nil { if _, err := cache.Get(context.Background(), key).Result(); err != nil {

View File

@ -69,7 +69,7 @@ func MapToStruct[T any](data interface{}) *T {
} }
} }
func ToInt(value string) int { func ParseInt(value string) int {
if res, err := strconv.Atoi(value); err == nil { if res, err := strconv.Atoi(value); err == nil {
return res return res
} else { } else {

View File

@ -145,6 +145,12 @@ func CountOutputToken(model string, t int) float32 {
return 0.25 return 0.25
case globals.Midjourney: case globals.Midjourney:
return 0.5 return 0.5
case globals.MidjourneyFast:
return 2
case globals.MidjourneyTurbo:
return 5
case globals.Dalle3:
return 5.6
default: default:
return 0 return 0
} }