feat: support dall-e 3

This commit is contained in:
Hk-Gosuto 2023-11-07 17:57:53 +08:00
parent d5566f8c56
commit 8272245f7b
2 changed files with 66 additions and 27 deletions

View File

@ -1,20 +1,27 @@
import { Tool } from "langchain/tools"; import { StructuredTool } from "langchain/tools";
import { z } from "zod";
import S3FileStorage from "../../utils/r2_file_storage"; import S3FileStorage from "../../utils/r2_file_storage";
export class DallEAPIWrapper extends Tool { export class DallEAPIWrapper extends StructuredTool {
name = "dalle_image_generator"; name = "dalle_image_generator";
n = 1; n = 1;
size: "256x256" | "512x512" | "1024x1024" | null = "1024x1024";
apiKey: string; apiKey: string;
baseURL?: string; baseURL?: string;
constructor(apiKey?: string | undefined, baseURL?: string | undefined) { callback?: (data: string) => Promise<void>;
constructor(
apiKey?: string | undefined,
baseURL?: string | undefined,
callback?: (data: string) => Promise<void>,
) {
super(); super();
if (!apiKey) { if (!apiKey) {
throw new Error("OpenAI API key not set."); throw new Error("OpenAI API key not set.");
} }
this.apiKey = apiKey; this.apiKey = apiKey;
this.baseURL = baseURL; this.baseURL = baseURL;
this.callback = callback;
} }
async saveImageFromUrl(url: string) { async saveImageFromUrl(url: string) {
@ -24,23 +31,35 @@ export class DallEAPIWrapper extends Tool {
return await S3FileStorage.put(`${Date.now()}.png`, buffer); return await S3FileStorage.put(`${Date.now()}.png`, buffer);
} }
schema = z.object({
prompt: z
.string()
.describe(
'input must be a english prompt. you can set `quality: "hd"` for enhanced detail.',
),
size: z
.enum(["1024x1024", "1024x1792", "1792x1024"])
.default("1024x1024")
.describe("images size"),
});
/** @ignore */ /** @ignore */
async _call(prompt: string) { async _call({ prompt, size }: z.infer<typeof this.schema>) {
let image_url; let image_url;
const apiUrl = `${this.baseURL}/images/generations`; const apiUrl = `${this.baseURL}/images/generations`;
try { try {
const requestOptions = { const requestOptions = {
method: 'POST', method: "POST",
headers: { headers: {
'Content-Type': 'application/json', "Content-Type": "application/json",
'Authorization': `Bearer ${this.apiKey}` Authorization: `Bearer ${this.apiKey}`,
}, },
body: JSON.stringify({ body: JSON.stringify({
model: "dall-e-3",
prompt: prompt, prompt: prompt,
n: this.n, n: this.n,
size: this.size size: size,
}) }),
}; };
const response = await fetch(apiUrl, requestOptions); const response = await fetch(apiUrl, requestOptions);
const json = await response.json(); const json = await response.json();
@ -50,13 +69,18 @@ export class DallEAPIWrapper extends Tool {
console.error("[DALL-E]", e); console.error("[DALL-E]", e);
} }
if (!image_url) return "No image was generated"; if (!image_url) return "No image was generated";
try {
let filePath = await this.saveImageFromUrl(image_url); let filePath = await this.saveImageFromUrl(image_url);
console.log(filePath); console.log("[DALL-E]", filePath);
return filePath; var imageMarkdown = `![img](${filePath})`;
if (this.callback != null) await this.callback(imageMarkdown);
return imageMarkdown;
} catch (e) {
if (this.callback != null)
await this.callback("Image upload to R2 storage failed");
return "Image upload to R2 storage failed";
}
} }
description = `openai's dall-e image generator. description = `openai's dall-e 3 image generator.`;
input must be a english prompt.
output will be the image link url.
use markdown to display images. like: ![img](/api/file/xxx.png)`;
} }

View File

@ -107,7 +107,7 @@ async function handle(req: NextRequest) {
} }
}, },
async handleChainError(err, runId, parentRunId, tags) { async handleChainError(err, runId, parentRunId, tags) {
console.log(err, "writer error"); console.log("[handleChainError]", err, "writer error");
var response = new ResponseBody(); var response = new ResponseBody();
response.isSuccess = false; response.isSuccess = false;
response.message = err; response.message = err;
@ -118,6 +118,7 @@ async function handle(req: NextRequest) {
await writer.close(); await writer.close();
}, },
async handleChainEnd(outputs, runId, parentRunId, tags) { async handleChainEnd(outputs, runId, parentRunId, tags) {
console.log("[handleChainEnd]");
await writer.ready; await writer.ready;
await writer.close(); await writer.close();
}, },
@ -126,7 +127,7 @@ async function handle(req: NextRequest) {
// await writer.close(); // await writer.close();
}, },
async handleLLMError(e: Error) { async handleLLMError(e: Error) {
console.log(e, "writer error"); console.log("[handleLLMError]", e, "writer error");
var response = new ResponseBody(); var response = new ResponseBody();
response.isSuccess = false; response.isSuccess = false;
response.message = e.message; response.message = e.message;
@ -144,11 +145,7 @@ async function handle(req: NextRequest) {
}, },
async handleAgentAction(action) { async handleAgentAction(action) {
try { try {
console.log( console.log("[handleAgentAction]", action.tool);
"agent (llm)",
`tool: ${action.tool} toolInput: ${action.toolInput}`,
{ action },
);
if (!reqBody.returnIntermediateSteps) return; if (!reqBody.returnIntermediateSteps) return;
var response = new ResponseBody(); var response = new ResponseBody();
response.isToolMessage = true; response.isToolMessage = true;
@ -171,7 +168,13 @@ async function handle(req: NextRequest) {
} }
}, },
handleToolStart(tool, input) { handleToolStart(tool, input) {
console.log("handleToolStart", { tool, input }); console.log("[handleToolStart]", { tool });
},
async handleToolEnd(output, runId, parentRunId, tags) {
console.log("[handleToolEnd]", { output, runId, parentRunId, tags });
},
handleAgentEnd(action, runId, parentRunId, tags) {
console.log("[handleAgentEnd]");
}, },
}); });
@ -228,7 +231,19 @@ async function handle(req: NextRequest) {
]; ];
const webBrowserTool = new WebBrowser({ model, embeddings }); const webBrowserTool = new WebBrowser({ model, embeddings });
const calculatorTool = new Calculator(); const calculatorTool = new Calculator();
const dallEAPITool = new DallEAPIWrapper(apiKey, baseUrl); const dallEAPITool = new DallEAPIWrapper(
apiKey,
baseUrl,
async (data: string) => {
var response = new ResponseBody();
response.message = data;
await writer.ready;
await writer.write(
encoder.encode(`data: ${JSON.stringify(response)}\n\n`),
);
},
);
dallEAPITool.returnDirect = true;
const stableDiffusionTool = new StableDiffusionWrapper(); const stableDiffusionTool = new StableDiffusionWrapper();
const arxivAPITool = new ArxivAPIWrapper(); const arxivAPITool = new ArxivAPIWrapper();
if (useTools.includes("web-search")) tools.push(searchTool); if (useTools.includes("web-search")) tools.push(searchTool);