From 35f5dc52e71195100ff477d4f07a68c444516a50 Mon Sep 17 00:00:00 2001 From: Hk-Gosuto Date: Mon, 27 Nov 2023 22:53:20 +0800 Subject: [PATCH] feat: support switch nodejs plugin --- Dockerfile | 1 + app/api/langchain/tool/agent/edge/route.ts | 331 ++++++++++++++++++ .../tool/agent/{ => nodejs}/route.ts | 10 +- app/client/platforms/openai.ts | 4 +- app/components/plugin.tsx | 6 +- 5 files changed, 343 insertions(+), 9 deletions(-) create mode 100644 app/api/langchain/tool/agent/edge/route.ts rename app/api/langchain/tool/agent/{ => nodejs}/route.ts (98%) diff --git a/Dockerfile b/Dockerfile index 720a0cfe9..7676a3fb8 100644 --- a/Dockerfile +++ b/Dockerfile @@ -17,6 +17,7 @@ RUN apk update && apk add --no-cache git ENV OPENAI_API_KEY="" ENV CODE="" +ENV NEXT_PUBLIC_ENABLE_NODEJS_PLUGIN=1 WORKDIR /app COPY --from=deps /app/node_modules ./node_modules diff --git a/app/api/langchain/tool/agent/edge/route.ts b/app/api/langchain/tool/agent/edge/route.ts new file mode 100644 index 000000000..398a6eef9 --- /dev/null +++ b/app/api/langchain/tool/agent/edge/route.ts @@ -0,0 +1,331 @@ +import { NextRequest, NextResponse } from "next/server"; +import { getServerSideConfig } from "@/app/config/server"; +import { auth } from "../../../../auth"; + +import { ChatOpenAI } from "langchain/chat_models/openai"; +import { BaseCallbackHandler } from "langchain/callbacks"; + +import { AIMessage, HumanMessage, SystemMessage } from "langchain/schema"; +import { BufferMemory, ChatMessageHistory } from "langchain/memory"; +import { initializeAgentExecutorWithOptions } from "langchain/agents"; +import { ACCESS_CODE_PREFIX } from "@/app/constant"; +import { OpenAI } from "langchain/llms/openai"; +import { OpenAIEmbeddings } from "langchain/embeddings/openai"; + +import * as langchainTools from "langchain/tools"; +import { HttpGetTool } from "@/app/api/langchain-tools/http_get"; +import { DuckDuckGo } from "@/app/api/langchain-tools/duckduckgo_search"; +import { WebBrowser } from "langchain/tools/webbrowser"; +import { Calculator } from "langchain/tools/calculator"; +import { DynamicTool, Tool } from "langchain/tools"; +import { DallEAPIWrapper } from "@/app/api/langchain-tools/dalle_image_generator"; +import { BaiduSearch } from "@/app/api/langchain-tools/baidu_search"; +import { GoogleSearch } from "@/app/api/langchain-tools/google_search"; +import { StableDiffusionWrapper } from "@/app/api/langchain-tools/stable_diffusion_image_generator"; +import { ArxivAPIWrapper } from "@/app/api/langchain-tools/arxiv"; +import dynamic from "next/dynamic"; + +const serverConfig = getServerSideConfig(); + +interface RequestMessage { + role: string; + content: string; +} + +interface RequestBody { + messages: RequestMessage[]; + model: string; + stream?: boolean; + temperature: number; + presence_penalty?: number; + frequency_penalty?: number; + top_p?: number; + baseUrl?: string; + apiKey?: string; + maxIterations: number; + returnIntermediateSteps: boolean; + useTools: (undefined | string)[]; +} + +class ResponseBody { + isSuccess: boolean = true; + message!: string; + isToolMessage: boolean = false; + toolName?: string; +} + +interface ToolInput { + input: string; +} + +async function handle(req: NextRequest) { + if (req.method === "OPTIONS") { + return NextResponse.json({ body: "OK" }, { status: 200 }); + } + try { + const authResult = auth(req); + if (authResult.error) { + return NextResponse.json(authResult, { + status: 401, + }); + } + + const encoder = new TextEncoder(); + const transformStream = new TransformStream(); + const writer = transformStream.writable.getWriter(); + const reqBody: RequestBody = await req.json(); + const authToken = req.headers.get("Authorization") ?? ""; + const token = authToken.trim().replaceAll("Bearer ", "").trim(); + const isOpenAiKey = !token.startsWith(ACCESS_CODE_PREFIX); + let useTools = reqBody.useTools ?? []; + let apiKey = serverConfig.apiKey; + if (isOpenAiKey && token) { + apiKey = token; + } + + // support base url + let baseUrl = "https://api.openai.com/v1"; + if (serverConfig.baseUrl) baseUrl = serverConfig.baseUrl; + if ( + reqBody.baseUrl?.startsWith("http://") || + reqBody.baseUrl?.startsWith("https://") + ) + baseUrl = reqBody.baseUrl; + if (!baseUrl.endsWith("/v1")) + baseUrl = baseUrl.endsWith("/") ? `${baseUrl}v1` : `${baseUrl}/v1`; + console.log("[baseUrl]", baseUrl); + + const handler = BaseCallbackHandler.fromMethods({ + async handleLLMNewToken(token: string) { + // console.log("[Token]", token); + if (token) { + var response = new ResponseBody(); + response.message = token; + await writer.ready; + await writer.write( + encoder.encode(`data: ${JSON.stringify(response)}\n\n`), + ); + } + }, + async handleChainError(err, runId, parentRunId, tags) { + console.log("[handleChainError]", err, "writer error"); + var response = new ResponseBody(); + response.isSuccess = false; + response.message = err; + await writer.ready; + await writer.write( + encoder.encode(`data: ${JSON.stringify(response)}\n\n`), + ); + await writer.close(); + }, + async handleChainEnd(outputs, runId, parentRunId, tags) { + console.log("[handleChainEnd]"); + await writer.ready; + await writer.close(); + }, + async handleLLMEnd() { + // await writer.ready; + // await writer.close(); + }, + async handleLLMError(e: Error) { + console.log("[handleLLMError]", e, "writer error"); + var response = new ResponseBody(); + response.isSuccess = false; + response.message = e.message; + await writer.ready; + await writer.write( + encoder.encode(`data: ${JSON.stringify(response)}\n\n`), + ); + await writer.close(); + }, + handleLLMStart(llm, _prompts: string[]) { + // console.log("handleLLMStart: I'm the second handler!!", { llm }); + }, + handleChainStart(chain) { + // console.log("handleChainStart: I'm the second handler!!", { chain }); + }, + async handleAgentAction(action) { + try { + console.log("[handleAgentAction]", action.tool); + if (!reqBody.returnIntermediateSteps) return; + var response = new ResponseBody(); + response.isToolMessage = true; + response.message = JSON.stringify(action.toolInput); + response.toolName = action.tool; + await writer.ready; + await writer.write( + encoder.encode(`data: ${JSON.stringify(response)}\n\n`), + ); + } catch (ex) { + console.error("[handleAgentAction]", ex); + var response = new ResponseBody(); + response.isSuccess = false; + response.message = (ex as Error).message; + await writer.ready; + await writer.write( + encoder.encode(`data: ${JSON.stringify(response)}\n\n`), + ); + await writer.close(); + } + }, + handleToolStart(tool, input) { + console.log("[handleToolStart]", { tool }); + }, + async handleToolEnd(output, runId, parentRunId, tags) { + console.log("[handleToolEnd]", { output, runId, parentRunId, tags }); + }, + handleAgentEnd(action, runId, parentRunId, tags) { + console.log("[handleAgentEnd]"); + }, + }); + + let searchTool: Tool = new DuckDuckGo(); + if (process.env.CHOOSE_SEARCH_ENGINE) { + switch (process.env.CHOOSE_SEARCH_ENGINE) { + case "google": + searchTool = new GoogleSearch(); + break; + case "baidu": + searchTool = new BaiduSearch(); + break; + } + } + if (process.env.BING_SEARCH_API_KEY) { + let bingSearchTool = new langchainTools["BingSerpAPI"]( + process.env.BING_SEARCH_API_KEY, + ); + searchTool = new DynamicTool({ + name: "bing_search", + description: bingSearchTool.description, + func: async (input: string) => bingSearchTool.call(input), + }); + } + if (process.env.SERPAPI_API_KEY) { + let serpAPITool = new langchainTools["SerpAPI"]( + process.env.SERPAPI_API_KEY, + ); + searchTool = new DynamicTool({ + name: "google_search", + description: serpAPITool.description, + func: async (input: string) => serpAPITool.call(input), + }); + } + + const model = new OpenAI( + { + temperature: 0, + modelName: reqBody.model, + openAIApiKey: apiKey, + }, + { basePath: baseUrl }, + ); + const embeddings = new OpenAIEmbeddings( + { + openAIApiKey: apiKey, + }, + { basePath: baseUrl }, + ); + + const tools = [ + // new RequestsGetTool(), + // new RequestsPostTool(), + ]; + const webBrowserTool = new WebBrowser({ model, embeddings }); + const calculatorTool = new Calculator(); + const dallEAPITool = new DallEAPIWrapper( + apiKey, + baseUrl, + async (data: string) => { + var response = new ResponseBody(); + response.message = data; + await writer.ready; + await writer.write( + encoder.encode(`data: ${JSON.stringify(response)}\n\n`), + ); + }, + ); + dallEAPITool.returnDirect = true; + const stableDiffusionTool = new StableDiffusionWrapper(); + const arxivAPITool = new ArxivAPIWrapper(); + if (useTools.includes("web-search")) tools.push(searchTool); + if (useTools.includes(webBrowserTool.name)) tools.push(webBrowserTool); + if (useTools.includes(calculatorTool.name)) tools.push(calculatorTool); + if (useTools.includes(dallEAPITool.name)) tools.push(dallEAPITool); + if (useTools.includes(stableDiffusionTool.name)) + tools.push(stableDiffusionTool); + if (useTools.includes(arxivAPITool.name)) tools.push(arxivAPITool); + + useTools.forEach((toolName) => { + if (toolName) { + var tool = langchainTools[ + toolName as keyof typeof langchainTools + ] as any; + if (tool) { + tools.push(new tool()); + } + } + }); + + const pastMessages = new Array(); + + reqBody.messages + .slice(0, reqBody.messages.length - 1) + .forEach((message) => { + if (message.role === "system") + pastMessages.push(new SystemMessage(message.content)); + if (message.role === "user") + pastMessages.push(new HumanMessage(message.content)); + if (message.role === "assistant") + pastMessages.push(new AIMessage(message.content)); + }); + + const memory = new BufferMemory({ + memoryKey: "chat_history", + returnMessages: true, + inputKey: "input", + outputKey: "output", + chatHistory: new ChatMessageHistory(pastMessages), + }); + + const llm = new ChatOpenAI( + { + modelName: reqBody.model, + openAIApiKey: apiKey, + temperature: reqBody.temperature, + streaming: reqBody.stream, + topP: reqBody.top_p, + presencePenalty: reqBody.presence_penalty, + frequencyPenalty: reqBody.frequency_penalty, + }, + { basePath: baseUrl }, + ); + const executor = await initializeAgentExecutorWithOptions(tools, llm, { + agentType: "openai-functions", + returnIntermediateSteps: reqBody.returnIntermediateSteps, + maxIterations: reqBody.maxIterations, + memory: memory, + }); + + executor.call( + { + input: reqBody.messages.slice(-1)[0].content, + }, + [handler], + ); + + console.log("returning response"); + return new Response(transformStream.readable, { + headers: { "Content-Type": "text/event-stream" }, + }); + } catch (e) { + return new Response(JSON.stringify({ error: (e as any).message }), { + status: 500, + headers: { "Content-Type": "application/json" }, + }); + } +} + +export const GET = handle; +export const POST = handle; + +export const runtime = "edge"; diff --git a/app/api/langchain/tool/agent/route.ts b/app/api/langchain/tool/agent/nodejs/route.ts similarity index 98% rename from app/api/langchain/tool/agent/route.ts rename to app/api/langchain/tool/agent/nodejs/route.ts index ca68f9328..05ec45733 100644 --- a/app/api/langchain/tool/agent/route.ts +++ b/app/api/langchain/tool/agent/nodejs/route.ts @@ -1,6 +1,6 @@ import { NextRequest, NextResponse } from "next/server"; import { getServerSideConfig } from "@/app/config/server"; -import { auth } from "../../../auth"; +import { auth } from "../../../../auth"; import { ChatOpenAI } from "langchain/chat_models/openai"; import { BaseCallbackHandler } from "langchain/callbacks"; @@ -23,6 +23,7 @@ import { BaiduSearch } from "@/app/api/langchain-tools/baidu_search"; import { GoogleSearch } from "@/app/api/langchain-tools/google_search"; import { StableDiffusionWrapper } from "@/app/api/langchain-tools/stable_diffusion_image_generator"; import { ArxivAPIWrapper } from "@/app/api/langchain-tools/arxiv"; +import dynamic from "next/dynamic"; import { PDFBrowser } from "@/app/api/langchain-tools/pdf_browser"; const serverConfig = getServerSideConfig(); @@ -247,6 +248,7 @@ async function handle(req: NextRequest) { dallEAPITool.returnDirect = true; const stableDiffusionTool = new StableDiffusionWrapper(); const arxivAPITool = new ArxivAPIWrapper(); + const pdfBrowserTool = new PDFBrowser(model, embeddings); if (useTools.includes("web-search")) tools.push(searchTool); if (useTools.includes(webBrowserTool.name)) tools.push(webBrowserTool); if (useTools.includes(calculatorTool.name)) tools.push(calculatorTool); @@ -254,9 +256,7 @@ async function handle(req: NextRequest) { if (useTools.includes(stableDiffusionTool.name)) tools.push(stableDiffusionTool); if (useTools.includes(arxivAPITool.name)) tools.push(arxivAPITool); - - const pdfBrowserTool = new PDFBrowser(model, embeddings); - tools.push(pdfBrowserTool); + if (useTools.includes(pdfBrowserTool.name)) tools.push(pdfBrowserTool); useTools.forEach((toolName) => { if (toolName) { @@ -331,4 +331,4 @@ async function handle(req: NextRequest) { export const GET = handle; export const POST = handle; -export const runtime = "edge"; +export const runtime = "nodejs"; diff --git a/app/client/platforms/openai.ts b/app/client/platforms/openai.ts index d07d23a66..1080658bb 100644 --- a/app/client/platforms/openai.ts +++ b/app/client/platforms/openai.ts @@ -255,7 +255,9 @@ export class ChatGPTApi implements LLMApi { options.onController?.(controller); try { - const path = "/api/langchain/tool/agent"; + let path = "/api/langchain/tool/agent/"; + const enableNodeJSPlugin = !!process.env.NEXT_PUBLIC_ENABLE_NODEJS_PLUGIN; + path = enableNodeJSPlugin ? path + "nodejs" : path + "edge"; const chatPayload = { method: "POST", body: JSON.stringify(requestPayload), diff --git a/app/components/plugin.tsx b/app/components/plugin.tsx index f81be1d2f..cf425bb8c 100644 --- a/app/components/plugin.tsx +++ b/app/components/plugin.tsx @@ -218,7 +218,7 @@ export function PluginPage() { }); }; - const serverConfig = getServerSideConfig(); + const enableNodeJSPlugin = !!process.env.NEXT_PUBLIC_ENABLE_NODEJS_PLUGIN; return ( @@ -260,7 +260,7 @@ export function PluginPage() {
{m.name}
- {m.onlyNodeRuntime && serverConfig.isVercel && ( + {m.onlyNodeRuntime && !enableNodeJSPlugin && (
{Locale.Plugin.RuntimeWarning}
@@ -274,7 +274,7 @@ export function PluginPage() {
{ updatePluginEnableStatus(m.id, e.currentTarget.checked);