diff --git a/.env.template b/.env.template index ff6d3cdbf..7be2f86f3 100644 --- a/.env.template +++ b/.env.template @@ -50,4 +50,22 @@ DISABLE_FAST_LINK= # (optional) # Default: 1 # If your project is not deployed on Vercel, set this value to 1. -NEXT_PUBLIC_ENABLE_NODEJS_PLUGIN=1 \ No newline at end of file +NEXT_PUBLIC_ENABLE_NODEJS_PLUGIN=1 + +# (optional) +# Default: Empty +# If you want to enable RAG, set this value to 1. +ENABLE_RAG= + +# (optional) +# Default: Empty +# Model used when RAG vectorized data. +RAG_EMBEDDING_MODEL=text-embedding-ada-002 + +# Configuration is required when turning on RAG. +# Default: Empty +QDRANT_URL= + +# Configuration is required when turning on RAG. +# Default: Empty +QDRANT_API_KEY= \ No newline at end of file diff --git a/README.md b/README.md index 60ddc1374..7e82c2404 100644 --- a/README.md +++ b/README.md @@ -25,7 +25,7 @@ > [!WARNING] > 本项目插件功能基于 [OpenAI API 函数调用](https://platform.openai.com/docs/guides/function-calling) 功能实现,转发 GitHub Copilot 接口或类似实现的模拟接口并不能正常调用插件功能! -![cover](./docs/images/gpt-vision-example.jpg) +![cover](./docs/images/rag-example.jpg) ![plugin-example](./docs/images/plugin-example.png) @@ -35,6 +35,9 @@ ## 主要功能 +- RAG 功能 (预览) + - 配置请参考文档[RAG 功能配置说明](./docs/rag-cn.md) + - 除插件工具外,与原项目保持一致 [ChatGPT-Next-Web 主要功能](https://github.com/Yidadaa/ChatGPT-Next-Web#主要功能) - 支持 OpenAI TTS(文本转语音)https://github.com/Hk-Gosuto/ChatGPT-Next-Web-LangChain/issues/208 @@ -142,7 +145,7 @@ - [x] 支持语音输入 https://github.com/Hk-Gosuto/ChatGPT-Next-Web-LangChain/issues/208 -- [ ] 支持其他类型文件上传 https://github.com/Hk-Gosuto/ChatGPT-Next-Web-LangChain/issues/77 +- [x] 支持其他类型文件上传 https://github.com/Hk-Gosuto/ChatGPT-Next-Web-LangChain/issues/77 - [ ] 支持 Azure Storage https://github.com/Hk-Gosuto/ChatGPT-Next-Web-LangChain/issues/217 @@ -295,11 +298,9 @@ docker run -d -p 3000:3000 \ | [简体中文](./docs/synchronise-chat-logs-cn.md) | [English](./docs/synchronise-chat-logs-en.md) | [Italiano](./docs/synchronise-chat-logs-es.md) | [日本語](./docs/synchronise-chat-logs-ja.md) | [한국어](./docs/synchronise-chat-logs-ko.md) -## 贡献者 +## Star History - - - +[![Star History Chart](https://api.star-history.com/svg?repos=Hk-Gosuto/ChatGPT-Next-Web-LangChain&type=Date)](https://star-history.com/#Hk-Gosuto/ChatGPT-Next-Web-LangChain&Date) ## 捐赠 diff --git a/app/api/config/route.ts b/app/api/config/route.ts index db84fba17..aed95a1b4 100644 --- a/app/api/config/route.ts +++ b/app/api/config/route.ts @@ -13,6 +13,7 @@ const DANGER_CONFIG = { hideBalanceQuery: serverConfig.hideBalanceQuery, disableFastLink: serverConfig.disableFastLink, customModels: serverConfig.customModels, + isEnableRAG: serverConfig.isEnableRAG, }; declare global { diff --git a/app/api/file/[...path]/route.ts b/app/api/file/[...path]/route.ts index f6fbfd52f..e253389a0 100644 --- a/app/api/file/[...path]/route.ts +++ b/app/api/file/[...path]/route.ts @@ -2,6 +2,7 @@ import { getServerSideConfig } from "@/app/config/server"; import LocalFileStorage from "@/app/utils/local_file_storage"; import S3FileStorage from "@/app/utils/s3_file_storage"; import { NextRequest, NextResponse } from "next/server"; +import mime from "mime"; async function handle( req: NextRequest, @@ -13,19 +14,27 @@ async function handle( try { const serverConfig = getServerSideConfig(); + const fileName = params.path[0]; + const contentType = mime.getType(fileName); + if (serverConfig.isStoreFileToLocal) { - var fileBuffer = await LocalFileStorage.get(params.path[0]); + var fileBuffer = await LocalFileStorage.get(fileName); return new Response(fileBuffer, { headers: { - "Content-Type": "image/png", + "Content-Type": contentType ?? "application/octet-stream", }, }); } else { - var file = await S3FileStorage.get(params.path[0]); - return new Response(file?.transformToWebStream(), { - headers: { - "Content-Type": "image/png", - }, + var file = await S3FileStorage.get(fileName); + if (file) { + return new Response(file?.transformToWebStream(), { + headers: { + "Content-Type": contentType ?? "application/octet-stream", + }, + }); + } + return new Response("not found", { + status: 404, }); } } catch (e) { diff --git a/app/api/file/upload/route.ts b/app/api/file/upload/route.ts index 65991477a..c372065e7 100644 --- a/app/api/file/upload/route.ts +++ b/app/api/file/upload/route.ts @@ -4,6 +4,7 @@ import { auth } from "@/app/api/auth"; import LocalFileStorage from "@/app/utils/local_file_storage"; import { getServerSideConfig } from "@/app/config/server"; import S3FileStorage from "@/app/utils/s3_file_storage"; +import path from "path"; async function handle(req: NextRequest) { if (req.method === "OPTIONS") { @@ -19,20 +20,14 @@ async function handle(req: NextRequest) { try { const formData = await req.formData(); - const image = formData.get("file") as File; + const file = formData.get("file") as File; + const fileData = await file.arrayBuffer(); + const originalFileName = file?.name; - const imageReader = image.stream().getReader(); - const imageData: number[] = []; - - while (true) { - const { done, value } = await imageReader.read(); - if (done) break; - imageData.push(...value); - } - - const buffer = Buffer.from(imageData); - - var fileName = `${Date.now()}.png`; + if (!fileData) throw new Error("Get file buffer error"); + const buffer = Buffer.from(fileData); + const fileType = path.extname(originalFileName).slice(1); + var fileName = `${Date.now()}.${fileType}`; var filePath = ""; const serverConfig = getServerSideConfig(); if (serverConfig.isStoreFileToLocal) { diff --git a/app/api/langchain-tools/nodejs_tools.ts b/app/api/langchain-tools/nodejs_tools.ts index 6b8d9c084..f4033df53 100644 --- a/app/api/langchain-tools/nodejs_tools.ts +++ b/app/api/langchain-tools/nodejs_tools.ts @@ -10,16 +10,15 @@ import { WolframAlphaTool } from "@/app/api/langchain-tools/wolframalpha"; import { BilibiliVideoInfoTool } from "./bilibili_vid_info"; import { BilibiliVideoSearchTool } from "./bilibili_vid_search"; import { BilibiliMusicRecognitionTool } from "./bilibili_music_recognition"; +import { RAGSearch } from "./rag_search"; export class NodeJSTool { private apiKey: string | undefined; - private baseUrl: string; - private model: BaseLanguageModel; - private embeddings: Embeddings; - + private sessionId: string; + private ragEmbeddings: Embeddings; private callback?: (data: string) => Promise; constructor( @@ -27,12 +26,16 @@ export class NodeJSTool { baseUrl: string, model: BaseLanguageModel, embeddings: Embeddings, + sessionId: string, + ragEmbeddings: Embeddings, callback?: (data: string) => Promise, ) { this.apiKey = apiKey; this.baseUrl = baseUrl; this.model = model; this.embeddings = embeddings; + this.sessionId = sessionId; + this.ragEmbeddings = ragEmbeddings; this.callback = callback; } @@ -66,6 +69,9 @@ export class NodeJSTool { bilibiliVideoSearchTool, bilibiliMusicRecognitionTool, ]; + if (!!process.env.ENABLE_RAG) { + tools.push(new RAGSearch(this.sessionId, this.model, this.ragEmbeddings)); + } return tools; } } diff --git a/app/api/langchain-tools/rag_search.ts b/app/api/langchain-tools/rag_search.ts new file mode 100644 index 000000000..c3db3c4c3 --- /dev/null +++ b/app/api/langchain-tools/rag_search.ts @@ -0,0 +1,79 @@ +import { Tool } from "@langchain/core/tools"; +import { CallbackManagerForToolRun } from "@langchain/core/callbacks/manager"; +import { BaseLanguageModel } from "langchain/dist/base_language"; +import { formatDocumentsAsString } from "langchain/util/document"; +import { Embeddings } from "langchain/dist/embeddings/base.js"; +import { RunnableSequence } from "@langchain/core/runnables"; +import { StringOutputParser } from "@langchain/core/output_parsers"; +import { Pinecone } from "@pinecone-database/pinecone"; +import { PineconeStore } from "@langchain/pinecone"; +import { getServerSideConfig } from "@/app/config/server"; +import { QdrantVectorStore } from "@langchain/community/vectorstores/qdrant"; + +export class RAGSearch extends Tool { + static lc_name() { + return "RAGSearch"; + } + + get lc_namespace() { + return [...super.lc_namespace, "ragsearch"]; + } + + private sessionId: string; + private model: BaseLanguageModel; + private embeddings: Embeddings; + + constructor( + sessionId: string, + model: BaseLanguageModel, + embeddings: Embeddings, + ) { + super(); + this.sessionId = sessionId; + this.model = model; + this.embeddings = embeddings; + } + + /** @ignore */ + async _call(inputs: string, runManager?: CallbackManagerForToolRun) { + const serverConfig = getServerSideConfig(); + if (!serverConfig.isEnableRAG) + throw new Error("env ENABLE_RAG not configured"); + // const pinecone = new Pinecone(); + // const pineconeIndex = pinecone.Index(serverConfig.pineconeIndex!); + // const vectorStore = await PineconeStore.fromExistingIndex(this.embeddings, { + // pineconeIndex, + // }); + const vectorStore = await QdrantVectorStore.fromExistingCollection( + this.embeddings, + { + url: process.env.QDRANT_URL, + apiKey: process.env.QDRANT_API_KEY, + collectionName: this.sessionId, + }, + ); + + let context; + const returnCunt = serverConfig.ragReturnCount + ? parseInt(serverConfig.ragReturnCount, 10) + : 4; + console.log("[rag-search]", { inputs, returnCunt }); + // const results = await vectorStore.similaritySearch(inputs, returnCunt, { + // sessionId: this.sessionId, + // }); + const results = await vectorStore.similaritySearch(inputs, returnCunt); + context = formatDocumentsAsString(results); + console.log("[rag-search]", { context }); + return context; + // const input = `Text:${context}\n\nQuestion:${inputs}\n\nI need you to answer the question based on the text.`; + + // console.log("[rag-search]", input); + + // const chain = RunnableSequence.from([this.model, new StringOutputParser()]); + // return chain.invoke(input, runManager?.getChild()); + } + + name = "rag-search"; + + description = `It is used to query documents entered by the user.The input content is the keywords extracted from the user's question, and multiple keywords are separated by spaces and passed in.`; +} diff --git a/app/api/langchain/rag/search/route.ts b/app/api/langchain/rag/search/route.ts new file mode 100644 index 000000000..8c5aae6ea --- /dev/null +++ b/app/api/langchain/rag/search/route.ts @@ -0,0 +1,120 @@ +import { NextRequest, NextResponse } from "next/server"; +import { auth } from "@/app/api/auth"; +import { ACCESS_CODE_PREFIX, ModelProvider } from "@/app/constant"; +import { OpenAIEmbeddings } from "@langchain/openai"; +import { Pinecone } from "@pinecone-database/pinecone"; +import { PineconeStore } from "@langchain/pinecone"; +import { QdrantVectorStore } from "@langchain/community/vectorstores/qdrant"; +import { getServerSideConfig } from "@/app/config/server"; + +interface RequestBody { + sessionId: string; + query: string; + baseUrl?: string; +} + +async function handle(req: NextRequest) { + if (req.method === "OPTIONS") { + return NextResponse.json({ body: "OK" }, { status: 200 }); + } + try { + const authResult = auth(req, ModelProvider.GPT); + if (authResult.error) { + return NextResponse.json(authResult, { + status: 401, + }); + } + + const reqBody: RequestBody = await req.json(); + const authToken = req.headers.get("Authorization") ?? ""; + const token = authToken.trim().replaceAll("Bearer ", "").trim(); + const serverConfig = getServerSideConfig(); + // const pinecone = new Pinecone(); + // const pineconeIndex = pinecone.Index(serverConfig.pineconeIndex!); + const apiKey = getOpenAIApiKey(token); + const baseUrl = getOpenAIBaseUrl(reqBody.baseUrl); + const embeddings = new OpenAIEmbeddings( + { + modelName: serverConfig.ragEmbeddingModel ?? "text-embedding-3-large", + openAIApiKey: apiKey, + }, + { basePath: baseUrl }, + ); + // const vectorStore = await PineconeStore.fromExistingIndex(embeddings, { + // pineconeIndex, + // }); + // const results = await vectorStore.similaritySearch(reqBody.query, 4, { + // sessionId: reqBody.sessionId, + // }); + const vectorStore = await QdrantVectorStore.fromExistingCollection( + embeddings, + { + url: process.env.QDRANT_URL, + apiKey: process.env.QDRANT_API_KEY, + collectionName: reqBody.sessionId, + }, + ); + const returnCunt = serverConfig.ragReturnCount + ? parseInt(serverConfig.ragReturnCount, 10) + : 4; + const response = await vectorStore.similaritySearch( + reqBody.query, + returnCunt, + ); + return NextResponse.json(response, { + status: 200, + }); + } catch (e) { + console.error(e); + return new Response(JSON.stringify({ error: (e as any).message }), { + status: 500, + headers: { "Content-Type": "application/json" }, + }); + } +} + +function getOpenAIApiKey(token: string) { + const serverConfig = getServerSideConfig(); + const isApiKey = !token.startsWith(ACCESS_CODE_PREFIX); + + let apiKey = serverConfig.apiKey; + if (isApiKey && token) { + apiKey = token; + } + return apiKey; +} + +function getOpenAIBaseUrl(reqBaseUrl: string | undefined) { + const serverConfig = getServerSideConfig(); + let baseUrl = "https://api.openai.com/v1"; + if (serverConfig.baseUrl) baseUrl = serverConfig.baseUrl; + if (reqBaseUrl?.startsWith("http://") || reqBaseUrl?.startsWith("https://")) + baseUrl = reqBaseUrl; + if (!baseUrl.endsWith("/v1")) + baseUrl = baseUrl.endsWith("/") ? `${baseUrl}v1` : `${baseUrl}/v1`; + console.log("[baseUrl]", baseUrl); + return baseUrl; +} + +export const POST = handle; + +export const runtime = "nodejs"; +export const preferredRegion = [ + "arn1", + "bom1", + "cdg1", + "cle1", + "cpt1", + "dub1", + "fra1", + "gru1", + "hnd1", + "iad1", + "icn1", + "kix1", + "lhr1", + "pdx1", + "sfo1", + "sin1", + "syd1", +]; diff --git a/app/api/langchain/rag/store/route.ts b/app/api/langchain/rag/store/route.ts new file mode 100644 index 000000000..9ded033d9 --- /dev/null +++ b/app/api/langchain/rag/store/route.ts @@ -0,0 +1,221 @@ +import { NextRequest, NextResponse } from "next/server"; +import { auth } from "@/app/api/auth"; +import { ACCESS_CODE_PREFIX, ModelProvider } from "@/app/constant"; +import { OpenAI, OpenAIEmbeddings } from "@langchain/openai"; +import { PDFLoader } from "langchain/document_loaders/fs/pdf"; +import { TextLoader } from "langchain/document_loaders/fs/text"; +import { CSVLoader } from "langchain/document_loaders/fs/csv"; +import { DocxLoader } from "langchain/document_loaders/fs/docx"; +import { EPubLoader } from "langchain/document_loaders/fs/epub"; +import { JSONLoader } from "langchain/document_loaders/fs/json"; +import { JSONLinesLoader } from "langchain/document_loaders/fs/json"; +import { OpenAIWhisperAudio } from "langchain/document_loaders/fs/openai_whisper_audio"; +// import { PPTXLoader } from "langchain/document_loaders/fs/pptx"; +import { SRTLoader } from "langchain/document_loaders/fs/srt"; +import { RecursiveCharacterTextSplitter } from "langchain/text_splitter"; +import { Pinecone } from "@pinecone-database/pinecone"; +import { PineconeStore } from "@langchain/pinecone"; +import { getServerSideConfig } from "@/app/config/server"; +import { FileInfo } from "@/app/client/platforms/utils"; +import mime from "mime"; +import LocalFileStorage from "@/app/utils/local_file_storage"; +import S3FileStorage from "@/app/utils/s3_file_storage"; +import { QdrantVectorStore } from "@langchain/community/vectorstores/qdrant"; + +interface RequestBody { + sessionId: string; + fileInfos: FileInfo[]; + baseUrl?: string; +} + +function getLoader( + fileName: string, + fileBlob: Blob, + openaiApiKey: string, + openaiBaseUrl: string, +) { + const extension = fileName.split(".").pop(); + switch (extension) { + case "txt": + case "md": + return new TextLoader(fileBlob); + case "pdf": + return new PDFLoader(fileBlob); + case "docx": + return new DocxLoader(fileBlob); + case "csv": + return new CSVLoader(fileBlob); + case "json": + return new JSONLoader(fileBlob); + // case 'pptx': + // return new PPTXLoader(fileBlob); + case "srt": + return new SRTLoader(fileBlob); + case "mp3": + return new OpenAIWhisperAudio(fileBlob, { + clientOptions: { + apiKey: openaiApiKey, + baseURL: openaiBaseUrl, + }, + }); + default: + throw new Error(`Unsupported file type: ${extension}`); + } +} + +async function handle(req: NextRequest) { + if (req.method === "OPTIONS") { + return NextResponse.json({ body: "OK" }, { status: 200 }); + } + try { + const authResult = auth(req, ModelProvider.GPT); + if (authResult.error) { + return NextResponse.json(authResult, { + status: 401, + }); + } + + const reqBody: RequestBody = await req.json(); + const authToken = req.headers.get("Authorization") ?? ""; + const token = authToken.trim().replaceAll("Bearer ", "").trim(); + const apiKey = getOpenAIApiKey(token); + const baseUrl = getOpenAIBaseUrl(reqBody.baseUrl); + const serverConfig = getServerSideConfig(); + // const pinecone = new Pinecone(); + // const pineconeIndex = pinecone.Index(serverConfig.pineconeIndex!); + const embeddings = new OpenAIEmbeddings( + { + modelName: serverConfig.ragEmbeddingModel, + openAIApiKey: apiKey, + }, + { basePath: baseUrl }, + ); + // https://js.langchain.com/docs/integrations/vectorstores/pinecone + // https://js.langchain.com/docs/integrations/vectorstores/qdrant + // process files + for (let i = 0; i < reqBody.fileInfos.length; i++) { + const fileInfo = reqBody.fileInfos[i]; + const contentType = mime.getType(fileInfo.fileName); + // get file buffer + var fileBuffer: Buffer | undefined; + if (serverConfig.isStoreFileToLocal) { + fileBuffer = await LocalFileStorage.get(fileInfo.fileName); + } else { + var file = await S3FileStorage.get(fileInfo.fileName); + var fileByteArray = await file?.transformToByteArray(); + if (fileByteArray) fileBuffer = Buffer.from(fileByteArray); + } + if (!fileBuffer || !contentType) { + console.error(`get ${fileInfo.fileName} buffer fail`); + continue; + } + // load file to docs + const fileBlob = bufferToBlob(fileBuffer, contentType); + const loader = getLoader(fileInfo.fileName, fileBlob, apiKey, baseUrl); + const docs = await loader.load(); + // modify doc meta + docs.forEach((doc) => { + doc.metadata = { + ...doc.metadata, + sessionId: reqBody.sessionId, + sourceFileName: fileInfo.originalFilename, + fileName: fileInfo.fileName, + }; + }); + // split + const chunkSize = serverConfig.ragChunkSize + ? parseInt(serverConfig.ragChunkSize, 10) + : 2000; + const chunkOverlap = serverConfig.ragChunkOverlap + ? parseInt(serverConfig.ragChunkOverlap, 10) + : 200; + const textSplitter = new RecursiveCharacterTextSplitter({ + chunkSize: chunkSize, + chunkOverlap: chunkOverlap, + }); + const splits = await textSplitter.splitDocuments(docs); + const vectorStore = await QdrantVectorStore.fromDocuments( + splits, + embeddings, + { + url: process.env.QDRANT_URL, + apiKey: process.env.QDRANT_API_KEY, + collectionName: reqBody.sessionId, + }, + ); + // await PineconeStore.fromDocuments(splits, embeddings, { + // pineconeIndex, + // maxConcurrency: 5, + // }); + // const vectorStore = await PineconeStore.fromExistingIndex(embeddings, { + // pineconeIndex, + // }); + } + return NextResponse.json( + { + sessionId: reqBody.sessionId, + }, + { + status: 200, + }, + ); + } catch (e) { + console.error(e); + return new Response(JSON.stringify({ error: (e as any).message }), { + status: 500, + headers: { "Content-Type": "application/json" }, + }); + } +} + +function bufferToBlob(buffer: Buffer, mimeType?: string): Blob { + const arrayBuffer: ArrayBuffer = buffer.buffer.slice( + buffer.byteOffset, + buffer.byteOffset + buffer.byteLength, + ); + return new Blob([arrayBuffer], { type: mimeType || "" }); +} +function getOpenAIApiKey(token: string) { + const serverConfig = getServerSideConfig(); + const isApiKey = !token.startsWith(ACCESS_CODE_PREFIX); + + let apiKey = serverConfig.apiKey; + if (isApiKey && token) { + apiKey = token; + } + return apiKey; +} +function getOpenAIBaseUrl(reqBaseUrl: string | undefined) { + const serverConfig = getServerSideConfig(); + let baseUrl = "https://api.openai.com/v1"; + if (serverConfig.baseUrl) baseUrl = serverConfig.baseUrl; + if (reqBaseUrl?.startsWith("http://") || reqBaseUrl?.startsWith("https://")) + baseUrl = reqBaseUrl; + if (!baseUrl.endsWith("/v1")) + baseUrl = baseUrl.endsWith("/") ? `${baseUrl}v1` : `${baseUrl}/v1`; + console.log("[baseUrl]", baseUrl); + return baseUrl; +} + +export const POST = handle; + +export const runtime = "nodejs"; +export const preferredRegion = [ + "arn1", + "bom1", + "cdg1", + "cle1", + "cpt1", + "dub1", + "fra1", + "gru1", + "hnd1", + "iad1", + "icn1", + "kix1", + "lhr1", + "pdx1", + "sfo1", + "sin1", + "syd1", +]; diff --git a/app/api/langchain/tool/agent/agentapi.ts b/app/api/langchain/tool/agent/agentapi.ts index 5f51a8cf4..cbbfa8f5e 100644 --- a/app/api/langchain/tool/agent/agentapi.ts +++ b/app/api/langchain/tool/agent/agentapi.ts @@ -44,6 +44,7 @@ export interface RequestMessage { } export interface RequestBody { + chatSessionId: string; messages: RequestMessage[]; isAzure: boolean; azureApiVersion?: string; diff --git a/app/api/langchain/tool/agent/nodejs/route.ts b/app/api/langchain/tool/agent/nodejs/route.ts index e8f6c80b8..e8169373b 100644 --- a/app/api/langchain/tool/agent/nodejs/route.ts +++ b/app/api/langchain/tool/agent/nodejs/route.ts @@ -44,6 +44,13 @@ async function handle(req: NextRequest) { }, { basePath: baseUrl }, ); + const ragEmbeddings = new OpenAIEmbeddings( + { + modelName: process.env.RAG_EMBEDDING_MODEL ?? "text-embedding-3-large", + openAIApiKey: apiKey, + }, + { basePath: baseUrl }, + ); var dalleCallback = async (data: string) => { var response = new ResponseBody(); @@ -62,6 +69,8 @@ async function handle(req: NextRequest) { baseUrl, model, embeddings, + reqBody.chatSessionId, + ragEmbeddings, dalleCallback, ); var nodejsTools = await nodejsTool.getCustomTools(); diff --git a/app/client/api.ts b/app/client/api.ts index a44741d93..77759b318 100644 --- a/app/client/api.ts +++ b/app/client/api.ts @@ -7,7 +7,7 @@ import { } from "../constant"; import { ChatMessage, ModelType, useAccessStore, useChatStore } from "../store"; import { ChatGPTApi } from "./platforms/openai"; -import { FileApi } from "./platforms/utils"; +import { FileApi, FileInfo } from "./platforms/utils"; import { GeminiProApi } from "./platforms/google"; export const ROLES = ["system", "user", "assistant"] as const; export type MessageRole = (typeof ROLES)[number]; @@ -27,6 +27,7 @@ export interface MultimodalContent { export interface RequestMessage { role: MessageRole; content: string | MultimodalContent[]; + fileInfos?: FileInfo[]; } export interface LLMConfig { @@ -74,6 +75,7 @@ export interface ChatOptions { } export interface AgentChatOptions { + chatSessionId?: string; messages: RequestMessage[]; config: LLMConfig; agentConfig: LLMAgentConfig; @@ -84,6 +86,13 @@ export interface AgentChatOptions { onController?: (controller: AbortController) => void; } +export interface CreateRAGStoreOptions { + chatSessionId: string; + fileInfos: FileInfo[]; + onError?: (err: Error) => void; + onController?: (controller: AbortController) => void; +} + export interface LLMUsage { used: number; total: number; @@ -106,6 +115,7 @@ export abstract class LLMApi { abstract speech(options: SpeechOptions): Promise; abstract transcription(options: TranscriptionOptions): Promise; abstract toolAgentChat(options: AgentChatOptions): Promise; + abstract createRAGStore(options: CreateRAGStoreOptions): Promise; abstract usage(): Promise; abstract models(): Promise; } @@ -213,8 +223,8 @@ export function getHeaders(ignoreHeaders?: boolean) { const apiKey = isGoogle ? accessStore.googleApiKey : isAzure - ? accessStore.azureApiKey - : accessStore.openaiApiKey; + ? accessStore.azureApiKey + : accessStore.openaiApiKey; const makeBearer = (s: string) => `${isGoogle || isAzure ? "" : "Bearer "}${s.trim()}`; diff --git a/app/client/platforms/google.ts b/app/client/platforms/google.ts index d18b8966b..ffc1e977f 100644 --- a/app/client/platforms/google.ts +++ b/app/client/platforms/google.ts @@ -2,6 +2,7 @@ import { Google, REQUEST_TIMEOUT_MS } from "@/app/constant"; import { AgentChatOptions, ChatOptions, + CreateRAGStoreOptions, getHeaders, LLMApi, LLMModel, @@ -19,6 +20,9 @@ import { } from "@/app/utils"; export class GeminiProApi implements LLMApi { + createRAGStore(options: CreateRAGStoreOptions): Promise { + throw new Error("Method not implemented."); + } transcription(options: TranscriptionOptions): Promise { throw new Error("Method not implemented."); } diff --git a/app/client/platforms/openai.ts b/app/client/platforms/openai.ts index 44c095742..0ff9ad706 100644 --- a/app/client/platforms/openai.ts +++ b/app/client/platforms/openai.ts @@ -12,6 +12,7 @@ import { useAccessStore, useAppConfig, useChatStore } from "@/app/store"; import { AgentChatOptions, ChatOptions, + CreateRAGStoreOptions, getHeaders, LLMApi, LLMModel, @@ -362,6 +363,34 @@ export class ChatGPTApi implements LLMApi { } } + async createRAGStore(options: CreateRAGStoreOptions): Promise { + try { + const accessStore = useAccessStore.getState(); + const isAzure = accessStore.provider === ServiceProvider.Azure; + let baseUrl = isAzure ? accessStore.azureUrl : accessStore.openaiUrl; + const requestPayload = { + sessionId: options.chatSessionId, + fileInfos: options.fileInfos, + baseUrl: baseUrl, + }; + console.log("[Request] rag store payload: ", requestPayload); + const controller = new AbortController(); + options.onController?.(controller); + let path = "/api/langchain/rag/store"; + const chatPayload = { + method: "POST", + body: JSON.stringify(requestPayload), + signal: controller.signal, + headers: getHeaders(), + }; + const res = await fetch(path, chatPayload); + if (res.status !== 200) throw new Error(await res.text()); + } catch (e) { + console.log("[Request] failed to make a chat reqeust", e); + options.onError?.(e as Error); + } + } + async toolAgentChat(options: AgentChatOptions) { const messages = options.messages.map((v) => ({ role: v.role, @@ -379,6 +408,7 @@ export class ChatGPTApi implements LLMApi { const isAzure = accessStore.provider === ServiceProvider.Azure; let baseUrl = isAzure ? accessStore.azureUrl : accessStore.openaiUrl; const requestPayload = { + chatSessionId: options.chatSessionId, messages, isAzure, azureApiVersion: accessStore.azureApiVersion, diff --git a/app/client/platforms/utils.ts b/app/client/platforms/utils.ts index 9973bb19a..543b96a12 100644 --- a/app/client/platforms/utils.ts +++ b/app/client/platforms/utils.ts @@ -1,7 +1,16 @@ import { getHeaders } from "../api"; +export interface FileInfo { + originalFilename: string; + fileName: string; + filePath: string; + size: number; +} + export class FileApi { - async upload(file: any): Promise { + async upload(file: any): Promise { + const fileName = file.name; + const fileSize = file.size; const formData = new FormData(); formData.append("file", file); var headers = getHeaders(true); @@ -16,6 +25,8 @@ export class FileApi { const resJson = await res.json(); console.log(resJson); return { + originalFilename: fileName, + size: fileSize, fileName: resJson.fileName, filePath: resJson.filePath, }; diff --git a/app/components/chat.module.scss b/app/components/chat.module.scss index d9d97666e..703d029a7 100644 --- a/app/components/chat.module.scss +++ b/app/components/chat.module.scss @@ -1,5 +1,69 @@ @import "../styles/animation.scss"; +.attach-files { + position: absolute; + left: 30px; + bottom: 32px; + display: flex; +} + +.attach-file { + cursor: default; + width: 64px; + height: 64px; + border: rgba($color: #888, $alpha: 0.2) 1px solid; + border-radius: 5px; + margin-right: 10px; + background-size: cover; + background-position: center; + background-color: var(--second); + display: flex; + position: relative; + justify-content: center; + align-items: center; + + .attach-file-info { + top: 5px; + width: 100%; + position: absolute; + font-size: 12px; + font-weight: bolder; + text-align: center; + word-wrap: break-word; + word-break: break-all; + -webkit-line-clamp: 3; + -webkit-box-orient: vertical; + line-height: 1.5; + overflow: hidden; + text-overflow: ellipsis; + display: -webkit-box; + } + + .attach-file-mask { + width: 100%; + height: 100%; + opacity: 0; + transition: all ease 0.2s; + z-index: 999; + } + + .attach-file-mask:hover { + opacity: 1; + } + + .delete-file { + width: 24px; + height: 24px; + cursor: pointer; + display: flex; + align-items: center; + justify-content: center; + border-radius: 5px; + float: right; + background-color: var(--white); + } +} + .attach-images { position: absolute; left: 30px; @@ -232,10 +296,12 @@ animation: slide-in ease 0.3s; - $linear: linear-gradient(to right, - rgba(0, 0, 0, 0), - rgba(0, 0, 0, 1), - rgba(0, 0, 0, 0)); + $linear: linear-gradient( + to right, + rgba(0, 0, 0, 0), + rgba(0, 0, 0, 1), + rgba(0, 0, 0, 0) + ); mask-image: $linear; @mixin show { @@ -368,7 +434,7 @@ } } -.chat-message-user>.chat-message-container { +.chat-message-user > .chat-message-container { align-items: flex-end; } @@ -454,6 +520,17 @@ transition: all ease 0.3s; } +.chat-message-item-files { + display: grid; + grid-template-columns: repeat(var(--file-count), auto); + grid-gap: 5px; +} + +.chat-message-item-file { + text-decoration: none; + color: #aaa; +} + .chat-message-item-image { width: 100%; margin-top: 10px; @@ -482,23 +559,27 @@ border: rgba($color: #888, $alpha: 0.2) 1px solid; } - @media only screen and (max-width: 600px) { - $calc-image-width: calc(100vw/3*2/var(--image-count)); + $calc-image-width: calc(100vw / 3 * 2 / var(--image-count)); .chat-message-item-image-multi { width: $calc-image-width; height: $calc-image-width; } - + .chat-message-item-image { - max-width: calc(100vw/3*2); + max-width: calc(100vw / 3 * 2); } } @media screen and (min-width: 600px) { - $max-image-width: calc(calc(1200px - var(--sidebar-width))/3*2/var(--image-count)); - $image-width: calc(calc(var(--window-width) - var(--sidebar-width))/3*2/var(--image-count)); + $max-image-width: calc( + calc(1200px - var(--sidebar-width)) / 3 * 2 / var(--image-count) + ); + $image-width: calc( + calc(var(--window-width) - var(--sidebar-width)) / 3 * 2 / + var(--image-count) + ); .chat-message-item-image-multi { width: $image-width; @@ -508,7 +589,7 @@ } .chat-message-item-image { - max-width: calc(calc(1200px - var(--sidebar-width))/3*2); + max-width: calc(calc(1200px - var(--sidebar-width)) / 3 * 2); } } @@ -526,7 +607,7 @@ z-index: 1; } -.chat-message-user>.chat-message-container>.chat-message-item { +.chat-message-user > .chat-message-container > .chat-message-item { background-color: var(--second); &:hover { @@ -637,7 +718,8 @@ min-height: 68px; } -.chat-input:focus {} +.chat-input:focus { +} .chat-input-send { background-color: var(--primary); @@ -656,4 +738,4 @@ .chat-input-send { bottom: 30px; } -} \ No newline at end of file +} diff --git a/app/components/chat.tsx b/app/components/chat.tsx index 6c33f6bc5..01e9eb62a 100644 --- a/app/components/chat.tsx +++ b/app/components/chat.tsx @@ -69,6 +69,7 @@ import { isVisionModel, compressImage, isFirefox, + isSupportRAGModel, } from "../utils"; import dynamic from "next/dynamic"; @@ -116,6 +117,7 @@ import { SpeechApi, WebTranscriptionApi, } from "../utils/speech"; +import { FileInfo } from "../client/platforms/utils"; const ttsPlayer = createTTSPlayer(); @@ -460,6 +462,8 @@ function useScrollToBottom( export function ChatActions(props: { uploadImage: () => void; setAttachImages: (images: string[]) => void; + uploadFile: () => void; + setAttachFiles: (files: FileInfo[]) => void; setUploading: (uploading: boolean) => void; showPromptModal: () => void; scrollToBottom: () => void; @@ -502,10 +506,19 @@ export function ChatActions(props: { ); const [showModelSelector, setShowModelSelector] = useState(false); const [showUploadImage, setShowUploadImage] = useState(false); + const [showUploadFile, setShowUploadFile] = useState(false); + + const accessStore = useAccessStore(); + const isEnableRAG = useMemo( + () => accessStore.enableRAG(), + // eslint-disable-next-line react-hooks/exhaustive-deps + [], + ); useEffect(() => { const show = isVisionModel(currentModel); setShowUploadImage(show); + setShowUploadFile(isEnableRAG && !show && isSupportRAGModel(currentModel)); if (!show) { props.setAttachImages([]); props.setUploading(false); @@ -555,6 +568,14 @@ export function ChatActions(props: { icon={props.uploading ? : } /> )} + + {showUploadFile && ( + : } + /> + )} void }) { ); } +export function DeleteFileButton(props: { deleteFile: () => void }) { + return ( +
+ +
+ ); +} + function _Chat() { type RenderMessage = ChatMessage & { preview?: boolean }; @@ -743,6 +772,7 @@ function _Chat() { const navigate = useNavigate(); const [attachImages, setAttachImages] = useState([]); const [uploading, setUploading] = useState(false); + const [attachFiles, setAttachFiles] = useState([]); // prompt hints const promptStore = usePromptStore(); @@ -848,9 +878,10 @@ function _Chat() { } setIsLoading(true); chatStore - .onUserInput(userInput, attachImages) + .onUserInput(userInput, attachImages, attachFiles) .then(() => setIsLoading(false)); setAttachImages([]); + setAttachFiles([]); localStorage.setItem(LAST_INPUT_KEY, userInput); setUserInput(""); setPromptHints([]); @@ -1010,7 +1041,9 @@ function _Chat() { setIsLoading(true); const textContent = getMessageTextContent(userMessage); const images = getMessageImages(userMessage); - chatStore.onUserInput(textContent, images).then(() => setIsLoading(false)); + chatStore + .onUserInput(textContent, images, userMessage.fileInfos) + .then(() => setIsLoading(false)); inputRef.current?.focus(); }; @@ -1077,34 +1110,36 @@ function _Chat() { // preview messages const renderMessages = useMemo(() => { - return context - .concat(session.messages as RenderMessage[]) - .concat( - isLoading - ? [ - { - ...createMessage({ - role: "assistant", - content: "……", - }), - preview: true, - }, - ] - : [], - ) - .concat( - userInput.length > 0 && config.sendPreviewBubble - ? [ - { - ...createMessage({ - role: "user", - content: userInput, - }), - preview: true, - }, - ] - : [], - ); + return ( + context + .concat(session.messages as RenderMessage[]) + // .concat( + // isLoading + // ? [ + // { + // ...createMessage({ + // role: "assistant", + // content: "……", + // }), + // preview: true, + // }, + // ] + // : [], + // ) + .concat( + userInput.length > 0 && config.sendPreviewBubble + ? [ + { + ...createMessage({ + role: "user", + content: userInput, + }), + preview: true, + }, + ] + : [], + ) + ); }, [ config.sendPreviewBubble, context, @@ -1324,6 +1359,53 @@ function _Chat() { setAttachImages(images); } + async function uploadFile() { + const uploadFiles: FileInfo[] = []; + uploadFiles.push(...attachFiles); + + uploadFiles.push( + ...(await new Promise((res, rej) => { + const fileInput = document.createElement("input"); + fileInput.type = "file"; + fileInput.accept = ".pdf,.txt,.md,.json,.csv,.docx,.srt,.mp3"; + fileInput.multiple = true; + fileInput.onchange = (event: any) => { + setUploading(true); + const files = event.target.files; + const api = new ClientApi(); + const fileDatas: FileInfo[] = []; + for (let i = 0; i < files.length; i++) { + const file = event.target.files[i]; + api.file + .upload(file) + .then((fileInfo) => { + console.log(fileInfo); + fileDatas.push(fileInfo); + if ( + fileDatas.length === 3 || + fileDatas.length === files.length + ) { + setUploading(false); + res(fileDatas); + } + }) + .catch((e) => { + setUploading(false); + rej(e); + }); + } + }; + fileInput.click(); + })), + ); + + const filesLength = uploadFiles.length; + if (filesLength > 5) { + uploadFiles.splice(5, filesLength - 5); + } + setAttachFiles(uploadFiles); + } + return (
@@ -1582,6 +1664,29 @@ function _Chat() { parentRef={scrollRef} defaultShow={i >= messages.length - 6} /> + {message.fileInfos && message.fileInfos.length > 0 && ( + + )} {getMessageImages(message).length == 1 && ( setShowPromptModal(true)} scrollToBottom={scrollToBottom} @@ -1651,7 +1758,7 @@ function _Chat() { />