From a18cb2c525150d78461c786561a9bdd02940feb5 Mon Sep 17 00:00:00 2001 From: Hk-Gosuto Date: Sun, 24 Mar 2024 11:42:06 +0800 Subject: [PATCH] feat: #226 --- app/api/common.ts | 11 +++++++++-- app/client/platforms/openai.ts | 4 ++-- app/components/button.module.scss | 10 ++++++++++ app/components/button.tsx | 17 +++++++++++++++-- app/components/chat.tsx | 18 ++++++++++++++---- app/components/stt-config.tsx | 24 ++++++++++++++++++++++-- app/constant.ts | 3 +++ app/icons/three-dots-white.svg | 14 ++++++++++++++ app/locales/cn.ts | 4 ++++ app/locales/en.ts | 4 ++++ app/store/config.ts | 11 +++++++++++ app/utils/speech.ts | 10 ++++------ 12 files changed, 112 insertions(+), 18 deletions(-) create mode 100644 app/icons/three-dots-white.svg diff --git a/app/api/common.ts b/app/api/common.ts index 190f87805..cd521c324 100644 --- a/app/api/common.ts +++ b/app/api/common.ts @@ -67,9 +67,16 @@ export async function requestOpenai(req: NextRequest) { let jsonBody; let clonedBody; - if (req.method !== "GET" && req.method !== "HEAD") { + const contentType = req.headers.get("Content-Type"); + if ( + req.method !== "GET" && + req.method !== "HEAD" && + contentType?.includes("json") + ) { clonedBody = await req.text(); jsonBody = JSON.parse(clonedBody) as { model?: string }; + } else { + clonedBody = req.body; } if (serverConfig.isAzure) { baseUrl = `${baseUrl}/${jsonBody?.model}`; @@ -77,7 +84,7 @@ export async function requestOpenai(req: NextRequest) { const fetchUrl = `${baseUrl}/${path}`; const fetchOptions: RequestInit = { headers: { - "Content-Type": "application/json", + "Content-Type": contentType ?? "application/json", "Cache-Control": "no-store", [authHeaderName]: authValue, ...(serverConfig.openaiOrgId && { diff --git a/app/client/platforms/openai.ts b/app/client/platforms/openai.ts index 5cc7f07e7..6be3e74b1 100644 --- a/app/client/platforms/openai.ts +++ b/app/client/platforms/openai.ts @@ -143,11 +143,12 @@ export class ChatGPTApi implements LLMApi { try { const path = this.path(OpenaiPath.TranscriptionPath, options.model); + const headers = getHeaders(true); const payload = { method: "POST", body: formData, signal: controller.signal, - headers: getHeaders(true), + headers: headers, }; // make a fetch request @@ -155,7 +156,6 @@ export class ChatGPTApi implements LLMApi { () => controller.abort(), REQUEST_TIMEOUT_MS, ); - const res = await fetch(path, payload); clearTimeout(requestTimeoutId); const json = await res.json(); diff --git a/app/components/button.module.scss b/app/components/button.module.scss index e332df2d2..1ce62c627 100644 --- a/app/components/button.module.scss +++ b/app/components/button.module.scss @@ -65,6 +65,16 @@ align-items: center; } +.icon-button-loading-icon { + width: 40px; + height: 16px; + display: flex; + align-items: center; + justify-content: center; + fill: white; + stroke: white; +} + @media only screen and (max-width: 600px) { .icon-button { padding: 16px; diff --git a/app/components/button.tsx b/app/components/button.tsx index 7a5633924..8316661fd 100644 --- a/app/components/button.tsx +++ b/app/components/button.tsx @@ -4,6 +4,8 @@ import styles from "./button.module.scss"; export type ButtonType = "primary" | "danger" | null; +import LoadingIcon from "../icons/three-dots-white.svg"; + export function IconButton(props: { onClick?: () => void; icon?: JSX.Element; @@ -16,6 +18,7 @@ export function IconButton(props: { disabled?: boolean; tabIndex?: number; autoFocus?: boolean; + loding?: boolean; }) { return ( ); } diff --git a/app/components/chat.tsx b/app/components/chat.tsx index 98fb06077..4ad66a8ce 100644 --- a/app/components/chat.tsx +++ b/app/components/chat.tsx @@ -91,6 +91,7 @@ import { import { useNavigate } from "react-router-dom"; import { CHAT_PAGE_SIZE, + DEFAULT_STT_ENGINE, LAST_INPUT_KEY, ModelProvider, Path, @@ -806,10 +807,10 @@ function _Chat() { }; const [isListening, setIsListening] = useState(false); + const [isTranscription, setIsTranscription] = useState(false); const [speechApi, setSpeechApi] = useState(null); const startListening = async () => { - console.log(speechApi); if (speechApi) { await speechApi.start(); setIsListening(true); @@ -818,6 +819,8 @@ function _Chat() { const stopListening = async () => { if (speechApi) { + if (config.sttConfig.engine !== DEFAULT_STT_ENGINE) + setIsTranscription(true); await speechApi.stop(); setIsListening(false); } @@ -826,6 +829,8 @@ function _Chat() { const onRecognitionEnd = (finalTranscript: string) => { console.log(finalTranscript); if (finalTranscript) setUserInput(finalTranscript); + if (config.sttConfig.engine !== DEFAULT_STT_ENGINE) + setIsTranscription(false); }; const doSubmit = (userInput: string) => { @@ -899,9 +904,13 @@ function _Chat() { }); // eslint-disable-next-line react-hooks/exhaustive-deps setSpeechApi( - new OpenAITranscriptionApi((transcription) => - onRecognitionEnd(transcription), - ), + config.sttConfig.engine === DEFAULT_STT_ENGINE + ? new WebTranscriptionApi((transcription) => + onRecognitionEnd(transcription), + ) + : new OpenAITranscriptionApi((transcription) => + onRecognitionEnd(transcription), + ), ); }, []); @@ -1695,6 +1704,7 @@ function _Chat() { onClick={async () => isListening ? await stopListening() : await startListening() } + loding={isTranscription} /> ) : ( + + + ); } diff --git a/app/constant.ts b/app/constant.ts index d50cfff71..4712a1019 100644 --- a/app/constant.ts +++ b/app/constant.ts @@ -134,6 +134,9 @@ export const DEFAULT_TTS_VOICES = [ "shimmer", ]; +export const DEFAULT_STT_ENGINE = "WebAPI"; +export const DEFAULT_STT_ENGINES = ["WebAPI", "OpenAI Whisper"]; + export const DEFAULT_MODELS = [ { name: "gpt-4", diff --git a/app/icons/three-dots-white.svg b/app/icons/three-dots-white.svg new file mode 100644 index 000000000..cf5dfe7f7 --- /dev/null +++ b/app/icons/three-dots-white.svg @@ -0,0 +1,14 @@ + + + + + + + + + + + + + + diff --git a/app/locales/cn.ts b/app/locales/cn.ts index 5661f943a..0595fa30d 100644 --- a/app/locales/cn.ts +++ b/app/locales/cn.ts @@ -402,6 +402,10 @@ const cn = { Title: "启用语音转文本", SubTitle: "启用语音转文本", }, + Engine: { + Title: "转换引擎", + SubTitle: "音频转换引擎", + }, }, }, Store: { diff --git a/app/locales/en.ts b/app/locales/en.ts index 1c4350ee8..5c562b5a4 100644 --- a/app/locales/en.ts +++ b/app/locales/en.ts @@ -408,6 +408,10 @@ const en: LocaleType = { Title: "Enable STT", SubTitle: "Enable Speech-to-Text", }, + Engine: { + Title: "STT Engine", + SubTitle: "Text-to-Speech Engine", + }, }, }, Store: { diff --git a/app/store/config.ts b/app/store/config.ts index df0cfe9db..1d9c56835 100644 --- a/app/store/config.ts +++ b/app/store/config.ts @@ -5,6 +5,8 @@ import { DEFAULT_INPUT_TEMPLATE, DEFAULT_MODELS, DEFAULT_SIDEBAR_WIDTH, + DEFAULT_STT_ENGINE, + DEFAULT_STT_ENGINES, DEFAULT_TTS_MODEL, DEFAULT_TTS_MODELS, DEFAULT_TTS_VOICE, @@ -17,6 +19,8 @@ export type ModelType = (typeof DEFAULT_MODELS)[number]["name"]; export type TTSModelType = (typeof DEFAULT_TTS_MODELS)[number]; export type TTSVoiceType = (typeof DEFAULT_TTS_VOICES)[number]; +export type STTEngineType = (typeof DEFAULT_STT_ENGINES)[number]; + export enum SubmitKey { Enter = "Enter", CtrlEnter = "Ctrl + Enter", @@ -81,6 +85,7 @@ export const DEFAULT_CONFIG = { sttConfig: { enable: false, + engine: DEFAULT_STT_ENGINE, }, }; @@ -116,6 +121,12 @@ export const TTSConfigValidator = { }, }; +export const STTConfigValidator = { + engine(x: string) { + return x as STTEngineType; + }, +}; + export const ModalConfigValidator = { model(x: string) { return x as ModelType; diff --git a/app/utils/speech.ts b/app/utils/speech.ts index 0c74fc0d9..e993859b5 100644 --- a/app/utils/speech.ts +++ b/app/utils/speech.ts @@ -31,7 +31,7 @@ export class OpenAITranscriptionApi extends SpeechApi { } async start(): Promise { - // @ts-ignore + // @ts-ignore prettier-ignore navigator.getUserMedia = navigator.getUserMedia || navigator.webkitGetUserMedia || @@ -103,20 +103,18 @@ export class WebTranscriptionApi extends SpeechApi { this.recognitionInstance.onresult = (event: any) => { const result = event.results[event.results.length - 1]; if (result.isFinal) { - if (!this.isListening) { - this.onTranscriptionReceived(result[0].transcript); - } + this.onTranscription(result[0].transcript); } }; } async start(): Promise { - await this.recognitionInstance.start(); this.listeningStatus = true; + await this.recognitionInstance.start(); } async stop(): Promise { - await this.recognitionInstance.stop(); this.listeningStatus = false; + await this.recognitionInstance.stop(); } }