DreamVAE / scripts /verify_trt_audio.py
ryanontheinside's picture
DreamVAE initial release
53e74f7
#!/usr/bin/env python3
"""Compare TRT student engine output against PyTorch teacher and student.
Encodes real audio through the VAE, decodes with all three backends,
saves WAVs and prints SNR comparisons.
"""
import math
import time
from pathlib import Path
import soundfile as sf
import torch
import torch.nn as nn
from torch.nn.utils import weight_norm
# -- Student model (must match training script) --
class Snake1d(nn.Module):
def __init__(self, hidden_dim, logscale=True):
super().__init__()
self.alpha = nn.Parameter(torch.zeros(1, hidden_dim, 1))
self.beta = nn.Parameter(torch.zeros(1, hidden_dim, 1))
self.logscale = logscale
def forward(self, x):
shape = x.shape
a = self.alpha if not self.logscale else torch.exp(self.alpha)
b = self.beta if not self.logscale else torch.exp(self.beta)
x = x.reshape(shape[0], shape[1], -1)
x = x + (b + 1e-9).reciprocal() * torch.sin(a * x).pow(2)
return x.reshape(shape)
class FastResidualUnit(nn.Module):
def __init__(self, dim, dilation=1):
super().__init__()
pad = ((7 - 1) * dilation) // 2
self.snake1 = Snake1d(dim)
self.conv1 = weight_norm(nn.Conv1d(dim, dim, 7, dilation=dilation, padding=pad))
self.snake2 = Snake1d(dim)
self.conv2 = weight_norm(nn.Conv1d(dim, dim, 1))
def forward(self, x):
h = self.conv1(self.snake1(x))
h = self.conv2(self.snake2(h))
pad = (x.shape[-1] - h.shape[-1]) // 2
if pad > 0:
x = x[..., pad:-pad]
return x + h
class FastDecoderBlock(nn.Module):
def __init__(self, in_dim, out_dim, stride=1):
super().__init__()
self.snake1 = Snake1d(in_dim)
self.conv_t = weight_norm(nn.ConvTranspose1d(in_dim, out_dim, 2 * stride, stride=stride, padding=math.ceil(stride / 2)))
self.res1 = FastResidualUnit(out_dim, 1)
self.res2 = FastResidualUnit(out_dim, 3)
def forward(self, x):
x = self.snake1(x)
x = self.conv_t(x)
x = self.res1(x)
x = self.res2(x)
return x
class FastOobleckDecoder(nn.Module):
def __init__(self, channels=128, input_channels=64, audio_channels=2, upsampling_ratios=None, channel_multiples=None):
super().__init__()
upsampling_ratios = upsampling_ratios or [10, 6, 4, 4, 2]
channel_multiples = channel_multiples or [1, 2, 4, 8, 8]
cm = [1] + channel_multiples
self.conv1 = weight_norm(nn.Conv1d(input_channels, channels * cm[-1], 7, padding=3))
blocks = []
for i, s in enumerate(upsampling_ratios):
blocks.append(FastDecoderBlock(channels * cm[len(upsampling_ratios) - i], channels * cm[len(upsampling_ratios) - i - 1], s))
self.blocks = nn.ModuleList(blocks)
self.final_snake = Snake1d(channels)
self.conv2 = weight_norm(nn.Conv1d(channels, audio_channels, 7, padding=3, bias=False))
def forward(self, z):
x = self.conv1(z)
for b in self.blocks:
x = b(x)
return self.conv2(self.final_snake(x))
def snr(ref, gen):
min_len = min(ref.shape[-1], gen.shape[-1])
ref, gen = ref[..., :min_len], gen[..., :min_len]
noise = ref - gen
return 10 * torch.log10((ref ** 2).mean() / ((noise ** 2).mean() + 1e-10)).item()
def main():
device = "cuda"
out_dir = Path("research_program/vae_distillation/results/trt_verify_60s")
out_dir.mkdir(parents=True, exist_ok=True)
# Load audio
audio_path = "tests/fixtures/techno.wav"
data, sr = sf.read(audio_path, dtype="float32")
assert sr == 48000
waveform = torch.tensor(data, dtype=torch.float32).T # [2, samples]
# Take 60s clip (or full file if shorter)
samples_60s = 48000 * 60
clip = waveform[:, :samples_60s].unsqueeze(0).to(device)
peak = clip.abs().max()
if peak > 1e-6:
clip = clip / peak
# Load teacher VAE
print("Loading teacher VAE...")
from diffusers import AutoencoderOobleck
vae = AutoencoderOobleck.from_pretrained("ACE-Step/Ace-Step1.5", subfolder="vae")
vae = vae.to(device, dtype=torch.float32).eval()
# Encode
print("Encoding...")
with torch.no_grad():
z = vae.encode(clip).latent_dist.sample()
print(f"Latent shape: {list(z.shape)}")
# Decode with teacher (PyTorch)
print("Decoding with teacher (PyTorch)...")
with torch.no_grad():
teacher_audio = vae.decoder(z)
# Load student (PyTorch)
print("Loading student...")
ckpt = torch.load("research_program/vae_distillation/checkpoints/student_step620000.pt",
map_location=device, weights_only=False)
student = FastOobleckDecoder().to(device).eval()
student.load_state_dict(ckpt["student_state_dict"])
print("Decoding with student (PyTorch)...")
with torch.no_grad():
student_audio = student(z)
# Load TRT engines
from polygraphy.backend.common import bytes_from_path
from polygraphy.backend.trt import engine_from_bytes
from polygraphy import cuda as pg_cuda
stream = pg_cuda.Stream()
lat = z.float().contiguous()
def trt_decode(engine_path_str):
engine = engine_from_bytes(bytes_from_path(engine_path_str))
ctx = engine.create_execution_context()
ctx.set_input_shape("latents", tuple(lat.shape))
ctx.set_tensor_address("latents", lat.data_ptr())
out_shape = tuple(ctx.get_tensor_shape("audio"))
buf = torch.empty(out_shape, dtype=torch.float32, device=device)
ctx.set_tensor_address("audio", buf.data_ptr())
ctx.execute_async_v3(stream.ptr)
stream.synchronize()
return buf.clone(), engine, ctx
print("Decoding with teacher (TRT FP16)...")
teacher_trt_path = "trt_engines/vae_decode_fp16_240s/vae_decode_fp16_240s.engine"
teacher_trt_audio, teacher_trt_engine, teacher_trt_ctx = trt_decode(teacher_trt_path)
print("Decoding with student (TRT FP16)...")
student_trt_path = "trt_engines/dreamvae_decode_fp16_240s/dreamvae_decode_fp16_240s.engine"
trt_audio, student_trt_engine, student_trt_ctx = trt_decode(student_trt_path)
# Trim to common length
min_len = min(teacher_audio.shape[-1], student_audio.shape[-1],
teacher_trt_audio.shape[-1], trt_audio.shape[-1], clip.shape[-1])
teacher_audio = teacher_audio[..., :min_len]
student_audio = student_audio[..., :min_len]
teacher_trt_audio = teacher_trt_audio[..., :min_len]
trt_audio = trt_audio[..., :min_len]
original = clip[..., :min_len]
# Speed benchmark (20 trials each)
duration_s = min_len / 48000
print(f"\nSpeed benchmark ({duration_s:.1f}s audio, 20 trials)...")
# Warmup
with torch.no_grad():
for _ in range(3):
_ = vae.decoder(z)
_ = student(z)
teacher_trt_ctx.execute_async_v3(stream.ptr)
stream.synchronize()
student_trt_ctx.execute_async_v3(stream.ptr)
stream.synchronize()
teacher_pt_times = []
student_pt_times = []
teacher_trt_times = []
student_trt_times = []
with torch.no_grad():
for _ in range(20):
torch.cuda.synchronize()
t0 = time.time()
_ = vae.decoder(z)
torch.cuda.synchronize()
teacher_pt_times.append(time.time() - t0)
torch.cuda.synchronize()
t0 = time.time()
_ = student(z)
torch.cuda.synchronize()
student_pt_times.append(time.time() - t0)
torch.cuda.synchronize()
t0 = time.time()
teacher_trt_ctx.execute_async_v3(stream.ptr)
stream.synchronize()
teacher_trt_times.append(time.time() - t0)
torch.cuda.synchronize()
t0 = time.time()
student_trt_ctx.execute_async_v3(stream.ptr)
stream.synchronize()
student_trt_times.append(time.time() - t0)
avg_tp = sum(teacher_pt_times) / len(teacher_pt_times) * 1000
avg_sp = sum(student_pt_times) / len(student_pt_times) * 1000
avg_tt = sum(teacher_trt_times) / len(teacher_trt_times) * 1000
avg_st = sum(student_trt_times) / len(student_trt_times) * 1000
print(f" Teacher (PyTorch): {avg_tp:.1f} ms")
print(f" Teacher (TRT FP16): {avg_tt:.1f} ms ({avg_tp/avg_tt:.2f}x vs teacher PT)")
print(f" Student (PyTorch): {avg_sp:.1f} ms ({avg_tp/avg_sp:.2f}x vs teacher PT)")
print(f" Student (TRT FP16): {avg_st:.1f} ms ({avg_tp/avg_st:.2f}x vs teacher PT)")
# Comparisons
print(f"\n{'='*60}")
print("SNR comparisons (higher = closer match):")
print(f"{'='*60}")
print(f" Teacher (PT) vs original: {snr(original, teacher_audio):.1f} dB")
print(f" Teacher (TRT) vs original: {snr(original, teacher_trt_audio):.1f} dB")
print(f" Student (PT) vs original: {snr(original, student_audio):.1f} dB")
print(f" Student (TRT) vs original: {snr(original, trt_audio):.1f} dB")
print(f" Student (PT) vs teacher (PT): {snr(teacher_audio, student_audio):.1f} dB")
print(f" Student (TRT) vs teacher (PT): {snr(teacher_audio, trt_audio):.1f} dB")
print(f" Teacher TRT vs teacher PT: {snr(teacher_audio, teacher_trt_audio):.1f} dB")
print(f" Student TRT vs student PT: {snr(student_audio, trt_audio):.1f} dB")
print()
# Save WAVs
for name, audio in [("original", original), ("teacher_pytorch", teacher_audio),
("teacher_trt_fp16", teacher_trt_audio),
("student_pytorch", student_audio), ("student_trt_fp16", trt_audio)]:
path = out_dir / f"{name}.wav"
sf.write(str(path), audio[0].cpu().numpy().T, 48000)
print(f" Saved: {path}")
print(f"\nListen and compare in {out_dir}")
if __name__ == "__main__":
main()