Enhance encryption security with additional safeguards.

This commit is contained in:
glay 2024-12-08 23:28:59 +08:00
parent 26b9fa97cd
commit f5ae086d3c
6 changed files with 245 additions and 141 deletions

View File

@ -32,14 +32,20 @@ async function getBedrockCredentials(
const [encryptedRegion, encryptedAccessKey, encryptedSecretKey] = const [encryptedRegion, encryptedAccessKey, encryptedSecretKey] =
credentials.split(":"); credentials.split(":");
// console.log("===========encryptedRegion",encryptedRegion);
// console.log("===========encryptedAccessKey",encryptedAccessKey);
// console.log("===========encryptedSecretKey",encryptedSecretKey);
if (!encryptedRegion || !encryptedAccessKey || !encryptedSecretKey) { if (!encryptedRegion || !encryptedAccessKey || !encryptedSecretKey) {
throw new Error("Invalid Authorization header format"); throw new Error("Invalid Authorization header format");
} }
const encryptionKey = req.headers.get("XEncryptionKey") || ""; const encryptionKey = req.headers.get("XEncryptionKey") || "";
// console.log("===========encryptionKey",encryptionKey);
// Decrypt the credentials // Decrypt the credentials
awsRegion = decrypt(encryptedRegion, encryptionKey); [awsRegion, awsAccessKey, awsSecretKey] = await Promise.all([
awsAccessKey = decrypt(encryptedAccessKey, encryptionKey); decrypt(encryptedRegion, encryptionKey),
awsSecretKey = decrypt(encryptedSecretKey, encryptionKey); decrypt(encryptedAccessKey, encryptionKey),
decrypt(encryptedSecretKey, encryptionKey),
]);
if (!awsRegion || !awsAccessKey || !awsSecretKey) { if (!awsRegion || !awsAccessKey || !awsSecretKey) {
throw new Error( throw new Error(

View File

@ -23,7 +23,6 @@ import { SparkApi } from "./platforms/iflytek";
import { XAIApi } from "./platforms/xai"; import { XAIApi } from "./platforms/xai";
import { ChatGLMApi } from "./platforms/glm"; import { ChatGLMApi } from "./platforms/glm";
import { BedrockApi } from "./platforms/bedrock"; import { BedrockApi } from "./platforms/bedrock";
import { encrypt } from "../utils/aws";
export const ROLES = ["system", "user", "assistant"] as const; export const ROLES = ["system", "user", "assistant"] as const;
export type MessageRole = (typeof ROLES)[number]; export type MessageRole = (typeof ROLES)[number];
@ -276,16 +275,6 @@ export function getHeaders(ignoreHeaders: boolean = false) {
? accessStore.iflytekApiKey && accessStore.iflytekApiSecret ? accessStore.iflytekApiKey && accessStore.iflytekApiSecret
? accessStore.iflytekApiKey + ":" + accessStore.iflytekApiSecret ? accessStore.iflytekApiKey + ":" + accessStore.iflytekApiSecret
: "" : ""
: isBedrock
? accessStore.awsRegion &&
accessStore.awsAccessKey &&
accessStore.awsSecretKey
? encrypt(accessStore.awsRegion, accessStore.encryptionKey) +
":" +
encrypt(accessStore.awsAccessKey, accessStore.encryptionKey) +
":" +
encrypt(accessStore.awsSecretKey, accessStore.encryptionKey)
: ""
: accessStore.openaiApiKey; : accessStore.openaiApiKey;
return { return {
isBedrock, isBedrock,

View File

@ -20,6 +20,7 @@ import {
} from "@/app/utils/aws"; } from "@/app/utils/aws";
import { prettyObject } from "@/app/utils/format"; import { prettyObject } from "@/app/utils/format";
import Locale from "@/app/locales"; import Locale from "@/app/locales";
import { encrypt } from "@/app/utils/aws";
const ClaudeMapper = { const ClaudeMapper = {
assistant: "assistant", assistant: "assistant",
@ -41,6 +42,66 @@ interface Tool {
parameters?: any; parameters?: any;
}; };
} }
const isApp = !!getClientConfig()?.isApp;
// const isApp = true;
async function getBedrockHeaders(
modelId: string,
chatPath: string,
finalRequestBody: any,
shouldStream: boolean,
): Promise<Record<string, string>> {
const accessStore = useAccessStore.getState();
const bedrockHeaders = isApp
? await sign({
method: "POST",
url: chatPath,
region: accessStore.awsRegion,
accessKeyId: accessStore.awsAccessKey,
secretAccessKey: accessStore.awsSecretKey,
body: finalRequestBody,
service: "bedrock",
headers: {},
isStreaming: shouldStream,
})
: getHeaders();
if (!isApp) {
const { awsRegion, awsAccessKey, awsSecretKey, encryptionKey } =
accessStore;
const bedrockHeadersConfig = {
XModelID: modelId,
XEncryptionKey: encryptionKey,
ShouldStream: String(shouldStream),
Authorization: await createAuthHeader(
awsRegion,
awsAccessKey,
awsSecretKey,
encryptionKey,
),
};
Object.assign(bedrockHeaders, bedrockHeadersConfig);
}
return bedrockHeaders;
}
// Helper function to create Authorization header
async function createAuthHeader(
region: string,
accessKey: string,
secretKey: string,
encryptionKey: string,
): Promise<string> {
const encryptedValues = await Promise.all([
encrypt(region, encryptionKey),
encrypt(accessKey, encryptionKey),
encrypt(secretKey, encryptionKey),
]);
return `Bearer ${encryptedValues.join(":")}`;
}
export class BedrockApi implements LLMApi { export class BedrockApi implements LLMApi {
speech(options: SpeechOptions): Promise<ArrayBuffer> { speech(options: SpeechOptions): Promise<ArrayBuffer> {
@ -343,32 +404,11 @@ export class BedrockApi implements LLMApi {
let finalRequestBody = this.formatRequestBody(messages, modelConfig); let finalRequestBody = this.formatRequestBody(messages, modelConfig);
try { try {
const isApp = !!getClientConfig()?.isApp;
// const isApp = true;
const bedrockAPIPath = `${BEDROCK_BASE_URL}/model/${ const bedrockAPIPath = `${BEDROCK_BASE_URL}/model/${
modelConfig.model modelConfig.model
}/invoke${shouldStream ? "-with-response-stream" : ""}`; }/invoke${shouldStream ? "-with-response-stream" : ""}`;
const chatPath = isApp ? bedrockAPIPath : ApiPath.Bedrock + "/chat"; const chatPath = isApp ? bedrockAPIPath : ApiPath.Bedrock + "/chat";
const headers = isApp
? await sign({
method: "POST",
url: chatPath,
region: accessStore.awsRegion,
accessKeyId: accessStore.awsAccessKey,
secretAccessKey: accessStore.awsSecretKey,
body: finalRequestBody,
service: "bedrock",
isStreaming: shouldStream,
})
: getHeaders();
if (!isApp) {
headers.XModelID = modelConfig.model;
headers.XEncryptionKey = accessStore.encryptionKey;
headers.ShouldStream = shouldStream + "";
}
if (process.env.NODE_ENV !== "production") { if (process.env.NODE_ENV !== "production") {
console.debug("[Bedrock Client] Request:", { console.debug("[Bedrock Client] Request:", {
path: chatPath, path: chatPath,
@ -385,9 +425,9 @@ export class BedrockApi implements LLMApi {
useChatStore.getState().currentSession().mask?.plugin || [], useChatStore.getState().currentSession().mask?.plugin || [],
); );
return bedrockStream( return bedrockStream(
modelConfig.model,
chatPath, chatPath,
finalRequestBody, finalRequestBody,
headers,
funcs, funcs,
controller, controller,
// processToolMessage, include tool_calls message and tool call results // processToolMessage, include tool_calls message and tool call results
@ -513,9 +553,15 @@ export class BedrockApi implements LLMApi {
try { try {
controller.signal.onabort = () => controller.signal.onabort = () =>
options.onFinish("", new Response(null, { status: 400 })); options.onFinish("", new Response(null, { status: 400 }));
const newHeaders = await getBedrockHeaders(
modelConfig.model,
chatPath,
JSON.stringify(finalRequestBody),
shouldStream,
);
const res = await fetch(chatPath, { const res = await fetch(chatPath, {
method: "POST", method: "POST",
headers: headers, headers: newHeaders,
body: JSON.stringify(finalRequestBody), body: JSON.stringify(finalRequestBody),
}); });
const contentType = res.headers.get("content-type"); const contentType = res.headers.get("content-type");
@ -547,9 +593,9 @@ export class BedrockApi implements LLMApi {
} }
function bedrockStream( function bedrockStream(
modelId: string,
chatPath: string, chatPath: string,
requestPayload: any, requestPayload: any,
headers: any,
funcs: Record<string, Function>, funcs: Record<string, Function>,
controller: AbortController, controller: AbortController,
processToolMessage: ( processToolMessage: (
@ -655,7 +701,7 @@ function bedrockStream(
setTimeout(() => { setTimeout(() => {
console.debug("[BedrockAPI for toolCallResult] restart"); console.debug("[BedrockAPI for toolCallResult] restart");
running = false; running = false;
bedrockChatApi(chatPath, headers, requestPayload); bedrockChatApi(modelId, chatPath, requestPayload, true);
}, 60); }, 60);
}); });
} }
@ -671,19 +717,26 @@ function bedrockStream(
controller.signal.onabort = finish; controller.signal.onabort = finish;
async function bedrockChatApi( async function bedrockChatApi(
modelId: string,
chatPath: string, chatPath: string,
headers: any,
requestPayload: any, requestPayload: any,
shouldStream: boolean,
) { ) {
const requestTimeoutId = setTimeout( const requestTimeoutId = setTimeout(
() => controller.abort(), () => controller.abort(),
REQUEST_TIMEOUT_MS, REQUEST_TIMEOUT_MS,
); );
const newHeaders = await getBedrockHeaders(
modelId,
chatPath,
JSON.stringify(requestPayload),
shouldStream,
);
try { try {
const res = await fetch(chatPath, { const res = await fetch(chatPath, {
method: "POST", method: "POST",
headers, headers: newHeaders,
body: JSON.stringify(requestPayload), body: JSON.stringify(requestPayload),
redirect: "manual", redirect: "manual",
// @ts-ignore // @ts-ignore
@ -792,5 +845,5 @@ function bedrockStream(
} }
console.debug("[BedrockAPI] start"); console.debug("[BedrockAPI] start");
bedrockChatApi(chatPath, headers, requestPayload); bedrockChatApi(modelId, chatPath, requestPayload, true);
} }

View File

@ -276,26 +276,18 @@ export function PasswordInput(
}, },
) { ) {
const [visible, setVisible] = useState(false); const [visible, setVisible] = useState(false);
const [displayValue, setDisplayValue] = useState(props.value as string); const [isEditing, setIsEditing] = useState(false);
const { maskWhenShow, ...inputProps } = props; const { maskWhenShow, onChange, value, ...inputProps } = props;
useEffect(() => {
if (maskWhenShow && visible && props.value) {
setDisplayValue(maskSensitiveValue(props.value as string));
} else {
setDisplayValue(props.value as string);
}
}, [visible, props.value, maskWhenShow]);
function changeVisibility() { function changeVisibility() {
setVisible(!visible); setVisible(!visible);
} }
const handleChange = (e: React.ChangeEvent<HTMLInputElement>) => { // Get display value - use masked value only when showing and maskWhenShow is true and not editing
if (props.onChange) { const displayValue =
props.onChange(e); maskWhenShow && visible && value && !isEditing
} ? maskSensitiveValue(value as string)
}; : value;
return ( return (
<div className={"password-input-container"}> <div className={"password-input-container"}>
@ -308,7 +300,9 @@ export function PasswordInput(
<input <input
{...inputProps} {...inputProps}
value={displayValue} value={displayValue}
onChange={handleChange} onChange={onChange}
onFocus={() => setIsEditing(true)}
onBlur={() => setIsEditing(false)}
type={visible ? "text" : "password"} type={visible ? "text" : "password"}
className={"password-input"} className={"password-input"}
/> />

View File

@ -1,9 +1,3 @@
import SHA256 from "crypto-js/sha256";
import HmacSHA256 from "crypto-js/hmac-sha256";
import Hex from "crypto-js/enc-hex";
import Utf8 from "crypto-js/enc-utf8";
import { AES, enc, lib, PBKDF2, mode, pad, algo } from "crypto-js";
// Types and Interfaces // Types and Interfaces
export interface BedrockCredentials { export interface BedrockCredentials {
region: string; region: string;
@ -15,89 +9,128 @@ export interface BedrockCredentials {
type ParsedEvent = Record<string, any>; type ParsedEvent = Record<string, any>;
type EventResult = ParsedEvent[]; type EventResult = ParsedEvent[];
// Encryption utilities
function generateSalt(): string {
const salt = lib.WordArray.random(128 / 8);
return salt.toString(enc.Base64);
}
function generateIV(): string {
const iv = lib.WordArray.random(128 / 8);
return iv.toString(enc.Base64);
}
function deriveKey(password: string, salt: string): lib.WordArray {
// Use PBKDF2 with SHA256 for key derivation
return PBKDF2(password, salt, {
keySize: 256 / 32,
iterations: 10000,
hasher: algo.SHA256,
});
}
// Using a dot as separator since it's not used in Base64 // Using a dot as separator since it's not used in Base64
const SEPARATOR = "."; const SEPARATOR = ".";
export function encrypt(data: string, encryptionKey: string): string { // Unified crypto utilities for both frontend and backend
async function generateKey(
password: string,
salt: Uint8Array,
): Promise<CryptoKey> {
const enc = new TextEncoder();
const keyMaterial = await crypto.subtle.importKey(
"raw",
enc.encode(password),
{ name: "PBKDF2" },
false,
["deriveBits", "deriveKey"],
);
return crypto.subtle.deriveKey(
{
name: "PBKDF2",
salt,
iterations: 100000,
hash: "SHA-256",
},
keyMaterial,
{ name: "AES-GCM", length: 256 },
false,
["encrypt", "decrypt"],
);
}
function arrayBufferToBase64(buffer: ArrayBuffer | Uint8Array): string {
const bytes = buffer instanceof Uint8Array ? buffer : new Uint8Array(buffer);
return btoa(String.fromCharCode(...bytes));
}
function base64ToArrayBuffer(base64: string): Uint8Array {
const binaryString = atob(base64);
const bytes = new Uint8Array(binaryString.length);
for (let i = 0; i < binaryString.length; i++) {
bytes[i] = binaryString.charCodeAt(i);
}
return bytes;
}
export async function encrypt(
data: string,
encryptionKey: string,
): Promise<string> {
if (!data) return ""; if (!data) return "";
if (!encryptionKey) { if (!encryptionKey) {
throw new Error("Encryption key is required for AWS credential encryption"); throw new Error("Encryption key is required for AWS credential encryption");
} }
try { try {
// Generate salt and IV const enc = new TextEncoder();
const salt = generateSalt(); const salt = crypto.getRandomValues(new Uint8Array(16));
const iv = generateIV(); const iv = crypto.getRandomValues(new Uint8Array(12));
const key = await generateKey(encryptionKey, salt);
// Derive key using PBKDF2 const encrypted = await crypto.subtle.encrypt(
const key = deriveKey(encryptionKey, salt); {
name: "AES-GCM",
iv,
},
key,
enc.encode(data),
);
// Encrypt the data // Convert to base64 strings
const encrypted = AES.encrypt(data, key, { const encryptedBase64 = arrayBufferToBase64(encrypted);
iv: enc.Base64.parse(iv), const saltBase64 = arrayBufferToBase64(salt);
mode: mode.CBC, const ivBase64 = arrayBufferToBase64(iv);
padding: pad.Pkcs7,
});
// Combine salt, IV, and encrypted data return [saltBase64, ivBase64, encryptedBase64].join(SEPARATOR);
// Format: salt.iv.encryptedData
return [salt, iv, encrypted.toString()].join(SEPARATOR);
} catch (error) { } catch (error) {
console.error("[Encryption Error]:", error);
throw new Error("Failed to encrypt AWS credentials"); throw new Error("Failed to encrypt AWS credentials");
} }
} }
export function decrypt(encryptedData: string, encryptionKey: string): string { export async function decrypt(
encryptedData: string,
encryptionKey: string,
): Promise<string> {
if (!encryptedData) return ""; if (!encryptedData) return "";
if (!encryptionKey) { if (!encryptionKey) {
throw new Error("Encryption key is required for AWS credential decryption"); throw new Error("Encryption key is required for AWS credential decryption");
} }
try {
let components = encryptedData.split(SEPARATOR);
const [salt, iv, data] = components;
// For new format, use the provided salt and IV
const key = deriveKey(encryptionKey, salt);
const decrypted = AES.decrypt(data, key, {
iv: enc.Base64.parse(iv),
mode: mode.CBC,
padding: pad.Pkcs7,
});
const result = decrypted.toString(enc.Utf8); try {
if (!result) { const [saltBase64, ivBase64, cipherBase64] = encryptedData.split(SEPARATOR);
throw new Error("Failed to decrypt AWS credentials");
} // Convert base64 strings back to Uint8Arrays
return result; const salt = base64ToArrayBuffer(saltBase64);
const iv = base64ToArrayBuffer(ivBase64);
const cipherData = base64ToArrayBuffer(cipherBase64);
const key = await generateKey(encryptionKey, salt);
const decrypted = await crypto.subtle.decrypt(
{
name: "AES-GCM",
iv,
},
key,
cipherData,
);
const dec = new TextDecoder();
return dec.decode(decrypted);
} catch (error) { } catch (error) {
console.error("[Decryption Error]:", error);
throw new Error("Failed to decrypt AWS credentials"); throw new Error("Failed to decrypt AWS credentials");
} }
} }
export function maskSensitiveValue(value: string): string { export function maskSensitiveValue(value: string): string {
if (!value) return ""; if (!value) return "";
if (value.length <= 4) return value; if (value.length <= 6) return value;
// Use constant-time operations to prevent timing attacks const masked = "*".repeat(value.length - 6);
const masked = Buffer.alloc(value.length - 4, "*").toString(); return value.slice(0, 3) + masked + value.slice(-3);
return value.slice(0, 2) + masked + value.slice(-2);
} }
// AWS Signing // AWS Signing
@ -113,26 +146,33 @@ export interface SignParams {
isStreaming?: boolean; isStreaming?: boolean;
} }
function hmac( async function createHmac(
key: string | CryptoJS.lib.WordArray, key: ArrayBuffer | Uint8Array,
data: string, data: string,
): CryptoJS.lib.WordArray { ): Promise<ArrayBuffer> {
if (typeof key === "string") { const encoder = new TextEncoder();
key = Utf8.parse(key); const keyData = key instanceof Uint8Array ? key : new Uint8Array(key);
} const keyObject = await crypto.subtle.importKey(
return HmacSHA256(data, key); "raw",
keyData,
{ name: "HMAC", hash: "SHA-256" },
false,
["sign"],
);
return crypto.subtle.sign("HMAC", keyObject, encoder.encode(data));
} }
function getSigningKey( async function getSigningKey(
secretKey: string, secretKey: string,
dateStamp: string, dateStamp: string,
region: string, region: string,
service: string, service: string,
): CryptoJS.lib.WordArray { ): Promise<ArrayBuffer> {
const kDate = hmac("AWS4" + secretKey, dateStamp); const encoder = new TextEncoder();
const kRegion = hmac(kDate, region); const kDate = await createHmac(encoder.encode("AWS4" + secretKey), dateStamp);
const kService = hmac(kRegion, service); const kRegion = await createHmac(kDate, region);
const kSigning = hmac(kService, "aws4_request"); const kService = await createHmac(kRegion, service);
const kSigning = await createHmac(kService, "aws4_request");
return kSigning; return kSigning;
} }
@ -202,7 +242,14 @@ export async function sign({
const dateStamp = amzDate.slice(0, 8); const dateStamp = amzDate.slice(0, 8);
const bodyString = typeof body === "string" ? body : JSON.stringify(body); const bodyString = typeof body === "string" ? body : JSON.stringify(body);
const payloadHash = SHA256(bodyString).toString(Hex); const encoder = new TextEncoder();
const payloadBuffer = await crypto.subtle.digest(
"SHA-256",
encoder.encode(bodyString),
);
const payloadHash = Array.from(new Uint8Array(payloadBuffer))
.map((b) => b.toString(16).padStart(2, "0"))
.join("");
const headers: Record<string, string> = { const headers: Record<string, string> = {
accept: isStreaming accept: isStreaming
@ -215,6 +262,7 @@ export async function sign({
...customHeaders, ...customHeaders,
}; };
// Add x-amzn-bedrock-accept header for streaming requests
if (isStreaming) { if (isStreaming) {
headers["x-amzn-bedrock-accept"] = "*/*"; headers["x-amzn-bedrock-accept"] = "*/*";
} }
@ -244,20 +292,34 @@ export async function sign({
const algorithm = "AWS4-HMAC-SHA256"; const algorithm = "AWS4-HMAC-SHA256";
const credentialScope = `${dateStamp}/${region}/${service}/aws4_request`; const credentialScope = `${dateStamp}/${region}/${service}/aws4_request`;
const canonicalRequestHash = Array.from(
new Uint8Array(
await crypto.subtle.digest("SHA-256", encoder.encode(canonicalRequest)),
),
)
.map((b) => b.toString(16).padStart(2, "0"))
.join("");
const stringToSign = [ const stringToSign = [
algorithm, algorithm,
amzDate, amzDate,
credentialScope, credentialScope,
SHA256(canonicalRequest).toString(Hex), canonicalRequestHash,
].join("\n"); ].join("\n");
const signingKey = getSigningKey( const signingKey = await getSigningKey(
secretAccessKey, secretAccessKey,
dateStamp, dateStamp,
region, region,
service, service,
); );
const signature = hmac(signingKey, stringToSign).toString(Hex);
const signature = Array.from(
new Uint8Array(await createHmac(signingKey, stringToSign)),
)
.map((b) => b.toString(16).padStart(2, "0"))
.join("");
const authorization = [ const authorization = [
`${algorithm} Credential=${accessKeyId}/${credentialScope}`, `${algorithm} Credential=${accessKeyId}/${credentialScope}`,
@ -278,7 +340,9 @@ export async function sign({
// Bedrock utilities // Bedrock utilities
function decodeBase64(base64String: string): string { function decodeBase64(base64String: string): string {
try { try {
return Buffer.from(base64String, "base64").toString("utf-8"); const bytes = Buffer.from(base64String, "base64");
const decoder = new TextDecoder("utf-8");
return decoder.decode(bytes);
} catch (e) { } catch (e) {
console.error("[Base64 Decode Error]:", e); console.error("[Base64 Decode Error]:", e);
return ""; return "";
@ -286,7 +350,7 @@ function decodeBase64(base64String: string): string {
} }
export function parseEventData(chunk: Uint8Array): EventResult { export function parseEventData(chunk: Uint8Array): EventResult {
const decoder = new TextDecoder(); const decoder = new TextDecoder("utf-8");
const text = decoder.decode(chunk); const text = decoder.decode(chunk);
const results: EventResult = []; const results: EventResult = [];

View File

@ -24,11 +24,9 @@
"@hello-pangea/dnd": "^16.5.0", "@hello-pangea/dnd": "^16.5.0",
"@next/third-parties": "^14.1.0", "@next/third-parties": "^14.1.0",
"@svgr/webpack": "^6.5.1", "@svgr/webpack": "^6.5.1",
"@types/crypto-js": "^4.2.2",
"@vercel/analytics": "^0.1.11", "@vercel/analytics": "^0.1.11",
"@vercel/speed-insights": "^1.0.2", "@vercel/speed-insights": "^1.0.2",
"axios": "^1.7.5", "axios": "^1.7.5",
"crypto-js": "^4.2.0",
"clsx": "^2.1.1", "clsx": "^2.1.1",
"emoji-picker-react": "^4.9.2", "emoji-picker-react": "^4.9.2",
"fuse.js": "^7.0.0", "fuse.js": "^7.0.0",
@ -93,4 +91,4 @@
"lint-staged/yaml": "^2.2.2" "lint-staged/yaml": "^2.2.2"
}, },
"packageManager": "yarn@1.22.19" "packageManager": "yarn@1.22.19"
} }