ChatGPT-Next-Web/app/components/TensorFlow.tsx
2025-04-30 14:06:10 +08:00

507 lines
15 KiB
TypeScript
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import React, { useState, useEffect, useRef } from "react";
import * as tf from "@tensorflow/tfjs";
import { VoicePrint } from "./voice-print/voice-print";
import styles from "./TensorFlow.module.scss";
// 声纹识别状态
enum VoiceRecognitionStatus {
IDLE = "空闲",
RECORDING = "录制中",
TRAINING = "训练中",
RECOGNIZING = "识别中",
TRAINED = "已训练",
MATCHED = "声纹匹配",
NOT_MATCHED = "声纹不匹配",
ERROR = "错误",
}
// 声纹特征提取参数
const SAMPLE_RATE = 16000; // 采样率
const FFT_SIZE = 1024; // FFT大小
const MEL_BINS = 40; // Mel滤波器数量
const FRAME_LENGTH = 25; // 帧长度(ms)
const FRAME_STEP = 10; // 帧步长(ms)
const FEATURE_LENGTH = 100; // 特征序列长度
const TensorFlow: React.FC = () => {
// 状态管理
const [status, setStatus] = useState<VoiceRecognitionStatus>(
VoiceRecognitionStatus.IDLE,
);
const [message, setMessage] = useState<string>("");
const [isRecording, setIsRecording] = useState<boolean>(false);
const [isTrained, setIsTrained] = useState<boolean>(false);
const [matchScore, setMatchScore] = useState<number>(0);
const [frequencies, setFrequencies] = useState<Uint8Array | undefined>(
undefined,
);
// 引用
const audioContextRef = useRef<AudioContext | null>(null);
const analyserRef = useRef<AnalyserNode | null>(null);
const mediaStreamRef = useRef<MediaStream | null>(null);
const recordedChunksRef = useRef<Float32Array[]>([]);
const modelRef = useRef<tf.LayersModel | null>(null);
const voiceprintRef = useRef<Float32Array | null>(null);
const animationFrameRef = useRef<number | null>(null);
// 初始化
useEffect(() => {
// 检查是否有保存的声纹模型
const savedVoiceprint = localStorage.getItem("userVoiceprint");
if (savedVoiceprint) {
try {
voiceprintRef.current = new Float32Array(JSON.parse(savedVoiceprint));
setIsTrained(true);
setStatus(VoiceRecognitionStatus.TRAINED);
setMessage("已加载保存的声纹模型");
} catch (error) {
console.error("加载保存的声纹模型失败:", error);
}
}
// 加载TensorFlow模型
loadModel();
return () => {
stopRecording();
if (animationFrameRef.current) {
cancelAnimationFrame(animationFrameRef.current);
}
};
}, []);
// 加载声纹识别模型
const loadModel = async () => {
try {
// 创建简单的声纹识别模型
const model = tf.sequential();
// 添加卷积层处理音频特征
model.add(
tf.layers.conv1d({
inputShape: [FEATURE_LENGTH, MEL_BINS],
filters: 32,
kernelSize: 3,
activation: "relu",
}),
);
model.add(tf.layers.maxPooling1d({ poolSize: 2 }));
model.add(
tf.layers.conv1d({
filters: 64,
kernelSize: 3,
activation: "relu",
}),
);
model.add(tf.layers.maxPooling1d({ poolSize: 2 }));
model.add(tf.layers.flatten());
// 添加全连接层
model.add(tf.layers.dense({ units: 128, activation: "relu" }));
model.add(tf.layers.dropout({ rate: 0.5 }));
// 输出层 - 声纹特征向量
model.add(tf.layers.dense({ units: 64, activation: "linear" }));
// 编译模型
model.compile({
optimizer: "adam",
loss: "meanSquaredError",
});
modelRef.current = model;
console.log("声纹识别模型已加载");
} catch (error) {
console.error("加载模型失败:", error);
setStatus(VoiceRecognitionStatus.ERROR);
setMessage("加载模型失败");
}
};
// 开始录音
const startRecording = async (isTraining: boolean = false) => {
try {
if (isRecording) return;
// 重置录音数据
recordedChunksRef.current = [];
// 请求麦克风权限
const stream = await navigator.mediaDevices.getUserMedia({ audio: true });
mediaStreamRef.current = stream;
// 创建音频上下文
const audioContext = new (window.AudioContext ||
(window as any).webkitAudioContext)();
audioContextRef.current = audioContext;
// 创建分析器节点用于可视化
const analyser = audioContext.createAnalyser();
analyser.fftSize = FFT_SIZE;
analyserRef.current = analyser;
// 创建音频源
const source = audioContext.createMediaStreamSource(stream);
source.connect(analyser);
// 创建处理器节点
const processor = audioContext.createScriptProcessor(4096, 1, 1);
// 处理音频数据
processor.onaudioprocess = (e) => {
const inputData = e.inputBuffer.getChannelData(0);
recordedChunksRef.current.push(new Float32Array(inputData));
};
// 连接节点
analyser.connect(processor);
processor.connect(audioContext.destination);
// 更新状态
setIsRecording(true);
setStatus(
isTraining
? VoiceRecognitionStatus.RECORDING
: VoiceRecognitionStatus.RECOGNIZING,
);
setMessage(
isTraining ? "请说话3-5秒钟用于训练..." : "请说话进行声纹识别...",
);
// 开始频谱可视化
startVisualization();
// 设置自动停止录音训练模式下5秒后自动停止
if (isTraining) {
setTimeout(() => {
stopRecording();
trainVoiceprint();
}, 5000);
}
} catch (error) {
console.error("开始录音失败:", error);
setStatus(VoiceRecognitionStatus.ERROR);
setMessage("无法访问麦克风,请检查权限");
}
};
// 停止录音
const stopRecording = () => {
if (!isRecording) return;
// 停止所有音频流
if (mediaStreamRef.current) {
mediaStreamRef.current.getTracks().forEach((track) => track.stop());
mediaStreamRef.current = null;
}
// 关闭音频上下文
if (audioContextRef.current) {
audioContextRef.current.close();
audioContextRef.current = null;
}
// 停止可视化
if (animationFrameRef.current) {
cancelAnimationFrame(animationFrameRef.current);
animationFrameRef.current = null;
}
setIsRecording(false);
setFrequencies(undefined);
};
// 开始频谱可视化
const startVisualization = () => {
const analyser = analyserRef.current;
if (!analyser) return;
const bufferLength = analyser.frequencyBinCount;
const dataArray = new Uint8Array(bufferLength);
const updateVisualization = () => {
if (!analyser) return;
analyser.getByteFrequencyData(dataArray);
setFrequencies(dataArray);
animationFrameRef.current = requestAnimationFrame(updateVisualization);
};
updateVisualization();
};
// 提取音频特征
const extractFeatures = async (
audioData: Float32Array[],
): Promise<tf.Tensor | null> => {
try {
// 合并所有音频块
const mergedData = new Float32Array(
audioData.reduce((acc, chunk) => acc + chunk.length, 0),
);
let offset = 0;
for (const chunk of audioData) {
mergedData.set(chunk, offset);
offset += chunk.length;
}
// 转换为张量
const audioTensor = tf.tensor1d(mergedData);
// 计算梅尔频谱图 (简化版)
// 在实际应用中,这里应该使用更复杂的信号处理方法
// 如MFCC (Mel-frequency cepstral coefficients)
const frameLength = Math.round((SAMPLE_RATE * FRAME_LENGTH) / 1000);
const frameStep = Math.round((SAMPLE_RATE * FRAME_STEP) / 1000);
// 使用短时傅里叶变换提取特征
// 注意这是简化版实际应用中应使用专业的DSP库
const frames = [];
for (let i = 0; i + frameLength <= mergedData.length; i += frameStep) {
const frame = mergedData.slice(i, i + frameLength);
frames.push(Array.from(frame));
}
// 限制帧数
const limitedFrames = frames.slice(0, FEATURE_LENGTH);
// 如果帧数不足,用零填充
while (limitedFrames.length < FEATURE_LENGTH) {
limitedFrames.push(new Array(frameLength).fill(0));
}
// 创建特征张量
const featureTensor = tf.tensor(limitedFrames);
// 简化的梅尔频谱计算
// 在实际应用中应使用更准确的方法
const melSpectrogram = tf.tidy(() => {
// 应用FFT (简化)
const fftMag = featureTensor.abs();
// 降维到MEL_BINS
const reshaped = fftMag.reshape([FEATURE_LENGTH, -1]);
const melFeatures = reshaped.slice([0, 0], [FEATURE_LENGTH, MEL_BINS]);
// 归一化
const normalized = melFeatures.div(tf.scalar(255.0));
return normalized.expandDims(0); // 添加批次维度
});
return melSpectrogram;
} catch (error) {
console.error("特征提取失败:", error);
return null;
}
};
// 训练声纹模型
const trainVoiceprint = async () => {
if (recordedChunksRef.current.length === 0 || !modelRef.current) {
setStatus(VoiceRecognitionStatus.ERROR);
setMessage("没有录音数据或模型未加载");
return;
}
setStatus(VoiceRecognitionStatus.TRAINING);
setMessage("正在训练声纹模型...");
try {
// 提取特征
const features = await extractFeatures(recordedChunksRef.current);
if (!features) throw new Error("特征提取失败");
// 使用模型提取声纹特征向量
const voiceprint = tf.tidy(() => {
// 前向传播获取声纹特征
const prediction = modelRef.current!.predict(features) as tf.Tensor;
// 归一化特征向量
return tf.div(prediction, tf.norm(prediction));
});
// 保存声纹特征
const voiceprintData = await voiceprint.data();
voiceprintRef.current = new Float32Array(voiceprintData);
// 保存到localStorage
localStorage.setItem(
"userVoiceprint",
JSON.stringify(Array.from(voiceprintData)),
);
setIsTrained(true);
setStatus(VoiceRecognitionStatus.TRAINED);
setMessage("声纹模型训练完成并已保存");
// 清理
voiceprint.dispose();
features.dispose();
} catch (error) {
console.error("训练失败:", error);
setStatus(VoiceRecognitionStatus.ERROR);
setMessage("声纹训练失败");
}
};
// 识别声纹
const recognizeVoice = async () => {
if (!isTrained || !voiceprintRef.current) {
setStatus(VoiceRecognitionStatus.ERROR);
setMessage("请先训练声纹模型");
return;
}
if (recordedChunksRef.current.length === 0 || !modelRef.current) {
setStatus(VoiceRecognitionStatus.ERROR);
setMessage("没有录音数据或模型未加载");
return;
}
try {
// 提取特征
const features = await extractFeatures(recordedChunksRef.current);
if (!features) throw new Error("特征提取失败");
// 使用模型提取声纹特征向量
const currentVoiceprint = tf.tidy(() => {
// 前向传播获取声纹特征
const prediction = modelRef.current!.predict(features) as tf.Tensor;
// 归一化特征向量
return tf.div(prediction, tf.norm(prediction));
});
// 计算与保存的声纹的余弦相似度
const similarity = tf.tidy(() => {
const savedVoiceprint = tf.tensor1d(voiceprintRef.current!);
// 计算点积
const dotProduct = tf.sum(
tf.mul(currentVoiceprint.reshape([-1]), savedVoiceprint),
);
return dotProduct;
});
// 获取相似度分数 (范围从-1到1越接近1表示越相似)
const similarityScore = await similarity.data();
const score = similarityScore[0];
setMatchScore(score);
// 判断是否为同一人 (阈值可调整)
const threshold = 0.7;
const isMatch = score > threshold;
setStatus(
isMatch
? VoiceRecognitionStatus.MATCHED
: VoiceRecognitionStatus.NOT_MATCHED,
);
setMessage(
isMatch
? `声纹匹配成功!相似度: ${(score * 100).toFixed(2)}%`
: `声纹不匹配。相似度: ${(score * 100).toFixed(2)}%`,
);
// 清理
currentVoiceprint.dispose();
features.dispose();
similarity.dispose();
} catch (error) {
console.error("识别失败:", error);
setStatus(VoiceRecognitionStatus.ERROR);
setMessage("声纹识别失败");
}
};
// 清除训练数据
const clearTrainedData = () => {
localStorage.removeItem("userVoiceprint");
voiceprintRef.current = null;
setIsTrained(false);
setStatus(VoiceRecognitionStatus.IDLE);
setMessage("声纹数据已清除");
};
return (
<div className={styles.voiceRecognitionContainer}>
<h2 className={styles.title}></h2>
<div className={styles.statusContainer}>
<div className={styles.statusIndicator}>
<div
className={`${styles.statusDot} ${styles[status.toLowerCase()]}`}
></div>
<span className={styles.statusText}>{status}</span>
</div>
<p className={styles.message}>{message}</p>
</div>
<div className={styles.visualizerContainer}>
<VoicePrint frequencies={frequencies} isActive={isRecording} />
</div>
<div className={styles.controlsContainer}>
<div className={styles.trainingControls}>
<h3></h3>
<button
className={styles.button}
onClick={() => startRecording(true)}
disabled={isRecording}
>
</button>
<button
className={styles.button}
onClick={clearTrainedData}
disabled={!isTrained}
>
</button>
</div>
<div className={styles.recognitionControls}>
<h3></h3>
<button
className={styles.button}
onClick={() => startRecording(false)}
disabled={isRecording || !isTrained}
>
</button>
<button
className={styles.button}
onClick={() => {
stopRecording();
recognizeVoice();
}}
disabled={!isRecording}
>
</button>
</div>
</div>
{status === VoiceRecognitionStatus.MATCHED ||
status === VoiceRecognitionStatus.NOT_MATCHED ? (
<div className={styles.resultContainer}>
<div className={styles.scoreBar}>
<div
className={styles.scoreIndicator}
style={{ width: `${Math.max(0, matchScore * 100)}%` }}
></div>
</div>
<div className={styles.scoreValue}>
: {(matchScore * 100).toFixed(2)}%
</div>
</div>
) : null}
</div>
);
};
export default TensorFlow;