using stream_fetch in App

This commit is contained in:
lloydzhou 2024-09-29 19:44:09 +08:00
parent 2d920f7ccc
commit 3898c507c4
5 changed files with 156 additions and 137 deletions

View File

@ -288,12 +288,16 @@ export function showPlugins(provider: ServiceProvider, model: string) {
} }
export function adapter(config: Record<string, unknown>) { export function adapter(config: Record<string, unknown>) {
const { baseURL, url, params, ...rest } = config; const { baseURL, url, params, method, data, ...rest } = config;
const path = baseURL ? `${baseURL}${url}` : url; const path = baseURL ? `${baseURL}${url}` : url;
const fetchUrl = params const fetchUrl = params
? `${path}?${new URLSearchParams(params as any).toString()}` ? `${path}?${new URLSearchParams(params as any).toString()}`
: path; : path;
return fetch(fetchUrl as string, rest) return fetch(fetchUrl as string, {
...rest,
method,
body: method.toUpperCase() == "GET" ? undefined : data,
})
.then((res) => res.text()) .then((res) => res.text())
.then((data) => ({ data })); .then((data) => ({ data }));
} }

View File

@ -10,6 +10,7 @@ import {
fetchEventSource, fetchEventSource,
} from "@fortaine/fetch-event-source"; } from "@fortaine/fetch-event-source";
import { prettyObject } from "./format"; import { prettyObject } from "./format";
import { fetch as tauriFetch } from "./stream";
export function compressImage(file: Blob, maxSize: number): Promise<string> { export function compressImage(file: Blob, maxSize: number): Promise<string> {
return new Promise((resolve, reject) => { return new Promise((resolve, reject) => {
@ -287,6 +288,7 @@ export function stream(
REQUEST_TIMEOUT_MS, REQUEST_TIMEOUT_MS,
); );
fetchEventSource(chatPath, { fetchEventSource(chatPath, {
fetch: tauriFetch,
...chatPayload, ...chatPayload,
async onopen(res) { async onopen(res) {
clearTimeout(requestTimeoutId); clearTimeout(requestTimeoutId);

View File

@ -1,100 +1,94 @@
// using tauri register_uri_scheme_protocol, register `stream:` protocol // using tauri command to send request
// see src-tauri/src/stream.rs, and src-tauri/src/main.rs // see src-tauri/src/stream.rs, and src-tauri/src/main.rs
// 1. window.fetch(`stream://localhost/${fetchUrl}`), get request_id // 1. invoke('stream_fetch', {url, method, headers, body}), get response with headers.
// 2. listen event: `stream-response` multi times to get response headers and body // 2. listen event: `stream-response` multi times to get body
type ResponseEvent = { type ResponseEvent = {
id: number; id: number;
payload: { payload: {
request_id: number; request_id: number;
status?: number; status?: number;
error?: string;
name?: string;
value?: string;
chunk?: number[]; chunk?: number[];
}; };
}; };
export function fetch(url: string, options?: RequestInit): Promise<any> { export function fetch(url: string, options?: RequestInit): Promise<any> {
if (window.__TAURI__) { if (window.__TAURI__) {
const tauriUri = window.__TAURI__.convertFileSrc(url, "stream"); const { signal, method = "GET", headers = {}, body = [] } = options || {};
const { signal, ...rest } = options || {}; return window.__TAURI__
return window .invoke("stream_fetch", {
.fetch(tauriUri, rest) method,
.then((r) => r.text()) url,
.then((rid) => parseInt(rid)) headers,
.then((request_id: number) => { // TODO FormData
// 1. using event to get status and statusText and headers, and resolve it body:
let resolve: Function | undefined; typeof body === "string"
let reject: Function | undefined; ? Array.from(new TextEncoder().encode(body))
let status: number; : [],
let writable: WritableStream | undefined; })
let writer: WritableStreamDefaultWriter | undefined; .then(
const headers = new Headers(); (res: {
request_id: number;
status: number;
status_text: string;
headers: Record<string, string>;
}) => {
const { request_id, status, status_text: statusText, headers } = res;
console.log("send request_id", request_id, status, statusText);
let unlisten: Function | undefined; let unlisten: Function | undefined;
const ts = new TransformStream();
const writer = ts.writable.getWriter();
if (signal) { const close = () => {
signal.addEventListener("abort", () => {
// Reject the promise with the abort reason.
unlisten && unlisten(); unlisten && unlisten();
reject && reject(signal.reason); writer.ready.then(() => {
try {
writer.releaseLock();
} catch (e) {
console.error(e);
}
ts.writable.close();
}); });
};
const response = new Response(ts.readable, {
status,
statusText,
headers,
});
if (signal) {
signal.addEventListener("abort", () => close());
} }
// @ts-ignore 2. listen response multi times, and write to Response.body // @ts-ignore 2. listen response multi times, and write to Response.body
window.__TAURI__.event window.__TAURI__.event
.listen("stream-response", (e: ResponseEvent) => { .listen("stream-response", (e: ResponseEvent) => {
const { id, payload } = e; const { id, payload } = e;
const { const { request_id: rid, chunk, status } = payload;
request_id: rid,
status: _status,
name,
value,
error,
chunk,
} = payload;
if (request_id != rid) { if (request_id != rid) {
return; return;
} }
/** if (chunk) {
* 1. get status code
* 2. get headers
* 3. start get body, then resolve response
* 4. get body chunk
*/
if (error) {
unlisten && unlisten();
return reject && reject(error);
} else if (_status) {
status = _status;
} else if (name && value) {
headers.append(name, value);
} else if (chunk) {
if (resolve) {
const ts = new TransformStream();
writable = ts.writable;
writer = writable.getWriter();
resolve(new Response(ts.readable, { status, headers }));
resolve = undefined;
}
writer && writer &&
writer.ready.then(() => { writer.ready.then(() => {
writer && writer.write(new Uint8Array(chunk)); writer && writer.write(new Uint8Array(chunk));
}); });
} else if (_status === 0) { } else if (status === 0) {
// end of body // end of body
unlisten && unlisten(); close();
writer &&
writer.ready.then(() => {
writer && writer.releaseLock();
writable && writable.close();
});
} }
}) })
.then((u: Function) => (unlisten = u)); .then((u: Function) => (unlisten = u));
return new Promise( return response;
(_resolve, _reject) => ([resolve, reject] = [_resolve, _reject]), },
); )
.catch((e) => {
console.error("stream error", e);
throw e;
}); });
} }
return window.fetch(url, options); return window.fetch(url, options);
} }
if (undefined !== window) {
window.tauriFetch = fetch;
}

View File

@ -5,10 +5,8 @@ mod stream;
fn main() { fn main() {
tauri::Builder::default() tauri::Builder::default()
.invoke_handler(tauri::generate_handler![stream::stream_fetch])
.plugin(tauri_plugin_window_state::Builder::default().build()) .plugin(tauri_plugin_window_state::Builder::default().build())
.register_uri_scheme_protocol("stream", move |app_handle, request| {
stream::stream(app_handle, request)
})
.run(tauri::generate_context!()) .run(tauri::generate_context!())
.expect("error while running tauri application"); .expect("error while running tauri application");
} }

View File

@ -1,30 +1,25 @@
//
//
use std::error::Error; use std::error::Error;
use futures_util::{StreamExt}; use futures_util::{StreamExt};
use reqwest::Client; use reqwest::Client;
use tauri::{ Manager, AppHandle }; use reqwest::header::{HeaderName, HeaderMap};
use tauri::http::{Request, ResponseBuilder};
use tauri::http::Response;
static mut REQUEST_COUNTER: u32 = 0; static mut REQUEST_COUNTER: u32 = 0;
#[derive(Clone, serde::Serialize)] #[derive(Clone, serde::Serialize)]
pub struct ErrorPayload { pub struct StreamResponse {
request_id: u32,
error: String,
}
#[derive(Clone, serde::Serialize)]
pub struct StatusPayload {
request_id: u32, request_id: u32,
status: u16, status: u16,
status_text: String,
headers: HashMap<String, String>
} }
#[derive(Clone, serde::Serialize)] #[derive(Clone, serde::Serialize)]
pub struct HeaderPayload { pub struct EndPayload {
request_id: u32, request_id: u32,
name: String, status: u16,
value: String,
} }
#[derive(Clone, serde::Serialize)] #[derive(Clone, serde::Serialize)]
@ -33,64 +28,90 @@ pub struct ChunkPayload {
chunk: bytes::Bytes, chunk: bytes::Bytes,
} }
pub fn stream(app_handle: &AppHandle, request: &Request) -> Result<Response, Box<dyn Error>> { use std::collections::HashMap;
#[derive(serde::Serialize)]
pub struct CustomResponse {
message: String,
other_val: usize,
}
#[tauri::command]
pub async fn stream_fetch(
window: tauri::Window,
method: String,
url: String,
headers: HashMap<String, String>,
body: Vec<u8>,
) -> Result<StreamResponse, String> {
let mut request_id = 0; let mut request_id = 0;
let event_name = "stream-response"; let event_name = "stream-response";
unsafe { unsafe {
REQUEST_COUNTER += 1; REQUEST_COUNTER += 1;
request_id = REQUEST_COUNTER; request_id = REQUEST_COUNTER;
} }
let path = request.uri().to_string().replace("stream://localhost/", "").replace("http://stream.localhost/", "");
let path = percent_encoding::percent_decode(path.as_bytes())
.decode_utf8_lossy()
.to_string();
// println!("path : {}", path);
let client = Client::new();
let handle = app_handle.app_handle();
// send http request
let body = reqwest::Body::from(request.body().clone());
let response_future = client.request(request.method().clone(), path)
.headers(request.headers().clone())
.body(body).send();
// get response and emit to client let mut _headers = HeaderMap::new();
tauri::async_runtime::spawn(async move { for (key, value) in headers {
let res = response_future.await; _headers.insert(key.parse::<HeaderName>().unwrap(), value.parse().unwrap());
match res {
Ok(res) => {
handle.emit_all(event_name, StatusPayload{ request_id, status: res.status().as_u16() }).unwrap();
for (name, value) in res.headers() {
handle.emit_all(event_name, HeaderPayload {
request_id,
name: name.to_string(),
value: std::str::from_utf8(value.as_bytes()).unwrap().to_string()
}).unwrap();
} }
let body = bytes::Bytes::from(body);
let response_future = Client::new().request(
method.parse::<reqwest::Method>().map_err(|err| format!("failed to parse method: {}", err))?,
url.parse::<reqwest::Url>().map_err(|err| format!("failed to parse url: {}", err))?
).headers(_headers).body(body).send();
let res = response_future.await;
let response = match res {
Ok(res) => {
println!("Error: {:?}", res);
// get response and emit to client
// .register_uri_scheme_protocol("stream", move |app_handle, request| {
let mut headers = HashMap::new();
for (name, value) in res.headers() {
headers.insert(
name.as_str().to_string(),
std::str::from_utf8(value.as_bytes()).unwrap().to_string()
);
}
let status = res.status().as_u16();
tauri::async_runtime::spawn(async move {
let mut stream = res.bytes_stream(); let mut stream = res.bytes_stream();
while let Some(chunk) = stream.next().await { while let Some(chunk) = stream.next().await {
match chunk { match chunk {
Ok(bytes) => { Ok(bytes) => {
handle.emit_all(event_name, ChunkPayload{ request_id, chunk: bytes }).unwrap(); println!("chunk: {:?}", bytes);
window.emit(event_name, ChunkPayload{ request_id, chunk: bytes }).unwrap();
} }
Err(err) => { Err(err) => {
println!("Error: {:?}", err); println!("Error: {:?}", err);
} }
} }
} }
handle.emit_all(event_name, StatusPayload { request_id, status: 0 }).unwrap(); window.emit(event_name, EndPayload { request_id, status: 0 }).unwrap();
});
StreamResponse {
request_id,
status,
status_text: "OK".to_string(),
headers,
}
} }
Err(err) => { Err(err) => {
println!("Error: {:?}", err.source().expect("REASON").to_string()); println!("Error: {:?}", err.source().expect("REASON").to_string());
handle.emit_all(event_name, ErrorPayload { StreamResponse {
request_id, request_id,
error: err.source().expect("REASON").to_string() status: 599,
}).unwrap(); status_text: err.source().expect("REASON").to_string(),
headers: HashMap::new(),
} }
} }
}); };
return ResponseBuilder::new() Ok(response)
.header("Access-Control-Allow-Origin", "*")
.status(200).body(request_id.to_string().into())
} }