From 4254fd34f9a972dc3c1143ffe054ca15bf167e88 Mon Sep 17 00:00:00 2001 From: glay Date: Sat, 7 Dec 2024 14:20:59 +0800 Subject: [PATCH] =?UTF-8?q?=E5=A2=9E=E5=8A=A0bedrock=E6=9C=80=E6=96=B0nova?= =?UTF-8?q?=E6=A8=A1=E5=9E=8B=EF=BC=8C=E5=8C=85=E6=8B=ACimage=E8=A7=A3?= =?UTF-8?q?=E6=9E=90=E7=9A=84=E6=94=AF=E6=8C=81?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- app/client/platforms/bedrock.ts | 90 ++++++++++++++++++++++++--------- app/constant.ts | 8 +-- app/utils.ts | 2 + app/utils/aws.ts | 45 +++++++++++++---- 4 files changed, 109 insertions(+), 36 deletions(-) diff --git a/app/client/platforms/bedrock.ts b/app/client/platforms/bedrock.ts index 9f5932698..b2bd2c2bb 100644 --- a/app/client/platforms/bedrock.ts +++ b/app/client/platforms/bedrock.ts @@ -46,20 +46,68 @@ export class BedrockApi implements LLMApi { // Handle Nova models if (model.startsWith("us.amazon.nova")) { - return { + // Extract system message if present + const systemMessage = messages.find((m) => m.role === "system"); + const conversationMessages = messages.filter((m) => m.role !== "system"); + + const requestBody: any = { + schemaVersion: "messages-v1", + messages: conversationMessages.map((message) => { + const content = Array.isArray(message.content) + ? message.content + : [{ text: getMessageTextContent(message) }]; + + return { + role: message.role, + content: content.map((item: any) => { + // Handle text content + if (item.text || typeof item === "string") { + return { text: item.text || item }; + } + + // Handle image content + if (item.image_url?.url) { + const { url = "" } = item.image_url; + const colonIndex = url.indexOf(":"); + const semicolonIndex = url.indexOf(";"); + const comma = url.indexOf(","); + + // Extract format from mime type + const mimeType = url.slice(colonIndex + 1, semicolonIndex); + const format = mimeType.split("/")[1]; + const data = url.slice(comma + 1); + + return { + image: { + format, + source: { + bytes: data, + }, + }, + }; + } + return item; + }), + }; + }), inferenceConfig: { - max_tokens: modelConfig.max_tokens || 1000, + temperature: modelConfig.temperature || 0.7, + top_p: modelConfig.top_p || 0.9, + top_k: modelConfig.top_k || 50, + max_new_tokens: modelConfig.max_tokens || 1000, }, - messages: messages.map((message) => ({ - role: message.role, - content: [ - { - type: "text", - text: getMessageTextContent(message), - }, - ], - })), }; + + // Add system message if present + if (systemMessage) { + requestBody.system = [ + { + text: getMessageTextContent(systemMessage), + }, + ]; + } + + return requestBody; } // Handle Titan models @@ -426,10 +474,9 @@ function bedrockStream( let runTools: any[] = []; let responseRes: Response; let index = -1; - let chunks: Uint8Array[] = []; // 使用数组存储二进制数据块 - let pendingChunk: Uint8Array | null = null; // 存储不完整的数据块 + let chunks: Uint8Array[] = []; + let pendingChunk: Uint8Array | null = null; - // Animate response to make it looks smooth function animateResponseText() { if (finished || controller.signal.aborted) { responseText += remainText; @@ -451,7 +498,6 @@ function bedrockStream( requestAnimationFrame(animateResponseText); } - // Start animation animateResponseText(); const finish = () => { @@ -462,7 +508,7 @@ function bedrockStream( tool_calls: [...runTools], }; running = true; - runTools.splice(0, runTools.length); // empty runTools + runTools.splice(0, runTools.length); return Promise.all( toolCallMessage.tool_calls.map((tool) => { options?.onBeforeTool?.(tool); @@ -510,7 +556,6 @@ function bedrockStream( ).then((toolCallResult) => { processToolMessage(requestPayload, toolCallMessage, toolCallResult); setTimeout(() => { - // call again console.debug("[BedrockAPI for toolCallResult] restart"); running = false; bedrockChatApi(chatPath, headers, requestPayload, tools); @@ -562,13 +607,11 @@ function bedrockStream( contentType, ); - // Handle non-stream responses if (contentType?.startsWith("text/plain")) { responseText = await res.text(); return finish(); } - // Handle error responses if ( !res.ok || res.status !== 200 || @@ -593,7 +636,6 @@ function bedrockStream( return finish(); } - // Process the stream using chunks const reader = res.body?.getReader(); if (!reader) { throw new Error("No response body reader available"); @@ -603,7 +645,6 @@ function bedrockStream( while (true) { const { done, value } = await reader.read(); if (done) { - // Process final pending chunk if (pendingChunk) { try { const parsed = parseEventData(pendingChunk); @@ -624,10 +665,8 @@ function bedrockStream( break; } - // Add new chunk to queue chunks.push(value); - // Process chunk queue const result = processChunks( chunks, pendingChunk, @@ -648,6 +687,11 @@ function bedrockStream( finish(); } } catch (e) { + // @ts-ignore + if (e.name === "AbortError") { + console.log("[Bedrock Client] Aborted by user"); + return; + } console.error("[Bedrock Request] error", e); options.onError?.(e); throw e; diff --git a/app/constant.ts b/app/constant.ts index d53c52846..1c47ae372 100644 --- a/app/constant.ts +++ b/app/constant.ts @@ -329,10 +329,10 @@ const openaiModels = [ ]; const bedrockModels = [ - // Amazon Titan Models - "amazon.titan-text-express-v1", - "amazon.titan-text-lite-v1", - "amazon.titan-tg1-large", + // Amazon nova Models + "us.amazon.nova-micro-v1:0", + "us.amazon.nova-lite-v1:0", + "us.amazon.nova-pro-v1:0", // Claude Models "anthropic.claude-3-haiku-20240307-v1:0", "anthropic.claude-3-5-haiku-20241022-v1:0", diff --git a/app/utils.ts b/app/utils.ts index 30c2dde5d..a2a4f21dc 100644 --- a/app/utils.ts +++ b/app/utils.ts @@ -264,6 +264,8 @@ export function isVisionModel(model: string) { "learnlm", "qwen-vl", "qwen2-vl", + "nova-lite", + "nova-pro", ]; const isGpt4Turbo = model.includes("gpt-4-turbo") && !model.includes("preview"); diff --git a/app/utils/aws.ts b/app/utils/aws.ts index 912df4811..6e5943885 100644 --- a/app/utils/aws.ts +++ b/app/utils/aws.ts @@ -327,14 +327,35 @@ export function processMessage( if (!data) return { remainText, index }; try { - // Handle message_start event - if (data.type === "message_start") { - // Keep existing text but mark the start of a new message - console.debug("[Message Start] Current text:", remainText); + // Handle Nova's messageStart event + if (data.messageStart) { return { remainText, index }; } - // Handle content_block_start event + // Handle Nova's contentBlockDelta event + if (data.contentBlockDelta) { + if (data.contentBlockDelta.delta?.text) { + remainText += data.contentBlockDelta.delta.text; + } + return { remainText, index }; + } + + // Handle Nova's contentBlockStop event + if (data.contentBlockStop) { + return { remainText, index }; + } + + // Handle Nova's messageStop event + if (data.messageStop) { + return { remainText, index }; + } + + // Handle message_start event (for other models) + if (data.type === "message_start") { + return { remainText, index }; + } + + // Handle content_block_start event (for other models) if (data.type === "content_block_start") { if (data.content_block?.type === "tool_use") { index += 1; @@ -350,13 +371,12 @@ export function processMessage( return { remainText, index }; } - // Handle content_block_delta event + // Handle content_block_delta event (for other models) if (data.type === "content_block_delta") { if (data.delta?.type === "input_json_delta" && runTools[index]) { runTools[index].function.arguments += data.delta.partial_json; } else if (data.delta?.type === "text_delta") { const newText = data.delta.text || ""; - // console.debug("[Text Delta] Adding:", newText); remainText += newText; } return { remainText, index }; @@ -398,7 +418,6 @@ export function processMessage( // Only append if we have new text if (newText) { - // console.debug("[New Text] Adding:", newText); remainText += newText; } } catch (e) { @@ -530,8 +549,16 @@ export function extractMessage(res: any, modelId: string = ""): string { let message = ""; + // Handle Nova model response format + if (modelId.toLowerCase().includes("nova")) { + if (res.output?.message?.content?.[0]?.text) { + message = res.output.message.content[0].text; + } else { + message = res.output || ""; + } + } // Handle Mistral model response format - if (modelId.toLowerCase().includes("mistral")) { + else if (modelId.toLowerCase().includes("mistral")) { if (res.choices?.[0]?.message?.content) { message = res.choices[0].message.content; } else {