完善总结功能的代码逻辑

This commit is contained in:
glay 2024-11-23 15:13:52 +08:00
parent ff88421904
commit a6337e9f23
3 changed files with 37 additions and 21 deletions

View File

@ -199,7 +199,6 @@ async function requestBedrock(req: NextRequest) {
} }
const [_, credentials] = authHeader.split("Bearer "); const [_, credentials] = authHeader.split("Bearer ");
console.log("credentials===============" + credentials);
const [encryptedRegion, encryptedAccessKey, encryptedSecretKey] = const [encryptedRegion, encryptedAccessKey, encryptedSecretKey] =
credentials.split(":"); credentials.split(":");
@ -218,6 +217,7 @@ async function requestBedrock(req: NextRequest) {
} }
let modelId = req.headers.get("ModelID"); let modelId = req.headers.get("ModelID");
let shouldStream = req.headers.get("ShouldStream");
if (!awsRegion || !awsAccessKey || !awsSecretKey || !modelId) { if (!awsRegion || !awsAccessKey || !awsSecretKey || !modelId) {
throw new Error("Missing required AWS credentials or model ID"); throw new Error("Missing required AWS credentials or model ID");
} }
@ -232,7 +232,6 @@ async function requestBedrock(req: NextRequest) {
// Determine the endpoint and request body based on model type // Determine the endpoint and request body based on model type
let endpoint; let endpoint;
let requestBody; let requestBody;
let additionalHeaders = {};
const bodyText = await req.clone().text(); const bodyText = await req.clone().text();
if (!bodyText) { if (!bodyText) {
@ -242,15 +241,19 @@ async function requestBedrock(req: NextRequest) {
const bodyJson = JSON.parse(bodyText); const bodyJson = JSON.parse(bodyText);
validateRequest(bodyJson, modelId); validateRequest(bodyJson, modelId);
// For all other models, use standard endpoint // For all models, use standard endpoints
endpoint = `${baseEndpoint}/model/${modelId}/invoke-with-response-stream`; if (shouldStream === "false") {
endpoint = `${baseEndpoint}/model/${modelId}/invoke`;
} else {
endpoint = `${baseEndpoint}/model/${modelId}/invoke-with-response-stream`;
}
requestBody = JSON.stringify(bodyJson.body || bodyJson); requestBody = JSON.stringify(bodyJson.body || bodyJson);
console.log("Request to AWS Bedrock:", { // console.log("Request to AWS Bedrock:", {
endpoint, // endpoint,
modelId, // modelId,
body: requestBody, // body: requestBody,
}); // });
const headers = await sign({ const headers = await sign({
method: "POST", method: "POST",
@ -260,14 +263,12 @@ async function requestBedrock(req: NextRequest) {
secretAccessKey: awsSecretKey, secretAccessKey: awsSecretKey,
body: requestBody, body: requestBody,
service: "bedrock", service: "bedrock",
isStreaming: shouldStream !== "false",
}); });
const res = await fetch(endpoint, { const res = await fetch(endpoint, {
method: "POST", method: "POST",
headers: { headers,
...headers,
...additionalHeaders,
},
body: requestBody, body: requestBody,
redirect: "manual", redirect: "manual",
// @ts-ignore // @ts-ignore
@ -290,6 +291,15 @@ async function requestBedrock(req: NextRequest) {
throw new Error("Empty response from Bedrock"); throw new Error("Empty response from Bedrock");
} }
// Handle non-streaming response
if (shouldStream === "false") {
const responseText = await res.text();
console.error("AWS Bedrock shouldStream === false:", responseText);
const parsed = parseEventData(new TextEncoder().encode(responseText));
return NextResponse.json(parsed);
}
// Handle streaming response
const transformedStream = transformBedrockStream(res.body, modelId); const transformedStream = transformBedrockStream(res.body, modelId);
const stream = new ReadableStream({ const stream = new ReadableStream({
async start(controller) { async start(controller) {

View File

@ -391,6 +391,7 @@ export class BedrockApi implements LLMApi {
options, options,
); );
} else { } else {
headers.ShouldStream = "false";
const res = await fetch(chatPath, { const res = await fetch(chatPath, {
method: "POST", method: "POST",
headers, headers,

View File

@ -50,6 +50,8 @@ export interface SignParams {
secretAccessKey: string; secretAccessKey: string;
body: string; body: string;
service: string; service: string;
isStreaming?: boolean;
additionalHeaders?: Record<string, string>;
} }
function hmac( function hmac(
@ -133,6 +135,8 @@ export async function sign({
secretAccessKey, secretAccessKey,
body, body,
service, service,
isStreaming = true,
additionalHeaders = {},
}: SignParams): Promise<Record<string, string>> { }: SignParams): Promise<Record<string, string>> {
const endpoint = new URL(url); const endpoint = new URL(url);
const canonicalUri = "/" + encodeURI_RFC3986(endpoint.pathname.slice(1)); const canonicalUri = "/" + encodeURI_RFC3986(endpoint.pathname.slice(1));
@ -145,14 +149,20 @@ export async function sign({
const payloadHash = SHA256(body).toString(Hex); const payloadHash = SHA256(body).toString(Hex);
const headers: Record<string, string> = { const headers: Record<string, string> = {
accept: "application/vnd.amazon.eventstream", accept: isStreaming
? "application/vnd.amazon.eventstream"
: "application/json",
"content-type": "application/json", "content-type": "application/json",
host: endpoint.host, host: endpoint.host,
"x-amz-content-sha256": payloadHash, "x-amz-content-sha256": payloadHash,
"x-amz-date": amzDate, "x-amz-date": amzDate,
"x-amzn-bedrock-accept": "*/*", ...additionalHeaders,
}; };
if (isStreaming) {
headers["x-amzn-bedrock-accept"] = "*/*";
}
const sortedHeaderKeys = Object.keys(headers).sort((a, b) => const sortedHeaderKeys = Object.keys(headers).sort((a, b) =>
a.toLowerCase().localeCompare(b.toLowerCase()), a.toLowerCase().localeCompare(b.toLowerCase()),
); );
@ -195,12 +205,7 @@ export async function sign({
].join(", "); ].join(", ");
return { return {
Accept: headers.accept, ...headers,
"Content-Type": headers["content-type"],
Host: headers.host,
"X-Amz-Content-Sha256": headers["x-amz-content-sha256"],
"X-Amz-Date": headers["x-amz-date"],
"X-Amzn-Bedrock-Accept": headers["x-amzn-bedrock-accept"],
Authorization: authorization, Authorization: authorization,
}; };
} }