Virtual-Cloths-TryOn / CatVTON_model.py
harsh99's picture
Update CatVTON_model.py
ef4558b verified
import inspect
import os
from typing import Union
import PIL
import numpy as np
import torch
from diffusers.utils.torch_utils import randn_tensor
from utils import (check_inputs_maskfree, get_time_embedding, numpy_to_pil, prepare_image, compute_vae_encodings)
from ddpm import DDPMSampler
from tqdm import tqdm
class CatVTONPix2PixPipeline:
def __init__(
self,
weight_dtype=torch.float32,
device='cpu',
compile=False,
skip_safety_check=True,
use_tf32=True,
models={},
):
self.device = device
self.weight_dtype = weight_dtype
self.skip_safety_check = skip_safety_check
self.models = models
self.generator = torch.Generator(device=device)
self.noise_scheduler = DDPMSampler(generator=self.generator)
# self.vae = AutoencoderKL.from_pretrained("stabilityai/sd-vae-ft-mse").to(device, dtype=weight_dtype)
self.encoder= models.get('encoder', None)
self.decoder= models.get('decoder', None)
self.unet=models.get('diffusion', None)
# # Enable TF32 for faster training on Ampere GPUs (A100 and RTX 30 series).
# if use_tf32:
# torch.set_float32_matmul_precision("high")
# torch.backends.cuda.matmul.allow_tf32 = True
@torch.no_grad()
def __call__(
self,
image: Union[PIL.Image.Image, torch.Tensor],
condition_image: Union[PIL.Image.Image, torch.Tensor],
num_inference_steps: int = 50,
guidance_scale: float = 2.5,
height: int = 1024,
width: int = 768,
generator=None,
eta=1.0,
**kwargs
):
concat_dim = -1 # FIXME: y axis concat
# Prepare inputs to Tensor
image, condition_image = check_inputs_maskfree(image, condition_image, width, height)
image = prepare_image(image).to(self.device, dtype=self.weight_dtype)
condition_image = prepare_image(condition_image).to(self.device, dtype=self.weight_dtype)
# Encode the image
image_latent = compute_vae_encodings(image, self.encoder)
condition_latent = compute_vae_encodings(condition_image, self.encoder)
del image, condition_image
# Concatenate latents
# Concatenate latents
condition_latent_concat = torch.cat([image_latent, condition_latent], dim=concat_dim)
# Prepare noise
latents = randn_tensor(
condition_latent_concat.shape,
generator=generator,
device=condition_latent_concat.device,
dtype=self.weight_dtype,
)
# Prepare timesteps
self.noise_scheduler.set_inference_timesteps(num_inference_steps)
timesteps = self.noise_scheduler.timesteps
# latents = latents * self.noise_scheduler.init_noise_sigma
latents = self.noise_scheduler.add_noise(latents, timesteps[0])
# Classifier-Free Guidance
if do_classifier_free_guidance := (guidance_scale > 1.0):
condition_latent_concat = torch.cat(
[
torch.cat([image_latent, torch.zeros_like(condition_latent)], dim=concat_dim),
condition_latent_concat,
]
)
num_warmup_steps = 0 # For simple DDPM, no warmup needed
with tqdm(total=num_inference_steps) as progress_bar:
for i, t in enumerate(timesteps):
# expand the latents if we are doing classifier free guidance
latent_model_input = (torch.cat([latents] * 2) if do_classifier_free_guidance else latents)
# prepare the input for the inpainting model
p2p_latent_model_input = torch.cat([latent_model_input, condition_latent_concat], dim=1)
# predict the noise residual
timestep = t.repeat(p2p_latent_model_input.shape[0])
time_embedding = get_time_embedding(timestep).to(self.device, dtype=self.weight_dtype)
noise_pred = self.unet(
p2p_latent_model_input,
time_embedding
)
# perform guidance
if do_classifier_free_guidance:
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
noise_pred = noise_pred_uncond + guidance_scale * (
noise_pred_text - noise_pred_uncond
)
# compute the previous noisy sample x_t -> x_t-1
latents = self.noise_scheduler.step(
t, latents, noise_pred
)
# call the callback, if provided
if i == len(timesteps) - 1 or (
(i + 1) > num_warmup_steps
):
progress_bar.update()
# Decode the final latents
latents = latents.split(latents.shape[concat_dim] // 2, dim=concat_dim)[0]
# latents = 1 / self.vae.config.scaling_factor * latents
# image = self.vae.decode(latents.to(self.device, dtype=self.weight_dtype)).sample
image = self.decoder(latents.to(self.device, dtype=self.weight_dtype))
image = (image / 2 + 0.5).clamp(0, 1)
# we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16
image = image.cpu().permute(0, 2, 3, 1).float().numpy()
image = numpy_to_pil(image)
return image