mirror of
https://github.com/coaidev/coai.git
synced 2025-05-20 05:20:15 +09:00
add midjourney models
This commit is contained in:
parent
d084e544e6
commit
042f67fd74
@ -4,6 +4,7 @@ import (
|
||||
"chat/adapter/bing"
|
||||
"chat/adapter/claude"
|
||||
"chat/adapter/dashscope"
|
||||
"chat/adapter/midjourney"
|
||||
"chat/adapter/oneapi"
|
||||
"chat/adapter/palm2"
|
||||
"chat/adapter/slack"
|
||||
@ -61,6 +62,11 @@ func NewChatRequest(props *ChatProps, hook globals.Hook) error {
|
||||
Model: props.Model,
|
||||
Message: props.Message,
|
||||
}, hook)
|
||||
} else if globals.IsMidjourneyModel(props.Model) {
|
||||
return midjourney.NewChatInstanceFromConfig().CreateStreamChatRequest(&midjourney.ChatProps{
|
||||
Model: props.Model,
|
||||
Messages: props.Message,
|
||||
}, hook)
|
||||
}
|
||||
|
||||
return hook("Sorry, we cannot find the model you are looking for. Please try another model.")
|
||||
|
@ -70,7 +70,7 @@ func (c *ChatInstance) GetChatBody(props *ChatProps, stream bool) interface{} {
|
||||
|
||||
// CreateChatRequest is the native http request body for chatgpt
|
||||
func (c *ChatInstance) CreateChatRequest(props *ChatProps) (string, error) {
|
||||
if props.Model == globals.Dalle2 {
|
||||
if globals.IsDalleModel(props.Model) {
|
||||
return c.CreateImage(props)
|
||||
}
|
||||
|
||||
@ -95,7 +95,7 @@ func (c *ChatInstance) CreateChatRequest(props *ChatProps) (string, error) {
|
||||
|
||||
// CreateStreamChatRequest is the stream response body for chatgpt
|
||||
func (c *ChatInstance) CreateStreamChatRequest(props *ChatProps, callback globals.Hook) error {
|
||||
if props.Model == globals.Dalle2 {
|
||||
if globals.IsDalleModel(props.Model) {
|
||||
if url, err := c.CreateImage(props); err != nil {
|
||||
return err
|
||||
} else {
|
||||
|
@ -1,12 +1,14 @@
|
||||
package chatgpt
|
||||
|
||||
import (
|
||||
"chat/globals"
|
||||
"chat/utils"
|
||||
"fmt"
|
||||
"strings"
|
||||
)
|
||||
|
||||
type ImageProps struct {
|
||||
Model string
|
||||
Prompt string
|
||||
Size ImageSize
|
||||
}
|
||||
@ -20,8 +22,13 @@ func (c *ChatInstance) CreateImageRequest(props ImageProps) (string, error) {
|
||||
res, err := utils.Post(
|
||||
c.GetImageEndpoint(),
|
||||
c.GetHeader(), ImageRequest{
|
||||
Model: props.Model,
|
||||
Prompt: props.Prompt,
|
||||
Size: utils.Multi[ImageSize](len(props.Size) == 0, ImageSize512, props.Size),
|
||||
Size: utils.Multi[ImageSize](
|
||||
props.Model == globals.Dalle3,
|
||||
ImageSize1024,
|
||||
ImageSize512,
|
||||
),
|
||||
N: 1,
|
||||
})
|
||||
if err != nil || res == nil {
|
||||
@ -41,6 +48,7 @@ func (c *ChatInstance) CreateImageRequest(props ImageProps) (string, error) {
|
||||
// CreateImage will create a dalle image from prompt, return markdown of image
|
||||
func (c *ChatInstance) CreateImage(props *ChatProps) (string, error) {
|
||||
url, err := c.CreateImageRequest(ImageProps{
|
||||
Model: props.Model,
|
||||
Prompt: c.GetLatestPrompt(props),
|
||||
})
|
||||
if err != nil {
|
||||
|
@ -52,7 +52,7 @@ func NewChatInstanceFromModel(props *InstanceProps) *ChatInstance {
|
||||
globals.GPT432k, globals.GPT432k0613, globals.GPT432k0314:
|
||||
return NewChatInstanceFromConfig("gpt4")
|
||||
|
||||
case globals.GPT4Vision, globals.GPT4Dalle, globals.GPT4All:
|
||||
case globals.GPT4Vision, globals.GPT4Dalle, globals.Dalle3, globals.GPT4All:
|
||||
return NewChatInstanceFromConfig("reverse")
|
||||
|
||||
case globals.GPT3Turbo, globals.GPT3TurboInstruct, globals.GPT3Turbo0613, globals.GPT3Turbo0301,
|
||||
|
@ -89,6 +89,7 @@ type ImageSize string
|
||||
|
||||
// ImageRequest is the request body for chatgpt dalle image generation
|
||||
type ImageRequest struct {
|
||||
Model string `json:"model"`
|
||||
Prompt string `json:"prompt"`
|
||||
Size ImageSize `json:"size"`
|
||||
N int `json:"n"`
|
||||
|
101
adapter/midjourney/api.go
Normal file
101
adapter/midjourney/api.go
Normal file
@ -0,0 +1,101 @@
|
||||
package midjourney
|
||||
|
||||
import (
|
||||
"chat/utils"
|
||||
"fmt"
|
||||
"github.com/spf13/viper"
|
||||
"strings"
|
||||
)
|
||||
|
||||
func (c *ChatInstance) GetImagineUrl() string {
|
||||
return fmt.Sprintf("%s/mj/submit/imagine", c.GetEndpoint())
|
||||
}
|
||||
|
||||
func (c *ChatInstance) CreateImagineRequest(prompt string) (*ImagineResponse, error) {
|
||||
res, err := utils.Post(
|
||||
c.GetImagineUrl(),
|
||||
map[string]string{
|
||||
"Content-Type": "application/json",
|
||||
"mj-api-secret": c.GetApiSecret(),
|
||||
},
|
||||
ImagineRequest{
|
||||
NotifyHook: fmt.Sprintf("%s/mj/notify", viper.GetString("midjourney.expose")),
|
||||
Prompt: prompt,
|
||||
},
|
||||
)
|
||||
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return utils.MapToStruct[ImagineResponse](res), nil
|
||||
}
|
||||
|
||||
func getStatusCode(code int) error {
|
||||
switch code {
|
||||
case SuccessCode, QueueCode:
|
||||
return nil
|
||||
case ExistedCode:
|
||||
return fmt.Errorf("task is existed, please try again later with another prompt")
|
||||
case MaxQueueCode:
|
||||
return fmt.Errorf("task queue is full, please try again later")
|
||||
case NudeCode:
|
||||
return fmt.Errorf("prompt violates the content policy of midjourney, the request is rejected")
|
||||
default:
|
||||
return fmt.Errorf("unknown error from midjourney")
|
||||
}
|
||||
}
|
||||
|
||||
func getProgress(value string) int {
|
||||
progress := strings.TrimSuffix(value, "%")
|
||||
return utils.ParseInt(progress)
|
||||
}
|
||||
|
||||
func (c *ChatInstance) CreateStreamImagineTask(prompt string, hook func(progress int) error) (string, error) {
|
||||
res, err := c.CreateImagineRequest(prompt)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
if err := getStatusCode(res.Code); err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
task := res.Result
|
||||
progress := -1
|
||||
|
||||
for {
|
||||
utils.Sleep(100)
|
||||
form := getStorage(task)
|
||||
if form == nil {
|
||||
continue
|
||||
}
|
||||
|
||||
switch form.Status {
|
||||
case Success:
|
||||
if err := hook(100); err != nil {
|
||||
return "", err
|
||||
}
|
||||
return form.Url, nil
|
||||
case Failure:
|
||||
if err := hook(100); err != nil {
|
||||
return "", err
|
||||
}
|
||||
return "", fmt.Errorf("task failed: %s", form.FailReason)
|
||||
case InProgress:
|
||||
current := getProgress(form.Progress)
|
||||
if progress != current {
|
||||
if err := hook(current); err != nil {
|
||||
return "", err
|
||||
}
|
||||
progress = current
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (c *ChatInstance) CreateImagineTask(prompt string) (string, error) {
|
||||
return c.CreateStreamImagineTask(prompt, func(progress int) error {
|
||||
return nil
|
||||
})
|
||||
}
|
79
adapter/midjourney/chat.go
Normal file
79
adapter/midjourney/chat.go
Normal file
@ -0,0 +1,79 @@
|
||||
package midjourney
|
||||
|
||||
import (
|
||||
"chat/globals"
|
||||
"chat/utils"
|
||||
"fmt"
|
||||
"strings"
|
||||
)
|
||||
|
||||
type ChatProps struct {
|
||||
Messages []globals.Message
|
||||
Model string
|
||||
}
|
||||
|
||||
func getMode(model string) string {
|
||||
switch model {
|
||||
case globals.Midjourney: // relax
|
||||
return RelaxMode
|
||||
case globals.MidjourneyFast: // fast
|
||||
return FastMode
|
||||
case globals.MidjourneyTurbo: // turbo
|
||||
return TurboMode
|
||||
default:
|
||||
return RelaxMode
|
||||
}
|
||||
}
|
||||
|
||||
func (c *ChatInstance) GetCleanPrompt(model string, prompt string) string {
|
||||
arr := strings.Split(strings.TrimSpace(prompt), " ")
|
||||
var res []string
|
||||
|
||||
for _, word := range arr {
|
||||
if utils.Contains[string](word, ModeArr) {
|
||||
continue
|
||||
}
|
||||
res = append(res, word)
|
||||
}
|
||||
|
||||
res = append(res, getMode(model))
|
||||
target := strings.Join(res, " ")
|
||||
return target
|
||||
}
|
||||
|
||||
func (c *ChatInstance) GetPrompt(props *ChatProps) string {
|
||||
return c.GetCleanPrompt(props.Model, props.Messages[len(props.Messages)-1].Content)
|
||||
}
|
||||
|
||||
func (c *ChatInstance) CreateStreamChatRequest(props *ChatProps, callback globals.Hook) error {
|
||||
// partial response like:
|
||||
// ```progress
|
||||
// 0
|
||||
// ...
|
||||
// 100
|
||||
// ```
|
||||
// 
|
||||
|
||||
prompt := c.GetPrompt(props)
|
||||
if prompt == "" {
|
||||
return fmt.Errorf("format error: please provide available prompt")
|
||||
}
|
||||
|
||||
if err := callback("```progress\n"); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
url, err := c.CreateStreamImagineTask(prompt, func(progress int) error {
|
||||
return callback(fmt.Sprintf("%d\n", progress))
|
||||
})
|
||||
|
||||
if err := callback("```\n"); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
return fmt.Errorf("error from midjourney: %s", err.Error())
|
||||
}
|
||||
|
||||
return callback(utils.GetImageMarkdown(url))
|
||||
}
|
51
adapter/midjourney/expose.go
Normal file
51
adapter/midjourney/expose.go
Normal file
@ -0,0 +1,51 @@
|
||||
package midjourney
|
||||
|
||||
import (
|
||||
"chat/utils"
|
||||
"fmt"
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/spf13/viper"
|
||||
"net/http"
|
||||
"strings"
|
||||
)
|
||||
|
||||
func InWhiteList(ip string) bool {
|
||||
arr := strings.Split(viper.GetString("midjourney.white_list"), ",")
|
||||
return utils.Contains[string](ip, arr)
|
||||
}
|
||||
|
||||
func NotifyAPI(c *gin.Context) {
|
||||
if !InWhiteList(c.ClientIP()) {
|
||||
fmt.Println(fmt.Sprintf("[midjourney] notify api: banned request from %s", c.ClientIP()))
|
||||
c.AbortWithStatus(http.StatusForbidden)
|
||||
return
|
||||
}
|
||||
|
||||
var form NotifyForm
|
||||
if err := c.ShouldBindJSON(&form); err != nil {
|
||||
c.AbortWithStatus(http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
// fmt.Println(fmt.Sprintf("[midjourney] notify api: get notify: %s (from: %s)", utils.Marshal(form), c.ClientIP()))
|
||||
|
||||
if !utils.Contains(form.Status, []string{InProgress, Success, Failure}) {
|
||||
// ignore
|
||||
return
|
||||
}
|
||||
|
||||
reason, ok := form.FailReason.(string)
|
||||
if !ok {
|
||||
reason = "unknown"
|
||||
}
|
||||
|
||||
err := setStorage(form.Id, StorageForm{
|
||||
Url: form.ImageUrl,
|
||||
FailReason: reason,
|
||||
Progress: form.Progress,
|
||||
Status: form.Status,
|
||||
})
|
||||
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"status": err == nil,
|
||||
})
|
||||
}
|
19
adapter/midjourney/storage.go
Normal file
19
adapter/midjourney/storage.go
Normal file
@ -0,0 +1,19 @@
|
||||
package midjourney
|
||||
|
||||
import (
|
||||
"chat/connection"
|
||||
"chat/utils"
|
||||
"fmt"
|
||||
)
|
||||
|
||||
func getTaskName(task string) string {
|
||||
return fmt.Sprintf("nio:mj-task:%s", task)
|
||||
}
|
||||
|
||||
func setStorage(task string, form StorageForm) error {
|
||||
return utils.SetJson(connection.Cache, getTaskName(task), form, 60*60)
|
||||
}
|
||||
|
||||
func getStorage(task string) *StorageForm {
|
||||
return utils.GetJson[StorageForm](connection.Cache, getTaskName(task))
|
||||
}
|
32
adapter/midjourney/struct.go
Normal file
32
adapter/midjourney/struct.go
Normal file
@ -0,0 +1,32 @@
|
||||
package midjourney
|
||||
|
||||
import (
|
||||
"github.com/spf13/viper"
|
||||
)
|
||||
|
||||
type ChatInstance struct {
|
||||
Endpoint string
|
||||
ApiSecret string
|
||||
}
|
||||
|
||||
func (c *ChatInstance) GetApiSecret() string {
|
||||
return c.ApiSecret
|
||||
}
|
||||
|
||||
func (c *ChatInstance) GetEndpoint() string {
|
||||
return c.Endpoint
|
||||
}
|
||||
|
||||
func NewChatInstance(endpoint string, apiSecret string) *ChatInstance {
|
||||
return &ChatInstance{
|
||||
Endpoint: endpoint,
|
||||
ApiSecret: apiSecret,
|
||||
}
|
||||
}
|
||||
|
||||
func NewChatInstanceFromConfig() *ChatInstance {
|
||||
return NewChatInstance(
|
||||
viper.GetString("midjourney.endpoint"),
|
||||
viper.GetString("midjourney.api_secret"),
|
||||
)
|
||||
}
|
58
adapter/midjourney/types.go
Normal file
58
adapter/midjourney/types.go
Normal file
@ -0,0 +1,58 @@
|
||||
package midjourney
|
||||
|
||||
const (
|
||||
SuccessCode = 1
|
||||
ExistedCode = 21
|
||||
QueueCode = 22
|
||||
MaxQueueCode = 23
|
||||
NudeCode = 24
|
||||
)
|
||||
|
||||
const (
|
||||
NotStartStatus = "NOT_START"
|
||||
Submitted = "SUBMITTED"
|
||||
InProgress = "IN_PROGRESS"
|
||||
Failure = "FAILURE"
|
||||
Success = "SUCCESS"
|
||||
)
|
||||
|
||||
const (
|
||||
TurboMode = "--turbo"
|
||||
FastMode = "--fast"
|
||||
RelaxMode = "--relax"
|
||||
)
|
||||
|
||||
var ModeArr = []string{TurboMode, FastMode, RelaxMode}
|
||||
|
||||
type ImagineRequest struct {
|
||||
NotifyHook string `json:"notifyHook"`
|
||||
Prompt string `json:"prompt"`
|
||||
}
|
||||
|
||||
type ImagineResponse struct {
|
||||
Code int `json:"code"`
|
||||
Description string `json:"description"`
|
||||
Result string `json:"result"`
|
||||
}
|
||||
|
||||
type NotifyForm struct {
|
||||
Id string `json:"id"`
|
||||
Action string `json:"action"`
|
||||
Status string `json:"status"`
|
||||
Prompt string `json:"prompt"`
|
||||
PromptEn string `json:"promptEn"`
|
||||
Description string `json:"description"`
|
||||
SubmitTime int64 `json:"submitTime"`
|
||||
StartTime int64 `json:"startTime"`
|
||||
FinishTime int64 `json:"finishTime"`
|
||||
Progress string `json:"progress"`
|
||||
ImageUrl string `json:"imageUrl"`
|
||||
FailReason interface{} `json:"failReason"`
|
||||
}
|
||||
|
||||
type StorageForm struct {
|
||||
Url string `json:"url"`
|
||||
FailReason string `json:"failReason"`
|
||||
Progress string `json:"progress"`
|
||||
Status string `json:"status"`
|
||||
}
|
10
adapter/router.go
Normal file
10
adapter/router.go
Normal file
@ -0,0 +1,10 @@
|
||||
package adapter
|
||||
|
||||
import (
|
||||
"chat/adapter/midjourney"
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
func Register(app *gin.Engine) {
|
||||
app.POST("/mj/notify", midjourney.NotifyAPI)
|
||||
}
|
@ -31,12 +31,7 @@ function ChatInterface({ setTarget }: ChatInterfaceProps) {
|
||||
if (!ref.current) return;
|
||||
const el = ref.current as HTMLDivElement;
|
||||
|
||||
const event = () => {
|
||||
setScrollable(
|
||||
el.scrollTop + el.clientHeight + 20 >= el.scrollHeight, // at bottom
|
||||
);
|
||||
}
|
||||
|
||||
const event = () => setScrollable(el.scrollTop + el.clientHeight + 20 >= el.scrollHeight);
|
||||
return addEventListeners(el, [
|
||||
"scroll", "scrollend",
|
||||
"resize", "touchend",
|
||||
|
@ -8,7 +8,7 @@ import {
|
||||
} from "@/utils/env.ts";
|
||||
import { getMemory } from "@/utils/memory.ts";
|
||||
|
||||
export const version = "3.6.19";
|
||||
export const version = "3.6.20";
|
||||
export const dev: boolean = getDev();
|
||||
export const deploy: boolean = true;
|
||||
export let rest_api: string = getRestApi(deploy);
|
||||
@ -61,7 +61,12 @@ export const supportModels: Model[] = [
|
||||
{ id: "chat-bison-001", name: "Palm2", free: true, auth: true },
|
||||
|
||||
// dalle models
|
||||
{ id: "dalle", name: "DALLE2", free: true, auth: true },
|
||||
{ id: "dall-e-3", name: "DALLE 3", free: false, auth: true },
|
||||
{ id: "dall-e-2", name: "DALLE 2", free: true, auth: true },
|
||||
|
||||
{ id: "midjourney", name: "Midjourney", free: false, auth: true },
|
||||
{ id: "midjourney-fast", name: "Midjourney Fast", free: false, auth: true },
|
||||
{ id: "midjourney-turbo", name: "Midjourney Turbo", free: false, auth: true },
|
||||
|
||||
// reverse models
|
||||
{ id: "gpt-4-v", name: "GPT-4 Vision", free: false, auth: true },
|
||||
@ -96,7 +101,11 @@ export const planModels = [
|
||||
"claude-2-100k",
|
||||
];
|
||||
|
||||
export const expensiveModels = ["gpt-4-32k-0613"];
|
||||
export const expensiveModels = [
|
||||
"dall-e-3",
|
||||
"midjourney-turbo",
|
||||
"gpt-4-32k-0613",
|
||||
];
|
||||
|
||||
export function login() {
|
||||
location.href = `https://deeptrain.net/login?app=${dev ? "dev" : "chatnio"}`;
|
||||
|
@ -22,7 +22,6 @@ export const apiSlice = createSlice({
|
||||
state.dialog = false;
|
||||
},
|
||||
setKey: (state, action) => {
|
||||
if (!action.payload.length) return;
|
||||
state.key = action.payload as string;
|
||||
},
|
||||
},
|
||||
@ -35,7 +34,13 @@ export default apiSlice.reducer;
|
||||
export const dialogSelector = (state: RootState): boolean => state.api.dialog;
|
||||
export const keySelector = (state: RootState): string => state.api.key;
|
||||
|
||||
export const getApiKey = async (dispatch: AppDispatch) => {
|
||||
export const getApiKey = async (dispatch: AppDispatch, retries?: boolean) => {
|
||||
const response = await getKey();
|
||||
if (response.status) dispatch(setKey(response.key));
|
||||
if (response.status) {
|
||||
if (response.key.length === 0 && retries !== false) {
|
||||
await getApiKey(dispatch, false);
|
||||
return;
|
||||
}
|
||||
dispatch(setKey(response.key));
|
||||
};
|
||||
};
|
||||
|
@ -11,7 +11,7 @@ func CanEnableModel(db *sql.DB, user *User, model string) bool {
|
||||
switch model {
|
||||
case globals.GPT3Turbo, globals.GPT3TurboInstruct, globals.GPT3Turbo0301, globals.GPT3Turbo0613:
|
||||
return true
|
||||
case globals.GPT4, globals.GPT40613, globals.GPT40314, globals.GPT41106Preview:
|
||||
case globals.GPT4, globals.GPT40613, globals.GPT40314, globals.GPT41106Preview, globals.Dalle3:
|
||||
return user != nil && user.GetQuota(db) >= 5
|
||||
case globals.GPT432k, globals.GPT432k0613, globals.GPT432k0314:
|
||||
return user != nil && user.GetQuota(db) >= 50
|
||||
@ -23,7 +23,7 @@ func CanEnableModel(db *sql.DB, user *User, model string) bool {
|
||||
return user != nil && user.GetQuota(db) >= 1
|
||||
case globals.QwenTurbo, globals.QwenPlus, globals.QwenPlusNet, globals.QwenTurboNet:
|
||||
return user != nil && user.GetQuota(db) >= 1
|
||||
case globals.Midjourney, globals.StableDiffusion:
|
||||
case globals.StableDiffusion, globals.Midjourney, globals.MidjourneyFast, globals.MidjourneyTurbo:
|
||||
return user != nil && user.GetQuota(db) >= 1
|
||||
case globals.LLaMa27B, globals.LLaMa213B, globals.LLaMa270B,
|
||||
globals.CodeLLaMa34B, globals.CodeLLaMa13B, globals.CodeLLaMa7B:
|
||||
|
@ -56,7 +56,8 @@ const (
|
||||
GPT432k = "gpt-4-32k"
|
||||
GPT432k0314 = "gpt-4-32k-0314"
|
||||
GPT432k0613 = "gpt-4-32k-0613"
|
||||
Dalle2 = "dalle"
|
||||
Dalle2 = "dall-e-2"
|
||||
Dalle3 = "dall-e-3"
|
||||
Claude1 = "claude-1"
|
||||
Claude1100k = "claude-1.3"
|
||||
Claude2 = "claude-1-100k"
|
||||
@ -78,6 +79,8 @@ const (
|
||||
QwenTurboNet = "qwen-turbo-net"
|
||||
QwenPlusNet = "qwen-plus-net"
|
||||
Midjourney = "midjourney"
|
||||
MidjourneyFast = "midjourney-fast"
|
||||
MidjourneyTurbo = "midjourney-turbo"
|
||||
StableDiffusion = "stable-diffusion"
|
||||
LLaMa270B = "llama-2-70b"
|
||||
LLaMa213B = "llama-2-13b"
|
||||
@ -148,6 +151,12 @@ var QwenModelArray = []string{
|
||||
QwenPlusNet,
|
||||
}
|
||||
|
||||
var MidjourneyModelArray = []string{
|
||||
Midjourney,
|
||||
MidjourneyFast,
|
||||
MidjourneyTurbo,
|
||||
}
|
||||
|
||||
var LongContextModelArray = []string{
|
||||
GPT3Turbo16k, GPT3Turbo16k0613, GPT3Turbo16k0301,
|
||||
GPT41106Preview, GPT432k, GPT432k0314, GPT432k0613,
|
||||
@ -179,14 +188,14 @@ var AllModels = []string{
|
||||
GPT3Turbo16k, GPT3Turbo16k0613, GPT3Turbo16k0301,
|
||||
GPT4, GPT40314, GPT40613, GPT4Vision, GPT4All, GPT41106Preview, GPT4Dalle,
|
||||
GPT432k, GPT432k0314, GPT432k0613,
|
||||
Dalle2,
|
||||
Dalle2, Dalle3,
|
||||
Claude1, Claude1100k, Claude2, Claude2100k, ClaudeSlack,
|
||||
SparkDesk, SparkDeskV2, SparkDeskV3,
|
||||
ChatBison001,
|
||||
BingCreative, BingBalanced, BingPrecise,
|
||||
ZhiPuChatGLMTurbo, ZhiPuChatGLMPro, ZhiPuChatGLMStd, ZhiPuChatGLMLite,
|
||||
QwenTurbo, QwenPlus, QwenTurboNet, QwenPlusNet,
|
||||
Midjourney, StableDiffusion,
|
||||
StableDiffusion, Midjourney, MidjourneyFast, MidjourneyTurbo,
|
||||
LLaMa270B, LLaMa213B, LLaMa27B,
|
||||
CodeLLaMa34B, CodeLLaMa13B, CodeLLaMa7B,
|
||||
}
|
||||
@ -213,7 +222,7 @@ func IsGPT3TurboModel(model string) bool {
|
||||
}
|
||||
|
||||
func IsChatGPTModel(model string) bool {
|
||||
return IsGPT3TurboModel(model) || IsGPT4Model(model)
|
||||
return IsGPT3TurboModel(model) || IsGPT4Model(model) || IsDalleModel(model)
|
||||
}
|
||||
|
||||
func IsClaudeModel(model string) bool {
|
||||
@ -224,6 +233,10 @@ func IsLLaMaModel(model string) bool {
|
||||
return in(model, LLaMaModelArray)
|
||||
}
|
||||
|
||||
func IsDalleModel(model string) bool {
|
||||
return model == Dalle2 || model == Dalle3
|
||||
}
|
||||
|
||||
func IsClaude100KModel(model string) bool {
|
||||
return model == Claude1100k || model == Claude2100k
|
||||
}
|
||||
@ -252,6 +265,10 @@ func IsQwenModel(model string) bool {
|
||||
return in(model, QwenModelArray)
|
||||
}
|
||||
|
||||
func IsMidjourneyModel(model string) bool {
|
||||
return in(model, MidjourneyModelArray)
|
||||
}
|
||||
|
||||
func IsLongContextModel(model string) bool {
|
||||
return in(model, LongContextModelArray)
|
||||
}
|
||||
|
2
main.go
2
main.go
@ -1,6 +1,7 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"chat/adapter"
|
||||
"chat/addition"
|
||||
"chat/admin"
|
||||
"chat/auth"
|
||||
@ -29,6 +30,7 @@ func main() {
|
||||
{
|
||||
auth.Register(app)
|
||||
admin.Register(app)
|
||||
adapter.Register(app)
|
||||
manager.Register(app)
|
||||
addition.Register(app)
|
||||
conversation.Register(app)
|
||||
|
@ -3,6 +3,7 @@ package utils
|
||||
import (
|
||||
"fmt"
|
||||
"github.com/goccy/go-json"
|
||||
"time"
|
||||
)
|
||||
|
||||
func Sum[T int | int64 | float32 | float64](arr []T) T {
|
||||
@ -149,3 +150,7 @@ func EachNotNil[T any, U any](arr []T, f func(T) *U) []U {
|
||||
}
|
||||
return res
|
||||
}
|
||||
|
||||
func Sleep(ms int) {
|
||||
time.Sleep(time.Duration(ms) * time.Millisecond)
|
||||
}
|
||||
|
@ -31,6 +31,19 @@ func SetInt(cache *redis.Client, key string, value int64, expiration int64) erro
|
||||
return cache.Set(context.Background(), key, value, time.Duration(expiration)*time.Second).Err()
|
||||
}
|
||||
|
||||
func SetJson(cache *redis.Client, key string, value interface{}, expiration int64) error {
|
||||
err := cache.Set(context.Background(), key, Marshal(value), time.Duration(expiration)*time.Second).Err()
|
||||
return err
|
||||
}
|
||||
|
||||
func GetJson[T any](cache *redis.Client, key string) *T {
|
||||
val, err := cache.Get(context.Background(), key).Result()
|
||||
if err != nil {
|
||||
return nil
|
||||
}
|
||||
return UnmarshalForm[T](val)
|
||||
}
|
||||
|
||||
func IncrWithLimit(cache *redis.Client, key string, delta int64, limit int64, expiration int64) bool {
|
||||
// not exist
|
||||
if _, err := cache.Get(context.Background(), key).Result(); err != nil {
|
||||
|
@ -69,7 +69,7 @@ func MapToStruct[T any](data interface{}) *T {
|
||||
}
|
||||
}
|
||||
|
||||
func ToInt(value string) int {
|
||||
func ParseInt(value string) int {
|
||||
if res, err := strconv.Atoi(value); err == nil {
|
||||
return res
|
||||
} else {
|
||||
|
@ -145,6 +145,12 @@ func CountOutputToken(model string, t int) float32 {
|
||||
return 0.25
|
||||
case globals.Midjourney:
|
||||
return 0.5
|
||||
case globals.MidjourneyFast:
|
||||
return 2
|
||||
case globals.MidjourneyTurbo:
|
||||
return 5
|
||||
case globals.Dalle3:
|
||||
return 5.6
|
||||
default:
|
||||
return 0
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user