mirror of
https://github.com/coaidev/coai.git
synced 2025-05-22 14:30:14 +09:00
feat: support midjourney upscale / variation / reroll actions (#40)
This commit is contained in:
parent
0c44eed4ec
commit
1f52d56e46
@ -69,7 +69,7 @@ func (c *ChatInstance) GetChatBody(props *ChatProps, stream bool) interface{} {
|
||||
|
||||
// CreateChatRequest is the native http request body for chatgpt
|
||||
func (c *ChatInstance) CreateChatRequest(props *ChatProps) (string, error) {
|
||||
if globals.IsDalleModel(props.Model) {
|
||||
if globals.IsOpenAIDalleModel(props.Model) {
|
||||
return c.CreateImage(props)
|
||||
}
|
||||
|
||||
@ -94,7 +94,7 @@ func (c *ChatInstance) CreateChatRequest(props *ChatProps) (string, error) {
|
||||
|
||||
// CreateStreamChatRequest is the stream response body for chatgpt
|
||||
func (c *ChatInstance) CreateStreamChatRequest(props *ChatProps, callback globals.Hook) error {
|
||||
if globals.IsDalleModel(props.Model) {
|
||||
if globals.IsOpenAIDalleModel(props.Model) {
|
||||
if url, err := c.CreateImage(props); err != nil {
|
||||
return err
|
||||
} else {
|
||||
|
@ -35,7 +35,7 @@ func formatMessages(props *ChatProps) interface{} {
|
||||
}
|
||||
props.Message[len(props.Message)-1].Content = base
|
||||
return props.Message
|
||||
} else if globals.IsGPT41106VisionPreview(props.Model) {
|
||||
} else if globals.IsOpenAIVisionModels(props.Model) {
|
||||
return utils.Each[globals.Message, Message](props.Message, func(message globals.Message) Message {
|
||||
if message.Role == globals.User {
|
||||
urls := utils.ExtractImageUrls(message.Content)
|
||||
|
@ -70,7 +70,7 @@ func (c *ChatInstance) GetChatBody(props *ChatProps, stream bool) interface{} {
|
||||
|
||||
// CreateChatRequest is the native http request body for chatgpt
|
||||
func (c *ChatInstance) CreateChatRequest(props *ChatProps) (string, error) {
|
||||
if globals.IsDalleModel(props.Model) {
|
||||
if globals.IsOpenAIDalleModel(props.Model) {
|
||||
return c.CreateImage(props)
|
||||
}
|
||||
|
||||
@ -95,7 +95,7 @@ func (c *ChatInstance) CreateChatRequest(props *ChatProps) (string, error) {
|
||||
|
||||
// CreateStreamChatRequest is the stream response body for chatgpt
|
||||
func (c *ChatInstance) CreateStreamChatRequest(props *ChatProps, callback globals.Hook) error {
|
||||
if globals.IsDalleModel(props.Model) {
|
||||
if globals.IsOpenAIDalleModel(props.Model) {
|
||||
if url, err := c.CreateImage(props); err != nil {
|
||||
return err
|
||||
} else {
|
||||
|
@ -35,7 +35,7 @@ func formatMessages(props *ChatProps) interface{} {
|
||||
}
|
||||
props.Message[len(props.Message)-1].Content = base
|
||||
return props.Message
|
||||
} else if globals.IsGPT41106VisionPreview(props.Model) {
|
||||
} else if globals.IsOpenAIVisionModels(props.Model) {
|
||||
return utils.Each[globals.Message, Message](props.Message, func(message globals.Message) Message {
|
||||
if message.Role == globals.User {
|
||||
urls := utils.ExtractImageUrls(message.Content)
|
||||
|
@ -1,118 +1,59 @@
|
||||
package midjourney
|
||||
|
||||
import (
|
||||
"chat/globals"
|
||||
"chat/utils"
|
||||
"fmt"
|
||||
"strings"
|
||||
)
|
||||
|
||||
var midjourneyEmptySecret = "null"
|
||||
|
||||
func (c *ChatInstance) GetImagineUrl() string {
|
||||
func (c *ChatInstance) GetImagineEndpoint() string {
|
||||
return fmt.Sprintf("%s/mj/submit/imagine", c.GetEndpoint())
|
||||
}
|
||||
|
||||
func (c *ChatInstance) GetImagineHeaders() map[string]string {
|
||||
secret := c.GetApiSecret()
|
||||
if secret == "" || secret == midjourneyEmptySecret {
|
||||
return map[string]string{
|
||||
"Content-Type": "application/json",
|
||||
}
|
||||
}
|
||||
func (c *ChatInstance) GetChangeEndpoint() string {
|
||||
return fmt.Sprintf("%s/mj/submit/change", c.GetEndpoint())
|
||||
}
|
||||
|
||||
return map[string]string{
|
||||
"Content-Type": "application/json",
|
||||
"mj-api-secret": secret,
|
||||
func (c *ChatInstance) GetImagineRequest(prompt string) *ImagineRequest {
|
||||
return &ImagineRequest{
|
||||
NotifyHook: c.GetNotifyEndpoint(),
|
||||
Prompt: prompt,
|
||||
}
|
||||
}
|
||||
|
||||
func (c *ChatInstance) CreateImagineRequest(prompt string) (*ImagineResponse, error) {
|
||||
func (c *ChatInstance) GetChangeRequest(action string, task string, index *int) *ChangeRequest {
|
||||
return &ChangeRequest{
|
||||
NotifyHook: c.GetNotifyEndpoint(),
|
||||
Action: action,
|
||||
Index: index,
|
||||
TaskId: task,
|
||||
}
|
||||
}
|
||||
|
||||
func (c *ChatInstance) CreateImagineRequest(prompt string) (*CommonResponse, error) {
|
||||
res, err := utils.Post(
|
||||
c.GetImagineUrl(),
|
||||
c.GetImagineHeaders(),
|
||||
ImagineRequest{
|
||||
NotifyHook: fmt.Sprintf(
|
||||
"%s/mj/notify",
|
||||
globals.NotifyUrl,
|
||||
),
|
||||
Prompt: prompt,
|
||||
},
|
||||
c.GetImagineEndpoint(),
|
||||
c.GetMidjourneyHeaders(),
|
||||
c.GetImagineRequest(prompt),
|
||||
)
|
||||
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return utils.MapToStruct[ImagineResponse](res), nil
|
||||
return utils.MapToStruct[CommonResponse](res), nil
|
||||
}
|
||||
|
||||
func getStatusCode(response *ImagineResponse) error {
|
||||
code := response.Code
|
||||
switch code {
|
||||
case SuccessCode, QueueCode:
|
||||
return nil
|
||||
case ExistedCode:
|
||||
return fmt.Errorf("task is existed, please try again later with another prompt")
|
||||
case MaxQueueCode:
|
||||
return fmt.Errorf("task queue is full, please try again later")
|
||||
case NudeCode:
|
||||
return fmt.Errorf("prompt violates the content policy of midjourney, the request is rejected")
|
||||
default:
|
||||
return fmt.Errorf(fmt.Sprintf("unknown error from midjourney (code: %d, description: %s)", code, response.Description))
|
||||
}
|
||||
}
|
||||
func (c *ChatInstance) CreateChangeRequest(action string, task string, index *int) (*CommonResponse, error) {
|
||||
res, err := utils.Post(
|
||||
c.GetChangeEndpoint(),
|
||||
c.GetMidjourneyHeaders(),
|
||||
c.GetChangeRequest(action, task, index),
|
||||
)
|
||||
|
||||
func getProgress(value string) int {
|
||||
progress := strings.TrimSuffix(value, "%")
|
||||
return utils.ParseInt(progress)
|
||||
}
|
||||
|
||||
func (c *ChatInstance) CreateStreamImagineTask(prompt string, hook func(progress int) error) (string, error) {
|
||||
res, err := c.CreateImagineRequest(prompt)
|
||||
if err != nil {
|
||||
return "", err
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if err := getStatusCode(res); err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
task := res.Result
|
||||
progress := -1
|
||||
|
||||
for {
|
||||
utils.Sleep(100)
|
||||
form := getStorage(task)
|
||||
if form == nil {
|
||||
continue
|
||||
}
|
||||
|
||||
switch form.Status {
|
||||
case Success:
|
||||
if err := hook(100); err != nil {
|
||||
return "", err
|
||||
}
|
||||
return form.Url, nil
|
||||
case Failure:
|
||||
if err := hook(100); err != nil {
|
||||
return "", err
|
||||
}
|
||||
return "", fmt.Errorf("task failed: %s", form.FailReason)
|
||||
case InProgress:
|
||||
current := getProgress(form.Progress)
|
||||
if progress != current {
|
||||
if err := hook(current); err != nil {
|
||||
return "", err
|
||||
}
|
||||
progress = current
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (c *ChatInstance) CreateImagineTask(prompt string) (string, error) {
|
||||
return c.CreateStreamImagineTask(prompt, func(progress int) error {
|
||||
return nil
|
||||
})
|
||||
fmt.Println(res)
|
||||
return utils.MapToStruct[CommonResponse](res), nil
|
||||
}
|
||||
|
@ -7,6 +7,21 @@ import (
|
||||
"strings"
|
||||
)
|
||||
|
||||
const maxActions = 4
|
||||
const (
|
||||
ImagineAction = "IMAGINE"
|
||||
UpscaleAction = "UPSCALE"
|
||||
VariationAction = "VARIATION"
|
||||
RerollAction = "REROLL"
|
||||
)
|
||||
|
||||
const (
|
||||
ImagineCommand = "/IMAGINE"
|
||||
UpscaleCommand = "/UPSCALE"
|
||||
VariationCommand = "/VARIATION"
|
||||
RerollCommand = "/REROLL"
|
||||
)
|
||||
|
||||
type ChatProps struct {
|
||||
Messages []globals.Message
|
||||
Model string
|
||||
@ -30,7 +45,7 @@ func (c *ChatInstance) GetCleanPrompt(model string, prompt string) string {
|
||||
var res []string
|
||||
|
||||
for _, word := range arr {
|
||||
if utils.Contains[string](word, ModeArr) {
|
||||
if utils.Contains[string](word, RendererMode) {
|
||||
continue
|
||||
}
|
||||
res = append(res, word)
|
||||
@ -54,28 +69,69 @@ func (c *ChatInstance) CreateStreamChatRequest(props *ChatProps, callback global
|
||||
// ```
|
||||
// 
|
||||
|
||||
prompt := c.GetPrompt(props)
|
||||
if prompt == "" {
|
||||
action, prompt := c.ExtractPrompt(c.GetPrompt(props))
|
||||
if len(prompt) == 0 {
|
||||
return fmt.Errorf("format error: please provide available prompt")
|
||||
}
|
||||
|
||||
url, err := c.CreateStreamImagineTask(prompt, func(progress int) error {
|
||||
var begin bool
|
||||
|
||||
form, err := c.CreateStreamTask(action, prompt, func(form *StorageForm, progress int) error {
|
||||
if progress == 0 {
|
||||
begin = true
|
||||
if err := callback("```progress\n"); err != nil {
|
||||
return err
|
||||
}
|
||||
} else if progress == 100 {
|
||||
} else if progress == 100 && !begin {
|
||||
if err := callback("```progress\n"); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
if err := callback(fmt.Sprintf("%d\n", progress)); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if progress == 100 {
|
||||
if err := callback("```\n"); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
return callback(fmt.Sprintf("%d\n", progress))
|
||||
return nil
|
||||
})
|
||||
|
||||
if err != nil {
|
||||
return fmt.Errorf("error from midjourney: %s", err.Error())
|
||||
}
|
||||
|
||||
return callback(utils.GetImageMarkdown(url))
|
||||
if err := callback(utils.GetImageMarkdown(form.Url)); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return c.CallbackActions(form, callback)
|
||||
}
|
||||
|
||||
func toVirtualMessage(message string) string {
|
||||
return "https://chatnio.virtual" + strings.Replace(message, " ", "-", -1)
|
||||
}
|
||||
|
||||
func (c *ChatInstance) CallbackActions(form *StorageForm, callback globals.Hook) error {
|
||||
if form.Action == UpscaleAction {
|
||||
return nil
|
||||
}
|
||||
|
||||
actions := utils.Range(1, maxActions+1)
|
||||
|
||||
upscale := strings.Join(utils.Each(actions, func(index int) string {
|
||||
return fmt.Sprintf("[U%d](%s)", index, toVirtualMessage(fmt.Sprintf("/UPSCALE %s %d", form.Task, index)))
|
||||
}), " ")
|
||||
|
||||
variation := strings.Join(utils.Each(actions, func(index int) string {
|
||||
return fmt.Sprintf("[V%d](%s)", index, toVirtualMessage(fmt.Sprintf("/VARIATION %s %d", form.Task, index)))
|
||||
}), " ")
|
||||
|
||||
reroll := fmt.Sprintf("[REROLL](%s)", toVirtualMessage(fmt.Sprintf("/REROLL %s", form.Task)))
|
||||
|
||||
return callback(fmt.Sprintf("\n\n%s\n\n%s\n\n%s\n", upscale, variation, reroll))
|
||||
}
|
||||
|
@ -55,6 +55,8 @@ func NotifyAPI(c *gin.Context) {
|
||||
}
|
||||
|
||||
err := setStorage(form.Id, StorageForm{
|
||||
Task: form.Id,
|
||||
Action: form.Action,
|
||||
Url: form.ImageUrl,
|
||||
FailReason: reason,
|
||||
Progress: form.Progress,
|
||||
|
120
adapter/midjourney/handler.go
Normal file
120
adapter/midjourney/handler.go
Normal file
@ -0,0 +1,120 @@
|
||||
package midjourney
|
||||
|
||||
import (
|
||||
"chat/utils"
|
||||
"fmt"
|
||||
"strings"
|
||||
)
|
||||
|
||||
func getStatusCode(action string, response *CommonResponse) error {
|
||||
code := response.Code
|
||||
switch code {
|
||||
case SuccessCode, QueueCode:
|
||||
return nil
|
||||
case ExistedCode:
|
||||
if action != ImagineCommand {
|
||||
return nil
|
||||
}
|
||||
return fmt.Errorf("task is existed, please try again later with another prompt")
|
||||
case MaxQueueCode:
|
||||
return fmt.Errorf("task queue is full, please try again later")
|
||||
case NudeCode:
|
||||
return fmt.Errorf("prompt violates the content policy of midjourney, the request is rejected")
|
||||
default:
|
||||
return fmt.Errorf(fmt.Sprintf("unknown error from midjourney (code: %d, description: %s)", code, response.Description))
|
||||
}
|
||||
}
|
||||
|
||||
func getProgress(value string) int {
|
||||
progress := strings.TrimSuffix(value, "%")
|
||||
return utils.ParseInt(progress)
|
||||
}
|
||||
|
||||
func (c *ChatInstance) GetAction(command string) string {
|
||||
return strings.TrimLeft(command, "/")
|
||||
}
|
||||
|
||||
func (c *ChatInstance) ExtractPrompt(input string) (action string, prompt string) {
|
||||
segment := utils.SafeSplit(input, " ", 2)
|
||||
|
||||
action = strings.TrimSpace(segment[0])
|
||||
prompt = strings.TrimSpace(segment[1])
|
||||
|
||||
switch action {
|
||||
case ImagineCommand, VariationCommand, UpscaleCommand, RerollCommand:
|
||||
return
|
||||
default:
|
||||
return ImagineCommand, strings.TrimSpace(input)
|
||||
}
|
||||
}
|
||||
|
||||
func (c *ChatInstance) ExtractCommand(input string) (task string, index *int) {
|
||||
segment := utils.SafeSplit(input, " ", 2)
|
||||
|
||||
task = strings.TrimSpace(segment[0])
|
||||
|
||||
if segment[1] != "" {
|
||||
data := segment[1]
|
||||
slice := strings.Split(segment[1], " ")
|
||||
if len(slice) > 1 {
|
||||
data = slice[0]
|
||||
}
|
||||
|
||||
index = utils.ToPtr(utils.ParseInt(strings.TrimSpace(data)))
|
||||
}
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
func (c *ChatInstance) CreateRequest(action string, prompt string) (*CommonResponse, error) {
|
||||
switch action {
|
||||
case ImagineCommand:
|
||||
return c.CreateImagineRequest(prompt)
|
||||
case VariationCommand, UpscaleCommand, RerollCommand:
|
||||
task, index := c.ExtractCommand(prompt)
|
||||
|
||||
return c.CreateChangeRequest(c.GetAction(action), task, index)
|
||||
default:
|
||||
return nil, fmt.Errorf("unknown action: %s", action)
|
||||
}
|
||||
}
|
||||
|
||||
func (c *ChatInstance) CreateStreamTask(action string, prompt string, hook func(form *StorageForm, progress int) error) (*StorageForm, error) {
|
||||
res, err := c.CreateRequest(action, prompt)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if err := getStatusCode(action, res); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
task := res.Result
|
||||
progress := -1
|
||||
|
||||
for {
|
||||
utils.Sleep(100)
|
||||
form := getStorage(task)
|
||||
if form == nil {
|
||||
continue
|
||||
}
|
||||
|
||||
switch form.Status {
|
||||
case Success:
|
||||
if err := hook(form, 100); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return form, nil
|
||||
case Failure:
|
||||
return nil, fmt.Errorf("task failed: %s", form.FailReason)
|
||||
case InProgress:
|
||||
current := getProgress(form.Progress)
|
||||
if progress != current {
|
||||
if err := hook(form, current); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
progress = current
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
@ -2,8 +2,11 @@ package midjourney
|
||||
|
||||
import (
|
||||
"chat/globals"
|
||||
"fmt"
|
||||
)
|
||||
|
||||
var midjourneyEmptySecret = "null"
|
||||
|
||||
type ChatInstance struct {
|
||||
Endpoint string
|
||||
ApiSecret string
|
||||
@ -17,6 +20,24 @@ func (c *ChatInstance) GetEndpoint() string {
|
||||
return c.Endpoint
|
||||
}
|
||||
|
||||
func (c *ChatInstance) GetMidjourneyHeaders() map[string]string {
|
||||
secret := c.GetApiSecret()
|
||||
if secret == "" || secret == midjourneyEmptySecret {
|
||||
return map[string]string{
|
||||
"Content-Type": "application/json",
|
||||
}
|
||||
}
|
||||
|
||||
return map[string]string{
|
||||
"Content-Type": "application/json",
|
||||
"mj-api-secret": secret,
|
||||
}
|
||||
}
|
||||
|
||||
func (c *ChatInstance) GetNotifyEndpoint() string {
|
||||
return fmt.Sprintf("%s/mj/notify", globals.NotifyUrl)
|
||||
}
|
||||
|
||||
func NewChatInstance(endpoint, apiSecret, whiteList string) *ChatInstance {
|
||||
SaveWhiteList(whiteList)
|
||||
|
||||
|
@ -22,22 +22,29 @@ const (
|
||||
RelaxMode = "--relax"
|
||||
)
|
||||
|
||||
var ModeArr = []string{TurboMode, FastMode, RelaxMode}
|
||||
var RendererMode = []string{TurboMode, FastMode, RelaxMode}
|
||||
|
||||
type ImagineHeader struct {
|
||||
type CommonHeader struct {
|
||||
ContentType string `json:"Content-Type"`
|
||||
MjApiSecret string `json:"mj-api-secret,omitempty"`
|
||||
}
|
||||
|
||||
type CommonResponse struct {
|
||||
Code int `json:"code"`
|
||||
Description string `json:"description"`
|
||||
Result string `json:"result"`
|
||||
}
|
||||
|
||||
type ImagineRequest struct {
|
||||
NotifyHook string `json:"notifyHook"`
|
||||
Prompt string `json:"prompt"`
|
||||
}
|
||||
|
||||
type ImagineResponse struct {
|
||||
Code int `json:"code"`
|
||||
Description string `json:"description"`
|
||||
Result string `json:"result"`
|
||||
type ChangeRequest struct {
|
||||
NotifyHook string `json:"notifyHook"`
|
||||
Action string `json:"action"`
|
||||
Index *int `json:"index,omitempty"`
|
||||
TaskId string `json:"taskId"`
|
||||
}
|
||||
|
||||
type NotifyForm struct {
|
||||
@ -56,6 +63,8 @@ type NotifyForm struct {
|
||||
}
|
||||
|
||||
type StorageForm struct {
|
||||
Task string `json:"task"`
|
||||
Action string `json:"action"`
|
||||
Url string `json:"url"`
|
||||
FailReason string `json:"failReason"`
|
||||
Progress string `json:"progress"`
|
||||
|
@ -105,3 +105,37 @@
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
.virtual-prompt {
|
||||
text-align: center;
|
||||
border-radius: var(--radius);
|
||||
border: 1px solid hsl(var(--border));
|
||||
padding: 0.5rem;
|
||||
}
|
||||
|
||||
.virtual-action {
|
||||
svg {
|
||||
transform: translateY(1px);
|
||||
}
|
||||
}
|
||||
|
||||
p:has(.virtual-action) {
|
||||
display: flex;
|
||||
flex-direction: row;
|
||||
flex-wrap: wrap;
|
||||
align-items: center;
|
||||
justify-content: center;
|
||||
|
||||
margin-top: 1.25rem !important;
|
||||
|
||||
|
||||
@media (min-width: 668px) {
|
||||
.virtual-action {
|
||||
margin-right: auto;
|
||||
|
||||
&:last-child {
|
||||
margin-right: 0;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -198,7 +198,7 @@
|
||||
box-sizing: content-box;
|
||||
background-color: var(--color-canvas-default);
|
||||
border-radius: 4px;
|
||||
max-height: 35vh;
|
||||
max-height: 50vh;
|
||||
}
|
||||
|
||||
.markdown-body code,
|
||||
|
@ -18,7 +18,10 @@ import {
|
||||
Codesandbox,
|
||||
Copy,
|
||||
Github,
|
||||
Maximize,
|
||||
RefreshCcwDot,
|
||||
Twitter,
|
||||
Wand2,
|
||||
Youtube,
|
||||
} from "lucide-react";
|
||||
import { copyClipboard } from "@/utils/dom.ts";
|
||||
@ -26,6 +29,18 @@ import { useToast } from "./ui/use-toast.ts";
|
||||
import { useTranslation } from "react-i18next";
|
||||
import { parseProgressbar } from "@/components/plugins/progress.tsx";
|
||||
import { cn } from "@/components/ui/lib/utils.ts";
|
||||
import { Button } from "@/components/ui/button.tsx";
|
||||
import {
|
||||
Dialog,
|
||||
DialogContent,
|
||||
DialogDescription,
|
||||
DialogFooter,
|
||||
DialogHeader,
|
||||
DialogTitle,
|
||||
DialogTrigger,
|
||||
} from "@/components/ui/dialog.tsx";
|
||||
import { DialogClose } from "@radix-ui/react-dialog";
|
||||
import { posterEvent } from "@/events/poster.ts";
|
||||
|
||||
type MarkdownProps = {
|
||||
children: string;
|
||||
@ -73,6 +88,17 @@ function getSocialIcon(url: string) {
|
||||
}
|
||||
}
|
||||
|
||||
function getVirtualIcon(command: string) {
|
||||
switch (command) {
|
||||
case "/VARIATION":
|
||||
return <Wand2 className="h-4 w-4 inline-block mr-2" />;
|
||||
case "/UPSCALE":
|
||||
return <Maximize className="h-4 w-4 inline-block mr-2" />;
|
||||
case "/REROLL":
|
||||
return <RefreshCcwDot className="h-4 w-4 inline-block mr-2" />;
|
||||
}
|
||||
}
|
||||
|
||||
function MarkdownContent({
|
||||
children,
|
||||
className,
|
||||
@ -107,6 +133,43 @@ function MarkdownContent({
|
||||
a({ href, children }) {
|
||||
const url: string = href?.toString() || "";
|
||||
|
||||
if (url.startsWith("https://chatnio.virtual")) {
|
||||
const message = url.slice(23).replace(/-/g, " ");
|
||||
const prefix = message.split(" ")[0];
|
||||
const send = () => posterEvent.emit(message);
|
||||
|
||||
return (
|
||||
<Dialog>
|
||||
<DialogTrigger asChild>
|
||||
<Button
|
||||
variant={`outline`}
|
||||
className={`flex flex-row items-center virtual-action mx-1 my-0.5 min-w-[4rem]`}
|
||||
>
|
||||
{getVirtualIcon(prefix)}
|
||||
{children}
|
||||
</Button>
|
||||
</DialogTrigger>
|
||||
<DialogContent>
|
||||
<DialogHeader>
|
||||
<DialogTitle>{t("chat.send-message")}</DialogTitle>
|
||||
<DialogDescription className={`pb-2`}>
|
||||
{t("chat.send-message-desc")}
|
||||
</DialogDescription>
|
||||
<p className={`virtual-prompt`}>{message}</p>
|
||||
</DialogHeader>
|
||||
<DialogFooter>
|
||||
<DialogClose asChild>
|
||||
<Button variant={`outline`}>{t("cancel")}</Button>
|
||||
</DialogClose>
|
||||
<DialogClose onClick={send} asChild>
|
||||
<Button variant={`default`}>{t("confirm")}</Button>
|
||||
</DialogClose>
|
||||
</DialogFooter>
|
||||
</DialogContent>
|
||||
</Dialog>
|
||||
);
|
||||
}
|
||||
|
||||
return (
|
||||
<a
|
||||
href={url}
|
||||
|
@ -158,7 +158,13 @@ function MessageContent({ message, end, index, onEvent }: MessageProps) {
|
||||
<MousePointerSquare className={`h-4 w-4 mr-1.5`} />
|
||||
{t("message.use")}
|
||||
</DropdownMenuItem>
|
||||
<DropdownMenuItem onClick={() => setOpen(true)}>
|
||||
<DropdownMenuItem
|
||||
onClick={() => {
|
||||
editedMessage?.length === 0 &&
|
||||
setEditedMessage(message.content);
|
||||
setOpen(true);
|
||||
}}
|
||||
>
|
||||
<PencilLine className={`h-4 w-4 mr-1.5`} />
|
||||
{t("message.edit")}
|
||||
</DropdownMenuItem>
|
||||
|
@ -46,6 +46,7 @@ import { chatEvent } from "@/events/chat.ts";
|
||||
import { cn } from "@/components/ui/lib/utils.ts";
|
||||
import { goAuth } from "@/utils/app.ts";
|
||||
import { getModelFromId } from "@/conf/model.ts";
|
||||
import { posterEvent } from "@/events/poster.ts";
|
||||
|
||||
type InterfaceProps = {
|
||||
setWorking: (working: boolean) => void;
|
||||
@ -105,8 +106,11 @@ function ChatWrapper() {
|
||||
setFiles([]);
|
||||
}
|
||||
|
||||
async function processSend(data: string): Promise<boolean> {
|
||||
if (requireAuth && !auth) {
|
||||
async function processSend(
|
||||
data: string,
|
||||
passAuth?: boolean,
|
||||
): Promise<boolean> {
|
||||
if (requireAuth && !auth && !passAuth) {
|
||||
toast({
|
||||
title: t("login-require"),
|
||||
action: (
|
||||
@ -159,6 +163,10 @@ function ChatWrapper() {
|
||||
});
|
||||
}
|
||||
|
||||
useEffect(() => {
|
||||
posterEvent.bind((data) => processSend(data, true));
|
||||
}, []);
|
||||
|
||||
useEffect(() => {
|
||||
window.addEventListener("load", () => {
|
||||
const el = document.getElementById("input");
|
||||
|
@ -1,5 +1,6 @@
|
||||
import { useMemo } from "react";
|
||||
import { parseNumber } from "@/utils/base.ts";
|
||||
import { Check, Loader2 } from "lucide-react";
|
||||
|
||||
export function parseProgressbar(data: string) {
|
||||
const progress = useMemo(() => {
|
||||
@ -11,7 +12,16 @@ export function parseProgressbar(data: string) {
|
||||
|
||||
return (
|
||||
<div className={`progress`}>
|
||||
<p className={`text-primary select-none text-center text-white`}>
|
||||
<p
|
||||
className={`flex flex-row items-center justify-center text-primary select-none text-center text-white px-6`}
|
||||
>
|
||||
{progress < 100 ? (
|
||||
<Loader2
|
||||
className={`h-4 w-4 mr-2 inline-block animate-spin shrink-0`}
|
||||
/>
|
||||
) : (
|
||||
<Check className={`h-4 w-4 mr-2 inline-block animate-out shrink-0`} />
|
||||
)}
|
||||
Generating: {progress < 0 ? 0 : progress.toFixed()}%
|
||||
</p>
|
||||
{progress > 0 && (
|
||||
@ -24,7 +34,7 @@ export function parseProgressbar(data: string) {
|
||||
data-max={100}
|
||||
>
|
||||
<p
|
||||
className={`h-full w-full flex-1 bg-primary transition-all`}
|
||||
className={`h-full w-full flex-1 bg-primary transition-all duration-300`}
|
||||
style={{ transform: `translateX(-${100 - progress}%)` }}
|
||||
data-max={100}
|
||||
/>
|
||||
|
@ -66,8 +66,8 @@ export function Combobox({
|
||||
<CommandItem
|
||||
key={key}
|
||||
value={key}
|
||||
onSelect={(current) => {
|
||||
onChange(current);
|
||||
onSelect={() => {
|
||||
onChange(key);
|
||||
setOpen(false);
|
||||
}}
|
||||
>
|
||||
|
5
app/src/events/poster.ts
Normal file
5
app/src/events/poster.ts
Normal file
@ -0,0 +1,5 @@
|
||||
import { EventCommitter } from "@/events/struct.ts";
|
||||
|
||||
export const posterEvent = new EventCommitter<string>({
|
||||
name: "poster",
|
||||
});
|
@ -132,7 +132,9 @@
|
||||
"placeholder-raw": "写点什么...",
|
||||
"recall": "历史复原",
|
||||
"recall-desc": "检测到您上次有未发送的消息,已经为您恢复。",
|
||||
"recall-cancel": "取消"
|
||||
"recall-cancel": "取消",
|
||||
"send-message": "发送消息",
|
||||
"send-message-desc": "是否确认发送此消息?"
|
||||
},
|
||||
"message": {
|
||||
"copy": "复制消息",
|
||||
|
@ -80,7 +80,9 @@
|
||||
"recall-desc": "Detected that you have unsent messages last time, has been restored for you.",
|
||||
"recall-cancel": "Cancel",
|
||||
"placeholder-enter": "Write something... (Enter to send)",
|
||||
"placeholder-raw": "Write something..."
|
||||
"placeholder-raw": "Write something...",
|
||||
"send-message": "Send Message",
|
||||
"send-message-desc": "Are you sure you want to send this message?"
|
||||
},
|
||||
"message": {
|
||||
"copy": "Copy Message",
|
||||
|
@ -80,7 +80,9 @@
|
||||
"recall-desc": "最後に未送信のメッセージが検出され、復元されました。",
|
||||
"recall-cancel": "キャンセル",
|
||||
"placeholder-enter": "何か書いてください... (送信するにはEnterキーを押してください)",
|
||||
"placeholder-raw": "何か書いてください..."
|
||||
"placeholder-raw": "何か書いてください...",
|
||||
"send-message": "メッセージを送信",
|
||||
"send-message-desc": "このメッセージを送信してもよろしいですか?"
|
||||
},
|
||||
"message": {
|
||||
"copy": "メッセージをコピー",
|
||||
|
@ -80,7 +80,9 @@
|
||||
"recall-desc": "Обнаружено, что у вас есть неотправленные сообщения в прошлый раз, они были восстановлены для вас.",
|
||||
"recall-cancel": "Отмена",
|
||||
"placeholder-enter": "Напишите что-нибудь... (Введите, чтобы отправить)",
|
||||
"placeholder-raw": "Напишите что-нибудь..."
|
||||
"placeholder-raw": "Напишите что-нибудь...",
|
||||
"send-message": "Отправить",
|
||||
"send-message-desc": "Вы уверены, что хотите отправить это сообщение?"
|
||||
},
|
||||
"message": {
|
||||
"copy": "Копировать сообщение",
|
||||
|
@ -100,26 +100,29 @@ const (
|
||||
SkylarkChat = "skylark-chat"
|
||||
)
|
||||
|
||||
var DalleModels = []string{
|
||||
var OpenAIDalleModels = []string{
|
||||
Dalle, Dalle2, Dalle3,
|
||||
}
|
||||
|
||||
var OpenAIVisionModels = []string{
|
||||
GPT4VisionPreview, GPT41106VisionPreview,
|
||||
}
|
||||
|
||||
func in(value string, slice []string) bool {
|
||||
for _, item := range slice {
|
||||
if item == value {
|
||||
if item == value || strings.Contains(value, item) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func IsDalleModel(model string) bool {
|
||||
func IsOpenAIDalleModel(model string) bool {
|
||||
// using image generation api if model is in dalle models
|
||||
return in(model, DalleModels)
|
||||
return in(model, OpenAIDalleModels)
|
||||
}
|
||||
|
||||
func IsGPT41106VisionPreview(model string) bool {
|
||||
// enable openai image format for gpt-4-vision-preview model
|
||||
return (model == GPT41106VisionPreview || strings.Contains(model, GPT41106VisionPreview)) ||
|
||||
(model == GPT4VisionPreview || strings.Contains(model, GPT4VisionPreview))
|
||||
func IsOpenAIVisionModels(model string) bool {
|
||||
// enable openai image format for gpt-4-vision-preview models
|
||||
return in(model, OpenAIVisionModels)
|
||||
}
|
||||
|
@ -309,3 +309,11 @@ func Any(arr ...bool) bool {
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func Range(start int, end int) []int {
|
||||
var res []int
|
||||
for i := start; i < end; i++ {
|
||||
res = append(res, i)
|
||||
}
|
||||
return res
|
||||
}
|
||||
|
@ -255,3 +255,35 @@ func SortString(arr []string) []string {
|
||||
|
||||
return result
|
||||
}
|
||||
|
||||
func SafeSplit(data string, sep string, seglen int) (res []string) {
|
||||
// split string by sep, and each segment has seglen length
|
||||
// e.g. SafeSplit("abc,def,ghi", ",", 2) => ["abc", "def,ghi"]
|
||||
|
||||
if data == "" {
|
||||
for i := 0; i < seglen; i++ {
|
||||
res = append(res, "")
|
||||
}
|
||||
return res
|
||||
}
|
||||
|
||||
arr := strings.Split(data, sep)
|
||||
length := len(arr)
|
||||
|
||||
if length == seglen {
|
||||
return arr
|
||||
}
|
||||
|
||||
if length < seglen {
|
||||
for i := 0; i < seglen-length; i++ {
|
||||
arr = append(arr, "")
|
||||
}
|
||||
return arr
|
||||
} else {
|
||||
offset := length - seglen
|
||||
for i := 0; i < offset; i++ {
|
||||
arr[seglen-1] += sep + arr[seglen+i]
|
||||
}
|
||||
return arr[:seglen]
|
||||
}
|
||||
}
|
||||
|
@ -11,6 +11,11 @@ import (
|
||||
var configFile = "config/config.yaml"
|
||||
var configExampleFile = "config.example.yaml"
|
||||
|
||||
var redirectRoutes = []string{
|
||||
"/v1",
|
||||
"/mj",
|
||||
}
|
||||
|
||||
func ReadConf() {
|
||||
viper.SetConfigFile(configFile)
|
||||
|
||||
@ -58,11 +63,12 @@ func RegisterStaticRoute(engine *gin.Engine) {
|
||||
c.File("./app/dist/index.html")
|
||||
})
|
||||
|
||||
// redirect /v1 to /api/v1
|
||||
engine.Any("/v1/*path", func(c *gin.Context) {
|
||||
path := c.Param("path")
|
||||
c.Redirect(301, fmt.Sprintf("/api/v1/%s", path))
|
||||
})
|
||||
for _, route := range redirectRoutes {
|
||||
engine.Any(fmt.Sprintf("%s/*path", route), func(c *gin.Context) {
|
||||
path := c.Param("path")
|
||||
c.Redirect(301, fmt.Sprintf("/api%s/%s", route, path))
|
||||
})
|
||||
}
|
||||
|
||||
fmt.Println(`[service] start serving static files from ~/app/dist`)
|
||||
}
|
||||
|
@ -86,7 +86,7 @@ func (i *Image) GetPixelColor(x int, y int) (int, int, int) {
|
||||
}
|
||||
|
||||
func (i *Image) CountTokens(model string) int {
|
||||
if globals.IsGPT41106VisionPreview(model) {
|
||||
if globals.IsOpenAIVisionModels(model) {
|
||||
// tile size is 512x512
|
||||
// the max size of image is 2048x2048
|
||||
// the image that is larger than 2048x2048 will be resized in 16 tiles
|
||||
|
Loading…
Reference in New Issue
Block a user