feat: support midjourney upscale / variation / reroll actions (#40)

This commit is contained in:
Zhang Minghan 2024-02-02 09:01:12 +08:00
parent 0c44eed4ec
commit 1f52d56e46
27 changed files with 467 additions and 135 deletions

View File

@ -69,7 +69,7 @@ func (c *ChatInstance) GetChatBody(props *ChatProps, stream bool) interface{} {
// CreateChatRequest is the native http request body for chatgpt
func (c *ChatInstance) CreateChatRequest(props *ChatProps) (string, error) {
if globals.IsDalleModel(props.Model) {
if globals.IsOpenAIDalleModel(props.Model) {
return c.CreateImage(props)
}
@ -94,7 +94,7 @@ func (c *ChatInstance) CreateChatRequest(props *ChatProps) (string, error) {
// CreateStreamChatRequest is the stream response body for chatgpt
func (c *ChatInstance) CreateStreamChatRequest(props *ChatProps, callback globals.Hook) error {
if globals.IsDalleModel(props.Model) {
if globals.IsOpenAIDalleModel(props.Model) {
if url, err := c.CreateImage(props); err != nil {
return err
} else {

View File

@ -35,7 +35,7 @@ func formatMessages(props *ChatProps) interface{} {
}
props.Message[len(props.Message)-1].Content = base
return props.Message
} else if globals.IsGPT41106VisionPreview(props.Model) {
} else if globals.IsOpenAIVisionModels(props.Model) {
return utils.Each[globals.Message, Message](props.Message, func(message globals.Message) Message {
if message.Role == globals.User {
urls := utils.ExtractImageUrls(message.Content)

View File

@ -70,7 +70,7 @@ func (c *ChatInstance) GetChatBody(props *ChatProps, stream bool) interface{} {
// CreateChatRequest is the native http request body for chatgpt
func (c *ChatInstance) CreateChatRequest(props *ChatProps) (string, error) {
if globals.IsDalleModel(props.Model) {
if globals.IsOpenAIDalleModel(props.Model) {
return c.CreateImage(props)
}
@ -95,7 +95,7 @@ func (c *ChatInstance) CreateChatRequest(props *ChatProps) (string, error) {
// CreateStreamChatRequest is the stream response body for chatgpt
func (c *ChatInstance) CreateStreamChatRequest(props *ChatProps, callback globals.Hook) error {
if globals.IsDalleModel(props.Model) {
if globals.IsOpenAIDalleModel(props.Model) {
if url, err := c.CreateImage(props); err != nil {
return err
} else {

View File

@ -35,7 +35,7 @@ func formatMessages(props *ChatProps) interface{} {
}
props.Message[len(props.Message)-1].Content = base
return props.Message
} else if globals.IsGPT41106VisionPreview(props.Model) {
} else if globals.IsOpenAIVisionModels(props.Model) {
return utils.Each[globals.Message, Message](props.Message, func(message globals.Message) Message {
if message.Role == globals.User {
urls := utils.ExtractImageUrls(message.Content)

View File

@ -1,118 +1,59 @@
package midjourney
import (
"chat/globals"
"chat/utils"
"fmt"
"strings"
)
var midjourneyEmptySecret = "null"
func (c *ChatInstance) GetImagineUrl() string {
func (c *ChatInstance) GetImagineEndpoint() string {
return fmt.Sprintf("%s/mj/submit/imagine", c.GetEndpoint())
}
func (c *ChatInstance) GetImagineHeaders() map[string]string {
secret := c.GetApiSecret()
if secret == "" || secret == midjourneyEmptySecret {
return map[string]string{
"Content-Type": "application/json",
}
}
func (c *ChatInstance) GetChangeEndpoint() string {
return fmt.Sprintf("%s/mj/submit/change", c.GetEndpoint())
}
return map[string]string{
"Content-Type": "application/json",
"mj-api-secret": secret,
func (c *ChatInstance) GetImagineRequest(prompt string) *ImagineRequest {
return &ImagineRequest{
NotifyHook: c.GetNotifyEndpoint(),
Prompt: prompt,
}
}
func (c *ChatInstance) CreateImagineRequest(prompt string) (*ImagineResponse, error) {
func (c *ChatInstance) GetChangeRequest(action string, task string, index *int) *ChangeRequest {
return &ChangeRequest{
NotifyHook: c.GetNotifyEndpoint(),
Action: action,
Index: index,
TaskId: task,
}
}
func (c *ChatInstance) CreateImagineRequest(prompt string) (*CommonResponse, error) {
res, err := utils.Post(
c.GetImagineUrl(),
c.GetImagineHeaders(),
ImagineRequest{
NotifyHook: fmt.Sprintf(
"%s/mj/notify",
globals.NotifyUrl,
),
Prompt: prompt,
},
c.GetImagineEndpoint(),
c.GetMidjourneyHeaders(),
c.GetImagineRequest(prompt),
)
if err != nil {
return nil, err
}
return utils.MapToStruct[ImagineResponse](res), nil
return utils.MapToStruct[CommonResponse](res), nil
}
func getStatusCode(response *ImagineResponse) error {
code := response.Code
switch code {
case SuccessCode, QueueCode:
return nil
case ExistedCode:
return fmt.Errorf("task is existed, please try again later with another prompt")
case MaxQueueCode:
return fmt.Errorf("task queue is full, please try again later")
case NudeCode:
return fmt.Errorf("prompt violates the content policy of midjourney, the request is rejected")
default:
return fmt.Errorf(fmt.Sprintf("unknown error from midjourney (code: %d, description: %s)", code, response.Description))
}
}
func (c *ChatInstance) CreateChangeRequest(action string, task string, index *int) (*CommonResponse, error) {
res, err := utils.Post(
c.GetChangeEndpoint(),
c.GetMidjourneyHeaders(),
c.GetChangeRequest(action, task, index),
)
func getProgress(value string) int {
progress := strings.TrimSuffix(value, "%")
return utils.ParseInt(progress)
}
func (c *ChatInstance) CreateStreamImagineTask(prompt string, hook func(progress int) error) (string, error) {
res, err := c.CreateImagineRequest(prompt)
if err != nil {
return "", err
return nil, err
}
if err := getStatusCode(res); err != nil {
return "", err
}
task := res.Result
progress := -1
for {
utils.Sleep(100)
form := getStorage(task)
if form == nil {
continue
}
switch form.Status {
case Success:
if err := hook(100); err != nil {
return "", err
}
return form.Url, nil
case Failure:
if err := hook(100); err != nil {
return "", err
}
return "", fmt.Errorf("task failed: %s", form.FailReason)
case InProgress:
current := getProgress(form.Progress)
if progress != current {
if err := hook(current); err != nil {
return "", err
}
progress = current
}
}
}
}
func (c *ChatInstance) CreateImagineTask(prompt string) (string, error) {
return c.CreateStreamImagineTask(prompt, func(progress int) error {
return nil
})
fmt.Println(res)
return utils.MapToStruct[CommonResponse](res), nil
}

View File

@ -7,6 +7,21 @@ import (
"strings"
)
const maxActions = 4
const (
ImagineAction = "IMAGINE"
UpscaleAction = "UPSCALE"
VariationAction = "VARIATION"
RerollAction = "REROLL"
)
const (
ImagineCommand = "/IMAGINE"
UpscaleCommand = "/UPSCALE"
VariationCommand = "/VARIATION"
RerollCommand = "/REROLL"
)
type ChatProps struct {
Messages []globals.Message
Model string
@ -30,7 +45,7 @@ func (c *ChatInstance) GetCleanPrompt(model string, prompt string) string {
var res []string
for _, word := range arr {
if utils.Contains[string](word, ModeArr) {
if utils.Contains[string](word, RendererMode) {
continue
}
res = append(res, word)
@ -54,28 +69,69 @@ func (c *ChatInstance) CreateStreamChatRequest(props *ChatProps, callback global
// ```
// ![image](...)
prompt := c.GetPrompt(props)
if prompt == "" {
action, prompt := c.ExtractPrompt(c.GetPrompt(props))
if len(prompt) == 0 {
return fmt.Errorf("format error: please provide available prompt")
}
url, err := c.CreateStreamImagineTask(prompt, func(progress int) error {
var begin bool
form, err := c.CreateStreamTask(action, prompt, func(form *StorageForm, progress int) error {
if progress == 0 {
begin = true
if err := callback("```progress\n"); err != nil {
return err
}
} else if progress == 100 {
} else if progress == 100 && !begin {
if err := callback("```progress\n"); err != nil {
return err
}
}
if err := callback(fmt.Sprintf("%d\n", progress)); err != nil {
return err
}
if progress == 100 {
if err := callback("```\n"); err != nil {
return err
}
}
return callback(fmt.Sprintf("%d\n", progress))
return nil
})
if err != nil {
return fmt.Errorf("error from midjourney: %s", err.Error())
}
return callback(utils.GetImageMarkdown(url))
if err := callback(utils.GetImageMarkdown(form.Url)); err != nil {
return err
}
return c.CallbackActions(form, callback)
}
func toVirtualMessage(message string) string {
return "https://chatnio.virtual" + strings.Replace(message, " ", "-", -1)
}
func (c *ChatInstance) CallbackActions(form *StorageForm, callback globals.Hook) error {
if form.Action == UpscaleAction {
return nil
}
actions := utils.Range(1, maxActions+1)
upscale := strings.Join(utils.Each(actions, func(index int) string {
return fmt.Sprintf("[U%d](%s)", index, toVirtualMessage(fmt.Sprintf("/UPSCALE %s %d", form.Task, index)))
}), " ")
variation := strings.Join(utils.Each(actions, func(index int) string {
return fmt.Sprintf("[V%d](%s)", index, toVirtualMessage(fmt.Sprintf("/VARIATION %s %d", form.Task, index)))
}), " ")
reroll := fmt.Sprintf("[REROLL](%s)", toVirtualMessage(fmt.Sprintf("/REROLL %s", form.Task)))
return callback(fmt.Sprintf("\n\n%s\n\n%s\n\n%s\n", upscale, variation, reroll))
}

View File

@ -55,6 +55,8 @@ func NotifyAPI(c *gin.Context) {
}
err := setStorage(form.Id, StorageForm{
Task: form.Id,
Action: form.Action,
Url: form.ImageUrl,
FailReason: reason,
Progress: form.Progress,

View File

@ -0,0 +1,120 @@
package midjourney
import (
"chat/utils"
"fmt"
"strings"
)
func getStatusCode(action string, response *CommonResponse) error {
code := response.Code
switch code {
case SuccessCode, QueueCode:
return nil
case ExistedCode:
if action != ImagineCommand {
return nil
}
return fmt.Errorf("task is existed, please try again later with another prompt")
case MaxQueueCode:
return fmt.Errorf("task queue is full, please try again later")
case NudeCode:
return fmt.Errorf("prompt violates the content policy of midjourney, the request is rejected")
default:
return fmt.Errorf(fmt.Sprintf("unknown error from midjourney (code: %d, description: %s)", code, response.Description))
}
}
func getProgress(value string) int {
progress := strings.TrimSuffix(value, "%")
return utils.ParseInt(progress)
}
func (c *ChatInstance) GetAction(command string) string {
return strings.TrimLeft(command, "/")
}
func (c *ChatInstance) ExtractPrompt(input string) (action string, prompt string) {
segment := utils.SafeSplit(input, " ", 2)
action = strings.TrimSpace(segment[0])
prompt = strings.TrimSpace(segment[1])
switch action {
case ImagineCommand, VariationCommand, UpscaleCommand, RerollCommand:
return
default:
return ImagineCommand, strings.TrimSpace(input)
}
}
func (c *ChatInstance) ExtractCommand(input string) (task string, index *int) {
segment := utils.SafeSplit(input, " ", 2)
task = strings.TrimSpace(segment[0])
if segment[1] != "" {
data := segment[1]
slice := strings.Split(segment[1], " ")
if len(slice) > 1 {
data = slice[0]
}
index = utils.ToPtr(utils.ParseInt(strings.TrimSpace(data)))
}
return
}
func (c *ChatInstance) CreateRequest(action string, prompt string) (*CommonResponse, error) {
switch action {
case ImagineCommand:
return c.CreateImagineRequest(prompt)
case VariationCommand, UpscaleCommand, RerollCommand:
task, index := c.ExtractCommand(prompt)
return c.CreateChangeRequest(c.GetAction(action), task, index)
default:
return nil, fmt.Errorf("unknown action: %s", action)
}
}
func (c *ChatInstance) CreateStreamTask(action string, prompt string, hook func(form *StorageForm, progress int) error) (*StorageForm, error) {
res, err := c.CreateRequest(action, prompt)
if err != nil {
return nil, err
}
if err := getStatusCode(action, res); err != nil {
return nil, err
}
task := res.Result
progress := -1
for {
utils.Sleep(100)
form := getStorage(task)
if form == nil {
continue
}
switch form.Status {
case Success:
if err := hook(form, 100); err != nil {
return nil, err
}
return form, nil
case Failure:
return nil, fmt.Errorf("task failed: %s", form.FailReason)
case InProgress:
current := getProgress(form.Progress)
if progress != current {
if err := hook(form, current); err != nil {
return nil, err
}
progress = current
}
}
}
}

View File

@ -2,8 +2,11 @@ package midjourney
import (
"chat/globals"
"fmt"
)
var midjourneyEmptySecret = "null"
type ChatInstance struct {
Endpoint string
ApiSecret string
@ -17,6 +20,24 @@ func (c *ChatInstance) GetEndpoint() string {
return c.Endpoint
}
func (c *ChatInstance) GetMidjourneyHeaders() map[string]string {
secret := c.GetApiSecret()
if secret == "" || secret == midjourneyEmptySecret {
return map[string]string{
"Content-Type": "application/json",
}
}
return map[string]string{
"Content-Type": "application/json",
"mj-api-secret": secret,
}
}
func (c *ChatInstance) GetNotifyEndpoint() string {
return fmt.Sprintf("%s/mj/notify", globals.NotifyUrl)
}
func NewChatInstance(endpoint, apiSecret, whiteList string) *ChatInstance {
SaveWhiteList(whiteList)

View File

@ -22,22 +22,29 @@ const (
RelaxMode = "--relax"
)
var ModeArr = []string{TurboMode, FastMode, RelaxMode}
var RendererMode = []string{TurboMode, FastMode, RelaxMode}
type ImagineHeader struct {
type CommonHeader struct {
ContentType string `json:"Content-Type"`
MjApiSecret string `json:"mj-api-secret,omitempty"`
}
type CommonResponse struct {
Code int `json:"code"`
Description string `json:"description"`
Result string `json:"result"`
}
type ImagineRequest struct {
NotifyHook string `json:"notifyHook"`
Prompt string `json:"prompt"`
}
type ImagineResponse struct {
Code int `json:"code"`
Description string `json:"description"`
Result string `json:"result"`
type ChangeRequest struct {
NotifyHook string `json:"notifyHook"`
Action string `json:"action"`
Index *int `json:"index,omitempty"`
TaskId string `json:"taskId"`
}
type NotifyForm struct {
@ -56,6 +63,8 @@ type NotifyForm struct {
}
type StorageForm struct {
Task string `json:"task"`
Action string `json:"action"`
Url string `json:"url"`
FailReason string `json:"failReason"`
Progress string `json:"progress"`

View File

@ -105,3 +105,37 @@
}
}
}
.virtual-prompt {
text-align: center;
border-radius: var(--radius);
border: 1px solid hsl(var(--border));
padding: 0.5rem;
}
.virtual-action {
svg {
transform: translateY(1px);
}
}
p:has(.virtual-action) {
display: flex;
flex-direction: row;
flex-wrap: wrap;
align-items: center;
justify-content: center;
margin-top: 1.25rem !important;
@media (min-width: 668px) {
.virtual-action {
margin-right: auto;
&:last-child {
margin-right: 0;
}
}
}
}

View File

@ -198,7 +198,7 @@
box-sizing: content-box;
background-color: var(--color-canvas-default);
border-radius: 4px;
max-height: 35vh;
max-height: 50vh;
}
.markdown-body code,

View File

@ -18,7 +18,10 @@ import {
Codesandbox,
Copy,
Github,
Maximize,
RefreshCcwDot,
Twitter,
Wand2,
Youtube,
} from "lucide-react";
import { copyClipboard } from "@/utils/dom.ts";
@ -26,6 +29,18 @@ import { useToast } from "./ui/use-toast.ts";
import { useTranslation } from "react-i18next";
import { parseProgressbar } from "@/components/plugins/progress.tsx";
import { cn } from "@/components/ui/lib/utils.ts";
import { Button } from "@/components/ui/button.tsx";
import {
Dialog,
DialogContent,
DialogDescription,
DialogFooter,
DialogHeader,
DialogTitle,
DialogTrigger,
} from "@/components/ui/dialog.tsx";
import { DialogClose } from "@radix-ui/react-dialog";
import { posterEvent } from "@/events/poster.ts";
type MarkdownProps = {
children: string;
@ -73,6 +88,17 @@ function getSocialIcon(url: string) {
}
}
function getVirtualIcon(command: string) {
switch (command) {
case "/VARIATION":
return <Wand2 className="h-4 w-4 inline-block mr-2" />;
case "/UPSCALE":
return <Maximize className="h-4 w-4 inline-block mr-2" />;
case "/REROLL":
return <RefreshCcwDot className="h-4 w-4 inline-block mr-2" />;
}
}
function MarkdownContent({
children,
className,
@ -107,6 +133,43 @@ function MarkdownContent({
a({ href, children }) {
const url: string = href?.toString() || "";
if (url.startsWith("https://chatnio.virtual")) {
const message = url.slice(23).replace(/-/g, " ");
const prefix = message.split(" ")[0];
const send = () => posterEvent.emit(message);
return (
<Dialog>
<DialogTrigger asChild>
<Button
variant={`outline`}
className={`flex flex-row items-center virtual-action mx-1 my-0.5 min-w-[4rem]`}
>
{getVirtualIcon(prefix)}
{children}
</Button>
</DialogTrigger>
<DialogContent>
<DialogHeader>
<DialogTitle>{t("chat.send-message")}</DialogTitle>
<DialogDescription className={`pb-2`}>
{t("chat.send-message-desc")}
</DialogDescription>
<p className={`virtual-prompt`}>{message}</p>
</DialogHeader>
<DialogFooter>
<DialogClose asChild>
<Button variant={`outline`}>{t("cancel")}</Button>
</DialogClose>
<DialogClose onClick={send} asChild>
<Button variant={`default`}>{t("confirm")}</Button>
</DialogClose>
</DialogFooter>
</DialogContent>
</Dialog>
);
}
return (
<a
href={url}

View File

@ -158,7 +158,13 @@ function MessageContent({ message, end, index, onEvent }: MessageProps) {
<MousePointerSquare className={`h-4 w-4 mr-1.5`} />
{t("message.use")}
</DropdownMenuItem>
<DropdownMenuItem onClick={() => setOpen(true)}>
<DropdownMenuItem
onClick={() => {
editedMessage?.length === 0 &&
setEditedMessage(message.content);
setOpen(true);
}}
>
<PencilLine className={`h-4 w-4 mr-1.5`} />
{t("message.edit")}
</DropdownMenuItem>

View File

@ -46,6 +46,7 @@ import { chatEvent } from "@/events/chat.ts";
import { cn } from "@/components/ui/lib/utils.ts";
import { goAuth } from "@/utils/app.ts";
import { getModelFromId } from "@/conf/model.ts";
import { posterEvent } from "@/events/poster.ts";
type InterfaceProps = {
setWorking: (working: boolean) => void;
@ -105,8 +106,11 @@ function ChatWrapper() {
setFiles([]);
}
async function processSend(data: string): Promise<boolean> {
if (requireAuth && !auth) {
async function processSend(
data: string,
passAuth?: boolean,
): Promise<boolean> {
if (requireAuth && !auth && !passAuth) {
toast({
title: t("login-require"),
action: (
@ -159,6 +163,10 @@ function ChatWrapper() {
});
}
useEffect(() => {
posterEvent.bind((data) => processSend(data, true));
}, []);
useEffect(() => {
window.addEventListener("load", () => {
const el = document.getElementById("input");

View File

@ -1,5 +1,6 @@
import { useMemo } from "react";
import { parseNumber } from "@/utils/base.ts";
import { Check, Loader2 } from "lucide-react";
export function parseProgressbar(data: string) {
const progress = useMemo(() => {
@ -11,7 +12,16 @@ export function parseProgressbar(data: string) {
return (
<div className={`progress`}>
<p className={`text-primary select-none text-center text-white`}>
<p
className={`flex flex-row items-center justify-center text-primary select-none text-center text-white px-6`}
>
{progress < 100 ? (
<Loader2
className={`h-4 w-4 mr-2 inline-block animate-spin shrink-0`}
/>
) : (
<Check className={`h-4 w-4 mr-2 inline-block animate-out shrink-0`} />
)}
Generating: {progress < 0 ? 0 : progress.toFixed()}%
</p>
{progress > 0 && (
@ -24,7 +34,7 @@ export function parseProgressbar(data: string) {
data-max={100}
>
<p
className={`h-full w-full flex-1 bg-primary transition-all`}
className={`h-full w-full flex-1 bg-primary transition-all duration-300`}
style={{ transform: `translateX(-${100 - progress}%)` }}
data-max={100}
/>

View File

@ -66,8 +66,8 @@ export function Combobox({
<CommandItem
key={key}
value={key}
onSelect={(current) => {
onChange(current);
onSelect={() => {
onChange(key);
setOpen(false);
}}
>

5
app/src/events/poster.ts Normal file
View File

@ -0,0 +1,5 @@
import { EventCommitter } from "@/events/struct.ts";
export const posterEvent = new EventCommitter<string>({
name: "poster",
});

View File

@ -132,7 +132,9 @@
"placeholder-raw": "写点什么...",
"recall": "历史复原",
"recall-desc": "检测到您上次有未发送的消息,已经为您恢复。",
"recall-cancel": "取消"
"recall-cancel": "取消",
"send-message": "发送消息",
"send-message-desc": "是否确认发送此消息?"
},
"message": {
"copy": "复制消息",

View File

@ -80,7 +80,9 @@
"recall-desc": "Detected that you have unsent messages last time, has been restored for you.",
"recall-cancel": "Cancel",
"placeholder-enter": "Write something... (Enter to send)",
"placeholder-raw": "Write something..."
"placeholder-raw": "Write something...",
"send-message": "Send Message",
"send-message-desc": "Are you sure you want to send this message?"
},
"message": {
"copy": "Copy Message",

View File

@ -80,7 +80,9 @@
"recall-desc": "最後に未送信のメッセージが検出され、復元されました。",
"recall-cancel": "キャンセル",
"placeholder-enter": "何か書いてください... 送信するにはEnterキーを押してください",
"placeholder-raw": "何か書いてください..."
"placeholder-raw": "何か書いてください...",
"send-message": "メッセージを送信",
"send-message-desc": "このメッセージを送信してもよろしいですか?"
},
"message": {
"copy": "メッセージをコピー",

View File

@ -80,7 +80,9 @@
"recall-desc": "Обнаружено, что у вас есть неотправленные сообщения в прошлый раз, они были восстановлены для вас.",
"recall-cancel": "Отмена",
"placeholder-enter": "Напишите что-нибудь... (Введите, чтобы отправить)",
"placeholder-raw": "Напишите что-нибудь..."
"placeholder-raw": "Напишите что-нибудь...",
"send-message": "Отправить",
"send-message-desc": "Вы уверены, что хотите отправить это сообщение?"
},
"message": {
"copy": "Копировать сообщение",

View File

@ -100,26 +100,29 @@ const (
SkylarkChat = "skylark-chat"
)
var DalleModels = []string{
var OpenAIDalleModels = []string{
Dalle, Dalle2, Dalle3,
}
var OpenAIVisionModels = []string{
GPT4VisionPreview, GPT41106VisionPreview,
}
func in(value string, slice []string) bool {
for _, item := range slice {
if item == value {
if item == value || strings.Contains(value, item) {
return true
}
}
return false
}
func IsDalleModel(model string) bool {
func IsOpenAIDalleModel(model string) bool {
// using image generation api if model is in dalle models
return in(model, DalleModels)
return in(model, OpenAIDalleModels)
}
func IsGPT41106VisionPreview(model string) bool {
// enable openai image format for gpt-4-vision-preview model
return (model == GPT41106VisionPreview || strings.Contains(model, GPT41106VisionPreview)) ||
(model == GPT4VisionPreview || strings.Contains(model, GPT4VisionPreview))
func IsOpenAIVisionModels(model string) bool {
// enable openai image format for gpt-4-vision-preview models
return in(model, OpenAIVisionModels)
}

View File

@ -309,3 +309,11 @@ func Any(arr ...bool) bool {
}
return false
}
func Range(start int, end int) []int {
var res []int
for i := start; i < end; i++ {
res = append(res, i)
}
return res
}

View File

@ -255,3 +255,35 @@ func SortString(arr []string) []string {
return result
}
func SafeSplit(data string, sep string, seglen int) (res []string) {
// split string by sep, and each segment has seglen length
// e.g. SafeSplit("abc,def,ghi", ",", 2) => ["abc", "def,ghi"]
if data == "" {
for i := 0; i < seglen; i++ {
res = append(res, "")
}
return res
}
arr := strings.Split(data, sep)
length := len(arr)
if length == seglen {
return arr
}
if length < seglen {
for i := 0; i < seglen-length; i++ {
arr = append(arr, "")
}
return arr
} else {
offset := length - seglen
for i := 0; i < offset; i++ {
arr[seglen-1] += sep + arr[seglen+i]
}
return arr[:seglen]
}
}

View File

@ -11,6 +11,11 @@ import (
var configFile = "config/config.yaml"
var configExampleFile = "config.example.yaml"
var redirectRoutes = []string{
"/v1",
"/mj",
}
func ReadConf() {
viper.SetConfigFile(configFile)
@ -58,11 +63,12 @@ func RegisterStaticRoute(engine *gin.Engine) {
c.File("./app/dist/index.html")
})
// redirect /v1 to /api/v1
engine.Any("/v1/*path", func(c *gin.Context) {
path := c.Param("path")
c.Redirect(301, fmt.Sprintf("/api/v1/%s", path))
})
for _, route := range redirectRoutes {
engine.Any(fmt.Sprintf("%s/*path", route), func(c *gin.Context) {
path := c.Param("path")
c.Redirect(301, fmt.Sprintf("/api%s/%s", route, path))
})
}
fmt.Println(`[service] start serving static files from ~/app/dist`)
}

View File

@ -86,7 +86,7 @@ func (i *Image) GetPixelColor(x int, y int) (int, int, int) {
}
func (i *Image) CountTokens(model string) int {
if globals.IsGPT41106VisionPreview(model) {
if globals.IsOpenAIVisionModels(model) {
// tile size is 512x512
// the max size of image is 2048x2048
// the image that is larger than 2048x2048 will be resized in 16 tiles