完善mistral tool use功能 和llama3消息格式问题

This commit is contained in:
glay 2024-11-25 20:08:21 +08:00
parent 15d0600642
commit e6633753a4
4 changed files with 151 additions and 147 deletions

View File

@ -105,26 +105,17 @@ async function requestBedrock(req: NextRequest) {
console.log("[Bedrock Request] Model ID:", modelId); console.log("[Bedrock Request] Model ID:", modelId);
// Handle tools for different models // Handle tools for different models
const isMistralModel = modelId.toLowerCase().includes("mistral"); const isMistralLargeModel = modelId
.toLowerCase()
.includes("mistral.mistral-large");
const isClaudeModel = modelId.toLowerCase().includes("claude"); const isClaudeModel = modelId.toLowerCase().includes("claude");
const requestBody = { const requestBody: any = {
...bodyJson, ...bodyJson,
}; };
if (tools && tools.length > 0) { if (tools && tools.length > 0) {
if (isClaudeModel) { if (isMistralLargeModel) {
// Claude models already have correct tool format
requestBody.tools = tools;
} else if (isMistralModel) {
// Format messages for Mistral
if (typeof requestBody.prompt === "string") {
requestBody.messages = [
{ role: "user", content: requestBody.prompt },
];
delete requestBody.prompt;
}
// Add tools in Mistral's format // Add tools in Mistral's format
requestBody.tool_choice = "auto"; requestBody.tool_choice = "auto";
requestBody.tools = tools.map((tool) => ({ requestBody.tools = tools.map((tool) => ({
@ -135,6 +126,8 @@ async function requestBedrock(req: NextRequest) {
parameters: tool.input_schema, parameters: tool.input_schema,
}, },
})); }));
} else if (isClaudeModel) {
requestBody.tools = tools;
} }
} }

View File

@ -21,7 +21,22 @@ const ClaudeMapper = {
system: "user", system: "user",
} as const; } as const;
const MistralMapper = {
system: "system",
user: "user",
assistant: "assistant",
} as const;
type ClaudeRole = keyof typeof ClaudeMapper; type ClaudeRole = keyof typeof ClaudeMapper;
type MistralRole = keyof typeof MistralMapper;
interface Tool {
function?: {
name?: string;
description?: string;
parameters?: any;
};
}
export class BedrockApi implements LLMApi { export class BedrockApi implements LLMApi {
speech(options: SpeechOptions): Promise<ArrayBuffer> { speech(options: SpeechOptions): Promise<ArrayBuffer> {
@ -30,7 +45,6 @@ export class BedrockApi implements LLMApi {
formatRequestBody(messages: ChatOptions["messages"], modelConfig: any) { formatRequestBody(messages: ChatOptions["messages"], modelConfig: any) {
const model = modelConfig.model; const model = modelConfig.model;
const visionModel = isVisionModel(modelConfig.model); const visionModel = isVisionModel(modelConfig.model);
// Handle Titan models // Handle Titan models
@ -53,37 +67,27 @@ export class BedrockApi implements LLMApi {
// Handle LLaMA models // Handle LLaMA models
if (model.includes("meta.llama")) { if (model.includes("meta.llama")) {
// Format conversation for Llama models let prompt = "<|begin_of_text|>";
let prompt = "";
let systemPrompt = "";
// Extract system message if present // Extract system message if present
const systemMessage = messages.find((m) => m.role === "system"); const systemMessage = messages.find((m) => m.role === "system");
if (systemMessage) { if (systemMessage) {
systemPrompt = getMessageTextContent(systemMessage); prompt += `<|start_header_id|>system<|end_header_id|>\n${getMessageTextContent(
systemMessage,
)}<|eot_id|>`;
} }
// Format the conversation // Format the conversation
const conversationMessages = messages.filter((m) => m.role !== "system"); const conversationMessages = messages.filter((m) => m.role !== "system");
prompt = `<s>[INST] <<SYS>>\n${ for (const message of conversationMessages) {
systemPrompt || "You are a helpful, respectful and honest assistant." const role = message.role === "assistant" ? "assistant" : "user";
}\n<</SYS>>\n\n`;
for (let i = 0; i < conversationMessages.length; i++) {
const message = conversationMessages[i];
const content = getMessageTextContent(message); const content = getMessageTextContent(message);
if (i === 0 && message.role === "user") { prompt += `<|start_header_id|>${role}<|end_header_id|>\n${content}<|eot_id|>`;
// First user message goes in the same [INST] block as system prompt
prompt += `${content} [/INST]`;
} else {
if (message.role === "user") {
prompt += `\n\n<s>[INST] ${content} [/INST]`;
} else {
prompt += ` ${content} </s>`;
}
}
} }
// Add the final assistant header to prompt completion
prompt += "<|start_header_id|>assistant<|end_header_id|>";
return { return {
prompt, prompt,
max_gen_len: modelConfig.max_tokens || 512, max_gen_len: modelConfig.max_tokens || 512,
@ -94,9 +98,8 @@ export class BedrockApi implements LLMApi {
// Handle Mistral models // Handle Mistral models
if (model.startsWith("mistral.mistral")) { if (model.startsWith("mistral.mistral")) {
// Format messages for Mistral's chat format
const formattedMessages = messages.map((message) => ({ const formattedMessages = messages.map((message) => ({
role: message.role, role: MistralMapper[message.role as MistralRole] || "user",
content: getMessageTextContent(message), content: getMessageTextContent(message),
})); }));
@ -234,6 +237,11 @@ export class BedrockApi implements LLMApi {
}); });
const finalRequestBody = this.formatRequestBody(messages, modelConfig); const finalRequestBody = this.formatRequestBody(messages, modelConfig);
console.log(
"[Bedrock Client] Request Body:",
JSON.stringify(finalRequestBody, null, 2),
);
if (shouldStream) { if (shouldStream) {
let index = -1; let index = -1;
const [tools, funcs] = usePluginStore const [tools, funcs] = usePluginStore
@ -253,6 +261,7 @@ export class BedrockApi implements LLMApi {
})), })),
funcs, funcs,
controller, controller,
// parseSSE
(text: string, runTools: ChatMessageTool[]) => { (text: string, runTools: ChatMessageTool[]) => {
// console.log("parseSSE", text, runTools); // console.log("parseSSE", text, runTools);
let chunkJson: let chunkJson:
@ -304,36 +313,73 @@ export class BedrockApi implements LLMApi {
) => { ) => {
// reset index value // reset index value
index = -1; index = -1;
// @ts-ignore
requestPayload?.messages?.splice( const modelId = modelConfig.model;
const isMistral = modelId.startsWith("mistral.mistral");
const isClaude = modelId.includes("anthropic.claude");
if (isClaude) {
// Format for Claude
// @ts-ignore // @ts-ignore
requestPayload?.messages?.length, requestPayload?.messages?.splice(
0, // @ts-ignore
{ requestPayload?.messages?.length,
role: "assistant", 0,
content: toolCallMessage.tool_calls.map( {
(tool: ChatMessageTool) => ({ role: "assistant",
type: "tool_use", content: toolCallMessage.tool_calls.map(
id: tool.id, (tool: ChatMessageTool) => ({
name: tool?.function?.name, type: "tool_use",
input: tool?.function?.arguments id: tool.id,
? JSON.parse(tool?.function?.arguments) name: tool?.function?.name,
: {}, input: tool?.function?.arguments
}), ? JSON.parse(tool?.function?.arguments)
), : {},
}, }),
// @ts-ignore ),
...toolCallResult.map((result) => ({ },
role: "user", // @ts-ignore
content: [ ...toolCallResult.map((result) => ({
{ role: "user",
type: "tool_result", content: [
tool_use_id: result.tool_call_id, {
content: result.content, type: "tool_result",
}, tool_use_id: result.tool_call_id,
], content: result.content,
})), },
); ],
})),
);
} else if (isMistral) {
// Format for Mistral
requestPayload?.messages?.splice(
requestPayload?.messages?.length,
0,
{
role: "assistant",
content: "",
// @ts-ignore
tool_calls: toolCallMessage.tool_calls.map(
(tool: ChatMessageTool) => ({
id: tool.id,
function: {
name: tool?.function?.name,
arguments: tool?.function?.arguments || "{}",
},
}),
),
},
...toolCallResult.map((result) => ({
role: "tool",
tool_call_id: result.tool_call_id,
content: result.content,
})),
);
} else {
console.warn(
`[Bedrock Client] Unhandled model type for tool calls: ${modelId}`,
);
}
}, },
options, options,
); );
@ -368,6 +414,7 @@ export class BedrockApi implements LLMApi {
options.onError?.(e as Error); options.onError?.(e as Error);
} }
} }
path(path: string): string { path(path: string): string {
const accessStore = useAccessStore.getState(); const accessStore = useAccessStore.getState();
let baseUrl = accessStore.useCustomConfig ? accessStore.bedrockUrl : ""; let baseUrl = accessStore.useCustomConfig ? accessStore.bedrockUrl : "";

View File

@ -342,12 +342,9 @@ const bedrockModels = [
// Meta Llama Models // Meta Llama Models
"us.meta.llama3-1-8b-instruct-v1:0", "us.meta.llama3-1-8b-instruct-v1:0",
"us.meta.llama3-1-70b-instruct-v1:0", "us.meta.llama3-1-70b-instruct-v1:0",
"us.meta.llama3-2-1b-instruct-v1:0",
"us.meta.llama3-2-3b-instruct-v1:0",
"us.meta.llama3-2-11b-instruct-v1:0", "us.meta.llama3-2-11b-instruct-v1:0",
"us.meta.llama3-2-90b-instruct-v1:0", "us.meta.llama3-2-90b-instruct-v1:0",
// Mistral Models // Mistral Models
"mistral.mistral-7b-instruct-v0:2",
"mistral.mistral-large-2402-v1:0", "mistral.mistral-large-2402-v1:0",
"mistral.mistral-large-2407-v1:0", "mistral.mistral-large-2407-v1:0",
]; ];

View File

@ -245,6 +245,7 @@ export async function sign({
export function parseEventData(chunk: Uint8Array): any { export function parseEventData(chunk: Uint8Array): any {
const decoder = new TextDecoder(); const decoder = new TextDecoder();
const text = decoder.decode(chunk); const text = decoder.decode(chunk);
// console.info("[AWS Parse ] parsing:", text);
try { try {
const parsed = JSON.parse(text); const parsed = JSON.parse(text);
// AWS Bedrock wraps the response in a 'body' field // AWS Bedrock wraps the response in a 'body' field
@ -317,7 +318,10 @@ export function extractMessage(res: any, modelId: string = ""): string {
// Handle Mistral model response format // Handle Mistral model response format
if (modelId.toLowerCase().includes("mistral")) { if (modelId.toLowerCase().includes("mistral")) {
return res?.outputs?.[0]?.text || ""; if (res.choices?.[0]?.message?.content) {
return res.choices[0].message.content;
}
return res.output || "";
} }
// Handle Llama model response format // Handle Llama model response format
@ -334,9 +338,7 @@ export async function* transformBedrockStream(
modelId: string, modelId: string,
) { ) {
const reader = stream.getReader(); const reader = stream.getReader();
let accumulatedText = ""; let toolInput = "";
let toolCallStarted = false;
let currentToolCall = null;
try { try {
while (true) { while (true) {
@ -349,90 +351,54 @@ export async function* transformBedrockStream(
// console.log("parseEventData========================="); // console.log("parseEventData=========================");
// console.log(parsed); // console.log(parsed);
// Handle Mistral models // Handle Mistral models
if (modelId.toLowerCase().includes("mistral")) { if (modelId.toLowerCase().includes("mistral")) {
// If we have content, accumulate it // Handle tool calls
if ( if (parsed.choices?.[0]?.message?.tool_calls) {
parsed.choices?.[0]?.message?.role === "assistant" && const toolCalls = parsed.choices[0].message.tool_calls;
parsed.choices?.[0]?.message?.content for (const toolCall of toolCalls) {
) { // Emit tool call start
accumulatedText += parsed.choices?.[0]?.message?.content; yield `data: ${JSON.stringify({
// console.log("accumulatedText========================="); type: "content_block_start",
// console.log(accumulatedText); content_block: {
// Check for tool call in the accumulated text type: "tool_use",
if (!toolCallStarted && accumulatedText.includes("```json")) { id: toolCall.id || `tool-${Date.now()}`,
const jsonMatch = accumulatedText.match( name: toolCall.function?.name,
/```json\s*({[\s\S]*?})\s*```/, },
); })}\n\n`;
if (jsonMatch) {
try {
const toolData = JSON.parse(jsonMatch[1]);
currentToolCall = {
id: `tool-${Date.now()}`,
name: toolData.name,
arguments: toolData.arguments,
};
// Emit tool call start // Emit tool arguments
yield `data: ${JSON.stringify({ if (toolCall.function?.arguments) {
type: "content_block_start", yield `data: ${JSON.stringify({
content_block: { type: "content_block_delta",
type: "tool_use", delta: {
id: currentToolCall.id, type: "input_json_delta",
name: currentToolCall.name, partial_json: toolCall.function.arguments,
}, },
})}\n\n`; })}\n\n`;
// Emit tool arguments
yield `data: ${JSON.stringify({
type: "content_block_delta",
delta: {
type: "input_json_delta",
partial_json: JSON.stringify(currentToolCall.arguments),
},
})}\n\n`;
// Emit tool call stop
yield `data: ${JSON.stringify({
type: "content_block_stop",
})}\n\n`;
// Clear the accumulated text after processing the tool call
accumulatedText = accumulatedText.replace(
/```json\s*{[\s\S]*?}\s*```/,
"",
);
toolCallStarted = false;
currentToolCall = null;
} catch (e) {
console.error("Failed to parse tool JSON:", e);
}
} }
}
// emit the text content if it's not empty // Emit tool call stop
if (parsed.choices?.[0]?.message?.content.trim()) {
yield `data: ${JSON.stringify({ yield `data: ${JSON.stringify({
delta: { text: parsed.choices?.[0]?.message?.content }, type: "content_block_stop",
})}\n\n`;
}
// Handle stop reason if present
if (parsed.choices?.[0]?.stop_reason) {
yield `data: ${JSON.stringify({
delta: { stop_reason: parsed.choices[0].stop_reason },
})}\n\n`; })}\n\n`;
} }
continue;
} }
}
// Handle Llama models // Handle regular content
else if (modelId.toLowerCase().includes("llama")) { const content = parsed.choices?.[0]?.message?.content;
if (parsed.generation) { if (content?.trim()) {
yield `data: ${JSON.stringify({ yield `data: ${JSON.stringify({
delta: { text: parsed.generation }, delta: { text: content },
})}\n\n`; })}\n\n`;
} }
if (parsed.stop_reason) {
// Handle stop reason
if (parsed.choices?.[0]?.finish_reason) {
yield `data: ${JSON.stringify({ yield `data: ${JSON.stringify({
delta: { stop_reason: parsed.stop_reason }, delta: { stop_reason: parsed.choices[0].finish_reason },
})}\n\n`; })}\n\n`;
} }
} }
@ -469,8 +435,9 @@ export async function* transformBedrockStream(
})}\n\n`; })}\n\n`;
} }
} }
} else { }
// Handle other model text responses // Handle other models
else {
const text = parsed.outputText || parsed.generation || ""; const text = parsed.outputText || parsed.generation || "";
if (text) { if (text) {
yield `data: ${JSON.stringify({ yield `data: ${JSON.stringify({