Merge pull request #20 from Hk-Gosuto/feat/dalle-plugin

Feature: Add OpenAI DALL-E Plugin
This commit is contained in:
Hk-Gosuto 2023-09-18 12:29:49 +08:00 committed by GitHub
commit 4db31d4136
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
11 changed files with 1312 additions and 2 deletions

4
.gitignore vendored
View File

@ -43,4 +43,6 @@ dev
.env
*.key
*.key.pub
*.key.pub
/public/uploads

View File

@ -46,6 +46,8 @@
- 其它
- [Wiki](https://js.langchain.com/docs/api/tools/classes/WikipediaQueryRun)
- DALL-E
- DALL-E 插件需要配置 R2 存储,请参考 [Cloudflare R2 服务配置指南](./docs/cloudflare-r2-cn.md) 配置
@ -161,6 +163,21 @@ OpenAI 接口代理 URL如果你手动配置了 openai 接口代理,请填
如果你不想让用户查询余额,将此环境变量设置为 1 即可。
### `R2_ACCOUNT_ID` (可选)
Cloudflare R2 帐户 ID使用 `DALL-E` 插件时需要配置。
### `R2_ACCESS_KEY_ID` (可选)
Cloudflare R2 访问密钥 ID使用 `DALL-E` 插件时需要配置。
### `R2_SECRET_ACCESS_KEY` (可选)
Cloudflare R2 机密访问密钥,使用 `DALL-E` 插件时需要配置。
### `R2_BUCKET` (可选)
Cloudflare R2 Bucket 名称,使用 `DALL-E` 插件时需要配置。
## 部署
### 容器部署 (推荐)

View File

@ -0,0 +1,36 @@
import { NextRequest, NextResponse } from "next/server";
import { auth } from "../../auth";
import S3FileStorage from "../../../utils/r2_file_storage";
async function handle(
req: NextRequest,
{ params }: { params: { path: string[] } },
) {
if (req.method === "OPTIONS") {
return NextResponse.json({ body: "OK" }, { status: 200 });
}
const authResult = auth(req);
if (authResult.error) {
return NextResponse.json(authResult, {
status: 401,
});
}
try {
var file = await S3FileStorage.get(params.path[0]);
return new Response(file?.transformToWebStream(), {
headers: {
"Content-Type": "image/png",
},
});
} catch (e) {
return new Response("not found", {
status: 404,
});
}
}
export const GET = handle;
export const runtime = "edge";

View File

@ -0,0 +1,52 @@
import { Tool } from "langchain/tools";
import OpenAI from "openai";
import S3FileStorage from "../../utils/r2_file_storage";
export class DallEAPIWrapper extends Tool {
name = "dalle_image_generator";
n = 1;
size: "256x256" | "512x512" | "1024x1024" | null = "1024x1024";
apiKey: string;
baseURL?: string;
constructor(apiKey?: string | undefined, baseURL?: string | undefined) {
super();
if (!apiKey) {
throw new Error("OpenAI API key not set.");
}
this.apiKey = apiKey;
this.baseURL = baseURL;
}
async saveImageFromUrl(url: string) {
const response = await fetch(url);
const content = await response.arrayBuffer();
const buffer = Buffer.from(content);
return await S3FileStorage.put(`${Date.now()}.png`, buffer);
}
/** @ignore */
async _call(prompt: string) {
const openai = new OpenAI({
apiKey: this.apiKey,
baseURL: this.baseURL,
});
const response = await openai.images.generate({
prompt: prompt,
n: this.n,
size: this.size,
});
let image_url = response.data[0].url;
console.log(image_url);
if (!image_url) return "No image was generated";
let filePath = await this.saveImageFromUrl(image_url);
console.log(filePath);
return filePath;
}
description = `openai's dall-e image generator.
input must be a english prompt.
output will be the image link url.
use markdown to display images. like: ![img](/api/file/xxx.png)`;
}

View File

@ -18,6 +18,7 @@ import { DuckDuckGo } from "@/app/api/langchain-tools/duckduckgo_search";
import { WebBrowser } from "langchain/tools/webbrowser";
import { Calculator } from "langchain/tools/calculator";
import { DynamicTool, Tool } from "langchain/tools";
import { DallEAPIWrapper } from "@/app/api/langchain-tools/dalle_image_generator";
const serverConfig = getServerSideConfig();
@ -214,9 +215,11 @@ async function handle(req: NextRequest) {
];
const webBrowserTool = new WebBrowser({ model, embeddings });
const calculatorTool = new Calculator();
const dallEAPITool = new DallEAPIWrapper(apiKey, baseUrl);
if (useTools.includes("web-search")) tools.push(searchTool);
if (useTools.includes(webBrowserTool.name)) tools.push(webBrowserTool);
if (useTools.includes(calculatorTool.name)) tools.push(calculatorTool);
if (useTools.includes(dallEAPITool.name)) tools.push(dallEAPITool);
useTools.forEach((toolName) => {
if (toolName) {

View File

@ -38,4 +38,14 @@ export const CN_PLUGINS: BuiltinPlugin[] = [
createdAt: 1694235989000,
enable: false,
},
{
name: "DALL·E",
toolName: "dalle_image_generator",
lang: "cn",
description:
"DALL·E 可以根据自然语言的描述创建逼真的图像和艺术。使用本插件需要配置 Cloudflare R2 对象存储服务。",
builtin: true,
createdAt: 1694703673000,
enable: false,
},
];

View File

@ -40,4 +40,14 @@ export const EN_PLUGINS: BuiltinPlugin[] = [
createdAt: 1694235989000,
enable: false,
},
{
name: "DALL·E",
toolName: "dalle_image_generator",
lang: "en",
description:
"DALL·E 2 is an AI system that can create realistic images and art from a description in natural language. Using this plugin requires configuring Cloudflare R2 object storage service.",
builtin: true,
createdAt: 1694703673000,
enable: false,
},
];

View File

@ -0,0 +1,61 @@
import {
S3Client,
ListBucketsCommand,
ListObjectsV2Command,
GetObjectCommand,
PutObjectCommand,
} from "@aws-sdk/client-s3";
import { getSignedUrl } from "@aws-sdk/s3-request-presigner";
const R2_ACCOUNT_ID = process.env.R2_ACCOUNT_ID;
const R2_ACCESS_KEY_ID = process.env.R2_ACCESS_KEY_ID;
const R2_SECRET_ACCESS_KEY = process.env.R2_SECRET_ACCESS_KEY;
const R2_BUCKET = process.env.R2_BUCKET;
const getR2Client = () => {
return new S3Client({
region: "auto",
endpoint: `https://${R2_ACCOUNT_ID}.r2.cloudflarestorage.com`,
credentials: {
accessKeyId: R2_ACCESS_KEY_ID!,
secretAccessKey: R2_SECRET_ACCESS_KEY!,
},
});
};
export default class S3FileStorage {
static async get(fileName: string) {
const file = await getR2Client().send(
new GetObjectCommand({
Bucket: R2_BUCKET,
Key: fileName,
}),
);
if (!file) {
throw new Error("not found.");
}
return file.Body;
}
static async put(fileName: string, data: Buffer) {
const signedUrl = await getSignedUrl(
getR2Client(),
new PutObjectCommand({
Bucket: R2_BUCKET,
Key: fileName,
}),
{ expiresIn: 60 },
);
console.log(signedUrl);
await fetch(signedUrl, {
method: "PUT",
body: data,
});
return `/api/file/${fileName}`;
}
}

46
docs/cloudflare-r2-cn.md Normal file
View File

@ -0,0 +1,46 @@
# Cloudflare R2 服务配置指南
## 如何配置 R2 服务
登录到 dash.cloudflare.com 并在左侧菜单进入 R2。
1. 复制右侧 "账户ID"
2. 点击 "创建存储桶"。
3. 自定义配置一个存储桶名称,记录下来用于后面配置环境变量,点击 "创建存储桶"。
4. 进入 "设置",点击 "添加 CORS 策略",将下面内容粘贴上去并点击 "保存"。
```json
[
{
"AllowedOrigins": [
"*"
],
"AllowedMethods": [
"PUT",
"DELETE",
"GET"
],
"AllowedHeaders": [
"content-type"
],
"MaxAgeSeconds": 3000
}
]
```
5. 回到 R2 主菜单,点击右侧 "管理 R2 API 令牌"。
6. 点击 "创建 API 令牌",权限选择为 "管理员读和写"TTL 选择为 "永久",点击 "创建 API 令牌"。
7. 复制 "访问密钥 ID" 和 "机密访问密钥",点击 "完成"。
8. 回到 ChatGPT-Next-Web-LangChain 项目修改环境变量。按照以下信息填写:
- `R2_ACCOUNT_ID=账户ID`
- `R2_ACCESS_KEY_ID=访问密钥 ID`
- `R2_SECRET_ACCESS_KEY=机密访问密钥`
- `R2_BUCKET=存储桶名称`
9. Enjoy.

View File

@ -16,6 +16,8 @@
"proxy-dev": "sh ./scripts/init-proxy.sh && proxychains -f ./scripts/proxychains.conf yarn dev"
},
"dependencies": {
"@aws-sdk/client-s3": "^3.414.0",
"@aws-sdk/s3-request-presigner": "^3.414.0",
"@fortaine/fetch-event-source": "^3.0.6",
"@hello-pangea/dnd": "^16.3.0",
"@svgr/webpack": "^6.5.1",
@ -24,6 +26,7 @@
"cheerio": "^1.0.0-rc.12",
"duck-duck-scrape": "^2.2.4",
"emoji-picker-react": "^4.5.1",
"encoding": "^0.1.13",
"fuse.js": "^6.6.2",
"html-entities": "^2.4.0",
"html-to-image": "^1.11.11",
@ -34,6 +37,7 @@
"nanoid": "^4.0.2",
"next": "^13.4.9",
"node-fetch": "^3.3.1",
"openai": "^4.6.0",
"react": "^18.2.0",
"react-dom": "^18.2.0",
"react-markdown": "^8.0.7",

1071
yarn.lock

File diff suppressed because it is too large Load Diff