mirror of
https://github.com/ChatGPTNextWeb/ChatGPT-Next-Web.git
synced 2025-05-23 22:20:23 +09:00
完善mistral tool use功能 和llama3消息格式问题
This commit is contained in:
parent
15d0600642
commit
e6633753a4
@ -105,26 +105,17 @@ async function requestBedrock(req: NextRequest) {
|
|||||||
console.log("[Bedrock Request] Model ID:", modelId);
|
console.log("[Bedrock Request] Model ID:", modelId);
|
||||||
|
|
||||||
// Handle tools for different models
|
// Handle tools for different models
|
||||||
const isMistralModel = modelId.toLowerCase().includes("mistral");
|
const isMistralLargeModel = modelId
|
||||||
|
.toLowerCase()
|
||||||
|
.includes("mistral.mistral-large");
|
||||||
const isClaudeModel = modelId.toLowerCase().includes("claude");
|
const isClaudeModel = modelId.toLowerCase().includes("claude");
|
||||||
|
|
||||||
const requestBody = {
|
const requestBody: any = {
|
||||||
...bodyJson,
|
...bodyJson,
|
||||||
};
|
};
|
||||||
|
|
||||||
if (tools && tools.length > 0) {
|
if (tools && tools.length > 0) {
|
||||||
if (isClaudeModel) {
|
if (isMistralLargeModel) {
|
||||||
// Claude models already have correct tool format
|
|
||||||
requestBody.tools = tools;
|
|
||||||
} else if (isMistralModel) {
|
|
||||||
// Format messages for Mistral
|
|
||||||
if (typeof requestBody.prompt === "string") {
|
|
||||||
requestBody.messages = [
|
|
||||||
{ role: "user", content: requestBody.prompt },
|
|
||||||
];
|
|
||||||
delete requestBody.prompt;
|
|
||||||
}
|
|
||||||
|
|
||||||
// Add tools in Mistral's format
|
// Add tools in Mistral's format
|
||||||
requestBody.tool_choice = "auto";
|
requestBody.tool_choice = "auto";
|
||||||
requestBody.tools = tools.map((tool) => ({
|
requestBody.tools = tools.map((tool) => ({
|
||||||
@ -135,6 +126,8 @@ async function requestBedrock(req: NextRequest) {
|
|||||||
parameters: tool.input_schema,
|
parameters: tool.input_schema,
|
||||||
},
|
},
|
||||||
}));
|
}));
|
||||||
|
} else if (isClaudeModel) {
|
||||||
|
requestBody.tools = tools;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -21,7 +21,22 @@ const ClaudeMapper = {
|
|||||||
system: "user",
|
system: "user",
|
||||||
} as const;
|
} as const;
|
||||||
|
|
||||||
|
const MistralMapper = {
|
||||||
|
system: "system",
|
||||||
|
user: "user",
|
||||||
|
assistant: "assistant",
|
||||||
|
} as const;
|
||||||
|
|
||||||
type ClaudeRole = keyof typeof ClaudeMapper;
|
type ClaudeRole = keyof typeof ClaudeMapper;
|
||||||
|
type MistralRole = keyof typeof MistralMapper;
|
||||||
|
|
||||||
|
interface Tool {
|
||||||
|
function?: {
|
||||||
|
name?: string;
|
||||||
|
description?: string;
|
||||||
|
parameters?: any;
|
||||||
|
};
|
||||||
|
}
|
||||||
|
|
||||||
export class BedrockApi implements LLMApi {
|
export class BedrockApi implements LLMApi {
|
||||||
speech(options: SpeechOptions): Promise<ArrayBuffer> {
|
speech(options: SpeechOptions): Promise<ArrayBuffer> {
|
||||||
@ -30,7 +45,6 @@ export class BedrockApi implements LLMApi {
|
|||||||
|
|
||||||
formatRequestBody(messages: ChatOptions["messages"], modelConfig: any) {
|
formatRequestBody(messages: ChatOptions["messages"], modelConfig: any) {
|
||||||
const model = modelConfig.model;
|
const model = modelConfig.model;
|
||||||
|
|
||||||
const visionModel = isVisionModel(modelConfig.model);
|
const visionModel = isVisionModel(modelConfig.model);
|
||||||
|
|
||||||
// Handle Titan models
|
// Handle Titan models
|
||||||
@ -53,37 +67,27 @@ export class BedrockApi implements LLMApi {
|
|||||||
|
|
||||||
// Handle LLaMA models
|
// Handle LLaMA models
|
||||||
if (model.includes("meta.llama")) {
|
if (model.includes("meta.llama")) {
|
||||||
// Format conversation for Llama models
|
let prompt = "<|begin_of_text|>";
|
||||||
let prompt = "";
|
|
||||||
let systemPrompt = "";
|
|
||||||
|
|
||||||
// Extract system message if present
|
// Extract system message if present
|
||||||
const systemMessage = messages.find((m) => m.role === "system");
|
const systemMessage = messages.find((m) => m.role === "system");
|
||||||
if (systemMessage) {
|
if (systemMessage) {
|
||||||
systemPrompt = getMessageTextContent(systemMessage);
|
prompt += `<|start_header_id|>system<|end_header_id|>\n${getMessageTextContent(
|
||||||
|
systemMessage,
|
||||||
|
)}<|eot_id|>`;
|
||||||
}
|
}
|
||||||
|
|
||||||
// Format the conversation
|
// Format the conversation
|
||||||
const conversationMessages = messages.filter((m) => m.role !== "system");
|
const conversationMessages = messages.filter((m) => m.role !== "system");
|
||||||
prompt = `<s>[INST] <<SYS>>\n${
|
for (const message of conversationMessages) {
|
||||||
systemPrompt || "You are a helpful, respectful and honest assistant."
|
const role = message.role === "assistant" ? "assistant" : "user";
|
||||||
}\n<</SYS>>\n\n`;
|
|
||||||
|
|
||||||
for (let i = 0; i < conversationMessages.length; i++) {
|
|
||||||
const message = conversationMessages[i];
|
|
||||||
const content = getMessageTextContent(message);
|
const content = getMessageTextContent(message);
|
||||||
if (i === 0 && message.role === "user") {
|
prompt += `<|start_header_id|>${role}<|end_header_id|>\n${content}<|eot_id|>`;
|
||||||
// First user message goes in the same [INST] block as system prompt
|
|
||||||
prompt += `${content} [/INST]`;
|
|
||||||
} else {
|
|
||||||
if (message.role === "user") {
|
|
||||||
prompt += `\n\n<s>[INST] ${content} [/INST]`;
|
|
||||||
} else {
|
|
||||||
prompt += ` ${content} </s>`;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Add the final assistant header to prompt completion
|
||||||
|
prompt += "<|start_header_id|>assistant<|end_header_id|>";
|
||||||
|
|
||||||
return {
|
return {
|
||||||
prompt,
|
prompt,
|
||||||
max_gen_len: modelConfig.max_tokens || 512,
|
max_gen_len: modelConfig.max_tokens || 512,
|
||||||
@ -94,9 +98,8 @@ export class BedrockApi implements LLMApi {
|
|||||||
|
|
||||||
// Handle Mistral models
|
// Handle Mistral models
|
||||||
if (model.startsWith("mistral.mistral")) {
|
if (model.startsWith("mistral.mistral")) {
|
||||||
// Format messages for Mistral's chat format
|
|
||||||
const formattedMessages = messages.map((message) => ({
|
const formattedMessages = messages.map((message) => ({
|
||||||
role: message.role,
|
role: MistralMapper[message.role as MistralRole] || "user",
|
||||||
content: getMessageTextContent(message),
|
content: getMessageTextContent(message),
|
||||||
}));
|
}));
|
||||||
|
|
||||||
@ -234,6 +237,11 @@ export class BedrockApi implements LLMApi {
|
|||||||
});
|
});
|
||||||
|
|
||||||
const finalRequestBody = this.formatRequestBody(messages, modelConfig);
|
const finalRequestBody = this.formatRequestBody(messages, modelConfig);
|
||||||
|
console.log(
|
||||||
|
"[Bedrock Client] Request Body:",
|
||||||
|
JSON.stringify(finalRequestBody, null, 2),
|
||||||
|
);
|
||||||
|
|
||||||
if (shouldStream) {
|
if (shouldStream) {
|
||||||
let index = -1;
|
let index = -1;
|
||||||
const [tools, funcs] = usePluginStore
|
const [tools, funcs] = usePluginStore
|
||||||
@ -253,6 +261,7 @@ export class BedrockApi implements LLMApi {
|
|||||||
})),
|
})),
|
||||||
funcs,
|
funcs,
|
||||||
controller,
|
controller,
|
||||||
|
// parseSSE
|
||||||
(text: string, runTools: ChatMessageTool[]) => {
|
(text: string, runTools: ChatMessageTool[]) => {
|
||||||
// console.log("parseSSE", text, runTools);
|
// console.log("parseSSE", text, runTools);
|
||||||
let chunkJson:
|
let chunkJson:
|
||||||
@ -304,36 +313,73 @@ export class BedrockApi implements LLMApi {
|
|||||||
) => {
|
) => {
|
||||||
// reset index value
|
// reset index value
|
||||||
index = -1;
|
index = -1;
|
||||||
// @ts-ignore
|
|
||||||
requestPayload?.messages?.splice(
|
const modelId = modelConfig.model;
|
||||||
|
const isMistral = modelId.startsWith("mistral.mistral");
|
||||||
|
const isClaude = modelId.includes("anthropic.claude");
|
||||||
|
|
||||||
|
if (isClaude) {
|
||||||
|
// Format for Claude
|
||||||
// @ts-ignore
|
// @ts-ignore
|
||||||
requestPayload?.messages?.length,
|
requestPayload?.messages?.splice(
|
||||||
0,
|
// @ts-ignore
|
||||||
{
|
requestPayload?.messages?.length,
|
||||||
role: "assistant",
|
0,
|
||||||
content: toolCallMessage.tool_calls.map(
|
{
|
||||||
(tool: ChatMessageTool) => ({
|
role: "assistant",
|
||||||
type: "tool_use",
|
content: toolCallMessage.tool_calls.map(
|
||||||
id: tool.id,
|
(tool: ChatMessageTool) => ({
|
||||||
name: tool?.function?.name,
|
type: "tool_use",
|
||||||
input: tool?.function?.arguments
|
id: tool.id,
|
||||||
? JSON.parse(tool?.function?.arguments)
|
name: tool?.function?.name,
|
||||||
: {},
|
input: tool?.function?.arguments
|
||||||
}),
|
? JSON.parse(tool?.function?.arguments)
|
||||||
),
|
: {},
|
||||||
},
|
}),
|
||||||
// @ts-ignore
|
),
|
||||||
...toolCallResult.map((result) => ({
|
},
|
||||||
role: "user",
|
// @ts-ignore
|
||||||
content: [
|
...toolCallResult.map((result) => ({
|
||||||
{
|
role: "user",
|
||||||
type: "tool_result",
|
content: [
|
||||||
tool_use_id: result.tool_call_id,
|
{
|
||||||
content: result.content,
|
type: "tool_result",
|
||||||
},
|
tool_use_id: result.tool_call_id,
|
||||||
],
|
content: result.content,
|
||||||
})),
|
},
|
||||||
);
|
],
|
||||||
|
})),
|
||||||
|
);
|
||||||
|
} else if (isMistral) {
|
||||||
|
// Format for Mistral
|
||||||
|
requestPayload?.messages?.splice(
|
||||||
|
requestPayload?.messages?.length,
|
||||||
|
0,
|
||||||
|
{
|
||||||
|
role: "assistant",
|
||||||
|
content: "",
|
||||||
|
// @ts-ignore
|
||||||
|
tool_calls: toolCallMessage.tool_calls.map(
|
||||||
|
(tool: ChatMessageTool) => ({
|
||||||
|
id: tool.id,
|
||||||
|
function: {
|
||||||
|
name: tool?.function?.name,
|
||||||
|
arguments: tool?.function?.arguments || "{}",
|
||||||
|
},
|
||||||
|
}),
|
||||||
|
),
|
||||||
|
},
|
||||||
|
...toolCallResult.map((result) => ({
|
||||||
|
role: "tool",
|
||||||
|
tool_call_id: result.tool_call_id,
|
||||||
|
content: result.content,
|
||||||
|
})),
|
||||||
|
);
|
||||||
|
} else {
|
||||||
|
console.warn(
|
||||||
|
`[Bedrock Client] Unhandled model type for tool calls: ${modelId}`,
|
||||||
|
);
|
||||||
|
}
|
||||||
},
|
},
|
||||||
options,
|
options,
|
||||||
);
|
);
|
||||||
@ -368,6 +414,7 @@ export class BedrockApi implements LLMApi {
|
|||||||
options.onError?.(e as Error);
|
options.onError?.(e as Error);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
path(path: string): string {
|
path(path: string): string {
|
||||||
const accessStore = useAccessStore.getState();
|
const accessStore = useAccessStore.getState();
|
||||||
let baseUrl = accessStore.useCustomConfig ? accessStore.bedrockUrl : "";
|
let baseUrl = accessStore.useCustomConfig ? accessStore.bedrockUrl : "";
|
||||||
|
@ -342,12 +342,9 @@ const bedrockModels = [
|
|||||||
// Meta Llama Models
|
// Meta Llama Models
|
||||||
"us.meta.llama3-1-8b-instruct-v1:0",
|
"us.meta.llama3-1-8b-instruct-v1:0",
|
||||||
"us.meta.llama3-1-70b-instruct-v1:0",
|
"us.meta.llama3-1-70b-instruct-v1:0",
|
||||||
"us.meta.llama3-2-1b-instruct-v1:0",
|
|
||||||
"us.meta.llama3-2-3b-instruct-v1:0",
|
|
||||||
"us.meta.llama3-2-11b-instruct-v1:0",
|
"us.meta.llama3-2-11b-instruct-v1:0",
|
||||||
"us.meta.llama3-2-90b-instruct-v1:0",
|
"us.meta.llama3-2-90b-instruct-v1:0",
|
||||||
// Mistral Models
|
// Mistral Models
|
||||||
"mistral.mistral-7b-instruct-v0:2",
|
|
||||||
"mistral.mistral-large-2402-v1:0",
|
"mistral.mistral-large-2402-v1:0",
|
||||||
"mistral.mistral-large-2407-v1:0",
|
"mistral.mistral-large-2407-v1:0",
|
||||||
];
|
];
|
||||||
|
123
app/utils/aws.ts
123
app/utils/aws.ts
@ -245,6 +245,7 @@ export async function sign({
|
|||||||
export function parseEventData(chunk: Uint8Array): any {
|
export function parseEventData(chunk: Uint8Array): any {
|
||||||
const decoder = new TextDecoder();
|
const decoder = new TextDecoder();
|
||||||
const text = decoder.decode(chunk);
|
const text = decoder.decode(chunk);
|
||||||
|
// console.info("[AWS Parse ] parsing:", text);
|
||||||
try {
|
try {
|
||||||
const parsed = JSON.parse(text);
|
const parsed = JSON.parse(text);
|
||||||
// AWS Bedrock wraps the response in a 'body' field
|
// AWS Bedrock wraps the response in a 'body' field
|
||||||
@ -317,7 +318,10 @@ export function extractMessage(res: any, modelId: string = ""): string {
|
|||||||
|
|
||||||
// Handle Mistral model response format
|
// Handle Mistral model response format
|
||||||
if (modelId.toLowerCase().includes("mistral")) {
|
if (modelId.toLowerCase().includes("mistral")) {
|
||||||
return res?.outputs?.[0]?.text || "";
|
if (res.choices?.[0]?.message?.content) {
|
||||||
|
return res.choices[0].message.content;
|
||||||
|
}
|
||||||
|
return res.output || "";
|
||||||
}
|
}
|
||||||
|
|
||||||
// Handle Llama model response format
|
// Handle Llama model response format
|
||||||
@ -334,9 +338,7 @@ export async function* transformBedrockStream(
|
|||||||
modelId: string,
|
modelId: string,
|
||||||
) {
|
) {
|
||||||
const reader = stream.getReader();
|
const reader = stream.getReader();
|
||||||
let accumulatedText = "";
|
let toolInput = "";
|
||||||
let toolCallStarted = false;
|
|
||||||
let currentToolCall = null;
|
|
||||||
|
|
||||||
try {
|
try {
|
||||||
while (true) {
|
while (true) {
|
||||||
@ -349,90 +351,54 @@ export async function* transformBedrockStream(
|
|||||||
|
|
||||||
// console.log("parseEventData=========================");
|
// console.log("parseEventData=========================");
|
||||||
// console.log(parsed);
|
// console.log(parsed);
|
||||||
|
|
||||||
// Handle Mistral models
|
// Handle Mistral models
|
||||||
if (modelId.toLowerCase().includes("mistral")) {
|
if (modelId.toLowerCase().includes("mistral")) {
|
||||||
// If we have content, accumulate it
|
// Handle tool calls
|
||||||
if (
|
if (parsed.choices?.[0]?.message?.tool_calls) {
|
||||||
parsed.choices?.[0]?.message?.role === "assistant" &&
|
const toolCalls = parsed.choices[0].message.tool_calls;
|
||||||
parsed.choices?.[0]?.message?.content
|
for (const toolCall of toolCalls) {
|
||||||
) {
|
// Emit tool call start
|
||||||
accumulatedText += parsed.choices?.[0]?.message?.content;
|
yield `data: ${JSON.stringify({
|
||||||
// console.log("accumulatedText=========================");
|
type: "content_block_start",
|
||||||
// console.log(accumulatedText);
|
content_block: {
|
||||||
// Check for tool call in the accumulated text
|
type: "tool_use",
|
||||||
if (!toolCallStarted && accumulatedText.includes("```json")) {
|
id: toolCall.id || `tool-${Date.now()}`,
|
||||||
const jsonMatch = accumulatedText.match(
|
name: toolCall.function?.name,
|
||||||
/```json\s*({[\s\S]*?})\s*```/,
|
},
|
||||||
);
|
})}\n\n`;
|
||||||
if (jsonMatch) {
|
|
||||||
try {
|
|
||||||
const toolData = JSON.parse(jsonMatch[1]);
|
|
||||||
currentToolCall = {
|
|
||||||
id: `tool-${Date.now()}`,
|
|
||||||
name: toolData.name,
|
|
||||||
arguments: toolData.arguments,
|
|
||||||
};
|
|
||||||
|
|
||||||
// Emit tool call start
|
// Emit tool arguments
|
||||||
yield `data: ${JSON.stringify({
|
if (toolCall.function?.arguments) {
|
||||||
type: "content_block_start",
|
yield `data: ${JSON.stringify({
|
||||||
content_block: {
|
type: "content_block_delta",
|
||||||
type: "tool_use",
|
delta: {
|
||||||
id: currentToolCall.id,
|
type: "input_json_delta",
|
||||||
name: currentToolCall.name,
|
partial_json: toolCall.function.arguments,
|
||||||
},
|
},
|
||||||
})}\n\n`;
|
})}\n\n`;
|
||||||
|
|
||||||
// Emit tool arguments
|
|
||||||
yield `data: ${JSON.stringify({
|
|
||||||
type: "content_block_delta",
|
|
||||||
delta: {
|
|
||||||
type: "input_json_delta",
|
|
||||||
partial_json: JSON.stringify(currentToolCall.arguments),
|
|
||||||
},
|
|
||||||
})}\n\n`;
|
|
||||||
|
|
||||||
// Emit tool call stop
|
|
||||||
yield `data: ${JSON.stringify({
|
|
||||||
type: "content_block_stop",
|
|
||||||
})}\n\n`;
|
|
||||||
|
|
||||||
// Clear the accumulated text after processing the tool call
|
|
||||||
accumulatedText = accumulatedText.replace(
|
|
||||||
/```json\s*{[\s\S]*?}\s*```/,
|
|
||||||
"",
|
|
||||||
);
|
|
||||||
toolCallStarted = false;
|
|
||||||
currentToolCall = null;
|
|
||||||
} catch (e) {
|
|
||||||
console.error("Failed to parse tool JSON:", e);
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
}
|
|
||||||
// emit the text content if it's not empty
|
// Emit tool call stop
|
||||||
if (parsed.choices?.[0]?.message?.content.trim()) {
|
|
||||||
yield `data: ${JSON.stringify({
|
yield `data: ${JSON.stringify({
|
||||||
delta: { text: parsed.choices?.[0]?.message?.content },
|
type: "content_block_stop",
|
||||||
})}\n\n`;
|
|
||||||
}
|
|
||||||
// Handle stop reason if present
|
|
||||||
if (parsed.choices?.[0]?.stop_reason) {
|
|
||||||
yield `data: ${JSON.stringify({
|
|
||||||
delta: { stop_reason: parsed.choices[0].stop_reason },
|
|
||||||
})}\n\n`;
|
})}\n\n`;
|
||||||
}
|
}
|
||||||
|
continue;
|
||||||
}
|
}
|
||||||
}
|
|
||||||
// Handle Llama models
|
// Handle regular content
|
||||||
else if (modelId.toLowerCase().includes("llama")) {
|
const content = parsed.choices?.[0]?.message?.content;
|
||||||
if (parsed.generation) {
|
if (content?.trim()) {
|
||||||
yield `data: ${JSON.stringify({
|
yield `data: ${JSON.stringify({
|
||||||
delta: { text: parsed.generation },
|
delta: { text: content },
|
||||||
})}\n\n`;
|
})}\n\n`;
|
||||||
}
|
}
|
||||||
if (parsed.stop_reason) {
|
|
||||||
|
// Handle stop reason
|
||||||
|
if (parsed.choices?.[0]?.finish_reason) {
|
||||||
yield `data: ${JSON.stringify({
|
yield `data: ${JSON.stringify({
|
||||||
delta: { stop_reason: parsed.stop_reason },
|
delta: { stop_reason: parsed.choices[0].finish_reason },
|
||||||
})}\n\n`;
|
})}\n\n`;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -469,8 +435,9 @@ export async function* transformBedrockStream(
|
|||||||
})}\n\n`;
|
})}\n\n`;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
} else {
|
}
|
||||||
// Handle other model text responses
|
// Handle other models
|
||||||
|
else {
|
||||||
const text = parsed.outputText || parsed.generation || "";
|
const text = parsed.outputText || parsed.generation || "";
|
||||||
if (text) {
|
if (text) {
|
||||||
yield `data: ${JSON.stringify({
|
yield `data: ${JSON.stringify({
|
||||||
|
Loading…
Reference in New Issue
Block a user