diff --git a/app/api/bedrock.ts b/app/api/bedrock.ts index 5739415ba..fa52363ae 100644 --- a/app/api/bedrock.ts +++ b/app/api/bedrock.ts @@ -15,7 +15,8 @@ function parseEventData(chunk: Uint8Array): any { // AWS Bedrock wraps the response in a 'body' field if (typeof parsed.body === "string") { try { - return JSON.parse(parsed.body); + const bodyJson = JSON.parse(parsed.body); + return bodyJson; } catch (e) { return { output: parsed.body }; } @@ -89,10 +90,12 @@ async function* transformBedrockStream( })}\n\n`; } } - // Handle LLaMA3 models - else if (modelId.startsWith("us.meta.llama3")) { + // Handle LLaMA models + else if (modelId.startsWith("us.meta.llama")) { let text = ""; - if (parsed.generation) { + if (parsed.outputs?.[0]?.text) { + text = parsed.outputs[0].text; + } else if (parsed.generation) { text = parsed.generation; } else if (parsed.output) { text = parsed.output; @@ -101,8 +104,6 @@ async function* transformBedrockStream( } if (text) { - // Clean up any control characters or invalid JSON characters - text = text.replace(/[\x00-\x1F\x7F-\x9F]/g, ""); yield `data: ${JSON.stringify({ delta: { text }, })}\n\n`; @@ -162,6 +163,7 @@ async function* transformBedrockStream( function validateRequest(body: any, modelId: string): void { if (!modelId) throw new Error("Model ID is required"); + // Handle nested body structure const bodyContent = body.body || body; if (modelId.startsWith("anthropic.claude")) { @@ -180,22 +182,20 @@ function validateRequest(body: any, modelId: string): void { } else if (typeof body.prompt !== "string") { throw new Error("prompt is required for Claude 2 and earlier"); } - } else if (modelId.startsWith("us.meta.llama3")) { - if (!bodyContent.prompt) { - throw new Error("prompt is required for LLaMA3 models"); + } else if (modelId.startsWith("us.meta.llama")) { + if (!bodyContent.prompt || typeof bodyContent.prompt !== "string") { + throw new Error("prompt string is required for LLaMA models"); + } + if ( + !bodyContent.max_gen_len || + typeof bodyContent.max_gen_len !== "number" + ) { + throw new Error("max_gen_len must be a positive number for LLaMA models"); } } else if (modelId.startsWith("mistral.mistral")) { if (!bodyContent.prompt) { throw new Error("prompt is required for Mistral models"); } - if ( - !bodyContent.prompt.startsWith("[INST]") || - !bodyContent.prompt.includes("[/INST]") - ) { - throw new Error( - "Mistral prompt must be wrapped in [INST] and [/INST] tags", - ); - } } else if (modelId.startsWith("amazon.titan")) { if (!bodyContent.inputText) throw new Error("Titan requires inputText"); } @@ -250,7 +250,6 @@ async function requestBedrock(req: NextRequest) { try { // Determine the endpoint and request body based on model type let endpoint; - let requestBody; const bodyText = await req.clone().text(); if (!bodyText) { @@ -258,6 +257,10 @@ async function requestBedrock(req: NextRequest) { } const bodyJson = JSON.parse(bodyText); + + // Debug log the request body + console.log("Original request body:", JSON.stringify(bodyJson, null, 2)); + validateRequest(bodyJson, modelId); // For all models, use standard endpoints @@ -267,26 +270,44 @@ async function requestBedrock(req: NextRequest) { endpoint = `${baseEndpoint}/model/${modelId}/invoke-with-response-stream`; } - // Set content type and accept headers for Mistral models + // Set additional headers based on model type + const additionalHeaders: Record = {}; + if ( + modelId.startsWith("us.meta.llama") || + modelId.startsWith("mistral.mistral") + ) { + additionalHeaders["content-type"] = "application/json"; + additionalHeaders["accept"] = "application/json"; + } + + // For Mistral models, unwrap the body object + const finalRequestBody = + modelId.startsWith("mistral.mistral") && bodyJson.body + ? bodyJson.body + : bodyJson; + + // Set content type and accept headers for specific models const headers = await sign({ method: "POST", url: endpoint, region: awsRegion, accessKeyId: awsAccessKey, secretAccessKey: awsSecretKey, - body: JSON.stringify(bodyJson.body || bodyJson), + body: JSON.stringify(finalRequestBody), service: "bedrock", isStreaming: shouldStream !== "false", - ...(modelId.startsWith("mistral.mistral") && { - contentType: "application/json", - accept: "application/json", - }), + additionalHeaders, }); + // Debug log the final request body + // console.log("Final request endpoint:", endpoint); + // console.log(headers); + // console.log("Final request body:", JSON.stringify(finalRequestBody, null, 2)); + const res = await fetch(endpoint, { method: "POST", headers, - body: JSON.stringify(bodyJson.body || bodyJson), + body: JSON.stringify(finalRequestBody), redirect: "manual", // @ts-ignore duplex: "half", diff --git a/app/client/platforms/bedrock.ts b/app/client/platforms/bedrock.ts index aff4c5582..5c661c86f 100644 --- a/app/client/platforms/bedrock.ts +++ b/app/client/platforms/bedrock.ts @@ -85,14 +85,17 @@ export class BedrockApi implements LLMApi { } // Handle LLaMA models - if (modelId.startsWith("us.meta.llama3")) { + if (modelId.startsWith("us.meta.llama")) { if (res?.delta?.text) { return res.delta.text; } if (res?.generation) { return res.generation; } - if (typeof res?.output === "string") { + if (res?.outputs?.[0]?.text) { + return res.outputs[0].text; + } + if (res?.output) { return res.output; } if (typeof res === "string") { @@ -103,11 +106,28 @@ export class BedrockApi implements LLMApi { // Handle Mistral models if (modelId.startsWith("mistral.mistral")) { - if (res?.delta?.text) return res.delta.text; - return res?.outputs?.[0]?.text || res?.output || res?.completion || ""; + if (res?.delta?.text) { + return res.delta.text; + } + if (res?.outputs?.[0]?.text) { + return res.outputs[0].text; + } + if (res?.content?.[0]?.text) { + return res.content[0].text; + } + if (res?.output) { + return res.output; + } + if (res?.completion) { + return res.completion; + } + if (typeof res === "string") { + return res; + } + return ""; } - // Handle Claude models and fallback cases + // Handle Claude models if (res?.content?.[0]?.text) return res.content[0].text; if (res?.messages?.[0]?.content?.[0]?.text) return res.messages[0].content[0].text; @@ -142,14 +162,11 @@ export class BedrockApi implements LLMApi { ] : messages; - // Format messages without role prefixes for Titan const inputText = allMessages .map((m) => { - // Include system message as a prefix instruction if (m.role === "system") { return getMessageTextContent(m); } - // For user/assistant messages, just include the content return getMessageTextContent(m); }) .join("\n\n"); @@ -166,25 +183,39 @@ export class BedrockApi implements LLMApi { }; } - // Handle LLaMA3 models - if (model.startsWith("us.meta.llama3")) { - // Only include the last user message for LLaMA - const lastMessage = messages[messages.length - 1]; - const prompt = getMessageTextContent(lastMessage); + // Handle LLaMA models + if (model.startsWith("us.meta.llama")) { + const allMessages = systemMessage + ? [ + { role: "system" as MessageRole, content: systemMessage }, + ...messages, + ] + : messages; + + const prompt = allMessages + .map((m) => { + const content = getMessageTextContent(m); + if (m.role === "system") { + return `System: ${content}`; + } else if (m.role === "user") { + return `User: ${content}`; + } else if (m.role === "assistant") { + return `Assistant: ${content}`; + } + return content; + }) + .join("\n\n"); return { - contentType: "application/json", - accept: "application/json", - body: { - prompt: prompt, - max_gen_len: modelConfig.max_tokens || 256, - temperature: modelConfig.temperature || 0.5, - top_p: 0.9, - }, + prompt, + max_gen_len: modelConfig.max_tokens || 512, + temperature: modelConfig.temperature || 0.6, + top_p: modelConfig.top_p || 0.9, + stop: ["User:", "System:", "Assistant:", "\n\n"], }; } - // Handle Mistral models with correct instruction format + // Handle Mistral models if (model.startsWith("mistral.mistral")) { const allMessages = systemMessage ? [ @@ -193,25 +224,29 @@ export class BedrockApi implements LLMApi { ] : messages; - // Format messages as a conversation with instruction tags - const prompt = `[INST] ${allMessages - .map((m) => getMessageTextContent(m)) - .join("\n")} [/INST]`; + const formattedConversation = allMessages + .map((m) => { + const content = getMessageTextContent(m); + if (m.role === "system") { + return content; + } else if (m.role === "user") { + return content; + } else if (m.role === "assistant") { + return content; + } + return content; + }) + .join("\n"); + // Format according to Mistral's requirements return { - contentType: "application/json", - accept: "application/json", - body: { - prompt, - max_tokens: modelConfig.max_tokens || 4096, - temperature: modelConfig.temperature || 0.5, - top_p: 0.9, - top_k: 50, - }, + prompt: formattedConversation, + max_tokens: modelConfig.max_tokens || 4096, + temperature: modelConfig.temperature || 0.7, }; } - // Handle Claude models (existing implementation) + // Handle Claude models const isClaude3 = model.startsWith("anthropic.claude-3"); const formattedMessages = messages .filter( @@ -253,12 +288,14 @@ export class BedrockApi implements LLMApi { }); return { - anthropic_version: "bedrock-2023-05-31", - max_tokens: modelConfig.max_tokens, - messages: formattedMessages, - ...(systemMessage && { system: systemMessage }), - temperature: modelConfig.temperature, - ...(isClaude3 && { top_k: 5 }), + body: { + anthropic_version: "bedrock-2023-05-31", + max_tokens: modelConfig.max_tokens, + messages: formattedMessages, + ...(systemMessage && { system: systemMessage }), + temperature: modelConfig.temperature, + ...(isClaude3 && { top_k: modelConfig.top_k || 50 }), + }, }; } @@ -301,6 +338,13 @@ export class BedrockApi implements LLMApi { const headers = getHeaders(); headers.ModelID = modelConfig.model; + // For LLaMA and Mistral models, send the request body directly without the 'body' wrapper + const finalRequestBody = + modelConfig.model.startsWith("us.meta.llama") || + modelConfig.model.startsWith("mistral.mistral") + ? requestBody + : requestBody.body; + if (options.config.stream) { let index = -1; let currentToolArgs = ""; @@ -312,7 +356,7 @@ export class BedrockApi implements LLMApi { return stream( chatPath, - requestBody, + finalRequestBody, headers, (tools as ToolDefinition[]).map((tool) => ({ name: tool?.function?.name, @@ -420,7 +464,7 @@ export class BedrockApi implements LLMApi { const res = await fetch(chatPath, { method: "POST", headers, - body: JSON.stringify(requestBody), + body: JSON.stringify(finalRequestBody), }); const resJson = await res.json();