完善mistral tool use功能 和llama3消息格式问题

This commit is contained in:
glay 2024-11-25 20:08:21 +08:00
parent 15d0600642
commit e6633753a4
4 changed files with 151 additions and 147 deletions

View File

@ -105,26 +105,17 @@ async function requestBedrock(req: NextRequest) {
console.log("[Bedrock Request] Model ID:", modelId);
// Handle tools for different models
const isMistralModel = modelId.toLowerCase().includes("mistral");
const isMistralLargeModel = modelId
.toLowerCase()
.includes("mistral.mistral-large");
const isClaudeModel = modelId.toLowerCase().includes("claude");
const requestBody = {
const requestBody: any = {
...bodyJson,
};
if (tools && tools.length > 0) {
if (isClaudeModel) {
// Claude models already have correct tool format
requestBody.tools = tools;
} else if (isMistralModel) {
// Format messages for Mistral
if (typeof requestBody.prompt === "string") {
requestBody.messages = [
{ role: "user", content: requestBody.prompt },
];
delete requestBody.prompt;
}
if (isMistralLargeModel) {
// Add tools in Mistral's format
requestBody.tool_choice = "auto";
requestBody.tools = tools.map((tool) => ({
@ -135,6 +126,8 @@ async function requestBedrock(req: NextRequest) {
parameters: tool.input_schema,
},
}));
} else if (isClaudeModel) {
requestBody.tools = tools;
}
}

View File

@ -21,7 +21,22 @@ const ClaudeMapper = {
system: "user",
} as const;
const MistralMapper = {
system: "system",
user: "user",
assistant: "assistant",
} as const;
type ClaudeRole = keyof typeof ClaudeMapper;
type MistralRole = keyof typeof MistralMapper;
interface Tool {
function?: {
name?: string;
description?: string;
parameters?: any;
};
}
export class BedrockApi implements LLMApi {
speech(options: SpeechOptions): Promise<ArrayBuffer> {
@ -30,7 +45,6 @@ export class BedrockApi implements LLMApi {
formatRequestBody(messages: ChatOptions["messages"], modelConfig: any) {
const model = modelConfig.model;
const visionModel = isVisionModel(modelConfig.model);
// Handle Titan models
@ -53,37 +67,27 @@ export class BedrockApi implements LLMApi {
// Handle LLaMA models
if (model.includes("meta.llama")) {
// Format conversation for Llama models
let prompt = "";
let systemPrompt = "";
let prompt = "<|begin_of_text|>";
// Extract system message if present
const systemMessage = messages.find((m) => m.role === "system");
if (systemMessage) {
systemPrompt = getMessageTextContent(systemMessage);
prompt += `<|start_header_id|>system<|end_header_id|>\n${getMessageTextContent(
systemMessage,
)}<|eot_id|>`;
}
// Format the conversation
const conversationMessages = messages.filter((m) => m.role !== "system");
prompt = `<s>[INST] <<SYS>>\n${
systemPrompt || "You are a helpful, respectful and honest assistant."
}\n<</SYS>>\n\n`;
for (let i = 0; i < conversationMessages.length; i++) {
const message = conversationMessages[i];
for (const message of conversationMessages) {
const role = message.role === "assistant" ? "assistant" : "user";
const content = getMessageTextContent(message);
if (i === 0 && message.role === "user") {
// First user message goes in the same [INST] block as system prompt
prompt += `${content} [/INST]`;
} else {
if (message.role === "user") {
prompt += `\n\n<s>[INST] ${content} [/INST]`;
} else {
prompt += ` ${content} </s>`;
}
}
prompt += `<|start_header_id|>${role}<|end_header_id|>\n${content}<|eot_id|>`;
}
// Add the final assistant header to prompt completion
prompt += "<|start_header_id|>assistant<|end_header_id|>";
return {
prompt,
max_gen_len: modelConfig.max_tokens || 512,
@ -94,9 +98,8 @@ export class BedrockApi implements LLMApi {
// Handle Mistral models
if (model.startsWith("mistral.mistral")) {
// Format messages for Mistral's chat format
const formattedMessages = messages.map((message) => ({
role: message.role,
role: MistralMapper[message.role as MistralRole] || "user",
content: getMessageTextContent(message),
}));
@ -234,6 +237,11 @@ export class BedrockApi implements LLMApi {
});
const finalRequestBody = this.formatRequestBody(messages, modelConfig);
console.log(
"[Bedrock Client] Request Body:",
JSON.stringify(finalRequestBody, null, 2),
);
if (shouldStream) {
let index = -1;
const [tools, funcs] = usePluginStore
@ -253,6 +261,7 @@ export class BedrockApi implements LLMApi {
})),
funcs,
controller,
// parseSSE
(text: string, runTools: ChatMessageTool[]) => {
// console.log("parseSSE", text, runTools);
let chunkJson:
@ -304,36 +313,73 @@ export class BedrockApi implements LLMApi {
) => {
// reset index value
index = -1;
// @ts-ignore
requestPayload?.messages?.splice(
const modelId = modelConfig.model;
const isMistral = modelId.startsWith("mistral.mistral");
const isClaude = modelId.includes("anthropic.claude");
if (isClaude) {
// Format for Claude
// @ts-ignore
requestPayload?.messages?.length,
0,
{
role: "assistant",
content: toolCallMessage.tool_calls.map(
(tool: ChatMessageTool) => ({
type: "tool_use",
id: tool.id,
name: tool?.function?.name,
input: tool?.function?.arguments
? JSON.parse(tool?.function?.arguments)
: {},
}),
),
},
// @ts-ignore
...toolCallResult.map((result) => ({
role: "user",
content: [
{
type: "tool_result",
tool_use_id: result.tool_call_id,
content: result.content,
},
],
})),
);
requestPayload?.messages?.splice(
// @ts-ignore
requestPayload?.messages?.length,
0,
{
role: "assistant",
content: toolCallMessage.tool_calls.map(
(tool: ChatMessageTool) => ({
type: "tool_use",
id: tool.id,
name: tool?.function?.name,
input: tool?.function?.arguments
? JSON.parse(tool?.function?.arguments)
: {},
}),
),
},
// @ts-ignore
...toolCallResult.map((result) => ({
role: "user",
content: [
{
type: "tool_result",
tool_use_id: result.tool_call_id,
content: result.content,
},
],
})),
);
} else if (isMistral) {
// Format for Mistral
requestPayload?.messages?.splice(
requestPayload?.messages?.length,
0,
{
role: "assistant",
content: "",
// @ts-ignore
tool_calls: toolCallMessage.tool_calls.map(
(tool: ChatMessageTool) => ({
id: tool.id,
function: {
name: tool?.function?.name,
arguments: tool?.function?.arguments || "{}",
},
}),
),
},
...toolCallResult.map((result) => ({
role: "tool",
tool_call_id: result.tool_call_id,
content: result.content,
})),
);
} else {
console.warn(
`[Bedrock Client] Unhandled model type for tool calls: ${modelId}`,
);
}
},
options,
);
@ -368,6 +414,7 @@ export class BedrockApi implements LLMApi {
options.onError?.(e as Error);
}
}
path(path: string): string {
const accessStore = useAccessStore.getState();
let baseUrl = accessStore.useCustomConfig ? accessStore.bedrockUrl : "";

View File

@ -342,12 +342,9 @@ const bedrockModels = [
// Meta Llama Models
"us.meta.llama3-1-8b-instruct-v1:0",
"us.meta.llama3-1-70b-instruct-v1:0",
"us.meta.llama3-2-1b-instruct-v1:0",
"us.meta.llama3-2-3b-instruct-v1:0",
"us.meta.llama3-2-11b-instruct-v1:0",
"us.meta.llama3-2-90b-instruct-v1:0",
// Mistral Models
"mistral.mistral-7b-instruct-v0:2",
"mistral.mistral-large-2402-v1:0",
"mistral.mistral-large-2407-v1:0",
];

View File

@ -245,6 +245,7 @@ export async function sign({
export function parseEventData(chunk: Uint8Array): any {
const decoder = new TextDecoder();
const text = decoder.decode(chunk);
// console.info("[AWS Parse ] parsing:", text);
try {
const parsed = JSON.parse(text);
// AWS Bedrock wraps the response in a 'body' field
@ -317,7 +318,10 @@ export function extractMessage(res: any, modelId: string = ""): string {
// Handle Mistral model response format
if (modelId.toLowerCase().includes("mistral")) {
return res?.outputs?.[0]?.text || "";
if (res.choices?.[0]?.message?.content) {
return res.choices[0].message.content;
}
return res.output || "";
}
// Handle Llama model response format
@ -334,9 +338,7 @@ export async function* transformBedrockStream(
modelId: string,
) {
const reader = stream.getReader();
let accumulatedText = "";
let toolCallStarted = false;
let currentToolCall = null;
let toolInput = "";
try {
while (true) {
@ -349,90 +351,54 @@ export async function* transformBedrockStream(
// console.log("parseEventData=========================");
// console.log(parsed);
// Handle Mistral models
if (modelId.toLowerCase().includes("mistral")) {
// If we have content, accumulate it
if (
parsed.choices?.[0]?.message?.role === "assistant" &&
parsed.choices?.[0]?.message?.content
) {
accumulatedText += parsed.choices?.[0]?.message?.content;
// console.log("accumulatedText=========================");
// console.log(accumulatedText);
// Check for tool call in the accumulated text
if (!toolCallStarted && accumulatedText.includes("```json")) {
const jsonMatch = accumulatedText.match(
/```json\s*({[\s\S]*?})\s*```/,
);
if (jsonMatch) {
try {
const toolData = JSON.parse(jsonMatch[1]);
currentToolCall = {
id: `tool-${Date.now()}`,
name: toolData.name,
arguments: toolData.arguments,
};
// Handle tool calls
if (parsed.choices?.[0]?.message?.tool_calls) {
const toolCalls = parsed.choices[0].message.tool_calls;
for (const toolCall of toolCalls) {
// Emit tool call start
yield `data: ${JSON.stringify({
type: "content_block_start",
content_block: {
type: "tool_use",
id: toolCall.id || `tool-${Date.now()}`,
name: toolCall.function?.name,
},
})}\n\n`;
// Emit tool call start
yield `data: ${JSON.stringify({
type: "content_block_start",
content_block: {
type: "tool_use",
id: currentToolCall.id,
name: currentToolCall.name,
},
})}\n\n`;
// Emit tool arguments
yield `data: ${JSON.stringify({
type: "content_block_delta",
delta: {
type: "input_json_delta",
partial_json: JSON.stringify(currentToolCall.arguments),
},
})}\n\n`;
// Emit tool call stop
yield `data: ${JSON.stringify({
type: "content_block_stop",
})}\n\n`;
// Clear the accumulated text after processing the tool call
accumulatedText = accumulatedText.replace(
/```json\s*{[\s\S]*?}\s*```/,
"",
);
toolCallStarted = false;
currentToolCall = null;
} catch (e) {
console.error("Failed to parse tool JSON:", e);
}
// Emit tool arguments
if (toolCall.function?.arguments) {
yield `data: ${JSON.stringify({
type: "content_block_delta",
delta: {
type: "input_json_delta",
partial_json: toolCall.function.arguments,
},
})}\n\n`;
}
}
// emit the text content if it's not empty
if (parsed.choices?.[0]?.message?.content.trim()) {
// Emit tool call stop
yield `data: ${JSON.stringify({
delta: { text: parsed.choices?.[0]?.message?.content },
})}\n\n`;
}
// Handle stop reason if present
if (parsed.choices?.[0]?.stop_reason) {
yield `data: ${JSON.stringify({
delta: { stop_reason: parsed.choices[0].stop_reason },
type: "content_block_stop",
})}\n\n`;
}
continue;
}
}
// Handle Llama models
else if (modelId.toLowerCase().includes("llama")) {
if (parsed.generation) {
// Handle regular content
const content = parsed.choices?.[0]?.message?.content;
if (content?.trim()) {
yield `data: ${JSON.stringify({
delta: { text: parsed.generation },
delta: { text: content },
})}\n\n`;
}
if (parsed.stop_reason) {
// Handle stop reason
if (parsed.choices?.[0]?.finish_reason) {
yield `data: ${JSON.stringify({
delta: { stop_reason: parsed.stop_reason },
delta: { stop_reason: parsed.choices[0].finish_reason },
})}\n\n`;
}
}
@ -469,8 +435,9 @@ export async function* transformBedrockStream(
})}\n\n`;
}
}
} else {
// Handle other model text responses
}
// Handle other models
else {
const text = parsed.outputText || parsed.generation || "";
if (text) {
yield `data: ${JSON.stringify({