mirror of
https://github.com/ChatGPTNextWeb/ChatGPT-Next-Web.git
synced 2025-05-23 06:00:17 +09:00
完善mistral模型的推理结果
This commit is contained in:
parent
a6337e9f23
commit
238eb70986
@ -110,8 +110,17 @@ async function* transformBedrockStream(
|
|||||||
}
|
}
|
||||||
// Handle Mistral models
|
// Handle Mistral models
|
||||||
else if (modelId.startsWith("mistral.mistral")) {
|
else if (modelId.startsWith("mistral.mistral")) {
|
||||||
const text =
|
let text = "";
|
||||||
parsed.output || parsed.outputs?.[0]?.text || parsed.completion || "";
|
if (parsed.outputs?.[0]?.text) {
|
||||||
|
text = parsed.outputs[0].text;
|
||||||
|
} else if (parsed.output) {
|
||||||
|
text = parsed.output;
|
||||||
|
} else if (parsed.completion) {
|
||||||
|
text = parsed.completion;
|
||||||
|
} else if (typeof parsed === "string") {
|
||||||
|
text = parsed;
|
||||||
|
}
|
||||||
|
|
||||||
if (text) {
|
if (text) {
|
||||||
yield `data: ${JSON.stringify({
|
yield `data: ${JSON.stringify({
|
||||||
delta: { text },
|
delta: { text },
|
||||||
@ -176,7 +185,17 @@ function validateRequest(body: any, modelId: string): void {
|
|||||||
throw new Error("prompt is required for LLaMA3 models");
|
throw new Error("prompt is required for LLaMA3 models");
|
||||||
}
|
}
|
||||||
} else if (modelId.startsWith("mistral.mistral")) {
|
} else if (modelId.startsWith("mistral.mistral")) {
|
||||||
if (!bodyContent.prompt) throw new Error("Mistral requires a prompt");
|
if (!bodyContent.prompt) {
|
||||||
|
throw new Error("prompt is required for Mistral models");
|
||||||
|
}
|
||||||
|
if (
|
||||||
|
!bodyContent.prompt.startsWith("<s>[INST]") ||
|
||||||
|
!bodyContent.prompt.includes("[/INST]")
|
||||||
|
) {
|
||||||
|
throw new Error(
|
||||||
|
"Mistral prompt must be wrapped in <s>[INST] and [/INST] tags",
|
||||||
|
);
|
||||||
|
}
|
||||||
} else if (modelId.startsWith("amazon.titan")) {
|
} else if (modelId.startsWith("amazon.titan")) {
|
||||||
if (!bodyContent.inputText) throw new Error("Titan requires inputText");
|
if (!bodyContent.inputText) throw new Error("Titan requires inputText");
|
||||||
}
|
}
|
||||||
@ -247,29 +266,27 @@ async function requestBedrock(req: NextRequest) {
|
|||||||
} else {
|
} else {
|
||||||
endpoint = `${baseEndpoint}/model/${modelId}/invoke-with-response-stream`;
|
endpoint = `${baseEndpoint}/model/${modelId}/invoke-with-response-stream`;
|
||||||
}
|
}
|
||||||
requestBody = JSON.stringify(bodyJson.body || bodyJson);
|
|
||||||
|
|
||||||
// console.log("Request to AWS Bedrock:", {
|
|
||||||
// endpoint,
|
|
||||||
// modelId,
|
|
||||||
// body: requestBody,
|
|
||||||
// });
|
|
||||||
|
|
||||||
|
// Set content type and accept headers for Mistral models
|
||||||
const headers = await sign({
|
const headers = await sign({
|
||||||
method: "POST",
|
method: "POST",
|
||||||
url: endpoint,
|
url: endpoint,
|
||||||
region: awsRegion,
|
region: awsRegion,
|
||||||
accessKeyId: awsAccessKey,
|
accessKeyId: awsAccessKey,
|
||||||
secretAccessKey: awsSecretKey,
|
secretAccessKey: awsSecretKey,
|
||||||
body: requestBody,
|
body: JSON.stringify(bodyJson.body || bodyJson),
|
||||||
service: "bedrock",
|
service: "bedrock",
|
||||||
isStreaming: shouldStream !== "false",
|
isStreaming: shouldStream !== "false",
|
||||||
|
...(modelId.startsWith("mistral.mistral") && {
|
||||||
|
contentType: "application/json",
|
||||||
|
accept: "application/json",
|
||||||
|
}),
|
||||||
});
|
});
|
||||||
|
|
||||||
const res = await fetch(endpoint, {
|
const res = await fetch(endpoint, {
|
||||||
method: "POST",
|
method: "POST",
|
||||||
headers,
|
headers,
|
||||||
body: requestBody,
|
body: JSON.stringify(bodyJson.body || bodyJson),
|
||||||
redirect: "manual",
|
redirect: "manual",
|
||||||
// @ts-ignore
|
// @ts-ignore
|
||||||
duplex: "half",
|
duplex: "half",
|
||||||
|
@ -74,16 +74,30 @@ export class BedrockApi implements LLMApi {
|
|||||||
try {
|
try {
|
||||||
// Handle Titan models
|
// Handle Titan models
|
||||||
if (modelId.startsWith("amazon.titan")) {
|
if (modelId.startsWith("amazon.titan")) {
|
||||||
if (res?.delta?.text) return res.delta.text;
|
let text = "";
|
||||||
return res?.outputText || "";
|
if (res?.delta?.text) {
|
||||||
|
text = res.delta.text;
|
||||||
|
} else {
|
||||||
|
text = res?.outputText || "";
|
||||||
|
}
|
||||||
|
// Clean up Titan response by removing leading question mark and whitespace
|
||||||
|
return text.replace(/^[\s?]+/, "");
|
||||||
}
|
}
|
||||||
|
|
||||||
// Handle LLaMA models
|
// Handle LLaMA models
|
||||||
if (modelId.startsWith("us.meta.llama3")) {
|
if (modelId.startsWith("us.meta.llama3")) {
|
||||||
if (res?.delta?.text) return res.delta.text;
|
if (res?.delta?.text) {
|
||||||
if (res?.generation) return res.generation;
|
return res.delta.text;
|
||||||
if (typeof res?.output === "string") return res.output;
|
}
|
||||||
if (typeof res === "string") return res;
|
if (res?.generation) {
|
||||||
|
return res.generation;
|
||||||
|
}
|
||||||
|
if (typeof res?.output === "string") {
|
||||||
|
return res.output;
|
||||||
|
}
|
||||||
|
if (typeof res === "string") {
|
||||||
|
return res;
|
||||||
|
}
|
||||||
return "";
|
return "";
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -127,9 +141,19 @@ export class BedrockApi implements LLMApi {
|
|||||||
...messages,
|
...messages,
|
||||||
]
|
]
|
||||||
: messages;
|
: messages;
|
||||||
|
|
||||||
|
// Format messages without role prefixes for Titan
|
||||||
const inputText = allMessages
|
const inputText = allMessages
|
||||||
.map((m) => `${m.role}: ${getMessageTextContent(m)}`)
|
.map((m) => {
|
||||||
.join("\n");
|
// Include system message as a prefix instruction
|
||||||
|
if (m.role === "system") {
|
||||||
|
return getMessageTextContent(m);
|
||||||
|
}
|
||||||
|
// For user/assistant messages, just include the content
|
||||||
|
return getMessageTextContent(m);
|
||||||
|
})
|
||||||
|
.join("\n\n");
|
||||||
|
|
||||||
return {
|
return {
|
||||||
body: {
|
body: {
|
||||||
inputText,
|
inputText,
|
||||||
@ -142,29 +166,25 @@ export class BedrockApi implements LLMApi {
|
|||||||
};
|
};
|
||||||
}
|
}
|
||||||
|
|
||||||
// Handle LLaMA3 models - simplified format
|
// Handle LLaMA3 models
|
||||||
if (model.startsWith("us.meta.llama3")) {
|
if (model.startsWith("us.meta.llama3")) {
|
||||||
const allMessages = systemMessage
|
// Only include the last user message for LLaMA
|
||||||
? [
|
const lastMessage = messages[messages.length - 1];
|
||||||
{ role: "system" as MessageRole, content: systemMessage },
|
const prompt = getMessageTextContent(lastMessage);
|
||||||
...messages,
|
|
||||||
]
|
|
||||||
: messages;
|
|
||||||
|
|
||||||
const prompt = allMessages
|
|
||||||
.map((m) => `${m.role}: ${getMessageTextContent(m)}`)
|
|
||||||
.join("\n");
|
|
||||||
|
|
||||||
return {
|
return {
|
||||||
contentType: "application/json",
|
contentType: "application/json",
|
||||||
accept: "application/json",
|
accept: "application/json",
|
||||||
body: {
|
body: {
|
||||||
prompt,
|
prompt: prompt,
|
||||||
|
max_gen_len: modelConfig.max_tokens || 256,
|
||||||
|
temperature: modelConfig.temperature || 0.5,
|
||||||
|
top_p: 0.9,
|
||||||
},
|
},
|
||||||
};
|
};
|
||||||
}
|
}
|
||||||
|
|
||||||
// Handle Mistral models
|
// Handle Mistral models with correct instruction format
|
||||||
if (model.startsWith("mistral.mistral")) {
|
if (model.startsWith("mistral.mistral")) {
|
||||||
const allMessages = systemMessage
|
const allMessages = systemMessage
|
||||||
? [
|
? [
|
||||||
@ -172,14 +192,21 @@ export class BedrockApi implements LLMApi {
|
|||||||
...messages,
|
...messages,
|
||||||
]
|
]
|
||||||
: messages;
|
: messages;
|
||||||
const prompt = allMessages
|
|
||||||
.map((m) => `${m.role}: ${getMessageTextContent(m)}`)
|
// Format messages as a conversation with instruction tags
|
||||||
.join("\n");
|
const prompt = `<s>[INST] ${allMessages
|
||||||
|
.map((m) => getMessageTextContent(m))
|
||||||
|
.join("\n")} [/INST]`;
|
||||||
|
|
||||||
return {
|
return {
|
||||||
|
contentType: "application/json",
|
||||||
|
accept: "application/json",
|
||||||
body: {
|
body: {
|
||||||
prompt,
|
prompt,
|
||||||
temperature: modelConfig.temperature || 0.7,
|
|
||||||
max_tokens: modelConfig.max_tokens || 4096,
|
max_tokens: modelConfig.max_tokens || 4096,
|
||||||
|
temperature: modelConfig.temperature || 0.5,
|
||||||
|
top_p: 0.9,
|
||||||
|
top_k: 50,
|
||||||
},
|
},
|
||||||
};
|
};
|
||||||
}
|
}
|
||||||
@ -258,7 +285,6 @@ export class BedrockApi implements LLMApi {
|
|||||||
systemMessage,
|
systemMessage,
|
||||||
modelConfig,
|
modelConfig,
|
||||||
);
|
);
|
||||||
// console.log("Request body:", JSON.stringify(requestBody, null, 2));
|
|
||||||
|
|
||||||
const controller = new AbortController();
|
const controller = new AbortController();
|
||||||
options.onController?.(controller);
|
options.onController?.(controller);
|
||||||
@ -338,7 +364,6 @@ export class BedrockApi implements LLMApi {
|
|||||||
} catch (e) {}
|
} catch (e) {}
|
||||||
}
|
}
|
||||||
const message = this.extractMessage(chunkJson, modelConfig.model);
|
const message = this.extractMessage(chunkJson, modelConfig.model);
|
||||||
// console.log("Extracted message:", message);
|
|
||||||
return message;
|
return message;
|
||||||
} catch (e) {
|
} catch (e) {
|
||||||
console.error("Error parsing chunk:", e);
|
console.error("Error parsing chunk:", e);
|
||||||
|
Loading…
Reference in New Issue
Block a user