diff --git a/app/api/bedrock.ts b/app/api/bedrock.ts index f2e008bf5..79063e03c 100644 --- a/app/api/bedrock.ts +++ b/app/api/bedrock.ts @@ -105,26 +105,17 @@ async function requestBedrock(req: NextRequest) { console.log("[Bedrock Request] Model ID:", modelId); // Handle tools for different models - const isMistralModel = modelId.toLowerCase().includes("mistral"); + const isMistralLargeModel = modelId + .toLowerCase() + .includes("mistral.mistral-large"); const isClaudeModel = modelId.toLowerCase().includes("claude"); - const requestBody = { + const requestBody: any = { ...bodyJson, }; if (tools && tools.length > 0) { - if (isClaudeModel) { - // Claude models already have correct tool format - requestBody.tools = tools; - } else if (isMistralModel) { - // Format messages for Mistral - if (typeof requestBody.prompt === "string") { - requestBody.messages = [ - { role: "user", content: requestBody.prompt }, - ]; - delete requestBody.prompt; - } - + if (isMistralLargeModel) { // Add tools in Mistral's format requestBody.tool_choice = "auto"; requestBody.tools = tools.map((tool) => ({ @@ -135,6 +126,8 @@ async function requestBedrock(req: NextRequest) { parameters: tool.input_schema, }, })); + } else if (isClaudeModel) { + requestBody.tools = tools; } } diff --git a/app/client/platforms/bedrock.ts b/app/client/platforms/bedrock.ts index 1799e4dbc..856c67b47 100644 --- a/app/client/platforms/bedrock.ts +++ b/app/client/platforms/bedrock.ts @@ -21,7 +21,22 @@ const ClaudeMapper = { system: "user", } as const; +const MistralMapper = { + system: "system", + user: "user", + assistant: "assistant", +} as const; + type ClaudeRole = keyof typeof ClaudeMapper; +type MistralRole = keyof typeof MistralMapper; + +interface Tool { + function?: { + name?: string; + description?: string; + parameters?: any; + }; +} export class BedrockApi implements LLMApi { speech(options: SpeechOptions): Promise { @@ -30,7 +45,6 @@ export class BedrockApi implements LLMApi { formatRequestBody(messages: ChatOptions["messages"], modelConfig: any) { const model = modelConfig.model; - const visionModel = isVisionModel(modelConfig.model); // Handle Titan models @@ -53,37 +67,27 @@ export class BedrockApi implements LLMApi { // Handle LLaMA models if (model.includes("meta.llama")) { - // Format conversation for Llama models - let prompt = ""; - let systemPrompt = ""; + let prompt = "<|begin_of_text|>"; // Extract system message if present const systemMessage = messages.find((m) => m.role === "system"); if (systemMessage) { - systemPrompt = getMessageTextContent(systemMessage); + prompt += `<|start_header_id|>system<|end_header_id|>\n${getMessageTextContent( + systemMessage, + )}<|eot_id|>`; } // Format the conversation const conversationMessages = messages.filter((m) => m.role !== "system"); - prompt = `[INST] <>\n${ - systemPrompt || "You are a helpful, respectful and honest assistant." - }\n<>\n\n`; - - for (let i = 0; i < conversationMessages.length; i++) { - const message = conversationMessages[i]; + for (const message of conversationMessages) { + const role = message.role === "assistant" ? "assistant" : "user"; const content = getMessageTextContent(message); - if (i === 0 && message.role === "user") { - // First user message goes in the same [INST] block as system prompt - prompt += `${content} [/INST]`; - } else { - if (message.role === "user") { - prompt += `\n\n[INST] ${content} [/INST]`; - } else { - prompt += ` ${content} `; - } - } + prompt += `<|start_header_id|>${role}<|end_header_id|>\n${content}<|eot_id|>`; } + // Add the final assistant header to prompt completion + prompt += "<|start_header_id|>assistant<|end_header_id|>"; + return { prompt, max_gen_len: modelConfig.max_tokens || 512, @@ -94,9 +98,8 @@ export class BedrockApi implements LLMApi { // Handle Mistral models if (model.startsWith("mistral.mistral")) { - // Format messages for Mistral's chat format const formattedMessages = messages.map((message) => ({ - role: message.role, + role: MistralMapper[message.role as MistralRole] || "user", content: getMessageTextContent(message), })); @@ -234,6 +237,11 @@ export class BedrockApi implements LLMApi { }); const finalRequestBody = this.formatRequestBody(messages, modelConfig); + console.log( + "[Bedrock Client] Request Body:", + JSON.stringify(finalRequestBody, null, 2), + ); + if (shouldStream) { let index = -1; const [tools, funcs] = usePluginStore @@ -253,6 +261,7 @@ export class BedrockApi implements LLMApi { })), funcs, controller, + // parseSSE (text: string, runTools: ChatMessageTool[]) => { // console.log("parseSSE", text, runTools); let chunkJson: @@ -304,36 +313,73 @@ export class BedrockApi implements LLMApi { ) => { // reset index value index = -1; - // @ts-ignore - requestPayload?.messages?.splice( + + const modelId = modelConfig.model; + const isMistral = modelId.startsWith("mistral.mistral"); + const isClaude = modelId.includes("anthropic.claude"); + + if (isClaude) { + // Format for Claude // @ts-ignore - requestPayload?.messages?.length, - 0, - { - role: "assistant", - content: toolCallMessage.tool_calls.map( - (tool: ChatMessageTool) => ({ - type: "tool_use", - id: tool.id, - name: tool?.function?.name, - input: tool?.function?.arguments - ? JSON.parse(tool?.function?.arguments) - : {}, - }), - ), - }, - // @ts-ignore - ...toolCallResult.map((result) => ({ - role: "user", - content: [ - { - type: "tool_result", - tool_use_id: result.tool_call_id, - content: result.content, - }, - ], - })), - ); + requestPayload?.messages?.splice( + // @ts-ignore + requestPayload?.messages?.length, + 0, + { + role: "assistant", + content: toolCallMessage.tool_calls.map( + (tool: ChatMessageTool) => ({ + type: "tool_use", + id: tool.id, + name: tool?.function?.name, + input: tool?.function?.arguments + ? JSON.parse(tool?.function?.arguments) + : {}, + }), + ), + }, + // @ts-ignore + ...toolCallResult.map((result) => ({ + role: "user", + content: [ + { + type: "tool_result", + tool_use_id: result.tool_call_id, + content: result.content, + }, + ], + })), + ); + } else if (isMistral) { + // Format for Mistral + requestPayload?.messages?.splice( + requestPayload?.messages?.length, + 0, + { + role: "assistant", + content: "", + // @ts-ignore + tool_calls: toolCallMessage.tool_calls.map( + (tool: ChatMessageTool) => ({ + id: tool.id, + function: { + name: tool?.function?.name, + arguments: tool?.function?.arguments || "{}", + }, + }), + ), + }, + ...toolCallResult.map((result) => ({ + role: "tool", + tool_call_id: result.tool_call_id, + content: result.content, + })), + ); + } else { + console.warn( + `[Bedrock Client] Unhandled model type for tool calls: ${modelId}`, + ); + } }, options, ); @@ -368,6 +414,7 @@ export class BedrockApi implements LLMApi { options.onError?.(e as Error); } } + path(path: string): string { const accessStore = useAccessStore.getState(); let baseUrl = accessStore.useCustomConfig ? accessStore.bedrockUrl : ""; diff --git a/app/constant.ts b/app/constant.ts index 75ed0a403..8efee805f 100644 --- a/app/constant.ts +++ b/app/constant.ts @@ -342,12 +342,9 @@ const bedrockModels = [ // Meta Llama Models "us.meta.llama3-1-8b-instruct-v1:0", "us.meta.llama3-1-70b-instruct-v1:0", - "us.meta.llama3-2-1b-instruct-v1:0", - "us.meta.llama3-2-3b-instruct-v1:0", "us.meta.llama3-2-11b-instruct-v1:0", "us.meta.llama3-2-90b-instruct-v1:0", // Mistral Models - "mistral.mistral-7b-instruct-v0:2", "mistral.mistral-large-2402-v1:0", "mistral.mistral-large-2407-v1:0", ]; diff --git a/app/utils/aws.ts b/app/utils/aws.ts index 1dd23d067..cb23f60e2 100644 --- a/app/utils/aws.ts +++ b/app/utils/aws.ts @@ -245,6 +245,7 @@ export async function sign({ export function parseEventData(chunk: Uint8Array): any { const decoder = new TextDecoder(); const text = decoder.decode(chunk); + // console.info("[AWS Parse ] parsing:", text); try { const parsed = JSON.parse(text); // AWS Bedrock wraps the response in a 'body' field @@ -317,7 +318,10 @@ export function extractMessage(res: any, modelId: string = ""): string { // Handle Mistral model response format if (modelId.toLowerCase().includes("mistral")) { - return res?.outputs?.[0]?.text || ""; + if (res.choices?.[0]?.message?.content) { + return res.choices[0].message.content; + } + return res.output || ""; } // Handle Llama model response format @@ -334,9 +338,7 @@ export async function* transformBedrockStream( modelId: string, ) { const reader = stream.getReader(); - let accumulatedText = ""; - let toolCallStarted = false; - let currentToolCall = null; + let toolInput = ""; try { while (true) { @@ -349,90 +351,54 @@ export async function* transformBedrockStream( // console.log("parseEventData========================="); // console.log(parsed); + // Handle Mistral models if (modelId.toLowerCase().includes("mistral")) { - // If we have content, accumulate it - if ( - parsed.choices?.[0]?.message?.role === "assistant" && - parsed.choices?.[0]?.message?.content - ) { - accumulatedText += parsed.choices?.[0]?.message?.content; - // console.log("accumulatedText========================="); - // console.log(accumulatedText); - // Check for tool call in the accumulated text - if (!toolCallStarted && accumulatedText.includes("```json")) { - const jsonMatch = accumulatedText.match( - /```json\s*({[\s\S]*?})\s*```/, - ); - if (jsonMatch) { - try { - const toolData = JSON.parse(jsonMatch[1]); - currentToolCall = { - id: `tool-${Date.now()}`, - name: toolData.name, - arguments: toolData.arguments, - }; + // Handle tool calls + if (parsed.choices?.[0]?.message?.tool_calls) { + const toolCalls = parsed.choices[0].message.tool_calls; + for (const toolCall of toolCalls) { + // Emit tool call start + yield `data: ${JSON.stringify({ + type: "content_block_start", + content_block: { + type: "tool_use", + id: toolCall.id || `tool-${Date.now()}`, + name: toolCall.function?.name, + }, + })}\n\n`; - // Emit tool call start - yield `data: ${JSON.stringify({ - type: "content_block_start", - content_block: { - type: "tool_use", - id: currentToolCall.id, - name: currentToolCall.name, - }, - })}\n\n`; - - // Emit tool arguments - yield `data: ${JSON.stringify({ - type: "content_block_delta", - delta: { - type: "input_json_delta", - partial_json: JSON.stringify(currentToolCall.arguments), - }, - })}\n\n`; - - // Emit tool call stop - yield `data: ${JSON.stringify({ - type: "content_block_stop", - })}\n\n`; - - // Clear the accumulated text after processing the tool call - accumulatedText = accumulatedText.replace( - /```json\s*{[\s\S]*?}\s*```/, - "", - ); - toolCallStarted = false; - currentToolCall = null; - } catch (e) { - console.error("Failed to parse tool JSON:", e); - } + // Emit tool arguments + if (toolCall.function?.arguments) { + yield `data: ${JSON.stringify({ + type: "content_block_delta", + delta: { + type: "input_json_delta", + partial_json: toolCall.function.arguments, + }, + })}\n\n`; } - } - // emit the text content if it's not empty - if (parsed.choices?.[0]?.message?.content.trim()) { + + // Emit tool call stop yield `data: ${JSON.stringify({ - delta: { text: parsed.choices?.[0]?.message?.content }, - })}\n\n`; - } - // Handle stop reason if present - if (parsed.choices?.[0]?.stop_reason) { - yield `data: ${JSON.stringify({ - delta: { stop_reason: parsed.choices[0].stop_reason }, + type: "content_block_stop", })}\n\n`; } + continue; } - } - // Handle Llama models - else if (modelId.toLowerCase().includes("llama")) { - if (parsed.generation) { + + // Handle regular content + const content = parsed.choices?.[0]?.message?.content; + if (content?.trim()) { yield `data: ${JSON.stringify({ - delta: { text: parsed.generation }, + delta: { text: content }, })}\n\n`; } - if (parsed.stop_reason) { + + // Handle stop reason + if (parsed.choices?.[0]?.finish_reason) { yield `data: ${JSON.stringify({ - delta: { stop_reason: parsed.stop_reason }, + delta: { stop_reason: parsed.choices[0].finish_reason }, })}\n\n`; } } @@ -469,8 +435,9 @@ export async function* transformBedrockStream( })}\n\n`; } } - } else { - // Handle other model text responses + } + // Handle other models + else { const text = parsed.outputText || parsed.generation || ""; if (text) { yield `data: ${JSON.stringify({