maya1-txt2speech / maya1 /snac_decoder.py
Rajkumar Pramanik "RJproz
initial commit
3c92819
import torch
import numpy as np
import asyncio
from typing import List, Optional, Tuple
from snac import SNAC
from .constants import (
CODE_END_TOKEN_ID,
CODE_TOKEN_OFFSET,
SNAC_MODEL_NAME,
SNAC_SAMPLE_RATE,
SNAC_TOKENS_PER_FRAME,
)
class SNACDecoder:
"""
SNAC Decoder for maya1.
Unpacks 7-token SNAC frames and decodes to audio waveforms.
Unpacking logic is the EXACT INVERSE of training preprocessing.
Supports async batching for concurrent requests.
CRITICAL: Any mismatch in unpacking will produce garbage audio.
"""
def __init__(
self,
device: str = "cuda",
compile_decoder: bool = False,
enable_batching: bool = False,
max_batch_size: int = 64,
batch_timeout_ms: int = 15,
):
"""
Initialize SNAC decoder.
Args:
device: Device for SNAC model (cuda/cpu)
compile_decoder: Use torch.compile for speedup
enable_batching: Enable async batching
max_batch_size: Max sequences to batch together
batch_timeout_ms: Max wait time before processing batch
"""
self.device = device
self.enable_batching = enable_batching
self.max_batch_size = max_batch_size
self.batch_timeout_ms = batch_timeout_ms
print(f"Loading SNAC 24kHz model to {device}...")
self.snac_model = SNAC.from_pretrained(SNAC_MODEL_NAME).eval().to(device)
if compile_decoder:
print(f"Compiling SNAC decoder with torch.compile...")
self._compile_model()
# Batching infrastructure
if enable_batching:
self.request_queue = asyncio.Queue()
self.batch_processor_task = None
self._running = False
print(f"Batching enabled (max_batch={max_batch_size}, timeout={batch_timeout_ms}ms)")
print(f"SNAC decoder initialized")
def _compile_model(self):
"""Compile SNAC decoder with torch.compile"""
# Warm up with various sizes
for frames in [4, 16, 32]:
dummy_codes = [
torch.randint(0, 4096, (1, frames), device=self.device),
torch.randint(0, 4096, (1, frames * 2), device=self.device),
torch.randint(0, 4096, (1, frames * 4), device=self.device),
]
with torch.inference_mode():
z_q = self.snac_model.quantizer.from_codes(dummy_codes)
_ = self.snac_model.decoder(z_q)
# Apply compilation
self.snac_model.decoder = torch.compile(
self.snac_model.decoder,
mode="max-autotune"
)
self.snac_model.quantizer = torch.compile(
self.snac_model.quantizer,
mode="reduce-overhead"
)
print(f"SNAC decoder compiled")
def unpack_snac_from_7(self, vocab_ids: List[int]) -> List[List[int]]:
"""
Unpack 7-token SNAC frames to 3 hierarchical levels.
This is the EXACT INVERSE of the training preprocessing function
`pack_snac_to_7_and_offset()`.
Frame structure:
[slot0, slot1, slot2, slot3, slot4, slot5, slot6]
Unpacking:
- slot0: L1[i]
- slot1: L2[2*i] (even index)
- slot2: L3[4*i + 0]
- slot3: L3[4*i + 1]
- slot4: L2[2*i + 1] (odd index)
- slot5: L3[4*i + 2]
- slot6: L3[4*i + 3]
Args:
vocab_ids: List of SNAC token IDs (128266-156937)
Must be divisible by 7
Returns:
[L1, L2, L3] where:
L1: n elements (coarse level)
L2: 2n elements (medium level)
L3: 4n elements (fine level)
"""
# Strip EOS token if present
if vocab_ids and vocab_ids[-1] == CODE_END_TOKEN_ID:
vocab_ids = vocab_ids[:-1]
# Ensure complete frames (divisible by 7)
frames = len(vocab_ids) // SNAC_TOKENS_PER_FRAME
vocab_ids = vocab_ids[:frames * SNAC_TOKENS_PER_FRAME]
if frames == 0:
return [[], [], []]
l1, l2, l3 = [], [], []
for i in range(frames):
# Extract 7 slots for this frame
slots = vocab_ids[i*7:(i+1)*7]
# Subtract offset (128266) and mod 4096 to get original codes
# Each level uses 4096 codes (0-4095)
l1.append((slots[0] - CODE_TOKEN_OFFSET) % 4096)
l2.extend([
(slots[1] - CODE_TOKEN_OFFSET) % 4096, # Even index
(slots[4] - CODE_TOKEN_OFFSET) % 4096, # Odd index
])
l3.extend([
(slots[2] - CODE_TOKEN_OFFSET) % 4096,
(slots[3] - CODE_TOKEN_OFFSET) % 4096,
(slots[5] - CODE_TOKEN_OFFSET) % 4096,
(slots[6] - CODE_TOKEN_OFFSET) % 4096,
])
return [l1, l2, l3]
@torch.inference_mode()
def decode(
self,
snac_tokens: List[int],
trim_warmup: bool = True,
trim_amount: Optional[int] = None,
use_sliding_window: bool = False
) -> Optional[np.ndarray]:
"""
Decode SNAC tokens to audio waveform.
Args:
snac_tokens: List of SNAC token IDs (7*n tokens)
trim_warmup: Whether to trim SNAC warmup samples (default: True)
trim_amount: Number of samples to trim (default: 2048 for first chunk, 0 for others)
Can be set to a smaller value (e.g., 512) for intermediate chunks
use_sliding_window: If True, only return middle 2048 samples (for sliding window streaming)
Returns:
Audio waveform as numpy array (float32, 24kHz mono)
Shape: (samples,)
Returns None if not enough tokens
"""
if len(snac_tokens) < SNAC_TOKENS_PER_FRAME:
print(f"Not enough SNAC tokens: {len(snac_tokens)} < {SNAC_TOKENS_PER_FRAME}")
return None
# Unpack to 3 levels
levels = self.unpack_snac_from_7(snac_tokens)
if not levels[0]: # No frames after unpacking
return None
# Convert to tensors
codes = [
torch.tensor(level, dtype=torch.long, device=self.device).unsqueeze(0)
for level in levels
]
# Decode through SNAC
z_q = self.snac_model.quantizer.from_codes(codes)
audio = self.snac_model.decoder(z_q)
# Extract audio (remove padding if any)
# SNAC decoder outputs: [batch, 1, samples]
audio = audio[0, 0].cpu().numpy()
# Sliding window mode: only keep middle 2048 samples
# This eliminates popping/cracking when using overlapping 28-token windows
if use_sliding_window:
if len(audio) >= 4096:
audio = audio[2048:4096] # Keep middle portion only
else:
# For shorter audio, keep everything (final chunk)
pass
else:
# Standard mode: trim warm-up samples
# Default: 2048 samples for first chunk, 0 for subsequent chunks
# Can be customized via trim_amount parameter
if trim_warmup:
if trim_amount is None:
trim_amount = 2048 # Default full trim
if len(audio) > trim_amount:
audio = audio[trim_amount:]
return audio
def decode_to_bytes(
self,
snac_tokens: List[int],
trim_warmup: bool = True,
use_sliding_window: bool = False
) -> Optional[bytes]:
"""
Decode SNAC tokens to audio bytes (int16 PCM).
Args:
snac_tokens: List of SNAC token IDs
trim_warmup: Whether to trim SNAC warmup samples (default: True)
use_sliding_window: If True, only return middle 2048 samples (for sliding window streaming)
Returns:
Audio as bytes (int16 PCM, 24kHz mono)
Returns None if decode fails
"""
audio = self.decode(snac_tokens, trim_warmup=trim_warmup, use_sliding_window=use_sliding_window)
if audio is None:
return None
# Convert float32 to int16 PCM
audio_int16 = (audio * 32767).astype(np.int16)
return audio_int16.tobytes()
def validate_tokens(self, snac_tokens: List[int]) -> bool:
"""
Validate SNAC tokens before decoding.
Args:
snac_tokens: List of SNAC token IDs
Returns:
True if valid, False otherwise
"""
# Check minimum length
if len(snac_tokens) < SNAC_TOKENS_PER_FRAME:
print(f"Too few tokens: {len(snac_tokens)}")
return False
# Check divisibility by 7
if len(snac_tokens) % SNAC_TOKENS_PER_FRAME != 0:
print(f" Warning: Token count {len(snac_tokens)} not divisible by 7")
print(f" Will truncate to {(len(snac_tokens) // 7) * 7}")
# Check token range
for i, token_id in enumerate(snac_tokens):
if token_id < CODE_TOKEN_OFFSET or token_id > 156937:
print(f" Invalid token at position {i}: {token_id}")
print(f" Expected range: [{CODE_TOKEN_OFFSET}, 156937]")
return False
return True
# ========== Async Batching Methods ==========
@property
def is_running(self) -> bool:
"""Check if batch processor is running."""
return self._running if self.enable_batching else False
async def start_batch_processor(self):
"""Start the background batch processor task."""
if not self.enable_batching:
return
if self._running:
print("Batch processor already running")
return
self._running = True
self.batch_processor_task = asyncio.create_task(self._batch_processor_loop())
print("Batch processor started")
async def stop_batch_processor(self):
"""Stop the background batch processor task."""
if not self.enable_batching:
return
if not self._running:
return
self._running = False
if self.batch_processor_task:
self.batch_processor_task.cancel()
try:
await self.batch_processor_task
except asyncio.CancelledError:
pass
print("Batch processor stopped")
async def decode_single_async(
self,
snac_tokens: List[int],
trim_warmup: bool = True,
use_sliding_window: bool = False
) -> Optional[bytes]:
"""
Async decode for batching support.
Queues the request and waits for batched processing.
Args:
snac_tokens: List of SNAC token IDs
trim_warmup: Whether to trim SNAC warmup samples (default: True)
use_sliding_window: If True, only return middle 2048 samples (for sliding window streaming)
Returns:
Audio bytes or None if decode fails
"""
if not self.enable_batching:
# Fallback to synchronous decode
return self.decode_to_bytes(snac_tokens, trim_warmup=trim_warmup, use_sliding_window=use_sliding_window)
# Create future for result
result_future = asyncio.Future()
# Add to queue (include trim_warmup and sliding_window flags)
await self.request_queue.put((snac_tokens, trim_warmup, use_sliding_window, result_future))
# Wait for result
return await result_future
async def _batch_processor_loop(self):
"""Background task that processes batched decode requests."""
while self._running:
try:
# Collect batch
batch = await self._collect_batch()
if not batch:
continue
# Process batch
await self._process_batch(batch)
except asyncio.CancelledError:
break
except Exception as e:
print(f"Batch processor error: {e}")
import traceback
traceback.print_exc()
async def _collect_batch(self) -> List[Tuple[List[int], bool, bool, asyncio.Future]]:
"""
Collect requests into a batch.
Waits for timeout or until batch is full.
Returns:
List of (tokens, trim_warmup, use_sliding_window, future) tuples
"""
batch = []
timeout_sec = self.batch_timeout_ms / 1000.0
try:
# Wait for first request (blocking)
first_item = await asyncio.wait_for(
self.request_queue.get(),
timeout=timeout_sec
)
batch.append(first_item)
# Collect more requests (non-blocking)
while len(batch) < self.max_batch_size:
try:
item = await asyncio.wait_for(
self.request_queue.get(),
timeout=timeout_sec
)
batch.append(item)
except asyncio.TimeoutError:
break # Timeout reached, process what we have
except asyncio.TimeoutError:
# No requests in timeout period
pass
return batch
@torch.inference_mode()
async def _process_batch(self, batch: List[Tuple[List[int], bool, bool, asyncio.Future]]):
"""
Process a batch of decode requests.
Args:
batch: List of (tokens, trim_warmup, use_sliding_window, future) tuples
"""
if not batch:
return
# Extract components
token_sequences = [item[0] for item in batch]
trim_warmup_flags = [item[1] for item in batch]
sliding_window_flags = [item[2] for item in batch]
futures = [item[3] for item in batch]
lengths = [len(tokens) for tokens in token_sequences]
can_batch_efficiently = len(set(lengths)) == 1
if can_batch_efficiently and len(batch) > 1:
# Efficient batching: all same length
try:
audio_bytes_list = await self._decode_batch_same_length(
token_sequences, trim_warmup_flags, sliding_window_flags
)
# Set results
for future, audio_bytes in zip(futures, audio_bytes_list):
if not future.done():
future.set_result(audio_bytes)
except Exception as e:
# Set exceptions
for future in futures:
if not future.done():
future.set_exception(e)
else:
# Sequential decode (different lengths or single item)
for tokens, trim_warmup, use_sliding_window, future in batch:
try:
audio_bytes = self.decode_to_bytes(
tokens, trim_warmup=trim_warmup, use_sliding_window=use_sliding_window
)
if not future.done():
future.set_result(audio_bytes)
except Exception as e:
if not future.done():
future.set_exception(e)
async def _decode_batch_same_length(
self,
token_sequences: List[List[int]],
trim_warmup_flags: List[bool],
sliding_window_flags: List[bool]
) -> List[Optional[bytes]]:
"""
Decode multiple sequences with same length in parallel.
Args:
token_sequences: List of token sequences (all same length)
trim_warmup_flags: List of trim_warmup flags for each sequence
sliding_window_flags: List of use_sliding_window flags for each sequence
Returns:
List of audio bytes
"""
if not token_sequences:
return []
# Unpack all sequences
unpacked_list = [self.unpack_snac_from_7(tokens) for tokens in token_sequences]
# Check all have valid frames
valid_indices = [i for i, levels in enumerate(unpacked_list) if levels[0]]
if not valid_indices:
return [None] * len(token_sequences)
# Stack into batched tensors
batch_size = len(valid_indices)
frames = len(unpacked_list[valid_indices[0]][0])
# Build batched codes [batch, frames], [batch, 2*frames], [batch, 4*frames]
codes = [
torch.stack([
torch.tensor(unpacked_list[i][level_idx], dtype=torch.long, device=self.device)
for i in valid_indices
], dim=0)
for level_idx in range(3)
]
# Batched decode
z_q = self.snac_model.quantizer.from_codes(codes)
audio_batch = self.snac_model.decoder(z_q) # [batch, 1, samples]
# Extract and convert to bytes
audio_bytes_list = [None] * len(token_sequences)
for batch_idx, orig_idx in enumerate(valid_indices):
audio = audio_batch[batch_idx, 0].detach().cpu().numpy()
# Apply sliding window or trim warmup based on flags
if sliding_window_flags[orig_idx]:
# Sliding window mode: keep middle 2048 samples only
if len(audio) >= 4096:
audio = audio[2048:4096]
else:
# Standard mode: trim warm-up if requested
if trim_warmup_flags[orig_idx] and len(audio) > 2048:
audio = audio[2048:]
# Convert to int16
audio_int16 = (audio * 32767).astype(np.int16)
audio_bytes_list[orig_idx] = audio_int16.tobytes()
return audio_bytes_list