From afbf5eb541f4f4f3a53ba4aeeeb26e541cd0753c Mon Sep 17 00:00:00 2001 From: glay Date: Tue, 5 Nov 2024 14:27:52 +0800 Subject: [PATCH] =?UTF-8?q?=09=E4=BF=AE=E6=94=B9=EF=BC=9A=20=20=20=20=20.e?= =?UTF-8?q?nv.template=20=09=E4=BF=AE=E6=94=B9=EF=BC=9A=20=20=20=20=20app/?= =?UTF-8?q?api/auth.ts=20=09=E4=BF=AE=E6=94=B9=EF=BC=9A=20=20=20=20=20app/?= =?UTF-8?q?api/bedrock.ts=20=09=E4=BF=AE=E6=94=B9=EF=BC=9A=20=20=20=20=20a?= =?UTF-8?q?pp/client/api.ts=20=09=E4=BF=AE=E6=94=B9=EF=BC=9A=20=20=20=20?= =?UTF-8?q?=20app/client/platforms/bedrock.ts=20=09=E4=BF=AE=E6=94=B9?= =?UTF-8?q?=EF=BC=9A=20=20=20=20=20app/components/settings.tsx=20=09?= =?UTF-8?q?=E4=BF=AE=E6=94=B9=EF=BC=9A=20=20=20=20=20app/config/server.ts?= =?UTF-8?q?=20=09=E4=BF=AE=E6=94=B9=EF=BC=9A=20=20=20=20=20app/constant.t?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .env.template | 7 +- app/api/auth.ts | 10 +- app/api/bedrock.ts | 215 +++++++++++++++++--------------- app/client/api.ts | 15 ++- app/client/platforms/bedrock.ts | 53 ++++---- app/components/settings.tsx | 8 +- app/config/server.ts | 8 +- app/constant.ts | 8 +- app/store/access.ts | 13 +- 9 files changed, 184 insertions(+), 153 deletions(-) diff --git a/.env.template b/.env.template index 82f44216a..8c0c4c1cc 100644 --- a/.env.template +++ b/.env.template @@ -66,4 +66,9 @@ ANTHROPIC_API_VERSION= ANTHROPIC_URL= ### (optional) -WHITE_WEBDAV_ENDPOINTS= \ No newline at end of file +WHITE_WEBDAV_ENDPOINTS= + +### bedrock (optional) +AWS_REGION= +AWS_ACCESS_KEY= +AWS_SECRET_KEY= \ No newline at end of file diff --git a/app/api/auth.ts b/app/api/auth.ts index 1a0ae0b43..bb8ee1474 100644 --- a/app/api/auth.ts +++ b/app/api/auth.ts @@ -54,18 +54,18 @@ export function auth(req: NextRequest, modelProvider: ModelProvider) { } // Special handling for Bedrock if (modelProvider === ModelProvider.Bedrock) { - const region = req.headers.get("X-Region"); - const accessKeyId = req.headers.get("X-Access-Key"); - const secretKey = req.headers.get("X-Secret-Key"); + const region = serverConfig.awsRegion; + const accessKeyId = serverConfig.awsAccessKey; + const secretAccessKey = serverConfig.awsSecretKey; console.log("[Auth] Bedrock credentials:", { region, accessKeyId: accessKeyId ? "***" : undefined, - secretKey: secretKey ? "***" : undefined, + secretKey: secretAccessKey ? "***" : undefined, }); // Check if AWS credentials are provided - if (!region || !accessKeyId || !secretKey) { + if (!region || !accessKeyId || !secretAccessKey) { return { error: true, msg: "Missing AWS credentials. Please configure Region, Access Key ID, and Secret Access Key in settings.", diff --git a/app/api/bedrock.ts b/app/api/bedrock.ts index e3ca645bd..8b5ddc47e 100644 --- a/app/api/bedrock.ts +++ b/app/api/bedrock.ts @@ -1,7 +1,6 @@ -import { ModelProvider } from "../constant"; +import { getServerSideConfig } from "../config/server"; import { prettyObject } from "../utils/format"; import { NextRequest, NextResponse } from "next/server"; -import { auth } from "./auth"; import { BedrockRuntimeClient, ConverseStreamCommand, @@ -16,6 +15,15 @@ import { type ToolResultContentBlock, } from "@aws-sdk/client-bedrock-runtime"; +// 解密函数 +function decrypt(str: string): string { + try { + return Buffer.from(str, "base64").toString().split("").reverse().join(""); + } catch { + return ""; + } +} + // Constants and Types const ALLOWED_PATH = new Set(["converse"]); @@ -92,26 +100,6 @@ type DocumentFormat = | "txt" | "md"; -// Validation Functions -function validateModelId(modelId: string): string | null { - if ( - modelId.startsWith("meta.llama") && - !modelId.includes("inference-profile") - ) { - return "Llama models require an inference profile. Please use the full inference profile ARN."; - } - return null; -} - -function validateDocumentSize(base64Data: string): boolean { - const sizeInBytes = (base64Data.length * 3) / 4; - const maxSize = 4.5 * 1024 * 1024; - if (sizeInBytes > maxSize) { - throw new Error("Document size exceeds 4.5 MB limit"); - } - return true; -} - function validateImageSize(base64Data: string): boolean { const sizeInBytes = (base64Data.length * 3) / 4; const maxSize = 3.75 * 1024 * 1024; @@ -147,21 +135,6 @@ function convertContentToAWSBlock(item: ContentItem): ContentBlock | null { } } - if (item.type === "document" && item.document) { - validateDocumentSize(item.document.source.bytes); - return { - document: { - format: item.document.format, - name: item.document.name, - source: { - bytes: Uint8Array.from( - Buffer.from(item.document.source.bytes, "base64"), - ), - }, - }, - }; - } - if (item.type === "tool_use" && item.tool_use) { return { toolUse: { @@ -373,15 +346,48 @@ export async function handle( ); } - const authResult = auth(req, ModelProvider.Bedrock); - if (authResult.error) { - return NextResponse.json(authResult, { - status: 401, - }); + const serverConfig = getServerSideConfig(); + + // 首先尝试使用环境变量中的凭证 + let region = serverConfig.awsRegion; + let accessKeyId = serverConfig.awsAccessKey; + let secretAccessKey = serverConfig.awsSecretKey; + let sessionToken = undefined; + + // 如果环境变量中没有配置,则尝试使用前端传来的加密凭证 + if (!region || !accessKeyId || !secretAccessKey) { + // 解密前端传来的凭证 + region = decrypt(req.headers.get("X-Region") ?? ""); + accessKeyId = decrypt(req.headers.get("X-Access-Key") ?? ""); + secretAccessKey = decrypt(req.headers.get("X-Secret-Key") ?? ""); + sessionToken = req.headers.get("X-Session-Token") + ? decrypt(req.headers.get("X-Session-Token") ?? "") + : undefined; + } + + if (!region || !accessKeyId || !secretAccessKey) { + return NextResponse.json( + { + error: true, + msg: "AWS credentials not found in environment variables or request headers", + }, + { + status: 401, + }, + ); } try { - const response = await handleConverseRequest(req); + const client = new BedrockRuntimeClient({ + region, + credentials: { + accessKeyId, + secretAccessKey, + sessionToken, + }, + }); + + const response = await handleConverseRequest(req, client); return response; } catch (e) { console.error("[Bedrock] ", e); @@ -396,42 +402,14 @@ export async function handle( } } -async function handleConverseRequest(req: NextRequest) { - const region = req.headers.get("X-Region") || "us-west-2"; - const accessKeyId = req.headers.get("X-Access-Key") || ""; - const secretAccessKey = req.headers.get("X-Secret-Key") || ""; - const sessionToken = req.headers.get("X-Session-Token"); - - if (!accessKeyId || !secretAccessKey) { - return NextResponse.json( - { - error: true, - message: "Missing AWS credentials", - }, - { - status: 401, - }, - ); - } - - const client = new BedrockRuntimeClient({ - region, - credentials: { - accessKeyId, - secretAccessKey, - sessionToken: sessionToken || undefined, - }, - }); - +async function handleConverseRequest( + req: NextRequest, + client: BedrockRuntimeClient, +) { try { const body = (await req.json()) as ConverseRequest; const { modelId } = body; - const validationError = validateModelId(modelId); - if (validationError) { - throw new Error(validationError); - } - console.log("[Bedrock] Invoking model:", modelId); const command = new ConverseStreamCommand(formatRequestBody(body)); @@ -455,8 +433,9 @@ async function handleConverseRequest(req: NextRequest) { if ("messageStart" in output && output.messageStart?.role) { controller.enqueue( `data: ${JSON.stringify({ - type: "messageStart", - role: output.messageStart.role, + stream: { + messageStart: { role: output.messageStart.role }, + }, })}\n\n`, ); } else if ( @@ -465,9 +444,13 @@ async function handleConverseRequest(req: NextRequest) { ) { controller.enqueue( `data: ${JSON.stringify({ - type: "contentBlockStart", - index: output.contentBlockStart.contentBlockIndex, - start: output.contentBlockStart.start, + stream: { + contentBlockStart: { + contentBlockIndex: + output.contentBlockStart.contentBlockIndex, + start: output.contentBlockStart.start, + }, + }, })}\n\n`, ); } else if ( @@ -477,15 +460,30 @@ async function handleConverseRequest(req: NextRequest) { if ("text" in output.contentBlockDelta.delta) { controller.enqueue( `data: ${JSON.stringify({ - type: "text", - content: output.contentBlockDelta.delta.text, + stream: { + contentBlockDelta: { + delta: { text: output.contentBlockDelta.delta.text }, + contentBlockIndex: + output.contentBlockDelta.contentBlockIndex, + }, + }, })}\n\n`, ); } else if ("toolUse" in output.contentBlockDelta.delta) { controller.enqueue( `data: ${JSON.stringify({ - type: "toolUse", - input: output.contentBlockDelta.delta.toolUse?.input, + stream: { + contentBlockDelta: { + delta: { + toolUse: { + input: + output.contentBlockDelta.delta.toolUse?.input, + }, + }, + contentBlockIndex: + output.contentBlockDelta.contentBlockIndex, + }, + }, })}\n\n`, ); } @@ -495,26 +493,36 @@ async function handleConverseRequest(req: NextRequest) { ) { controller.enqueue( `data: ${JSON.stringify({ - type: "contentBlockStop", - index: output.contentBlockStop.contentBlockIndex, + stream: { + contentBlockStop: { + contentBlockIndex: + output.contentBlockStop.contentBlockIndex, + }, + }, })}\n\n`, ); } else if ("messageStop" in output && output.messageStop) { controller.enqueue( `data: ${JSON.stringify({ - type: "messageStop", - stopReason: output.messageStop.stopReason, - additionalModelResponseFields: - output.messageStop.additionalModelResponseFields, + stream: { + messageStop: { + stopReason: output.messageStop.stopReason, + additionalModelResponseFields: + output.messageStop.additionalModelResponseFields, + }, + }, })}\n\n`, ); } else if ("metadata" in output && output.metadata) { controller.enqueue( `data: ${JSON.stringify({ - type: "metadata", - usage: output.metadata.usage, - metrics: output.metadata.metrics, - trace: output.metadata.trace, + stream: { + metadata: { + usage: output.metadata.usage, + metrics: output.metadata.metrics, + trace: output.metadata.trace, + }, + }, })}\n\n`, ); } @@ -522,14 +530,17 @@ async function handleConverseRequest(req: NextRequest) { controller.close(); } catch (error) { const errorResponse = { - type: "error", - error: - error instanceof Error ? error.constructor.name : "UnknownError", - message: error instanceof Error ? error.message : "Unknown error", - ...(error instanceof ModelStreamErrorException && { - originalStatusCode: error.originalStatusCode, - originalMessage: error.originalMessage, - }), + stream: { + error: + error instanceof Error + ? error.constructor.name + : "UnknownError", + message: error instanceof Error ? error.message : "Unknown error", + ...(error instanceof ModelStreamErrorException && { + originalStatusCode: error.originalStatusCode, + originalMessage: error.originalMessage, + }), + }, }; controller.enqueue(`data: ${JSON.stringify(errorResponse)}\n\n`); controller.close(); diff --git a/app/client/api.ts b/app/client/api.ts index 05ce8a236..e7ba2fd5d 100644 --- a/app/client/api.ts +++ b/app/client/api.ts @@ -261,7 +261,7 @@ export function getHeaders(ignoreHeaders: boolean = false) { const apiKey = isGoogle ? accessStore.googleApiKey : isBedrock - ? accessStore.awsAccessKeyId // Use AWS access key for Bedrock + ? accessStore.awsAccessKey // Use AWS access key for Bedrock : isAzure ? accessStore.azureApiKey : isAnthropic @@ -322,12 +322,15 @@ export function getHeaders(ignoreHeaders: boolean = false) { const authHeader = getAuthHeader(); if (isBedrock) { - // Add AWS credentials for Bedrock - headers["X-Region"] = accessStore.awsRegion; - headers["X-Access-Key"] = accessStore.awsAccessKeyId; - headers["X-Secret-Key"] = accessStore.awsSecretAccessKey; + // 简单加密 AWS credentials + const encrypt = (str: string) => + Buffer.from(str.split("").reverse().join("")).toString("base64"); + + headers["X-Region"] = encrypt(accessStore.awsRegion); + headers["X-Access-Key"] = encrypt(accessStore.awsAccessKey); + headers["X-Secret-Key"] = encrypt(accessStore.awsSecretKey); if (accessStore.awsSessionToken) { - headers["X-Session-Token"] = accessStore.awsSessionToken; + headers["X-Session-Token"] = encrypt(accessStore.awsSessionToken); } } else { const bearerToken = getBearerToken( diff --git a/app/client/platforms/bedrock.ts b/app/client/platforms/bedrock.ts index 5e1f9f0a2..b44070352 100644 --- a/app/client/platforms/bedrock.ts +++ b/app/client/platforms/bedrock.ts @@ -8,7 +8,6 @@ import { SpeechOptions, } from "../api"; import { - useAccessStore, useAppConfig, usePluginStore, useChatStore, @@ -60,7 +59,6 @@ export class BedrockApi implements LLMApi { async chat(options: ChatOptions): Promise { const visionModel = isVisionModel(options.config.model); - const accessStore = useAccessStore.getState(); const shouldStream = !!options.config.stream; const modelConfig = { ...useAppConfig.getState().modelConfig, @@ -69,12 +67,6 @@ export class BedrockApi implements LLMApi { model: options.config.model, }, }; - const headers: Record = { - ...getHeaders(), - "X-Region": accessStore.awsRegion, - "X-Access-Key": accessStore.awsAccessKeyId, - "X-Secret-Key": accessStore.awsSecretAccessKey, - }; // try get base64image from local cache image_url const messages: ChatOptions["messages"] = []; @@ -196,7 +188,7 @@ export class BedrockApi implements LLMApi { return stream( conversePath, requestBody, - headers, + getHeaders(), Array.isArray(tools) ? tools.map((tool: any) => ({ name: tool?.function?.name, @@ -208,14 +200,20 @@ export class BedrockApi implements LLMApi { controller, // parseSSE (text: string, runTools: ChatMessageTool[]) => { - const event = JSON.parse(text); + const parsed = JSON.parse(text); + const event = parsed.stream; - if (event.type === "messageStart") { + if (!event) { + console.warn("[Bedrock] Unexpected event format:", parsed); return ""; } - if (event.type === "contentBlockStart" && event.start?.toolUse) { - const { toolUseId, name } = event.start.toolUse; + if (event.messageStart) { + return ""; + } + + if (event.contentBlockStart?.start?.toolUse) { + const { toolUseId, name } = event.contentBlockStart.start.toolUse; currentToolUse = { id: toolUseId, type: "function", @@ -228,21 +226,34 @@ export class BedrockApi implements LLMApi { return ""; } - if (event.type === "text" && event.content) { - return event.content; + if (event.contentBlockDelta?.delta?.text) { + return event.contentBlockDelta.delta.text; } if ( - event.type === "toolUse" && - event.input && + event.contentBlockDelta?.delta?.toolUse?.input && currentToolUse?.function ) { - currentToolUse.function.arguments += event.input; + currentToolUse.function.arguments += + event.contentBlockDelta.delta.toolUse.input; return ""; } - if (event.type === "error") { - throw new Error(event.message || "Unknown error"); + if ( + event.internalServerException || + event.modelStreamErrorException || + event.validationException || + event.throttlingException || + event.serviceUnavailableException + ) { + const errorMessage = + event.internalServerException?.message || + event.modelStreamErrorException?.message || + event.validationException?.message || + event.throttlingException?.message || + event.serviceUnavailableException?.message || + "Unknown error"; + throw new Error(errorMessage); } return ""; @@ -284,7 +295,7 @@ export class BedrockApi implements LLMApi { try { const response = await fetch(conversePath, { method: "POST", - headers, + headers: getHeaders(), body: JSON.stringify(requestBody), signal: controller.signal, }); diff --git a/app/components/settings.tsx b/app/components/settings.tsx index 9c6d9793c..ddd6a5c15 100644 --- a/app/components/settings.tsx +++ b/app/components/settings.tsx @@ -988,12 +988,12 @@ export function Settings() { > { accessStore.update( - (access) => (access.awsAccessKeyId = e.currentTarget.value), + (access) => (access.awsAccessKey = e.currentTarget.value), ); }} /> @@ -1004,12 +1004,12 @@ export function Settings() { > { accessStore.update( - (access) => (access.awsSecretAccessKey = e.currentTarget.value), + (access) => (access.awsSecretKey = e.currentTarget.value), ); }} /> diff --git a/app/config/server.ts b/app/config/server.ts index 7e130aa0e..5250b0610 100644 --- a/app/config/server.ts +++ b/app/config/server.ts @@ -13,8 +13,9 @@ declare global { OPENAI_ORG_ID?: string; // openai only // bedrock only - BEDROCK_URL?: string; + BEDROCK_REGION?: string; BEDROCK_API_KEY?: string; + BEDROCK_API_SECRET?: string; VERCEL?: string; BUILD_MODE?: "standalone" | "export"; @@ -173,8 +174,9 @@ export const getServerSideConfig = () => { openaiOrgId: process.env.OPENAI_ORG_ID, isBedrock, - bedrockUrl: process.env.BEDROCK_URL, - bedrockApiKey: getApiKey(process.env.BEDROCK_API_KEY), + awsRegion: process.env.AWS_REGION, + awsAccessKey: process.env.AWS_ACCESS_KEY, + awsSecretKey: process.env.AWS_SECRET_KEY, isStability, stabilityUrl: process.env.STABILITY_URL, diff --git a/app/constant.ts b/app/constant.ts index 7d7e099cd..1b8aa49d8 100644 --- a/app/constant.ts +++ b/app/constant.ts @@ -230,6 +230,10 @@ export const XAI = { ChatPath: "v1/chat/completions", }; +export const Bedrock = { + ChatPath: "converse", +}; + export const DEFAULT_INPUT_TEMPLATE = `{{input}}`; // input / time / model / lang // export const DEFAULT_SYSTEM_TEMPLATE = ` // You are ChatGPT, a large language model trained by {{ServiceProvider}}. @@ -312,9 +316,11 @@ const openaiModels = [ const bedrockModels = [ // Claude Models "anthropic.claude-3-haiku-20240307-v1:0", + "anthropic.claude-3-5-haiku-20241022-v1:0", "anthropic.claude-3-sonnet-20240229-v1:0", - "anthropic.claude-3-opus-20240229-v1:0", "anthropic.claude-3-5-sonnet-20241022-v2:0", + "anthropic.claude-3-opus-20240229-v1:0", + // Meta Llama Models "us.meta.llama3-2-11b-instruct-v1:0", "us.meta.llama3-2-90b-instruct-v1:0", diff --git a/app/store/access.ts b/app/store/access.ts index 11127cbed..b5f765cfc 100644 --- a/app/store/access.ts +++ b/app/store/access.ts @@ -60,14 +60,11 @@ const DEFAULT_ACCESS_STATE = { openaiApiKey: "", // bedrock - bedrockUrl: DEFAULT_BEDROCK_URL, - bedrockApiKey: "", awsRegion: "", - awsAccessKeyId: "", - awsSecretAccessKey: "", + awsAccessKey: "", + awsSecretKey: "", awsSessionToken: "", awsCognitoUser: false, - awsInferenceProfile: "", // Added inference profile field // azure azureUrl: "", @@ -154,11 +151,7 @@ export const useAccessStore = createPersistStore( }, isValidBedrock() { - return ensure(get(), [ - "awsAccessKeyId", - "awsSecretAccessKey", - "awsRegion", - ]); + return ensure(get(), ["awsAccessKey", "awsSecretKey", "awsRegion"]); }, isValidAzure() {