From 238eb70986ad744a31bd5388a9cdfac355856c8e Mon Sep 17 00:00:00 2001 From: glay Date: Sat, 23 Nov 2024 16:27:19 +0800 Subject: [PATCH] =?UTF-8?q?=E5=AE=8C=E5=96=84mistral=E6=A8=A1=E5=9E=8B?= =?UTF-8?q?=E7=9A=84=E6=8E=A8=E7=90=86=E7=BB=93=E6=9E=9C?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- app/api/bedrock.ts | 41 ++++++++++++----- app/client/platforms/bedrock.ts | 79 ++++++++++++++++++++++----------- 2 files changed, 81 insertions(+), 39 deletions(-) diff --git a/app/api/bedrock.ts b/app/api/bedrock.ts index e342a8867..5739415ba 100644 --- a/app/api/bedrock.ts +++ b/app/api/bedrock.ts @@ -110,8 +110,17 @@ async function* transformBedrockStream( } // Handle Mistral models else if (modelId.startsWith("mistral.mistral")) { - const text = - parsed.output || parsed.outputs?.[0]?.text || parsed.completion || ""; + let text = ""; + if (parsed.outputs?.[0]?.text) { + text = parsed.outputs[0].text; + } else if (parsed.output) { + text = parsed.output; + } else if (parsed.completion) { + text = parsed.completion; + } else if (typeof parsed === "string") { + text = parsed; + } + if (text) { yield `data: ${JSON.stringify({ delta: { text }, @@ -176,7 +185,17 @@ function validateRequest(body: any, modelId: string): void { throw new Error("prompt is required for LLaMA3 models"); } } else if (modelId.startsWith("mistral.mistral")) { - if (!bodyContent.prompt) throw new Error("Mistral requires a prompt"); + if (!bodyContent.prompt) { + throw new Error("prompt is required for Mistral models"); + } + if ( + !bodyContent.prompt.startsWith("[INST]") || + !bodyContent.prompt.includes("[/INST]") + ) { + throw new Error( + "Mistral prompt must be wrapped in [INST] and [/INST] tags", + ); + } } else if (modelId.startsWith("amazon.titan")) { if (!bodyContent.inputText) throw new Error("Titan requires inputText"); } @@ -247,29 +266,27 @@ async function requestBedrock(req: NextRequest) { } else { endpoint = `${baseEndpoint}/model/${modelId}/invoke-with-response-stream`; } - requestBody = JSON.stringify(bodyJson.body || bodyJson); - - // console.log("Request to AWS Bedrock:", { - // endpoint, - // modelId, - // body: requestBody, - // }); + // Set content type and accept headers for Mistral models const headers = await sign({ method: "POST", url: endpoint, region: awsRegion, accessKeyId: awsAccessKey, secretAccessKey: awsSecretKey, - body: requestBody, + body: JSON.stringify(bodyJson.body || bodyJson), service: "bedrock", isStreaming: shouldStream !== "false", + ...(modelId.startsWith("mistral.mistral") && { + contentType: "application/json", + accept: "application/json", + }), }); const res = await fetch(endpoint, { method: "POST", headers, - body: requestBody, + body: JSON.stringify(bodyJson.body || bodyJson), redirect: "manual", // @ts-ignore duplex: "half", diff --git a/app/client/platforms/bedrock.ts b/app/client/platforms/bedrock.ts index 0f7d73022..aff4c5582 100644 --- a/app/client/platforms/bedrock.ts +++ b/app/client/platforms/bedrock.ts @@ -74,16 +74,30 @@ export class BedrockApi implements LLMApi { try { // Handle Titan models if (modelId.startsWith("amazon.titan")) { - if (res?.delta?.text) return res.delta.text; - return res?.outputText || ""; + let text = ""; + if (res?.delta?.text) { + text = res.delta.text; + } else { + text = res?.outputText || ""; + } + // Clean up Titan response by removing leading question mark and whitespace + return text.replace(/^[\s?]+/, ""); } // Handle LLaMA models if (modelId.startsWith("us.meta.llama3")) { - if (res?.delta?.text) return res.delta.text; - if (res?.generation) return res.generation; - if (typeof res?.output === "string") return res.output; - if (typeof res === "string") return res; + if (res?.delta?.text) { + return res.delta.text; + } + if (res?.generation) { + return res.generation; + } + if (typeof res?.output === "string") { + return res.output; + } + if (typeof res === "string") { + return res; + } return ""; } @@ -127,9 +141,19 @@ export class BedrockApi implements LLMApi { ...messages, ] : messages; + + // Format messages without role prefixes for Titan const inputText = allMessages - .map((m) => `${m.role}: ${getMessageTextContent(m)}`) - .join("\n"); + .map((m) => { + // Include system message as a prefix instruction + if (m.role === "system") { + return getMessageTextContent(m); + } + // For user/assistant messages, just include the content + return getMessageTextContent(m); + }) + .join("\n\n"); + return { body: { inputText, @@ -142,29 +166,25 @@ export class BedrockApi implements LLMApi { }; } - // Handle LLaMA3 models - simplified format + // Handle LLaMA3 models if (model.startsWith("us.meta.llama3")) { - const allMessages = systemMessage - ? [ - { role: "system" as MessageRole, content: systemMessage }, - ...messages, - ] - : messages; - - const prompt = allMessages - .map((m) => `${m.role}: ${getMessageTextContent(m)}`) - .join("\n"); + // Only include the last user message for LLaMA + const lastMessage = messages[messages.length - 1]; + const prompt = getMessageTextContent(lastMessage); return { contentType: "application/json", accept: "application/json", body: { - prompt, + prompt: prompt, + max_gen_len: modelConfig.max_tokens || 256, + temperature: modelConfig.temperature || 0.5, + top_p: 0.9, }, }; } - // Handle Mistral models + // Handle Mistral models with correct instruction format if (model.startsWith("mistral.mistral")) { const allMessages = systemMessage ? [ @@ -172,14 +192,21 @@ export class BedrockApi implements LLMApi { ...messages, ] : messages; - const prompt = allMessages - .map((m) => `${m.role}: ${getMessageTextContent(m)}`) - .join("\n"); + + // Format messages as a conversation with instruction tags + const prompt = `[INST] ${allMessages + .map((m) => getMessageTextContent(m)) + .join("\n")} [/INST]`; + return { + contentType: "application/json", + accept: "application/json", body: { prompt, - temperature: modelConfig.temperature || 0.7, max_tokens: modelConfig.max_tokens || 4096, + temperature: modelConfig.temperature || 0.5, + top_p: 0.9, + top_k: 50, }, }; } @@ -258,7 +285,6 @@ export class BedrockApi implements LLMApi { systemMessage, modelConfig, ); - // console.log("Request body:", JSON.stringify(requestBody, null, 2)); const controller = new AbortController(); options.onController?.(controller); @@ -338,7 +364,6 @@ export class BedrockApi implements LLMApi { } catch (e) {} } const message = this.extractMessage(chunkJson, modelConfig.model); - // console.log("Extracted message:", message); return message; } catch (e) { console.error("Error parsing chunk:", e);