Audio-to-Audio
PyTorch
ONNX
Safetensors
TensorRT
English
fast_oobleck_decoder
ace-step
audio
vae
knowledge-distillation
music-generation
streaming
dreamvae
custom_code
Instructions to use daydreamlive/DreamVAE with libraries, inference providers, notebooks, and local apps. Follow these links to get started.
- Libraries
- TensorRT
How to use daydreamlive/DreamVAE with TensorRT:
# No code snippets available yet for this library. # To use this model, check the repository files and the library's documentation. # Want to help? PRs adding snippets are welcome at: # https://github.com/huggingface/huggingface.js
- Notebooks
- Google Colab
- Kaggle
| #!/usr/bin/env python3 | |
| """Compare TRT student engine output against PyTorch teacher and student. | |
| Encodes real audio through the VAE, decodes with all three backends, | |
| saves WAVs and prints SNR comparisons. | |
| """ | |
| import math | |
| import time | |
| from pathlib import Path | |
| import soundfile as sf | |
| import torch | |
| import torch.nn as nn | |
| from torch.nn.utils import weight_norm | |
| # -- Student model (must match training script) -- | |
| class Snake1d(nn.Module): | |
| def __init__(self, hidden_dim, logscale=True): | |
| super().__init__() | |
| self.alpha = nn.Parameter(torch.zeros(1, hidden_dim, 1)) | |
| self.beta = nn.Parameter(torch.zeros(1, hidden_dim, 1)) | |
| self.logscale = logscale | |
| def forward(self, x): | |
| shape = x.shape | |
| a = self.alpha if not self.logscale else torch.exp(self.alpha) | |
| b = self.beta if not self.logscale else torch.exp(self.beta) | |
| x = x.reshape(shape[0], shape[1], -1) | |
| x = x + (b + 1e-9).reciprocal() * torch.sin(a * x).pow(2) | |
| return x.reshape(shape) | |
| class FastResidualUnit(nn.Module): | |
| def __init__(self, dim, dilation=1): | |
| super().__init__() | |
| pad = ((7 - 1) * dilation) // 2 | |
| self.snake1 = Snake1d(dim) | |
| self.conv1 = weight_norm(nn.Conv1d(dim, dim, 7, dilation=dilation, padding=pad)) | |
| self.snake2 = Snake1d(dim) | |
| self.conv2 = weight_norm(nn.Conv1d(dim, dim, 1)) | |
| def forward(self, x): | |
| h = self.conv1(self.snake1(x)) | |
| h = self.conv2(self.snake2(h)) | |
| pad = (x.shape[-1] - h.shape[-1]) // 2 | |
| if pad > 0: | |
| x = x[..., pad:-pad] | |
| return x + h | |
| class FastDecoderBlock(nn.Module): | |
| def __init__(self, in_dim, out_dim, stride=1): | |
| super().__init__() | |
| self.snake1 = Snake1d(in_dim) | |
| self.conv_t = weight_norm(nn.ConvTranspose1d(in_dim, out_dim, 2 * stride, stride=stride, padding=math.ceil(stride / 2))) | |
| self.res1 = FastResidualUnit(out_dim, 1) | |
| self.res2 = FastResidualUnit(out_dim, 3) | |
| def forward(self, x): | |
| x = self.snake1(x) | |
| x = self.conv_t(x) | |
| x = self.res1(x) | |
| x = self.res2(x) | |
| return x | |
| class FastOobleckDecoder(nn.Module): | |
| def __init__(self, channels=128, input_channels=64, audio_channels=2, upsampling_ratios=None, channel_multiples=None): | |
| super().__init__() | |
| upsampling_ratios = upsampling_ratios or [10, 6, 4, 4, 2] | |
| channel_multiples = channel_multiples or [1, 2, 4, 8, 8] | |
| cm = [1] + channel_multiples | |
| self.conv1 = weight_norm(nn.Conv1d(input_channels, channels * cm[-1], 7, padding=3)) | |
| blocks = [] | |
| for i, s in enumerate(upsampling_ratios): | |
| blocks.append(FastDecoderBlock(channels * cm[len(upsampling_ratios) - i], channels * cm[len(upsampling_ratios) - i - 1], s)) | |
| self.blocks = nn.ModuleList(blocks) | |
| self.final_snake = Snake1d(channels) | |
| self.conv2 = weight_norm(nn.Conv1d(channels, audio_channels, 7, padding=3, bias=False)) | |
| def forward(self, z): | |
| x = self.conv1(z) | |
| for b in self.blocks: | |
| x = b(x) | |
| return self.conv2(self.final_snake(x)) | |
| def snr(ref, gen): | |
| min_len = min(ref.shape[-1], gen.shape[-1]) | |
| ref, gen = ref[..., :min_len], gen[..., :min_len] | |
| noise = ref - gen | |
| return 10 * torch.log10((ref ** 2).mean() / ((noise ** 2).mean() + 1e-10)).item() | |
| def main(): | |
| device = "cuda" | |
| out_dir = Path("research_program/vae_distillation/results/trt_verify_60s") | |
| out_dir.mkdir(parents=True, exist_ok=True) | |
| # Load audio | |
| audio_path = "tests/fixtures/techno.wav" | |
| data, sr = sf.read(audio_path, dtype="float32") | |
| assert sr == 48000 | |
| waveform = torch.tensor(data, dtype=torch.float32).T # [2, samples] | |
| # Take 60s clip (or full file if shorter) | |
| samples_60s = 48000 * 60 | |
| clip = waveform[:, :samples_60s].unsqueeze(0).to(device) | |
| peak = clip.abs().max() | |
| if peak > 1e-6: | |
| clip = clip / peak | |
| # Load teacher VAE | |
| print("Loading teacher VAE...") | |
| from diffusers import AutoencoderOobleck | |
| vae = AutoencoderOobleck.from_pretrained("ACE-Step/Ace-Step1.5", subfolder="vae") | |
| vae = vae.to(device, dtype=torch.float32).eval() | |
| # Encode | |
| print("Encoding...") | |
| with torch.no_grad(): | |
| z = vae.encode(clip).latent_dist.sample() | |
| print(f"Latent shape: {list(z.shape)}") | |
| # Decode with teacher (PyTorch) | |
| print("Decoding with teacher (PyTorch)...") | |
| with torch.no_grad(): | |
| teacher_audio = vae.decoder(z) | |
| # Load student (PyTorch) | |
| print("Loading student...") | |
| ckpt = torch.load("research_program/vae_distillation/checkpoints/student_step620000.pt", | |
| map_location=device, weights_only=False) | |
| student = FastOobleckDecoder().to(device).eval() | |
| student.load_state_dict(ckpt["student_state_dict"]) | |
| print("Decoding with student (PyTorch)...") | |
| with torch.no_grad(): | |
| student_audio = student(z) | |
| # Load TRT engines | |
| from polygraphy.backend.common import bytes_from_path | |
| from polygraphy.backend.trt import engine_from_bytes | |
| from polygraphy import cuda as pg_cuda | |
| stream = pg_cuda.Stream() | |
| lat = z.float().contiguous() | |
| def trt_decode(engine_path_str): | |
| engine = engine_from_bytes(bytes_from_path(engine_path_str)) | |
| ctx = engine.create_execution_context() | |
| ctx.set_input_shape("latents", tuple(lat.shape)) | |
| ctx.set_tensor_address("latents", lat.data_ptr()) | |
| out_shape = tuple(ctx.get_tensor_shape("audio")) | |
| buf = torch.empty(out_shape, dtype=torch.float32, device=device) | |
| ctx.set_tensor_address("audio", buf.data_ptr()) | |
| ctx.execute_async_v3(stream.ptr) | |
| stream.synchronize() | |
| return buf.clone(), engine, ctx | |
| print("Decoding with teacher (TRT FP16)...") | |
| teacher_trt_path = "trt_engines/vae_decode_fp16_240s/vae_decode_fp16_240s.engine" | |
| teacher_trt_audio, teacher_trt_engine, teacher_trt_ctx = trt_decode(teacher_trt_path) | |
| print("Decoding with student (TRT FP16)...") | |
| student_trt_path = "trt_engines/dreamvae_decode_fp16_240s/dreamvae_decode_fp16_240s.engine" | |
| trt_audio, student_trt_engine, student_trt_ctx = trt_decode(student_trt_path) | |
| # Trim to common length | |
| min_len = min(teacher_audio.shape[-1], student_audio.shape[-1], | |
| teacher_trt_audio.shape[-1], trt_audio.shape[-1], clip.shape[-1]) | |
| teacher_audio = teacher_audio[..., :min_len] | |
| student_audio = student_audio[..., :min_len] | |
| teacher_trt_audio = teacher_trt_audio[..., :min_len] | |
| trt_audio = trt_audio[..., :min_len] | |
| original = clip[..., :min_len] | |
| # Speed benchmark (20 trials each) | |
| duration_s = min_len / 48000 | |
| print(f"\nSpeed benchmark ({duration_s:.1f}s audio, 20 trials)...") | |
| # Warmup | |
| with torch.no_grad(): | |
| for _ in range(3): | |
| _ = vae.decoder(z) | |
| _ = student(z) | |
| teacher_trt_ctx.execute_async_v3(stream.ptr) | |
| stream.synchronize() | |
| student_trt_ctx.execute_async_v3(stream.ptr) | |
| stream.synchronize() | |
| teacher_pt_times = [] | |
| student_pt_times = [] | |
| teacher_trt_times = [] | |
| student_trt_times = [] | |
| with torch.no_grad(): | |
| for _ in range(20): | |
| torch.cuda.synchronize() | |
| t0 = time.time() | |
| _ = vae.decoder(z) | |
| torch.cuda.synchronize() | |
| teacher_pt_times.append(time.time() - t0) | |
| torch.cuda.synchronize() | |
| t0 = time.time() | |
| _ = student(z) | |
| torch.cuda.synchronize() | |
| student_pt_times.append(time.time() - t0) | |
| torch.cuda.synchronize() | |
| t0 = time.time() | |
| teacher_trt_ctx.execute_async_v3(stream.ptr) | |
| stream.synchronize() | |
| teacher_trt_times.append(time.time() - t0) | |
| torch.cuda.synchronize() | |
| t0 = time.time() | |
| student_trt_ctx.execute_async_v3(stream.ptr) | |
| stream.synchronize() | |
| student_trt_times.append(time.time() - t0) | |
| avg_tp = sum(teacher_pt_times) / len(teacher_pt_times) * 1000 | |
| avg_sp = sum(student_pt_times) / len(student_pt_times) * 1000 | |
| avg_tt = sum(teacher_trt_times) / len(teacher_trt_times) * 1000 | |
| avg_st = sum(student_trt_times) / len(student_trt_times) * 1000 | |
| print(f" Teacher (PyTorch): {avg_tp:.1f} ms") | |
| print(f" Teacher (TRT FP16): {avg_tt:.1f} ms ({avg_tp/avg_tt:.2f}x vs teacher PT)") | |
| print(f" Student (PyTorch): {avg_sp:.1f} ms ({avg_tp/avg_sp:.2f}x vs teacher PT)") | |
| print(f" Student (TRT FP16): {avg_st:.1f} ms ({avg_tp/avg_st:.2f}x vs teacher PT)") | |
| # Comparisons | |
| print(f"\n{'='*60}") | |
| print("SNR comparisons (higher = closer match):") | |
| print(f"{'='*60}") | |
| print(f" Teacher (PT) vs original: {snr(original, teacher_audio):.1f} dB") | |
| print(f" Teacher (TRT) vs original: {snr(original, teacher_trt_audio):.1f} dB") | |
| print(f" Student (PT) vs original: {snr(original, student_audio):.1f} dB") | |
| print(f" Student (TRT) vs original: {snr(original, trt_audio):.1f} dB") | |
| print(f" Student (PT) vs teacher (PT): {snr(teacher_audio, student_audio):.1f} dB") | |
| print(f" Student (TRT) vs teacher (PT): {snr(teacher_audio, trt_audio):.1f} dB") | |
| print(f" Teacher TRT vs teacher PT: {snr(teacher_audio, teacher_trt_audio):.1f} dB") | |
| print(f" Student TRT vs student PT: {snr(student_audio, trt_audio):.1f} dB") | |
| print() | |
| # Save WAVs | |
| for name, audio in [("original", original), ("teacher_pytorch", teacher_audio), | |
| ("teacher_trt_fp16", teacher_trt_audio), | |
| ("student_pytorch", student_audio), ("student_trt_fp16", trt_audio)]: | |
| path = out_dir / f"{name}.wav" | |
| sf.write(str(path), audio[0].cpu().numpy().T, 48000) | |
| print(f" Saved: {path}") | |
| print(f"\nListen and compare in {out_dir}") | |
| if __name__ == "__main__": | |
| main() | |