#!/usr/bin/env python3 """Compare TRT student engine output against PyTorch teacher and student. Encodes real audio through the VAE, decodes with all three backends, saves WAVs and prints SNR comparisons. """ import math import time from pathlib import Path import soundfile as sf import torch import torch.nn as nn from torch.nn.utils import weight_norm # -- Student model (must match training script) -- class Snake1d(nn.Module): def __init__(self, hidden_dim, logscale=True): super().__init__() self.alpha = nn.Parameter(torch.zeros(1, hidden_dim, 1)) self.beta = nn.Parameter(torch.zeros(1, hidden_dim, 1)) self.logscale = logscale def forward(self, x): shape = x.shape a = self.alpha if not self.logscale else torch.exp(self.alpha) b = self.beta if not self.logscale else torch.exp(self.beta) x = x.reshape(shape[0], shape[1], -1) x = x + (b + 1e-9).reciprocal() * torch.sin(a * x).pow(2) return x.reshape(shape) class FastResidualUnit(nn.Module): def __init__(self, dim, dilation=1): super().__init__() pad = ((7 - 1) * dilation) // 2 self.snake1 = Snake1d(dim) self.conv1 = weight_norm(nn.Conv1d(dim, dim, 7, dilation=dilation, padding=pad)) self.snake2 = Snake1d(dim) self.conv2 = weight_norm(nn.Conv1d(dim, dim, 1)) def forward(self, x): h = self.conv1(self.snake1(x)) h = self.conv2(self.snake2(h)) pad = (x.shape[-1] - h.shape[-1]) // 2 if pad > 0: x = x[..., pad:-pad] return x + h class FastDecoderBlock(nn.Module): def __init__(self, in_dim, out_dim, stride=1): super().__init__() self.snake1 = Snake1d(in_dim) self.conv_t = weight_norm(nn.ConvTranspose1d(in_dim, out_dim, 2 * stride, stride=stride, padding=math.ceil(stride / 2))) self.res1 = FastResidualUnit(out_dim, 1) self.res2 = FastResidualUnit(out_dim, 3) def forward(self, x): x = self.snake1(x) x = self.conv_t(x) x = self.res1(x) x = self.res2(x) return x class FastOobleckDecoder(nn.Module): def __init__(self, channels=128, input_channels=64, audio_channels=2, upsampling_ratios=None, channel_multiples=None): super().__init__() upsampling_ratios = upsampling_ratios or [10, 6, 4, 4, 2] channel_multiples = channel_multiples or [1, 2, 4, 8, 8] cm = [1] + channel_multiples self.conv1 = weight_norm(nn.Conv1d(input_channels, channels * cm[-1], 7, padding=3)) blocks = [] for i, s in enumerate(upsampling_ratios): blocks.append(FastDecoderBlock(channels * cm[len(upsampling_ratios) - i], channels * cm[len(upsampling_ratios) - i - 1], s)) self.blocks = nn.ModuleList(blocks) self.final_snake = Snake1d(channels) self.conv2 = weight_norm(nn.Conv1d(channels, audio_channels, 7, padding=3, bias=False)) def forward(self, z): x = self.conv1(z) for b in self.blocks: x = b(x) return self.conv2(self.final_snake(x)) def snr(ref, gen): min_len = min(ref.shape[-1], gen.shape[-1]) ref, gen = ref[..., :min_len], gen[..., :min_len] noise = ref - gen return 10 * torch.log10((ref ** 2).mean() / ((noise ** 2).mean() + 1e-10)).item() def main(): device = "cuda" out_dir = Path("research_program/vae_distillation/results/trt_verify_60s") out_dir.mkdir(parents=True, exist_ok=True) # Load audio audio_path = "tests/fixtures/techno.wav" data, sr = sf.read(audio_path, dtype="float32") assert sr == 48000 waveform = torch.tensor(data, dtype=torch.float32).T # [2, samples] # Take 60s clip (or full file if shorter) samples_60s = 48000 * 60 clip = waveform[:, :samples_60s].unsqueeze(0).to(device) peak = clip.abs().max() if peak > 1e-6: clip = clip / peak # Load teacher VAE print("Loading teacher VAE...") from diffusers import AutoencoderOobleck vae = AutoencoderOobleck.from_pretrained("ACE-Step/Ace-Step1.5", subfolder="vae") vae = vae.to(device, dtype=torch.float32).eval() # Encode print("Encoding...") with torch.no_grad(): z = vae.encode(clip).latent_dist.sample() print(f"Latent shape: {list(z.shape)}") # Decode with teacher (PyTorch) print("Decoding with teacher (PyTorch)...") with torch.no_grad(): teacher_audio = vae.decoder(z) # Load student (PyTorch) print("Loading student...") ckpt = torch.load("research_program/vae_distillation/checkpoints/student_step620000.pt", map_location=device, weights_only=False) student = FastOobleckDecoder().to(device).eval() student.load_state_dict(ckpt["student_state_dict"]) print("Decoding with student (PyTorch)...") with torch.no_grad(): student_audio = student(z) # Load TRT engines from polygraphy.backend.common import bytes_from_path from polygraphy.backend.trt import engine_from_bytes from polygraphy import cuda as pg_cuda stream = pg_cuda.Stream() lat = z.float().contiguous() def trt_decode(engine_path_str): engine = engine_from_bytes(bytes_from_path(engine_path_str)) ctx = engine.create_execution_context() ctx.set_input_shape("latents", tuple(lat.shape)) ctx.set_tensor_address("latents", lat.data_ptr()) out_shape = tuple(ctx.get_tensor_shape("audio")) buf = torch.empty(out_shape, dtype=torch.float32, device=device) ctx.set_tensor_address("audio", buf.data_ptr()) ctx.execute_async_v3(stream.ptr) stream.synchronize() return buf.clone(), engine, ctx print("Decoding with teacher (TRT FP16)...") teacher_trt_path = "trt_engines/vae_decode_fp16_240s/vae_decode_fp16_240s.engine" teacher_trt_audio, teacher_trt_engine, teacher_trt_ctx = trt_decode(teacher_trt_path) print("Decoding with student (TRT FP16)...") student_trt_path = "trt_engines/dreamvae_decode_fp16_240s/dreamvae_decode_fp16_240s.engine" trt_audio, student_trt_engine, student_trt_ctx = trt_decode(student_trt_path) # Trim to common length min_len = min(teacher_audio.shape[-1], student_audio.shape[-1], teacher_trt_audio.shape[-1], trt_audio.shape[-1], clip.shape[-1]) teacher_audio = teacher_audio[..., :min_len] student_audio = student_audio[..., :min_len] teacher_trt_audio = teacher_trt_audio[..., :min_len] trt_audio = trt_audio[..., :min_len] original = clip[..., :min_len] # Speed benchmark (20 trials each) duration_s = min_len / 48000 print(f"\nSpeed benchmark ({duration_s:.1f}s audio, 20 trials)...") # Warmup with torch.no_grad(): for _ in range(3): _ = vae.decoder(z) _ = student(z) teacher_trt_ctx.execute_async_v3(stream.ptr) stream.synchronize() student_trt_ctx.execute_async_v3(stream.ptr) stream.synchronize() teacher_pt_times = [] student_pt_times = [] teacher_trt_times = [] student_trt_times = [] with torch.no_grad(): for _ in range(20): torch.cuda.synchronize() t0 = time.time() _ = vae.decoder(z) torch.cuda.synchronize() teacher_pt_times.append(time.time() - t0) torch.cuda.synchronize() t0 = time.time() _ = student(z) torch.cuda.synchronize() student_pt_times.append(time.time() - t0) torch.cuda.synchronize() t0 = time.time() teacher_trt_ctx.execute_async_v3(stream.ptr) stream.synchronize() teacher_trt_times.append(time.time() - t0) torch.cuda.synchronize() t0 = time.time() student_trt_ctx.execute_async_v3(stream.ptr) stream.synchronize() student_trt_times.append(time.time() - t0) avg_tp = sum(teacher_pt_times) / len(teacher_pt_times) * 1000 avg_sp = sum(student_pt_times) / len(student_pt_times) * 1000 avg_tt = sum(teacher_trt_times) / len(teacher_trt_times) * 1000 avg_st = sum(student_trt_times) / len(student_trt_times) * 1000 print(f" Teacher (PyTorch): {avg_tp:.1f} ms") print(f" Teacher (TRT FP16): {avg_tt:.1f} ms ({avg_tp/avg_tt:.2f}x vs teacher PT)") print(f" Student (PyTorch): {avg_sp:.1f} ms ({avg_tp/avg_sp:.2f}x vs teacher PT)") print(f" Student (TRT FP16): {avg_st:.1f} ms ({avg_tp/avg_st:.2f}x vs teacher PT)") # Comparisons print(f"\n{'='*60}") print("SNR comparisons (higher = closer match):") print(f"{'='*60}") print(f" Teacher (PT) vs original: {snr(original, teacher_audio):.1f} dB") print(f" Teacher (TRT) vs original: {snr(original, teacher_trt_audio):.1f} dB") print(f" Student (PT) vs original: {snr(original, student_audio):.1f} dB") print(f" Student (TRT) vs original: {snr(original, trt_audio):.1f} dB") print(f" Student (PT) vs teacher (PT): {snr(teacher_audio, student_audio):.1f} dB") print(f" Student (TRT) vs teacher (PT): {snr(teacher_audio, trt_audio):.1f} dB") print(f" Teacher TRT vs teacher PT: {snr(teacher_audio, teacher_trt_audio):.1f} dB") print(f" Student TRT vs student PT: {snr(student_audio, trt_audio):.1f} dB") print() # Save WAVs for name, audio in [("original", original), ("teacher_pytorch", teacher_audio), ("teacher_trt_fp16", teacher_trt_audio), ("student_pytorch", student_audio), ("student_trt_fp16", trt_audio)]: path = out_dir / f"{name}.wav" sf.write(str(path), audio[0].cpu().numpy().T, 48000) print(f" Saved: {path}") print(f"\nListen and compare in {out_dir}") if __name__ == "__main__": main()