Merge pull request #4906 from ConnectAI-E/feature-gemini-streaming

gemini using real sse format response #3677 #3688
This commit is contained in:
Lloyd Zhou 2024-07-03 10:58:27 +08:00 committed by GitHub
commit 78e2b41e0c
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 88 additions and 66 deletions

View File

@ -63,7 +63,9 @@ async function handle(
); );
} }
const fetchUrl = `${baseUrl}/${path}?key=${key}`; const fetchUrl = `${baseUrl}/${path}?key=${key}${
req?.nextUrl?.searchParams?.get("alt") == "sse" ? "&alt=sse" : ""
}`;
const fetchOptions: RequestInit = { const fetchOptions: RequestInit = {
headers: { headers: {
"Content-Type": "application/json", "Content-Type": "application/json",

View File

@ -3,6 +3,12 @@ import { ChatOptions, getHeaders, LLMApi, LLMModel, LLMUsage } from "../api";
import { useAccessStore, useAppConfig, useChatStore } from "@/app/store"; import { useAccessStore, useAppConfig, useChatStore } from "@/app/store";
import { getClientConfig } from "@/app/config/client"; import { getClientConfig } from "@/app/config/client";
import { DEFAULT_API_HOST } from "@/app/constant"; import { DEFAULT_API_HOST } from "@/app/constant";
import Locale from "../../locales";
import {
EventStreamContentType,
fetchEventSource,
} from "@fortaine/fetch-event-source";
import { prettyObject } from "@/app/utils/format";
import { import {
getMessageTextContent, getMessageTextContent,
getMessageImages, getMessageImages,
@ -20,7 +26,7 @@ export class GeminiProApi implements LLMApi {
); );
} }
async chat(options: ChatOptions): Promise<void> { async chat(options: ChatOptions): Promise<void> {
// const apiClient = this; const apiClient = this;
let multimodal = false; let multimodal = false;
const messages = options.messages.map((v) => { const messages = options.messages.map((v) => {
let parts: any[] = [{ text: getMessageTextContent(v) }]; let parts: any[] = [{ text: getMessageTextContent(v) }];
@ -120,7 +126,9 @@ export class GeminiProApi implements LLMApi {
if (!baseUrl) { if (!baseUrl) {
baseUrl = isApp baseUrl = isApp
? DEFAULT_API_HOST + "/api/proxy/google/" + Google.ChatPath(modelConfig.model) ? DEFAULT_API_HOST +
"/api/proxy/google/" +
Google.ChatPath(modelConfig.model)
: this.path(Google.ChatPath(modelConfig.model)); : this.path(Google.ChatPath(modelConfig.model));
} }
@ -139,16 +147,15 @@ export class GeminiProApi implements LLMApi {
() => controller.abort(), () => controller.abort(),
REQUEST_TIMEOUT_MS, REQUEST_TIMEOUT_MS,
); );
if (shouldStream) { if (shouldStream) {
let responseText = ""; let responseText = "";
let remainText = ""; let remainText = "";
let finished = false; let finished = false;
let existingTexts: string[] = [];
const finish = () => { const finish = () => {
finished = true; finished = true;
options.onFinish(existingTexts.join("")); options.onFinish(responseText + remainText);
}; };
// animate response to make it looks smooth // animate response to make it looks smooth
@ -173,72 +180,85 @@ export class GeminiProApi implements LLMApi {
// start animaion // start animaion
animateResponseText(); animateResponseText();
fetch( controller.signal.onabort = finish;
baseUrl.replace("generateContent", "streamGenerateContent"),
chatPayload,
)
.then((response) => {
const reader = response?.body?.getReader();
const decoder = new TextDecoder();
let partialData = "";
return reader?.read().then(function processText({ // https://github.com/google-gemini/cookbook/blob/main/quickstarts/rest/Streaming_REST.ipynb
done, const chatPath =
value, baseUrl.replace("generateContent", "streamGenerateContent") +
}): Promise<any> { (baseUrl.indexOf("?") > -1 ? "&alt=sse" : "?alt=sse");
if (done) { fetchEventSource(chatPath, {
if (response.status !== 200) { ...chatPayload,
try { async onopen(res) {
let data = JSON.parse(ensureProperEnding(partialData)); clearTimeout(requestTimeoutId);
if (data && data[0].error) { const contentType = res.headers.get("content-type");
options.onError?.(new Error(data[0].error.message)); console.log(
} else { "[Gemini] request response content type: ",
options.onError?.(new Error("Request failed")); contentType,
} );
} catch (_) {
options.onError?.(new Error("Request failed"));
}
}
console.log("Stream complete"); if (contentType?.startsWith("text/plain")) {
// options.onFinish(responseText + remainText); responseText = await res.clone().text();
finished = true; return finish();
return Promise.resolve(); }
}
partialData += decoder.decode(value, { stream: true });
if (
!res.ok ||
!res.headers
.get("content-type")
?.startsWith(EventStreamContentType) ||
res.status !== 200
) {
const responseTexts = [responseText];
let extraInfo = await res.clone().text();
try { try {
let data = JSON.parse(ensureProperEnding(partialData)); const resJson = await res.clone().json();
extraInfo = prettyObject(resJson);
} catch {}
const textArray = data.reduce( if (res.status === 401) {
(acc: string[], item: { candidates: any[] }) => { responseTexts.push(Locale.Error.Unauthorized);
const texts = item.candidates.map((candidate) =>
candidate.content.parts
.map((part: { text: any }) => part.text)
.join(""),
);
return acc.concat(texts);
},
[],
);
if (textArray.length > existingTexts.length) {
const deltaArray = textArray.slice(existingTexts.length);
existingTexts = textArray;
remainText += deltaArray.join("");
}
} catch (error) {
// console.log("[Response Animation] error: ", error,partialData);
// skip error message when parsing json
} }
return reader.read().then(processText); if (extraInfo) {
}); responseTexts.push(extraInfo);
}) }
.catch((error) => {
console.error("Error:", error); responseText = responseTexts.join("\n\n");
});
return finish();
}
},
onmessage(msg) {
if (msg.data === "[DONE]" || finished) {
return finish();
}
const text = msg.data;
try {
const json = JSON.parse(text);
const delta = apiClient.extractMessage(json);
if (delta) {
remainText += delta;
}
const blockReason = json?.promptFeedback?.blockReason;
if (blockReason) {
// being blocked
console.log(`[Google] [Safety Ratings] result:`, blockReason);
}
} catch (e) {
console.error("[Request] parse error", text, msg);
}
},
onclose() {
finish();
},
onerror(e) {
options.onError?.(e);
throw e;
},
openWhenHidden: true,
});
} else { } else {
const res = await fetch(baseUrl, chatPayload); const res = await fetch(baseUrl, chatPayload);
clearTimeout(requestTimeoutId); clearTimeout(requestTimeoutId);
@ -252,7 +272,7 @@ export class GeminiProApi implements LLMApi {
), ),
); );
} }
const message = this.extractMessage(resJson); const message = apiClient.extractMessage(resJson);
options.onFinish(message); options.onFinish(message);
} }
} catch (e) { } catch (e) {