diff --git a/app/client/platforms/openai.ts b/app/client/platforms/openai.ts index fe1ef38d7..edac751b0 100644 --- a/app/client/platforms/openai.ts +++ b/app/client/platforms/openai.ts @@ -20,6 +20,7 @@ import { preProcessImageContent, uploadImage, base64Image2Blob, + stream, } from "@/app/utils/chat"; import { cloudflareAIGatewayUrl } from "@/app/utils/cloudflare"; import { DalleSize, DalleQuality, DalleStyle } from "@/app/typing"; @@ -238,52 +239,30 @@ export class ChatGPTApi implements LLMApi { isDalle3 ? OpenaiPath.ImagePath : OpenaiPath.ChatPath, ); } - const chatPayload = { - method: "POST", - body: JSON.stringify(requestPayload), - signal: controller.signal, - headers: getHeaders(), - }; - - // make a fetch request - const requestTimeoutId = setTimeout( - () => controller.abort(), - isDalle3 ? REQUEST_TIMEOUT_MS * 2 : REQUEST_TIMEOUT_MS, // dalle3 using b64_json is slow. - ); - if (shouldStream) { - let responseText = ""; - let remainText = ""; - let finished = false; - let running = false; - let runTools: ChatMessageTool[] = []; - - // animate response to make it looks smooth - function animateResponseText() { - if (finished || controller.signal.aborted) { - responseText += remainText; - console.log("[Response Animation] finished"); - if (responseText?.length === 0) { - options.onError?.(new Error("empty response from server")); - } - return; - } - - if (remainText.length > 0) { - const fetchCount = Math.max(1, Math.round(remainText.length / 60)); - const fetchText = remainText.slice(0, fetchCount); - responseText += fetchText; - remainText = remainText.slice(fetchCount); - options.onUpdate?.(responseText, fetchText); - } - - requestAnimationFrame(animateResponseText); - } - - // start animaion - animateResponseText(); - - // TODO 后面这里是从选择的plugins中获取function列表 + const tools = [ + { + type: "function", + function: { + name: "get_current_weather", + description: "Get the current weather", + parameters: { + type: "object", + properties: { + location: { + type: "string", + description: "The city and country, eg. San Francisco, USA", + }, + format: { + type: "string", + enum: ["celsius", "fahrenheit"], + }, + }, + required: ["location", "format"], + }, + }, + }, + ]; const funcs = { get_current_weather: (args: any) => { console.log("call get_current_weather", args); @@ -292,221 +271,61 @@ export class ChatGPTApi implements LLMApi { }); }, }; - const finish = () => { - if (!finished) { - console.log("try run tools", runTools.length, finished, running); - if (!running && runTools.length > 0) { - const toolCallMessage = { - role: "assistant", - tool_calls: [...runTools], + stream( + chatPath, + requestPayload, + getHeaders(), + tools, + funcs, + controller, + (text: string, runTools: ChatMessageTool[]) => { + console.log("parseSSE", text, runTools); + const json = JSON.parse(text); + const choices = json.choices as Array<{ + delta: { + content: string; + tool_calls: ChatMessageTool[]; }; - running = true; - runTools.splice(0, runTools.length); // empty runTools - return Promise.all( - toolCallMessage.tool_calls.map((tool) => { - options?.onBeforeTool?.(tool); - return Promise.resolve( - // @ts-ignore - funcs[tool.function.name]( - // @ts-ignore - JSON.parse(tool.function.arguments), - ), - ) - .then((content) => { - options?.onAfterTool?.({ - ...tool, - content, - isError: false, - }); - return content; - }) - .catch((e) => { - options?.onAfterTool?.({ ...tool, isError: true }); - return e.toString(); - }) - .then((content) => ({ - role: "tool", - content, - tool_call_id: tool.id, - })); - }), - ).then((toolCallResult) => { - console.log("end runTools", toolCallMessage, toolCallResult); + }>; + const tool_calls = choices[0]?.delta?.tool_calls; + if (tool_calls?.length > 0) { + const index = tool_calls[0]?.index; + const id = tool_calls[0]?.id; + const args = tool_calls[0]?.function?.arguments; + if (id) { + runTools.push({ + id, + type: tool_calls[0]?.type, + function: { + name: tool_calls[0]?.function?.name as string, + arguments: args, + }, + }); + } else { // @ts-ignore - requestPayload?.messages?.splice( - // @ts-ignore - requestPayload?.messages?.length, - 0, - toolCallMessage, - ...toolCallResult, - ); - setTimeout(() => { - // call again - console.log("start again"); - running = false; - chatApi(chatPath, requestPayload as RequestPayload); // call fetchEventSource - }, 60); - }); - console.log("try run tools", runTools.length, finished); - return; + runTools[index]["function"]["arguments"] += args; + } } - if (running) { - return; - } - finished = true; - options.onFinish(responseText + remainText); - } + + console.log("runTools", runTools); + return choices[0]?.delta?.content; + }, + options, + ); + } else { + const chatPayload = { + method: "POST", + body: JSON.stringify(requestPayload), + signal: controller.signal, + headers: getHeaders(), }; - controller.signal.onabort = finish; + // make a fetch request + const requestTimeoutId = setTimeout( + () => controller.abort(), + isDalle3 ? REQUEST_TIMEOUT_MS * 2 : REQUEST_TIMEOUT_MS, // dalle3 using b64_json is slow. + ); - function chatApi(chatPath: string, requestPayload: RequestPayload) { - const chatPayload = { - method: "POST", - body: JSON.stringify({ - ...requestPayload, - // TODO 这里暂时写死的,后面从store.tools中按照当前session中选择的获取 - tools: [ - { - type: "function", - function: { - name: "get_current_weather", - description: "Get the current weather", - parameters: { - type: "object", - properties: { - location: { - type: "string", - description: - "The city and country, eg. San Francisco, USA", - }, - format: { - type: "string", - enum: ["celsius", "fahrenheit"], - }, - }, - required: ["location", "format"], - }, - }, - }, - ], - }), - signal: controller.signal, - headers: getHeaders(), - }; - console.log("chatApi", chatPath, requestPayload, chatPayload); - fetchEventSource(chatPath, { - ...chatPayload, - async onopen(res) { - clearTimeout(requestTimeoutId); - const contentType = res.headers.get("content-type"); - console.log( - "[OpenAI] request response content type: ", - contentType, - ); - - if (contentType?.startsWith("text/plain")) { - responseText = await res.clone().text(); - return finish(); - } - - if ( - !res.ok || - !res.headers - .get("content-type") - ?.startsWith(EventStreamContentType) || - res.status !== 200 - ) { - const responseTexts = [responseText]; - let extraInfo = await res.clone().text(); - try { - const resJson = await res.clone().json(); - extraInfo = prettyObject(resJson); - } catch {} - - if (res.status === 401) { - responseTexts.push(Locale.Error.Unauthorized); - } - - if (extraInfo) { - responseTexts.push(extraInfo); - } - - responseText = responseTexts.join("\n\n"); - - return finish(); - } - }, - onmessage(msg) { - if (msg.data === "[DONE]" || finished) { - return finish(); - } - const text = msg.data; - try { - const json = JSON.parse(text); - const choices = json.choices as Array<{ - delta: { - content: string; - tool_calls: ChatMessageTool[]; - }; - }>; - console.log("choices", choices); - const delta = choices[0]?.delta?.content; - const tool_calls = choices[0]?.delta?.tool_calls; - const textmoderation = json?.prompt_filter_results; - - if (delta) { - remainText += delta; - } - if (tool_calls?.length > 0) { - const index = tool_calls[0]?.index; - const id = tool_calls[0]?.id; - const args = tool_calls[0]?.function?.arguments; - if (id) { - runTools.push({ - id, - type: tool_calls[0]?.type, - function: { - name: tool_calls[0]?.function?.name as string, - arguments: args, - }, - }); - } else { - // @ts-ignore - runTools[index]["function"]["arguments"] += args; - } - } - - console.log("runTools", runTools); - - if ( - textmoderation && - textmoderation.length > 0 && - ServiceProvider.Azure - ) { - const contentFilterResults = - textmoderation[0]?.content_filter_results; - console.log( - `[${ServiceProvider.Azure}] [Text Moderation] flagged categories result:`, - contentFilterResults, - ); - } - } catch (e) { - console.error("[Request] parse error", text, msg); - } - }, - onclose() { - finish(); - }, - onerror(e) { - options.onError?.(e); - throw e; - }, - openWhenHidden: true, - }); - } - chatApi(chatPath, requestPayload as RequestPayload); // call fetchEventSource - } else { const res = await fetch(chatPath, chatPayload); clearTimeout(requestTimeoutId); diff --git a/app/utils/chat.ts b/app/utils/chat.ts index 6a296e576..1289695b9 100644 --- a/app/utils/chat.ts +++ b/app/utils/chat.ts @@ -1,5 +1,15 @@ -import { CACHE_URL_PREFIX, UPLOAD_URL } from "@/app/constant"; +import { + CACHE_URL_PREFIX, + UPLOAD_URL, + REQUEST_TIMEOUT_MS, +} from "@/app/constant"; import { RequestMessage } from "@/app/client/api"; +import Locale from "@/app/locales"; +import { + EventStreamContentType, + fetchEventSource, +} from "@fortaine/fetch-event-source"; +import { prettyObject } from "./format"; export function compressImage(file: Blob, maxSize: number): Promise { return new Promise((resolve, reject) => { @@ -142,3 +152,198 @@ export function removeImage(imageUrl: string) { credentials: "include", }); } + +export function stream( + chatPath: string, + requestPayload: any, + headers: any, + tools: any[], + funcs: any, + controller: AbortController, + parseSSE: (text: string, runTools: any[]) => string | undefined, + options: any, +) { + let responseText = ""; + let remainText = ""; + let finished = false; + let running = false; + let runTools: any[] = []; + + // animate response to make it looks smooth + function animateResponseText() { + if (finished || controller.signal.aborted) { + responseText += remainText; + console.log("[Response Animation] finished"); + if (responseText?.length === 0) { + options.onError?.(new Error("empty response from server")); + } + return; + } + + if (remainText.length > 0) { + const fetchCount = Math.max(1, Math.round(remainText.length / 60)); + const fetchText = remainText.slice(0, fetchCount); + responseText += fetchText; + remainText = remainText.slice(fetchCount); + options.onUpdate?.(responseText, fetchText); + } + + requestAnimationFrame(animateResponseText); + } + + // start animaion + animateResponseText(); + + const finish = () => { + if (!finished) { + console.log("try run tools", runTools.length, finished, running); + if (!running && runTools.length > 0) { + const toolCallMessage = { + role: "assistant", + tool_calls: [...runTools], + }; + running = true; + runTools.splice(0, runTools.length); // empty runTools + return Promise.all( + toolCallMessage.tool_calls.map((tool) => { + options?.onBeforeTool?.(tool); + return Promise.resolve( + // @ts-ignore + funcs[tool.function.name]( + // @ts-ignore + JSON.parse(tool.function.arguments), + ), + ) + .then((content) => { + options?.onAfterTool?.({ + ...tool, + content, + isError: false, + }); + return content; + }) + .catch((e) => { + options?.onAfterTool?.({ ...tool, isError: true }); + return e.toString(); + }) + .then((content) => ({ + role: "tool", + content, + tool_call_id: tool.id, + })); + }), + ).then((toolCallResult) => { + console.log("end runTools", toolCallMessage, toolCallResult); + // @ts-ignore + requestPayload?.messages?.splice( + // @ts-ignore + requestPayload?.messages?.length, + 0, + toolCallMessage, + ...toolCallResult, + ); + setTimeout(() => { + // call again + console.log("start again"); + running = false; + chatApi(chatPath, headers, requestPayload, tools); // call fetchEventSource + }, 60); + }); + console.log("try run tools", runTools.length, finished); + return; + } + if (running) { + return; + } + finished = true; + options.onFinish(responseText + remainText); + } + }; + + controller.signal.onabort = finish; + + function chatApi( + chatPath: string, + headers: any, + requestPayload: any, + tools: any, + ) { + const chatPayload = { + method: "POST", + body: JSON.stringify({ + ...requestPayload, + tools, + }), + signal: controller.signal, + headers, + }; + const requestTimeoutId = setTimeout( + () => controller.abort(), + REQUEST_TIMEOUT_MS, + ); + fetchEventSource(chatPath, { + ...chatPayload, + async onopen(res) { + clearTimeout(requestTimeoutId); + const contentType = res.headers.get("content-type"); + console.log("[Request] response content type: ", contentType); + + if (contentType?.startsWith("text/plain")) { + responseText = await res.clone().text(); + return finish(); + } + + if ( + !res.ok || + !res.headers + .get("content-type") + ?.startsWith(EventStreamContentType) || + res.status !== 200 + ) { + const responseTexts = [responseText]; + let extraInfo = await res.clone().text(); + try { + const resJson = await res.clone().json(); + extraInfo = prettyObject(resJson); + } catch {} + + if (res.status === 401) { + responseTexts.push(Locale.Error.Unauthorized); + } + + if (extraInfo) { + responseTexts.push(extraInfo); + } + + responseText = responseTexts.join("\n\n"); + + return finish(); + } + }, + onmessage(msg) { + if (msg.data === "[DONE]" || finished) { + return finish(); + } + const text = msg.data; + try { + const chunk = parseSSE(msg.data, runTools); + if (chunk) { + remainText += chunk; + } + } catch (e) { + console.error("[Request] parse error", text, msg); + } + }, + onclose() { + finish(); + }, + onerror(e) { + options?.onError?.(e); + throw e; + }, + openWhenHidden: true, + }); + console.log("chatApi", chatPath, requestPayload, tools); + } + chatApi(chatPath, headers, requestPayload, tools); // call fetchEventSource +}