mirror of
https://github.com/ChatGPTNextWeb/ChatGPT-Next-Web.git
synced 2025-05-22 05:30:19 +09:00
完善mistral tool use功能 和llama3消息格式问题
This commit is contained in:
parent
15d0600642
commit
e6633753a4
@ -105,26 +105,17 @@ async function requestBedrock(req: NextRequest) {
|
||||
console.log("[Bedrock Request] Model ID:", modelId);
|
||||
|
||||
// Handle tools for different models
|
||||
const isMistralModel = modelId.toLowerCase().includes("mistral");
|
||||
const isMistralLargeModel = modelId
|
||||
.toLowerCase()
|
||||
.includes("mistral.mistral-large");
|
||||
const isClaudeModel = modelId.toLowerCase().includes("claude");
|
||||
|
||||
const requestBody = {
|
||||
const requestBody: any = {
|
||||
...bodyJson,
|
||||
};
|
||||
|
||||
if (tools && tools.length > 0) {
|
||||
if (isClaudeModel) {
|
||||
// Claude models already have correct tool format
|
||||
requestBody.tools = tools;
|
||||
} else if (isMistralModel) {
|
||||
// Format messages for Mistral
|
||||
if (typeof requestBody.prompt === "string") {
|
||||
requestBody.messages = [
|
||||
{ role: "user", content: requestBody.prompt },
|
||||
];
|
||||
delete requestBody.prompt;
|
||||
}
|
||||
|
||||
if (isMistralLargeModel) {
|
||||
// Add tools in Mistral's format
|
||||
requestBody.tool_choice = "auto";
|
||||
requestBody.tools = tools.map((tool) => ({
|
||||
@ -135,6 +126,8 @@ async function requestBedrock(req: NextRequest) {
|
||||
parameters: tool.input_schema,
|
||||
},
|
||||
}));
|
||||
} else if (isClaudeModel) {
|
||||
requestBody.tools = tools;
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -21,7 +21,22 @@ const ClaudeMapper = {
|
||||
system: "user",
|
||||
} as const;
|
||||
|
||||
const MistralMapper = {
|
||||
system: "system",
|
||||
user: "user",
|
||||
assistant: "assistant",
|
||||
} as const;
|
||||
|
||||
type ClaudeRole = keyof typeof ClaudeMapper;
|
||||
type MistralRole = keyof typeof MistralMapper;
|
||||
|
||||
interface Tool {
|
||||
function?: {
|
||||
name?: string;
|
||||
description?: string;
|
||||
parameters?: any;
|
||||
};
|
||||
}
|
||||
|
||||
export class BedrockApi implements LLMApi {
|
||||
speech(options: SpeechOptions): Promise<ArrayBuffer> {
|
||||
@ -30,7 +45,6 @@ export class BedrockApi implements LLMApi {
|
||||
|
||||
formatRequestBody(messages: ChatOptions["messages"], modelConfig: any) {
|
||||
const model = modelConfig.model;
|
||||
|
||||
const visionModel = isVisionModel(modelConfig.model);
|
||||
|
||||
// Handle Titan models
|
||||
@ -53,37 +67,27 @@ export class BedrockApi implements LLMApi {
|
||||
|
||||
// Handle LLaMA models
|
||||
if (model.includes("meta.llama")) {
|
||||
// Format conversation for Llama models
|
||||
let prompt = "";
|
||||
let systemPrompt = "";
|
||||
let prompt = "<|begin_of_text|>";
|
||||
|
||||
// Extract system message if present
|
||||
const systemMessage = messages.find((m) => m.role === "system");
|
||||
if (systemMessage) {
|
||||
systemPrompt = getMessageTextContent(systemMessage);
|
||||
prompt += `<|start_header_id|>system<|end_header_id|>\n${getMessageTextContent(
|
||||
systemMessage,
|
||||
)}<|eot_id|>`;
|
||||
}
|
||||
|
||||
// Format the conversation
|
||||
const conversationMessages = messages.filter((m) => m.role !== "system");
|
||||
prompt = `<s>[INST] <<SYS>>\n${
|
||||
systemPrompt || "You are a helpful, respectful and honest assistant."
|
||||
}\n<</SYS>>\n\n`;
|
||||
|
||||
for (let i = 0; i < conversationMessages.length; i++) {
|
||||
const message = conversationMessages[i];
|
||||
for (const message of conversationMessages) {
|
||||
const role = message.role === "assistant" ? "assistant" : "user";
|
||||
const content = getMessageTextContent(message);
|
||||
if (i === 0 && message.role === "user") {
|
||||
// First user message goes in the same [INST] block as system prompt
|
||||
prompt += `${content} [/INST]`;
|
||||
} else {
|
||||
if (message.role === "user") {
|
||||
prompt += `\n\n<s>[INST] ${content} [/INST]`;
|
||||
} else {
|
||||
prompt += ` ${content} </s>`;
|
||||
}
|
||||
}
|
||||
prompt += `<|start_header_id|>${role}<|end_header_id|>\n${content}<|eot_id|>`;
|
||||
}
|
||||
|
||||
// Add the final assistant header to prompt completion
|
||||
prompt += "<|start_header_id|>assistant<|end_header_id|>";
|
||||
|
||||
return {
|
||||
prompt,
|
||||
max_gen_len: modelConfig.max_tokens || 512,
|
||||
@ -94,9 +98,8 @@ export class BedrockApi implements LLMApi {
|
||||
|
||||
// Handle Mistral models
|
||||
if (model.startsWith("mistral.mistral")) {
|
||||
// Format messages for Mistral's chat format
|
||||
const formattedMessages = messages.map((message) => ({
|
||||
role: message.role,
|
||||
role: MistralMapper[message.role as MistralRole] || "user",
|
||||
content: getMessageTextContent(message),
|
||||
}));
|
||||
|
||||
@ -234,6 +237,11 @@ export class BedrockApi implements LLMApi {
|
||||
});
|
||||
|
||||
const finalRequestBody = this.formatRequestBody(messages, modelConfig);
|
||||
console.log(
|
||||
"[Bedrock Client] Request Body:",
|
||||
JSON.stringify(finalRequestBody, null, 2),
|
||||
);
|
||||
|
||||
if (shouldStream) {
|
||||
let index = -1;
|
||||
const [tools, funcs] = usePluginStore
|
||||
@ -253,6 +261,7 @@ export class BedrockApi implements LLMApi {
|
||||
})),
|
||||
funcs,
|
||||
controller,
|
||||
// parseSSE
|
||||
(text: string, runTools: ChatMessageTool[]) => {
|
||||
// console.log("parseSSE", text, runTools);
|
||||
let chunkJson:
|
||||
@ -304,36 +313,73 @@ export class BedrockApi implements LLMApi {
|
||||
) => {
|
||||
// reset index value
|
||||
index = -1;
|
||||
// @ts-ignore
|
||||
requestPayload?.messages?.splice(
|
||||
|
||||
const modelId = modelConfig.model;
|
||||
const isMistral = modelId.startsWith("mistral.mistral");
|
||||
const isClaude = modelId.includes("anthropic.claude");
|
||||
|
||||
if (isClaude) {
|
||||
// Format for Claude
|
||||
// @ts-ignore
|
||||
requestPayload?.messages?.length,
|
||||
0,
|
||||
{
|
||||
role: "assistant",
|
||||
content: toolCallMessage.tool_calls.map(
|
||||
(tool: ChatMessageTool) => ({
|
||||
type: "tool_use",
|
||||
id: tool.id,
|
||||
name: tool?.function?.name,
|
||||
input: tool?.function?.arguments
|
||||
? JSON.parse(tool?.function?.arguments)
|
||||
: {},
|
||||
}),
|
||||
),
|
||||
},
|
||||
// @ts-ignore
|
||||
...toolCallResult.map((result) => ({
|
||||
role: "user",
|
||||
content: [
|
||||
{
|
||||
type: "tool_result",
|
||||
tool_use_id: result.tool_call_id,
|
||||
content: result.content,
|
||||
},
|
||||
],
|
||||
})),
|
||||
);
|
||||
requestPayload?.messages?.splice(
|
||||
// @ts-ignore
|
||||
requestPayload?.messages?.length,
|
||||
0,
|
||||
{
|
||||
role: "assistant",
|
||||
content: toolCallMessage.tool_calls.map(
|
||||
(tool: ChatMessageTool) => ({
|
||||
type: "tool_use",
|
||||
id: tool.id,
|
||||
name: tool?.function?.name,
|
||||
input: tool?.function?.arguments
|
||||
? JSON.parse(tool?.function?.arguments)
|
||||
: {},
|
||||
}),
|
||||
),
|
||||
},
|
||||
// @ts-ignore
|
||||
...toolCallResult.map((result) => ({
|
||||
role: "user",
|
||||
content: [
|
||||
{
|
||||
type: "tool_result",
|
||||
tool_use_id: result.tool_call_id,
|
||||
content: result.content,
|
||||
},
|
||||
],
|
||||
})),
|
||||
);
|
||||
} else if (isMistral) {
|
||||
// Format for Mistral
|
||||
requestPayload?.messages?.splice(
|
||||
requestPayload?.messages?.length,
|
||||
0,
|
||||
{
|
||||
role: "assistant",
|
||||
content: "",
|
||||
// @ts-ignore
|
||||
tool_calls: toolCallMessage.tool_calls.map(
|
||||
(tool: ChatMessageTool) => ({
|
||||
id: tool.id,
|
||||
function: {
|
||||
name: tool?.function?.name,
|
||||
arguments: tool?.function?.arguments || "{}",
|
||||
},
|
||||
}),
|
||||
),
|
||||
},
|
||||
...toolCallResult.map((result) => ({
|
||||
role: "tool",
|
||||
tool_call_id: result.tool_call_id,
|
||||
content: result.content,
|
||||
})),
|
||||
);
|
||||
} else {
|
||||
console.warn(
|
||||
`[Bedrock Client] Unhandled model type for tool calls: ${modelId}`,
|
||||
);
|
||||
}
|
||||
},
|
||||
options,
|
||||
);
|
||||
@ -368,6 +414,7 @@ export class BedrockApi implements LLMApi {
|
||||
options.onError?.(e as Error);
|
||||
}
|
||||
}
|
||||
|
||||
path(path: string): string {
|
||||
const accessStore = useAccessStore.getState();
|
||||
let baseUrl = accessStore.useCustomConfig ? accessStore.bedrockUrl : "";
|
||||
|
@ -342,12 +342,9 @@ const bedrockModels = [
|
||||
// Meta Llama Models
|
||||
"us.meta.llama3-1-8b-instruct-v1:0",
|
||||
"us.meta.llama3-1-70b-instruct-v1:0",
|
||||
"us.meta.llama3-2-1b-instruct-v1:0",
|
||||
"us.meta.llama3-2-3b-instruct-v1:0",
|
||||
"us.meta.llama3-2-11b-instruct-v1:0",
|
||||
"us.meta.llama3-2-90b-instruct-v1:0",
|
||||
// Mistral Models
|
||||
"mistral.mistral-7b-instruct-v0:2",
|
||||
"mistral.mistral-large-2402-v1:0",
|
||||
"mistral.mistral-large-2407-v1:0",
|
||||
];
|
||||
|
123
app/utils/aws.ts
123
app/utils/aws.ts
@ -245,6 +245,7 @@ export async function sign({
|
||||
export function parseEventData(chunk: Uint8Array): any {
|
||||
const decoder = new TextDecoder();
|
||||
const text = decoder.decode(chunk);
|
||||
// console.info("[AWS Parse ] parsing:", text);
|
||||
try {
|
||||
const parsed = JSON.parse(text);
|
||||
// AWS Bedrock wraps the response in a 'body' field
|
||||
@ -317,7 +318,10 @@ export function extractMessage(res: any, modelId: string = ""): string {
|
||||
|
||||
// Handle Mistral model response format
|
||||
if (modelId.toLowerCase().includes("mistral")) {
|
||||
return res?.outputs?.[0]?.text || "";
|
||||
if (res.choices?.[0]?.message?.content) {
|
||||
return res.choices[0].message.content;
|
||||
}
|
||||
return res.output || "";
|
||||
}
|
||||
|
||||
// Handle Llama model response format
|
||||
@ -334,9 +338,7 @@ export async function* transformBedrockStream(
|
||||
modelId: string,
|
||||
) {
|
||||
const reader = stream.getReader();
|
||||
let accumulatedText = "";
|
||||
let toolCallStarted = false;
|
||||
let currentToolCall = null;
|
||||
let toolInput = "";
|
||||
|
||||
try {
|
||||
while (true) {
|
||||
@ -349,90 +351,54 @@ export async function* transformBedrockStream(
|
||||
|
||||
// console.log("parseEventData=========================");
|
||||
// console.log(parsed);
|
||||
|
||||
// Handle Mistral models
|
||||
if (modelId.toLowerCase().includes("mistral")) {
|
||||
// If we have content, accumulate it
|
||||
if (
|
||||
parsed.choices?.[0]?.message?.role === "assistant" &&
|
||||
parsed.choices?.[0]?.message?.content
|
||||
) {
|
||||
accumulatedText += parsed.choices?.[0]?.message?.content;
|
||||
// console.log("accumulatedText=========================");
|
||||
// console.log(accumulatedText);
|
||||
// Check for tool call in the accumulated text
|
||||
if (!toolCallStarted && accumulatedText.includes("```json")) {
|
||||
const jsonMatch = accumulatedText.match(
|
||||
/```json\s*({[\s\S]*?})\s*```/,
|
||||
);
|
||||
if (jsonMatch) {
|
||||
try {
|
||||
const toolData = JSON.parse(jsonMatch[1]);
|
||||
currentToolCall = {
|
||||
id: `tool-${Date.now()}`,
|
||||
name: toolData.name,
|
||||
arguments: toolData.arguments,
|
||||
};
|
||||
// Handle tool calls
|
||||
if (parsed.choices?.[0]?.message?.tool_calls) {
|
||||
const toolCalls = parsed.choices[0].message.tool_calls;
|
||||
for (const toolCall of toolCalls) {
|
||||
// Emit tool call start
|
||||
yield `data: ${JSON.stringify({
|
||||
type: "content_block_start",
|
||||
content_block: {
|
||||
type: "tool_use",
|
||||
id: toolCall.id || `tool-${Date.now()}`,
|
||||
name: toolCall.function?.name,
|
||||
},
|
||||
})}\n\n`;
|
||||
|
||||
// Emit tool call start
|
||||
yield `data: ${JSON.stringify({
|
||||
type: "content_block_start",
|
||||
content_block: {
|
||||
type: "tool_use",
|
||||
id: currentToolCall.id,
|
||||
name: currentToolCall.name,
|
||||
},
|
||||
})}\n\n`;
|
||||
|
||||
// Emit tool arguments
|
||||
yield `data: ${JSON.stringify({
|
||||
type: "content_block_delta",
|
||||
delta: {
|
||||
type: "input_json_delta",
|
||||
partial_json: JSON.stringify(currentToolCall.arguments),
|
||||
},
|
||||
})}\n\n`;
|
||||
|
||||
// Emit tool call stop
|
||||
yield `data: ${JSON.stringify({
|
||||
type: "content_block_stop",
|
||||
})}\n\n`;
|
||||
|
||||
// Clear the accumulated text after processing the tool call
|
||||
accumulatedText = accumulatedText.replace(
|
||||
/```json\s*{[\s\S]*?}\s*```/,
|
||||
"",
|
||||
);
|
||||
toolCallStarted = false;
|
||||
currentToolCall = null;
|
||||
} catch (e) {
|
||||
console.error("Failed to parse tool JSON:", e);
|
||||
}
|
||||
// Emit tool arguments
|
||||
if (toolCall.function?.arguments) {
|
||||
yield `data: ${JSON.stringify({
|
||||
type: "content_block_delta",
|
||||
delta: {
|
||||
type: "input_json_delta",
|
||||
partial_json: toolCall.function.arguments,
|
||||
},
|
||||
})}\n\n`;
|
||||
}
|
||||
}
|
||||
// emit the text content if it's not empty
|
||||
if (parsed.choices?.[0]?.message?.content.trim()) {
|
||||
|
||||
// Emit tool call stop
|
||||
yield `data: ${JSON.stringify({
|
||||
delta: { text: parsed.choices?.[0]?.message?.content },
|
||||
})}\n\n`;
|
||||
}
|
||||
// Handle stop reason if present
|
||||
if (parsed.choices?.[0]?.stop_reason) {
|
||||
yield `data: ${JSON.stringify({
|
||||
delta: { stop_reason: parsed.choices[0].stop_reason },
|
||||
type: "content_block_stop",
|
||||
})}\n\n`;
|
||||
}
|
||||
continue;
|
||||
}
|
||||
}
|
||||
// Handle Llama models
|
||||
else if (modelId.toLowerCase().includes("llama")) {
|
||||
if (parsed.generation) {
|
||||
|
||||
// Handle regular content
|
||||
const content = parsed.choices?.[0]?.message?.content;
|
||||
if (content?.trim()) {
|
||||
yield `data: ${JSON.stringify({
|
||||
delta: { text: parsed.generation },
|
||||
delta: { text: content },
|
||||
})}\n\n`;
|
||||
}
|
||||
if (parsed.stop_reason) {
|
||||
|
||||
// Handle stop reason
|
||||
if (parsed.choices?.[0]?.finish_reason) {
|
||||
yield `data: ${JSON.stringify({
|
||||
delta: { stop_reason: parsed.stop_reason },
|
||||
delta: { stop_reason: parsed.choices[0].finish_reason },
|
||||
})}\n\n`;
|
||||
}
|
||||
}
|
||||
@ -469,8 +435,9 @@ export async function* transformBedrockStream(
|
||||
})}\n\n`;
|
||||
}
|
||||
}
|
||||
} else {
|
||||
// Handle other model text responses
|
||||
}
|
||||
// Handle other models
|
||||
else {
|
||||
const text = parsed.outputText || parsed.generation || "";
|
||||
if (text) {
|
||||
yield `data: ${JSON.stringify({
|
||||
|
Loading…
Reference in New Issue
Block a user