diff --git a/README.md b/README.md index 3529c7c1e..1a6abcadb 100644 --- a/README.md +++ b/README.md @@ -59,7 +59,7 @@ - [ ] 插件列表页面开发 - [ ] 支持开关指定插件 - [ ] 支持添加自定义插件 -- [ ] 支持 Agent 参数配置( agentType, maxIterations, returnIntermediateSteps 等) +- [x] 支持 Agent 参数配置( ~~agentType~~, maxIterations, returnIntermediateSteps 等) - [x] 支持 ChatSession 级别插件功能开关 仅在使用 `0613` 版本模型时会出现插件开关,其它模型默认为关闭状态,开关也不会显示。 diff --git a/app/api/langchain/tool/agent/route.ts b/app/api/langchain/tool/agent/route.ts index c2b38716d..9b428f02f 100644 --- a/app/api/langchain/tool/agent/route.ts +++ b/app/api/langchain/tool/agent/route.ts @@ -35,6 +35,8 @@ interface RequestBody { presence_penalty?: number; frequency_penalty?: number; top_p?: number; + maxIterations: number; + returnIntermediateSteps: boolean; } class ResponseBody { @@ -120,6 +122,7 @@ async function handle(req: NextRequest) { `tool: ${action.tool} toolInput: ${action.toolInput}`, { action }, ); + if (!reqBody.returnIntermediateSteps) return; var response = new ResponseBody(); response.isToolMessage = true; let toolInput = (action.toolInput); @@ -202,8 +205,8 @@ async function handle(req: NextRequest) { }); const executor = await initializeAgentExecutorWithOptions(tools, llm, { agentType: "openai-functions", - returnIntermediateSteps: true, - maxIterations: 3, + returnIntermediateSteps: reqBody.returnIntermediateSteps, + maxIterations: reqBody.maxIterations, memory: memory, }); diff --git a/app/client/api.ts b/app/client/api.ts index fde18a242..1bf4a1e6a 100644 --- a/app/client/api.ts +++ b/app/client/api.ts @@ -23,6 +23,11 @@ export interface LLMConfig { frequency_penalty?: number; } +export interface LLMAgentConfig { + maxIterations: number; + returnIntermediateSteps: boolean; +} + export interface ChatOptions { messages: RequestMessage[]; config: LLMConfig; @@ -33,6 +38,17 @@ export interface ChatOptions { onController?: (controller: AbortController) => void; } +export interface AgentChatOptions { + messages: RequestMessage[]; + config: LLMConfig; + agentConfig: LLMAgentConfig; + onToolUpdate?: (toolName: string, toolInput: string) => void; + onUpdate?: (message: string, chunk: string) => void; + onFinish: (message: string) => void; + onError?: (err: Error) => void; + onController?: (controller: AbortController) => void; +} + export interface LLMUsage { used: number; total: number; @@ -45,7 +61,7 @@ export interface LLMModel { export abstract class LLMApi { abstract chat(options: ChatOptions): Promise; - abstract toolAgentChat(options: ChatOptions): Promise; + abstract toolAgentChat(options: AgentChatOptions): Promise; abstract usage(): Promise; abstract models(): Promise; } diff --git a/app/client/platforms/openai.ts b/app/client/platforms/openai.ts index 05309d479..08caeaff0 100644 --- a/app/client/platforms/openai.ts +++ b/app/client/platforms/openai.ts @@ -6,7 +6,14 @@ import { } from "@/app/constant"; import { useAccessStore, useAppConfig, useChatStore } from "@/app/store"; -import { ChatOptions, getHeaders, LLMApi, LLMModel, LLMUsage } from "../api"; +import { + AgentChatOptions, + ChatOptions, + getHeaders, + LLMApi, + LLMModel, + LLMUsage, +} from "../api"; import Locale from "../../locales"; import { EventStreamContentType, @@ -188,7 +195,7 @@ export class ChatGPTApi implements LLMApi { } } - async toolAgentChat(options: ChatOptions) { + async toolAgentChat(options: AgentChatOptions) { const messages = options.messages.map((v) => ({ role: v.role, content: v.content, @@ -210,6 +217,8 @@ export class ChatGPTApi implements LLMApi { presence_penalty: modelConfig.presence_penalty, frequency_penalty: modelConfig.frequency_penalty, top_p: modelConfig.top_p, + maxIterations: options.agentConfig.maxIterations, + returnIntermediateSteps: options.agentConfig.returnIntermediateSteps, }; console.log("[Request] openai payload: ", requestPayload); diff --git a/app/components/chat.tsx b/app/components/chat.tsx index e22f2f276..b8a1604c7 100644 --- a/app/components/chat.tsx +++ b/app/components/chat.tsx @@ -511,7 +511,7 @@ export function ChatActions(props: { icon={} /> - {currentModel.endsWith("0613") && ( + {config.pluginConfig.enable && currentModel.endsWith("0613") && ( void) => void; +}) { + return ( + <> + + + props.updateConfig( + (config) => (config.enable = e.currentTarget.checked), + ) + } + > + + + + props.updateConfig( + (config) => + (config.maxIterations = e.currentTarget.valueAsNumber), + ) + } + > + + + + props.updateConfig( + (config) => + (config.returnIntermediateSteps = e.currentTarget.checked), + ) + } + > + + + ); +} diff --git a/app/components/settings.tsx b/app/components/settings.tsx index 1e6ef7139..a5ace6284 100644 --- a/app/components/settings.tsx +++ b/app/components/settings.tsx @@ -49,6 +49,7 @@ import { Avatar, AvatarPicker } from "./emoji"; import { getClientConfig } from "../config/client"; import { useSyncStore } from "../store/sync"; import { nanoid } from "nanoid"; +import { PluginConfigList } from "./plugin-config"; function EditPromptModal(props: { id: string; onClose: () => void }) { const promptStore = usePromptStore(); @@ -739,6 +740,17 @@ export function Settings() { setShowPromptModal(false)} /> )} + + { + const pluginConfig = { ...config.pluginConfig }; + updater(pluginConfig); + config.update((config) => (config.pluginConfig = pluginConfig)); + }} + /> + + diff --git a/app/locales/cn.ts b/app/locales/cn.ts index a2fc40180..ff869f2f4 100644 --- a/app/locales/cn.ts +++ b/app/locales/cn.ts @@ -261,6 +261,20 @@ const cn = { Title: "频率惩罚度 (frequency_penalty)", SubTitle: "值越大,越有可能降低重复字词", }, + Plugin: { + Enable: { + Title: "启用插件", + SubTitle: "启用插件调用功能", + }, + MaxIteration: { + Title: "最大迭代数", + SubTitle: "插件调用最大迭代数", + }, + ReturnIntermediateStep: { + Title: "返回中间步骤", + SubTitle: "是否返回插件调用的中间步骤", + }, + }, }, Store: { DefaultTopic: "新的聊天", diff --git a/app/locales/en.ts b/app/locales/en.ts index 21d2922f1..9cad0b59e 100644 --- a/app/locales/en.ts +++ b/app/locales/en.ts @@ -265,6 +265,20 @@ const en: LocaleType = { SubTitle: "A larger value decreasing the likelihood to repeat the same line", }, + Plugin: { + Enable: { + Title: "Enable Plugin", + SubTitle: "Enable plugin invocation", + }, + MaxIteration: { + Title: "Max Iterations", + SubTitle: "Max of plugin iterations", + }, + ReturnIntermediateStep: { + Title: "Return Intermediate Steps", + SubTitle: "Return Intermediate Steps", + }, + }, }, Store: { DefaultTopic: "New Conversation", diff --git a/app/store/chat.ts b/app/store/chat.ts index acfaad0ba..d38d188b5 100644 --- a/app/store/chat.ts +++ b/app/store/chat.ts @@ -305,6 +305,9 @@ export const useChatStore = create()( const sendMessages = recentMessages.concat(userMessage); const messageIndex = get().currentSession().messages.length + 1; + const config = useAppConfig.getState(); + const pluginConfig = useAppConfig.getState().pluginConfig; + // save user's and bot's message get().updateCurrentSession((session) => { const savedUserMessage = { @@ -315,11 +318,12 @@ export const useChatStore = create()( session.messages.push(botMessage); }); - if (session.mask.usePlugins) { + if (config.pluginConfig.enable && session.mask.usePlugins) { console.log("[ToolAgent] start"); api.llm.toolAgentChat({ messages: sendMessages, config: { ...modelConfig, stream: true }, + agentConfig: { ...pluginConfig }, onUpdate(message) { botMessage.streaming = true; if (message) { diff --git a/app/store/config.ts b/app/store/config.ts index a638d51dd..19e77eb5b 100644 --- a/app/store/config.ts +++ b/app/store/config.ts @@ -51,6 +51,12 @@ export const DEFAULT_CONFIG = { enableInjectSystemPrompts: true, template: DEFAULT_INPUT_TEMPLATE, }, + + pluginConfig: { + enable: true, + maxIterations: 3, + returnIntermediateSteps: true, + }, }; export type ChatConfig = typeof DEFAULT_CONFIG; @@ -63,6 +69,7 @@ export type ChatConfigStore = ChatConfig & { }; export type ModelConfig = ChatConfig["modelConfig"]; +export type PluginConfig = ChatConfig["pluginConfig"]; export function limitNumber( x: number,