增加bedrock最新nova模型,优化后台代码

This commit is contained in:
glay 2024-12-07 16:49:26 +08:00
parent 4254fd34f9
commit 57dc44a54f
3 changed files with 148 additions and 64 deletions

View File

@ -80,54 +80,16 @@ async function requestBedrock(req: NextRequest) {
} catch (e) {
throw new Error(`Invalid JSON in request body: ${e}`);
}
// console.log(
// "[Bedrock Request] original Body:",
// JSON.stringify(bodyJson, null, 2),
// );
// Extract tool configuration if present
let tools: any[] | undefined;
if (bodyJson.tools) {
tools = bodyJson.tools;
delete bodyJson.tools; // Remove from main request body
}
console.log("[Bedrock Request] Initiating request");
// Get endpoint and prepare request
const endpoint = getBedrockEndpoint(
credentials.region,
modelId,
shouldStream,
);
console.log("[Bedrock Request] Initiating request");
// Handle tools for different models
const isMistralLargeModel = modelId
.toLowerCase()
.includes("mistral.mistral-large");
const isClaudeModel = modelId.toLowerCase().includes("claude");
const requestBody: any = {
...bodyJson,
};
if (tools && tools.length > 0) {
if (isMistralLargeModel) {
// Add tools in Mistral's format
requestBody.tool_choice = "auto";
requestBody.tools = tools.map((tool) => ({
type: "function",
function: {
name: tool.name,
description: tool.description,
parameters: tool.input_schema,
},
}));
} else if (isClaudeModel) {
requestBody.tools = tools;
}
}
// Sign request
const headers = await sign({
method: "POST",

View File

@ -9,7 +9,7 @@ import {
} from "@/app/store";
import { preProcessImageContent } from "@/app/utils/chat";
import { getMessageTextContent, isVisionModel } from "@/app/utils";
import { ApiPath, BEDROCK_BASE_URL } from "@/app/constant";
import { ApiPath, BEDROCK_BASE_URL, REQUEST_TIMEOUT_MS } from "@/app/constant";
import { getClientConfig } from "@/app/config/client";
import {
extractMessage,
@ -18,8 +18,6 @@ import {
parseEventData,
sign,
} from "@/app/utils/aws";
import { RequestPayload } from "./openai";
import { REQUEST_TIMEOUT_MS } from "@/app/constant";
import { prettyObject } from "@/app/utils/format";
import Locale from "@/app/locales";
@ -35,6 +33,15 @@ const MistralMapper = {
assistant: "assistant",
} as const;
type MistralRole = keyof typeof MistralMapper;
interface Tool {
function?: {
name?: string;
description?: string;
parameters?: any;
};
}
export class BedrockApi implements LLMApi {
speech(options: SpeechOptions): Promise<ArrayBuffer> {
throw new Error("Speech not implemented for Bedrock.");
@ -44,8 +51,15 @@ export class BedrockApi implements LLMApi {
const model = modelConfig.model;
const visionModel = isVisionModel(modelConfig.model);
// Get tools if available
const [tools] = usePluginStore
.getState()
.getAsTools(useChatStore.getState().currentSession().mask?.plugin || []);
const toolsArray = (tools as Tool[]) || [];
// Handle Nova models
if (model.startsWith("us.amazon.nova")) {
if (model.includes("amazon.nova")) {
// Extract system message if present
const systemMessage = messages.find((m) => m.role === "system");
const conversationMessages = messages.filter((m) => m.role !== "system");
@ -107,6 +121,26 @@ export class BedrockApi implements LLMApi {
];
}
// Add tools if available - now in correct format
if (toolsArray.length > 0) {
requestBody.toolConfig = {
tools: toolsArray.map((tool) => ({
toolSpec: {
name: tool?.function?.name || "",
description: tool?.function?.description || "",
inputSchema: {
json: {
type: "object",
properties: tool?.function?.parameters || {},
required: Object.keys(tool?.function?.parameters || {}),
},
},
},
})),
// toolChoice: { auto: {} }
};
}
return requestBody;
}
@ -160,18 +194,33 @@ export class BedrockApi implements LLMApi {
}
// Handle Mistral models
if (model.startsWith("mistral.mistral")) {
if (model.includes("mistral.mistral")) {
const formattedMessages = messages.map((message) => ({
role: MistralMapper[message.role as MistralRole] || "user",
content: getMessageTextContent(message),
}));
return {
const requestBody: any = {
messages: formattedMessages,
max_tokens: modelConfig.max_tokens || 4096,
temperature: modelConfig.temperature || 0.7,
top_p: modelConfig.top_p || 0.9,
};
// Add tools if available
if (toolsArray.length > 0) {
requestBody.tool_choice = "auto";
requestBody.tools = toolsArray.map((tool) => ({
type: "function",
function: {
name: tool?.function?.name,
description: tool?.function?.description,
parameters: tool?.function?.parameters,
},
}));
}
return requestBody;
}
// Handle Claude models
@ -254,6 +303,16 @@ export class BedrockApi implements LLMApi {
top_p: modelConfig.top_p || 0.9,
top_k: modelConfig.top_k || 5,
};
// Add tools if available for Claude models
if (toolsArray.length > 0 && model.includes("anthropic.claude")) {
requestBody.tools = toolsArray.map((tool) => ({
name: tool?.function?.name || "",
description: tool?.function?.description || "",
input_schema: tool?.function?.parameters || {},
}));
}
return requestBody;
}
@ -333,23 +392,18 @@ export class BedrockApi implements LLMApi {
chatPath,
finalRequestBody,
headers,
// @ts-ignore
tools.map((tool) => ({
name: tool?.function?.name,
description: tool?.function?.description,
input_schema: tool?.function?.parameters,
})),
funcs,
controller,
// processToolMessage, include tool_calls message and tool call results
(
requestPayload: RequestPayload,
requestPayload: any[],
toolCallMessage: any,
toolCallResult: any[],
) => {
const modelId = modelConfig.model;
const isMistral = modelId.startsWith("mistral.mistral");
const isMistral = modelId.includes("mistral.mistral");
const isClaude = modelId.includes("anthropic.claude");
const isNova = modelId.includes("amazon.nova");
if (isClaude) {
// Format for Claude
@ -385,7 +439,9 @@ export class BedrockApi implements LLMApi {
);
} else if (isMistral) {
// Format for Mistral
// @ts-ignore
requestPayload?.messages?.splice(
// @ts-ignore
requestPayload?.messages?.length,
0,
{
@ -408,6 +464,47 @@ export class BedrockApi implements LLMApi {
content: result.content,
})),
);
} else if (isNova) {
// Format for Nova
// @ts-ignore
requestPayload?.messages?.splice(
// @ts-ignore
requestPayload?.messages?.length,
0,
{
role: "assistant",
content: [
{
text: "", // Add empty text content to satisfy type requirements
tool_calls: toolCallMessage.tool_calls.map(
(tool: ChatMessageTool) => ({
id: tool.id,
name: tool?.function?.name,
arguments: tool?.function?.arguments
? JSON.parse(tool?.function?.arguments)
: {},
}),
),
},
],
},
...toolCallResult.map((result) => ({
role: "user",
content: [
{
toolUseId: result.tool_call_id,
content: [
{
json:
typeof result.content === "string"
? JSON.parse(result.content)
: result.content,
},
],
},
],
})),
);
} else {
console.warn(
`[Bedrock Client] Unhandled model type for tool calls: ${modelId}`,
@ -457,7 +554,6 @@ function bedrockStream(
chatPath: string,
requestPayload: any,
headers: any,
tools: any[],
funcs: Record<string, Function>,
controller: AbortController,
processToolMessage: (
@ -512,8 +608,13 @@ function bedrockStream(
return Promise.all(
toolCallMessage.tool_calls.map((tool) => {
options?.onBeforeTool?.(tool);
const funcName = tool?.function?.name || tool?.name;
if (!funcName || !funcs[funcName]) {
console.error(`Function ${funcName} not found in funcs:`, funcs);
return Promise.reject(`Function ${funcName} not found`);
}
return Promise.resolve(
funcs[tool.function.name](
funcs[funcName](
tool?.function?.arguments
? JSON.parse(tool?.function?.arguments)
: {},
@ -547,7 +648,7 @@ function bedrockStream(
return e.toString();
})
.then((content) => ({
name: tool.function.name,
name: funcName,
role: "tool",
content,
tool_call_id: tool.id,
@ -558,7 +659,7 @@ function bedrockStream(
setTimeout(() => {
console.debug("[BedrockAPI for toolCallResult] restart");
running = false;
bedrockChatApi(chatPath, headers, requestPayload, tools);
bedrockChatApi(chatPath, headers, requestPayload);
}, 60);
});
}
@ -577,7 +678,6 @@ function bedrockStream(
chatPath: string,
headers: any,
requestPayload: any,
tools: any,
) {
const requestTimeoutId = setTimeout(
() => controller.abort(),
@ -588,10 +688,7 @@ function bedrockStream(
const res = await fetch(chatPath, {
method: "POST",
headers,
body: JSON.stringify({
...requestPayload,
tools: tools && tools.length ? tools : undefined,
}),
body: JSON.stringify(requestPayload),
redirect: "manual",
// @ts-ignore
duplex: "half",
@ -699,5 +796,5 @@ function bedrockStream(
}
console.debug("[BedrockAPI] start");
bedrockChatApi(chatPath, headers, requestPayload, tools);
bedrockChatApi(chatPath, headers, requestPayload);
}

View File

@ -327,6 +327,31 @@ export function processMessage(
if (!data) return { remainText, index };
try {
// Handle Nova's tool calls
// console.log("processMessage data=========================",data);
if (
data.stopReason === "tool_use" &&
data.output?.message?.content?.[0]?.toolUse
) {
const toolUse = data.output.message.content[0].toolUse;
index += 1;
runTools.push({
id: `tool-${Date.now()}`,
type: "function",
function: {
name: toolUse.name,
arguments: JSON.stringify(toolUse.input),
},
});
return { remainText, index };
}
// Handle Nova's text content
if (data.output?.message?.content?.[0]?.text) {
remainText += data.output.message.content[0].text;
return { remainText, index };
}
// Handle Nova's messageStart event
if (data.messageStart) {
return { remainText, index };
@ -382,7 +407,7 @@ export function processMessage(
return { remainText, index };
}
// Handle tool calls
// Handle tool calls for other models
if (data.choices?.[0]?.message?.tool_calls) {
for (const toolCall of data.choices[0].message.tool_calls) {
index += 1;