import threading from dataclasses import dataclass from typing import Callable, Generator, override import fastrtc import librosa import numpy as np from ten_vad import TenVad @dataclass class VADEvent: interrupt_signal: bool | None = None full_audio: tuple[int, np.ndarray] | None = None global_ten_vad: TenVad | None = None global_vad_lock = threading.Lock() def global_vad_process(audio_data: np.ndarray) -> float: """ Process audio data (hop_size=256) with global TenVad instance. Returns: speech probability. """ global global_ten_vad with global_vad_lock: if global_ten_vad is None: global_ten_vad = TenVad() prob, _ = global_ten_vad.process(audio_data) return prob class RealtimeVAD: def __init__( self, src_sr: int = 24000, start_threshold: float = 0.8, end_threshold: float = 0.7, pad_start_s: float = 0.6, min_positive_s: float = 0.4, min_silence_s: float = 1.2, ): self.src_sr = src_sr self.vad_sr = 16000 self.hop_size = 256 self.start_threshold = start_threshold self.end_threshold = end_threshold self.pad_start_s = pad_start_s self.min_positive_s = min_positive_s self.min_silence_s = min_silence_s self.vad_buffer = np.array([], dtype=np.int16) """ VAD Buffer to store audio data for VAD processing Stores 16kHz int16 PCM. Process and cut for each `hop_size` samples. """ self.src_buffer = np.array([], dtype=np.int16) """ Source Buffer to store original audio data Stores original sampling rate (24kHz) int16 PCM. Cut when pause detected (after `min_silence_s`). Sliding window `pad_start_s` when inactive. """ self.vad_buffer_offset = 0 self.src_buffer_offset = 0 self.active = False self.interrupt_signal = False self.sum_positive_s = 0.0 self.silence_start_s: float | None = None def process(self, audio_data: np.ndarray): if audio_data.ndim == 2: # FastRTC style [channels, samples] audio_data = audio_data[0] # Append to buffers self.src_buffer = np.concatenate((self.src_buffer, audio_data)) vad_audio_data = librosa.resample( audio_data.astype(np.float32) / 32768.0, orig_sr=self.src_sr, target_sr=self.vad_sr, ) vad_audio_data = (vad_audio_data * 32767.0).round().astype(np.int16) self.vad_buffer = np.concatenate((self.vad_buffer, vad_audio_data)) vad_buffer_size = self.vad_buffer.shape[0] def process_chunk(chunk_offset_s: float, vad_chunk: np.ndarray): speech_prob = global_vad_process(vad_chunk) hop_s = self.hop_size / self.vad_sr if not self.active: if speech_prob >= self.start_threshold: self.active = True self.sum_positive_s = hop_s print(f"[VAD] Active at {chunk_offset_s:.2f}s, {speech_prob=:.3f}") else: new_src_offset = int( (chunk_offset_s - self.pad_start_s) * self.src_sr ) cut_pos = new_src_offset - self.src_buffer_offset if cut_pos > 0: self.src_buffer = self.src_buffer[cut_pos:] self.src_buffer_offset = new_src_offset return chunk_src_pos = int(chunk_offset_s * self.src_sr) if speech_prob >= self.end_threshold: self.silence_start_s = None self.sum_positive_s += hop_s if ( not self.interrupt_signal and self.sum_positive_s >= self.min_positive_s ): self.interrupt_signal = True yield VADEvent(interrupt_signal=True) print( f"[VAD] Interrupt signal at {chunk_offset_s:.2f}s, {speech_prob=:.3f}" ) elif self.silence_start_s is None: self.silence_start_s = chunk_offset_s if ( self.silence_start_s is not None and chunk_offset_s - self.silence_start_s >= self.min_silence_s ): # Inactive now cut_pos = chunk_src_pos - self.src_buffer_offset if self.interrupt_signal: webrtc_audio = self.src_buffer[np.newaxis, :cut_pos] yield VADEvent(full_audio=(self.src_sr, webrtc_audio)) print( f"[VAD] Full audio at {chunk_offset_s:.2f}s, {webrtc_audio.shape=}" ) self.src_buffer = self.src_buffer[cut_pos:] self.src_buffer_offset = chunk_src_pos self.active = False self.interrupt_signal = False self.sum_positive_s = 0.0 self.silence_start_s = None processed_samples = 0 for chunk_pos in range(0, vad_buffer_size - self.hop_size, self.hop_size): processed_samples = chunk_pos + self.hop_size chunk_offset_s = (self.vad_buffer_offset + chunk_pos) / self.vad_sr vad_chunk = self.vad_buffer[chunk_pos : chunk_pos + self.hop_size] yield from process_chunk(chunk_offset_s, vad_chunk) self.vad_buffer = self.vad_buffer[processed_samples:] self.vad_buffer_offset += processed_samples def init_global_ten_vad(input_sample_rate: int = 24000): """ Call this once at the start of the program to avoid latency on first use. No-op if already initialized. """ global global_ten_vad require_warmup = False with global_vad_lock: if global_ten_vad is None: global_ten_vad = TenVad() require_warmup = True if require_warmup: print("[VAD] Initializing global TenVad...") realtime_vad = RealtimeVAD(src_sr=input_sample_rate) for _ in range(25): # Warmup with 1 second of silence for _ in realtime_vad.process(np.zeros(960, dtype=np.int16)): pass print("[VAD] Global VAD initialized") type StreamerGenerator = Generator[fastrtc.tracks.EmitType, None, None] type StreamerFn = Callable[[tuple[int, np.ndarray], str], StreamerGenerator] class VADStreamHandler(fastrtc.StreamHandler): def __init__( self, streamer_fn: StreamerFn, input_sample_rate: int = 24000, ): super().__init__( "mono", 24000, None, input_sample_rate, 30, ) self.streamer_fn = streamer_fn self.realtime_vad = RealtimeVAD(src_sr=input_sample_rate) self.generator: StreamerGenerator | None = None init_global_ten_vad() @override def emit(self) -> fastrtc.tracks.EmitType: if self.generator is None: return None try: return next(self.generator) except StopIteration: self.generator = None return None @override def receive(self, frame: tuple[int, np.ndarray]): _, audio_data = frame for event in self.realtime_vad.process(audio_data): if event.interrupt_signal: self.generator = None self.clear_queue() if event.full_audio is not None: self.wait_for_args_sync() self.latest_args[0] = event.full_audio self.generator = self.streamer_fn(*self.latest_args) @override def copy(self): return VADStreamHandler( self.streamer_fn, input_sample_rate=self.input_sample_rate, )