diff --git a/app/api/bedrock.ts b/app/api/bedrock.ts index c8709722e..1417e9475 100644 --- a/app/api/bedrock.ts +++ b/app/api/bedrock.ts @@ -1,131 +1,186 @@ -import { getServerSideConfig } from "../config/server"; -import { prettyObject } from "../utils/format"; import { NextRequest, NextResponse } from "next/server"; -import { decrypt } from "../utils/encryption"; -import { - BedrockRuntimeClient, - ConverseStreamCommand, - ConverseStreamCommandInput, - Message, - ContentBlock, - ConverseStreamOutput, -} from "@aws-sdk/client-bedrock-runtime"; +import { sign, decrypt } from "../utils/aws"; -const ALLOWED_PATH = new Set(["converse"]); +const ALLOWED_PATH = new Set(["chat", "models"]); -// AWS Credential Validation Function -function validateAwsCredentials( - region: string, - accessKeyId: string, - secretAccessKey: string, -): boolean { - const regionRegex = /^[a-z]{2}-[a-z]+-\d+$/; - const accessKeyRegex = /^(AKIA|A3T|ASIA)[A-Z0-9]{16}$/; - - return ( - regionRegex.test(region) && - accessKeyRegex.test(accessKeyId) && - secretAccessKey.length === 40 - ); +function parseEventData(chunk: Uint8Array): any { + const decoder = new TextDecoder(); + const text = decoder.decode(chunk); + try { + return JSON.parse(text); + } catch (e) { + try { + const base64Match = text.match(/:"([A-Za-z0-9+/=]+)"/); + if (base64Match) { + const decoded = Buffer.from(base64Match[1], "base64").toString("utf-8"); + return JSON.parse(decoded); + } + const eventMatch = text.match(/:event-type[^\{]+({.*})/); + if (eventMatch) { + return JSON.parse(eventMatch[1]); + } + } catch (innerError) {} + } + return null; } -export interface ConverseRequest { - modelId: string; - messages: { - role: "user" | "assistant" | "system"; - content: string | any[]; - }[]; - inferenceConfig?: { - maxTokens?: number; - temperature?: number; - topP?: number; - stopSequences?: string[]; - }; - tools?: { - name: string; - description?: string; - input_schema: any; - }[]; - stream?: boolean; +async function* transformBedrockStream(stream: ReadableStream) { + const reader = stream.getReader(); + try { + while (true) { + const { done, value } = await reader.read(); + if (done) break; + + const parsed = parseEventData(value); + if (parsed) { + if (parsed.type === "content_block_delta") { + if (parsed.delta?.type === "text_delta") { + yield `data: ${JSON.stringify({ + delta: { text: parsed.delta.text }, + })}\n\n`; + } else if (parsed.delta?.type === "input_json_delta") { + yield `data: ${JSON.stringify(parsed)}\n\n`; + } + } else if ( + parsed.type === "message_delta" && + parsed.delta?.stop_reason + ) { + yield `data: ${JSON.stringify({ + delta: { stop_reason: parsed.delta.stop_reason }, + })}\n\n`; + } else if ( + parsed.type === "content_block_start" && + parsed.content_block?.type === "tool_use" + ) { + yield `data: ${JSON.stringify(parsed)}\n\n`; + } else if (parsed.type === "content_block_stop") { + yield `data: ${JSON.stringify(parsed)}\n\n`; + } + } + } + } finally { + reader.releaseLock(); + } } -function supportsToolUse(modelId: string): boolean { - return modelId.toLowerCase().includes("claude-3"); +function validateRequest(body: any, modelId: string): void { + if (!modelId) throw new Error("Model ID is required"); + + if (modelId.startsWith("anthropic.claude")) { + if ( + !body.anthropic_version || + body.anthropic_version !== "bedrock-2023-05-31" + ) { + throw new Error("anthropic_version must be 'bedrock-2023-05-31'"); + } + if (typeof body.max_tokens !== "number" || body.max_tokens < 0) { + throw new Error("max_tokens must be a positive number"); + } + if (modelId.startsWith("anthropic.claude-3")) { + if (!Array.isArray(body.messages)) + throw new Error("messages array is required for Claude 3"); + } else if (typeof body.prompt !== "string") { + throw new Error("prompt is required for Claude 2 and earlier"); + } + } else if (modelId.startsWith("meta.llama")) { + if (!body.prompt) throw new Error("Llama requires a prompt"); + } else if (modelId.startsWith("mistral.mistral")) { + if (!Array.isArray(body.messages)) + throw new Error("Mistral requires a messages array"); + } else if (modelId.startsWith("amazon.titan")) { + if (!body.inputText) throw new Error("Titan requires inputText"); + } } -function formatRequestBody( - request: ConverseRequest, -): ConverseStreamCommandInput { - const messages: Message[] = request.messages.map((msg) => ({ - role: msg.role === "system" ? "user" : msg.role, - content: Array.isArray(msg.content) - ? msg.content.map((item) => { - if (item.type === "tool_use") { - return { - toolUse: { - toolUseId: item.id, - name: item.name, - input: item.input || "{}", - }, - } as ContentBlock; - } - if (item.type === "tool_result") { - return { - toolResult: { - toolUseId: item.tool_use_id, - content: [{ text: item.content || ";" }], - status: "success", - }, - } as ContentBlock; - } - if (item.type === "text") { - return { text: item.text || ";" } as ContentBlock; - } - if (item.type === "image") { - return { - image: { - format: item.source.media_type.split("/")[1] as - | "png" - | "jpeg" - | "gif" - | "webp", - source: { - bytes: Uint8Array.from( - Buffer.from(item.source.data, "base64"), - ), - }, - }, - } as ContentBlock; - } - return { text: ";" } as ContentBlock; - }) - : [{ text: msg.content || ";" } as ContentBlock], - })); +async function requestBedrock(req: NextRequest) { + const controller = new AbortController(); + const awsRegion = req.headers.get("X-Region") ?? ""; + const awsAccessKey = req.headers.get("X-Access-Key") ?? ""; + const awsSecretKey = req.headers.get("X-Secret-Key") ?? ""; + const awsSessionToken = req.headers.get("X-Session-Token"); + const modelId = req.headers.get("X-Model-Id") ?? ""; - const input: ConverseStreamCommandInput = { - modelId: request.modelId, - messages, - ...(request.inferenceConfig && { - inferenceConfig: request.inferenceConfig, - }), - }; - - if (request.tools?.length && supportsToolUse(request.modelId)) { - input.toolConfig = { - tools: request.tools.map((tool) => ({ - toolSpec: { - name: tool.name, - description: tool.description, - inputSchema: { - json: tool.input_schema, - }, - }, - })), - toolChoice: { auto: {} }, - }; + if (!awsRegion || !awsAccessKey || !awsSecretKey || !modelId) { + throw new Error("Missing required AWS credentials or model ID"); } - return input; + const decryptedAccessKey = decrypt(awsAccessKey); + const decryptedSecretKey = decrypt(awsSecretKey); + const decryptedSessionToken = awsSessionToken + ? decrypt(awsSessionToken) + : undefined; + + if (!decryptedAccessKey || !decryptedSecretKey) { + throw new Error("Failed to decrypt AWS credentials"); + } + + const endpoint = `https://bedrock-runtime.${awsRegion}.amazonaws.com/model/${modelId}/invoke-with-response-stream`; + const timeoutId = setTimeout(() => controller.abort(), 10 * 60 * 1000); + + try { + const bodyText = await req.clone().text(); + const bodyJson = JSON.parse(bodyText); + validateRequest(bodyJson, modelId); + const canonicalBody = JSON.stringify(bodyJson); + + const headers = await sign({ + method: "POST", + url: endpoint, + region: awsRegion, + accessKeyId: decryptedAccessKey, + secretAccessKey: decryptedSecretKey, + sessionToken: decryptedSessionToken, + body: canonicalBody, + service: "bedrock", + }); + + const res = await fetch(endpoint, { + method: "POST", + headers, + body: canonicalBody, + redirect: "manual", + // @ts-ignore + duplex: "half", + signal: controller.signal, + }); + + if (!res.ok) { + const error = await res.text(); + try { + const errorJson = JSON.parse(error); + throw new Error(errorJson.message || error); + } catch { + throw new Error(error); + } + } + + const transformedStream = transformBedrockStream(res.body!); + const stream = new ReadableStream({ + async start(controller) { + try { + for await (const chunk of transformedStream) { + controller.enqueue(new TextEncoder().encode(chunk)); + } + controller.close(); + } catch (err) { + controller.error(err); + } + }, + }); + + return new Response(stream, { + headers: { + "Content-Type": "text/event-stream", + "Cache-Control": "no-cache", + Connection: "keep-alive", + "X-Accel-Buffering": "no", + }, + }); + } catch (e) { + throw e; + } finally { + clearTimeout(timeoutId); + } } export async function handle( @@ -139,166 +194,16 @@ export async function handle( const subpath = params.path.join("/"); if (!ALLOWED_PATH.has(subpath)) { return NextResponse.json( - { error: true, msg: "Path not allowed: " + subpath }, + { error: true, msg: "you are not allowed to request " + subpath }, { status: 403 }, ); } - const serverConfig = getServerSideConfig(); - let region = serverConfig.awsRegion; - let accessKeyId = serverConfig.awsAccessKey; - let secretAccessKey = serverConfig.awsSecretKey; - let sessionToken = undefined; - - // Attempt to get credentials from headers if not in server config - if (!region || !accessKeyId || !secretAccessKey) { - region = decrypt(req.headers.get("X-Region") ?? ""); - accessKeyId = decrypt(req.headers.get("X-Access-Key") ?? ""); - secretAccessKey = decrypt(req.headers.get("X-Secret-Key") ?? ""); - sessionToken = req.headers.get("X-Session-Token") - ? decrypt(req.headers.get("X-Session-Token") ?? "") - : undefined; - } - - // Validate AWS credentials - if (!validateAwsCredentials(region, accessKeyId, secretAccessKey)) { - return NextResponse.json( - { - error: true, - msg: "Invalid AWS credentials. Please check your region, access key, and secret key.", - }, - { status: 401 }, - ); - } - try { - const client = new BedrockRuntimeClient({ - region, - credentials: { - accessKeyId, - secretAccessKey, - sessionToken, - }, - }); - - const body = (await req.json()) as ConverseRequest; - const command = new ConverseStreamCommand(formatRequestBody(body)); - const response = await client.send(command); - - if (!response.stream) { - throw new Error("No stream in response"); - } - - // If stream is false, accumulate the response and return as JSON - if (body.stream === false) { - let fullResponse = { - content: "", - }; - - const responseStream = - response.stream as AsyncIterable; - for await (const event of responseStream) { - if ( - "contentBlockDelta" in event && - event.contentBlockDelta?.delta && - "text" in event.contentBlockDelta.delta && - event.contentBlockDelta.delta.text - ) { - fullResponse.content += event.contentBlockDelta.delta.text; - } - } - - return NextResponse.json(fullResponse); - } - - // Otherwise, return streaming response - const stream = new ReadableStream({ - async start(controller) { - try { - const responseStream = - response.stream as AsyncIterable; - for await (const event of responseStream) { - if ( - "contentBlockStart" in event && - event.contentBlockStart?.start?.toolUse && - event.contentBlockStart.contentBlockIndex !== undefined - ) { - controller.enqueue( - `data: ${JSON.stringify({ - type: "content_block", - content_block: { - type: "tool_use", - id: event.contentBlockStart.start.toolUse.toolUseId, - name: event.contentBlockStart.start.toolUse.name, - }, - index: event.contentBlockStart.contentBlockIndex, - })}\n\n`, - ); - } else if ( - "contentBlockDelta" in event && - event.contentBlockDelta?.delta && - event.contentBlockDelta.contentBlockIndex !== undefined - ) { - const delta = event.contentBlockDelta.delta; - - if ("text" in delta && delta.text) { - controller.enqueue( - `data: ${JSON.stringify({ - type: "content_block_delta", - delta: { - type: "text_delta", - text: delta.text, - }, - index: event.contentBlockDelta.contentBlockIndex, - })}\n\n`, - ); - } else if ("toolUse" in delta && delta.toolUse?.input) { - controller.enqueue( - `data: ${JSON.stringify({ - type: "content_block_delta", - delta: { - type: "input_json_delta", - partial_json: delta.toolUse.input, - }, - index: event.contentBlockDelta.contentBlockIndex, - })}\n\n`, - ); - } - } else if ( - "contentBlockStop" in event && - event.contentBlockStop?.contentBlockIndex !== undefined - ) { - controller.enqueue( - `data: ${JSON.stringify({ - type: "content_block_stop", - index: event.contentBlockStop.contentBlockIndex, - })}\n\n`, - ); - } - } - controller.close(); - } catch (error) { - console.error("[Bedrock] Stream error:", error); - controller.error(error); - } - }, - }); - - return new Response(stream, { - headers: { - "Content-Type": "text/event-stream", - "Cache-Control": "no-cache", - Connection: "keep-alive", - }, - }); + return await requestBedrock(req); } catch (e) { - console.error("[Bedrock] Error:", e); return NextResponse.json( - { - error: true, - message: e instanceof Error ? e.message : "Unknown error", - details: prettyObject(e), - }, + { error: true, msg: e instanceof Error ? e.message : "Unknown error" }, { status: 500 }, ); } diff --git a/app/client/api.ts b/app/client/api.ts index 003cd8874..feb1c93a2 100644 --- a/app/client/api.ts +++ b/app/client/api.ts @@ -23,7 +23,7 @@ import { SparkApi } from "./platforms/iflytek"; import { XAIApi } from "./platforms/xai"; import { ChatGLMApi } from "./platforms/glm"; import { BedrockApi } from "./platforms/bedrock"; -import { encrypt } from "../utils/encryption"; +import { encrypt } from "../utils/aws"; export const ROLES = ["system", "user", "assistant"] as const; export type MessageRole = (typeof ROLES)[number]; diff --git a/app/client/platforms/bedrock.ts b/app/client/platforms/bedrock.ts index 8dca4eff5..4c6371b17 100644 --- a/app/client/platforms/bedrock.ts +++ b/app/client/platforms/bedrock.ts @@ -1,30 +1,13 @@ -import { ApiPath } from "../../constant"; -import { ChatOptions, getHeaders, LLMApi, SpeechOptions } from "../api"; +import { ChatOptions, LLMApi, SpeechOptions } from "../api"; import { useAppConfig, usePluginStore, useChatStore, + useAccessStore, ChatMessageTool, } from "../../store"; -import { getMessageTextContent, isVisionModel } from "../../utils"; -import { fetch } from "../../utils/stream"; import { preProcessImageContent, stream } from "../../utils/chat"; -import { RequestPayload } from "./openai"; - -export type MultiBlockContent = { - type: "image" | "text"; - source?: { - type: string; - media_type: string; - data: string; - }; - text?: string; -}; - -export type AnthropicMessage = { - role: (typeof ClaudeMapper)[keyof typeof ClaudeMapper]; - content: string | MultiBlockContent[]; -}; +import { getMessageTextContent, isVisionModel } from "../../utils"; const ClaudeMapper = { assistant: "assistant", @@ -32,62 +15,52 @@ const ClaudeMapper = { system: "user", } as const; +interface ToolDefinition { + function?: { + name: string; + description?: string; + parameters?: any; + }; +} + export class BedrockApi implements LLMApi { speech(options: SpeechOptions): Promise { throw new Error("Speech not implemented for Bedrock."); } extractMessage(res: any) { - console.log("[Response] Bedrock not stream response: ", res); - if (res.error) { - return "```\n" + JSON.stringify(res, null, 4) + "\n```"; - } - return res?.content ?? res; + if (res?.content?.[0]?.text) return res.content[0].text; + if (res?.messages?.[0]?.content?.[0]?.text) + return res.messages[0].content[0].text; + if (res?.delta?.text) return res.delta.text; + return ""; } - async chat(options: ChatOptions): Promise { + async chat(options: ChatOptions) { const visionModel = isVisionModel(options.config.model); - const shouldStream = !!options.config.stream; + const isClaude3 = options.config.model.startsWith("anthropic.claude-3"); + const modelConfig = { ...useAppConfig.getState().modelConfig, ...useChatStore.getState().currentSession().mask.modelConfig, - ...{ - model: options.config.model, - }, + model: options.config.model, }; - // try get base64image from local cache image_url - const messages: ChatOptions["messages"] = []; - for (const v of options.messages) { - const content = await preProcessImageContent(v.content); - messages.push({ role: v.role, content }); - } - - const keys = ["system", "user"]; - - // roles must alternate between "user" and "assistant" in claude, so add a fake assistant message between two user messages - for (let i = 0; i < messages.length - 1; i++) { - const message = messages[i]; - const nextMessage = messages[i + 1]; - - if (keys.includes(message.role) && keys.includes(nextMessage.role)) { - messages[i] = [ - message, - { - role: "assistant", - content: ";", - }, - ] as any; + let systemMessage = ""; + const messages = []; + for (const msg of options.messages) { + const content = await preProcessImageContent(msg.content); + if (msg.role === "system") { + systemMessage = getMessageTextContent(msg); + } else { + messages.push({ role: msg.role, content }); } } - const prompt = messages - .flat() - .filter((v) => { - if (!v.content) return false; - if (typeof v.content === "string" && !v.content.trim()) return false; - return true; - }) + const formattedMessages = messages + .filter( + (v) => v.content && (typeof v.content !== "string" || v.content.trim()), + ) .map((v) => { const { role, content } = v; const insideRole = ClaudeMapper[role] ?? "user"; @@ -95,200 +68,201 @@ export class BedrockApi implements LLMApi { if (!visionModel || typeof content === "string") { return { role: insideRole, - content: getMessageTextContent(v), + content: [{ type: "text", text: getMessageTextContent(v) }], }; } + return { role: insideRole, content: content .filter((v) => v.image_url || v.text) .map(({ type, text, image_url }) => { - if (type === "text") { - return { - type, - text: text!, - }; - } + if (type === "text") return { type, text: text! }; + const { url = "" } = image_url || {}; const colonIndex = url.indexOf(":"); const semicolonIndex = url.indexOf(";"); const comma = url.indexOf(","); - const mimeType = url.slice(colonIndex + 1, semicolonIndex); - const encodeType = url.slice(semicolonIndex + 1, comma); - const data = url.slice(comma + 1); - return { - type: "image" as const, + type: "image", source: { - type: encodeType, - media_type: mimeType, - data, + type: url.slice(semicolonIndex + 1, comma), + media_type: url.slice(colonIndex + 1, semicolonIndex), + data: url.slice(comma + 1), }, }; }), }; }); - if (prompt[0]?.role === "assistant") { - prompt.unshift({ - role: "user", - content: ";", - }); - } - const requestBody = { - modelId: options.config.model, - messages: prompt, - inferenceConfig: { - maxTokens: modelConfig.max_tokens, + anthropic_version: "bedrock-2023-05-31", + max_tokens: modelConfig.max_tokens, + messages: formattedMessages, + ...(systemMessage && { system: systemMessage }), + ...(modelConfig.temperature !== undefined && { temperature: modelConfig.temperature, - topP: modelConfig.top_p, - stopSequences: [], - }, - stream: shouldStream, + }), + ...(modelConfig.top_p !== undefined && { top_p: modelConfig.top_p }), + ...(isClaude3 && { top_k: 5 }), }; - const conversePath = `${ApiPath.Bedrock}/converse`; const controller = new AbortController(); options.onController?.(controller); - if (shouldStream) { - let currentToolUse: ChatMessageTool | null = null; - let index = -1; - const [tools, funcs] = usePluginStore - .getState() - .getAsTools( - useChatStore.getState().currentSession().mask?.plugin || [], - ); - return stream( - conversePath, - requestBody, - getHeaders(), - // @ts-ignore - tools.map((tool) => ({ - name: tool?.function?.name, - description: tool?.function?.description, - input_schema: tool?.function?.parameters, - })), - funcs, - controller, - // parseSSE - (text: string, runTools: ChatMessageTool[]) => { - // console.log("parseSSE", text, runTools); - let chunkJson: - | undefined - | { - type: "content_block_delta" | "content_block_stop"; - content_block?: { - type: "tool_use"; - id: string; - name: string; - }; - delta?: { - type: "text_delta" | "input_json_delta"; - text?: string; - partial_json?: string; - }; - index: number; - }; - chunkJson = JSON.parse(text); - - if (chunkJson?.content_block?.type == "tool_use") { - index += 1; - const id = chunkJson?.content_block.id; - const name = chunkJson?.content_block.name; - runTools.push({ - id, - type: "function", - function: { - name, - arguments: "", - }, - }); - } - if ( - chunkJson?.delta?.type == "input_json_delta" && - chunkJson?.delta?.partial_json - ) { - // @ts-ignore - runTools[index]["function"]["arguments"] += - chunkJson?.delta?.partial_json; - } - return chunkJson?.delta?.text; - }, - // processToolMessage, include tool_calls message and tool call results - ( - requestPayload: RequestPayload, - toolCallMessage: any, - toolCallResult: any[], - ) => { - // reset index value - index = -1; - // @ts-ignore - requestPayload?.messages?.splice( - // @ts-ignore - requestPayload?.messages?.length, - 0, - { - role: "assistant", - content: toolCallMessage.tool_calls.map( - (tool: ChatMessageTool) => ({ - type: "tool_use", - id: tool.id, - name: tool?.function?.name, - input: tool?.function?.arguments - ? JSON.parse(tool?.function?.arguments) - : {}, - }), - ), - }, - // @ts-ignore - ...toolCallResult.map((result) => ({ - role: "user", - content: [ - { - type: "tool_result", - tool_use_id: result.tool_call_id, - content: result.content, - }, - ], - })), - ); - }, - options, + const accessStore = useAccessStore.getState(); + if (!accessStore.isValidBedrock()) { + throw new Error( + "Invalid AWS credentials. Please check your configuration.", ); - } else { - const payload = { - method: "POST", - body: JSON.stringify(requestBody), - signal: controller.signal, - headers: { - ...getHeaders(), // get common headers - }, + } + + try { + const apiEndpoint = "/api/bedrock/chat"; + const headers = { + "Content-Type": "application/json", + "X-Region": accessStore.awsRegion, + "X-Access-Key": accessStore.awsAccessKey, + "X-Secret-Key": accessStore.awsSecretKey, + "X-Model-Id": modelConfig.model, + ...(accessStore.awsSessionToken && { + "X-Session-Token": accessStore.awsSessionToken, + }), }; - try { - controller.signal.onabort = () => - options.onFinish("", new Response(null, { status: 400 })); + if (options.config.stream) { + let index = -1; + let currentToolArgs = ""; + const [tools, funcs] = usePluginStore + .getState() + .getAsTools( + useChatStore.getState().currentSession().mask?.plugin || [], + ); + + return stream( + apiEndpoint, + requestBody, + headers, + (tools as ToolDefinition[]).map((tool) => ({ + name: tool?.function?.name, + description: tool?.function?.description, + input_schema: tool?.function?.parameters, + })), + funcs, + controller, + (text: string, runTools: ChatMessageTool[]) => { + try { + const chunkJson = JSON.parse(text); + if (chunkJson?.content_block?.type === "tool_use") { + index += 1; + currentToolArgs = ""; + const id = chunkJson.content_block?.id; + const name = chunkJson.content_block?.name; + if (id && name) { + runTools.push({ + id, + type: "function", + function: { name, arguments: "" }, + }); + } + } else if ( + chunkJson?.delta?.type === "input_json_delta" && + chunkJson.delta?.partial_json + ) { + currentToolArgs += chunkJson.delta.partial_json; + try { + JSON.parse(currentToolArgs); + if (index >= 0 && index < runTools.length) { + runTools[index].function!.arguments = currentToolArgs; + } + } catch (e) {} + } else if ( + chunkJson?.type === "content_block_stop" && + currentToolArgs && + index >= 0 && + index < runTools.length + ) { + try { + if (currentToolArgs.trim().endsWith(",")) { + currentToolArgs = currentToolArgs.slice(0, -1) + "}"; + } else if (!currentToolArgs.endsWith("}")) { + currentToolArgs += "}"; + } + JSON.parse(currentToolArgs); + runTools[index].function!.arguments = currentToolArgs; + } catch (e) {} + } + return this.extractMessage(chunkJson); + } catch (e) { + return ""; + } + }, + ( + requestPayload: any, + toolCallMessage: any, + toolCallResult: any[], + ) => { + index = -1; + currentToolArgs = ""; + if (requestPayload?.messages) { + requestPayload.messages.splice( + requestPayload.messages.length, + 0, + { + role: "assistant", + content: [ + { + type: "text", + text: JSON.stringify( + toolCallMessage.tool_calls.map( + (tool: ChatMessageTool) => ({ + type: "tool_use", + id: tool.id, + name: tool?.function?.name, + input: tool?.function?.arguments + ? JSON.parse(tool?.function?.arguments) + : {}, + }), + ), + ), + }, + ], + }, + ...toolCallResult.map((result) => ({ + role: "user", + content: [ + { + type: "text", + text: `Tool '${result.tool_call_id}' returned: ${result.content}`, + }, + ], + })), + ); + } + }, + options, + ); + } else { + const res = await fetch(apiEndpoint, { + method: "POST", + headers, + body: JSON.stringify(requestBody), + }); - const res = await fetch(conversePath, payload); const resJson = await res.json(); - const message = this.extractMessage(resJson); options.onFinish(message, res); - } catch (e) { - console.error("failed to chat", e); - options.onError?.(e as Error); } + } catch (e) { + options.onError?.(e as Error); } } + async usage() { - return { - used: 0, - total: 0, - }; + return { used: 0, total: 0 }; } + async models() { return []; } diff --git a/app/components/ui-lib.tsx b/app/components/ui-lib.tsx index 5f183c8c4..4b7a4798a 100644 --- a/app/components/ui-lib.tsx +++ b/app/components/ui-lib.tsx @@ -11,7 +11,7 @@ import MaxIcon from "../icons/max.svg"; import MinIcon from "../icons/min.svg"; import Locale from "../locales"; -import { maskSensitiveValue } from "../utils/encryption"; +import { maskSensitiveValue } from "../utils/aws"; import { createRoot } from "react-dom/client"; import React, { diff --git a/app/constant.ts b/app/constant.ts index 822b4158e..32b051c76 100644 --- a/app/constant.ts +++ b/app/constant.ts @@ -241,9 +241,10 @@ export const ChatGLM = { }; export const Bedrock = { - ChatPath: "converse", + ChatPath: "model", // Simplified path since we'll append the full path in bedrock.ts ApiVersion: "2023-11-01", - getEndpoint: (region: string = "us-west-2") =>`https://bedrock-runtime.${region}.amazonaws.com`, + getEndpoint: (region: string = "us-west-2") => + `https://bedrock-runtime.${region}.amazonaws.com`, }; export const DEFAULT_INPUT_TEMPLATE = `{{input}}`; // input / time / model / lang @@ -326,19 +327,43 @@ const openaiModels = [ ]; const bedrockModels = [ + // Amazon Titan Models + "amazon.titan-text-express-v1", + "amazon.titan-text-lite-v1", + "amazon.titan-text-agile-v1", + + // Cohere Models + "cohere.command-light-text-v14", + "cohere.command-r-plus-v1:0", + "cohere.command-r-v1:0", + "cohere.command-text-v14", + // Claude Models "anthropic.claude-3-haiku-20240307-v1:0", "anthropic.claude-3-5-haiku-20241022-v1:0", "anthropic.claude-3-sonnet-20240229-v1:0", "anthropic.claude-3-5-sonnet-20241022-v2:0", "anthropic.claude-3-opus-20240229-v1:0", + "anthropic.claude-2.1", + "anthropic.claude-v2", + "anthropic.claude-v1", + "anthropic.claude-instant-v1", // Meta Llama Models - "us.meta.llama3-2-11b-instruct-v1:0", - "us.meta.llama3-2-90b-instruct-v1:0", - //Mistral + "meta.llama2-13b-chat-v1", + "meta.llama2-70b-chat-v1", + "meta.llama3-8b-instruct-v1:0", + "meta.llama3-2-11b-instruct-v1:0", + "meta.llama3-2-90b-instruct-v1:0", + + // Mistral Models + "mistral.mistral-7b-instruct-v0:2", "mistral.mistral-large-2402-v1:0", "mistral.mistral-large-2407-v1:0", + + // AI21 Models + "ai21.j2-mid-v1", + "ai21.j2-ultra-v1", ]; const googleModels = [ diff --git a/app/store/access.ts b/app/store/access.ts index b628d09e3..75a8123a2 100644 --- a/app/store/access.ts +++ b/app/store/access.ts @@ -4,7 +4,6 @@ import { StoreKey, ApiPath, OPENAI_BASE_URL, - BEDROCK_BASE_URL, ANTHROPIC_BASE_URL, GEMINI_BASE_URL, BAIDU_BASE_URL, @@ -23,14 +22,12 @@ import { createPersistStore } from "../utils/store"; import { ensure } from "../utils/clone"; import { DEFAULT_CONFIG } from "./config"; import { getModelProvider } from "../utils/model"; -import { encrypt, decrypt } from "../utils/encryption"; let fetchState = 0; // 0 not fetch, 1 fetching, 2 done const isApp = getClientConfig()?.buildMode === "export"; const DEFAULT_OPENAI_URL = isApp ? OPENAI_BASE_URL : ApiPath.OpenAI; -const DEFAULT_BEDROCK_URL = isApp ? BEDROCK_BASE_URL : ApiPath.Bedrock; const DEFAULT_GOOGLE_URL = isApp ? GEMINI_BASE_URL : ApiPath.Google; @@ -64,13 +61,6 @@ const DEFAULT_ACCESS_STATE = { openaiUrl: DEFAULT_OPENAI_URL, openaiApiKey: "", - // bedrock - awsRegion: "", - awsAccessKey: "", - awsSecretKey: "", - awsSessionToken: "", - awsCognitoUser: false, - // azure azureUrl: "", azureApiKey: "", @@ -126,6 +116,12 @@ const DEFAULT_ACCESS_STATE = { chatglmUrl: DEFAULT_CHATGLM_URL, chatglmApiKey: "", + // aws bedrock + awsRegion: "", + awsAccessKey: "", + awsSecretKey: "", + awsSessionToken: "", + // server config needCode: true, hideUserApiKey: false, @@ -139,9 +135,6 @@ const DEFAULT_ACCESS_STATE = { edgeTTSVoiceName: "zh-CN-YunxiNeural", }; -type AccessState = typeof DEFAULT_ACCESS_STATE; -type BedrockCredentialKey = "awsAccessKey" | "awsSecretKey" | "awsSessionToken"; - export const useAccessStore = createPersistStore( { ...DEFAULT_ACCESS_STATE }, @@ -162,46 +155,6 @@ export const useAccessStore = createPersistStore( return ensure(get(), ["openaiApiKey"]); }, - isValidBedrock() { - const state = get(); - return ( - ensure(state, ["awsAccessKey", "awsSecretKey", "awsRegion"]) && - this.validateAwsCredentials( - this.getDecryptedAwsCredential("awsAccessKey"), - this.getDecryptedAwsCredential("awsSecretKey"), - state.awsRegion, - ) - ); - }, - - validateAwsCredentials( - accessKey: string, - secretKey: string, - region: string, - ) { - // Comprehensive AWS credential validation - const accessKeyRegex = /^(AKIA|A3T|ASIA)[A-Z0-9]{16}$/; - const regionRegex = /^[a-z]{2}-[a-z]+-\d+$/; - - return ( - accessKeyRegex.test(accessKey) && // Validate access key format - secretKey.length === 40 && // Validate secret key length - regionRegex.test(region) && // Validate region format - accessKey !== "" && - secretKey !== "" && - region !== "" - ); - }, - - setEncryptedAwsCredential(key: BedrockCredentialKey, value: string) { - set({ [key]: encrypt(value) }); - }, - - getDecryptedAwsCredential(key: BedrockCredentialKey): string { - const encryptedValue = get()[key]; - return encryptedValue ? decrypt(encryptedValue) : ""; - }, - isValidAzure() { return ensure(get(), ["azureUrl", "azureApiKey", "azureApiVersion"]); }, @@ -233,6 +186,7 @@ export const useAccessStore = createPersistStore( isValidMoonshot() { return ensure(get(), ["moonshotApiKey"]); }, + isValidIflytek() { return ensure(get(), ["iflytekApiKey"]); }, @@ -245,13 +199,16 @@ export const useAccessStore = createPersistStore( return ensure(get(), ["chatglmApiKey"]); }, + isValidBedrock() { + return ensure(get(), ["awsRegion", "awsAccessKey", "awsSecretKey"]); + }, + isAuthorized() { this.fetch(); // has token or has code or disabled access control return ( this.isValidOpenAI() || - this.isValidBedrock() || this.isValidAzure() || this.isValidGoogle() || this.isValidAnthropic() || @@ -263,6 +220,7 @@ export const useAccessStore = createPersistStore( this.isValidIflytek() || this.isValidXAI() || this.isValidChatGLM() || + this.isValidBedrock() || !this.enabledAccessControl() || (this.enabledAccessControl() && ensure(get(), ["accessCode"])) ); @@ -290,28 +248,8 @@ export const useAccessStore = createPersistStore( return res; }) .then((res: DangerConfig) => { - console.log("[Config] received DangerConfig server configuration"); + console.log("[Config] got config from server", res); set(() => ({ ...res })); - return res; - }) - .then((res: Partial) => { - console.log("[Config] received AccessState server configuration"); - // Encrypt Bedrock-related sensitive data before storing - const encryptedRes = { ...res }; - const keysToEncrypt: BedrockCredentialKey[] = [ - "awsAccessKey", - "awsSecretKey", - "awsSessionToken", - ]; - - keysToEncrypt.forEach((key) => { - const value = encryptedRes[key]; - if (value) { - (encryptedRes[key] as string) = encrypt(value as string); - } - }); - - set(() => ({ ...encryptedRes })); }) .catch(() => { console.error("[Config] failed to fetch config"); diff --git a/app/utils/aws.ts b/app/utils/aws.ts new file mode 100644 index 000000000..dfa0a92fe --- /dev/null +++ b/app/utils/aws.ts @@ -0,0 +1,236 @@ +import SHA256 from "crypto-js/sha256"; +import HmacSHA256 from "crypto-js/hmac-sha256"; +import Hex from "crypto-js/enc-hex"; +import Utf8 from "crypto-js/enc-utf8"; +import { AES, enc } from "crypto-js"; + +const SECRET_KEY = + process.env.ENCRYPTION_KEY || + "your-secret-key-xxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx"; +if (!SECRET_KEY || SECRET_KEY.length < 32) { + throw new Error( + "ENCRYPTION_KEY environment variable must be set with at least 32 characters", + ); +} + +export function encrypt(data: string): string { + if (!data) return ""; + try { + return AES.encrypt(data, SECRET_KEY).toString(); + } catch (error) { + console.error("Encryption failed:", error); + return data; + } +} + +export function decrypt(encryptedData: string): string { + if (!encryptedData) return ""; + try { + // Try to decrypt + const bytes = AES.decrypt(encryptedData, SECRET_KEY); + const decrypted = bytes.toString(enc.Utf8); + + // If decryption results in empty string but input wasn't empty, + // the input might already be decrypted + if (!decrypted && encryptedData) { + return encryptedData; + } + return decrypted; + } catch (error) { + // If decryption fails, the input might already be decrypted + return encryptedData; + } +} + +export function maskSensitiveValue(value: string): string { + if (!value) return ""; + if (value.length <= 4) return value; + return "*".repeat(value.length - 4) + value.slice(-4); +} + +export interface SignParams { + method: string; + url: string; + region: string; + accessKeyId: string; + secretAccessKey: string; + sessionToken?: string; + body: string; + service: string; +} + +function hmac( + key: string | CryptoJS.lib.WordArray, + data: string, +): CryptoJS.lib.WordArray { + if (typeof key === "string") { + key = Utf8.parse(key); + } + return HmacSHA256(data, key); +} + +function getSigningKey( + secretKey: string, + dateStamp: string, + region: string, + service: string, +): CryptoJS.lib.WordArray { + const kDate = hmac("AWS4" + secretKey, dateStamp); + const kRegion = hmac(kDate, region); + const kService = hmac(kRegion, service); + const kSigning = hmac(kService, "aws4_request"); + return kSigning; +} + +function normalizeHeaderValue(value: string): string { + return value.replace(/\s+/g, " ").trim(); +} + +function encodeURIComponent_RFC3986(str: string): string { + return encodeURIComponent(str) + .replace( + /[!'()*]/g, + (c) => "%" + c.charCodeAt(0).toString(16).toUpperCase(), + ) + .replace(/[-_.~]/g, (c) => c); // RFC 3986 unreserved characters +} + +function encodeURI_RFC3986(uri: string): string { + // Handle empty or root path + if (!uri || uri === "/") return ""; + + // Split the path into segments, preserving empty segments for double slashes + const segments = uri.split("/"); + + return segments + .map((segment) => { + if (!segment) return ""; + + // Special handling for Bedrock model paths + if (segment.includes("model/")) { + const parts = segment.split(/(model\/)/); + return parts + .map((part) => { + if (part === "model/") return part; + // Handle the model identifier part + if (part.includes(".") || part.includes(":")) { + return part + .split(/([.:])/g) + .map((subpart, i) => { + if (i % 2 === 1) return subpart; // Don't encode separators + return encodeURIComponent_RFC3986(subpart); + }) + .join(""); + } + return encodeURIComponent_RFC3986(part); + }) + .join(""); + } + + // Handle invoke-with-response-stream without encoding + if (segment === "invoke-with-response-stream") { + return segment; + } + + return encodeURIComponent_RFC3986(segment); + }) + .join("/"); +} + +export async function sign({ + method, + url, + region, + accessKeyId, + secretAccessKey, + sessionToken, + body, + service, +}: SignParams): Promise> { + const endpoint = new URL(url); + const canonicalUri = "/" + encodeURI_RFC3986(endpoint.pathname.slice(1)); + const canonicalQueryString = endpoint.search.slice(1); // Remove leading '?' + + // Create a date stamp and time stamp in ISO8601 format + const now = new Date(); + const amzDate = now.toISOString().replace(/[:-]|\.\d{3}/g, ""); + const dateStamp = amzDate.slice(0, 8); + + // Calculate the hash of the payload + const payloadHash = SHA256(body).toString(Hex); + + // Define headers with normalized values + const headers: Record = { + accept: "application/vnd.amazon.eventstream", + "content-type": "application/json", + host: endpoint.host, + "x-amz-content-sha256": payloadHash, + "x-amz-date": amzDate, + "x-amzn-bedrock-accept": "*/*", + }; + + // Add session token if present + if (sessionToken) { + headers["x-amz-security-token"] = sessionToken; + } + + // Get sorted header keys (case-insensitive) + const sortedHeaderKeys = Object.keys(headers).sort((a, b) => + a.toLowerCase().localeCompare(b.toLowerCase()), + ); + + // Create canonical headers string with normalized values + const canonicalHeaders = sortedHeaderKeys + .map( + (key) => `${key.toLowerCase()}:${normalizeHeaderValue(headers[key])}\n`, + ) + .join(""); + + // Create signed headers string + const signedHeaders = sortedHeaderKeys + .map((key) => key.toLowerCase()) + .join(";"); + + // Create canonical request + const canonicalRequest = [ + method.toUpperCase(), + canonicalUri, + canonicalQueryString, + canonicalHeaders, + signedHeaders, + payloadHash, + ].join("\n"); + + // Create the string to sign + const algorithm = "AWS4-HMAC-SHA256"; + const credentialScope = `${dateStamp}/${region}/${service}/aws4_request`; + const stringToSign = [ + algorithm, + amzDate, + credentialScope, + SHA256(canonicalRequest).toString(Hex), + ].join("\n"); + + // Calculate the signature + const signingKey = getSigningKey(secretAccessKey, dateStamp, region, service); + const signature = hmac(signingKey, stringToSign).toString(Hex); + + // Create the authorization header + const authorization = [ + `${algorithm} Credential=${accessKeyId}/${credentialScope}`, + `SignedHeaders=${signedHeaders}`, + `Signature=${signature}`, + ].join(", "); + + // Return headers with proper casing for the request + return { + Accept: headers.accept, + "Content-Type": headers["content-type"], + Host: headers.host, + "X-Amz-Content-Sha256": headers["x-amz-content-sha256"], + "X-Amz-Date": headers["x-amz-date"], + "X-Amzn-Bedrock-Accept": headers["x-amzn-bedrock-accept"], + ...(sessionToken && { "X-Amz-Security-Token": sessionToken }), + Authorization: authorization, + }; +} diff --git a/app/utils/encryption.ts b/app/utils/encryption.ts deleted file mode 100644 index ef750d9ab..000000000 --- a/app/utils/encryption.ts +++ /dev/null @@ -1,35 +0,0 @@ -import { AES, enc } from "crypto-js"; - -const SECRET_KEY = - process.env.ENCRYPTION_KEY || - "your-secret-key-xxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx"; // Replace this with a secure, randomly generated key -if (!SECRET_KEY || SECRET_KEY.length < 32) { - throw new Error( - "ENCRYPTION_KEY environment variable must be set with at least 32 characters", - ); -} - -export function encrypt(data: string): string { - try { - return AES.encrypt(data, SECRET_KEY).toString(); - } catch (error) { - console.error("Encryption failed:", error); - return data; // Fallback to unencrypted data if encryption fails - } -} - -export function decrypt(encryptedData: string): string { - try { - const bytes = AES.decrypt(encryptedData, SECRET_KEY); - return bytes.toString(enc.Utf8); - } catch (error) { - console.error("Decryption failed:", error); - return encryptedData; // Fallback to the original data if decryption fails - } -} - -export function maskSensitiveValue(value: string): string { - if (!value) return ""; - if (value.length <= 4) return value; - return "*".repeat(value.length - 4) + value.slice(-4); -} diff --git a/package.json b/package.json index 304aa40b4..57a63bcac 100644 --- a/package.json +++ b/package.json @@ -20,7 +20,6 @@ "test:ci": "jest --ci" }, "dependencies": { - "@aws-sdk/client-bedrock-runtime": "^3.679.0", "@fortaine/fetch-event-source": "^3.0.6", "@hello-pangea/dnd": "^16.5.0", "@next/third-parties": "^14.1.0",