From 293679fa64a1c5b186d16f92a565ce1b0a25b267 Mon Sep 17 00:00:00 2001 From: Hk-Gosuto Date: Tue, 28 Nov 2023 18:59:00 +0800 Subject: [PATCH] refactor: refactor agent api route --- app/api/langchain-tools/edge_tools.ts | 94 ++++++ app/api/langchain-tools/nodejs_tools.ts | 35 ++ app/api/langchain/tool/agent/agentapi.ts | 316 +++++++++++++++++++ app/api/langchain/tool/agent/edge/route.ts | 284 ++--------------- app/api/langchain/tool/agent/nodejs/route.ts | 295 ++--------------- 5 files changed, 502 insertions(+), 522 deletions(-) create mode 100644 app/api/langchain-tools/edge_tools.ts create mode 100644 app/api/langchain-tools/nodejs_tools.ts create mode 100644 app/api/langchain/tool/agent/agentapi.ts diff --git a/app/api/langchain-tools/edge_tools.ts b/app/api/langchain-tools/edge_tools.ts new file mode 100644 index 000000000..f67de5e8e --- /dev/null +++ b/app/api/langchain-tools/edge_tools.ts @@ -0,0 +1,94 @@ +import { ArxivAPIWrapper } from "@/app/api/langchain-tools/arxiv"; +import { DallEAPIWrapper } from "@/app/api/langchain-tools/dalle_image_generator"; +import { StableDiffusionWrapper } from "@/app/api/langchain-tools/stable_diffusion_image_generator"; +import { BaseLanguageModel } from "langchain/dist/base_language"; +import { Calculator } from "langchain/tools/calculator"; +import { WebBrowser } from "langchain/tools/webbrowser"; +import { BaiduSearch } from "@/app/api/langchain-tools/baidu_search"; +import { DuckDuckGo } from "@/app/api/langchain-tools/duckduckgo_search"; +import { GoogleSearch } from "@/app/api/langchain-tools/google_search"; +import { Tool, DynamicTool } from "langchain/tools"; +import * as langchainTools from "langchain/tools"; +import { Embeddings } from "langchain/dist/embeddings/base.js"; +import { promises } from "dns"; + +export class EdgeTool { + private apiKey: string | undefined; + + private baseUrl: string; + + private model: BaseLanguageModel; + + private embeddings: Embeddings; + + private callback?: (data: string) => Promise; + + constructor( + apiKey: string | undefined, + baseUrl: string, + model: BaseLanguageModel, + embeddings: Embeddings, + callback?: (data: string) => Promise, + ) { + this.apiKey = apiKey; + this.baseUrl = baseUrl; + this.model = model; + this.embeddings = embeddings; + this.callback = callback; + } + + async getCustomTools(): Promise { + // let searchTool: Tool = new DuckDuckGo(); + // if (process.env.CHOOSE_SEARCH_ENGINE) { + // switch (process.env.CHOOSE_SEARCH_ENGINE) { + // case "google": + // searchTool = new GoogleSearch(); + // break; + // case "baidu": + // searchTool = new BaiduSearch(); + // break; + // } + // } + // if (process.env.BING_SEARCH_API_KEY) { + // let bingSearchTool = new langchainTools["BingSerpAPI"]( + // process.env.BING_SEARCH_API_KEY, + // ); + // searchTool = new DynamicTool({ + // name: "bing_search", + // description: bingSearchTool.description, + // func: async (input: string) => bingSearchTool.call(input), + // }); + // } + // if (process.env.SERPAPI_API_KEY) { + // let serpAPITool = new langchainTools["SerpAPI"]( + // process.env.SERPAPI_API_KEY, + // ); + // searchTool = new DynamicTool({ + // name: "google_search", + // description: serpAPITool.description, + // func: async (input: string) => serpAPITool.call(input), + // }); + // } + const webBrowserTool = new WebBrowser({ + model: this.model, + embeddings: this.embeddings, + }); + const calculatorTool = new Calculator(); + const dallEAPITool = new DallEAPIWrapper( + this.apiKey, + this.baseUrl, + this.callback, + ); + dallEAPITool.returnDirect = true; + const stableDiffusionTool = new StableDiffusionWrapper(); + const arxivAPITool = new ArxivAPIWrapper(); + return [ + // searchTool, + calculatorTool, + webBrowserTool, + dallEAPITool, + stableDiffusionTool, + arxivAPITool, + ]; + } +} diff --git a/app/api/langchain-tools/nodejs_tools.ts b/app/api/langchain-tools/nodejs_tools.ts new file mode 100644 index 000000000..af01dd2f2 --- /dev/null +++ b/app/api/langchain-tools/nodejs_tools.ts @@ -0,0 +1,35 @@ +import { BaseLanguageModel } from "langchain/dist/base_language"; +import { PDFBrowser } from "@/app/api/langchain-tools/pdf_browser"; + +import { Embeddings } from "langchain/dist/embeddings/base.js"; + +export class NodeJSTool { + private apiKey: string | undefined; + + private baseUrl: string; + + private model: BaseLanguageModel; + + private embeddings: Embeddings; + + private callback?: (data: string) => Promise; + + constructor( + apiKey: string | undefined, + baseUrl: string, + model: BaseLanguageModel, + embeddings: Embeddings, + callback?: (data: string) => Promise, + ) { + this.apiKey = apiKey; + this.baseUrl = baseUrl; + this.model = model; + this.embeddings = embeddings; + this.callback = callback; + } + + async getCustomTools(): Promise { + const pdfBrowserTool = new PDFBrowser(this.model, this.embeddings); + return [pdfBrowserTool]; + } +} diff --git a/app/api/langchain/tool/agent/agentapi.ts b/app/api/langchain/tool/agent/agentapi.ts new file mode 100644 index 000000000..e4ca14840 --- /dev/null +++ b/app/api/langchain/tool/agent/agentapi.ts @@ -0,0 +1,316 @@ +import { NextRequest, NextResponse } from "next/server"; +import { getServerSideConfig } from "@/app/config/server"; +import { auth } from "../../../auth"; + +import { ChatOpenAI } from "langchain/chat_models/openai"; +import { BaseCallbackHandler } from "langchain/callbacks"; + +import { AIMessage, HumanMessage, SystemMessage } from "langchain/schema"; +import { BufferMemory, ChatMessageHistory } from "langchain/memory"; +import { initializeAgentExecutorWithOptions } from "langchain/agents"; +import { ACCESS_CODE_PREFIX } from "@/app/constant"; +import { OpenAI } from "langchain/llms/openai"; +import { OpenAIEmbeddings } from "langchain/embeddings/openai"; + +import * as langchainTools from "langchain/tools"; +import { HttpGetTool } from "@/app/api/langchain-tools/http_get"; +import { DuckDuckGo } from "@/app/api/langchain-tools/duckduckgo_search"; +import { WebBrowser } from "langchain/tools/webbrowser"; +import { Calculator } from "langchain/tools/calculator"; +import { DynamicTool, Tool } from "langchain/tools"; +import { DallEAPIWrapper } from "@/app/api/langchain-tools/dalle_image_generator"; +import { BaiduSearch } from "@/app/api/langchain-tools/baidu_search"; +import { GoogleSearch } from "@/app/api/langchain-tools/google_search"; +import { StableDiffusionWrapper } from "@/app/api/langchain-tools/stable_diffusion_image_generator"; +import { ArxivAPIWrapper } from "@/app/api/langchain-tools/arxiv"; +import { PDFBrowser } from "@/app/api/langchain-tools/pdf_browser"; + +export interface RequestMessage { + role: string; + content: string; +} + +export interface RequestBody { + messages: RequestMessage[]; + model: string; + stream?: boolean; + temperature: number; + presence_penalty?: number; + frequency_penalty?: number; + top_p?: number; + baseUrl?: string; + apiKey?: string; + maxIterations: number; + returnIntermediateSteps: boolean; + useTools: (undefined | string)[]; +} + +export class ResponseBody { + isSuccess: boolean = true; + message!: string; + isToolMessage: boolean = false; + toolName?: string; +} + +export interface ToolInput { + input: string; +} + +export class AgentApi { + private encoder: TextEncoder; + private transformStream: TransformStream; + private writer: WritableStreamDefaultWriter; + + constructor( + encoder: TextEncoder, + transformStream: TransformStream, + writer: WritableStreamDefaultWriter, + ) { + this.encoder = encoder; + this.transformStream = transformStream; + this.writer = writer; + } + + async getHandler(reqBody: any) { + var writer = this.writer; + var encoder = this.encoder; + return BaseCallbackHandler.fromMethods({ + async handleLLMNewToken(token: string) { + if (token) { + var response = new ResponseBody(); + response.message = token; + await writer.ready; + await writer.write( + encoder.encode(`data: ${JSON.stringify(response)}\n\n`), + ); + } + }, + async handleChainError(err, runId, parentRunId, tags) { + console.log("[handleChainError]", err, "writer error"); + var response = new ResponseBody(); + response.isSuccess = false; + response.message = err; + await writer.ready; + await writer.write( + encoder.encode(`data: ${JSON.stringify(response)}\n\n`), + ); + await writer.close(); + }, + async handleChainEnd(outputs, runId, parentRunId, tags) { + console.log("[handleChainEnd]"); + await writer.ready; + await writer.close(); + }, + async handleLLMEnd() { + // await writer.ready; + // await writer.close(); + }, + async handleLLMError(e: Error) { + console.log("[handleLLMError]", e, "writer error"); + var response = new ResponseBody(); + response.isSuccess = false; + response.message = e.message; + await writer.ready; + await writer.write( + encoder.encode(`data: ${JSON.stringify(response)}\n\n`), + ); + await writer.close(); + }, + handleLLMStart(llm, _prompts: string[]) { + // console.log("handleLLMStart: I'm the second handler!!", { llm }); + }, + handleChainStart(chain) { + // console.log("handleChainStart: I'm the second handler!!", { chain }); + }, + async handleAgentAction(action) { + try { + console.log("[handleAgentAction]", action.tool); + if (!reqBody.returnIntermediateSteps) return; + var response = new ResponseBody(); + response.isToolMessage = true; + response.message = JSON.stringify(action.toolInput); + response.toolName = action.tool; + await writer.ready; + await writer.write( + encoder.encode(`data: ${JSON.stringify(response)}\n\n`), + ); + } catch (ex) { + console.error("[handleAgentAction]", ex); + var response = new ResponseBody(); + response.isSuccess = false; + response.message = (ex as Error).message; + await writer.ready; + await writer.write( + encoder.encode(`data: ${JSON.stringify(response)}\n\n`), + ); + await writer.close(); + } + }, + handleToolStart(tool, input) { + console.log("[handleToolStart]", { tool }); + }, + async handleToolEnd(output, runId, parentRunId, tags) { + console.log("[handleToolEnd]", { output, runId, parentRunId, tags }); + }, + handleAgentEnd(action, runId, parentRunId, tags) { + console.log("[handleAgentEnd]"); + }, + }); + } + + async getApiHandler( + req: NextRequest, + reqBody: RequestBody, + customTools: any[], + ) { + if (req.method === "OPTIONS") { + return NextResponse.json({ body: "OK" }, { status: 200 }); + } + try { + const authResult = auth(req); + if (authResult.error) { + return NextResponse.json(authResult, { + status: 401, + }); + } + + const serverConfig = getServerSideConfig(); + + // const reqBody: RequestBody = await req.json(); + const authToken = req.headers.get("Authorization") ?? ""; + const token = authToken.trim().replaceAll("Bearer ", "").trim(); + const isOpenAiKey = !token.startsWith(ACCESS_CODE_PREFIX); + let useTools = reqBody.useTools ?? []; + let apiKey = serverConfig.apiKey; + if (isOpenAiKey && token) { + apiKey = token; + } + + let baseUrl = "https://api.openai.com/v1"; + if (serverConfig.baseUrl) baseUrl = serverConfig.baseUrl; + if ( + reqBody.baseUrl?.startsWith("http://") || + reqBody.baseUrl?.startsWith("https://") + ) + baseUrl = reqBody.baseUrl; + if (!baseUrl.endsWith("/v1")) + baseUrl = baseUrl.endsWith("/") ? `${baseUrl}v1` : `${baseUrl}/v1`; + console.log("[baseUrl]", baseUrl); + + var handler = await this.getHandler(reqBody); + + let searchTool: Tool = new DuckDuckGo(); + if (process.env.CHOOSE_SEARCH_ENGINE) { + switch (process.env.CHOOSE_SEARCH_ENGINE) { + case "google": + searchTool = new GoogleSearch(); + break; + case "baidu": + searchTool = new BaiduSearch(); + break; + } + } + if (process.env.BING_SEARCH_API_KEY) { + let bingSearchTool = new langchainTools["BingSerpAPI"]( + process.env.BING_SEARCH_API_KEY, + ); + searchTool = new DynamicTool({ + name: "bing_search", + description: bingSearchTool.description, + func: async (input: string) => bingSearchTool.call(input), + }); + } + if (process.env.SERPAPI_API_KEY) { + let serpAPITool = new langchainTools["SerpAPI"]( + process.env.SERPAPI_API_KEY, + ); + searchTool = new DynamicTool({ + name: "google_search", + description: serpAPITool.description, + func: async (input: string) => serpAPITool.call(input), + }); + } + + const tools = []; + + if (useTools.includes("web-search")) tools.push(searchTool); + console.log(customTools); + + customTools.forEach((customTool) => { + if (customTool) { + if (useTools.includes(customTool.name)) { + tools.push(customTool); + } + } + }); + + useTools.forEach((toolName) => { + if (toolName) { + var tool = langchainTools[ + toolName as keyof typeof langchainTools + ] as any; + if (tool) { + tools.push(new tool()); + } + } + }); + + const pastMessages = new Array(); + + reqBody.messages + .slice(0, reqBody.messages.length - 1) + .forEach((message) => { + if (message.role === "system") + pastMessages.push(new SystemMessage(message.content)); + if (message.role === "user") + pastMessages.push(new HumanMessage(message.content)); + if (message.role === "assistant") + pastMessages.push(new AIMessage(message.content)); + }); + + const memory = new BufferMemory({ + memoryKey: "chat_history", + returnMessages: true, + inputKey: "input", + outputKey: "output", + chatHistory: new ChatMessageHistory(pastMessages), + }); + + const llm = new ChatOpenAI( + { + modelName: reqBody.model, + openAIApiKey: apiKey, + temperature: reqBody.temperature, + streaming: reqBody.stream, + topP: reqBody.top_p, + presencePenalty: reqBody.presence_penalty, + frequencyPenalty: reqBody.frequency_penalty, + }, + { basePath: baseUrl }, + ); + const executor = await initializeAgentExecutorWithOptions(tools, llm, { + agentType: "openai-functions", + returnIntermediateSteps: reqBody.returnIntermediateSteps, + maxIterations: reqBody.maxIterations, + memory: memory, + }); + + executor.call( + { + input: reqBody.messages.slice(-1)[0].content, + }, + [handler], + ); + + console.log("returning response"); + return new Response(this.transformStream.readable, { + headers: { "Content-Type": "text/event-stream" }, + }); + } catch (e) { + return new Response(JSON.stringify({ error: (e as any).message }), { + status: 500, + headers: { "Content-Type": "application/json" }, + }); + } + } +} diff --git a/app/api/langchain/tool/agent/edge/route.ts b/app/api/langchain/tool/agent/edge/route.ts index 398a6eef9..43f7c1b1c 100644 --- a/app/api/langchain/tool/agent/edge/route.ts +++ b/app/api/langchain/tool/agent/edge/route.ts @@ -1,63 +1,12 @@ import { NextRequest, NextResponse } from "next/server"; -import { getServerSideConfig } from "@/app/config/server"; -import { auth } from "../../../../auth"; - -import { ChatOpenAI } from "langchain/chat_models/openai"; -import { BaseCallbackHandler } from "langchain/callbacks"; - -import { AIMessage, HumanMessage, SystemMessage } from "langchain/schema"; -import { BufferMemory, ChatMessageHistory } from "langchain/memory"; -import { initializeAgentExecutorWithOptions } from "langchain/agents"; +import { AgentApi, RequestBody, ResponseBody } from "../agentapi"; +import { auth } from "@/app/api/auth"; +import { EdgeTool } from "../../../../langchain-tools/edge_tools"; import { ACCESS_CODE_PREFIX } from "@/app/constant"; +import { getServerSideConfig } from "@/app/config/server"; import { OpenAI } from "langchain/llms/openai"; import { OpenAIEmbeddings } from "langchain/embeddings/openai"; -import * as langchainTools from "langchain/tools"; -import { HttpGetTool } from "@/app/api/langchain-tools/http_get"; -import { DuckDuckGo } from "@/app/api/langchain-tools/duckduckgo_search"; -import { WebBrowser } from "langchain/tools/webbrowser"; -import { Calculator } from "langchain/tools/calculator"; -import { DynamicTool, Tool } from "langchain/tools"; -import { DallEAPIWrapper } from "@/app/api/langchain-tools/dalle_image_generator"; -import { BaiduSearch } from "@/app/api/langchain-tools/baidu_search"; -import { GoogleSearch } from "@/app/api/langchain-tools/google_search"; -import { StableDiffusionWrapper } from "@/app/api/langchain-tools/stable_diffusion_image_generator"; -import { ArxivAPIWrapper } from "@/app/api/langchain-tools/arxiv"; -import dynamic from "next/dynamic"; - -const serverConfig = getServerSideConfig(); - -interface RequestMessage { - role: string; - content: string; -} - -interface RequestBody { - messages: RequestMessage[]; - model: string; - stream?: boolean; - temperature: number; - presence_penalty?: number; - frequency_penalty?: number; - top_p?: number; - baseUrl?: string; - apiKey?: string; - maxIterations: number; - returnIntermediateSteps: boolean; - useTools: (undefined | string)[]; -} - -class ResponseBody { - isSuccess: boolean = true; - message!: string; - isToolMessage: boolean = false; - toolName?: string; -} - -interface ToolInput { - input: string; -} - async function handle(req: NextRequest) { if (req.method === "OPTIONS") { return NextResponse.json({ body: "OK" }, { status: 200 }); @@ -70,14 +19,17 @@ async function handle(req: NextRequest) { }); } + const serverConfig = getServerSideConfig(); + const encoder = new TextEncoder(); const transformStream = new TransformStream(); const writer = transformStream.writable.getWriter(); + const reqBody: RequestBody = await req.json(); const authToken = req.headers.get("Authorization") ?? ""; const token = authToken.trim().replaceAll("Bearer ", "").trim(); const isOpenAiKey = !token.startsWith(ACCESS_CODE_PREFIX); - let useTools = reqBody.useTools ?? []; + let apiKey = serverConfig.apiKey; if (isOpenAiKey && token) { apiKey = token; @@ -95,122 +47,6 @@ async function handle(req: NextRequest) { baseUrl = baseUrl.endsWith("/") ? `${baseUrl}v1` : `${baseUrl}/v1`; console.log("[baseUrl]", baseUrl); - const handler = BaseCallbackHandler.fromMethods({ - async handleLLMNewToken(token: string) { - // console.log("[Token]", token); - if (token) { - var response = new ResponseBody(); - response.message = token; - await writer.ready; - await writer.write( - encoder.encode(`data: ${JSON.stringify(response)}\n\n`), - ); - } - }, - async handleChainError(err, runId, parentRunId, tags) { - console.log("[handleChainError]", err, "writer error"); - var response = new ResponseBody(); - response.isSuccess = false; - response.message = err; - await writer.ready; - await writer.write( - encoder.encode(`data: ${JSON.stringify(response)}\n\n`), - ); - await writer.close(); - }, - async handleChainEnd(outputs, runId, parentRunId, tags) { - console.log("[handleChainEnd]"); - await writer.ready; - await writer.close(); - }, - async handleLLMEnd() { - // await writer.ready; - // await writer.close(); - }, - async handleLLMError(e: Error) { - console.log("[handleLLMError]", e, "writer error"); - var response = new ResponseBody(); - response.isSuccess = false; - response.message = e.message; - await writer.ready; - await writer.write( - encoder.encode(`data: ${JSON.stringify(response)}\n\n`), - ); - await writer.close(); - }, - handleLLMStart(llm, _prompts: string[]) { - // console.log("handleLLMStart: I'm the second handler!!", { llm }); - }, - handleChainStart(chain) { - // console.log("handleChainStart: I'm the second handler!!", { chain }); - }, - async handleAgentAction(action) { - try { - console.log("[handleAgentAction]", action.tool); - if (!reqBody.returnIntermediateSteps) return; - var response = new ResponseBody(); - response.isToolMessage = true; - response.message = JSON.stringify(action.toolInput); - response.toolName = action.tool; - await writer.ready; - await writer.write( - encoder.encode(`data: ${JSON.stringify(response)}\n\n`), - ); - } catch (ex) { - console.error("[handleAgentAction]", ex); - var response = new ResponseBody(); - response.isSuccess = false; - response.message = (ex as Error).message; - await writer.ready; - await writer.write( - encoder.encode(`data: ${JSON.stringify(response)}\n\n`), - ); - await writer.close(); - } - }, - handleToolStart(tool, input) { - console.log("[handleToolStart]", { tool }); - }, - async handleToolEnd(output, runId, parentRunId, tags) { - console.log("[handleToolEnd]", { output, runId, parentRunId, tags }); - }, - handleAgentEnd(action, runId, parentRunId, tags) { - console.log("[handleAgentEnd]"); - }, - }); - - let searchTool: Tool = new DuckDuckGo(); - if (process.env.CHOOSE_SEARCH_ENGINE) { - switch (process.env.CHOOSE_SEARCH_ENGINE) { - case "google": - searchTool = new GoogleSearch(); - break; - case "baidu": - searchTool = new BaiduSearch(); - break; - } - } - if (process.env.BING_SEARCH_API_KEY) { - let bingSearchTool = new langchainTools["BingSerpAPI"]( - process.env.BING_SEARCH_API_KEY, - ); - searchTool = new DynamicTool({ - name: "bing_search", - description: bingSearchTool.description, - func: async (input: string) => bingSearchTool.call(input), - }); - } - if (process.env.SERPAPI_API_KEY) { - let serpAPITool = new langchainTools["SerpAPI"]( - process.env.SERPAPI_API_KEY, - ); - searchTool = new DynamicTool({ - name: "google_search", - description: serpAPITool.description, - func: async (input: string) => serpAPITool.call(input), - }); - } - const model = new OpenAI( { temperature: 0, @@ -226,97 +62,25 @@ async function handle(req: NextRequest) { { basePath: baseUrl }, ); - const tools = [ - // new RequestsGetTool(), - // new RequestsPostTool(), - ]; - const webBrowserTool = new WebBrowser({ model, embeddings }); - const calculatorTool = new Calculator(); - const dallEAPITool = new DallEAPIWrapper( + var dalleCallback = async (data: string) => { + var response = new ResponseBody(); + response.message = data; + await writer.ready; + await writer.write( + encoder.encode(`data: ${JSON.stringify(response)}\n\n`), + ); + }; + + var edgeTool = new EdgeTool( apiKey, baseUrl, - async (data: string) => { - var response = new ResponseBody(); - response.message = data; - await writer.ready; - await writer.write( - encoder.encode(`data: ${JSON.stringify(response)}\n\n`), - ); - }, + model, + embeddings, + dalleCallback, ); - dallEAPITool.returnDirect = true; - const stableDiffusionTool = new StableDiffusionWrapper(); - const arxivAPITool = new ArxivAPIWrapper(); - if (useTools.includes("web-search")) tools.push(searchTool); - if (useTools.includes(webBrowserTool.name)) tools.push(webBrowserTool); - if (useTools.includes(calculatorTool.name)) tools.push(calculatorTool); - if (useTools.includes(dallEAPITool.name)) tools.push(dallEAPITool); - if (useTools.includes(stableDiffusionTool.name)) - tools.push(stableDiffusionTool); - if (useTools.includes(arxivAPITool.name)) tools.push(arxivAPITool); - - useTools.forEach((toolName) => { - if (toolName) { - var tool = langchainTools[ - toolName as keyof typeof langchainTools - ] as any; - if (tool) { - tools.push(new tool()); - } - } - }); - - const pastMessages = new Array(); - - reqBody.messages - .slice(0, reqBody.messages.length - 1) - .forEach((message) => { - if (message.role === "system") - pastMessages.push(new SystemMessage(message.content)); - if (message.role === "user") - pastMessages.push(new HumanMessage(message.content)); - if (message.role === "assistant") - pastMessages.push(new AIMessage(message.content)); - }); - - const memory = new BufferMemory({ - memoryKey: "chat_history", - returnMessages: true, - inputKey: "input", - outputKey: "output", - chatHistory: new ChatMessageHistory(pastMessages), - }); - - const llm = new ChatOpenAI( - { - modelName: reqBody.model, - openAIApiKey: apiKey, - temperature: reqBody.temperature, - streaming: reqBody.stream, - topP: reqBody.top_p, - presencePenalty: reqBody.presence_penalty, - frequencyPenalty: reqBody.frequency_penalty, - }, - { basePath: baseUrl }, - ); - const executor = await initializeAgentExecutorWithOptions(tools, llm, { - agentType: "openai-functions", - returnIntermediateSteps: reqBody.returnIntermediateSteps, - maxIterations: reqBody.maxIterations, - memory: memory, - }); - - executor.call( - { - input: reqBody.messages.slice(-1)[0].content, - }, - [handler], - ); - - console.log("returning response"); - return new Response(transformStream.readable, { - headers: { "Content-Type": "text/event-stream" }, - }); + var tools = await edgeTool.getCustomTools(); + var agentApi = new AgentApi(encoder, transformStream, writer); + return await agentApi.getApiHandler(req, reqBody, tools); } catch (e) { return new Response(JSON.stringify({ error: (e as any).message }), { status: 500, diff --git a/app/api/langchain/tool/agent/nodejs/route.ts b/app/api/langchain/tool/agent/nodejs/route.ts index 05ec45733..debbd6c21 100644 --- a/app/api/langchain/tool/agent/nodejs/route.ts +++ b/app/api/langchain/tool/agent/nodejs/route.ts @@ -1,63 +1,12 @@ import { NextRequest, NextResponse } from "next/server"; -import { getServerSideConfig } from "@/app/config/server"; -import { auth } from "../../../../auth"; - -import { ChatOpenAI } from "langchain/chat_models/openai"; -import { BaseCallbackHandler } from "langchain/callbacks"; - -import { AIMessage, HumanMessage, SystemMessage } from "langchain/schema"; -import { BufferMemory, ChatMessageHistory } from "langchain/memory"; -import { initializeAgentExecutorWithOptions } from "langchain/agents"; +import { AgentApi, RequestBody, ResponseBody } from "../agentapi"; +import { auth } from "@/app/api/auth"; +import { EdgeTool } from "../../../../langchain-tools/edge_tools"; import { ACCESS_CODE_PREFIX } from "@/app/constant"; +import { getServerSideConfig } from "@/app/config/server"; import { OpenAI } from "langchain/llms/openai"; import { OpenAIEmbeddings } from "langchain/embeddings/openai"; - -import * as langchainTools from "langchain/tools"; -import { HttpGetTool } from "@/app/api/langchain-tools/http_get"; -import { DuckDuckGo } from "@/app/api/langchain-tools/duckduckgo_search"; -import { WebBrowser } from "langchain/tools/webbrowser"; -import { Calculator } from "langchain/tools/calculator"; -import { DynamicTool, Tool } from "langchain/tools"; -import { DallEAPIWrapper } from "@/app/api/langchain-tools/dalle_image_generator"; -import { BaiduSearch } from "@/app/api/langchain-tools/baidu_search"; -import { GoogleSearch } from "@/app/api/langchain-tools/google_search"; -import { StableDiffusionWrapper } from "@/app/api/langchain-tools/stable_diffusion_image_generator"; -import { ArxivAPIWrapper } from "@/app/api/langchain-tools/arxiv"; -import dynamic from "next/dynamic"; -import { PDFBrowser } from "@/app/api/langchain-tools/pdf_browser"; - -const serverConfig = getServerSideConfig(); - -interface RequestMessage { - role: string; - content: string; -} - -interface RequestBody { - messages: RequestMessage[]; - model: string; - stream?: boolean; - temperature: number; - presence_penalty?: number; - frequency_penalty?: number; - top_p?: number; - baseUrl?: string; - apiKey?: string; - maxIterations: number; - returnIntermediateSteps: boolean; - useTools: (undefined | string)[]; -} - -class ResponseBody { - isSuccess: boolean = true; - message!: string; - isToolMessage: boolean = false; - toolName?: string; -} - -interface ToolInput { - input: string; -} +import { NodeJSTool } from "@/app/api/langchain-tools/nodejs_tools"; async function handle(req: NextRequest) { if (req.method === "OPTIONS") { @@ -71,14 +20,17 @@ async function handle(req: NextRequest) { }); } + const serverConfig = getServerSideConfig(); + const encoder = new TextEncoder(); const transformStream = new TransformStream(); const writer = transformStream.writable.getWriter(); + const reqBody: RequestBody = await req.json(); const authToken = req.headers.get("Authorization") ?? ""; const token = authToken.trim().replaceAll("Bearer ", "").trim(); const isOpenAiKey = !token.startsWith(ACCESS_CODE_PREFIX); - let useTools = reqBody.useTools ?? []; + let apiKey = serverConfig.apiKey; if (isOpenAiKey && token) { apiKey = token; @@ -96,122 +48,6 @@ async function handle(req: NextRequest) { baseUrl = baseUrl.endsWith("/") ? `${baseUrl}v1` : `${baseUrl}/v1`; console.log("[baseUrl]", baseUrl); - const handler = BaseCallbackHandler.fromMethods({ - async handleLLMNewToken(token: string) { - // console.log("[Token]", token); - if (token) { - var response = new ResponseBody(); - response.message = token; - await writer.ready; - await writer.write( - encoder.encode(`data: ${JSON.stringify(response)}\n\n`), - ); - } - }, - async handleChainError(err, runId, parentRunId, tags) { - console.log("[handleChainError]", err, "writer error"); - var response = new ResponseBody(); - response.isSuccess = false; - response.message = err; - await writer.ready; - await writer.write( - encoder.encode(`data: ${JSON.stringify(response)}\n\n`), - ); - await writer.close(); - }, - async handleChainEnd(outputs, runId, parentRunId, tags) { - console.log("[handleChainEnd]"); - await writer.ready; - await writer.close(); - }, - async handleLLMEnd() { - // await writer.ready; - // await writer.close(); - }, - async handleLLMError(e: Error) { - console.log("[handleLLMError]", e, "writer error"); - var response = new ResponseBody(); - response.isSuccess = false; - response.message = e.message; - await writer.ready; - await writer.write( - encoder.encode(`data: ${JSON.stringify(response)}\n\n`), - ); - await writer.close(); - }, - handleLLMStart(llm, _prompts: string[]) { - // console.log("handleLLMStart: I'm the second handler!!", { llm }); - }, - handleChainStart(chain) { - // console.log("handleChainStart: I'm the second handler!!", { chain }); - }, - async handleAgentAction(action) { - try { - console.log("[handleAgentAction]", action.tool); - if (!reqBody.returnIntermediateSteps) return; - var response = new ResponseBody(); - response.isToolMessage = true; - response.message = JSON.stringify(action.toolInput); - response.toolName = action.tool; - await writer.ready; - await writer.write( - encoder.encode(`data: ${JSON.stringify(response)}\n\n`), - ); - } catch (ex) { - console.error("[handleAgentAction]", ex); - var response = new ResponseBody(); - response.isSuccess = false; - response.message = (ex as Error).message; - await writer.ready; - await writer.write( - encoder.encode(`data: ${JSON.stringify(response)}\n\n`), - ); - await writer.close(); - } - }, - handleToolStart(tool, input) { - console.log("[handleToolStart]", { tool }); - }, - async handleToolEnd(output, runId, parentRunId, tags) { - console.log("[handleToolEnd]", { output, runId, parentRunId, tags }); - }, - handleAgentEnd(action, runId, parentRunId, tags) { - console.log("[handleAgentEnd]"); - }, - }); - - let searchTool: Tool = new DuckDuckGo(); - if (process.env.CHOOSE_SEARCH_ENGINE) { - switch (process.env.CHOOSE_SEARCH_ENGINE) { - case "google": - searchTool = new GoogleSearch(); - break; - case "baidu": - searchTool = new BaiduSearch(); - break; - } - } - if (process.env.BING_SEARCH_API_KEY) { - let bingSearchTool = new langchainTools["BingSerpAPI"]( - process.env.BING_SEARCH_API_KEY, - ); - searchTool = new DynamicTool({ - name: "bing_search", - description: bingSearchTool.description, - func: async (input: string) => bingSearchTool.call(input), - }); - } - if (process.env.SERPAPI_API_KEY) { - let serpAPITool = new langchainTools["SerpAPI"]( - process.env.SERPAPI_API_KEY, - ); - searchTool = new DynamicTool({ - name: "google_search", - description: serpAPITool.description, - func: async (input: string) => serpAPITool.call(input), - }); - } - const model = new OpenAI( { temperature: 0, @@ -227,99 +63,34 @@ async function handle(req: NextRequest) { { basePath: baseUrl }, ); - const tools = [ - // new RequestsGetTool(), - // new RequestsPostTool(), - ]; - const webBrowserTool = new WebBrowser({ model, embeddings }); - const calculatorTool = new Calculator(); - const dallEAPITool = new DallEAPIWrapper( + var dalleCallback = async (data: string) => { + var response = new ResponseBody(); + response.message = data; + await writer.ready; + await writer.write( + encoder.encode(`data: ${JSON.stringify(response)}\n\n`), + ); + }; + + var edgeTool = new EdgeTool( apiKey, baseUrl, - async (data: string) => { - var response = new ResponseBody(); - response.message = data; - await writer.ready; - await writer.write( - encoder.encode(`data: ${JSON.stringify(response)}\n\n`), - ); - }, + model, + embeddings, + dalleCallback, ); - dallEAPITool.returnDirect = true; - const stableDiffusionTool = new StableDiffusionWrapper(); - const arxivAPITool = new ArxivAPIWrapper(); - const pdfBrowserTool = new PDFBrowser(model, embeddings); - if (useTools.includes("web-search")) tools.push(searchTool); - if (useTools.includes(webBrowserTool.name)) tools.push(webBrowserTool); - if (useTools.includes(calculatorTool.name)) tools.push(calculatorTool); - if (useTools.includes(dallEAPITool.name)) tools.push(dallEAPITool); - if (useTools.includes(stableDiffusionTool.name)) - tools.push(stableDiffusionTool); - if (useTools.includes(arxivAPITool.name)) tools.push(arxivAPITool); - if (useTools.includes(pdfBrowserTool.name)) tools.push(pdfBrowserTool); - - useTools.forEach((toolName) => { - if (toolName) { - var tool = langchainTools[ - toolName as keyof typeof langchainTools - ] as any; - if (tool) { - tools.push(new tool()); - } - } - }); - - const pastMessages = new Array(); - - reqBody.messages - .slice(0, reqBody.messages.length - 1) - .forEach((message) => { - if (message.role === "system") - pastMessages.push(new SystemMessage(message.content)); - if (message.role === "user") - pastMessages.push(new HumanMessage(message.content)); - if (message.role === "assistant") - pastMessages.push(new AIMessage(message.content)); - }); - - const memory = new BufferMemory({ - memoryKey: "chat_history", - returnMessages: true, - inputKey: "input", - outputKey: "output", - chatHistory: new ChatMessageHistory(pastMessages), - }); - - const llm = new ChatOpenAI( - { - modelName: reqBody.model, - openAIApiKey: apiKey, - temperature: reqBody.temperature, - streaming: reqBody.stream, - topP: reqBody.top_p, - presencePenalty: reqBody.presence_penalty, - frequencyPenalty: reqBody.frequency_penalty, - }, - { basePath: baseUrl }, + var nodejsTool = new NodeJSTool( + apiKey, + baseUrl, + model, + embeddings, + dalleCallback, ); - const executor = await initializeAgentExecutorWithOptions(tools, llm, { - agentType: "openai-functions", - returnIntermediateSteps: reqBody.returnIntermediateSteps, - maxIterations: reqBody.maxIterations, - memory: memory, - }); - - executor.call( - { - input: reqBody.messages.slice(-1)[0].content, - }, - [handler], - ); - - console.log("returning response"); - return new Response(transformStream.readable, { - headers: { "Content-Type": "text/event-stream" }, - }); + var edgeTools = await edgeTool.getCustomTools(); + var nodejsTools = await nodejsTool.getCustomTools(); + edgeTools.push(nodejsTools); + var agentApi = new AgentApi(encoder, transformStream, writer); + return await agentApi.getApiHandler(req, reqBody, nodejsTools); } catch (e) { return new Response(JSON.stringify({ error: (e as any).message }), { status: 500,