import { create } from "zustand"; import { persist } from "zustand/middleware"; import { trimTopic } from "../utils"; import Locale, { getLang } from "../locales"; import { showToast } from "../components/ui-lib"; import { ModelConfig, ModelType, useAppConfig } from "./config"; import { createEmptyMask, Mask } from "./mask"; import { DEFAULT_INPUT_TEMPLATE, DEFAULT_SYSTEM_TEMPLATE, StoreKey, } from "../constant"; import { api, RequestMessage } from "../client/api"; import { ChatControllerPool } from "../client/controller"; import { prettyObject } from "../utils/format"; import { estimateTokenLength } from "../utils/token"; export type ChatMessage = RequestMessage & { date: string; streaming?: boolean; isError?: boolean; id?: number; model?: ModelType; toolPrompt?: string; }; export function createMessage(override: Partial): ChatMessage { return { id: Date.now(), date: new Date().toLocaleString(), role: "user", content: "", toolPrompt: undefined, ...override, }; } export interface ChatStat { tokenCount: number; wordCount: number; charCount: number; } export interface ChatSession { id: number; topic: string; memoryPrompt: string; messages: ChatMessage[]; stat: ChatStat; lastUpdate: number; lastSummarizeIndex: number; clearContextIndex?: number; mask: Mask; webSearch: boolean; } export const DEFAULT_TOPIC = Locale.Store.DefaultTopic; export const BOT_HELLO: ChatMessage = createMessage({ role: "assistant", content: Locale.Store.BotHello, }); function createEmptySession(): ChatSession { return { id: Date.now() + Math.random(), topic: DEFAULT_TOPIC, memoryPrompt: "", messages: [], stat: { tokenCount: 0, wordCount: 0, charCount: 0, }, lastUpdate: Date.now(), lastSummarizeIndex: 0, mask: createEmptyMask(), webSearch: false, }; } interface ChatStore { sessions: ChatSession[]; currentSessionIndex: number; globalId: number; clearSessions: () => void; moveSession: (from: number, to: number) => void; selectSession: (index: number) => void; newSession: (mask?: Mask) => void; deleteSession: (index: number) => void; currentSession: () => ChatSession; nextSession: (delta: number) => void; onNewMessage: (message: ChatMessage) => void; onUserInput: (content: string) => Promise; summarizeSession: () => void; updateStat: (message: ChatMessage) => void; updateCurrentSession: (updater: (session: ChatSession) => void) => void; updateMessage: ( sessionIndex: number, messageIndex: number, updater: (message?: ChatMessage) => void, ) => void; resetSession: () => void; getMessagesWithMemory: () => ChatMessage[]; getMemoryPrompt: () => ChatMessage; clearAllData: () => void; } function countMessages(msgs: ChatMessage[]) { return msgs.reduce((pre, cur) => pre + estimateTokenLength(cur.content), 0); } function fillTemplateWith(input: string, modelConfig: ModelConfig) { const vars = { model: modelConfig.model, time: new Date().toLocaleString(), lang: getLang(), input: input, }; let output = modelConfig.template ?? DEFAULT_INPUT_TEMPLATE; // must contains {{input}} const inputVar = "{{input}}"; if (!output.includes(inputVar)) { output += "\n" + inputVar; } Object.entries(vars).forEach(([name, value]) => { output = output.replaceAll(`{{${name}}}`, value); }); return output; } export const useChatStore = create()( persist( (set, get) => ({ sessions: [createEmptySession()], currentSessionIndex: 0, globalId: 0, clearSessions() { set(() => ({ sessions: [createEmptySession()], currentSessionIndex: 0, })); }, selectSession(index: number) { set({ currentSessionIndex: index, }); }, moveSession(from: number, to: number) { set((state) => { const { sessions, currentSessionIndex: oldIndex } = state; // move the session const newSessions = [...sessions]; const session = newSessions[from]; newSessions.splice(from, 1); newSessions.splice(to, 0, session); // modify current session id let newIndex = oldIndex === from ? to : oldIndex; if (oldIndex > from && oldIndex <= to) { newIndex -= 1; } else if (oldIndex < from && oldIndex >= to) { newIndex += 1; } return { currentSessionIndex: newIndex, sessions: newSessions, }; }); }, newSession(mask) { const session = createEmptySession(); set(() => ({ globalId: get().globalId + 1 })); session.id = get().globalId; if (mask) { const config = useAppConfig.getState(); const globalModelConfig = config.modelConfig; session.mask = { ...mask, modelConfig: { ...globalModelConfig, ...mask.modelConfig, }, }; session.topic = mask.name; } set((state) => ({ currentSessionIndex: 0, sessions: [session].concat(state.sessions), })); }, nextSession(delta) { const n = get().sessions.length; const limit = (x: number) => (x + n) % n; const i = get().currentSessionIndex; get().selectSession(limit(i + delta)); }, deleteSession(index) { const deletingLastSession = get().sessions.length === 1; const deletedSession = get().sessions.at(index); if (!deletedSession) return; const sessions = get().sessions.slice(); sessions.splice(index, 1); const currentIndex = get().currentSessionIndex; let nextIndex = Math.min( currentIndex - Number(index < currentIndex), sessions.length - 1, ); if (deletingLastSession) { nextIndex = 0; sessions.push(createEmptySession()); } // for undo delete action const restoreState = { currentSessionIndex: get().currentSessionIndex, sessions: get().sessions.slice(), }; set(() => ({ currentSessionIndex: nextIndex, sessions, })); showToast( Locale.Home.DeleteToast, { text: Locale.Home.Revert, onClick() { set(() => restoreState); }, }, 5000, ); }, currentSession() { let index = get().currentSessionIndex; const sessions = get().sessions; if (index < 0 || index >= sessions.length) { index = Math.min(sessions.length - 1, Math.max(0, index)); set(() => ({ currentSessionIndex: index })); } const session = sessions[index]; return session; }, onNewMessage(message) { get().updateCurrentSession((session) => { session.messages = session.messages.concat(); session.lastUpdate = Date.now(); }); get().updateStat(message); get().summarizeSession(); }, async onUserInput(content) { const session = get().currentSession(); const modelConfig = session.mask.modelConfig; const userContent = fillTemplateWith(content, modelConfig); console.log("[User Input] after template: ", userContent); const userMessage: ChatMessage = createMessage({ role: "user", content: userContent, }); const botMessage: ChatMessage = createMessage({ role: "assistant", streaming: true, id: userMessage.id! + 1, model: modelConfig.model, }); // get recent messages const recentMessages = get().getMessagesWithMemory(); const sendMessages = recentMessages.concat(userMessage); const sessionIndex = get().currentSessionIndex; const messageIndex = get().currentSession().messages.length + 1; // save user's and bot's message get().updateCurrentSession((session) => { const savedUserMessage = { ...userMessage, content, }; session.messages.push(savedUserMessage); }); if (session.webSearch) { const query = encodeURIComponent(content); let searchResult = await api.searchTool.call(query); console.log("[Tools] ", searchResult); const webSearchPrompt = ` Using the provided web search results, write a comprehensive reply to the given query. If the provided search results refer to multiple subjects with the same name, write separate answers for each subject. Make sure to cite results using \`[[number](URL)]\` notation after the reference. Web search json results: """ ${JSON.stringify(searchResult)} """ Current date: """ ${new Date().toISOString()} """ Query: """ ${content} """ Reply in ${getLang()} and markdown.`; userMessage.toolPrompt = webSearchPrompt; } // save user's and bot's message get().updateCurrentSession((session) => { session.messages.push(botMessage); }); // make request api.llm.chat({ messages: sendMessages, config: { ...modelConfig, stream: true }, onUpdate(message) { botMessage.streaming = true; if (message) { botMessage.content = message; } get().updateCurrentSession((session) => { session.messages = session.messages.concat(); }); }, onFinish(message) { botMessage.streaming = false; if (message) { botMessage.content = message; get().onNewMessage(botMessage); } ChatControllerPool.remove( sessionIndex, botMessage.id ?? messageIndex, ); }, onError(error) { const isAborted = error.message.includes("aborted"); botMessage.content = "\n\n" + prettyObject({ error: true, message: error.message, }); botMessage.streaming = false; userMessage.isError = !isAborted; botMessage.isError = !isAborted; get().updateCurrentSession((session) => { session.messages = session.messages.concat(); }); ChatControllerPool.remove( sessionIndex, botMessage.id ?? messageIndex, ); console.error("[Chat] failed ", error); }, onController(controller) { // collect controller for stop/retry ChatControllerPool.addController( sessionIndex, botMessage.id ?? messageIndex, controller, ); }, }); }, getMemoryPrompt() { const session = get().currentSession(); return { role: "system", content: session.memoryPrompt.length > 0 ? Locale.Store.Prompt.History(session.memoryPrompt) : "", date: "", } as ChatMessage; }, getMessagesWithMemory() { const session = get().currentSession(); const modelConfig = session.mask.modelConfig; const clearContextIndex = session.clearContextIndex ?? 0; const messages = session.messages.slice(); const totalMessageCount = session.messages.length; // in-context prompts const contextPrompts = session.mask.context.slice(); // system prompts, to get close to OpenAI Web ChatGPT // only will be injected if user does not use a mask or set none context prompts const shouldInjectSystemPrompts = contextPrompts.length === 0; const systemPrompts = shouldInjectSystemPrompts ? [ createMessage({ role: "system", content: fillTemplateWith("", { ...modelConfig, template: DEFAULT_SYSTEM_TEMPLATE, }), }), ] : []; if (shouldInjectSystemPrompts) { console.log( "[Global System Prompt] ", systemPrompts.at(0)?.content ?? "empty", ); } // long term memory const shouldSendLongTermMemory = modelConfig.sendMemory && session.memoryPrompt && session.memoryPrompt.length > 0 && session.lastSummarizeIndex > clearContextIndex; const longTermMemoryPrompts = shouldSendLongTermMemory ? [get().getMemoryPrompt()] : []; const longTermMemoryStartIndex = session.lastSummarizeIndex; // short term memory const shortTermMemoryStartIndex = Math.max( 0, totalMessageCount - modelConfig.historyMessageCount, ); // lets concat send messages, including 4 parts: // 0. system prompt: to get close to OpenAI Web ChatGPT // 1. long term memory: summarized memory messages // 2. pre-defined in-context prompts // 3. short term memory: latest n messages // 4. newest input message const memoryStartIndex = shouldSendLongTermMemory ? Math.min(longTermMemoryStartIndex, shortTermMemoryStartIndex) : shortTermMemoryStartIndex; // and if user has cleared history messages, we should exclude the memory too. const contextStartIndex = Math.max(clearContextIndex, memoryStartIndex); const maxTokenThreshold = modelConfig.max_tokens; // get recent messages as much as possible const reversedRecentMessages = []; for ( let i = totalMessageCount - 1, tokenCount = 0; i >= contextStartIndex && tokenCount < maxTokenThreshold; i -= 1 ) { const msg = messages[i]; if (!msg || msg.isError) continue; tokenCount += estimateTokenLength(msg.content); reversedRecentMessages.push(msg); } // concat all messages const recentMessages = [ ...systemPrompts, ...longTermMemoryPrompts, ...contextPrompts, ...reversedRecentMessages.reverse(), ]; return recentMessages; }, updateMessage( sessionIndex: number, messageIndex: number, updater: (message?: ChatMessage) => void, ) { const sessions = get().sessions; const session = sessions.at(sessionIndex); const messages = session?.messages; updater(messages?.at(messageIndex)); set(() => ({ sessions })); }, resetSession() { get().updateCurrentSession((session) => { session.messages = []; session.memoryPrompt = ""; }); }, summarizeSession() { const session = get().currentSession(); // remove error messages if any const messages = session.messages; // should summarize topic after chating more than 50 words const SUMMARIZE_MIN_LEN = 50; if ( session.topic === DEFAULT_TOPIC && countMessages(messages) >= SUMMARIZE_MIN_LEN ) { const topicMessages = messages.concat( createMessage({ role: "user", content: Locale.Store.Prompt.Topic, }), ); api.llm.chat({ messages: topicMessages, config: { model: "gpt-3.5-turbo", }, onFinish(message) { get().updateCurrentSession( (session) => (session.topic = message.length > 0 ? trimTopic(message) : DEFAULT_TOPIC), ); }, }); } const modelConfig = session.mask.modelConfig; const summarizeIndex = Math.max( session.lastSummarizeIndex, session.clearContextIndex ?? 0, ); let toBeSummarizedMsgs = messages .filter((msg) => !msg.isError) .slice(summarizeIndex); const historyMsgLength = countMessages(toBeSummarizedMsgs); if (historyMsgLength > modelConfig?.max_tokens ?? 4000) { const n = toBeSummarizedMsgs.length; toBeSummarizedMsgs = toBeSummarizedMsgs.slice( Math.max(0, n - modelConfig.historyMessageCount), ); } // add memory prompt toBeSummarizedMsgs.unshift(get().getMemoryPrompt()); const lastSummarizeIndex = session.messages.length; console.log( "[Chat History] ", toBeSummarizedMsgs, historyMsgLength, modelConfig.compressMessageLengthThreshold, ); if ( historyMsgLength > modelConfig.compressMessageLengthThreshold && modelConfig.sendMemory ) { api.llm.chat({ messages: toBeSummarizedMsgs.concat({ role: "system", content: Locale.Store.Prompt.Summarize, date: "", }), config: { ...modelConfig, stream: true }, onUpdate(message) { session.memoryPrompt = message; }, onFinish(message) { console.log("[Memory] ", message); session.lastSummarizeIndex = lastSummarizeIndex; }, onError(err) { console.error("[Summarize] ", err); }, }); } }, updateStat(message) { get().updateCurrentSession((session) => { session.stat.charCount += message.content.length; // TODO: should update chat count and word count }); }, updateCurrentSession(updater) { const sessions = get().sessions; const index = get().currentSessionIndex; updater(sessions[index]); set(() => ({ sessions })); }, clearAllData() { localStorage.clear(); location.reload(); }, }), { name: StoreKey.Chat, version: 2, migrate(persistedState, version) { const state = persistedState as any; const newState = JSON.parse(JSON.stringify(state)) as ChatStore; if (version < 2) { newState.globalId = 0; newState.sessions = []; const oldSessions = state.sessions; for (const oldSession of oldSessions) { const newSession = createEmptySession(); newSession.topic = oldSession.topic; newSession.messages = [...oldSession.messages]; newSession.mask.modelConfig.sendMemory = true; newSession.mask.modelConfig.historyMessageCount = 4; newSession.mask.modelConfig.compressMessageLengthThreshold = 1000; newState.sessions.push(newSession); } } return newState; }, }, ), );