diff --git a/app/api/langchain-tools/dalle_image_generator.ts b/app/api/langchain-tools/dalle_image_generator.ts index 0abd5d60e..c13a801d5 100644 --- a/app/api/langchain-tools/dalle_image_generator.ts +++ b/app/api/langchain-tools/dalle_image_generator.ts @@ -1,20 +1,27 @@ -import { Tool } from "langchain/tools"; +import { StructuredTool } from "langchain/tools"; +import { z } from "zod"; import S3FileStorage from "../../utils/r2_file_storage"; -export class DallEAPIWrapper extends Tool { +export class DallEAPIWrapper extends StructuredTool { name = "dalle_image_generator"; n = 1; - size: "256x256" | "512x512" | "1024x1024" | null = "1024x1024"; apiKey: string; baseURL?: string; - constructor(apiKey?: string | undefined, baseURL?: string | undefined) { + callback?: (data: string) => Promise; + + constructor( + apiKey?: string | undefined, + baseURL?: string | undefined, + callback?: (data: string) => Promise, + ) { super(); if (!apiKey) { throw new Error("OpenAI API key not set."); } this.apiKey = apiKey; this.baseURL = baseURL; + this.callback = callback; } async saveImageFromUrl(url: string) { @@ -24,23 +31,35 @@ export class DallEAPIWrapper extends Tool { return await S3FileStorage.put(`${Date.now()}.png`, buffer); } + schema = z.object({ + prompt: z + .string() + .describe( + 'input must be a english prompt. you can set `quality: "hd"` for enhanced detail.', + ), + size: z + .enum(["1024x1024", "1024x1792", "1792x1024"]) + .default("1024x1024") + .describe("images size"), + }); + /** @ignore */ - async _call(prompt: string) { + async _call({ prompt, size }: z.infer) { let image_url; const apiUrl = `${this.baseURL}/images/generations`; - try { const requestOptions = { - method: 'POST', + method: "POST", headers: { - 'Content-Type': 'application/json', - 'Authorization': `Bearer ${this.apiKey}` + "Content-Type": "application/json", + Authorization: `Bearer ${this.apiKey}`, }, body: JSON.stringify({ + model: "dall-e-3", prompt: prompt, n: this.n, - size: this.size - }) + size: size, + }), }; const response = await fetch(apiUrl, requestOptions); const json = await response.json(); @@ -50,13 +69,18 @@ export class DallEAPIWrapper extends Tool { console.error("[DALL-E]", e); } if (!image_url) return "No image was generated"; - let filePath = await this.saveImageFromUrl(image_url); - console.log(filePath); - return filePath; + try { + let filePath = await this.saveImageFromUrl(image_url); + console.log("[DALL-E]", filePath); + var imageMarkdown = `![img](${filePath})`; + if (this.callback != null) await this.callback(imageMarkdown); + return imageMarkdown; + } catch (e) { + if (this.callback != null) + await this.callback("Image upload to R2 storage failed"); + return "Image upload to R2 storage failed"; + } } - description = `openai's dall-e image generator. - input must be a english prompt. - output will be the image link url. - use markdown to display images. like: ![img](/api/file/xxx.png)`; + description = `openai's dall-e 3 image generator.`; } diff --git a/app/api/langchain/tool/agent/route.ts b/app/api/langchain/tool/agent/route.ts index 10d03ffa1..d1232bca5 100644 --- a/app/api/langchain/tool/agent/route.ts +++ b/app/api/langchain/tool/agent/route.ts @@ -107,7 +107,7 @@ async function handle(req: NextRequest) { } }, async handleChainError(err, runId, parentRunId, tags) { - console.log(err, "writer error"); + console.log("[handleChainError]", err, "writer error"); var response = new ResponseBody(); response.isSuccess = false; response.message = err; @@ -118,6 +118,7 @@ async function handle(req: NextRequest) { await writer.close(); }, async handleChainEnd(outputs, runId, parentRunId, tags) { + console.log("[handleChainEnd]"); await writer.ready; await writer.close(); }, @@ -126,7 +127,7 @@ async function handle(req: NextRequest) { // await writer.close(); }, async handleLLMError(e: Error) { - console.log(e, "writer error"); + console.log("[handleLLMError]", e, "writer error"); var response = new ResponseBody(); response.isSuccess = false; response.message = e.message; @@ -144,11 +145,7 @@ async function handle(req: NextRequest) { }, async handleAgentAction(action) { try { - console.log( - "agent (llm)", - `tool: ${action.tool} toolInput: ${action.toolInput}`, - { action }, - ); + console.log("[handleAgentAction]", action.tool); if (!reqBody.returnIntermediateSteps) return; var response = new ResponseBody(); response.isToolMessage = true; @@ -171,7 +168,13 @@ async function handle(req: NextRequest) { } }, handleToolStart(tool, input) { - console.log("handleToolStart", { tool, input }); + console.log("[handleToolStart]", { tool }); + }, + async handleToolEnd(output, runId, parentRunId, tags) { + console.log("[handleToolEnd]", { output, runId, parentRunId, tags }); + }, + handleAgentEnd(action, runId, parentRunId, tags) { + console.log("[handleAgentEnd]"); }, }); @@ -228,7 +231,19 @@ async function handle(req: NextRequest) { ]; const webBrowserTool = new WebBrowser({ model, embeddings }); const calculatorTool = new Calculator(); - const dallEAPITool = new DallEAPIWrapper(apiKey, baseUrl); + const dallEAPITool = new DallEAPIWrapper( + apiKey, + baseUrl, + async (data: string) => { + var response = new ResponseBody(); + response.message = data; + await writer.ready; + await writer.write( + encoder.encode(`data: ${JSON.stringify(response)}\n\n`), + ); + }, + ); + dallEAPITool.returnDirect = true; const stableDiffusionTool = new StableDiffusionWrapper(); const arxivAPITool = new ArxivAPIWrapper(); if (useTools.includes("web-search")) tools.push(searchTool);