diff --git a/app/components/chat.tsx b/app/components/chat.tsx index 51fe74fe7..75120041c 100644 --- a/app/components/chat.tsx +++ b/app/components/chat.tsx @@ -1,17 +1,18 @@ import { useDebouncedCallback } from "use-debounce"; import React, { - useState, - useRef, - useEffect, - useMemo, - useCallback, Fragment, RefObject, + useCallback, + useEffect, + useMemo, + useRef, + useState, } from "react"; import SendWhiteIcon from "../icons/send-white.svg"; import BrainIcon from "../icons/brain.svg"; import RenameIcon from "../icons/rename.svg"; +import EditIcon from "../icons/rename.svg"; import ExportIcon from "../icons/share.svg"; import ReturnIcon from "../icons/return.svg"; import CopyIcon from "../icons/copy.svg"; @@ -24,11 +25,11 @@ import MaskIcon from "../icons/mask.svg"; import MaxIcon from "../icons/max.svg"; import MinIcon from "../icons/min.svg"; import ResetIcon from "../icons/reload.svg"; +import ReloadIcon from "../icons/reload.svg"; import BreakIcon from "../icons/break.svg"; import SettingsIcon from "../icons/chat-settings.svg"; import DeleteIcon from "../icons/clear.svg"; import PinIcon from "../icons/pin.svg"; -import EditIcon from "../icons/rename.svg"; import ConfirmIcon from "../icons/confirm.svg"; import CloseIcon from "../icons/close.svg"; import CancelIcon from "../icons/cancel.svg"; @@ -45,33 +46,32 @@ import QualityIcon from "../icons/hd.svg"; import StyleIcon from "../icons/palette.svg"; import PluginIcon from "../icons/plugin.svg"; import ShortcutkeyIcon from "../icons/shortcutkey.svg"; -import ReloadIcon from "../icons/reload.svg"; import HeadphoneIcon from "../icons/headphone.svg"; import { - ChatMessage, - SubmitKey, - useChatStore, BOT_HELLO, + ChatMessage, createMessage, - useAccessStore, - Theme, - useAppConfig, DEFAULT_TOPIC, ModelType, + SubmitKey, + Theme, + useAccessStore, + useAppConfig, + useChatStore, usePluginStore, } from "../store"; import { - copyToClipboard, - selectOrCopy, autoGrowTextArea, - useMobileScreen, - getMessageTextContent, + copyToClipboard, getMessageImages, - isVisionModel, + getMessageTextContent, isDalle3, - showPlugins, + isVisionModel, safeLocalStorage, + selectOrCopy, + showPlugins, + useMobileScreen, } from "../utils"; import { uploadImage as uploadImageRemote } from "@/app/utils/chat"; @@ -79,7 +79,7 @@ import { uploadImage as uploadImageRemote } from "@/app/utils/chat"; import dynamic from "next/dynamic"; import { ChatControllerPool } from "../client/controller"; -import { DalleSize, DalleQuality, DalleStyle } from "../typing"; +import { DalleQuality, DalleSize, DalleStyle } from "../typing"; import { Prompt, usePromptStore } from "../store/prompt"; import Locale from "../locales"; @@ -102,8 +102,8 @@ import { ModelProvider, Path, REQUEST_TIMEOUT_MS, - UNFINISHED_INPUT, ServiceProvider, + UNFINISHED_INPUT, } from "../constant"; import { Avatar } from "./emoji"; import { ContextPrompts, MaskAvatar, MaskConfig } from "./mask"; @@ -113,9 +113,7 @@ import { prettyObject } from "../utils/format"; import { ExportMessageModal } from "./exporter"; import { getClientConfig } from "../config/client"; import { useAllModels } from "../utils/hooks"; -import { MultimodalContent } from "../client/api"; - -import { ClientApi } from "../client/api"; +import { ClientApi, MultimodalContent } from "../client/api"; import { createTTSPlayer } from "../utils/audio"; import { MsEdgeTTS, OUTPUT_FORMAT } from "../utils/ms_edge_tts"; @@ -427,6 +425,7 @@ function useScrollToBottom( // for auto-scroll const [autoScroll, setAutoScroll] = useState(true); + function scrollDomToBottom() { const dom = scrollRef.current; if (dom) { @@ -473,6 +472,7 @@ export function ChatActions(props: { // switch themes const theme = config.theme; + function nextTheme() { const themes = [Theme.Auto, Theme.Light, Theme.Dark]; const themeIndex = themes.indexOf(theme); @@ -1237,6 +1237,7 @@ function _Chat() { const accessStore = useAccessStore(); const [speechStatus, setSpeechStatus] = useState(false); const [speechLoading, setSpeechLoading] = useState(false); + async function openaiSpeech(text: string) { if (speechStatus) { ttsPlayer.stop(); @@ -1336,6 +1337,7 @@ function _Chat() { const [msgRenderIndex, _setMsgRenderIndex] = useState( Math.max(0, renderMessages.length - CHAT_PAGE_SIZE), ); + function setMsgRenderIndex(newIndex: number) { newIndex = Math.min(renderMessages.length - CHAT_PAGE_SIZE, newIndex); newIndex = Math.max(0, newIndex); @@ -1371,6 +1373,7 @@ function _Chat() { setHitBottom(isHitBottom); setAutoScroll(isHitBottom); }; + function scrollToBottom() { setMsgRenderIndex(renderMessages.length - CHAT_PAGE_SIZE); scrollDomToBottom(); @@ -1712,252 +1715,264 @@ function _Chat() { setAutoScroll(false); }} > - {messages.map((message, i) => { - const isUser = message.role === "user"; - const isContext = i < context.length; - const showActions = - i > 0 && - !(message.preview || message.content.length === 0) && - !isContext; - const showTyping = message.preview || message.streaming; + {messages + // TODO + // .filter((m) => !m.isMcpResponse) + .map((message, i) => { + const isUser = message.role === "user"; + const isContext = i < context.length; + const showActions = + i > 0 && + !(message.preview || message.content.length === 0) && + !isContext; + const showTyping = message.preview || message.streaming; - const shouldShowClearContextDivider = - i === clearContextIndex - 1; + const shouldShowClearContextDivider = + i === clearContextIndex - 1; - return ( - -
-
-
-
-
- } - aria={Locale.Chat.Actions.Edit} - onClick={async () => { - const newMessage = await showPrompt( - Locale.Chat.Actions.Edit, - getMessageTextContent(message), - 10, - ); - let newContent: string | MultimodalContent[] = - newMessage; - const images = getMessageImages(message); - if (images.length > 0) { - newContent = [ - { type: "text", text: newMessage }, - ]; - for (let i = 0; i < images.length; i++) { - newContent.push({ - type: "image_url", - image_url: { - url: images[i], - }, - }); - } - } - chatStore.updateTargetSession( - session, - (session) => { - const m = session.mask.context - .concat(session.messages) - .find((m) => m.id === message.id); - if (m) { - m.content = newContent; + return ( + +
+
+
+
+
+ } + aria={Locale.Chat.Actions.Edit} + onClick={async () => { + const newMessage = await showPrompt( + Locale.Chat.Actions.Edit, + getMessageTextContent(message), + 10, + ); + let newContent: + | string + | MultimodalContent[] = newMessage; + const images = getMessageImages(message); + if (images.length > 0) { + newContent = [ + { type: "text", text: newMessage }, + ]; + for (let i = 0; i < images.length; i++) { + newContent.push({ + type: "image_url", + image_url: { + url: images[i], + }, + }); } - }, - ); - }} - > -
- {isUser ? ( - - ) : ( - <> - {["system"].includes(message.role) ? ( - - ) : ( - - )} - + chatStore.updateTargetSession( + session, + (session) => { + const m = session.mask.context + .concat(session.messages) + .find((m) => m.id === message.id); + if (m) { + m.content = newContent; + } + }, + ); + }} + > +
+ {isUser ? ( + + ) : ( + <> + {["system"].includes(message.role) ? ( + + ) : ( + + )} + + )} +
+ {!isUser && ( +
+ {message.model} +
)} -
- {!isUser && ( -
- {message.model} -
- )} - {showActions && ( -
-
- {message.streaming ? ( - } - onClick={() => onUserStop(message.id ?? i)} - /> - ) : ( - <> + {showActions && ( +
+
+ {message.streaming ? ( } - onClick={() => onResend(message)} - /> - - } - onClick={() => onDelete(message.id ?? i)} - /> - - } - onClick={() => onPinMessage(message)} - /> - } + text={Locale.Chat.Actions.Stop} + icon={} onClick={() => - copyToClipboard( - getMessageTextContent(message), - ) + onUserStop(message.id ?? i) } /> - {config.ttsConfig.enable && ( + ) : ( + <> - ) : ( - - ) - } + text={Locale.Chat.Actions.Retry} + icon={} + onClick={() => onResend(message)} + /> + + } onClick={() => - openaiSpeech( + onDelete(message.id ?? i) + } + /> + + } + onClick={() => onPinMessage(message)} + /> + } + onClick={() => + copyToClipboard( getMessageTextContent(message), ) } /> - )} - - )} + {config.ttsConfig.enable && ( + + ) : ( + + ) + } + onClick={() => + openaiSpeech( + getMessageTextContent(message), + ) + } + /> + )} + + )} +
+ )} +
+ {message?.tools?.length == 0 && showTyping && ( +
+ {Locale.Chat.Typing}
)} -
- {message?.tools?.length == 0 && showTyping && ( -
- {Locale.Chat.Typing} -
- )} - {/*@ts-ignore*/} - {message?.tools?.length > 0 && ( -
- {message?.tools?.map((tool) => ( -
- {tool.isError === false ? ( - - ) : tool.isError === true ? ( - - ) : ( - - )} - {tool?.function?.name} -
- ))} -
- )} -
- onRightClick(e, message)} // hard to use - onDoubleClickCapture={() => { - if (!isMobileScreen) return; - setUserInput(getMessageTextContent(message)); - }} - fontSize={fontSize} - fontFamily={fontFamily} - parentRef={scrollRef} - defaultShow={i >= messages.length - 6} - /> - {getMessageImages(message).length == 1 && ( - + {/*@ts-ignore*/} + {message?.tools?.length > 0 && ( +
+ {message?.tools?.map((tool) => ( +
+ {tool.isError === false ? ( + + ) : tool.isError === true ? ( + + ) : ( + + )} + {tool?.function?.name} +
+ ))} +
)} - {getMessageImages(message).length > 1 && ( -
+ - {getMessageImages(message).map((image, index) => { - return ( - - ); - })} + // onContextMenu={(e) => onRightClick(e, message)} // hard to use + onDoubleClickCapture={() => { + if (!isMobileScreen) return; + setUserInput(getMessageTextContent(message)); + }} + fontSize={fontSize} + fontFamily={fontFamily} + parentRef={scrollRef} + defaultShow={i >= messages.length - 6} + /> + {getMessageImages(message).length == 1 && ( + + )} + {getMessageImages(message).length > 1 && ( +
+ {getMessageImages(message).map( + (image, index) => { + return ( + + ); + }, + )} +
+ )} +
+ {message?.audio_url && ( +
+
)} -
- {message?.audio_url && ( -
-
- )} -
- {isContext - ? Locale.Chat.IsContext - : message.date.toLocaleString()} +
+ {isContext + ? Locale.Chat.IsContext + : message.date.toLocaleString()} +
-
- {shouldShowClearContextDivider && } - - ); - })} + {shouldShowClearContextDivider && } + + ); + })}
(); +const clientsMap = new Map< + string, + { client: Client; primitives: Primitive[] } +>(); // Whether initialized let initialized = false; @@ -30,8 +38,11 @@ export async function initializeMcpClients() { try { logger.info(`Initializing MCP client: ${clientId}`); const client = await createClient(config, clientId); - clientsMap.set(clientId, client); - logger.success(`Client ${clientId} initialized`); + const primitives = await listPrimitives(client); + clientsMap.set(clientId, { client, primitives }); + logger.success( + `Client [${clientId}] initialized, ${primitives.length} primitives supported`, + ); } catch (error) { errorClients.push(clientId); logger.error(`Failed to initialize client ${clientId}: ${error}`); @@ -58,7 +69,7 @@ export async function executeMcpAction( ) { try { // Find the corresponding client - const client = clientsMap.get(clientId); + const client = clientsMap.get(clientId)?.client; if (!client) { logger.error(`Client ${clientId} not found`); return; @@ -80,3 +91,16 @@ export async function getAvailableClients() { (clientId) => !errorClients.includes(clientId), ); } + +// Get all primitives from all clients +export async function getAllPrimitives(): Promise< + { + clientId: string; + primitives: Primitive[]; + }[] +> { + return Array.from(clientsMap.entries()).map(([clientId, { primitives }]) => ({ + clientId, + primitives, + })); +} diff --git a/app/mcp/client.ts b/app/mcp/client.ts index 0600f00be..6650f9e2b 100644 --- a/app/mcp/client.ts +++ b/app/mcp/client.ts @@ -40,13 +40,13 @@ export async function createClient( return client; } -interface Primitive { +export interface Primitive { type: "resource" | "tool" | "prompt"; value: any; } /** List all resources, tools, and prompts */ -export async function listPrimitives(client: Client) { +export async function listPrimitives(client: Client): Promise { const capabilities = client.getServerCapabilities(); const primitives: Primitive[] = []; const promises = []; diff --git a/app/mcp/example.ts b/app/mcp/example.ts index 83fc8784c..f3b91fb8c 100644 --- a/app/mcp/example.ts +++ b/app/mcp/example.ts @@ -4,25 +4,25 @@ import conf from "./mcp_config.json"; const logger = new MCPClientLogger("MCP Server Example", true); -async function main() { - logger.info("Connecting to server..."); +const TEST_SERVER = "everything"; - const client = await createClient(conf.mcpServers.everything, "everything"); +async function main() { + logger.info(`All MCP servers: ${Object.keys(conf.mcpServers).join(", ")}`); + + logger.info(`Connecting to server ${TEST_SERVER}...`); + + const client = await createClient(conf.mcpServers[TEST_SERVER], TEST_SERVER); const primitives = await listPrimitives(client); - logger.success(`Connected to server everything`); + logger.success(`Connected to server ${TEST_SERVER}`); logger.info( - `server capabilities: ${Object.keys( - client.getServerCapabilities() ?? [], - ).join(", ")}`, + `${TEST_SERVER} supported primitives:\n${JSON.stringify( + primitives.filter((i) => i.type === "tool"), + null, + 2, + )}`, ); - - logger.info("Server supports the following primitives:"); - - primitives.forEach((primitive) => { - logger.info("\n" + JSON.stringify(primitive, null, 2)); - }); } main().catch((error) => { diff --git a/app/store/chat.ts b/app/store/chat.ts index e0ee95621..80c706ffd 100644 --- a/app/store/chat.ts +++ b/app/store/chat.ts @@ -21,6 +21,8 @@ import { DEFAULT_SYSTEM_TEMPLATE, GEMINI_SUMMARIZE_MODEL, KnowledgeCutOffDate, + MCP_PRIMITIVES_TEMPLATE, + MCP_SYSTEM_TEMPLATE, ServiceProvider, StoreKey, SUMMARIZE_MODEL, @@ -33,7 +35,7 @@ import { ModelConfig, ModelType, useAppConfig } from "./config"; import { useAccessStore } from "./access"; import { collectModelsWithDefaultModel } from "../utils/model"; import { createEmptyMask, Mask } from "./mask"; -import { executeMcpAction } from "../mcp/actions"; +import { executeMcpAction, getAllPrimitives } from "../mcp/actions"; import { extractMcpJson, isMcpJson } from "../mcp/utils"; const localStorage = safeLocalStorage(); @@ -196,6 +198,24 @@ function fillTemplateWith(input: string, modelConfig: ModelConfig) { return output; } +async function getMcpSystemPrompt(): Promise { + let primitives = await getAllPrimitives(); + primitives = primitives.filter((i) => + i.primitives.some((p) => p.type === "tool"), + ); + let primitivesString = ""; + primitives.forEach((i) => { + primitivesString += MCP_PRIMITIVES_TEMPLATE.replace( + "{{ clientId }}", + i.clientId, + ).replace( + "{{ primitives }}", + i.primitives.map((p) => JSON.stringify(p)).join("\n"), + ); + }); + return MCP_SYSTEM_TEMPLATE.replace("{{ MCP_PRIMITIVES }}", primitivesString); +} + const DEFAULT_CHAT_STATE = { sessions: [createEmptySession()], currentSessionIndex: 0, @@ -409,7 +429,7 @@ export const useChatStore = createPersistStore( }); // get recent messages - const recentMessages = get().getMessagesWithMemory(); + const recentMessages = await get().getMessagesWithMemory(); const sendMessages = recentMessages.concat(userMessage); const messageIndex = session.messages.length + 1; @@ -508,7 +528,7 @@ export const useChatStore = createPersistStore( } }, - getMessagesWithMemory() { + async getMessagesWithMemory() { const session = get().currentSession(); const modelConfig = session.mask.modelConfig; const clearContextIndex = session.clearContextIndex ?? 0; @@ -524,18 +544,26 @@ export const useChatStore = createPersistStore( (session.mask.modelConfig.model.startsWith("gpt-") || session.mask.modelConfig.model.startsWith("chatgpt-")); + const mcpSystemPrompt = await getMcpSystemPrompt(); + var systemPrompts: ChatMessage[] = []; systemPrompts = shouldInjectSystemPrompts ? [ createMessage({ role: "system", - content: fillTemplateWith("", { - ...modelConfig, - template: DEFAULT_SYSTEM_TEMPLATE, - }), + content: + fillTemplateWith("", { + ...modelConfig, + template: DEFAULT_SYSTEM_TEMPLATE, + }) + mcpSystemPrompt, }), ] - : []; + : [ + createMessage({ + role: "system", + content: mcpSystemPrompt, + }), + ]; if (shouldInjectSystemPrompts) { console.log( "[Global System Prompt] ", @@ -796,12 +824,12 @@ export const useChatStore = createPersistStore( ? JSON.stringify(result) : String(result); get().onUserInput( - `\`\`\`json:mcp:${mcpRequest.clientId}\n${mcpResponse}\n\`\`\``, + `\`\`\`json:mcp-response:${mcpRequest.clientId}\n${mcpResponse}\n\`\`\``, [], true, ); }) - .catch((error) => showToast(String(error))); + .catch((error) => showToast("MCP execution failed", error)); } } catch (error) { console.error("[MCP Error]", error);