feat: MCP message type

This commit is contained in:
Kadxy 2024-12-29 09:24:52 +08:00
parent e1ba8f1b0f
commit fe67f79050
4 changed files with 103 additions and 21 deletions

View File

@ -3,8 +3,9 @@
import { createClient, executeRequest } from "./client"; import { createClient, executeRequest } from "./client";
import { MCPClientLogger } from "./logger"; import { MCPClientLogger } from "./logger";
import conf from "./mcp_config.json"; import conf from "./mcp_config.json";
import { McpRequestMessage } from "./types";
const logger = new MCPClientLogger("MCP Server"); const logger = new MCPClientLogger("MCP Actions");
// Use Map to store all clients // Use Map to store all clients
const clientsMap = new Map<string, any>(); const clientsMap = new Map<string, any>();
@ -51,7 +52,10 @@ export async function initializeMcpClients() {
} }
// Execute MCP request // Execute MCP request
export async function executeMcpAction(clientId: string, request: any) { export async function executeMcpAction(
clientId: string,
request: McpRequestMessage,
) {
try { try {
// Find the corresponding client // Find the corresponding client
const client = clientsMap.get(clientId); const client = clientsMap.get(clientId);
@ -61,6 +65,7 @@ export async function executeMcpAction(clientId: string, request: any) {
} }
logger.info(`Executing MCP request for ${clientId}`); logger.info(`Executing MCP request for ${clientId}`);
// Execute request and return result // Execute request and return result
return await executeRequest(client, request); return await executeRequest(client, request);
} catch (error) { } catch (error) {

View File

@ -1,6 +1,7 @@
import { Client } from "@modelcontextprotocol/sdk/client/index.js"; import { Client } from "@modelcontextprotocol/sdk/client/index.js";
import { StdioClientTransport } from "@modelcontextprotocol/sdk/client/stdio.js"; import { StdioClientTransport } from "@modelcontextprotocol/sdk/client/stdio.js";
import { MCPClientLogger } from "./logger"; import { MCPClientLogger } from "./logger";
import { McpRequestMessage } from "./types";
import { z } from "zod"; import { z } from "zod";
export interface ServerConfig { export interface ServerConfig {
@ -79,6 +80,9 @@ export async function listPrimitives(client: Client) {
} }
/** Execute a request */ /** Execute a request */
export async function executeRequest(client: Client, request: any) { export async function executeRequest(
client: Client,
request: McpRequestMessage,
) {
return client.request(request, z.any()); return client.request(request, z.any());
} }

61
app/mcp/types.ts Normal file
View File

@ -0,0 +1,61 @@
// ref: https://spec.modelcontextprotocol.io/specification/basic/messages/
import { z } from "zod";
export interface McpRequestMessage {
jsonrpc?: "2.0";
id?: string | number;
method: "tools/call" | string;
params?: {
[key: string]: unknown;
};
}
export const McpRequestMessageSchema: z.ZodType<McpRequestMessage> = z.object({
jsonrpc: z.literal("2.0").optional(),
id: z.union([z.string(), z.number()]).optional(),
method: z.string(),
params: z.record(z.unknown()).optional(),
});
export interface McpResponseMessage {
jsonrpc?: "2.0";
id?: string | number;
result?: {
[key: string]: unknown;
};
error?: {
code: number;
message: string;
data?: unknown;
};
}
export const McpResponseMessageSchema: z.ZodType<McpResponseMessage> = z.object(
{
jsonrpc: z.literal("2.0").optional(),
id: z.union([z.string(), z.number()]).optional(),
result: z.record(z.unknown()).optional(),
error: z
.object({
code: z.number(),
message: z.string(),
data: z.unknown().optional(),
})
.optional(),
},
);
export interface McpNotifications {
jsonrpc?: "2.0";
method: string;
params?: {
[key: string]: unknown;
};
}
export const McpNotificationsSchema: z.ZodType<McpNotifications> = z.object({
jsonrpc: z.literal("2.0").optional(),
method: z.string(),
params: z.record(z.unknown()).optional(),
});

View File

@ -1,4 +1,9 @@
import { getMessageTextContent, trimTopic } from "../utils"; import {
getMessageTextContent,
isDalle3,
safeLocalStorage,
trimTopic,
} from "../utils";
import { indexedDBStorage } from "@/app/utils/indexedDB-storage"; import { indexedDBStorage } from "@/app/utils/indexedDB-storage";
import { nanoid } from "nanoid"; import { nanoid } from "nanoid";
@ -14,14 +19,13 @@ import {
DEFAULT_INPUT_TEMPLATE, DEFAULT_INPUT_TEMPLATE,
DEFAULT_MODELS, DEFAULT_MODELS,
DEFAULT_SYSTEM_TEMPLATE, DEFAULT_SYSTEM_TEMPLATE,
GEMINI_SUMMARIZE_MODEL,
KnowledgeCutOffDate, KnowledgeCutOffDate,
ServiceProvider,
StoreKey, StoreKey,
SUMMARIZE_MODEL, SUMMARIZE_MODEL,
GEMINI_SUMMARIZE_MODEL,
ServiceProvider,
} from "../constant"; } from "../constant";
import Locale, { getLang } from "../locales"; import Locale, { getLang } from "../locales";
import { isDalle3, safeLocalStorage } from "../utils";
import { prettyObject } from "../utils/format"; import { prettyObject } from "../utils/format";
import { createPersistStore } from "../utils/store"; import { createPersistStore } from "../utils/store";
import { estimateTokenLength } from "../utils/token"; import { estimateTokenLength } from "../utils/token";
@ -55,6 +59,7 @@ export type ChatMessage = RequestMessage & {
model?: ModelType; model?: ModelType;
tools?: ChatMessageTool[]; tools?: ChatMessageTool[];
audio_url?: string; audio_url?: string;
isMcpResponse?: boolean;
}; };
export function createMessage(override: Partial<ChatMessage>): ChatMessage { export function createMessage(override: Partial<ChatMessage>): ChatMessage {
@ -368,20 +373,22 @@ export const useChatStore = createPersistStore(
get().summarizeSession(false, targetSession); get().summarizeSession(false, targetSession);
}, },
async onUserInput(content: string, attachImages?: string[]) { async onUserInput(
content: string,
attachImages?: string[],
isMcpResponse?: boolean,
) {
const session = get().currentSession(); const session = get().currentSession();
const modelConfig = session.mask.modelConfig; const modelConfig = session.mask.modelConfig;
const userContent = fillTemplateWith(content, modelConfig); // MCP Response no need to fill template
console.log("[User Input] after template: ", userContent); let mContent: string | MultimodalContent[] = isMcpResponse
? content
: fillTemplateWith(content, modelConfig);
let mContent: string | MultimodalContent[] = userContent; if (!isMcpResponse && attachImages && attachImages.length > 0) {
if (attachImages && attachImages.length > 0) {
mContent = [ mContent = [
...(userContent ...(content ? [{ type: "text" as const, text: content }] : []),
? [{ type: "text" as const, text: userContent }]
: []),
...attachImages.map((url) => ({ ...attachImages.map((url) => ({
type: "image_url" as const, type: "image_url" as const,
image_url: { url }, image_url: { url },
@ -392,6 +399,7 @@ export const useChatStore = createPersistStore(
let userMessage: ChatMessage = createMessage({ let userMessage: ChatMessage = createMessage({
role: "user", role: "user",
content: mContent, content: mContent,
isMcpResponse,
}); });
const botMessage: ChatMessage = createMessage({ const botMessage: ChatMessage = createMessage({
@ -770,9 +778,10 @@ export const useChatStore = createPersistStore(
lastInput, lastInput,
}); });
}, },
/** check if the message contains MCP JSON and execute the MCP action */
checkMcpJson(message: ChatMessage) { checkMcpJson(message: ChatMessage) {
const content = const content = getMessageTextContent(message);
typeof message.content === "string" ? message.content : "";
if (isMcpJson(content)) { if (isMcpJson(content)) {
try { try {
const mcpRequest = extractMcpJson(content); const mcpRequest = extractMcpJson(content);
@ -782,11 +791,14 @@ export const useChatStore = createPersistStore(
executeMcpAction(mcpRequest.clientId, mcpRequest.mcp) executeMcpAction(mcpRequest.clientId, mcpRequest.mcp)
.then((result) => { .then((result) => {
console.log("[MCP Response]", result); console.log("[MCP Response]", result);
// 直接使用onUserInput发送结果 const mcpResponse =
get().onUserInput(
typeof result === "object" typeof result === "object"
? JSON.stringify(result) ? JSON.stringify(result)
: String(result), : String(result);
get().onUserInput(
`\`\`\`json:mcp:${mcpRequest.clientId}\n${mcpResponse}\n\`\`\``,
[],
true,
); );
}) })
.catch((error) => showToast(String(error))); .catch((error) => showToast(String(error)));