update model selector

This commit is contained in:
Zhang Minghan 2023-10-21 22:59:17 +08:00
parent a5e18a1ee2
commit e31d69ae18
12 changed files with 284 additions and 132 deletions

View File

@ -17,6 +17,9 @@
} }
.select-group-item { .select-group-item {
display: flex;
flex-direction: row;
align-items: center;
padding: 0.35rem 0.8rem; padding: 0.35rem 0.8rem;
border: 1px solid hsl(var(--border)); border: 1px solid hsl(var(--border));
border-radius: 4px; border-radius: 4px;
@ -34,6 +37,27 @@
background: hsl(var(--text)); background: hsl(var(--text));
border-color: hsl(var(--border-hover)); border-color: hsl(var(--border-hover));
color: hsl(var(--background)); color: hsl(var(--background));
.badge {
background: hsl(var(--background)) !important;
color: hsl(var(--text));
&:hover {
background: hsl(var(--background)) !important;
}
}
}
.badge {
user-select: none;
transition: .2s;
padding-left: 0.45rem;
padding-right: 0.45rem;
background: hsl(var(--primary)) !important;
&:hover {
background: hsl(var(--primary)) !important;
}
} }
} }
} }

View File

@ -6,7 +6,9 @@ import {
Copy, Copy,
File, File,
Loader2, Loader2,
MousePointerSquare, Power, RotateCcw, MousePointerSquare,
Power,
RotateCcw,
} from "lucide-react"; } from "lucide-react";
import { import {
ContextMenu, ContextMenu,
@ -165,21 +167,21 @@ function MessageContent({ message, end, onEvent }: MessageProps) {
<Loader2 className={`h-5 w-5 m-1 animate-spin`} /> <Loader2 className={`h-5 w-5 m-1 animate-spin`} />
)} )}
</div> </div>
{ {message.role === "assistant" && end === true && (
(message.role === "assistant" && end === true) && ( <div className={`message-toolbar`}>
<div className={`message-toolbar`}> {message.end !== false ? (
{ <RotateCcw
(message.end !== false) ? className={`h-4 w-4 m-0.5`}
<RotateCcw className={`h-4 w-4 m-0.5`} onClick={() => ( onClick={() => onEvent && onEvent("restart")}
onEvent && onEvent("restart") />
)} /> : ) : (
<Power className={`h-4 w-4 m-0.5`} onClick={() => ( <Power
onEvent && onEvent("stop") className={`h-4 w-4 m-0.5`}
)} /> onClick={() => onEvent && onEvent("stop")}
} />
</div> )}
) </div>
} )}
</div> </div>
); );
} }

View File

@ -7,13 +7,101 @@ import {
} from "./ui/select"; } from "./ui/select";
import { mobile } from "../utils.ts"; import { mobile } from "../utils.ts";
import { useEffect, useState } from "react"; import { useEffect, useState } from "react";
import { Badge } from "./ui/badge.tsx";
export type SelectItemProps = {
name: string;
value: string;
badge?: string;
tag?: any;
};
type SelectGroupProps = { type SelectGroupProps = {
current: string; current: SelectItemProps;
list: string[]; list: SelectItemProps[];
onChange?: (select: string) => void; onChange?: (select: string) => void;
maxElements?: number;
}; };
function GroupSelectItem(props: SelectItemProps) {
return (
<>
{props.value}
{props.badge && <Badge className="badge ml-1">{props.badge}</Badge>}
</>
);
}
function SelectGroupDesktop(props: SelectGroupProps) {
const max: number = props.maxElements || 5;
const range = props.list.length > max ? max : props.list.length;
const display = props.list.slice(0, range);
const hidden = props.list.slice(range);
return (
<div className={`select-group`}>
{display.map((select: SelectItemProps, idx: number) => (
<div
key={idx}
onClick={() => props.onChange?.(select.name)}
className={`select-group-item ${
select == props.current ? "active" : ""
}`}
>
<GroupSelectItem {...select} />
</div>
))}
{props.list.length > max && (
<Select
defaultValue={"..."}
value={props.current?.name || "..."}
onValueChange={(value: string) => props.onChange?.(value)}
>
<SelectTrigger
className={`w-max gap-1 select-group-item ${
hidden.includes(props.current) ? "active" : ""
}`}
>
<SelectValue asChild>
<span>
{hidden.includes(props.current) ? props.current.value : "..."}
</span>
</SelectValue>
</SelectTrigger>
<SelectContent>
{hidden.map((select: SelectItemProps, idx: number) => (
<SelectItem key={idx} value={select.name}>
<GroupSelectItem {...select} />
</SelectItem>
))}
</SelectContent>
</Select>
)}
</div>
);
}
function SelectGroupMobile(props: SelectGroupProps) {
return (
<Select
value={props.current.name}
onValueChange={(value: string) => props.onChange?.(value)}
>
<SelectTrigger className="select-group mobile">
<SelectValue placeholder={props.current.value} />
</SelectTrigger>
<SelectContent>
{props.list.map((select: SelectItemProps, idx: number) => (
<SelectItem key={idx} value={select.name}>
<GroupSelectItem {...select} />
</SelectItem>
))}
</SelectContent>
</Select>
);
}
function SelectGroup(props: SelectGroupProps) { function SelectGroup(props: SelectGroupProps) {
const [state, setState] = useState(mobile); const [state, setState] = useState(mobile);
useEffect(() => { useEffect(() => {
@ -23,35 +111,9 @@ function SelectGroup(props: SelectGroupProps) {
}, []); }, []);
return state ? ( return state ? (
<Select <SelectGroupMobile {...props} />
value={props.current}
onValueChange={(value: string) => props.onChange?.(value)}
>
<SelectTrigger className="select-group mobile">
<SelectValue placeholder={props.current} />
</SelectTrigger>
<SelectContent>
{props.list.map((select: string, idx: number) => (
<SelectItem key={idx} value={select}>
{select}
</SelectItem>
))}
</SelectContent>
</Select>
) : ( ) : (
<div className={`select-group`}> <SelectGroupDesktop {...props} />
{props.list.map((select: string, idx: number) => (
<div
key={idx}
onClick={() => props.onChange?.(select)}
className={`select-group-item ${
select == props.current ? "active" : ""
}`}
>
{select}
</div>
))}
</div>
); );
} }

View File

@ -0,0 +1,56 @@
import SelectGroup, { SelectItemProps } from "../SelectGroup.tsx";
import { supportModels } from "../../conf.ts";
import { selectModel, setModel } from "../../store/chat.ts";
import { useTranslation } from "react-i18next";
import { useDispatch, useSelector } from "react-redux";
import { selectAuthenticated } from "../../store/auth.ts";
import { useToast } from "../ui/use-toast.ts";
import { useEffect } from "react";
import { Model } from "../../conversation/types.ts";
function GetModel(name: string): Model {
return supportModels.find((model) => model.id === name) as Model;
}
function ModelSelector() {
const { t } = useTranslation();
const dispatch = useDispatch();
const { toast } = useToast();
const model = useSelector(selectModel);
const auth = useSelector(selectAuthenticated);
useEffect(() => {
if (auth && model === "GPT-3.5") dispatch(setModel("GPT-3.5-16k"));
}, [auth]);
const list = supportModels.map(
(model: Model): SelectItemProps => ({
name: model.id,
value: model.name,
badge: model.free ? "free" : undefined,
}),
);
return (
<SelectGroup
current={list.find((item) => item.name === model) as SelectItemProps}
list={list}
maxElements={6}
onChange={(value: string) => {
const model = GetModel(value);
console.debug(`[model] select model: ${model.name} (id: ${model.id})`);
if (!auth && model.auth) {
toast({
title: t("login-require"),
});
return;
}
dispatch(setModel(value));
}}
/>
);
}
export default ModelSelector;

View File

@ -1,6 +1,7 @@
import axios from "axios"; import axios from "axios";
import { Model } from "./conversation/types.ts";
export const version = "3.4.5"; export const version = "3.4.6";
export const deploy: boolean = true; export const deploy: boolean = true;
export let rest_api: string = "http://localhost:8094"; export let rest_api: string = "http://localhost:8094";
export let ws_api: string = "ws://localhost:8094"; export let ws_api: string = "ws://localhost:8094";
@ -11,19 +12,45 @@ if (deploy) {
} }
export const tokenField = deploy ? "token" : "token-dev"; export const tokenField = deploy ? "token" : "token-dev";
export const supportModels: string[] = [ export const supportModels: Model[] = [
"GPT-3.5", // openai models
"GPT-3.5-16k", { id: "gpt-3.5-turbo", name: "GPT-3.5", free: true, auth: false },
"GPT-4", { id: "gpt-3.5-turbo-16k", name: "GPT-3.5-16k", free: true, auth: true },
"GPT-4-32k", { id: "gpt-4", name: "GPT-4", free: false, auth: true },
"Claude-2", { id: "gpt-4-32k", name: "GPT-4-32k", free: false, auth: true },
"Claude-2-100k",
"SparkDesk 讯飞星火", // anthropic models
"Palm2", { id: "claude-1", name: "Claude-2", free: true, auth: false },
"New Bing", { id: "claude-2", name: "Claude-2-100k", free: false, auth: true }, // not claude-2-100k
"智谱 ChatGLM Pro",
"智谱 ChatGLM Std", // spark desk
"智谱 ChatGLM Lite", { id: "spark-desk", name: "SparkDesk 讯飞星火", free: false, auth: true },
// google palm2
{ id: "chat-bison-001", name: "Palm2", free: true, auth: true },
// new bing
{ id: "bing-creative", name: "New Bing", free: true, auth: true },
// zhipu models
{
id: "zhipu-chatglm-pro",
name: "智谱 ChatGLM Pro",
free: false,
auth: true,
},
{
id: "zhipu-chatglm-std",
name: "智谱 ChatGLM Std",
free: false,
auth: true,
},
{
id: "zhipu-chatglm-lite",
name: "智谱 ChatGLM Lite",
free: true,
auth: true,
},
]; ];
export const supportModelConvertor: Record<string, string> = { export const supportModelConvertor: Record<string, string> = {

View File

@ -71,10 +71,11 @@ export class Connection {
}, 500); }, 500);
} }
} catch { } catch {
if (t !== undefined) this.triggerCallback({ if (t !== undefined)
message: t("request-failed"), this.triggerCallback({
end: true, message: t("request-failed"),
}); end: true,
});
} }
} }

View File

@ -1,7 +1,7 @@
import { ChatProps, Connection, StreamMessage } from "./connection.ts"; import { ChatProps, Connection, StreamMessage } from "./connection.ts";
import { Message } from "./types.ts"; import { Message } from "./types.ts";
import { sharingEvent } from "../events/sharing.ts"; import { sharingEvent } from "../events/sharing.ts";
import {connectionEvent} from "../events/connection.ts"; import { connectionEvent } from "../events/connection.ts";
type ConversationCallback = (idx: number, message: Message[]) => void; type ConversationCallback = (idx: number, message: Message[]) => void;
@ -34,7 +34,9 @@ export class Conversation {
connectionEvent.addEventListener((ev) => { connectionEvent.addEventListener((ev) => {
if (ev.id === this.id) { if (ev.id === this.id) {
console.debug(`[conversation] connection event (id: ${this.id}, event: ${ev.event})`); console.debug(
`[conversation] connection event (id: ${this.id}, event: ${ev.event})`,
);
switch (ev.event) { switch (ev.event) {
case "stop": case "stop":
@ -52,10 +54,12 @@ export class Conversation {
break; break;
default: default:
console.debug(`[conversation] unknown event: ${ev.event} (from: ${ev.id})`); console.debug(
`[conversation] unknown event: ${ev.event} (from: ${ev.id})`,
);
} }
} }
}) });
} }
protected sendEvent(event: string, data?: string) { protected sendEvent(event: string, data?: string) {

View File

@ -8,6 +8,13 @@ export type Message = {
end?: boolean; end?: boolean;
}; };
export type Model = {
id: string;
name: string;
free: boolean;
auth: boolean;
};
export type Id = number; export type Id = number;
export type ConversationInstance = { export type ConversationInstance = {

View File

@ -1,4 +1,4 @@
import {EventCommitter} from "./struct.ts"; import { EventCommitter } from "./struct.ts";
export type ConnectionEvent = { export type ConnectionEvent = {
id: number; id: number;

View File

@ -1,18 +1,17 @@
import "../assets/generation.less"; import "../assets/generation.less";
import { useDispatch, useSelector } from "react-redux"; import { useSelector } from "react-redux";
import { selectAuthenticated } from "../store/auth.ts";
import { useTranslation } from "react-i18next"; import { useTranslation } from "react-i18next";
import { Button } from "../components/ui/button.tsx"; import { Button } from "../components/ui/button.tsx";
import { ChevronLeft, Cloud, FileDown, Send } from "lucide-react"; import { ChevronLeft, Cloud, FileDown, Send } from "lucide-react";
import { rest_api, supportModelConvertor, supportModels } from "../conf.ts"; import { rest_api, supportModelConvertor } from "../conf.ts";
import router from "../router.ts"; import router from "../router.ts";
import { Input } from "../components/ui/input.tsx"; import { Input } from "../components/ui/input.tsx";
import { useEffect, useRef, useState } from "react"; import { useEffect, useRef, useState } from "react";
import SelectGroup from "../components/SelectGroup.tsx";
import { manager } from "../conversation/generation.ts"; import { manager } from "../conversation/generation.ts";
import { useToast } from "../components/ui/use-toast.ts"; import { useToast } from "../components/ui/use-toast.ts";
import { handleGenerationData } from "../utils.ts"; import { handleGenerationData } from "../utils.ts";
import { selectModel, setModel } from "../store/chat.ts"; import { selectModel } from "../store/chat.ts";
import ModelSelector from "../components/home/ModelSelector.tsx";
type WrapperProps = { type WrapperProps = {
onSend?: (value: string, model: string) => boolean; onSend?: (value: string, model: string) => boolean;
@ -20,7 +19,6 @@ type WrapperProps = {
function Wrapper({ onSend }: WrapperProps) { function Wrapper({ onSend }: WrapperProps) {
const { t } = useTranslation(); const { t } = useTranslation();
const dispatch = useDispatch();
const ref = useRef(null); const ref = useRef(null);
const [stayed, setStayed] = useState<boolean>(false); const [stayed, setStayed] = useState<boolean>(false);
const [hash, setHash] = useState<string>(""); const [hash, setHash] = useState<string>("");
@ -28,14 +26,9 @@ function Wrapper({ onSend }: WrapperProps) {
const [quota, setQuota] = useState<number>(0); const [quota, setQuota] = useState<number>(0);
const model = useSelector(selectModel); const model = useSelector(selectModel);
const modelRef = useRef(model); const modelRef = useRef(model);
const auth = useSelector(selectAuthenticated);
const { toast } = useToast(); const { toast } = useToast();
useEffect(() => {
if (auth && model === "GPT-3.5") dispatch(setModel("GPT-3.5-16k"));
}, [auth]);
function clear() { function clear() {
setData(""); setData("");
setQuota(0); setQuota(0);
@ -154,19 +147,7 @@ function Wrapper({ onSend }: WrapperProps) {
</Button> </Button>
</div> </div>
<div className={`model-box`}> <div className={`model-box`}>
<SelectGroup <ModelSelector />
current={model}
list={supportModels}
onChange={(value: string) => {
if (!auth && value !== "GPT-3.5") {
toast({
title: t("login-require"),
});
return;
}
dispatch(setModel(value));
}}
/>
</div> </div>
</div> </div>
); );

View File

@ -22,7 +22,7 @@ import {
import { useDispatch, useSelector } from "react-redux"; import { useDispatch, useSelector } from "react-redux";
import type { RootState } from "../store"; import type { RootState } from "../store";
import { selectAuthenticated, selectInit } from "../store/auth.ts"; import { selectAuthenticated, selectInit } from "../store/auth.ts";
import { login, supportModels } from "../conf.ts"; import { login } from "../conf.ts";
import { import {
deleteConversation, deleteConversation,
toggleConversation, toggleConversation,
@ -39,7 +39,7 @@ import {
useEffectAsync, useEffectAsync,
copyClipboard, copyClipboard,
} from "../utils.ts"; } from "../utils.ts";
import { toast, useToast } from "../components/ui/use-toast.ts"; import { useToast } from "../components/ui/use-toast.ts";
import { ConversationInstance, Message } from "../conversation/types.ts"; import { ConversationInstance, Message } from "../conversation/types.ts";
import { import {
selectCurrent, selectCurrent,
@ -47,7 +47,6 @@ import {
selectHistory, selectHistory,
selectMessages, selectMessages,
selectWeb, selectWeb,
setModel,
setWeb, setWeb,
} from "../store/chat.ts"; } from "../store/chat.ts";
import { import {
@ -66,10 +65,10 @@ import MessageSegment from "../components/Message.tsx";
import { setMenu } from "../store/menu.ts"; import { setMenu } from "../store/menu.ts";
import FileProvider, { FileObject } from "../components/FileProvider.tsx"; import FileProvider, { FileObject } from "../components/FileProvider.tsx";
import router from "../router.ts"; import router from "../router.ts";
import SelectGroup from "../components/SelectGroup.tsx";
import EditorProvider from "../components/EditorProvider.tsx"; import EditorProvider from "../components/EditorProvider.tsx";
import ConversationSegment from "../components/home/ConversationSegment.tsx"; import ConversationSegment from "../components/home/ConversationSegment.tsx";
import {connectionEvent} from "../events/connection.ts"; import { connectionEvent } from "../events/connection.ts";
import ModelSelector from "../components/home/ModelSelector.tsx";
function SideBar() { function SideBar() {
const { t } = useTranslation(); const { t } = useTranslation();
@ -353,21 +352,19 @@ function ChatInterface() {
</Button> </Button>
</div> </div>
{ {messages.map((message, i) => (
messages.map((message, i) => <MessageSegment
<MessageSegment message={message}
message={message} end={i === messages.length - 1}
end={i === messages.length - 1} onEvent={(e: string) => {
onEvent={(e: string) => { connectionEvent.emit({
connectionEvent.emit({ id: current,
id: current, event: e,
event: e, });
}); }}
}} key={i}
key={i} />
/> ))}
)
}
</div> </div>
</> </>
); );
@ -390,10 +387,6 @@ function ChatWrapper() {
const target = useRef(null); const target = useRef(null);
manager.setDispatch(dispatch); manager.setDispatch(dispatch);
useEffect(() => {
if (auth && model === "GPT-3.5") dispatch(setModel("GPT-3.5-16k"));
}, [auth]);
function clearFile() { function clearFile() {
clearEvent?.(); clearEvent?.();
} }
@ -522,19 +515,7 @@ function ChatWrapper() {
</Button> </Button>
</div> </div>
<div className={`input-options`}> <div className={`input-options`}>
<SelectGroup <ModelSelector />
current={model}
list={supportModels}
onChange={(model: string) => {
if (!auth && model !== "GPT-3.5") {
toast({
title: t("login-require"),
});
return;
}
dispatch(setModel(model));
}}
/>
</div> </div>
</div> </div>
</div> </div>

View File

@ -1,8 +1,9 @@
import { createSlice } from "@reduxjs/toolkit"; import { createSlice } from "@reduxjs/toolkit";
import { ConversationInstance } from "../conversation/types.ts"; import {ConversationInstance, Model} from "../conversation/types.ts";
import { Message } from "../conversation/types.ts"; import { Message } from "../conversation/types.ts";
import { insertStart } from "../utils.ts"; import { insertStart } from "../utils.ts";
import { RootState } from "./index.ts"; import { RootState } from "./index.ts";
import { supportModels } from "../conf.ts";
type initialStateType = { type initialStateType = {
history: ConversationInstance[]; history: ConversationInstance[];
@ -12,12 +13,17 @@ type initialStateType = {
current: number; current: number;
}; };
function GetModel(model: string | undefined | null): string {
return model && supportModels.filter((item: Model) => item.id === model).length
? model : supportModels[0].id;
}
const chatSlice = createSlice({ const chatSlice = createSlice({
name: "chat", name: "chat",
initialState: { initialState: {
history: [], history: [],
messages: [], messages: [],
model: "GPT-3.5", model: GetModel(localStorage.getItem("model")),
web: false, web: false,
current: -1, current: -1,
} as initialStateType, } as initialStateType,
@ -44,6 +50,7 @@ const chatSlice = createSlice({
state.messages = action.payload as Message[]; state.messages = action.payload as Message[];
}, },
setModel: (state, action) => { setModel: (state, action) => {
localStorage.setItem("model", action.payload as string);
state.model = action.payload as string; state.model = action.payload as string;
}, },
setWeb: (state, action) => { setWeb: (state, action) => {