diff --git a/app/api/bedrock.ts b/app/api/bedrock.ts index bf1313eb0..7da14a17b 100644 --- a/app/api/bedrock.ts +++ b/app/api/bedrock.ts @@ -80,54 +80,16 @@ async function requestBedrock(req: NextRequest) { } catch (e) { throw new Error(`Invalid JSON in request body: ${e}`); } - // console.log( - // "[Bedrock Request] original Body:", - // JSON.stringify(bodyJson, null, 2), - // ); - - // Extract tool configuration if present - let tools: any[] | undefined; - if (bodyJson.tools) { - tools = bodyJson.tools; - delete bodyJson.tools; // Remove from main request body - } - + console.log("[Bedrock Request] Initiating request"); // Get endpoint and prepare request const endpoint = getBedrockEndpoint( credentials.region, modelId, shouldStream, ); - - console.log("[Bedrock Request] Initiating request"); - - // Handle tools for different models - const isMistralLargeModel = modelId - .toLowerCase() - .includes("mistral.mistral-large"); - const isClaudeModel = modelId.toLowerCase().includes("claude"); - const requestBody: any = { ...bodyJson, }; - - if (tools && tools.length > 0) { - if (isMistralLargeModel) { - // Add tools in Mistral's format - requestBody.tool_choice = "auto"; - requestBody.tools = tools.map((tool) => ({ - type: "function", - function: { - name: tool.name, - description: tool.description, - parameters: tool.input_schema, - }, - })); - } else if (isClaudeModel) { - requestBody.tools = tools; - } - } - // Sign request const headers = await sign({ method: "POST", diff --git a/app/client/platforms/bedrock.ts b/app/client/platforms/bedrock.ts index b2bd2c2bb..cd08cb3bb 100644 --- a/app/client/platforms/bedrock.ts +++ b/app/client/platforms/bedrock.ts @@ -9,7 +9,7 @@ import { } from "@/app/store"; import { preProcessImageContent } from "@/app/utils/chat"; import { getMessageTextContent, isVisionModel } from "@/app/utils"; -import { ApiPath, BEDROCK_BASE_URL } from "@/app/constant"; +import { ApiPath, BEDROCK_BASE_URL, REQUEST_TIMEOUT_MS } from "@/app/constant"; import { getClientConfig } from "@/app/config/client"; import { extractMessage, @@ -18,8 +18,6 @@ import { parseEventData, sign, } from "@/app/utils/aws"; -import { RequestPayload } from "./openai"; -import { REQUEST_TIMEOUT_MS } from "@/app/constant"; import { prettyObject } from "@/app/utils/format"; import Locale from "@/app/locales"; @@ -35,6 +33,15 @@ const MistralMapper = { assistant: "assistant", } as const; type MistralRole = keyof typeof MistralMapper; + +interface Tool { + function?: { + name?: string; + description?: string; + parameters?: any; + }; +} + export class BedrockApi implements LLMApi { speech(options: SpeechOptions): Promise { throw new Error("Speech not implemented for Bedrock."); @@ -44,8 +51,15 @@ export class BedrockApi implements LLMApi { const model = modelConfig.model; const visionModel = isVisionModel(modelConfig.model); + // Get tools if available + const [tools] = usePluginStore + .getState() + .getAsTools(useChatStore.getState().currentSession().mask?.plugin || []); + + const toolsArray = (tools as Tool[]) || []; + // Handle Nova models - if (model.startsWith("us.amazon.nova")) { + if (model.includes("amazon.nova")) { // Extract system message if present const systemMessage = messages.find((m) => m.role === "system"); const conversationMessages = messages.filter((m) => m.role !== "system"); @@ -107,6 +121,26 @@ export class BedrockApi implements LLMApi { ]; } + // Add tools if available - now in correct format + if (toolsArray.length > 0) { + requestBody.toolConfig = { + tools: toolsArray.map((tool) => ({ + toolSpec: { + name: tool?.function?.name || "", + description: tool?.function?.description || "", + inputSchema: { + json: { + type: "object", + properties: tool?.function?.parameters || {}, + required: Object.keys(tool?.function?.parameters || {}), + }, + }, + }, + })), + // toolChoice: { auto: {} } + }; + } + return requestBody; } @@ -160,18 +194,33 @@ export class BedrockApi implements LLMApi { } // Handle Mistral models - if (model.startsWith("mistral.mistral")) { + if (model.includes("mistral.mistral")) { const formattedMessages = messages.map((message) => ({ role: MistralMapper[message.role as MistralRole] || "user", content: getMessageTextContent(message), })); - return { + const requestBody: any = { messages: formattedMessages, max_tokens: modelConfig.max_tokens || 4096, temperature: modelConfig.temperature || 0.7, top_p: modelConfig.top_p || 0.9, }; + + // Add tools if available + if (toolsArray.length > 0) { + requestBody.tool_choice = "auto"; + requestBody.tools = toolsArray.map((tool) => ({ + type: "function", + function: { + name: tool?.function?.name, + description: tool?.function?.description, + parameters: tool?.function?.parameters, + }, + })); + } + + return requestBody; } // Handle Claude models @@ -254,6 +303,16 @@ export class BedrockApi implements LLMApi { top_p: modelConfig.top_p || 0.9, top_k: modelConfig.top_k || 5, }; + + // Add tools if available for Claude models + if (toolsArray.length > 0 && model.includes("anthropic.claude")) { + requestBody.tools = toolsArray.map((tool) => ({ + name: tool?.function?.name || "", + description: tool?.function?.description || "", + input_schema: tool?.function?.parameters || {}, + })); + } + return requestBody; } @@ -333,23 +392,18 @@ export class BedrockApi implements LLMApi { chatPath, finalRequestBody, headers, - // @ts-ignore - tools.map((tool) => ({ - name: tool?.function?.name, - description: tool?.function?.description, - input_schema: tool?.function?.parameters, - })), funcs, controller, // processToolMessage, include tool_calls message and tool call results ( - requestPayload: RequestPayload, + requestPayload: any[], toolCallMessage: any, toolCallResult: any[], ) => { const modelId = modelConfig.model; - const isMistral = modelId.startsWith("mistral.mistral"); + const isMistral = modelId.includes("mistral.mistral"); const isClaude = modelId.includes("anthropic.claude"); + const isNova = modelId.includes("amazon.nova"); if (isClaude) { // Format for Claude @@ -385,7 +439,9 @@ export class BedrockApi implements LLMApi { ); } else if (isMistral) { // Format for Mistral + // @ts-ignore requestPayload?.messages?.splice( + // @ts-ignore requestPayload?.messages?.length, 0, { @@ -408,6 +464,47 @@ export class BedrockApi implements LLMApi { content: result.content, })), ); + } else if (isNova) { + // Format for Nova + // @ts-ignore + requestPayload?.messages?.splice( + // @ts-ignore + requestPayload?.messages?.length, + 0, + { + role: "assistant", + content: [ + { + text: "", // Add empty text content to satisfy type requirements + tool_calls: toolCallMessage.tool_calls.map( + (tool: ChatMessageTool) => ({ + id: tool.id, + name: tool?.function?.name, + arguments: tool?.function?.arguments + ? JSON.parse(tool?.function?.arguments) + : {}, + }), + ), + }, + ], + }, + ...toolCallResult.map((result) => ({ + role: "user", + content: [ + { + toolUseId: result.tool_call_id, + content: [ + { + json: + typeof result.content === "string" + ? JSON.parse(result.content) + : result.content, + }, + ], + }, + ], + })), + ); } else { console.warn( `[Bedrock Client] Unhandled model type for tool calls: ${modelId}`, @@ -457,7 +554,6 @@ function bedrockStream( chatPath: string, requestPayload: any, headers: any, - tools: any[], funcs: Record, controller: AbortController, processToolMessage: ( @@ -512,8 +608,13 @@ function bedrockStream( return Promise.all( toolCallMessage.tool_calls.map((tool) => { options?.onBeforeTool?.(tool); + const funcName = tool?.function?.name || tool?.name; + if (!funcName || !funcs[funcName]) { + console.error(`Function ${funcName} not found in funcs:`, funcs); + return Promise.reject(`Function ${funcName} not found`); + } return Promise.resolve( - funcs[tool.function.name]( + funcs[funcName]( tool?.function?.arguments ? JSON.parse(tool?.function?.arguments) : {}, @@ -547,7 +648,7 @@ function bedrockStream( return e.toString(); }) .then((content) => ({ - name: tool.function.name, + name: funcName, role: "tool", content, tool_call_id: tool.id, @@ -558,7 +659,7 @@ function bedrockStream( setTimeout(() => { console.debug("[BedrockAPI for toolCallResult] restart"); running = false; - bedrockChatApi(chatPath, headers, requestPayload, tools); + bedrockChatApi(chatPath, headers, requestPayload); }, 60); }); } @@ -577,7 +678,6 @@ function bedrockStream( chatPath: string, headers: any, requestPayload: any, - tools: any, ) { const requestTimeoutId = setTimeout( () => controller.abort(), @@ -588,10 +688,7 @@ function bedrockStream( const res = await fetch(chatPath, { method: "POST", headers, - body: JSON.stringify({ - ...requestPayload, - tools: tools && tools.length ? tools : undefined, - }), + body: JSON.stringify(requestPayload), redirect: "manual", // @ts-ignore duplex: "half", @@ -699,5 +796,5 @@ function bedrockStream( } console.debug("[BedrockAPI] start"); - bedrockChatApi(chatPath, headers, requestPayload, tools); + bedrockChatApi(chatPath, headers, requestPayload); } diff --git a/app/utils/aws.ts b/app/utils/aws.ts index 6e5943885..161766435 100644 --- a/app/utils/aws.ts +++ b/app/utils/aws.ts @@ -327,6 +327,31 @@ export function processMessage( if (!data) return { remainText, index }; try { + // Handle Nova's tool calls + // console.log("processMessage data=========================",data); + if ( + data.stopReason === "tool_use" && + data.output?.message?.content?.[0]?.toolUse + ) { + const toolUse = data.output.message.content[0].toolUse; + index += 1; + runTools.push({ + id: `tool-${Date.now()}`, + type: "function", + function: { + name: toolUse.name, + arguments: JSON.stringify(toolUse.input), + }, + }); + return { remainText, index }; + } + + // Handle Nova's text content + if (data.output?.message?.content?.[0]?.text) { + remainText += data.output.message.content[0].text; + return { remainText, index }; + } + // Handle Nova's messageStart event if (data.messageStart) { return { remainText, index }; @@ -382,7 +407,7 @@ export function processMessage( return { remainText, index }; } - // Handle tool calls + // Handle tool calls for other models if (data.choices?.[0]?.message?.tool_calls) { for (const toolCall of data.choices[0].message.tool_calls) { index += 1;