""" Maya1 Streaming Pipeline - Sliding Window Approach Implements sliding window technique for smooth streaming without artifacts. """ import asyncio from typing import AsyncGenerator, Optional from vllm import SamplingParams from .constants import ( CODE_END_TOKEN_ID, SNAC_MIN_ID, SNAC_MAX_ID, DEFAULT_TEMPERATURE, DEFAULT_TOP_P, DEFAULT_MAX_TOKENS, DEFAULT_MIN_TOKENS, DEFAULT_REPETITION_PENALTY, DEFAULT_SEED, ) class Maya1SlidingWindowPipeline: """ Streaming TTS pipeline using sliding window approach. Decodes overlapping 28-token windows (4 frames) and keeps only the middle 2048 samples for smooth audio continuity. """ # Sliding window configuration WINDOW_SIZE = 28 # 4 frames (7 tokens per frame) YIELD_STRIDE = 7 # Yield every 1 frame MIDDLE_SAMPLES = 2048 # Keep middle 2048 samples from each decode def __init__(self, model, prompt_builder, snac_decoder): """ Initialize sliding window streaming pipeline. Args: model: Maya1Model instance prompt_builder: Maya1PromptBuilder instance snac_decoder: SNACDecoder instance """ self.model = model self.prompt_builder = prompt_builder self.snac_decoder = snac_decoder print(f"Sliding window pipeline initialized") async def generate_speech_stream( self, description: str, text: str, temperature: float = DEFAULT_TEMPERATURE, top_p: float = DEFAULT_TOP_P, max_tokens: int = DEFAULT_MAX_TOKENS, repetition_penalty: float = DEFAULT_REPETITION_PENALTY, seed: Optional[int] = None, ) -> AsyncGenerator[bytes, None]: """ Generate speech audio with sliding window streaming. Args: description: Voice description text: Text to synthesize (may include tags) temperature: Sampling temperature top_p: Nucleus sampling max_tokens: Max SNAC tokens to generate repetition_penalty: Prevent loops seed: Random seed Yields: Audio bytes (int16 PCM, 24kHz mono) """ # Build prompt prompt = self.prompt_builder.build_prefix(description, text) # Configure sampling sampling_params = SamplingParams( temperature=temperature, top_p=top_p, max_tokens=max_tokens, min_tokens=DEFAULT_MIN_TOKENS, repetition_penalty=repetition_penalty, stop_token_ids=[CODE_END_TOKEN_ID], seed=seed if seed is not None else DEFAULT_SEED, ) # Stream tokens snac_buffer = [] last_yield_position = 0 chunk_count = 0 total_tokens_seen = 0 async for output in self.model.generate_stream(prompt, sampling_params): # Get latest generated tokens (cumulative list) generated_token_ids = output.outputs[0].token_ids # Process only NEW tokens since last iteration new_tokens = generated_token_ids[total_tokens_seen:] total_tokens_seen = len(generated_token_ids) # Collect SNAC codes from new tokens for token_id in new_tokens: # Stop if we hit EOS if token_id == CODE_END_TOKEN_ID: break # Only collect valid SNAC tokens if SNAC_MIN_ID <= token_id <= SNAC_MAX_ID: snac_buffer.append(token_id) # Yield audio when we have enough tokens for a window while len(snac_buffer) >= last_yield_position + self.WINDOW_SIZE: # Get window of 28 tokens window_start = last_yield_position window_end = window_start + self.WINDOW_SIZE window = snac_buffer[window_start:window_end] if len(window) == self.WINDOW_SIZE: # Decode window to audio audio_bytes = await self.snac_decoder.decode_single_async(window) if audio_bytes: # Extract middle portion of audio audio_samples = len(audio_bytes) // 2 middle_start_sample = (audio_samples - self.MIDDLE_SAMPLES) // 2 middle_end_sample = middle_start_sample + self.MIDDLE_SAMPLES # Convert to byte positions middle_start_byte = middle_start_sample * 2 middle_end_byte = middle_end_sample * 2 # Extract middle chunk audio_chunk = audio_bytes[middle_start_byte:middle_end_byte] chunk_count += 1 if chunk_count == 1: print(f" First chunk ready") yield audio_chunk # Move forward by stride last_yield_position += self.YIELD_STRIDE # Check if generation is done if CODE_END_TOKEN_ID in new_tokens: break # Final chunk: decode remaining tokens remaining_tokens = len(snac_buffer) - last_yield_position if remaining_tokens >= self.WINDOW_SIZE: window = snac_buffer[-self.WINDOW_SIZE:] audio_bytes = await self.snac_decoder.decode_single_async(window) if audio_bytes: yield audio_bytes[-self.MIDDLE_SAMPLES * 2:] frames = len(snac_buffer) // 7 duration = frames / 6.86 print(f"Streamed {chunk_count} chunks (~{duration:.1f}s audio)")