feat: support langchain vision mode

This commit is contained in:
Hk-Gosuto 2024-04-12 14:49:50 +08:00
parent 668ee60e53
commit 094f4ea6b9
2 changed files with 22 additions and 11 deletions

View File

@ -37,10 +37,11 @@ import {
HumanMessage, HumanMessage,
AIMessage, AIMessage,
} from "@langchain/core/messages"; } from "@langchain/core/messages";
import { MultimodalContent } from "@/app/client/api";
export interface RequestMessage { export interface RequestMessage {
role: string; role: string;
content: string; content: string | MultimodalContent[];
} }
export interface RequestBody { export interface RequestBody {
@ -324,11 +325,18 @@ export class AgentApi {
reqBody.messages reqBody.messages
.slice(0, reqBody.messages.length - 1) .slice(0, reqBody.messages.length - 1)
.forEach((message) => { .forEach((message) => {
if (message.role === "system") if (message.role === "system" && typeof message.content === "string")
pastMessages.push(new SystemMessage(message.content)); pastMessages.push(new SystemMessage(message.content));
if (message.role === "user") if (message.role === "user")
pastMessages.push(new HumanMessage(message.content)); typeof message.content === "string"
if (message.role === "assistant") ? pastMessages.push(new HumanMessage(message.content))
: pastMessages.push(
new HumanMessage({ content: message.content }),
);
if (
message.role === "assistant" &&
typeof message.content === "string"
)
pastMessages.push(new AIMessage(message.content)); pastMessages.push(new AIMessage(message.content));
}); });
@ -370,7 +378,7 @@ export class AgentApi {
const MEMORY_KEY = "chat_history"; const MEMORY_KEY = "chat_history";
const prompt = ChatPromptTemplate.fromMessages([ const prompt = ChatPromptTemplate.fromMessages([
new MessagesPlaceholder(MEMORY_KEY), new MessagesPlaceholder(MEMORY_KEY),
["user", "{input}"], new MessagesPlaceholder("input"),
new MessagesPlaceholder("agent_scratchpad"), new MessagesPlaceholder("agent_scratchpad"),
]); ]);
const modelWithTools = llm.bind({ const modelWithTools = llm.bind({
@ -378,9 +386,7 @@ export class AgentApi {
}); });
const runnableAgent = RunnableSequence.from([ const runnableAgent = RunnableSequence.from([
{ {
input: (i: { input: string; steps: ToolsAgentStep[] }) => { input: (i) => i.input,
return i.input;
},
agent_scratchpad: (i: { input: string; steps: ToolsAgentStep[] }) => { agent_scratchpad: (i: { input: string; steps: ToolsAgentStep[] }) => {
return formatToOpenAIToolMessages(i.steps); return formatToOpenAIToolMessages(i.steps);
}, },
@ -401,11 +407,15 @@ export class AgentApi {
agent: runnableAgent, agent: runnableAgent,
tools, tools,
}); });
const lastMessageContent = reqBody.messages.slice(-1)[0].content;
const lastHumanMessage =
typeof lastMessageContent === "string"
? new HumanMessage(lastMessageContent)
: new HumanMessage({ content: lastMessageContent });
executor executor
.invoke( .invoke(
{ {
input: reqBody.messages.slice(-1)[0].content, input: [lastHumanMessage],
signal: this.controller.signal, signal: this.controller.signal,
}, },
{ {

View File

@ -399,9 +399,10 @@ export class ChatGPTApi implements LLMApi {
} }
async toolAgentChat(options: AgentChatOptions) { async toolAgentChat(options: AgentChatOptions) {
const visionModel = isVisionModel(options.config.model);
const messages = options.messages.map((v) => ({ const messages = options.messages.map((v) => ({
role: v.role, role: v.role,
content: getMessageTextContent(v), content: visionModel ? v.content : getMessageTextContent(v),
})); }));
const modelConfig = { const modelConfig = {