Spaces:
Running
Running
| import math | |
| from typing import List, Union | |
| import PIL | |
| import torch | |
| import numpy as np | |
| from tqdm import tqdm | |
| from ddpm import DDPMSampler | |
| from PIL import Image | |
| import load_model | |
| from utils import check_inputs, prepare_image, prepare_mask_image | |
| WIDTH = 512 | |
| HEIGHT = 512 | |
| LATENTS_WIDTH = WIDTH // 8 | |
| LATENTS_HEIGHT = HEIGHT // 8 | |
| def repaint_result(result, person_image, mask_image): | |
| result, person, mask = np.array(result), np.array(person_image), np.array(mask_image) | |
| # expand the mask to 3 channels & to 0~1 | |
| mask = np.expand_dims(mask, axis=2) | |
| mask = mask / 255.0 | |
| # mask for result, ~mask for person | |
| result_ = result * mask + person * (1 - mask) | |
| return Image.fromarray(result_.astype(np.uint8)) | |
| def numpy_to_pil(images): | |
| """ | |
| Convert a numpy image or a batch of images to a PIL image. | |
| """ | |
| if images.ndim == 3: | |
| images = images[None, ...] | |
| images = (images * 255).round().astype("uint8") | |
| if images.shape[-1] == 1: | |
| # special case for grayscale (single channel) images | |
| pil_images = [Image.fromarray(image.squeeze(), mode="L") for image in images] | |
| else: | |
| pil_images = [Image.fromarray(image) for image in images] | |
| return pil_images | |
| def tensor_to_image(tensor: torch.Tensor): | |
| """ | |
| Converts a torch tensor to PIL Image. | |
| """ | |
| assert tensor.dim() == 3, "Input tensor should be 3-dimensional." | |
| assert tensor.dtype == torch.float32, "Input tensor should be float32." | |
| assert ( | |
| tensor.min() >= 0 and tensor.max() <= 1 | |
| ), "Input tensor should be in range [0, 1]." | |
| tensor = tensor.cpu() | |
| tensor = tensor * 255 | |
| tensor = tensor.permute(1, 2, 0) | |
| tensor = tensor.numpy().astype(np.uint8) | |
| image = Image.fromarray(tensor) | |
| return image | |
| def concat_images(images: List[Image.Image], divider: int = 4, cols: int = 4): | |
| """ | |
| Concatenates images horizontally and with | |
| """ | |
| widths = [image.size[0] for image in images] | |
| heights = [image.size[1] for image in images] | |
| total_width = cols * max(widths) | |
| total_width += divider * (cols - 1) | |
| # `col` images each row | |
| rows = math.ceil(len(images) / cols) | |
| total_height = max(heights) * rows | |
| # add divider between rows | |
| total_height += divider * (len(heights) // cols - 1) | |
| # all black image | |
| concat_image = Image.new("RGB", (total_width, total_height), (0, 0, 0)) | |
| x_offset = 0 | |
| y_offset = 0 | |
| for i, image in enumerate(images): | |
| concat_image.paste(image, (x_offset, y_offset)) | |
| x_offset += image.size[0] + divider | |
| if (i + 1) % cols == 0: | |
| x_offset = 0 | |
| y_offset += image.size[1] + divider | |
| return concat_image | |
| def compute_vae_encodings(image_tensor, encoder, device): | |
| """Encode image using VAE encoder""" | |
| # Generate random noise for encoding | |
| encoder_noise = torch.randn( | |
| (image_tensor.shape[0], 4, image_tensor.shape[2] // 8, image_tensor.shape[3] // 8), | |
| device=device, | |
| ) | |
| # Encode using your custom encoder | |
| latent = encoder(image_tensor, encoder_noise) | |
| return latent | |
| def generate( | |
| image: Union[PIL.Image.Image, torch.Tensor], | |
| condition_image: Union[PIL.Image.Image, torch.Tensor], | |
| mask: Union[PIL.Image.Image, torch.Tensor], | |
| num_inference_steps: int = 50, | |
| guidance_scale: float = 2.5, | |
| height: int = 1024, | |
| width: int = 768, | |
| models={}, | |
| sampler_name="ddpm", | |
| seed=None, | |
| device=None, | |
| idle_device=None, | |
| **kwargs | |
| ): | |
| with torch.no_grad(): | |
| if idle_device: | |
| to_idle = lambda x: x.to(idle_device) | |
| else: | |
| to_idle = lambda x: x | |
| # Initialize random number generator according to the seed specified | |
| generator = torch.Generator(device=device) | |
| if seed is None: | |
| generator.seed() | |
| else: | |
| generator.manual_seed(seed) | |
| concat_dim = -1 # FIXME: y axis concat | |
| # Prepare inputs to Tensor | |
| image, condition_image, mask = check_inputs(image, condition_image, mask, width, height) | |
| # print(f"Input image shape: {image.shape}, condition image shape: {condition_image.shape}, mask shape: {mask.shape}") | |
| image = prepare_image(image).to(device) | |
| condition_image = prepare_image(condition_image).to(device) | |
| mask = prepare_mask_image(mask).to(device) | |
| print(f"Prepared image shape: {image.shape}, condition image shape: {condition_image.shape}, mask shape: {mask.shape}") | |
| # Mask image | |
| masked_image = image * (mask < 0.5) | |
| print(f"Masked image shape: {masked_image.shape}") | |
| # VAE encoding | |
| encoder = models.get('encoder', None) | |
| if encoder is None: | |
| raise ValueError("Encoder model not found in models dictionary") | |
| encoder.to(device) | |
| masked_latent = compute_vae_encodings(masked_image, encoder, device) | |
| condition_latent = compute_vae_encodings(condition_image, encoder, device) | |
| to_idle(encoder) | |
| print(f"Masked latent shape: {masked_latent.shape}, condition latent shape: {condition_latent.shape}") | |
| # Concatenate latents | |
| masked_latent_concat = torch.cat([masked_latent, condition_latent], dim=concat_dim) | |
| print(f"Masked Person latent + garment latent: {masked_latent_concat.shape}") | |
| mask_latent = torch.nn.functional.interpolate(mask, size=masked_latent.shape[-2:], mode="nearest") | |
| del image, mask, condition_image | |
| mask_latent_concat = torch.cat([mask_latent, torch.zeros_like(mask_latent)], dim=concat_dim) | |
| print(f"Mask latent concat shape: {mask_latent_concat.shape}") | |
| # Initialize latents | |
| latents = torch.randn( | |
| masked_latent_concat.shape, | |
| generator=generator, | |
| device=masked_latent_concat.device, | |
| dtype=masked_latent_concat.dtype | |
| ) | |
| print(f"Latents shape: {latents.shape}") | |
| # Prepare timesteps | |
| if sampler_name == "ddpm": | |
| sampler = DDPMSampler(generator) | |
| sampler.set_inference_timesteps(num_inference_steps) | |
| else: | |
| raise ValueError("Unknown sampler value %s. " % sampler_name) | |
| timesteps = sampler.timesteps | |
| # latents = sampler.add_noise(latents, timesteps[0]) | |
| # Classifier-Free Guidance | |
| do_classifier_free_guidance = guidance_scale > 1.0 | |
| if do_classifier_free_guidance: | |
| masked_latent_concat = torch.cat( | |
| [ | |
| torch.cat([masked_latent, torch.zeros_like(condition_latent)], dim=concat_dim), | |
| masked_latent_concat, | |
| ] | |
| ) | |
| mask_latent_concat = torch.cat([mask_latent_concat] * 2) | |
| print(f"Masked latent concat for classifier-free guidance: {masked_latent_concat.shape}, mask latent concat: {mask_latent_concat.shape}") | |
| # Denoising loop - Fixed: removed self references and incorrect scheduler calls | |
| num_warmup_steps = 0 # For simple DDPM, no warmup needed | |
| with tqdm(total=num_inference_steps) as progress_bar: | |
| for i, t in enumerate(timesteps): | |
| # expand the latents if we are doing classifier free guidance | |
| non_inpainting_latent_model_input = (torch.cat([latents] * 2) if do_classifier_free_guidance else latents) | |
| # print(f"Non-inpainting latent model input shape: {non_inpainting_latent_model_input.shape}") | |
| # prepare the input for the inpainting model | |
| inpainting_latent_model_input = torch.cat([non_inpainting_latent_model_input, mask_latent_concat, masked_latent_concat], dim=1) | |
| # print(f"Inpainting latent model input shape: {inpainting_latent_model_input.shape}") | |
| # predict the noise residual | |
| diffusion = models.get('diffusion', None) | |
| if diffusion is None: | |
| raise ValueError("Diffusion model not found in models dictionary") | |
| diffusion.to(device) | |
| # Create time embedding for the current timestep | |
| time_embedding = get_time_embedding(t.item()).to(device) | |
| # print(f"Time embedding shape: {time_embedding.shape}") | |
| if do_classifier_free_guidance: | |
| time_embedding = torch.cat([time_embedding] * 2) | |
| noise_pred = diffusion( | |
| inpainting_latent_model_input, | |
| time_embedding | |
| ) | |
| to_idle(diffusion) | |
| # perform guidance | |
| if do_classifier_free_guidance: | |
| noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) | |
| noise_pred = noise_pred_uncond + guidance_scale * ( | |
| noise_pred_text - noise_pred_uncond | |
| ) | |
| # compute the previous noisy sample x_t -> x_t-1 | |
| latents = sampler.step(t, latents, noise_pred) | |
| # Update progress bar | |
| if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps): | |
| progress_bar.update() | |
| # Decode the final latents | |
| latents = latents.split(latents.shape[concat_dim] // 2, dim=concat_dim)[0] | |
| decoder = models.get('decoder', None) | |
| if decoder is None: | |
| raise ValueError("Decoder model not found in models dictionary") | |
| decoder.to(device) | |
| image = decoder(latents.to(device)) | |
| # image = rescale(image, (-1, 1), (0, 255), clamp=True) | |
| image = (image / 2 + 0.5).clamp(0, 1) | |
| # we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16 | |
| image = image.cpu().permute(0, 2, 3, 1).float().numpy() | |
| image = numpy_to_pil(image) | |
| to_idle(decoder) | |
| return image | |
| def rescale(x, old_range, new_range, clamp=False): | |
| old_min, old_max = old_range | |
| new_min, new_max = new_range | |
| x -= old_min | |
| x *= (new_max - new_min) / (old_max - old_min) | |
| x += new_min | |
| if clamp: | |
| x = x.clamp(new_min, new_max) | |
| return x | |
| def get_time_embedding(timestep): | |
| # Shape: (160,) | |
| freqs = torch.pow(10000, -torch.arange(start=0, end=160, dtype=torch.float32) / 160) | |
| # Shape: (1, 160) | |
| x = torch.tensor([timestep], dtype=torch.float32)[:, None] * freqs[None] | |
| # Shape: (1, 160 * 2) -> (1, 320) | |
| return torch.cat([torch.cos(x), torch.sin(x)], dim=-1) | |
| if __name__ == "__main__": | |
| # Example usage | |
| image = Image.open("person.jpg").convert("RGB") | |
| condition_image = Image.open("image.png").convert("RGB") | |
| mask = Image.open("agnostic_mask.png").convert("L") | |
| # Load models | |
| models=load_model.preload_models_from_standard_weights("sd-v1-5-inpainting.ckpt", device="cuda") | |
| # Generate image | |
| generated_image = generate( | |
| image=image, | |
| condition_image=condition_image, | |
| mask=mask, | |
| num_inference_steps=50, | |
| guidance_scale=2.5, | |
| width=WIDTH, | |
| height=HEIGHT, | |
| models=models, | |
| sampler_name="ddpm", | |
| seed=42, | |
| device="cuda" # or "cpu" | |
| ) | |
| generated_image[0].save("generated_image.png") | |