mirror of
https://github.com/ChatGPTNextWeb/ChatGPT-Next-Web.git
synced 2025-05-23 06:00:17 +09:00
完善llama和mistral模型的推理功能
This commit is contained in:
parent
238eb70986
commit
513cf1b206
@ -15,7 +15,8 @@ function parseEventData(chunk: Uint8Array): any {
|
|||||||
// AWS Bedrock wraps the response in a 'body' field
|
// AWS Bedrock wraps the response in a 'body' field
|
||||||
if (typeof parsed.body === "string") {
|
if (typeof parsed.body === "string") {
|
||||||
try {
|
try {
|
||||||
return JSON.parse(parsed.body);
|
const bodyJson = JSON.parse(parsed.body);
|
||||||
|
return bodyJson;
|
||||||
} catch (e) {
|
} catch (e) {
|
||||||
return { output: parsed.body };
|
return { output: parsed.body };
|
||||||
}
|
}
|
||||||
@ -89,10 +90,12 @@ async function* transformBedrockStream(
|
|||||||
})}\n\n`;
|
})}\n\n`;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
// Handle LLaMA3 models
|
// Handle LLaMA models
|
||||||
else if (modelId.startsWith("us.meta.llama3")) {
|
else if (modelId.startsWith("us.meta.llama")) {
|
||||||
let text = "";
|
let text = "";
|
||||||
if (parsed.generation) {
|
if (parsed.outputs?.[0]?.text) {
|
||||||
|
text = parsed.outputs[0].text;
|
||||||
|
} else if (parsed.generation) {
|
||||||
text = parsed.generation;
|
text = parsed.generation;
|
||||||
} else if (parsed.output) {
|
} else if (parsed.output) {
|
||||||
text = parsed.output;
|
text = parsed.output;
|
||||||
@ -101,8 +104,6 @@ async function* transformBedrockStream(
|
|||||||
}
|
}
|
||||||
|
|
||||||
if (text) {
|
if (text) {
|
||||||
// Clean up any control characters or invalid JSON characters
|
|
||||||
text = text.replace(/[\x00-\x1F\x7F-\x9F]/g, "");
|
|
||||||
yield `data: ${JSON.stringify({
|
yield `data: ${JSON.stringify({
|
||||||
delta: { text },
|
delta: { text },
|
||||||
})}\n\n`;
|
})}\n\n`;
|
||||||
@ -162,6 +163,7 @@ async function* transformBedrockStream(
|
|||||||
function validateRequest(body: any, modelId: string): void {
|
function validateRequest(body: any, modelId: string): void {
|
||||||
if (!modelId) throw new Error("Model ID is required");
|
if (!modelId) throw new Error("Model ID is required");
|
||||||
|
|
||||||
|
// Handle nested body structure
|
||||||
const bodyContent = body.body || body;
|
const bodyContent = body.body || body;
|
||||||
|
|
||||||
if (modelId.startsWith("anthropic.claude")) {
|
if (modelId.startsWith("anthropic.claude")) {
|
||||||
@ -180,22 +182,20 @@ function validateRequest(body: any, modelId: string): void {
|
|||||||
} else if (typeof body.prompt !== "string") {
|
} else if (typeof body.prompt !== "string") {
|
||||||
throw new Error("prompt is required for Claude 2 and earlier");
|
throw new Error("prompt is required for Claude 2 and earlier");
|
||||||
}
|
}
|
||||||
} else if (modelId.startsWith("us.meta.llama3")) {
|
} else if (modelId.startsWith("us.meta.llama")) {
|
||||||
if (!bodyContent.prompt) {
|
if (!bodyContent.prompt || typeof bodyContent.prompt !== "string") {
|
||||||
throw new Error("prompt is required for LLaMA3 models");
|
throw new Error("prompt string is required for LLaMA models");
|
||||||
|
}
|
||||||
|
if (
|
||||||
|
!bodyContent.max_gen_len ||
|
||||||
|
typeof bodyContent.max_gen_len !== "number"
|
||||||
|
) {
|
||||||
|
throw new Error("max_gen_len must be a positive number for LLaMA models");
|
||||||
}
|
}
|
||||||
} else if (modelId.startsWith("mistral.mistral")) {
|
} else if (modelId.startsWith("mistral.mistral")) {
|
||||||
if (!bodyContent.prompt) {
|
if (!bodyContent.prompt) {
|
||||||
throw new Error("prompt is required for Mistral models");
|
throw new Error("prompt is required for Mistral models");
|
||||||
}
|
}
|
||||||
if (
|
|
||||||
!bodyContent.prompt.startsWith("<s>[INST]") ||
|
|
||||||
!bodyContent.prompt.includes("[/INST]")
|
|
||||||
) {
|
|
||||||
throw new Error(
|
|
||||||
"Mistral prompt must be wrapped in <s>[INST] and [/INST] tags",
|
|
||||||
);
|
|
||||||
}
|
|
||||||
} else if (modelId.startsWith("amazon.titan")) {
|
} else if (modelId.startsWith("amazon.titan")) {
|
||||||
if (!bodyContent.inputText) throw new Error("Titan requires inputText");
|
if (!bodyContent.inputText) throw new Error("Titan requires inputText");
|
||||||
}
|
}
|
||||||
@ -250,7 +250,6 @@ async function requestBedrock(req: NextRequest) {
|
|||||||
try {
|
try {
|
||||||
// Determine the endpoint and request body based on model type
|
// Determine the endpoint and request body based on model type
|
||||||
let endpoint;
|
let endpoint;
|
||||||
let requestBody;
|
|
||||||
|
|
||||||
const bodyText = await req.clone().text();
|
const bodyText = await req.clone().text();
|
||||||
if (!bodyText) {
|
if (!bodyText) {
|
||||||
@ -258,6 +257,10 @@ async function requestBedrock(req: NextRequest) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
const bodyJson = JSON.parse(bodyText);
|
const bodyJson = JSON.parse(bodyText);
|
||||||
|
|
||||||
|
// Debug log the request body
|
||||||
|
console.log("Original request body:", JSON.stringify(bodyJson, null, 2));
|
||||||
|
|
||||||
validateRequest(bodyJson, modelId);
|
validateRequest(bodyJson, modelId);
|
||||||
|
|
||||||
// For all models, use standard endpoints
|
// For all models, use standard endpoints
|
||||||
@ -267,26 +270,44 @@ async function requestBedrock(req: NextRequest) {
|
|||||||
endpoint = `${baseEndpoint}/model/${modelId}/invoke-with-response-stream`;
|
endpoint = `${baseEndpoint}/model/${modelId}/invoke-with-response-stream`;
|
||||||
}
|
}
|
||||||
|
|
||||||
// Set content type and accept headers for Mistral models
|
// Set additional headers based on model type
|
||||||
|
const additionalHeaders: Record<string, string> = {};
|
||||||
|
if (
|
||||||
|
modelId.startsWith("us.meta.llama") ||
|
||||||
|
modelId.startsWith("mistral.mistral")
|
||||||
|
) {
|
||||||
|
additionalHeaders["content-type"] = "application/json";
|
||||||
|
additionalHeaders["accept"] = "application/json";
|
||||||
|
}
|
||||||
|
|
||||||
|
// For Mistral models, unwrap the body object
|
||||||
|
const finalRequestBody =
|
||||||
|
modelId.startsWith("mistral.mistral") && bodyJson.body
|
||||||
|
? bodyJson.body
|
||||||
|
: bodyJson;
|
||||||
|
|
||||||
|
// Set content type and accept headers for specific models
|
||||||
const headers = await sign({
|
const headers = await sign({
|
||||||
method: "POST",
|
method: "POST",
|
||||||
url: endpoint,
|
url: endpoint,
|
||||||
region: awsRegion,
|
region: awsRegion,
|
||||||
accessKeyId: awsAccessKey,
|
accessKeyId: awsAccessKey,
|
||||||
secretAccessKey: awsSecretKey,
|
secretAccessKey: awsSecretKey,
|
||||||
body: JSON.stringify(bodyJson.body || bodyJson),
|
body: JSON.stringify(finalRequestBody),
|
||||||
service: "bedrock",
|
service: "bedrock",
|
||||||
isStreaming: shouldStream !== "false",
|
isStreaming: shouldStream !== "false",
|
||||||
...(modelId.startsWith("mistral.mistral") && {
|
additionalHeaders,
|
||||||
contentType: "application/json",
|
|
||||||
accept: "application/json",
|
|
||||||
}),
|
|
||||||
});
|
});
|
||||||
|
|
||||||
|
// Debug log the final request body
|
||||||
|
// console.log("Final request endpoint:", endpoint);
|
||||||
|
// console.log(headers);
|
||||||
|
// console.log("Final request body:", JSON.stringify(finalRequestBody, null, 2));
|
||||||
|
|
||||||
const res = await fetch(endpoint, {
|
const res = await fetch(endpoint, {
|
||||||
method: "POST",
|
method: "POST",
|
||||||
headers,
|
headers,
|
||||||
body: JSON.stringify(bodyJson.body || bodyJson),
|
body: JSON.stringify(finalRequestBody),
|
||||||
redirect: "manual",
|
redirect: "manual",
|
||||||
// @ts-ignore
|
// @ts-ignore
|
||||||
duplex: "half",
|
duplex: "half",
|
||||||
|
@ -85,14 +85,17 @@ export class BedrockApi implements LLMApi {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Handle LLaMA models
|
// Handle LLaMA models
|
||||||
if (modelId.startsWith("us.meta.llama3")) {
|
if (modelId.startsWith("us.meta.llama")) {
|
||||||
if (res?.delta?.text) {
|
if (res?.delta?.text) {
|
||||||
return res.delta.text;
|
return res.delta.text;
|
||||||
}
|
}
|
||||||
if (res?.generation) {
|
if (res?.generation) {
|
||||||
return res.generation;
|
return res.generation;
|
||||||
}
|
}
|
||||||
if (typeof res?.output === "string") {
|
if (res?.outputs?.[0]?.text) {
|
||||||
|
return res.outputs[0].text;
|
||||||
|
}
|
||||||
|
if (res?.output) {
|
||||||
return res.output;
|
return res.output;
|
||||||
}
|
}
|
||||||
if (typeof res === "string") {
|
if (typeof res === "string") {
|
||||||
@ -103,11 +106,28 @@ export class BedrockApi implements LLMApi {
|
|||||||
|
|
||||||
// Handle Mistral models
|
// Handle Mistral models
|
||||||
if (modelId.startsWith("mistral.mistral")) {
|
if (modelId.startsWith("mistral.mistral")) {
|
||||||
if (res?.delta?.text) return res.delta.text;
|
if (res?.delta?.text) {
|
||||||
return res?.outputs?.[0]?.text || res?.output || res?.completion || "";
|
return res.delta.text;
|
||||||
|
}
|
||||||
|
if (res?.outputs?.[0]?.text) {
|
||||||
|
return res.outputs[0].text;
|
||||||
|
}
|
||||||
|
if (res?.content?.[0]?.text) {
|
||||||
|
return res.content[0].text;
|
||||||
|
}
|
||||||
|
if (res?.output) {
|
||||||
|
return res.output;
|
||||||
|
}
|
||||||
|
if (res?.completion) {
|
||||||
|
return res.completion;
|
||||||
|
}
|
||||||
|
if (typeof res === "string") {
|
||||||
|
return res;
|
||||||
|
}
|
||||||
|
return "";
|
||||||
}
|
}
|
||||||
|
|
||||||
// Handle Claude models and fallback cases
|
// Handle Claude models
|
||||||
if (res?.content?.[0]?.text) return res.content[0].text;
|
if (res?.content?.[0]?.text) return res.content[0].text;
|
||||||
if (res?.messages?.[0]?.content?.[0]?.text)
|
if (res?.messages?.[0]?.content?.[0]?.text)
|
||||||
return res.messages[0].content[0].text;
|
return res.messages[0].content[0].text;
|
||||||
@ -142,14 +162,11 @@ export class BedrockApi implements LLMApi {
|
|||||||
]
|
]
|
||||||
: messages;
|
: messages;
|
||||||
|
|
||||||
// Format messages without role prefixes for Titan
|
|
||||||
const inputText = allMessages
|
const inputText = allMessages
|
||||||
.map((m) => {
|
.map((m) => {
|
||||||
// Include system message as a prefix instruction
|
|
||||||
if (m.role === "system") {
|
if (m.role === "system") {
|
||||||
return getMessageTextContent(m);
|
return getMessageTextContent(m);
|
||||||
}
|
}
|
||||||
// For user/assistant messages, just include the content
|
|
||||||
return getMessageTextContent(m);
|
return getMessageTextContent(m);
|
||||||
})
|
})
|
||||||
.join("\n\n");
|
.join("\n\n");
|
||||||
@ -166,25 +183,39 @@ export class BedrockApi implements LLMApi {
|
|||||||
};
|
};
|
||||||
}
|
}
|
||||||
|
|
||||||
// Handle LLaMA3 models
|
// Handle LLaMA models
|
||||||
if (model.startsWith("us.meta.llama3")) {
|
if (model.startsWith("us.meta.llama")) {
|
||||||
// Only include the last user message for LLaMA
|
const allMessages = systemMessage
|
||||||
const lastMessage = messages[messages.length - 1];
|
? [
|
||||||
const prompt = getMessageTextContent(lastMessage);
|
{ role: "system" as MessageRole, content: systemMessage },
|
||||||
|
...messages,
|
||||||
|
]
|
||||||
|
: messages;
|
||||||
|
|
||||||
|
const prompt = allMessages
|
||||||
|
.map((m) => {
|
||||||
|
const content = getMessageTextContent(m);
|
||||||
|
if (m.role === "system") {
|
||||||
|
return `System: ${content}`;
|
||||||
|
} else if (m.role === "user") {
|
||||||
|
return `User: ${content}`;
|
||||||
|
} else if (m.role === "assistant") {
|
||||||
|
return `Assistant: ${content}`;
|
||||||
|
}
|
||||||
|
return content;
|
||||||
|
})
|
||||||
|
.join("\n\n");
|
||||||
|
|
||||||
return {
|
return {
|
||||||
contentType: "application/json",
|
prompt,
|
||||||
accept: "application/json",
|
max_gen_len: modelConfig.max_tokens || 512,
|
||||||
body: {
|
temperature: modelConfig.temperature || 0.6,
|
||||||
prompt: prompt,
|
top_p: modelConfig.top_p || 0.9,
|
||||||
max_gen_len: modelConfig.max_tokens || 256,
|
stop: ["User:", "System:", "Assistant:", "\n\n"],
|
||||||
temperature: modelConfig.temperature || 0.5,
|
|
||||||
top_p: 0.9,
|
|
||||||
},
|
|
||||||
};
|
};
|
||||||
}
|
}
|
||||||
|
|
||||||
// Handle Mistral models with correct instruction format
|
// Handle Mistral models
|
||||||
if (model.startsWith("mistral.mistral")) {
|
if (model.startsWith("mistral.mistral")) {
|
||||||
const allMessages = systemMessage
|
const allMessages = systemMessage
|
||||||
? [
|
? [
|
||||||
@ -193,25 +224,29 @@ export class BedrockApi implements LLMApi {
|
|||||||
]
|
]
|
||||||
: messages;
|
: messages;
|
||||||
|
|
||||||
// Format messages as a conversation with instruction tags
|
const formattedConversation = allMessages
|
||||||
const prompt = `<s>[INST] ${allMessages
|
.map((m) => {
|
||||||
.map((m) => getMessageTextContent(m))
|
const content = getMessageTextContent(m);
|
||||||
.join("\n")} [/INST]`;
|
if (m.role === "system") {
|
||||||
|
return content;
|
||||||
|
} else if (m.role === "user") {
|
||||||
|
return content;
|
||||||
|
} else if (m.role === "assistant") {
|
||||||
|
return content;
|
||||||
|
}
|
||||||
|
return content;
|
||||||
|
})
|
||||||
|
.join("\n");
|
||||||
|
|
||||||
|
// Format according to Mistral's requirements
|
||||||
return {
|
return {
|
||||||
contentType: "application/json",
|
prompt: formattedConversation,
|
||||||
accept: "application/json",
|
|
||||||
body: {
|
|
||||||
prompt,
|
|
||||||
max_tokens: modelConfig.max_tokens || 4096,
|
max_tokens: modelConfig.max_tokens || 4096,
|
||||||
temperature: modelConfig.temperature || 0.5,
|
temperature: modelConfig.temperature || 0.7,
|
||||||
top_p: 0.9,
|
|
||||||
top_k: 50,
|
|
||||||
},
|
|
||||||
};
|
};
|
||||||
}
|
}
|
||||||
|
|
||||||
// Handle Claude models (existing implementation)
|
// Handle Claude models
|
||||||
const isClaude3 = model.startsWith("anthropic.claude-3");
|
const isClaude3 = model.startsWith("anthropic.claude-3");
|
||||||
const formattedMessages = messages
|
const formattedMessages = messages
|
||||||
.filter(
|
.filter(
|
||||||
@ -253,12 +288,14 @@ export class BedrockApi implements LLMApi {
|
|||||||
});
|
});
|
||||||
|
|
||||||
return {
|
return {
|
||||||
|
body: {
|
||||||
anthropic_version: "bedrock-2023-05-31",
|
anthropic_version: "bedrock-2023-05-31",
|
||||||
max_tokens: modelConfig.max_tokens,
|
max_tokens: modelConfig.max_tokens,
|
||||||
messages: formattedMessages,
|
messages: formattedMessages,
|
||||||
...(systemMessage && { system: systemMessage }),
|
...(systemMessage && { system: systemMessage }),
|
||||||
temperature: modelConfig.temperature,
|
temperature: modelConfig.temperature,
|
||||||
...(isClaude3 && { top_k: 5 }),
|
...(isClaude3 && { top_k: modelConfig.top_k || 50 }),
|
||||||
|
},
|
||||||
};
|
};
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -301,6 +338,13 @@ export class BedrockApi implements LLMApi {
|
|||||||
const headers = getHeaders();
|
const headers = getHeaders();
|
||||||
headers.ModelID = modelConfig.model;
|
headers.ModelID = modelConfig.model;
|
||||||
|
|
||||||
|
// For LLaMA and Mistral models, send the request body directly without the 'body' wrapper
|
||||||
|
const finalRequestBody =
|
||||||
|
modelConfig.model.startsWith("us.meta.llama") ||
|
||||||
|
modelConfig.model.startsWith("mistral.mistral")
|
||||||
|
? requestBody
|
||||||
|
: requestBody.body;
|
||||||
|
|
||||||
if (options.config.stream) {
|
if (options.config.stream) {
|
||||||
let index = -1;
|
let index = -1;
|
||||||
let currentToolArgs = "";
|
let currentToolArgs = "";
|
||||||
@ -312,7 +356,7 @@ export class BedrockApi implements LLMApi {
|
|||||||
|
|
||||||
return stream(
|
return stream(
|
||||||
chatPath,
|
chatPath,
|
||||||
requestBody,
|
finalRequestBody,
|
||||||
headers,
|
headers,
|
||||||
(tools as ToolDefinition[]).map((tool) => ({
|
(tools as ToolDefinition[]).map((tool) => ({
|
||||||
name: tool?.function?.name,
|
name: tool?.function?.name,
|
||||||
@ -420,7 +464,7 @@ export class BedrockApi implements LLMApi {
|
|||||||
const res = await fetch(chatPath, {
|
const res = await fetch(chatPath, {
|
||||||
method: "POST",
|
method: "POST",
|
||||||
headers,
|
headers,
|
||||||
body: JSON.stringify(requestBody),
|
body: JSON.stringify(finalRequestBody),
|
||||||
});
|
});
|
||||||
|
|
||||||
const resJson = await res.json();
|
const resJson = await res.json();
|
||||||
|
Loading…
Reference in New Issue
Block a user