From 094f4ea6b99524b0d28e433a22bb81dede95e5c0 Mon Sep 17 00:00:00 2001 From: Hk-Gosuto Date: Fri, 12 Apr 2024 14:49:50 +0800 Subject: [PATCH] feat: support langchain vision mode --- app/api/langchain/tool/agent/agentapi.ts | 30 ++++++++++++++++-------- app/client/platforms/openai.ts | 3 ++- 2 files changed, 22 insertions(+), 11 deletions(-) diff --git a/app/api/langchain/tool/agent/agentapi.ts b/app/api/langchain/tool/agent/agentapi.ts index cbbfa8f5e..7edf7201d 100644 --- a/app/api/langchain/tool/agent/agentapi.ts +++ b/app/api/langchain/tool/agent/agentapi.ts @@ -37,10 +37,11 @@ import { HumanMessage, AIMessage, } from "@langchain/core/messages"; +import { MultimodalContent } from "@/app/client/api"; export interface RequestMessage { role: string; - content: string; + content: string | MultimodalContent[]; } export interface RequestBody { @@ -324,11 +325,18 @@ export class AgentApi { reqBody.messages .slice(0, reqBody.messages.length - 1) .forEach((message) => { - if (message.role === "system") + if (message.role === "system" && typeof message.content === "string") pastMessages.push(new SystemMessage(message.content)); if (message.role === "user") - pastMessages.push(new HumanMessage(message.content)); - if (message.role === "assistant") + typeof message.content === "string" + ? pastMessages.push(new HumanMessage(message.content)) + : pastMessages.push( + new HumanMessage({ content: message.content }), + ); + if ( + message.role === "assistant" && + typeof message.content === "string" + ) pastMessages.push(new AIMessage(message.content)); }); @@ -370,7 +378,7 @@ export class AgentApi { const MEMORY_KEY = "chat_history"; const prompt = ChatPromptTemplate.fromMessages([ new MessagesPlaceholder(MEMORY_KEY), - ["user", "{input}"], + new MessagesPlaceholder("input"), new MessagesPlaceholder("agent_scratchpad"), ]); const modelWithTools = llm.bind({ @@ -378,9 +386,7 @@ export class AgentApi { }); const runnableAgent = RunnableSequence.from([ { - input: (i: { input: string; steps: ToolsAgentStep[] }) => { - return i.input; - }, + input: (i) => i.input, agent_scratchpad: (i: { input: string; steps: ToolsAgentStep[] }) => { return formatToOpenAIToolMessages(i.steps); }, @@ -401,11 +407,15 @@ export class AgentApi { agent: runnableAgent, tools, }); - + const lastMessageContent = reqBody.messages.slice(-1)[0].content; + const lastHumanMessage = + typeof lastMessageContent === "string" + ? new HumanMessage(lastMessageContent) + : new HumanMessage({ content: lastMessageContent }); executor .invoke( { - input: reqBody.messages.slice(-1)[0].content, + input: [lastHumanMessage], signal: this.controller.signal, }, { diff --git a/app/client/platforms/openai.ts b/app/client/platforms/openai.ts index a1d7d05d3..5a5f74b33 100644 --- a/app/client/platforms/openai.ts +++ b/app/client/platforms/openai.ts @@ -399,9 +399,10 @@ export class ChatGPTApi implements LLMApi { } async toolAgentChat(options: AgentChatOptions) { + const visionModel = isVisionModel(options.config.model); const messages = options.messages.map((v) => ({ role: v.role, - content: getMessageTextContent(v), + content: visionModel ? v.content : getMessageTextContent(v), })); const modelConfig = {