From a6337e9f235596685c1b7062447e1bb5d917eb80 Mon Sep 17 00:00:00 2001 From: glay Date: Sat, 23 Nov 2024 15:13:52 +0800 Subject: [PATCH] =?UTF-8?q?=E5=AE=8C=E5=96=84=E6=80=BB=E7=BB=93=E5=8A=9F?= =?UTF-8?q?=E8=83=BD=E7=9A=84=E4=BB=A3=E7=A0=81=E9=80=BB=E8=BE=91?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- app/api/bedrock.ts | 36 +++++++++++++++++++++------------ app/client/platforms/bedrock.ts | 1 + app/utils/aws.ts | 21 +++++++++++-------- 3 files changed, 37 insertions(+), 21 deletions(-) diff --git a/app/api/bedrock.ts b/app/api/bedrock.ts index 9677bf5ca..e342a8867 100644 --- a/app/api/bedrock.ts +++ b/app/api/bedrock.ts @@ -199,7 +199,6 @@ async function requestBedrock(req: NextRequest) { } const [_, credentials] = authHeader.split("Bearer "); - console.log("credentials===============" + credentials); const [encryptedRegion, encryptedAccessKey, encryptedSecretKey] = credentials.split(":"); @@ -218,6 +217,7 @@ async function requestBedrock(req: NextRequest) { } let modelId = req.headers.get("ModelID"); + let shouldStream = req.headers.get("ShouldStream"); if (!awsRegion || !awsAccessKey || !awsSecretKey || !modelId) { throw new Error("Missing required AWS credentials or model ID"); } @@ -232,7 +232,6 @@ async function requestBedrock(req: NextRequest) { // Determine the endpoint and request body based on model type let endpoint; let requestBody; - let additionalHeaders = {}; const bodyText = await req.clone().text(); if (!bodyText) { @@ -242,15 +241,19 @@ async function requestBedrock(req: NextRequest) { const bodyJson = JSON.parse(bodyText); validateRequest(bodyJson, modelId); - // For all other models, use standard endpoint - endpoint = `${baseEndpoint}/model/${modelId}/invoke-with-response-stream`; + // For all models, use standard endpoints + if (shouldStream === "false") { + endpoint = `${baseEndpoint}/model/${modelId}/invoke`; + } else { + endpoint = `${baseEndpoint}/model/${modelId}/invoke-with-response-stream`; + } requestBody = JSON.stringify(bodyJson.body || bodyJson); - console.log("Request to AWS Bedrock:", { - endpoint, - modelId, - body: requestBody, - }); + // console.log("Request to AWS Bedrock:", { + // endpoint, + // modelId, + // body: requestBody, + // }); const headers = await sign({ method: "POST", @@ -260,14 +263,12 @@ async function requestBedrock(req: NextRequest) { secretAccessKey: awsSecretKey, body: requestBody, service: "bedrock", + isStreaming: shouldStream !== "false", }); const res = await fetch(endpoint, { method: "POST", - headers: { - ...headers, - ...additionalHeaders, - }, + headers, body: requestBody, redirect: "manual", // @ts-ignore @@ -290,6 +291,15 @@ async function requestBedrock(req: NextRequest) { throw new Error("Empty response from Bedrock"); } + // Handle non-streaming response + if (shouldStream === "false") { + const responseText = await res.text(); + console.error("AWS Bedrock shouldStream === false:", responseText); + const parsed = parseEventData(new TextEncoder().encode(responseText)); + return NextResponse.json(parsed); + } + + // Handle streaming response const transformedStream = transformBedrockStream(res.body, modelId); const stream = new ReadableStream({ async start(controller) { diff --git a/app/client/platforms/bedrock.ts b/app/client/platforms/bedrock.ts index d55173746..0f7d73022 100644 --- a/app/client/platforms/bedrock.ts +++ b/app/client/platforms/bedrock.ts @@ -391,6 +391,7 @@ export class BedrockApi implements LLMApi { options, ); } else { + headers.ShouldStream = "false"; const res = await fetch(chatPath, { method: "POST", headers, diff --git a/app/utils/aws.ts b/app/utils/aws.ts index 92c5cc6b5..d2997412f 100644 --- a/app/utils/aws.ts +++ b/app/utils/aws.ts @@ -50,6 +50,8 @@ export interface SignParams { secretAccessKey: string; body: string; service: string; + isStreaming?: boolean; + additionalHeaders?: Record; } function hmac( @@ -133,6 +135,8 @@ export async function sign({ secretAccessKey, body, service, + isStreaming = true, + additionalHeaders = {}, }: SignParams): Promise> { const endpoint = new URL(url); const canonicalUri = "/" + encodeURI_RFC3986(endpoint.pathname.slice(1)); @@ -145,14 +149,20 @@ export async function sign({ const payloadHash = SHA256(body).toString(Hex); const headers: Record = { - accept: "application/vnd.amazon.eventstream", + accept: isStreaming + ? "application/vnd.amazon.eventstream" + : "application/json", "content-type": "application/json", host: endpoint.host, "x-amz-content-sha256": payloadHash, "x-amz-date": amzDate, - "x-amzn-bedrock-accept": "*/*", + ...additionalHeaders, }; + if (isStreaming) { + headers["x-amzn-bedrock-accept"] = "*/*"; + } + const sortedHeaderKeys = Object.keys(headers).sort((a, b) => a.toLowerCase().localeCompare(b.toLowerCase()), ); @@ -195,12 +205,7 @@ export async function sign({ ].join(", "); return { - Accept: headers.accept, - "Content-Type": headers["content-type"], - Host: headers.host, - "X-Amz-Content-Sha256": headers["x-amz-content-sha256"], - "X-Amz-Date": headers["x-amz-date"], - "X-Amzn-Bedrock-Accept": headers["x-amzn-bedrock-accept"], + ...headers, Authorization: authorization, }; }