refactor: refactor agent api route

This commit is contained in:
Hk-Gosuto 2023-11-28 18:59:00 +08:00
parent 35f5dc52e7
commit 293679fa64
5 changed files with 502 additions and 522 deletions

View File

@ -0,0 +1,94 @@
import { ArxivAPIWrapper } from "@/app/api/langchain-tools/arxiv";
import { DallEAPIWrapper } from "@/app/api/langchain-tools/dalle_image_generator";
import { StableDiffusionWrapper } from "@/app/api/langchain-tools/stable_diffusion_image_generator";
import { BaseLanguageModel } from "langchain/dist/base_language";
import { Calculator } from "langchain/tools/calculator";
import { WebBrowser } from "langchain/tools/webbrowser";
import { BaiduSearch } from "@/app/api/langchain-tools/baidu_search";
import { DuckDuckGo } from "@/app/api/langchain-tools/duckduckgo_search";
import { GoogleSearch } from "@/app/api/langchain-tools/google_search";
import { Tool, DynamicTool } from "langchain/tools";
import * as langchainTools from "langchain/tools";
import { Embeddings } from "langchain/dist/embeddings/base.js";
import { promises } from "dns";
export class EdgeTool {
private apiKey: string | undefined;
private baseUrl: string;
private model: BaseLanguageModel;
private embeddings: Embeddings;
private callback?: (data: string) => Promise<void>;
constructor(
apiKey: string | undefined,
baseUrl: string,
model: BaseLanguageModel,
embeddings: Embeddings,
callback?: (data: string) => Promise<void>,
) {
this.apiKey = apiKey;
this.baseUrl = baseUrl;
this.model = model;
this.embeddings = embeddings;
this.callback = callback;
}
async getCustomTools(): Promise<any[]> {
// let searchTool: Tool = new DuckDuckGo();
// if (process.env.CHOOSE_SEARCH_ENGINE) {
// switch (process.env.CHOOSE_SEARCH_ENGINE) {
// case "google":
// searchTool = new GoogleSearch();
// break;
// case "baidu":
// searchTool = new BaiduSearch();
// break;
// }
// }
// if (process.env.BING_SEARCH_API_KEY) {
// let bingSearchTool = new langchainTools["BingSerpAPI"](
// process.env.BING_SEARCH_API_KEY,
// );
// searchTool = new DynamicTool({
// name: "bing_search",
// description: bingSearchTool.description,
// func: async (input: string) => bingSearchTool.call(input),
// });
// }
// if (process.env.SERPAPI_API_KEY) {
// let serpAPITool = new langchainTools["SerpAPI"](
// process.env.SERPAPI_API_KEY,
// );
// searchTool = new DynamicTool({
// name: "google_search",
// description: serpAPITool.description,
// func: async (input: string) => serpAPITool.call(input),
// });
// }
const webBrowserTool = new WebBrowser({
model: this.model,
embeddings: this.embeddings,
});
const calculatorTool = new Calculator();
const dallEAPITool = new DallEAPIWrapper(
this.apiKey,
this.baseUrl,
this.callback,
);
dallEAPITool.returnDirect = true;
const stableDiffusionTool = new StableDiffusionWrapper();
const arxivAPITool = new ArxivAPIWrapper();
return [
// searchTool,
calculatorTool,
webBrowserTool,
dallEAPITool,
stableDiffusionTool,
arxivAPITool,
];
}
}

View File

@ -0,0 +1,35 @@
import { BaseLanguageModel } from "langchain/dist/base_language";
import { PDFBrowser } from "@/app/api/langchain-tools/pdf_browser";
import { Embeddings } from "langchain/dist/embeddings/base.js";
export class NodeJSTool {
private apiKey: string | undefined;
private baseUrl: string;
private model: BaseLanguageModel;
private embeddings: Embeddings;
private callback?: (data: string) => Promise<void>;
constructor(
apiKey: string | undefined,
baseUrl: string,
model: BaseLanguageModel,
embeddings: Embeddings,
callback?: (data: string) => Promise<void>,
) {
this.apiKey = apiKey;
this.baseUrl = baseUrl;
this.model = model;
this.embeddings = embeddings;
this.callback = callback;
}
async getCustomTools(): Promise<any[]> {
const pdfBrowserTool = new PDFBrowser(this.model, this.embeddings);
return [pdfBrowserTool];
}
}

View File

@ -0,0 +1,316 @@
import { NextRequest, NextResponse } from "next/server";
import { getServerSideConfig } from "@/app/config/server";
import { auth } from "../../../auth";
import { ChatOpenAI } from "langchain/chat_models/openai";
import { BaseCallbackHandler } from "langchain/callbacks";
import { AIMessage, HumanMessage, SystemMessage } from "langchain/schema";
import { BufferMemory, ChatMessageHistory } from "langchain/memory";
import { initializeAgentExecutorWithOptions } from "langchain/agents";
import { ACCESS_CODE_PREFIX } from "@/app/constant";
import { OpenAI } from "langchain/llms/openai";
import { OpenAIEmbeddings } from "langchain/embeddings/openai";
import * as langchainTools from "langchain/tools";
import { HttpGetTool } from "@/app/api/langchain-tools/http_get";
import { DuckDuckGo } from "@/app/api/langchain-tools/duckduckgo_search";
import { WebBrowser } from "langchain/tools/webbrowser";
import { Calculator } from "langchain/tools/calculator";
import { DynamicTool, Tool } from "langchain/tools";
import { DallEAPIWrapper } from "@/app/api/langchain-tools/dalle_image_generator";
import { BaiduSearch } from "@/app/api/langchain-tools/baidu_search";
import { GoogleSearch } from "@/app/api/langchain-tools/google_search";
import { StableDiffusionWrapper } from "@/app/api/langchain-tools/stable_diffusion_image_generator";
import { ArxivAPIWrapper } from "@/app/api/langchain-tools/arxiv";
import { PDFBrowser } from "@/app/api/langchain-tools/pdf_browser";
export interface RequestMessage {
role: string;
content: string;
}
export interface RequestBody {
messages: RequestMessage[];
model: string;
stream?: boolean;
temperature: number;
presence_penalty?: number;
frequency_penalty?: number;
top_p?: number;
baseUrl?: string;
apiKey?: string;
maxIterations: number;
returnIntermediateSteps: boolean;
useTools: (undefined | string)[];
}
export class ResponseBody {
isSuccess: boolean = true;
message!: string;
isToolMessage: boolean = false;
toolName?: string;
}
export interface ToolInput {
input: string;
}
export class AgentApi {
private encoder: TextEncoder;
private transformStream: TransformStream;
private writer: WritableStreamDefaultWriter<any>;
constructor(
encoder: TextEncoder,
transformStream: TransformStream,
writer: WritableStreamDefaultWriter<any>,
) {
this.encoder = encoder;
this.transformStream = transformStream;
this.writer = writer;
}
async getHandler(reqBody: any) {
var writer = this.writer;
var encoder = this.encoder;
return BaseCallbackHandler.fromMethods({
async handleLLMNewToken(token: string) {
if (token) {
var response = new ResponseBody();
response.message = token;
await writer.ready;
await writer.write(
encoder.encode(`data: ${JSON.stringify(response)}\n\n`),
);
}
},
async handleChainError(err, runId, parentRunId, tags) {
console.log("[handleChainError]", err, "writer error");
var response = new ResponseBody();
response.isSuccess = false;
response.message = err;
await writer.ready;
await writer.write(
encoder.encode(`data: ${JSON.stringify(response)}\n\n`),
);
await writer.close();
},
async handleChainEnd(outputs, runId, parentRunId, tags) {
console.log("[handleChainEnd]");
await writer.ready;
await writer.close();
},
async handleLLMEnd() {
// await writer.ready;
// await writer.close();
},
async handleLLMError(e: Error) {
console.log("[handleLLMError]", e, "writer error");
var response = new ResponseBody();
response.isSuccess = false;
response.message = e.message;
await writer.ready;
await writer.write(
encoder.encode(`data: ${JSON.stringify(response)}\n\n`),
);
await writer.close();
},
handleLLMStart(llm, _prompts: string[]) {
// console.log("handleLLMStart: I'm the second handler!!", { llm });
},
handleChainStart(chain) {
// console.log("handleChainStart: I'm the second handler!!", { chain });
},
async handleAgentAction(action) {
try {
console.log("[handleAgentAction]", action.tool);
if (!reqBody.returnIntermediateSteps) return;
var response = new ResponseBody();
response.isToolMessage = true;
response.message = JSON.stringify(action.toolInput);
response.toolName = action.tool;
await writer.ready;
await writer.write(
encoder.encode(`data: ${JSON.stringify(response)}\n\n`),
);
} catch (ex) {
console.error("[handleAgentAction]", ex);
var response = new ResponseBody();
response.isSuccess = false;
response.message = (ex as Error).message;
await writer.ready;
await writer.write(
encoder.encode(`data: ${JSON.stringify(response)}\n\n`),
);
await writer.close();
}
},
handleToolStart(tool, input) {
console.log("[handleToolStart]", { tool });
},
async handleToolEnd(output, runId, parentRunId, tags) {
console.log("[handleToolEnd]", { output, runId, parentRunId, tags });
},
handleAgentEnd(action, runId, parentRunId, tags) {
console.log("[handleAgentEnd]");
},
});
}
async getApiHandler(
req: NextRequest,
reqBody: RequestBody,
customTools: any[],
) {
if (req.method === "OPTIONS") {
return NextResponse.json({ body: "OK" }, { status: 200 });
}
try {
const authResult = auth(req);
if (authResult.error) {
return NextResponse.json(authResult, {
status: 401,
});
}
const serverConfig = getServerSideConfig();
// const reqBody: RequestBody = await req.json();
const authToken = req.headers.get("Authorization") ?? "";
const token = authToken.trim().replaceAll("Bearer ", "").trim();
const isOpenAiKey = !token.startsWith(ACCESS_CODE_PREFIX);
let useTools = reqBody.useTools ?? [];
let apiKey = serverConfig.apiKey;
if (isOpenAiKey && token) {
apiKey = token;
}
let baseUrl = "https://api.openai.com/v1";
if (serverConfig.baseUrl) baseUrl = serverConfig.baseUrl;
if (
reqBody.baseUrl?.startsWith("http://") ||
reqBody.baseUrl?.startsWith("https://")
)
baseUrl = reqBody.baseUrl;
if (!baseUrl.endsWith("/v1"))
baseUrl = baseUrl.endsWith("/") ? `${baseUrl}v1` : `${baseUrl}/v1`;
console.log("[baseUrl]", baseUrl);
var handler = await this.getHandler(reqBody);
let searchTool: Tool = new DuckDuckGo();
if (process.env.CHOOSE_SEARCH_ENGINE) {
switch (process.env.CHOOSE_SEARCH_ENGINE) {
case "google":
searchTool = new GoogleSearch();
break;
case "baidu":
searchTool = new BaiduSearch();
break;
}
}
if (process.env.BING_SEARCH_API_KEY) {
let bingSearchTool = new langchainTools["BingSerpAPI"](
process.env.BING_SEARCH_API_KEY,
);
searchTool = new DynamicTool({
name: "bing_search",
description: bingSearchTool.description,
func: async (input: string) => bingSearchTool.call(input),
});
}
if (process.env.SERPAPI_API_KEY) {
let serpAPITool = new langchainTools["SerpAPI"](
process.env.SERPAPI_API_KEY,
);
searchTool = new DynamicTool({
name: "google_search",
description: serpAPITool.description,
func: async (input: string) => serpAPITool.call(input),
});
}
const tools = [];
if (useTools.includes("web-search")) tools.push(searchTool);
console.log(customTools);
customTools.forEach((customTool) => {
if (customTool) {
if (useTools.includes(customTool.name)) {
tools.push(customTool);
}
}
});
useTools.forEach((toolName) => {
if (toolName) {
var tool = langchainTools[
toolName as keyof typeof langchainTools
] as any;
if (tool) {
tools.push(new tool());
}
}
});
const pastMessages = new Array();
reqBody.messages
.slice(0, reqBody.messages.length - 1)
.forEach((message) => {
if (message.role === "system")
pastMessages.push(new SystemMessage(message.content));
if (message.role === "user")
pastMessages.push(new HumanMessage(message.content));
if (message.role === "assistant")
pastMessages.push(new AIMessage(message.content));
});
const memory = new BufferMemory({
memoryKey: "chat_history",
returnMessages: true,
inputKey: "input",
outputKey: "output",
chatHistory: new ChatMessageHistory(pastMessages),
});
const llm = new ChatOpenAI(
{
modelName: reqBody.model,
openAIApiKey: apiKey,
temperature: reqBody.temperature,
streaming: reqBody.stream,
topP: reqBody.top_p,
presencePenalty: reqBody.presence_penalty,
frequencyPenalty: reqBody.frequency_penalty,
},
{ basePath: baseUrl },
);
const executor = await initializeAgentExecutorWithOptions(tools, llm, {
agentType: "openai-functions",
returnIntermediateSteps: reqBody.returnIntermediateSteps,
maxIterations: reqBody.maxIterations,
memory: memory,
});
executor.call(
{
input: reqBody.messages.slice(-1)[0].content,
},
[handler],
);
console.log("returning response");
return new Response(this.transformStream.readable, {
headers: { "Content-Type": "text/event-stream" },
});
} catch (e) {
return new Response(JSON.stringify({ error: (e as any).message }), {
status: 500,
headers: { "Content-Type": "application/json" },
});
}
}
}

View File

@ -1,63 +1,12 @@
import { NextRequest, NextResponse } from "next/server";
import { getServerSideConfig } from "@/app/config/server";
import { auth } from "../../../../auth";
import { ChatOpenAI } from "langchain/chat_models/openai";
import { BaseCallbackHandler } from "langchain/callbacks";
import { AIMessage, HumanMessage, SystemMessage } from "langchain/schema";
import { BufferMemory, ChatMessageHistory } from "langchain/memory";
import { initializeAgentExecutorWithOptions } from "langchain/agents";
import { AgentApi, RequestBody, ResponseBody } from "../agentapi";
import { auth } from "@/app/api/auth";
import { EdgeTool } from "../../../../langchain-tools/edge_tools";
import { ACCESS_CODE_PREFIX } from "@/app/constant";
import { getServerSideConfig } from "@/app/config/server";
import { OpenAI } from "langchain/llms/openai";
import { OpenAIEmbeddings } from "langchain/embeddings/openai";
import * as langchainTools from "langchain/tools";
import { HttpGetTool } from "@/app/api/langchain-tools/http_get";
import { DuckDuckGo } from "@/app/api/langchain-tools/duckduckgo_search";
import { WebBrowser } from "langchain/tools/webbrowser";
import { Calculator } from "langchain/tools/calculator";
import { DynamicTool, Tool } from "langchain/tools";
import { DallEAPIWrapper } from "@/app/api/langchain-tools/dalle_image_generator";
import { BaiduSearch } from "@/app/api/langchain-tools/baidu_search";
import { GoogleSearch } from "@/app/api/langchain-tools/google_search";
import { StableDiffusionWrapper } from "@/app/api/langchain-tools/stable_diffusion_image_generator";
import { ArxivAPIWrapper } from "@/app/api/langchain-tools/arxiv";
import dynamic from "next/dynamic";
const serverConfig = getServerSideConfig();
interface RequestMessage {
role: string;
content: string;
}
interface RequestBody {
messages: RequestMessage[];
model: string;
stream?: boolean;
temperature: number;
presence_penalty?: number;
frequency_penalty?: number;
top_p?: number;
baseUrl?: string;
apiKey?: string;
maxIterations: number;
returnIntermediateSteps: boolean;
useTools: (undefined | string)[];
}
class ResponseBody {
isSuccess: boolean = true;
message!: string;
isToolMessage: boolean = false;
toolName?: string;
}
interface ToolInput {
input: string;
}
async function handle(req: NextRequest) {
if (req.method === "OPTIONS") {
return NextResponse.json({ body: "OK" }, { status: 200 });
@ -70,14 +19,17 @@ async function handle(req: NextRequest) {
});
}
const serverConfig = getServerSideConfig();
const encoder = new TextEncoder();
const transformStream = new TransformStream();
const writer = transformStream.writable.getWriter();
const reqBody: RequestBody = await req.json();
const authToken = req.headers.get("Authorization") ?? "";
const token = authToken.trim().replaceAll("Bearer ", "").trim();
const isOpenAiKey = !token.startsWith(ACCESS_CODE_PREFIX);
let useTools = reqBody.useTools ?? [];
let apiKey = serverConfig.apiKey;
if (isOpenAiKey && token) {
apiKey = token;
@ -95,122 +47,6 @@ async function handle(req: NextRequest) {
baseUrl = baseUrl.endsWith("/") ? `${baseUrl}v1` : `${baseUrl}/v1`;
console.log("[baseUrl]", baseUrl);
const handler = BaseCallbackHandler.fromMethods({
async handleLLMNewToken(token: string) {
// console.log("[Token]", token);
if (token) {
var response = new ResponseBody();
response.message = token;
await writer.ready;
await writer.write(
encoder.encode(`data: ${JSON.stringify(response)}\n\n`),
);
}
},
async handleChainError(err, runId, parentRunId, tags) {
console.log("[handleChainError]", err, "writer error");
var response = new ResponseBody();
response.isSuccess = false;
response.message = err;
await writer.ready;
await writer.write(
encoder.encode(`data: ${JSON.stringify(response)}\n\n`),
);
await writer.close();
},
async handleChainEnd(outputs, runId, parentRunId, tags) {
console.log("[handleChainEnd]");
await writer.ready;
await writer.close();
},
async handleLLMEnd() {
// await writer.ready;
// await writer.close();
},
async handleLLMError(e: Error) {
console.log("[handleLLMError]", e, "writer error");
var response = new ResponseBody();
response.isSuccess = false;
response.message = e.message;
await writer.ready;
await writer.write(
encoder.encode(`data: ${JSON.stringify(response)}\n\n`),
);
await writer.close();
},
handleLLMStart(llm, _prompts: string[]) {
// console.log("handleLLMStart: I'm the second handler!!", { llm });
},
handleChainStart(chain) {
// console.log("handleChainStart: I'm the second handler!!", { chain });
},
async handleAgentAction(action) {
try {
console.log("[handleAgentAction]", action.tool);
if (!reqBody.returnIntermediateSteps) return;
var response = new ResponseBody();
response.isToolMessage = true;
response.message = JSON.stringify(action.toolInput);
response.toolName = action.tool;
await writer.ready;
await writer.write(
encoder.encode(`data: ${JSON.stringify(response)}\n\n`),
);
} catch (ex) {
console.error("[handleAgentAction]", ex);
var response = new ResponseBody();
response.isSuccess = false;
response.message = (ex as Error).message;
await writer.ready;
await writer.write(
encoder.encode(`data: ${JSON.stringify(response)}\n\n`),
);
await writer.close();
}
},
handleToolStart(tool, input) {
console.log("[handleToolStart]", { tool });
},
async handleToolEnd(output, runId, parentRunId, tags) {
console.log("[handleToolEnd]", { output, runId, parentRunId, tags });
},
handleAgentEnd(action, runId, parentRunId, tags) {
console.log("[handleAgentEnd]");
},
});
let searchTool: Tool = new DuckDuckGo();
if (process.env.CHOOSE_SEARCH_ENGINE) {
switch (process.env.CHOOSE_SEARCH_ENGINE) {
case "google":
searchTool = new GoogleSearch();
break;
case "baidu":
searchTool = new BaiduSearch();
break;
}
}
if (process.env.BING_SEARCH_API_KEY) {
let bingSearchTool = new langchainTools["BingSerpAPI"](
process.env.BING_SEARCH_API_KEY,
);
searchTool = new DynamicTool({
name: "bing_search",
description: bingSearchTool.description,
func: async (input: string) => bingSearchTool.call(input),
});
}
if (process.env.SERPAPI_API_KEY) {
let serpAPITool = new langchainTools["SerpAPI"](
process.env.SERPAPI_API_KEY,
);
searchTool = new DynamicTool({
name: "google_search",
description: serpAPITool.description,
func: async (input: string) => serpAPITool.call(input),
});
}
const model = new OpenAI(
{
temperature: 0,
@ -226,97 +62,25 @@ async function handle(req: NextRequest) {
{ basePath: baseUrl },
);
const tools = [
// new RequestsGetTool(),
// new RequestsPostTool(),
];
const webBrowserTool = new WebBrowser({ model, embeddings });
const calculatorTool = new Calculator();
const dallEAPITool = new DallEAPIWrapper(
var dalleCallback = async (data: string) => {
var response = new ResponseBody();
response.message = data;
await writer.ready;
await writer.write(
encoder.encode(`data: ${JSON.stringify(response)}\n\n`),
);
};
var edgeTool = new EdgeTool(
apiKey,
baseUrl,
async (data: string) => {
var response = new ResponseBody();
response.message = data;
await writer.ready;
await writer.write(
encoder.encode(`data: ${JSON.stringify(response)}\n\n`),
);
},
model,
embeddings,
dalleCallback,
);
dallEAPITool.returnDirect = true;
const stableDiffusionTool = new StableDiffusionWrapper();
const arxivAPITool = new ArxivAPIWrapper();
if (useTools.includes("web-search")) tools.push(searchTool);
if (useTools.includes(webBrowserTool.name)) tools.push(webBrowserTool);
if (useTools.includes(calculatorTool.name)) tools.push(calculatorTool);
if (useTools.includes(dallEAPITool.name)) tools.push(dallEAPITool);
if (useTools.includes(stableDiffusionTool.name))
tools.push(stableDiffusionTool);
if (useTools.includes(arxivAPITool.name)) tools.push(arxivAPITool);
useTools.forEach((toolName) => {
if (toolName) {
var tool = langchainTools[
toolName as keyof typeof langchainTools
] as any;
if (tool) {
tools.push(new tool());
}
}
});
const pastMessages = new Array();
reqBody.messages
.slice(0, reqBody.messages.length - 1)
.forEach((message) => {
if (message.role === "system")
pastMessages.push(new SystemMessage(message.content));
if (message.role === "user")
pastMessages.push(new HumanMessage(message.content));
if (message.role === "assistant")
pastMessages.push(new AIMessage(message.content));
});
const memory = new BufferMemory({
memoryKey: "chat_history",
returnMessages: true,
inputKey: "input",
outputKey: "output",
chatHistory: new ChatMessageHistory(pastMessages),
});
const llm = new ChatOpenAI(
{
modelName: reqBody.model,
openAIApiKey: apiKey,
temperature: reqBody.temperature,
streaming: reqBody.stream,
topP: reqBody.top_p,
presencePenalty: reqBody.presence_penalty,
frequencyPenalty: reqBody.frequency_penalty,
},
{ basePath: baseUrl },
);
const executor = await initializeAgentExecutorWithOptions(tools, llm, {
agentType: "openai-functions",
returnIntermediateSteps: reqBody.returnIntermediateSteps,
maxIterations: reqBody.maxIterations,
memory: memory,
});
executor.call(
{
input: reqBody.messages.slice(-1)[0].content,
},
[handler],
);
console.log("returning response");
return new Response(transformStream.readable, {
headers: { "Content-Type": "text/event-stream" },
});
var tools = await edgeTool.getCustomTools();
var agentApi = new AgentApi(encoder, transformStream, writer);
return await agentApi.getApiHandler(req, reqBody, tools);
} catch (e) {
return new Response(JSON.stringify({ error: (e as any).message }), {
status: 500,

View File

@ -1,63 +1,12 @@
import { NextRequest, NextResponse } from "next/server";
import { getServerSideConfig } from "@/app/config/server";
import { auth } from "../../../../auth";
import { ChatOpenAI } from "langchain/chat_models/openai";
import { BaseCallbackHandler } from "langchain/callbacks";
import { AIMessage, HumanMessage, SystemMessage } from "langchain/schema";
import { BufferMemory, ChatMessageHistory } from "langchain/memory";
import { initializeAgentExecutorWithOptions } from "langchain/agents";
import { AgentApi, RequestBody, ResponseBody } from "../agentapi";
import { auth } from "@/app/api/auth";
import { EdgeTool } from "../../../../langchain-tools/edge_tools";
import { ACCESS_CODE_PREFIX } from "@/app/constant";
import { getServerSideConfig } from "@/app/config/server";
import { OpenAI } from "langchain/llms/openai";
import { OpenAIEmbeddings } from "langchain/embeddings/openai";
import * as langchainTools from "langchain/tools";
import { HttpGetTool } from "@/app/api/langchain-tools/http_get";
import { DuckDuckGo } from "@/app/api/langchain-tools/duckduckgo_search";
import { WebBrowser } from "langchain/tools/webbrowser";
import { Calculator } from "langchain/tools/calculator";
import { DynamicTool, Tool } from "langchain/tools";
import { DallEAPIWrapper } from "@/app/api/langchain-tools/dalle_image_generator";
import { BaiduSearch } from "@/app/api/langchain-tools/baidu_search";
import { GoogleSearch } from "@/app/api/langchain-tools/google_search";
import { StableDiffusionWrapper } from "@/app/api/langchain-tools/stable_diffusion_image_generator";
import { ArxivAPIWrapper } from "@/app/api/langchain-tools/arxiv";
import dynamic from "next/dynamic";
import { PDFBrowser } from "@/app/api/langchain-tools/pdf_browser";
const serverConfig = getServerSideConfig();
interface RequestMessage {
role: string;
content: string;
}
interface RequestBody {
messages: RequestMessage[];
model: string;
stream?: boolean;
temperature: number;
presence_penalty?: number;
frequency_penalty?: number;
top_p?: number;
baseUrl?: string;
apiKey?: string;
maxIterations: number;
returnIntermediateSteps: boolean;
useTools: (undefined | string)[];
}
class ResponseBody {
isSuccess: boolean = true;
message!: string;
isToolMessage: boolean = false;
toolName?: string;
}
interface ToolInput {
input: string;
}
import { NodeJSTool } from "@/app/api/langchain-tools/nodejs_tools";
async function handle(req: NextRequest) {
if (req.method === "OPTIONS") {
@ -71,14 +20,17 @@ async function handle(req: NextRequest) {
});
}
const serverConfig = getServerSideConfig();
const encoder = new TextEncoder();
const transformStream = new TransformStream();
const writer = transformStream.writable.getWriter();
const reqBody: RequestBody = await req.json();
const authToken = req.headers.get("Authorization") ?? "";
const token = authToken.trim().replaceAll("Bearer ", "").trim();
const isOpenAiKey = !token.startsWith(ACCESS_CODE_PREFIX);
let useTools = reqBody.useTools ?? [];
let apiKey = serverConfig.apiKey;
if (isOpenAiKey && token) {
apiKey = token;
@ -96,122 +48,6 @@ async function handle(req: NextRequest) {
baseUrl = baseUrl.endsWith("/") ? `${baseUrl}v1` : `${baseUrl}/v1`;
console.log("[baseUrl]", baseUrl);
const handler = BaseCallbackHandler.fromMethods({
async handleLLMNewToken(token: string) {
// console.log("[Token]", token);
if (token) {
var response = new ResponseBody();
response.message = token;
await writer.ready;
await writer.write(
encoder.encode(`data: ${JSON.stringify(response)}\n\n`),
);
}
},
async handleChainError(err, runId, parentRunId, tags) {
console.log("[handleChainError]", err, "writer error");
var response = new ResponseBody();
response.isSuccess = false;
response.message = err;
await writer.ready;
await writer.write(
encoder.encode(`data: ${JSON.stringify(response)}\n\n`),
);
await writer.close();
},
async handleChainEnd(outputs, runId, parentRunId, tags) {
console.log("[handleChainEnd]");
await writer.ready;
await writer.close();
},
async handleLLMEnd() {
// await writer.ready;
// await writer.close();
},
async handleLLMError(e: Error) {
console.log("[handleLLMError]", e, "writer error");
var response = new ResponseBody();
response.isSuccess = false;
response.message = e.message;
await writer.ready;
await writer.write(
encoder.encode(`data: ${JSON.stringify(response)}\n\n`),
);
await writer.close();
},
handleLLMStart(llm, _prompts: string[]) {
// console.log("handleLLMStart: I'm the second handler!!", { llm });
},
handleChainStart(chain) {
// console.log("handleChainStart: I'm the second handler!!", { chain });
},
async handleAgentAction(action) {
try {
console.log("[handleAgentAction]", action.tool);
if (!reqBody.returnIntermediateSteps) return;
var response = new ResponseBody();
response.isToolMessage = true;
response.message = JSON.stringify(action.toolInput);
response.toolName = action.tool;
await writer.ready;
await writer.write(
encoder.encode(`data: ${JSON.stringify(response)}\n\n`),
);
} catch (ex) {
console.error("[handleAgentAction]", ex);
var response = new ResponseBody();
response.isSuccess = false;
response.message = (ex as Error).message;
await writer.ready;
await writer.write(
encoder.encode(`data: ${JSON.stringify(response)}\n\n`),
);
await writer.close();
}
},
handleToolStart(tool, input) {
console.log("[handleToolStart]", { tool });
},
async handleToolEnd(output, runId, parentRunId, tags) {
console.log("[handleToolEnd]", { output, runId, parentRunId, tags });
},
handleAgentEnd(action, runId, parentRunId, tags) {
console.log("[handleAgentEnd]");
},
});
let searchTool: Tool = new DuckDuckGo();
if (process.env.CHOOSE_SEARCH_ENGINE) {
switch (process.env.CHOOSE_SEARCH_ENGINE) {
case "google":
searchTool = new GoogleSearch();
break;
case "baidu":
searchTool = new BaiduSearch();
break;
}
}
if (process.env.BING_SEARCH_API_KEY) {
let bingSearchTool = new langchainTools["BingSerpAPI"](
process.env.BING_SEARCH_API_KEY,
);
searchTool = new DynamicTool({
name: "bing_search",
description: bingSearchTool.description,
func: async (input: string) => bingSearchTool.call(input),
});
}
if (process.env.SERPAPI_API_KEY) {
let serpAPITool = new langchainTools["SerpAPI"](
process.env.SERPAPI_API_KEY,
);
searchTool = new DynamicTool({
name: "google_search",
description: serpAPITool.description,
func: async (input: string) => serpAPITool.call(input),
});
}
const model = new OpenAI(
{
temperature: 0,
@ -227,99 +63,34 @@ async function handle(req: NextRequest) {
{ basePath: baseUrl },
);
const tools = [
// new RequestsGetTool(),
// new RequestsPostTool(),
];
const webBrowserTool = new WebBrowser({ model, embeddings });
const calculatorTool = new Calculator();
const dallEAPITool = new DallEAPIWrapper(
var dalleCallback = async (data: string) => {
var response = new ResponseBody();
response.message = data;
await writer.ready;
await writer.write(
encoder.encode(`data: ${JSON.stringify(response)}\n\n`),
);
};
var edgeTool = new EdgeTool(
apiKey,
baseUrl,
async (data: string) => {
var response = new ResponseBody();
response.message = data;
await writer.ready;
await writer.write(
encoder.encode(`data: ${JSON.stringify(response)}\n\n`),
);
},
model,
embeddings,
dalleCallback,
);
dallEAPITool.returnDirect = true;
const stableDiffusionTool = new StableDiffusionWrapper();
const arxivAPITool = new ArxivAPIWrapper();
const pdfBrowserTool = new PDFBrowser(model, embeddings);
if (useTools.includes("web-search")) tools.push(searchTool);
if (useTools.includes(webBrowserTool.name)) tools.push(webBrowserTool);
if (useTools.includes(calculatorTool.name)) tools.push(calculatorTool);
if (useTools.includes(dallEAPITool.name)) tools.push(dallEAPITool);
if (useTools.includes(stableDiffusionTool.name))
tools.push(stableDiffusionTool);
if (useTools.includes(arxivAPITool.name)) tools.push(arxivAPITool);
if (useTools.includes(pdfBrowserTool.name)) tools.push(pdfBrowserTool);
useTools.forEach((toolName) => {
if (toolName) {
var tool = langchainTools[
toolName as keyof typeof langchainTools
] as any;
if (tool) {
tools.push(new tool());
}
}
});
const pastMessages = new Array();
reqBody.messages
.slice(0, reqBody.messages.length - 1)
.forEach((message) => {
if (message.role === "system")
pastMessages.push(new SystemMessage(message.content));
if (message.role === "user")
pastMessages.push(new HumanMessage(message.content));
if (message.role === "assistant")
pastMessages.push(new AIMessage(message.content));
});
const memory = new BufferMemory({
memoryKey: "chat_history",
returnMessages: true,
inputKey: "input",
outputKey: "output",
chatHistory: new ChatMessageHistory(pastMessages),
});
const llm = new ChatOpenAI(
{
modelName: reqBody.model,
openAIApiKey: apiKey,
temperature: reqBody.temperature,
streaming: reqBody.stream,
topP: reqBody.top_p,
presencePenalty: reqBody.presence_penalty,
frequencyPenalty: reqBody.frequency_penalty,
},
{ basePath: baseUrl },
var nodejsTool = new NodeJSTool(
apiKey,
baseUrl,
model,
embeddings,
dalleCallback,
);
const executor = await initializeAgentExecutorWithOptions(tools, llm, {
agentType: "openai-functions",
returnIntermediateSteps: reqBody.returnIntermediateSteps,
maxIterations: reqBody.maxIterations,
memory: memory,
});
executor.call(
{
input: reqBody.messages.slice(-1)[0].content,
},
[handler],
);
console.log("returning response");
return new Response(transformStream.readable, {
headers: { "Content-Type": "text/event-stream" },
});
var edgeTools = await edgeTool.getCustomTools();
var nodejsTools = await nodejsTool.getCustomTools();
edgeTools.push(nodejsTools);
var agentApi = new AgentApi(encoder, transformStream, writer);
return await agentApi.getApiHandler(req, reqBody, nodejsTools);
} catch (e) {
return new Response(JSON.stringify({ error: (e as any).message }), {
status: 500,