完善llama和mistral模型的推理功能

This commit is contained in:
glay 2024-11-23 18:23:20 +08:00
parent 238eb70986
commit 513cf1b206
2 changed files with 134 additions and 69 deletions

View File

@ -15,7 +15,8 @@ function parseEventData(chunk: Uint8Array): any {
// AWS Bedrock wraps the response in a 'body' field // AWS Bedrock wraps the response in a 'body' field
if (typeof parsed.body === "string") { if (typeof parsed.body === "string") {
try { try {
return JSON.parse(parsed.body); const bodyJson = JSON.parse(parsed.body);
return bodyJson;
} catch (e) { } catch (e) {
return { output: parsed.body }; return { output: parsed.body };
} }
@ -89,10 +90,12 @@ async function* transformBedrockStream(
})}\n\n`; })}\n\n`;
} }
} }
// Handle LLaMA3 models // Handle LLaMA models
else if (modelId.startsWith("us.meta.llama3")) { else if (modelId.startsWith("us.meta.llama")) {
let text = ""; let text = "";
if (parsed.generation) { if (parsed.outputs?.[0]?.text) {
text = parsed.outputs[0].text;
} else if (parsed.generation) {
text = parsed.generation; text = parsed.generation;
} else if (parsed.output) { } else if (parsed.output) {
text = parsed.output; text = parsed.output;
@ -101,8 +104,6 @@ async function* transformBedrockStream(
} }
if (text) { if (text) {
// Clean up any control characters or invalid JSON characters
text = text.replace(/[\x00-\x1F\x7F-\x9F]/g, "");
yield `data: ${JSON.stringify({ yield `data: ${JSON.stringify({
delta: { text }, delta: { text },
})}\n\n`; })}\n\n`;
@ -162,6 +163,7 @@ async function* transformBedrockStream(
function validateRequest(body: any, modelId: string): void { function validateRequest(body: any, modelId: string): void {
if (!modelId) throw new Error("Model ID is required"); if (!modelId) throw new Error("Model ID is required");
// Handle nested body structure
const bodyContent = body.body || body; const bodyContent = body.body || body;
if (modelId.startsWith("anthropic.claude")) { if (modelId.startsWith("anthropic.claude")) {
@ -180,22 +182,20 @@ function validateRequest(body: any, modelId: string): void {
} else if (typeof body.prompt !== "string") { } else if (typeof body.prompt !== "string") {
throw new Error("prompt is required for Claude 2 and earlier"); throw new Error("prompt is required for Claude 2 and earlier");
} }
} else if (modelId.startsWith("us.meta.llama3")) { } else if (modelId.startsWith("us.meta.llama")) {
if (!bodyContent.prompt) { if (!bodyContent.prompt || typeof bodyContent.prompt !== "string") {
throw new Error("prompt is required for LLaMA3 models"); throw new Error("prompt string is required for LLaMA models");
}
if (
!bodyContent.max_gen_len ||
typeof bodyContent.max_gen_len !== "number"
) {
throw new Error("max_gen_len must be a positive number for LLaMA models");
} }
} else if (modelId.startsWith("mistral.mistral")) { } else if (modelId.startsWith("mistral.mistral")) {
if (!bodyContent.prompt) { if (!bodyContent.prompt) {
throw new Error("prompt is required for Mistral models"); throw new Error("prompt is required for Mistral models");
} }
if (
!bodyContent.prompt.startsWith("<s>[INST]") ||
!bodyContent.prompt.includes("[/INST]")
) {
throw new Error(
"Mistral prompt must be wrapped in <s>[INST] and [/INST] tags",
);
}
} else if (modelId.startsWith("amazon.titan")) { } else if (modelId.startsWith("amazon.titan")) {
if (!bodyContent.inputText) throw new Error("Titan requires inputText"); if (!bodyContent.inputText) throw new Error("Titan requires inputText");
} }
@ -250,7 +250,6 @@ async function requestBedrock(req: NextRequest) {
try { try {
// Determine the endpoint and request body based on model type // Determine the endpoint and request body based on model type
let endpoint; let endpoint;
let requestBody;
const bodyText = await req.clone().text(); const bodyText = await req.clone().text();
if (!bodyText) { if (!bodyText) {
@ -258,6 +257,10 @@ async function requestBedrock(req: NextRequest) {
} }
const bodyJson = JSON.parse(bodyText); const bodyJson = JSON.parse(bodyText);
// Debug log the request body
console.log("Original request body:", JSON.stringify(bodyJson, null, 2));
validateRequest(bodyJson, modelId); validateRequest(bodyJson, modelId);
// For all models, use standard endpoints // For all models, use standard endpoints
@ -267,26 +270,44 @@ async function requestBedrock(req: NextRequest) {
endpoint = `${baseEndpoint}/model/${modelId}/invoke-with-response-stream`; endpoint = `${baseEndpoint}/model/${modelId}/invoke-with-response-stream`;
} }
// Set content type and accept headers for Mistral models // Set additional headers based on model type
const additionalHeaders: Record<string, string> = {};
if (
modelId.startsWith("us.meta.llama") ||
modelId.startsWith("mistral.mistral")
) {
additionalHeaders["content-type"] = "application/json";
additionalHeaders["accept"] = "application/json";
}
// For Mistral models, unwrap the body object
const finalRequestBody =
modelId.startsWith("mistral.mistral") && bodyJson.body
? bodyJson.body
: bodyJson;
// Set content type and accept headers for specific models
const headers = await sign({ const headers = await sign({
method: "POST", method: "POST",
url: endpoint, url: endpoint,
region: awsRegion, region: awsRegion,
accessKeyId: awsAccessKey, accessKeyId: awsAccessKey,
secretAccessKey: awsSecretKey, secretAccessKey: awsSecretKey,
body: JSON.stringify(bodyJson.body || bodyJson), body: JSON.stringify(finalRequestBody),
service: "bedrock", service: "bedrock",
isStreaming: shouldStream !== "false", isStreaming: shouldStream !== "false",
...(modelId.startsWith("mistral.mistral") && { additionalHeaders,
contentType: "application/json",
accept: "application/json",
}),
}); });
// Debug log the final request body
// console.log("Final request endpoint:", endpoint);
// console.log(headers);
// console.log("Final request body:", JSON.stringify(finalRequestBody, null, 2));
const res = await fetch(endpoint, { const res = await fetch(endpoint, {
method: "POST", method: "POST",
headers, headers,
body: JSON.stringify(bodyJson.body || bodyJson), body: JSON.stringify(finalRequestBody),
redirect: "manual", redirect: "manual",
// @ts-ignore // @ts-ignore
duplex: "half", duplex: "half",

View File

@ -85,14 +85,17 @@ export class BedrockApi implements LLMApi {
} }
// Handle LLaMA models // Handle LLaMA models
if (modelId.startsWith("us.meta.llama3")) { if (modelId.startsWith("us.meta.llama")) {
if (res?.delta?.text) { if (res?.delta?.text) {
return res.delta.text; return res.delta.text;
} }
if (res?.generation) { if (res?.generation) {
return res.generation; return res.generation;
} }
if (typeof res?.output === "string") { if (res?.outputs?.[0]?.text) {
return res.outputs[0].text;
}
if (res?.output) {
return res.output; return res.output;
} }
if (typeof res === "string") { if (typeof res === "string") {
@ -103,11 +106,28 @@ export class BedrockApi implements LLMApi {
// Handle Mistral models // Handle Mistral models
if (modelId.startsWith("mistral.mistral")) { if (modelId.startsWith("mistral.mistral")) {
if (res?.delta?.text) return res.delta.text; if (res?.delta?.text) {
return res?.outputs?.[0]?.text || res?.output || res?.completion || ""; return res.delta.text;
}
if (res?.outputs?.[0]?.text) {
return res.outputs[0].text;
}
if (res?.content?.[0]?.text) {
return res.content[0].text;
}
if (res?.output) {
return res.output;
}
if (res?.completion) {
return res.completion;
}
if (typeof res === "string") {
return res;
}
return "";
} }
// Handle Claude models and fallback cases // Handle Claude models
if (res?.content?.[0]?.text) return res.content[0].text; if (res?.content?.[0]?.text) return res.content[0].text;
if (res?.messages?.[0]?.content?.[0]?.text) if (res?.messages?.[0]?.content?.[0]?.text)
return res.messages[0].content[0].text; return res.messages[0].content[0].text;
@ -142,14 +162,11 @@ export class BedrockApi implements LLMApi {
] ]
: messages; : messages;
// Format messages without role prefixes for Titan
const inputText = allMessages const inputText = allMessages
.map((m) => { .map((m) => {
// Include system message as a prefix instruction
if (m.role === "system") { if (m.role === "system") {
return getMessageTextContent(m); return getMessageTextContent(m);
} }
// For user/assistant messages, just include the content
return getMessageTextContent(m); return getMessageTextContent(m);
}) })
.join("\n\n"); .join("\n\n");
@ -166,25 +183,39 @@ export class BedrockApi implements LLMApi {
}; };
} }
// Handle LLaMA3 models // Handle LLaMA models
if (model.startsWith("us.meta.llama3")) { if (model.startsWith("us.meta.llama")) {
// Only include the last user message for LLaMA const allMessages = systemMessage
const lastMessage = messages[messages.length - 1]; ? [
const prompt = getMessageTextContent(lastMessage); { role: "system" as MessageRole, content: systemMessage },
...messages,
]
: messages;
const prompt = allMessages
.map((m) => {
const content = getMessageTextContent(m);
if (m.role === "system") {
return `System: ${content}`;
} else if (m.role === "user") {
return `User: ${content}`;
} else if (m.role === "assistant") {
return `Assistant: ${content}`;
}
return content;
})
.join("\n\n");
return { return {
contentType: "application/json", prompt,
accept: "application/json", max_gen_len: modelConfig.max_tokens || 512,
body: { temperature: modelConfig.temperature || 0.6,
prompt: prompt, top_p: modelConfig.top_p || 0.9,
max_gen_len: modelConfig.max_tokens || 256, stop: ["User:", "System:", "Assistant:", "\n\n"],
temperature: modelConfig.temperature || 0.5,
top_p: 0.9,
},
}; };
} }
// Handle Mistral models with correct instruction format // Handle Mistral models
if (model.startsWith("mistral.mistral")) { if (model.startsWith("mistral.mistral")) {
const allMessages = systemMessage const allMessages = systemMessage
? [ ? [
@ -193,25 +224,29 @@ export class BedrockApi implements LLMApi {
] ]
: messages; : messages;
// Format messages as a conversation with instruction tags const formattedConversation = allMessages
const prompt = `<s>[INST] ${allMessages .map((m) => {
.map((m) => getMessageTextContent(m)) const content = getMessageTextContent(m);
.join("\n")} [/INST]`; if (m.role === "system") {
return content;
} else if (m.role === "user") {
return content;
} else if (m.role === "assistant") {
return content;
}
return content;
})
.join("\n");
// Format according to Mistral's requirements
return { return {
contentType: "application/json", prompt: formattedConversation,
accept: "application/json", max_tokens: modelConfig.max_tokens || 4096,
body: { temperature: modelConfig.temperature || 0.7,
prompt,
max_tokens: modelConfig.max_tokens || 4096,
temperature: modelConfig.temperature || 0.5,
top_p: 0.9,
top_k: 50,
},
}; };
} }
// Handle Claude models (existing implementation) // Handle Claude models
const isClaude3 = model.startsWith("anthropic.claude-3"); const isClaude3 = model.startsWith("anthropic.claude-3");
const formattedMessages = messages const formattedMessages = messages
.filter( .filter(
@ -253,12 +288,14 @@ export class BedrockApi implements LLMApi {
}); });
return { return {
anthropic_version: "bedrock-2023-05-31", body: {
max_tokens: modelConfig.max_tokens, anthropic_version: "bedrock-2023-05-31",
messages: formattedMessages, max_tokens: modelConfig.max_tokens,
...(systemMessage && { system: systemMessage }), messages: formattedMessages,
temperature: modelConfig.temperature, ...(systemMessage && { system: systemMessage }),
...(isClaude3 && { top_k: 5 }), temperature: modelConfig.temperature,
...(isClaude3 && { top_k: modelConfig.top_k || 50 }),
},
}; };
} }
@ -301,6 +338,13 @@ export class BedrockApi implements LLMApi {
const headers = getHeaders(); const headers = getHeaders();
headers.ModelID = modelConfig.model; headers.ModelID = modelConfig.model;
// For LLaMA and Mistral models, send the request body directly without the 'body' wrapper
const finalRequestBody =
modelConfig.model.startsWith("us.meta.llama") ||
modelConfig.model.startsWith("mistral.mistral")
? requestBody
: requestBody.body;
if (options.config.stream) { if (options.config.stream) {
let index = -1; let index = -1;
let currentToolArgs = ""; let currentToolArgs = "";
@ -312,7 +356,7 @@ export class BedrockApi implements LLMApi {
return stream( return stream(
chatPath, chatPath,
requestBody, finalRequestBody,
headers, headers,
(tools as ToolDefinition[]).map((tool) => ({ (tools as ToolDefinition[]).map((tool) => ({
name: tool?.function?.name, name: tool?.function?.name,
@ -420,7 +464,7 @@ export class BedrockApi implements LLMApi {
const res = await fetch(chatPath, { const res = await fetch(chatPath, {
method: "POST", method: "POST",
headers, headers,
body: JSON.stringify(requestBody), body: JSON.stringify(finalRequestBody),
}); });
const resJson = await res.json(); const resJson = await res.json();