|
|
import argparse |
|
|
import copy |
|
|
import logging |
|
|
import math |
|
|
import os |
|
|
import os.path as osp |
|
|
import random |
|
|
import time |
|
|
import warnings |
|
|
from collections import OrderedDict |
|
|
from datetime import datetime |
|
|
from pathlib import Path |
|
|
from tempfile import TemporaryDirectory |
|
|
from src.utils.util import get_fps, read_frames, save_videos_grid |
|
|
|
|
|
import diffusers |
|
|
import mlflow |
|
|
import torch |
|
|
import torch.nn as nn |
|
|
import torch.nn.functional as F |
|
|
import torch.utils.checkpoint |
|
|
import transformers |
|
|
from accelerate import Accelerator |
|
|
from accelerate.logging import get_logger |
|
|
from accelerate.utils import DistributedDataParallelKwargs |
|
|
from diffusers import AutoencoderKL, DDIMScheduler |
|
|
from diffusers.optimization import get_scheduler |
|
|
from diffusers.utils import check_min_version |
|
|
from diffusers.utils.import_utils import is_xformers_available |
|
|
from einops import rearrange |
|
|
from omegaconf import OmegaConf |
|
|
from PIL import Image |
|
|
from torchvision import transforms |
|
|
from tqdm.auto import tqdm |
|
|
from transformers import CLIPVisionModelWithProjection |
|
|
|
|
|
from src.dataset.dance_video import HumanDanceVideoDataset |
|
|
from src.models.mutual_self_attention import ReferenceAttentionControl |
|
|
from src.models.pose_guider import PoseGuider |
|
|
from src.models.unet_2d_condition import UNet2DConditionModel |
|
|
from src.models.unet_3d import UNet3DConditionModel |
|
|
from src.pipelines.pipeline_pose2vid_long import Pose2VideoPipeline |
|
|
from src.utils.util import ( |
|
|
delete_additional_ckpt, |
|
|
import_filename, |
|
|
read_frames, |
|
|
save_videos_grid, |
|
|
seed_everything, |
|
|
) |
|
|
|
|
|
warnings.filterwarnings("ignore") |
|
|
|
|
|
|
|
|
check_min_version("0.10.0.dev0") |
|
|
|
|
|
logger = get_logger(__name__, log_level="INFO") |
|
|
|
|
|
|
|
|
class Net(nn.Module): |
|
|
def __init__( |
|
|
self, |
|
|
reference_unet: UNet2DConditionModel, |
|
|
denoising_unet: UNet3DConditionModel, |
|
|
pose_guider: PoseGuider, |
|
|
reference_control_writer, |
|
|
reference_control_reader, |
|
|
): |
|
|
super().__init__() |
|
|
self.reference_unet = reference_unet |
|
|
self.denoising_unet = denoising_unet |
|
|
self.pose_guider = pose_guider |
|
|
self.reference_control_writer = reference_control_writer |
|
|
self.reference_control_reader = reference_control_reader |
|
|
|
|
|
def forward( |
|
|
self, |
|
|
noisy_latents, |
|
|
timesteps, |
|
|
ref_image_latents, |
|
|
clip_image_embeds, |
|
|
pose_img, |
|
|
uncond_fwd: bool = False, |
|
|
): |
|
|
pose_cond_tensor = pose_img.to(device="cuda") |
|
|
pose_fea = self.pose_guider(pose_cond_tensor) |
|
|
|
|
|
if not uncond_fwd: |
|
|
ref_timesteps = torch.zeros_like(timesteps) |
|
|
self.reference_unet( |
|
|
ref_image_latents, |
|
|
ref_timesteps, |
|
|
encoder_hidden_states=clip_image_embeds, |
|
|
return_dict=False, |
|
|
) |
|
|
self.reference_control_reader.update(self.reference_control_writer) |
|
|
|
|
|
model_pred = self.denoising_unet( |
|
|
noisy_latents, |
|
|
timesteps, |
|
|
pose_cond_fea=pose_fea, |
|
|
encoder_hidden_states=clip_image_embeds, |
|
|
).sample |
|
|
|
|
|
return model_pred |
|
|
|
|
|
|
|
|
def compute_snr(noise_scheduler, timesteps): |
|
|
""" |
|
|
Computes SNR as per |
|
|
https://github.com/TiankaiHang/Min-SNR-Diffusion-Training/blob/521b624bd70c67cee4bdf49225915f5945a872e3/guided_diffusion/gaussian_diffusion.py#L847-L849 |
|
|
""" |
|
|
alphas_cumprod = noise_scheduler.alphas_cumprod |
|
|
sqrt_alphas_cumprod = alphas_cumprod**0.5 |
|
|
sqrt_one_minus_alphas_cumprod = (1.0 - alphas_cumprod) ** 0.5 |
|
|
|
|
|
|
|
|
|
|
|
sqrt_alphas_cumprod = sqrt_alphas_cumprod.to(device=timesteps.device)[ |
|
|
timesteps |
|
|
].float() |
|
|
while len(sqrt_alphas_cumprod.shape) < len(timesteps.shape): |
|
|
sqrt_alphas_cumprod = sqrt_alphas_cumprod[..., None] |
|
|
alpha = sqrt_alphas_cumprod.expand(timesteps.shape) |
|
|
|
|
|
sqrt_one_minus_alphas_cumprod = sqrt_one_minus_alphas_cumprod.to( |
|
|
device=timesteps.device |
|
|
)[timesteps].float() |
|
|
while len(sqrt_one_minus_alphas_cumprod.shape) < len(timesteps.shape): |
|
|
sqrt_one_minus_alphas_cumprod = sqrt_one_minus_alphas_cumprod[..., None] |
|
|
sigma = sqrt_one_minus_alphas_cumprod.expand(timesteps.shape) |
|
|
|
|
|
|
|
|
snr = (alpha / sigma) ** 2 |
|
|
return snr |
|
|
|
|
|
|
|
|
def log_validation( |
|
|
vae, |
|
|
image_enc, |
|
|
net, |
|
|
scheduler, |
|
|
accelerator, |
|
|
width, |
|
|
height, |
|
|
global_step, |
|
|
clip_length=24, |
|
|
generator=None, |
|
|
|
|
|
): |
|
|
logger.info("Running validation... ") |
|
|
|
|
|
ori_net = accelerator.unwrap_model(net) |
|
|
reference_unet = ori_net.reference_unet |
|
|
denoising_unet = ori_net.denoising_unet |
|
|
pose_guider = ori_net.pose_guider |
|
|
|
|
|
if generator is None: |
|
|
generator = torch.manual_seed(42) |
|
|
tmp_denoising_unet = copy.deepcopy(denoising_unet) |
|
|
tmp_denoising_unet = tmp_denoising_unet.to(dtype=torch.float16) |
|
|
|
|
|
pipe = Pose2VideoPipeline( |
|
|
vae=vae, |
|
|
image_encoder=image_enc, |
|
|
reference_unet=reference_unet, |
|
|
denoising_unet=tmp_denoising_unet, |
|
|
pose_guider=pose_guider, |
|
|
scheduler=scheduler, |
|
|
) |
|
|
pipe = pipe.to(accelerator.device) |
|
|
date_str = datetime.now().strftime("%Y%m%d") |
|
|
time_str = datetime.now().strftime("%H%M") |
|
|
save_dir_name = f"{time_str}" |
|
|
save_dir = Path(f"vividfuxian_motion/{date_str}/{save_dir_name}") |
|
|
save_dir.mkdir(exist_ok=True, parents=True) |
|
|
|
|
|
model_video_paths = ["/mnt/lpai-dione/ssai/cvg/team/wjj/ViViD/dataset/ViViD/dresses/videos/803128_detail.mp4"] |
|
|
cloth_image_paths=["/mnt/lpai-dione/ssai/cvg/team/wjj/ViViD/dataset/ViViD/dresses/images/1060638_in_xl.jpg"] |
|
|
transform = transforms.Compose( |
|
|
[transforms.Resize((height, width)), transforms.ToTensor()] |
|
|
) |
|
|
for model_image_path in model_video_paths: |
|
|
src_fps = get_fps(model_image_path) |
|
|
|
|
|
model_name = Path(model_image_path).stem |
|
|
agnostic_path=model_image_path.replace("videos","agnostic") |
|
|
agn_mask_path=model_image_path.replace("videos","agnostic_mask") |
|
|
densepose_path=model_image_path.replace("videos","densepose") |
|
|
|
|
|
video_tensor_list=[] |
|
|
video_images=read_frames(model_image_path) |
|
|
|
|
|
for vid_image_pil in video_images[:clip_length]: |
|
|
video_tensor_list.append(transform(vid_image_pil)) |
|
|
|
|
|
video_tensor = torch.stack(video_tensor_list, dim=0) |
|
|
video_tensor = video_tensor.transpose(0, 1) |
|
|
|
|
|
agnostic_list=[] |
|
|
agnostic_images=read_frames(agnostic_path) |
|
|
for agnostic_image_pil in agnostic_images[:clip_length]: |
|
|
agnostic_list.append(agnostic_image_pil) |
|
|
|
|
|
agn_mask_list=[] |
|
|
agn_mask_images=read_frames(agn_mask_path) |
|
|
for agn_mask_image_pil in agn_mask_images[:clip_length]: |
|
|
agn_mask_list.append(agn_mask_image_pil) |
|
|
|
|
|
pose_list=[] |
|
|
pose_images=read_frames(densepose_path) |
|
|
for pose_image_pil in pose_images[:clip_length]: |
|
|
pose_list.append(pose_image_pil) |
|
|
|
|
|
video_tensor = video_tensor.unsqueeze(0) |
|
|
|
|
|
|
|
|
for cloth_image_path in cloth_image_paths: |
|
|
cloth_name = Path(cloth_image_path).stem |
|
|
cloth_image_pil = Image.open(cloth_image_path).convert("RGB") |
|
|
|
|
|
cloth_mask_path=cloth_image_path.replace("cloth","cloth_mask") |
|
|
cloth_mask_pil = Image.open(cloth_mask_path).convert("RGB") |
|
|
|
|
|
pipeline_output = pipe( |
|
|
agnostic_list, |
|
|
agn_mask_list, |
|
|
cloth_image_pil, |
|
|
cloth_mask_pil, |
|
|
pose_list, |
|
|
width, |
|
|
height, |
|
|
clip_length, |
|
|
20, |
|
|
3.5, |
|
|
generator=generator, |
|
|
) |
|
|
video = pipeline_output.videos |
|
|
|
|
|
video = torch.cat([video_tensor,video], dim=0) |
|
|
save_videos_grid( |
|
|
video, |
|
|
f"{save_dir}/{global_step:06d}-{model_name}_{cloth_name}.mp4", |
|
|
n_rows=2, |
|
|
fps=src_fps, |
|
|
) |
|
|
|
|
|
del tmp_denoising_unet |
|
|
del pipe |
|
|
torch.cuda.empty_cache() |
|
|
|
|
|
return video |
|
|
|
|
|
|
|
|
def main(cfg): |
|
|
kwargs = DistributedDataParallelKwargs(find_unused_parameters=False) |
|
|
accelerator = Accelerator( |
|
|
gradient_accumulation_steps=cfg.solver.gradient_accumulation_steps, |
|
|
mixed_precision=cfg.solver.mixed_precision, |
|
|
log_with="mlflow", |
|
|
project_dir="./mlruns", |
|
|
kwargs_handlers=[kwargs], |
|
|
) |
|
|
|
|
|
|
|
|
logging.basicConfig( |
|
|
format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", |
|
|
datefmt="%m/%d/%Y %H:%M:%S", |
|
|
level=logging.INFO, |
|
|
) |
|
|
logger.info(accelerator.state, main_process_only=False) |
|
|
if accelerator.is_local_main_process: |
|
|
transformers.utils.logging.set_verbosity_warning() |
|
|
diffusers.utils.logging.set_verbosity_info() |
|
|
else: |
|
|
transformers.utils.logging.set_verbosity_error() |
|
|
diffusers.utils.logging.set_verbosity_error() |
|
|
|
|
|
|
|
|
if cfg.seed is not None: |
|
|
seed_everything(cfg.seed) |
|
|
|
|
|
exp_name = cfg.exp_name |
|
|
save_dir = f"{cfg.output_dir}/{exp_name}" |
|
|
if accelerator.is_main_process: |
|
|
if not os.path.exists(save_dir): |
|
|
os.makedirs(save_dir) |
|
|
|
|
|
|
|
|
inference_config_path = "./configs/inference/inference.yaml" |
|
|
infer_config = OmegaConf.load(inference_config_path) |
|
|
|
|
|
if cfg.weight_dtype == "fp16": |
|
|
weight_dtype = torch.float16 |
|
|
elif cfg.weight_dtype == "bf16": |
|
|
weight_dtype = torch.bfloat16 |
|
|
elif cfg.weight_dtype == "fp32": |
|
|
weight_dtype = torch.float32 |
|
|
else: |
|
|
raise ValueError( |
|
|
f"Do not support weight dtype: {cfg.weight_dtype} during training" |
|
|
) |
|
|
|
|
|
sched_kwargs = OmegaConf.to_container(cfg.noise_scheduler_kwargs) |
|
|
if cfg.enable_zero_snr: |
|
|
sched_kwargs.update( |
|
|
rescale_betas_zero_snr=True, |
|
|
timestep_spacing="trailing", |
|
|
prediction_type="v_prediction", |
|
|
) |
|
|
val_noise_scheduler = DDIMScheduler(**sched_kwargs) |
|
|
sched_kwargs.update({"beta_schedule": "scaled_linear"}) |
|
|
train_noise_scheduler = DDIMScheduler(**sched_kwargs) |
|
|
|
|
|
image_enc = CLIPVisionModelWithProjection.from_pretrained( |
|
|
cfg.image_encoder_path, |
|
|
).to(dtype=weight_dtype, device="cuda") |
|
|
vae = AutoencoderKL.from_pretrained(cfg.vae_model_path).to( |
|
|
"cuda", dtype=weight_dtype |
|
|
) |
|
|
reference_unet = UNet2DConditionModel.from_pretrained_2d( |
|
|
cfg.base_model_path, |
|
|
subfolder="unet", |
|
|
unet_additional_kwargs={ |
|
|
"in_channels": 5, |
|
|
} |
|
|
).to(device="cuda", dtype=weight_dtype) |
|
|
|
|
|
denoising_unet = UNet3DConditionModel.from_pretrained_2d( |
|
|
cfg.base_model_path, |
|
|
cfg.mm_path, |
|
|
subfolder="unet", |
|
|
unet_additional_kwargs=OmegaConf.to_container( |
|
|
infer_config.unet_additional_kwargs |
|
|
), |
|
|
).to(device="cuda") |
|
|
|
|
|
pose_guider = PoseGuider( |
|
|
conditioning_embedding_channels=320, block_out_channels=(16, 32, 96, 256) |
|
|
).to(device="cuda", dtype=weight_dtype) |
|
|
|
|
|
stage1_ckpt_dir = cfg.stage1_ckpt_dir |
|
|
stage1_ckpt_step = cfg.stage1_ckpt_step |
|
|
denoising_unet.load_state_dict( |
|
|
torch.load( |
|
|
os.path.join(stage1_ckpt_dir, f"denoising_unet-{stage1_ckpt_step}.pth"), |
|
|
map_location="cpu", |
|
|
), |
|
|
strict=False, |
|
|
) |
|
|
|
|
|
reference_unet.load_state_dict( |
|
|
torch.load( |
|
|
os.path.join(stage1_ckpt_dir, f"reference_unet-{stage1_ckpt_step}.pth"), |
|
|
map_location="cpu", |
|
|
), |
|
|
strict=False, |
|
|
) |
|
|
pose_guider.load_state_dict( |
|
|
torch.load( |
|
|
os.path.join(stage1_ckpt_dir, f"pose_guider-{stage1_ckpt_step}.pth"), |
|
|
map_location="cpu", |
|
|
), |
|
|
strict=False, |
|
|
) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
vae.requires_grad_(False) |
|
|
image_enc.requires_grad_(False) |
|
|
reference_unet.requires_grad_(False) |
|
|
denoising_unet.requires_grad_(False) |
|
|
pose_guider.requires_grad_(False) |
|
|
|
|
|
|
|
|
for name, module in denoising_unet.named_modules(): |
|
|
if "motion_modules" in name: |
|
|
for params in module.parameters(): |
|
|
params.requires_grad = True |
|
|
|
|
|
reference_control_writer = ReferenceAttentionControl( |
|
|
reference_unet, |
|
|
do_classifier_free_guidance=False, |
|
|
mode="write", |
|
|
fusion_blocks="full", |
|
|
) |
|
|
reference_control_reader = ReferenceAttentionControl( |
|
|
denoising_unet, |
|
|
do_classifier_free_guidance=False, |
|
|
mode="read", |
|
|
fusion_blocks="full", |
|
|
) |
|
|
|
|
|
net = Net( |
|
|
reference_unet, |
|
|
denoising_unet, |
|
|
pose_guider, |
|
|
reference_control_writer, |
|
|
reference_control_reader, |
|
|
) |
|
|
|
|
|
if cfg.solver.enable_xformers_memory_efficient_attention: |
|
|
if is_xformers_available(): |
|
|
reference_unet.enable_xformers_memory_efficient_attention() |
|
|
denoising_unet.enable_xformers_memory_efficient_attention() |
|
|
else: |
|
|
raise ValueError( |
|
|
"xformers is not available. Make sure it is installed correctly" |
|
|
) |
|
|
|
|
|
if cfg.solver.gradient_checkpointing: |
|
|
reference_unet.enable_gradient_checkpointing() |
|
|
denoising_unet.enable_gradient_checkpointing() |
|
|
|
|
|
if cfg.solver.scale_lr: |
|
|
learning_rate = ( |
|
|
cfg.solver.learning_rate |
|
|
* cfg.solver.gradient_accumulation_steps |
|
|
* cfg.data.train_bs |
|
|
* accelerator.num_processes |
|
|
) |
|
|
else: |
|
|
learning_rate = cfg.solver.learning_rate |
|
|
|
|
|
|
|
|
if cfg.solver.use_8bit_adam: |
|
|
try: |
|
|
import bitsandbytes as bnb |
|
|
except ImportError: |
|
|
raise ImportError( |
|
|
"Please install bitsandbytes to use 8-bit Adam. You can do so by running `pip install bitsandbytes`" |
|
|
) |
|
|
|
|
|
optimizer_cls = bnb.optim.AdamW8bit |
|
|
else: |
|
|
optimizer_cls = torch.optim.AdamW |
|
|
|
|
|
trainable_params = list(filter(lambda p: p.requires_grad, net.parameters())) |
|
|
logger.info(f"Total trainable params {len(trainable_params)}") |
|
|
optimizer = optimizer_cls( |
|
|
trainable_params, |
|
|
lr=learning_rate, |
|
|
betas=(cfg.solver.adam_beta1, cfg.solver.adam_beta2), |
|
|
weight_decay=cfg.solver.adam_weight_decay, |
|
|
eps=cfg.solver.adam_epsilon, |
|
|
) |
|
|
|
|
|
|
|
|
lr_scheduler = get_scheduler( |
|
|
cfg.solver.lr_scheduler, |
|
|
optimizer=optimizer, |
|
|
num_warmup_steps=cfg.solver.lr_warmup_steps |
|
|
* cfg.solver.gradient_accumulation_steps, |
|
|
num_training_steps=cfg.solver.max_train_steps |
|
|
* cfg.solver.gradient_accumulation_steps, |
|
|
) |
|
|
|
|
|
train_dataset = HumanDanceVideoDataset( |
|
|
width=cfg.data.train_width, |
|
|
height=cfg.data.train_height, |
|
|
n_sample_frames=cfg.data.n_sample_frames, |
|
|
sample_rate=cfg.data.sample_rate, |
|
|
img_scale=(1.0, 1.0), |
|
|
data_meta_paths=cfg.data.meta_paths, |
|
|
) |
|
|
train_dataloader = torch.utils.data.DataLoader( |
|
|
train_dataset, batch_size=cfg.data.train_bs, shuffle=True, num_workers=4 |
|
|
) |
|
|
|
|
|
|
|
|
( |
|
|
net, |
|
|
optimizer, |
|
|
train_dataloader, |
|
|
lr_scheduler, |
|
|
) = accelerator.prepare( |
|
|
net, |
|
|
optimizer, |
|
|
train_dataloader, |
|
|
lr_scheduler, |
|
|
) |
|
|
|
|
|
|
|
|
num_update_steps_per_epoch = math.ceil( |
|
|
len(train_dataloader) / cfg.solver.gradient_accumulation_steps |
|
|
) |
|
|
|
|
|
num_train_epochs = math.ceil( |
|
|
cfg.solver.max_train_steps / num_update_steps_per_epoch |
|
|
) |
|
|
|
|
|
|
|
|
|
|
|
if accelerator.is_main_process: |
|
|
run_time = datetime.now().strftime("%Y%m%d-%H%M") |
|
|
accelerator.init_trackers( |
|
|
exp_name, |
|
|
init_kwargs={"mlflow": {"run_name": run_time}}, |
|
|
) |
|
|
|
|
|
mlflow.log_dict(OmegaConf.to_container(cfg), "config.yaml") |
|
|
|
|
|
|
|
|
total_batch_size = ( |
|
|
cfg.data.train_bs |
|
|
* accelerator.num_processes |
|
|
* cfg.solver.gradient_accumulation_steps |
|
|
) |
|
|
|
|
|
logger.info("***** Running training *****") |
|
|
logger.info(f" Num examples = {len(train_dataset)}") |
|
|
logger.info(f" Num Epochs = {num_train_epochs}") |
|
|
logger.info(f" Instantaneous batch size per device = {cfg.data.train_bs}") |
|
|
logger.info( |
|
|
f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}" |
|
|
) |
|
|
logger.info( |
|
|
f" Gradient Accumulation steps = {cfg.solver.gradient_accumulation_steps}" |
|
|
) |
|
|
logger.info(f" Total optimization steps = {cfg.solver.max_train_steps}") |
|
|
global_step = 0 |
|
|
first_epoch = 0 |
|
|
|
|
|
|
|
|
if cfg.resume_from_checkpoint: |
|
|
if cfg.resume_from_checkpoint != "latest": |
|
|
resume_dir = cfg.resume_from_checkpoint |
|
|
else: |
|
|
resume_dir = save_dir |
|
|
|
|
|
dirs = os.listdir(resume_dir) |
|
|
dirs = [d for d in dirs if d.startswith("checkpoint")] |
|
|
dirs = sorted(dirs, key=lambda x: int(x.split("-")[1])) |
|
|
path = dirs[-1] |
|
|
accelerator.load_state(os.path.join(resume_dir, path)) |
|
|
accelerator.print(f"Resuming from checkpoint {path}") |
|
|
global_step = int(path.split("-")[1]) |
|
|
|
|
|
first_epoch = global_step // num_update_steps_per_epoch |
|
|
resume_step = global_step % num_update_steps_per_epoch |
|
|
|
|
|
|
|
|
progress_bar = tqdm( |
|
|
range(global_step, cfg.solver.max_train_steps), |
|
|
disable=not accelerator.is_local_main_process, |
|
|
) |
|
|
progress_bar.set_description("Steps") |
|
|
|
|
|
for epoch in range(first_epoch, num_train_epochs): |
|
|
train_loss = 0.0 |
|
|
t_data_start = time.time() |
|
|
for step, batch in enumerate(train_dataloader): |
|
|
t_data = time.time() - t_data_start |
|
|
with accelerator.accumulate(net): |
|
|
|
|
|
pixel_values_vid = batch["pixel_values_vid"].to(weight_dtype) |
|
|
masked_pixel_values = batch["pixel_values_vid_agnostic"].to(weight_dtype) |
|
|
|
|
|
mask_of_pixel_values = batch["pixel_values_vid_agnostic_mask"].to(weight_dtype)[:,:,0:1,:,:] |
|
|
mask_of_pixel_values=mask_of_pixel_values.transpose(1, 2) |
|
|
with torch.no_grad(): |
|
|
video_length = pixel_values_vid.shape[1] |
|
|
|
|
|
pixel_values_vid = rearrange( |
|
|
pixel_values_vid, "b f c h w -> (b f) c h w" |
|
|
) |
|
|
latents = vae.encode(pixel_values_vid).latent_dist.sample() |
|
|
latents = rearrange( |
|
|
latents, "(b f) c h w -> b c f h w", f=video_length |
|
|
) |
|
|
latents = latents * 0.18215 |
|
|
|
|
|
masked_pixel_values = rearrange( |
|
|
masked_pixel_values, "b f c h w -> (b f) c h w" |
|
|
) |
|
|
masked_latents = vae.encode(masked_pixel_values).latent_dist.sample() |
|
|
masked_latents = rearrange( |
|
|
masked_latents, "(b f) c h w -> b c f h w", f=video_length |
|
|
) |
|
|
masked_latents = masked_latents * 0.18215 |
|
|
mask_of_latents = torch.nn.functional.interpolate(mask_of_pixel_values, size=(24,mask_of_pixel_values.shape[-2] // 8, mask_of_pixel_values.shape[-1] // 8)) |
|
|
|
|
|
|
|
|
noise = torch.randn_like(latents) |
|
|
if cfg.noise_offset > 0: |
|
|
noise += cfg.noise_offset * torch.randn( |
|
|
(latents.shape[0], latents.shape[1], 1, 1, 1), |
|
|
device=latents.device, |
|
|
) |
|
|
bsz = latents.shape[0] |
|
|
|
|
|
timesteps = torch.randint( |
|
|
0, |
|
|
train_noise_scheduler.num_train_timesteps, |
|
|
(bsz,), |
|
|
device=latents.device, |
|
|
) |
|
|
timesteps = timesteps.long() |
|
|
|
|
|
pixel_values_pose = batch["pixel_values_pose"] |
|
|
pixel_values_pose = pixel_values_pose.transpose( |
|
|
1, 2 |
|
|
) |
|
|
|
|
|
uncond_fwd = random.random() < cfg.uncond_ratio |
|
|
clip_image_list = [] |
|
|
ref_image_list = [] |
|
|
cloth_mask_list = [] |
|
|
for batch_idx, (ref_img, cloth_mask, clip_img) in enumerate( |
|
|
zip( |
|
|
batch["pixel_cloth"], |
|
|
batch["pixel_cloth_mask"], |
|
|
batch["clip_ref_img"], |
|
|
) |
|
|
): |
|
|
if uncond_fwd: |
|
|
clip_image_list.append(torch.zeros_like(clip_img)) |
|
|
else: |
|
|
clip_image_list.append(clip_img) |
|
|
ref_image_list.append(ref_img) |
|
|
cloth_mask_list.append(cloth_mask) |
|
|
|
|
|
with torch.no_grad(): |
|
|
ref_img = torch.stack(ref_image_list, dim=0).to( |
|
|
dtype=vae.dtype, device=vae.device |
|
|
) |
|
|
ref_image_latents = vae.encode( |
|
|
ref_img |
|
|
).latent_dist.sample() |
|
|
ref_image_latents = ref_image_latents * 0.18215 |
|
|
|
|
|
cloth_mask = torch.stack(cloth_mask_list, dim=0).to( |
|
|
dtype=vae.dtype, device=vae.device |
|
|
) |
|
|
cloth_mask = cloth_mask[:,0:1,:,:] |
|
|
cloth_mask = torch.nn.functional.interpolate(cloth_mask, size=(cloth_mask.shape[-2] // 8, cloth_mask.shape[-1] // 8)) |
|
|
|
|
|
|
|
|
clip_img = torch.stack(clip_image_list, dim=0).to( |
|
|
dtype=image_enc.dtype, device=image_enc.device |
|
|
) |
|
|
clip_img = clip_img.to(device="cuda", dtype=weight_dtype) |
|
|
clip_image_embeds = image_enc( |
|
|
clip_img.to("cuda", dtype=weight_dtype) |
|
|
).image_embeds |
|
|
clip_image_embeds = clip_image_embeds.unsqueeze(1) |
|
|
|
|
|
|
|
|
noisy_latents = train_noise_scheduler.add_noise( |
|
|
latents, noise, timesteps |
|
|
) |
|
|
|
|
|
|
|
|
if train_noise_scheduler.prediction_type == "epsilon": |
|
|
target = noise |
|
|
elif train_noise_scheduler.prediction_type == "v_prediction": |
|
|
target = train_noise_scheduler.get_velocity( |
|
|
latents, noise, timesteps |
|
|
) |
|
|
else: |
|
|
raise ValueError( |
|
|
f"Unknown prediction type {train_noise_scheduler.prediction_type}" |
|
|
) |
|
|
|
|
|
model_pred = net( |
|
|
|
|
|
torch.cat([noisy_latents,masked_latents,mask_of_latents],dim=1), |
|
|
timesteps, |
|
|
|
|
|
torch.cat([ref_image_latents, cloth_mask],dim=1), |
|
|
clip_image_embeds, |
|
|
pixel_values_pose, |
|
|
uncond_fwd=uncond_fwd, |
|
|
) |
|
|
|
|
|
if cfg.snr_gamma == 0: |
|
|
loss = F.mse_loss( |
|
|
model_pred.float(), target.float(), reduction="mean" |
|
|
) |
|
|
else: |
|
|
snr = compute_snr(train_noise_scheduler, timesteps) |
|
|
if train_noise_scheduler.config.prediction_type == "v_prediction": |
|
|
|
|
|
snr = snr + 1 |
|
|
mse_loss_weights = ( |
|
|
torch.stack( |
|
|
[snr, cfg.snr_gamma * torch.ones_like(timesteps)], dim=1 |
|
|
).min(dim=1)[0] |
|
|
/ snr |
|
|
) |
|
|
loss = F.mse_loss( |
|
|
model_pred.float(), target.float(), reduction="none" |
|
|
) |
|
|
loss = ( |
|
|
loss.mean(dim=list(range(1, len(loss.shape)))) |
|
|
* mse_loss_weights |
|
|
) |
|
|
loss = loss.mean() |
|
|
|
|
|
|
|
|
avg_loss = accelerator.gather(loss.repeat(cfg.data.train_bs)).mean() |
|
|
train_loss += avg_loss.item() / cfg.solver.gradient_accumulation_steps |
|
|
|
|
|
|
|
|
accelerator.backward(loss) |
|
|
if accelerator.sync_gradients: |
|
|
accelerator.clip_grad_norm_( |
|
|
trainable_params, |
|
|
cfg.solver.max_grad_norm, |
|
|
) |
|
|
optimizer.step() |
|
|
lr_scheduler.step() |
|
|
optimizer.zero_grad() |
|
|
|
|
|
if accelerator.sync_gradients: |
|
|
reference_control_reader.clear() |
|
|
reference_control_writer.clear() |
|
|
progress_bar.update(1) |
|
|
global_step += 1 |
|
|
accelerator.log({"train_loss": train_loss}, step=global_step) |
|
|
train_loss = 0.0 |
|
|
|
|
|
if global_step % cfg.val.validation_steps == 0: |
|
|
if accelerator.is_main_process: |
|
|
generator = torch.Generator(device=accelerator.device) |
|
|
generator.manual_seed(cfg.seed) |
|
|
|
|
|
log_validation( |
|
|
vae=vae, |
|
|
image_enc=image_enc, |
|
|
net=net, |
|
|
scheduler=val_noise_scheduler, |
|
|
accelerator=accelerator, |
|
|
width=cfg.data.train_width, |
|
|
height=cfg.data.train_height, |
|
|
global_step=global_step, |
|
|
clip_length=cfg.data.n_sample_frames, |
|
|
generator=generator, |
|
|
|
|
|
) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
logs = { |
|
|
"step_loss": loss.detach().item(), |
|
|
"lr": lr_scheduler.get_last_lr()[0], |
|
|
"td": f"{t_data:.2f}s", |
|
|
} |
|
|
t_data_start = time.time() |
|
|
progress_bar.set_postfix(**logs) |
|
|
|
|
|
if global_step >= cfg.solver.max_train_steps: |
|
|
break |
|
|
|
|
|
|
|
|
if accelerator.is_main_process: |
|
|
save_path = os.path.join(save_dir, f"checkpoint-{global_step}") |
|
|
delete_additional_ckpt(save_dir, 1) |
|
|
|
|
|
|
|
|
unwrap_net = accelerator.unwrap_model(net) |
|
|
save_checkpoint( |
|
|
unwrap_net.denoising_unet, |
|
|
save_dir, |
|
|
"motion_module", |
|
|
global_step, |
|
|
total_limit=3, |
|
|
) |
|
|
|
|
|
|
|
|
accelerator.wait_for_everyone() |
|
|
accelerator.end_training() |
|
|
|
|
|
|
|
|
def save_checkpoint(model, save_dir, prefix, ckpt_num, total_limit=None): |
|
|
save_path = osp.join(save_dir, f"{prefix}-{ckpt_num}.pth") |
|
|
|
|
|
if total_limit is not None: |
|
|
checkpoints = os.listdir(save_dir) |
|
|
checkpoints = [d for d in checkpoints if d.startswith(prefix)] |
|
|
checkpoints = sorted( |
|
|
checkpoints, key=lambda x: int(x.split("-")[1].split(".")[0]) |
|
|
) |
|
|
|
|
|
if len(checkpoints) >= total_limit: |
|
|
num_to_remove = len(checkpoints) - total_limit + 1 |
|
|
removing_checkpoints = checkpoints[0:num_to_remove] |
|
|
logger.info( |
|
|
f"{len(checkpoints)} checkpoints already exist, removing {len(removing_checkpoints)} checkpoints" |
|
|
) |
|
|
logger.info(f"removing checkpoints: {', '.join(removing_checkpoints)}") |
|
|
|
|
|
for removing_checkpoint in removing_checkpoints: |
|
|
removing_checkpoint = os.path.join(save_dir, removing_checkpoint) |
|
|
os.remove(removing_checkpoint) |
|
|
|
|
|
mm_state_dict = OrderedDict() |
|
|
state_dict = model.state_dict() |
|
|
for key in state_dict: |
|
|
if "motion_module" in key: |
|
|
mm_state_dict[key] = state_dict[key] |
|
|
|
|
|
torch.save(mm_state_dict, save_path) |
|
|
|
|
|
|
|
|
def decode_latents(vae, latents): |
|
|
video_length = latents.shape[2] |
|
|
latents = 1 / 0.18215 * latents |
|
|
latents = rearrange(latents, "b c f h w -> (b f) c h w") |
|
|
|
|
|
video = [] |
|
|
for frame_idx in tqdm(range(latents.shape[0])): |
|
|
video.append(vae.decode(latents[frame_idx : frame_idx + 1]).sample) |
|
|
video = torch.cat(video) |
|
|
video = rearrange(video, "(b f) c h w -> b c f h w", f=video_length) |
|
|
video = (video / 2 + 0.5).clamp(0, 1) |
|
|
|
|
|
video = video.cpu().float().numpy() |
|
|
return video |
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
parser = argparse.ArgumentParser() |
|
|
parser.add_argument("--config", type=str, default="./configs/training/stage2.yaml") |
|
|
args = parser.parse_args() |
|
|
|
|
|
if args.config[-5:] == ".yaml": |
|
|
config = OmegaConf.load(args.config) |
|
|
elif args.config[-3:] == ".py": |
|
|
config = import_filename(args.config).cfg |
|
|
else: |
|
|
raise ValueError("Do not support this format config file") |
|
|
main(config) |