mirror of
https://github.com/ChatGPTNextWeb/ChatGPT-Next-Web.git
synced 2025-06-07 21:30:19 +09:00
feat: support dall-e 3
This commit is contained in:
parent
d5566f8c56
commit
8272245f7b
@ -1,20 +1,27 @@
|
|||||||
import { Tool } from "langchain/tools";
|
import { StructuredTool } from "langchain/tools";
|
||||||
|
import { z } from "zod";
|
||||||
import S3FileStorage from "../../utils/r2_file_storage";
|
import S3FileStorage from "../../utils/r2_file_storage";
|
||||||
|
|
||||||
export class DallEAPIWrapper extends Tool {
|
export class DallEAPIWrapper extends StructuredTool {
|
||||||
name = "dalle_image_generator";
|
name = "dalle_image_generator";
|
||||||
n = 1;
|
n = 1;
|
||||||
size: "256x256" | "512x512" | "1024x1024" | null = "1024x1024";
|
|
||||||
apiKey: string;
|
apiKey: string;
|
||||||
baseURL?: string;
|
baseURL?: string;
|
||||||
|
|
||||||
constructor(apiKey?: string | undefined, baseURL?: string | undefined) {
|
callback?: (data: string) => Promise<void>;
|
||||||
|
|
||||||
|
constructor(
|
||||||
|
apiKey?: string | undefined,
|
||||||
|
baseURL?: string | undefined,
|
||||||
|
callback?: (data: string) => Promise<void>,
|
||||||
|
) {
|
||||||
super();
|
super();
|
||||||
if (!apiKey) {
|
if (!apiKey) {
|
||||||
throw new Error("OpenAI API key not set.");
|
throw new Error("OpenAI API key not set.");
|
||||||
}
|
}
|
||||||
this.apiKey = apiKey;
|
this.apiKey = apiKey;
|
||||||
this.baseURL = baseURL;
|
this.baseURL = baseURL;
|
||||||
|
this.callback = callback;
|
||||||
}
|
}
|
||||||
|
|
||||||
async saveImageFromUrl(url: string) {
|
async saveImageFromUrl(url: string) {
|
||||||
@ -24,23 +31,35 @@ export class DallEAPIWrapper extends Tool {
|
|||||||
return await S3FileStorage.put(`${Date.now()}.png`, buffer);
|
return await S3FileStorage.put(`${Date.now()}.png`, buffer);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
schema = z.object({
|
||||||
|
prompt: z
|
||||||
|
.string()
|
||||||
|
.describe(
|
||||||
|
'input must be a english prompt. you can set `quality: "hd"` for enhanced detail.',
|
||||||
|
),
|
||||||
|
size: z
|
||||||
|
.enum(["1024x1024", "1024x1792", "1792x1024"])
|
||||||
|
.default("1024x1024")
|
||||||
|
.describe("images size"),
|
||||||
|
});
|
||||||
|
|
||||||
/** @ignore */
|
/** @ignore */
|
||||||
async _call(prompt: string) {
|
async _call({ prompt, size }: z.infer<typeof this.schema>) {
|
||||||
let image_url;
|
let image_url;
|
||||||
const apiUrl = `${this.baseURL}/images/generations`;
|
const apiUrl = `${this.baseURL}/images/generations`;
|
||||||
|
|
||||||
try {
|
try {
|
||||||
const requestOptions = {
|
const requestOptions = {
|
||||||
method: 'POST',
|
method: "POST",
|
||||||
headers: {
|
headers: {
|
||||||
'Content-Type': 'application/json',
|
"Content-Type": "application/json",
|
||||||
'Authorization': `Bearer ${this.apiKey}`
|
Authorization: `Bearer ${this.apiKey}`,
|
||||||
},
|
},
|
||||||
body: JSON.stringify({
|
body: JSON.stringify({
|
||||||
|
model: "dall-e-3",
|
||||||
prompt: prompt,
|
prompt: prompt,
|
||||||
n: this.n,
|
n: this.n,
|
||||||
size: this.size
|
size: size,
|
||||||
})
|
}),
|
||||||
};
|
};
|
||||||
const response = await fetch(apiUrl, requestOptions);
|
const response = await fetch(apiUrl, requestOptions);
|
||||||
const json = await response.json();
|
const json = await response.json();
|
||||||
@ -50,13 +69,18 @@ export class DallEAPIWrapper extends Tool {
|
|||||||
console.error("[DALL-E]", e);
|
console.error("[DALL-E]", e);
|
||||||
}
|
}
|
||||||
if (!image_url) return "No image was generated";
|
if (!image_url) return "No image was generated";
|
||||||
|
try {
|
||||||
let filePath = await this.saveImageFromUrl(image_url);
|
let filePath = await this.saveImageFromUrl(image_url);
|
||||||
console.log(filePath);
|
console.log("[DALL-E]", filePath);
|
||||||
return filePath;
|
var imageMarkdown = ``;
|
||||||
|
if (this.callback != null) await this.callback(imageMarkdown);
|
||||||
|
return imageMarkdown;
|
||||||
|
} catch (e) {
|
||||||
|
if (this.callback != null)
|
||||||
|
await this.callback("Image upload to R2 storage failed");
|
||||||
|
return "Image upload to R2 storage failed";
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
description = `openai's dall-e image generator.
|
description = `openai's dall-e 3 image generator.`;
|
||||||
input must be a english prompt.
|
|
||||||
output will be the image link url.
|
|
||||||
use markdown to display images. like: `;
|
|
||||||
}
|
}
|
||||||
|
@ -107,7 +107,7 @@ async function handle(req: NextRequest) {
|
|||||||
}
|
}
|
||||||
},
|
},
|
||||||
async handleChainError(err, runId, parentRunId, tags) {
|
async handleChainError(err, runId, parentRunId, tags) {
|
||||||
console.log(err, "writer error");
|
console.log("[handleChainError]", err, "writer error");
|
||||||
var response = new ResponseBody();
|
var response = new ResponseBody();
|
||||||
response.isSuccess = false;
|
response.isSuccess = false;
|
||||||
response.message = err;
|
response.message = err;
|
||||||
@ -118,6 +118,7 @@ async function handle(req: NextRequest) {
|
|||||||
await writer.close();
|
await writer.close();
|
||||||
},
|
},
|
||||||
async handleChainEnd(outputs, runId, parentRunId, tags) {
|
async handleChainEnd(outputs, runId, parentRunId, tags) {
|
||||||
|
console.log("[handleChainEnd]");
|
||||||
await writer.ready;
|
await writer.ready;
|
||||||
await writer.close();
|
await writer.close();
|
||||||
},
|
},
|
||||||
@ -126,7 +127,7 @@ async function handle(req: NextRequest) {
|
|||||||
// await writer.close();
|
// await writer.close();
|
||||||
},
|
},
|
||||||
async handleLLMError(e: Error) {
|
async handleLLMError(e: Error) {
|
||||||
console.log(e, "writer error");
|
console.log("[handleLLMError]", e, "writer error");
|
||||||
var response = new ResponseBody();
|
var response = new ResponseBody();
|
||||||
response.isSuccess = false;
|
response.isSuccess = false;
|
||||||
response.message = e.message;
|
response.message = e.message;
|
||||||
@ -144,11 +145,7 @@ async function handle(req: NextRequest) {
|
|||||||
},
|
},
|
||||||
async handleAgentAction(action) {
|
async handleAgentAction(action) {
|
||||||
try {
|
try {
|
||||||
console.log(
|
console.log("[handleAgentAction]", action.tool);
|
||||||
"agent (llm)",
|
|
||||||
`tool: ${action.tool} toolInput: ${action.toolInput}`,
|
|
||||||
{ action },
|
|
||||||
);
|
|
||||||
if (!reqBody.returnIntermediateSteps) return;
|
if (!reqBody.returnIntermediateSteps) return;
|
||||||
var response = new ResponseBody();
|
var response = new ResponseBody();
|
||||||
response.isToolMessage = true;
|
response.isToolMessage = true;
|
||||||
@ -171,7 +168,13 @@ async function handle(req: NextRequest) {
|
|||||||
}
|
}
|
||||||
},
|
},
|
||||||
handleToolStart(tool, input) {
|
handleToolStart(tool, input) {
|
||||||
console.log("handleToolStart", { tool, input });
|
console.log("[handleToolStart]", { tool });
|
||||||
|
},
|
||||||
|
async handleToolEnd(output, runId, parentRunId, tags) {
|
||||||
|
console.log("[handleToolEnd]", { output, runId, parentRunId, tags });
|
||||||
|
},
|
||||||
|
handleAgentEnd(action, runId, parentRunId, tags) {
|
||||||
|
console.log("[handleAgentEnd]");
|
||||||
},
|
},
|
||||||
});
|
});
|
||||||
|
|
||||||
@ -228,7 +231,19 @@ async function handle(req: NextRequest) {
|
|||||||
];
|
];
|
||||||
const webBrowserTool = new WebBrowser({ model, embeddings });
|
const webBrowserTool = new WebBrowser({ model, embeddings });
|
||||||
const calculatorTool = new Calculator();
|
const calculatorTool = new Calculator();
|
||||||
const dallEAPITool = new DallEAPIWrapper(apiKey, baseUrl);
|
const dallEAPITool = new DallEAPIWrapper(
|
||||||
|
apiKey,
|
||||||
|
baseUrl,
|
||||||
|
async (data: string) => {
|
||||||
|
var response = new ResponseBody();
|
||||||
|
response.message = data;
|
||||||
|
await writer.ready;
|
||||||
|
await writer.write(
|
||||||
|
encoder.encode(`data: ${JSON.stringify(response)}\n\n`),
|
||||||
|
);
|
||||||
|
},
|
||||||
|
);
|
||||||
|
dallEAPITool.returnDirect = true;
|
||||||
const stableDiffusionTool = new StableDiffusionWrapper();
|
const stableDiffusionTool = new StableDiffusionWrapper();
|
||||||
const arxivAPITool = new ArxivAPIWrapper();
|
const arxivAPITool = new ArxivAPIWrapper();
|
||||||
if (useTools.includes("web-search")) tools.push(searchTool);
|
if (useTools.includes("web-search")) tools.push(searchTool);
|
||||||
|
Loading…
Reference in New Issue
Block a user