完善mistral模型的推理结果

This commit is contained in:
glay 2024-11-23 16:27:19 +08:00
parent a6337e9f23
commit 238eb70986
2 changed files with 81 additions and 39 deletions

View File

@ -110,8 +110,17 @@ async function* transformBedrockStream(
} }
// Handle Mistral models // Handle Mistral models
else if (modelId.startsWith("mistral.mistral")) { else if (modelId.startsWith("mistral.mistral")) {
const text = let text = "";
parsed.output || parsed.outputs?.[0]?.text || parsed.completion || ""; if (parsed.outputs?.[0]?.text) {
text = parsed.outputs[0].text;
} else if (parsed.output) {
text = parsed.output;
} else if (parsed.completion) {
text = parsed.completion;
} else if (typeof parsed === "string") {
text = parsed;
}
if (text) { if (text) {
yield `data: ${JSON.stringify({ yield `data: ${JSON.stringify({
delta: { text }, delta: { text },
@ -176,7 +185,17 @@ function validateRequest(body: any, modelId: string): void {
throw new Error("prompt is required for LLaMA3 models"); throw new Error("prompt is required for LLaMA3 models");
} }
} else if (modelId.startsWith("mistral.mistral")) { } else if (modelId.startsWith("mistral.mistral")) {
if (!bodyContent.prompt) throw new Error("Mistral requires a prompt"); if (!bodyContent.prompt) {
throw new Error("prompt is required for Mistral models");
}
if (
!bodyContent.prompt.startsWith("<s>[INST]") ||
!bodyContent.prompt.includes("[/INST]")
) {
throw new Error(
"Mistral prompt must be wrapped in <s>[INST] and [/INST] tags",
);
}
} else if (modelId.startsWith("amazon.titan")) { } else if (modelId.startsWith("amazon.titan")) {
if (!bodyContent.inputText) throw new Error("Titan requires inputText"); if (!bodyContent.inputText) throw new Error("Titan requires inputText");
} }
@ -247,29 +266,27 @@ async function requestBedrock(req: NextRequest) {
} else { } else {
endpoint = `${baseEndpoint}/model/${modelId}/invoke-with-response-stream`; endpoint = `${baseEndpoint}/model/${modelId}/invoke-with-response-stream`;
} }
requestBody = JSON.stringify(bodyJson.body || bodyJson);
// console.log("Request to AWS Bedrock:", {
// endpoint,
// modelId,
// body: requestBody,
// });
// Set content type and accept headers for Mistral models
const headers = await sign({ const headers = await sign({
method: "POST", method: "POST",
url: endpoint, url: endpoint,
region: awsRegion, region: awsRegion,
accessKeyId: awsAccessKey, accessKeyId: awsAccessKey,
secretAccessKey: awsSecretKey, secretAccessKey: awsSecretKey,
body: requestBody, body: JSON.stringify(bodyJson.body || bodyJson),
service: "bedrock", service: "bedrock",
isStreaming: shouldStream !== "false", isStreaming: shouldStream !== "false",
...(modelId.startsWith("mistral.mistral") && {
contentType: "application/json",
accept: "application/json",
}),
}); });
const res = await fetch(endpoint, { const res = await fetch(endpoint, {
method: "POST", method: "POST",
headers, headers,
body: requestBody, body: JSON.stringify(bodyJson.body || bodyJson),
redirect: "manual", redirect: "manual",
// @ts-ignore // @ts-ignore
duplex: "half", duplex: "half",

View File

@ -74,16 +74,30 @@ export class BedrockApi implements LLMApi {
try { try {
// Handle Titan models // Handle Titan models
if (modelId.startsWith("amazon.titan")) { if (modelId.startsWith("amazon.titan")) {
if (res?.delta?.text) return res.delta.text; let text = "";
return res?.outputText || ""; if (res?.delta?.text) {
text = res.delta.text;
} else {
text = res?.outputText || "";
}
// Clean up Titan response by removing leading question mark and whitespace
return text.replace(/^[\s?]+/, "");
} }
// Handle LLaMA models // Handle LLaMA models
if (modelId.startsWith("us.meta.llama3")) { if (modelId.startsWith("us.meta.llama3")) {
if (res?.delta?.text) return res.delta.text; if (res?.delta?.text) {
if (res?.generation) return res.generation; return res.delta.text;
if (typeof res?.output === "string") return res.output; }
if (typeof res === "string") return res; if (res?.generation) {
return res.generation;
}
if (typeof res?.output === "string") {
return res.output;
}
if (typeof res === "string") {
return res;
}
return ""; return "";
} }
@ -127,9 +141,19 @@ export class BedrockApi implements LLMApi {
...messages, ...messages,
] ]
: messages; : messages;
// Format messages without role prefixes for Titan
const inputText = allMessages const inputText = allMessages
.map((m) => `${m.role}: ${getMessageTextContent(m)}`) .map((m) => {
.join("\n"); // Include system message as a prefix instruction
if (m.role === "system") {
return getMessageTextContent(m);
}
// For user/assistant messages, just include the content
return getMessageTextContent(m);
})
.join("\n\n");
return { return {
body: { body: {
inputText, inputText,
@ -142,29 +166,25 @@ export class BedrockApi implements LLMApi {
}; };
} }
// Handle LLaMA3 models - simplified format // Handle LLaMA3 models
if (model.startsWith("us.meta.llama3")) { if (model.startsWith("us.meta.llama3")) {
const allMessages = systemMessage // Only include the last user message for LLaMA
? [ const lastMessage = messages[messages.length - 1];
{ role: "system" as MessageRole, content: systemMessage }, const prompt = getMessageTextContent(lastMessage);
...messages,
]
: messages;
const prompt = allMessages
.map((m) => `${m.role}: ${getMessageTextContent(m)}`)
.join("\n");
return { return {
contentType: "application/json", contentType: "application/json",
accept: "application/json", accept: "application/json",
body: { body: {
prompt, prompt: prompt,
max_gen_len: modelConfig.max_tokens || 256,
temperature: modelConfig.temperature || 0.5,
top_p: 0.9,
}, },
}; };
} }
// Handle Mistral models // Handle Mistral models with correct instruction format
if (model.startsWith("mistral.mistral")) { if (model.startsWith("mistral.mistral")) {
const allMessages = systemMessage const allMessages = systemMessage
? [ ? [
@ -172,14 +192,21 @@ export class BedrockApi implements LLMApi {
...messages, ...messages,
] ]
: messages; : messages;
const prompt = allMessages
.map((m) => `${m.role}: ${getMessageTextContent(m)}`) // Format messages as a conversation with instruction tags
.join("\n"); const prompt = `<s>[INST] ${allMessages
.map((m) => getMessageTextContent(m))
.join("\n")} [/INST]`;
return { return {
contentType: "application/json",
accept: "application/json",
body: { body: {
prompt, prompt,
temperature: modelConfig.temperature || 0.7,
max_tokens: modelConfig.max_tokens || 4096, max_tokens: modelConfig.max_tokens || 4096,
temperature: modelConfig.temperature || 0.5,
top_p: 0.9,
top_k: 50,
}, },
}; };
} }
@ -258,7 +285,6 @@ export class BedrockApi implements LLMApi {
systemMessage, systemMessage,
modelConfig, modelConfig,
); );
// console.log("Request body:", JSON.stringify(requestBody, null, 2));
const controller = new AbortController(); const controller = new AbortController();
options.onController?.(controller); options.onController?.(controller);
@ -338,7 +364,6 @@ export class BedrockApi implements LLMApi {
} catch (e) {} } catch (e) {}
} }
const message = this.extractMessage(chunkJson, modelConfig.model); const message = this.extractMessage(chunkJson, modelConfig.model);
// console.log("Extracted message:", message);
return message; return message;
} catch (e) { } catch (e) {
console.error("Error parsing chunk:", e); console.error("Error parsing chunk:", e);