From f5ae086d3c031ad4c81aa0e60db1ce8f8d266017 Mon Sep 17 00:00:00 2001 From: glay Date: Sun, 8 Dec 2024 23:28:59 +0800 Subject: [PATCH] Enhance encryption security with additional safeguards. --- app/api/bedrock.ts | 12 +- app/client/api.ts | 11 -- app/client/platforms/bedrock.ts | 109 ++++++++++++---- app/components/ui-lib.tsx | 26 ++-- app/utils/aws.ts | 224 ++++++++++++++++++++------------ package.json | 4 +- 6 files changed, 245 insertions(+), 141 deletions(-) diff --git a/app/api/bedrock.ts b/app/api/bedrock.ts index 7da14a17b..85b4937e3 100644 --- a/app/api/bedrock.ts +++ b/app/api/bedrock.ts @@ -32,14 +32,20 @@ async function getBedrockCredentials( const [encryptedRegion, encryptedAccessKey, encryptedSecretKey] = credentials.split(":"); + // console.log("===========encryptedRegion",encryptedRegion); + // console.log("===========encryptedAccessKey",encryptedAccessKey); + // console.log("===========encryptedSecretKey",encryptedSecretKey); if (!encryptedRegion || !encryptedAccessKey || !encryptedSecretKey) { throw new Error("Invalid Authorization header format"); } const encryptionKey = req.headers.get("XEncryptionKey") || ""; + // console.log("===========encryptionKey",encryptionKey); // Decrypt the credentials - awsRegion = decrypt(encryptedRegion, encryptionKey); - awsAccessKey = decrypt(encryptedAccessKey, encryptionKey); - awsSecretKey = decrypt(encryptedSecretKey, encryptionKey); + [awsRegion, awsAccessKey, awsSecretKey] = await Promise.all([ + decrypt(encryptedRegion, encryptionKey), + decrypt(encryptedAccessKey, encryptionKey), + decrypt(encryptedSecretKey, encryptionKey), + ]); if (!awsRegion || !awsAccessKey || !awsSecretKey) { throw new Error( diff --git a/app/client/api.ts b/app/client/api.ts index 06537d1de..c476b23d6 100644 --- a/app/client/api.ts +++ b/app/client/api.ts @@ -23,7 +23,6 @@ import { SparkApi } from "./platforms/iflytek"; import { XAIApi } from "./platforms/xai"; import { ChatGLMApi } from "./platforms/glm"; import { BedrockApi } from "./platforms/bedrock"; -import { encrypt } from "../utils/aws"; export const ROLES = ["system", "user", "assistant"] as const; export type MessageRole = (typeof ROLES)[number]; @@ -276,16 +275,6 @@ export function getHeaders(ignoreHeaders: boolean = false) { ? accessStore.iflytekApiKey && accessStore.iflytekApiSecret ? accessStore.iflytekApiKey + ":" + accessStore.iflytekApiSecret : "" - : isBedrock - ? accessStore.awsRegion && - accessStore.awsAccessKey && - accessStore.awsSecretKey - ? encrypt(accessStore.awsRegion, accessStore.encryptionKey) + - ":" + - encrypt(accessStore.awsAccessKey, accessStore.encryptionKey) + - ":" + - encrypt(accessStore.awsSecretKey, accessStore.encryptionKey) - : "" : accessStore.openaiApiKey; return { isBedrock, diff --git a/app/client/platforms/bedrock.ts b/app/client/platforms/bedrock.ts index 7311c8a66..63de78347 100644 --- a/app/client/platforms/bedrock.ts +++ b/app/client/platforms/bedrock.ts @@ -20,6 +20,7 @@ import { } from "@/app/utils/aws"; import { prettyObject } from "@/app/utils/format"; import Locale from "@/app/locales"; +import { encrypt } from "@/app/utils/aws"; const ClaudeMapper = { assistant: "assistant", @@ -41,6 +42,66 @@ interface Tool { parameters?: any; }; } +const isApp = !!getClientConfig()?.isApp; +// const isApp = true; +async function getBedrockHeaders( + modelId: string, + chatPath: string, + finalRequestBody: any, + shouldStream: boolean, +): Promise> { + const accessStore = useAccessStore.getState(); + const bedrockHeaders = isApp + ? await sign({ + method: "POST", + url: chatPath, + region: accessStore.awsRegion, + accessKeyId: accessStore.awsAccessKey, + secretAccessKey: accessStore.awsSecretKey, + body: finalRequestBody, + service: "bedrock", + headers: {}, + isStreaming: shouldStream, + }) + : getHeaders(); + + if (!isApp) { + const { awsRegion, awsAccessKey, awsSecretKey, encryptionKey } = + accessStore; + + const bedrockHeadersConfig = { + XModelID: modelId, + XEncryptionKey: encryptionKey, + ShouldStream: String(shouldStream), + Authorization: await createAuthHeader( + awsRegion, + awsAccessKey, + awsSecretKey, + encryptionKey, + ), + }; + + Object.assign(bedrockHeaders, bedrockHeadersConfig); + } + + return bedrockHeaders; +} + +// Helper function to create Authorization header +async function createAuthHeader( + region: string, + accessKey: string, + secretKey: string, + encryptionKey: string, +): Promise { + const encryptedValues = await Promise.all([ + encrypt(region, encryptionKey), + encrypt(accessKey, encryptionKey), + encrypt(secretKey, encryptionKey), + ]); + + return `Bearer ${encryptedValues.join(":")}`; +} export class BedrockApi implements LLMApi { speech(options: SpeechOptions): Promise { @@ -343,32 +404,11 @@ export class BedrockApi implements LLMApi { let finalRequestBody = this.formatRequestBody(messages, modelConfig); try { - const isApp = !!getClientConfig()?.isApp; - // const isApp = true; const bedrockAPIPath = `${BEDROCK_BASE_URL}/model/${ modelConfig.model }/invoke${shouldStream ? "-with-response-stream" : ""}`; const chatPath = isApp ? bedrockAPIPath : ApiPath.Bedrock + "/chat"; - const headers = isApp - ? await sign({ - method: "POST", - url: chatPath, - region: accessStore.awsRegion, - accessKeyId: accessStore.awsAccessKey, - secretAccessKey: accessStore.awsSecretKey, - body: finalRequestBody, - service: "bedrock", - isStreaming: shouldStream, - }) - : getHeaders(); - - if (!isApp) { - headers.XModelID = modelConfig.model; - headers.XEncryptionKey = accessStore.encryptionKey; - headers.ShouldStream = shouldStream + ""; - } - if (process.env.NODE_ENV !== "production") { console.debug("[Bedrock Client] Request:", { path: chatPath, @@ -385,9 +425,9 @@ export class BedrockApi implements LLMApi { useChatStore.getState().currentSession().mask?.plugin || [], ); return bedrockStream( + modelConfig.model, chatPath, finalRequestBody, - headers, funcs, controller, // processToolMessage, include tool_calls message and tool call results @@ -513,9 +553,15 @@ export class BedrockApi implements LLMApi { try { controller.signal.onabort = () => options.onFinish("", new Response(null, { status: 400 })); + const newHeaders = await getBedrockHeaders( + modelConfig.model, + chatPath, + JSON.stringify(finalRequestBody), + shouldStream, + ); const res = await fetch(chatPath, { method: "POST", - headers: headers, + headers: newHeaders, body: JSON.stringify(finalRequestBody), }); const contentType = res.headers.get("content-type"); @@ -547,9 +593,9 @@ export class BedrockApi implements LLMApi { } function bedrockStream( + modelId: string, chatPath: string, requestPayload: any, - headers: any, funcs: Record, controller: AbortController, processToolMessage: ( @@ -655,7 +701,7 @@ function bedrockStream( setTimeout(() => { console.debug("[BedrockAPI for toolCallResult] restart"); running = false; - bedrockChatApi(chatPath, headers, requestPayload); + bedrockChatApi(modelId, chatPath, requestPayload, true); }, 60); }); } @@ -671,19 +717,26 @@ function bedrockStream( controller.signal.onabort = finish; async function bedrockChatApi( + modelId: string, chatPath: string, - headers: any, requestPayload: any, + shouldStream: boolean, ) { const requestTimeoutId = setTimeout( () => controller.abort(), REQUEST_TIMEOUT_MS, ); + const newHeaders = await getBedrockHeaders( + modelId, + chatPath, + JSON.stringify(requestPayload), + shouldStream, + ); try { const res = await fetch(chatPath, { method: "POST", - headers, + headers: newHeaders, body: JSON.stringify(requestPayload), redirect: "manual", // @ts-ignore @@ -792,5 +845,5 @@ function bedrockStream( } console.debug("[BedrockAPI] start"); - bedrockChatApi(chatPath, headers, requestPayload); + bedrockChatApi(modelId, chatPath, requestPayload, true); } diff --git a/app/components/ui-lib.tsx b/app/components/ui-lib.tsx index 4b7a4798a..14956103c 100644 --- a/app/components/ui-lib.tsx +++ b/app/components/ui-lib.tsx @@ -276,26 +276,18 @@ export function PasswordInput( }, ) { const [visible, setVisible] = useState(false); - const [displayValue, setDisplayValue] = useState(props.value as string); - const { maskWhenShow, ...inputProps } = props; - - useEffect(() => { - if (maskWhenShow && visible && props.value) { - setDisplayValue(maskSensitiveValue(props.value as string)); - } else { - setDisplayValue(props.value as string); - } - }, [visible, props.value, maskWhenShow]); + const [isEditing, setIsEditing] = useState(false); + const { maskWhenShow, onChange, value, ...inputProps } = props; function changeVisibility() { setVisible(!visible); } - const handleChange = (e: React.ChangeEvent) => { - if (props.onChange) { - props.onChange(e); - } - }; + // Get display value - use masked value only when showing and maskWhenShow is true and not editing + const displayValue = + maskWhenShow && visible && value && !isEditing + ? maskSensitiveValue(value as string) + : value; return (
@@ -308,7 +300,9 @@ export function PasswordInput( setIsEditing(true)} + onBlur={() => setIsEditing(false)} type={visible ? "text" : "password"} className={"password-input"} /> diff --git a/app/utils/aws.ts b/app/utils/aws.ts index f612a702a..ec781319d 100644 --- a/app/utils/aws.ts +++ b/app/utils/aws.ts @@ -1,9 +1,3 @@ -import SHA256 from "crypto-js/sha256"; -import HmacSHA256 from "crypto-js/hmac-sha256"; -import Hex from "crypto-js/enc-hex"; -import Utf8 from "crypto-js/enc-utf8"; -import { AES, enc, lib, PBKDF2, mode, pad, algo } from "crypto-js"; - // Types and Interfaces export interface BedrockCredentials { region: string; @@ -15,89 +9,128 @@ export interface BedrockCredentials { type ParsedEvent = Record; type EventResult = ParsedEvent[]; -// Encryption utilities -function generateSalt(): string { - const salt = lib.WordArray.random(128 / 8); - return salt.toString(enc.Base64); -} - -function generateIV(): string { - const iv = lib.WordArray.random(128 / 8); - return iv.toString(enc.Base64); -} - -function deriveKey(password: string, salt: string): lib.WordArray { - // Use PBKDF2 with SHA256 for key derivation - return PBKDF2(password, salt, { - keySize: 256 / 32, - iterations: 10000, - hasher: algo.SHA256, - }); -} - // Using a dot as separator since it's not used in Base64 const SEPARATOR = "."; -export function encrypt(data: string, encryptionKey: string): string { +// Unified crypto utilities for both frontend and backend +async function generateKey( + password: string, + salt: Uint8Array, +): Promise { + const enc = new TextEncoder(); + const keyMaterial = await crypto.subtle.importKey( + "raw", + enc.encode(password), + { name: "PBKDF2" }, + false, + ["deriveBits", "deriveKey"], + ); + + return crypto.subtle.deriveKey( + { + name: "PBKDF2", + salt, + iterations: 100000, + hash: "SHA-256", + }, + keyMaterial, + { name: "AES-GCM", length: 256 }, + false, + ["encrypt", "decrypt"], + ); +} + +function arrayBufferToBase64(buffer: ArrayBuffer | Uint8Array): string { + const bytes = buffer instanceof Uint8Array ? buffer : new Uint8Array(buffer); + return btoa(String.fromCharCode(...bytes)); +} + +function base64ToArrayBuffer(base64: string): Uint8Array { + const binaryString = atob(base64); + const bytes = new Uint8Array(binaryString.length); + for (let i = 0; i < binaryString.length; i++) { + bytes[i] = binaryString.charCodeAt(i); + } + return bytes; +} + +export async function encrypt( + data: string, + encryptionKey: string, +): Promise { if (!data) return ""; if (!encryptionKey) { throw new Error("Encryption key is required for AWS credential encryption"); } + try { - // Generate salt and IV - const salt = generateSalt(); - const iv = generateIV(); + const enc = new TextEncoder(); + const salt = crypto.getRandomValues(new Uint8Array(16)); + const iv = crypto.getRandomValues(new Uint8Array(12)); + const key = await generateKey(encryptionKey, salt); - // Derive key using PBKDF2 - const key = deriveKey(encryptionKey, salt); + const encrypted = await crypto.subtle.encrypt( + { + name: "AES-GCM", + iv, + }, + key, + enc.encode(data), + ); - // Encrypt the data - const encrypted = AES.encrypt(data, key, { - iv: enc.Base64.parse(iv), - mode: mode.CBC, - padding: pad.Pkcs7, - }); + // Convert to base64 strings + const encryptedBase64 = arrayBufferToBase64(encrypted); + const saltBase64 = arrayBufferToBase64(salt); + const ivBase64 = arrayBufferToBase64(iv); - // Combine salt, IV, and encrypted data - // Format: salt.iv.encryptedData - return [salt, iv, encrypted.toString()].join(SEPARATOR); + return [saltBase64, ivBase64, encryptedBase64].join(SEPARATOR); } catch (error) { + console.error("[Encryption Error]:", error); throw new Error("Failed to encrypt AWS credentials"); } } -export function decrypt(encryptedData: string, encryptionKey: string): string { +export async function decrypt( + encryptedData: string, + encryptionKey: string, +): Promise { if (!encryptedData) return ""; if (!encryptionKey) { throw new Error("Encryption key is required for AWS credential decryption"); } - try { - let components = encryptedData.split(SEPARATOR); - const [salt, iv, data] = components; - // For new format, use the provided salt and IV - const key = deriveKey(encryptionKey, salt); - const decrypted = AES.decrypt(data, key, { - iv: enc.Base64.parse(iv), - mode: mode.CBC, - padding: pad.Pkcs7, - }); - const result = decrypted.toString(enc.Utf8); - if (!result) { - throw new Error("Failed to decrypt AWS credentials"); - } - return result; + try { + const [saltBase64, ivBase64, cipherBase64] = encryptedData.split(SEPARATOR); + + // Convert base64 strings back to Uint8Arrays + const salt = base64ToArrayBuffer(saltBase64); + const iv = base64ToArrayBuffer(ivBase64); + const cipherData = base64ToArrayBuffer(cipherBase64); + + const key = await generateKey(encryptionKey, salt); + + const decrypted = await crypto.subtle.decrypt( + { + name: "AES-GCM", + iv, + }, + key, + cipherData, + ); + + const dec = new TextDecoder(); + return dec.decode(decrypted); } catch (error) { + console.error("[Decryption Error]:", error); throw new Error("Failed to decrypt AWS credentials"); } } export function maskSensitiveValue(value: string): string { if (!value) return ""; - if (value.length <= 4) return value; - // Use constant-time operations to prevent timing attacks - const masked = Buffer.alloc(value.length - 4, "*").toString(); - return value.slice(0, 2) + masked + value.slice(-2); + if (value.length <= 6) return value; + const masked = "*".repeat(value.length - 6); + return value.slice(0, 3) + masked + value.slice(-3); } // AWS Signing @@ -113,26 +146,33 @@ export interface SignParams { isStreaming?: boolean; } -function hmac( - key: string | CryptoJS.lib.WordArray, +async function createHmac( + key: ArrayBuffer | Uint8Array, data: string, -): CryptoJS.lib.WordArray { - if (typeof key === "string") { - key = Utf8.parse(key); - } - return HmacSHA256(data, key); +): Promise { + const encoder = new TextEncoder(); + const keyData = key instanceof Uint8Array ? key : new Uint8Array(key); + const keyObject = await crypto.subtle.importKey( + "raw", + keyData, + { name: "HMAC", hash: "SHA-256" }, + false, + ["sign"], + ); + return crypto.subtle.sign("HMAC", keyObject, encoder.encode(data)); } -function getSigningKey( +async function getSigningKey( secretKey: string, dateStamp: string, region: string, service: string, -): CryptoJS.lib.WordArray { - const kDate = hmac("AWS4" + secretKey, dateStamp); - const kRegion = hmac(kDate, region); - const kService = hmac(kRegion, service); - const kSigning = hmac(kService, "aws4_request"); +): Promise { + const encoder = new TextEncoder(); + const kDate = await createHmac(encoder.encode("AWS4" + secretKey), dateStamp); + const kRegion = await createHmac(kDate, region); + const kService = await createHmac(kRegion, service); + const kSigning = await createHmac(kService, "aws4_request"); return kSigning; } @@ -202,7 +242,14 @@ export async function sign({ const dateStamp = amzDate.slice(0, 8); const bodyString = typeof body === "string" ? body : JSON.stringify(body); - const payloadHash = SHA256(bodyString).toString(Hex); + const encoder = new TextEncoder(); + const payloadBuffer = await crypto.subtle.digest( + "SHA-256", + encoder.encode(bodyString), + ); + const payloadHash = Array.from(new Uint8Array(payloadBuffer)) + .map((b) => b.toString(16).padStart(2, "0")) + .join(""); const headers: Record = { accept: isStreaming @@ -215,6 +262,7 @@ export async function sign({ ...customHeaders, }; + // Add x-amzn-bedrock-accept header for streaming requests if (isStreaming) { headers["x-amzn-bedrock-accept"] = "*/*"; } @@ -244,20 +292,34 @@ export async function sign({ const algorithm = "AWS4-HMAC-SHA256"; const credentialScope = `${dateStamp}/${region}/${service}/aws4_request`; + + const canonicalRequestHash = Array.from( + new Uint8Array( + await crypto.subtle.digest("SHA-256", encoder.encode(canonicalRequest)), + ), + ) + .map((b) => b.toString(16).padStart(2, "0")) + .join(""); + const stringToSign = [ algorithm, amzDate, credentialScope, - SHA256(canonicalRequest).toString(Hex), + canonicalRequestHash, ].join("\n"); - const signingKey = getSigningKey( + const signingKey = await getSigningKey( secretAccessKey, dateStamp, region, service, ); - const signature = hmac(signingKey, stringToSign).toString(Hex); + + const signature = Array.from( + new Uint8Array(await createHmac(signingKey, stringToSign)), + ) + .map((b) => b.toString(16).padStart(2, "0")) + .join(""); const authorization = [ `${algorithm} Credential=${accessKeyId}/${credentialScope}`, @@ -278,7 +340,9 @@ export async function sign({ // Bedrock utilities function decodeBase64(base64String: string): string { try { - return Buffer.from(base64String, "base64").toString("utf-8"); + const bytes = Buffer.from(base64String, "base64"); + const decoder = new TextDecoder("utf-8"); + return decoder.decode(bytes); } catch (e) { console.error("[Base64 Decode Error]:", e); return ""; @@ -286,7 +350,7 @@ function decodeBase64(base64String: string): string { } export function parseEventData(chunk: Uint8Array): EventResult { - const decoder = new TextDecoder(); + const decoder = new TextDecoder("utf-8"); const text = decoder.decode(chunk); const results: EventResult = []; diff --git a/package.json b/package.json index 57a63bcac..e53c5ba89 100644 --- a/package.json +++ b/package.json @@ -24,11 +24,9 @@ "@hello-pangea/dnd": "^16.5.0", "@next/third-parties": "^14.1.0", "@svgr/webpack": "^6.5.1", - "@types/crypto-js": "^4.2.2", "@vercel/analytics": "^0.1.11", "@vercel/speed-insights": "^1.0.2", "axios": "^1.7.5", - "crypto-js": "^4.2.0", "clsx": "^2.1.1", "emoji-picker-react": "^4.9.2", "fuse.js": "^7.0.0", @@ -93,4 +91,4 @@ "lint-staged/yaml": "^2.2.2" }, "packageManager": "yarn@1.22.19" -} +} \ No newline at end of file