VIVID / train_stage_2.py
Teatime666's picture
Add files using upload-large-folder tool
823e49a verified
raw
history blame
31.2 kB
import argparse
import copy
import logging
import math
import os
import os.path as osp
import random
import time
import warnings
from collections import OrderedDict
from datetime import datetime
from pathlib import Path
from tempfile import TemporaryDirectory
from src.utils.util import get_fps, read_frames, save_videos_grid
import diffusers
import mlflow
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.utils.checkpoint
import transformers
from accelerate import Accelerator
from accelerate.logging import get_logger
from accelerate.utils import DistributedDataParallelKwargs
from diffusers import AutoencoderKL, DDIMScheduler
from diffusers.optimization import get_scheduler
from diffusers.utils import check_min_version
from diffusers.utils.import_utils import is_xformers_available
from einops import rearrange
from omegaconf import OmegaConf
from PIL import Image
from torchvision import transforms
from tqdm.auto import tqdm
from transformers import CLIPVisionModelWithProjection
from src.dataset.dance_video import HumanDanceVideoDataset
from src.models.mutual_self_attention import ReferenceAttentionControl
from src.models.pose_guider import PoseGuider
from src.models.unet_2d_condition import UNet2DConditionModel
from src.models.unet_3d import UNet3DConditionModel
from src.pipelines.pipeline_pose2vid_long import Pose2VideoPipeline
from src.utils.util import (
delete_additional_ckpt,
import_filename,
read_frames,
save_videos_grid,
seed_everything,
)
warnings.filterwarnings("ignore")
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
check_min_version("0.10.0.dev0")
logger = get_logger(__name__, log_level="INFO")
class Net(nn.Module):
def __init__(
self,
reference_unet: UNet2DConditionModel,
denoising_unet: UNet3DConditionModel,
pose_guider: PoseGuider,
reference_control_writer,
reference_control_reader,
):
super().__init__()
self.reference_unet = reference_unet
self.denoising_unet = denoising_unet
self.pose_guider = pose_guider
self.reference_control_writer = reference_control_writer
self.reference_control_reader = reference_control_reader
def forward(
self,
noisy_latents,
timesteps,
ref_image_latents,
clip_image_embeds,
pose_img,
uncond_fwd: bool = False,
):
pose_cond_tensor = pose_img.to(device="cuda")
pose_fea = self.pose_guider(pose_cond_tensor)
if not uncond_fwd:
ref_timesteps = torch.zeros_like(timesteps)
self.reference_unet(
ref_image_latents,
ref_timesteps,
encoder_hidden_states=clip_image_embeds,
return_dict=False,
)
self.reference_control_reader.update(self.reference_control_writer)
model_pred = self.denoising_unet(
noisy_latents,
timesteps,
pose_cond_fea=pose_fea,
encoder_hidden_states=clip_image_embeds,
).sample
return model_pred
def compute_snr(noise_scheduler, timesteps):
"""
Computes SNR as per
https://github.com/TiankaiHang/Min-SNR-Diffusion-Training/blob/521b624bd70c67cee4bdf49225915f5945a872e3/guided_diffusion/gaussian_diffusion.py#L847-L849
"""
alphas_cumprod = noise_scheduler.alphas_cumprod
sqrt_alphas_cumprod = alphas_cumprod**0.5
sqrt_one_minus_alphas_cumprod = (1.0 - alphas_cumprod) ** 0.5
# Expand the tensors.
# Adapted from https://github.com/TiankaiHang/Min-SNR-Diffusion-Training/blob/521b624bd70c67cee4bdf49225915f5945a872e3/guided_diffusion/gaussian_diffusion.py#L1026
sqrt_alphas_cumprod = sqrt_alphas_cumprod.to(device=timesteps.device)[
timesteps
].float()
while len(sqrt_alphas_cumprod.shape) < len(timesteps.shape):
sqrt_alphas_cumprod = sqrt_alphas_cumprod[..., None]
alpha = sqrt_alphas_cumprod.expand(timesteps.shape)
sqrt_one_minus_alphas_cumprod = sqrt_one_minus_alphas_cumprod.to(
device=timesteps.device
)[timesteps].float()
while len(sqrt_one_minus_alphas_cumprod.shape) < len(timesteps.shape):
sqrt_one_minus_alphas_cumprod = sqrt_one_minus_alphas_cumprod[..., None]
sigma = sqrt_one_minus_alphas_cumprod.expand(timesteps.shape)
# Compute SNR.
snr = (alpha / sigma) ** 2
return snr
def log_validation(
vae,
image_enc,
net,
scheduler,
accelerator,
width,
height,
global_step,
clip_length=24,
generator=None,
):
logger.info("Running validation... ")
ori_net = accelerator.unwrap_model(net)
reference_unet = ori_net.reference_unet
denoising_unet = ori_net.denoising_unet
pose_guider = ori_net.pose_guider
if generator is None:
generator = torch.manual_seed(42)
tmp_denoising_unet = copy.deepcopy(denoising_unet)
tmp_denoising_unet = tmp_denoising_unet.to(dtype=torch.float16)
pipe = Pose2VideoPipeline(
vae=vae,
image_encoder=image_enc,
reference_unet=reference_unet,
denoising_unet=tmp_denoising_unet,
pose_guider=pose_guider,
scheduler=scheduler,
)
pipe = pipe.to(accelerator.device)
date_str = datetime.now().strftime("%Y%m%d")
time_str = datetime.now().strftime("%H%M")
save_dir_name = f"{time_str}"
save_dir = Path(f"vividfuxian_motion/{date_str}/{save_dir_name}")
save_dir.mkdir(exist_ok=True, parents=True)
model_video_paths = ["/mnt/lpai-dione/ssai/cvg/team/wjj/ViViD/dataset/ViViD/dresses/videos/803128_detail.mp4"]
cloth_image_paths=["/mnt/lpai-dione/ssai/cvg/team/wjj/ViViD/dataset/ViViD/dresses/images/1060638_in_xl.jpg"]
transform = transforms.Compose(
[transforms.Resize((height, width)), transforms.ToTensor()]
)
for model_image_path in model_video_paths:
src_fps = get_fps(model_image_path)
model_name = Path(model_image_path).stem
agnostic_path=model_image_path.replace("videos","agnostic")
agn_mask_path=model_image_path.replace("videos","agnostic_mask")
densepose_path=model_image_path.replace("videos","densepose")
video_tensor_list=[]
video_images=read_frames(model_image_path)
for vid_image_pil in video_images[:clip_length]:
video_tensor_list.append(transform(vid_image_pil))
video_tensor = torch.stack(video_tensor_list, dim=0) # (f, c, h, w)
video_tensor = video_tensor.transpose(0, 1)
agnostic_list=[]
agnostic_images=read_frames(agnostic_path)
for agnostic_image_pil in agnostic_images[:clip_length]:
agnostic_list.append(agnostic_image_pil)
agn_mask_list=[]
agn_mask_images=read_frames(agn_mask_path)
for agn_mask_image_pil in agn_mask_images[:clip_length]:
agn_mask_list.append(agn_mask_image_pil)
pose_list=[]
pose_images=read_frames(densepose_path)
for pose_image_pil in pose_images[:clip_length]:
pose_list.append(pose_image_pil)
video_tensor = video_tensor.unsqueeze(0)
for cloth_image_path in cloth_image_paths:
cloth_name = Path(cloth_image_path).stem
cloth_image_pil = Image.open(cloth_image_path).convert("RGB")
cloth_mask_path=cloth_image_path.replace("cloth","cloth_mask")
cloth_mask_pil = Image.open(cloth_mask_path).convert("RGB")
pipeline_output = pipe(
agnostic_list,
agn_mask_list,
cloth_image_pil,
cloth_mask_pil,
pose_list,
width,
height,
clip_length,
20,
3.5,
generator=generator,
)
video = pipeline_output.videos
video = torch.cat([video_tensor,video], dim=0)
save_videos_grid(
video,
f"{save_dir}/{global_step:06d}-{model_name}_{cloth_name}.mp4",
n_rows=2,
fps=src_fps,
)
del tmp_denoising_unet
del pipe
torch.cuda.empty_cache()
return video
def main(cfg):
kwargs = DistributedDataParallelKwargs(find_unused_parameters=False)
accelerator = Accelerator(
gradient_accumulation_steps=cfg.solver.gradient_accumulation_steps,
mixed_precision=cfg.solver.mixed_precision,
log_with="mlflow",
project_dir="./mlruns",
kwargs_handlers=[kwargs],
)
# Make one log on every process with the configuration for debugging.
logging.basicConfig(
format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
datefmt="%m/%d/%Y %H:%M:%S",
level=logging.INFO,
)
logger.info(accelerator.state, main_process_only=False)
if accelerator.is_local_main_process:
transformers.utils.logging.set_verbosity_warning()
diffusers.utils.logging.set_verbosity_info()
else:
transformers.utils.logging.set_verbosity_error()
diffusers.utils.logging.set_verbosity_error()
# If passed along, set the training seed now.
if cfg.seed is not None:
seed_everything(cfg.seed)
exp_name = cfg.exp_name
save_dir = f"{cfg.output_dir}/{exp_name}"
if accelerator.is_main_process:
if not os.path.exists(save_dir):
os.makedirs(save_dir)
# inference_config_path = "./configs/inference/inference_v2.yaml"
inference_config_path = "./configs/inference/inference.yaml"
infer_config = OmegaConf.load(inference_config_path)
if cfg.weight_dtype == "fp16":
weight_dtype = torch.float16
elif cfg.weight_dtype == "bf16":
weight_dtype = torch.bfloat16
elif cfg.weight_dtype == "fp32":
weight_dtype = torch.float32
else:
raise ValueError(
f"Do not support weight dtype: {cfg.weight_dtype} during training"
)
sched_kwargs = OmegaConf.to_container(cfg.noise_scheduler_kwargs)
if cfg.enable_zero_snr:
sched_kwargs.update(
rescale_betas_zero_snr=True,
timestep_spacing="trailing",
prediction_type="v_prediction",
)
val_noise_scheduler = DDIMScheduler(**sched_kwargs)
sched_kwargs.update({"beta_schedule": "scaled_linear"})
train_noise_scheduler = DDIMScheduler(**sched_kwargs)
image_enc = CLIPVisionModelWithProjection.from_pretrained(
cfg.image_encoder_path,
).to(dtype=weight_dtype, device="cuda")
vae = AutoencoderKL.from_pretrained(cfg.vae_model_path).to(
"cuda", dtype=weight_dtype
)
reference_unet = UNet2DConditionModel.from_pretrained_2d(
cfg.base_model_path,
subfolder="unet",
unet_additional_kwargs={
"in_channels": 5,
}
).to(device="cuda", dtype=weight_dtype)
denoising_unet = UNet3DConditionModel.from_pretrained_2d(
cfg.base_model_path,
cfg.mm_path,
subfolder="unet",
unet_additional_kwargs=OmegaConf.to_container(
infer_config.unet_additional_kwargs
),
).to(device="cuda")
pose_guider = PoseGuider(
conditioning_embedding_channels=320, block_out_channels=(16, 32, 96, 256)
).to(device="cuda", dtype=weight_dtype)
stage1_ckpt_dir = cfg.stage1_ckpt_dir
stage1_ckpt_step = cfg.stage1_ckpt_step
denoising_unet.load_state_dict(
torch.load(
os.path.join(stage1_ckpt_dir, f"denoising_unet-{stage1_ckpt_step}.pth"),
map_location="cpu",
),
strict=False,
)
reference_unet.load_state_dict(
torch.load(
os.path.join(stage1_ckpt_dir, f"reference_unet-{stage1_ckpt_step}.pth"),
map_location="cpu",
),
strict=False,
)
pose_guider.load_state_dict(
torch.load(
os.path.join(stage1_ckpt_dir, f"pose_guider-{stage1_ckpt_step}.pth"),
map_location="cpu",
),
strict=False,
)
# Freeze
vae.requires_grad_(False)
image_enc.requires_grad_(False)
reference_unet.requires_grad_(False)
denoising_unet.requires_grad_(False)
pose_guider.requires_grad_(False)
# Set motion module learnable
for name, module in denoising_unet.named_modules():
if "motion_modules" in name:
for params in module.parameters():
params.requires_grad = True
reference_control_writer = ReferenceAttentionControl(
reference_unet,
do_classifier_free_guidance=False,
mode="write",
fusion_blocks="full",
)
reference_control_reader = ReferenceAttentionControl(
denoising_unet,
do_classifier_free_guidance=False,
mode="read",
fusion_blocks="full",
)
net = Net(
reference_unet,
denoising_unet,
pose_guider,
reference_control_writer,
reference_control_reader,
)
if cfg.solver.enable_xformers_memory_efficient_attention:
if is_xformers_available():
reference_unet.enable_xformers_memory_efficient_attention()
denoising_unet.enable_xformers_memory_efficient_attention()
else:
raise ValueError(
"xformers is not available. Make sure it is installed correctly"
)
if cfg.solver.gradient_checkpointing:
reference_unet.enable_gradient_checkpointing()
denoising_unet.enable_gradient_checkpointing()
if cfg.solver.scale_lr:
learning_rate = (
cfg.solver.learning_rate
* cfg.solver.gradient_accumulation_steps
* cfg.data.train_bs
* accelerator.num_processes
)
else:
learning_rate = cfg.solver.learning_rate
# Initialize the optimizer
if cfg.solver.use_8bit_adam:
try:
import bitsandbytes as bnb
except ImportError:
raise ImportError(
"Please install bitsandbytes to use 8-bit Adam. You can do so by running `pip install bitsandbytes`"
)
optimizer_cls = bnb.optim.AdamW8bit
else:
optimizer_cls = torch.optim.AdamW
trainable_params = list(filter(lambda p: p.requires_grad, net.parameters()))
logger.info(f"Total trainable params {len(trainable_params)}")
optimizer = optimizer_cls(
trainable_params,
lr=learning_rate,
betas=(cfg.solver.adam_beta1, cfg.solver.adam_beta2),
weight_decay=cfg.solver.adam_weight_decay,
eps=cfg.solver.adam_epsilon,
)
# Scheduler
lr_scheduler = get_scheduler(
cfg.solver.lr_scheduler,
optimizer=optimizer,
num_warmup_steps=cfg.solver.lr_warmup_steps
* cfg.solver.gradient_accumulation_steps,
num_training_steps=cfg.solver.max_train_steps
* cfg.solver.gradient_accumulation_steps,
)
train_dataset = HumanDanceVideoDataset(
width=cfg.data.train_width,
height=cfg.data.train_height,
n_sample_frames=cfg.data.n_sample_frames,
sample_rate=cfg.data.sample_rate,
img_scale=(1.0, 1.0),
data_meta_paths=cfg.data.meta_paths,
)
train_dataloader = torch.utils.data.DataLoader(
train_dataset, batch_size=cfg.data.train_bs, shuffle=True, num_workers=4
)
# Prepare everything with our `accelerator`.
(
net,
optimizer,
train_dataloader,
lr_scheduler,
) = accelerator.prepare(
net,
optimizer,
train_dataloader,
lr_scheduler,
)
# We need to recalculate our total training steps as the size of the training dataloader may have changed.
num_update_steps_per_epoch = math.ceil(
len(train_dataloader) / cfg.solver.gradient_accumulation_steps
)
# Afterwards we recalculate our number of training epochs
num_train_epochs = math.ceil(
cfg.solver.max_train_steps / num_update_steps_per_epoch
)
# We need to initialize the trackers we use, and also store our configuration.
# The trackers initializes automatically on the main process.
if accelerator.is_main_process:
run_time = datetime.now().strftime("%Y%m%d-%H%M")
accelerator.init_trackers(
exp_name,
init_kwargs={"mlflow": {"run_name": run_time}},
)
# dump config file
mlflow.log_dict(OmegaConf.to_container(cfg), "config.yaml")
# Train!
total_batch_size = (
cfg.data.train_bs
* accelerator.num_processes
* cfg.solver.gradient_accumulation_steps
)
logger.info("***** Running training *****")
logger.info(f" Num examples = {len(train_dataset)}")
logger.info(f" Num Epochs = {num_train_epochs}")
logger.info(f" Instantaneous batch size per device = {cfg.data.train_bs}")
logger.info(
f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}"
)
logger.info(
f" Gradient Accumulation steps = {cfg.solver.gradient_accumulation_steps}"
)
logger.info(f" Total optimization steps = {cfg.solver.max_train_steps}")
global_step = 0
first_epoch = 0
# Potentially load in the weights and states from a previous save
if cfg.resume_from_checkpoint:
if cfg.resume_from_checkpoint != "latest":
resume_dir = cfg.resume_from_checkpoint
else:
resume_dir = save_dir
# Get the most recent checkpoint
dirs = os.listdir(resume_dir)
dirs = [d for d in dirs if d.startswith("checkpoint")]
dirs = sorted(dirs, key=lambda x: int(x.split("-")[1]))
path = dirs[-1]
accelerator.load_state(os.path.join(resume_dir, path))
accelerator.print(f"Resuming from checkpoint {path}")
global_step = int(path.split("-")[1])
first_epoch = global_step // num_update_steps_per_epoch
resume_step = global_step % num_update_steps_per_epoch
# Only show the progress bar once on each machine.
progress_bar = tqdm(
range(global_step, cfg.solver.max_train_steps),
disable=not accelerator.is_local_main_process,
)
progress_bar.set_description("Steps")
for epoch in range(first_epoch, num_train_epochs):
train_loss = 0.0
t_data_start = time.time()
for step, batch in enumerate(train_dataloader):
t_data = time.time() - t_data_start
with accelerator.accumulate(net):
# Convert videos to latent space
pixel_values_vid = batch["pixel_values_vid"].to(weight_dtype)
masked_pixel_values = batch["pixel_values_vid_agnostic"].to(weight_dtype)
# mask_of_pixel_values = batch["pixel_values_vid_agnostic_mask"].to(weight_dtype)
mask_of_pixel_values = batch["pixel_values_vid_agnostic_mask"].to(weight_dtype)[:,:,0:1,:,:]
mask_of_pixel_values=mask_of_pixel_values.transpose(1, 2)#b f c h w->b c f h w
with torch.no_grad():
video_length = pixel_values_vid.shape[1]
pixel_values_vid = rearrange(
pixel_values_vid, "b f c h w -> (b f) c h w"
)
latents = vae.encode(pixel_values_vid).latent_dist.sample()
latents = rearrange(
latents, "(b f) c h w -> b c f h w", f=video_length
)
latents = latents * 0.18215
masked_pixel_values = rearrange(
masked_pixel_values, "b f c h w -> (b f) c h w"
)
masked_latents = vae.encode(masked_pixel_values).latent_dist.sample()
masked_latents = rearrange(
masked_latents, "(b f) c h w -> b c f h w", f=video_length
)
masked_latents = masked_latents * 0.18215
mask_of_latents = torch.nn.functional.interpolate(mask_of_pixel_values, size=(24,mask_of_pixel_values.shape[-2] // 8, mask_of_pixel_values.shape[-1] // 8))
noise = torch.randn_like(latents)
if cfg.noise_offset > 0:
noise += cfg.noise_offset * torch.randn(
(latents.shape[0], latents.shape[1], 1, 1, 1),
device=latents.device,
)
bsz = latents.shape[0]
# Sample a random timestep for each video
timesteps = torch.randint(
0,
train_noise_scheduler.num_train_timesteps,
(bsz,),
device=latents.device,
)
timesteps = timesteps.long()
pixel_values_pose = batch["pixel_values_pose"] # (bs, f, c, H, W)
pixel_values_pose = pixel_values_pose.transpose(
1, 2
) # (bs, c, f, H, W)
uncond_fwd = random.random() < cfg.uncond_ratio
clip_image_list = []
ref_image_list = []
cloth_mask_list = []
for batch_idx, (ref_img, cloth_mask, clip_img) in enumerate(
zip(
batch["pixel_cloth"],
batch["pixel_cloth_mask"],
batch["clip_ref_img"],
)
):
if uncond_fwd:
clip_image_list.append(torch.zeros_like(clip_img))
else:
clip_image_list.append(clip_img)
ref_image_list.append(ref_img)
cloth_mask_list.append(cloth_mask)
with torch.no_grad():
ref_img = torch.stack(ref_image_list, dim=0).to(
dtype=vae.dtype, device=vae.device
)
ref_image_latents = vae.encode(
ref_img
).latent_dist.sample() # (bs, d, 64, 64)
ref_image_latents = ref_image_latents * 0.18215
cloth_mask = torch.stack(cloth_mask_list, dim=0).to(
dtype=vae.dtype, device=vae.device
)
cloth_mask = cloth_mask[:,0:1,:,:]
cloth_mask = torch.nn.functional.interpolate(cloth_mask, size=(cloth_mask.shape[-2] // 8, cloth_mask.shape[-1] // 8))
clip_img = torch.stack(clip_image_list, dim=0).to(
dtype=image_enc.dtype, device=image_enc.device
)
clip_img = clip_img.to(device="cuda", dtype=weight_dtype)
clip_image_embeds = image_enc(
clip_img.to("cuda", dtype=weight_dtype)
).image_embeds
clip_image_embeds = clip_image_embeds.unsqueeze(1) # (bs, 1, d)
# add noise
noisy_latents = train_noise_scheduler.add_noise(
latents, noise, timesteps
)
# Get the target for loss depending on the prediction type
if train_noise_scheduler.prediction_type == "epsilon":
target = noise
elif train_noise_scheduler.prediction_type == "v_prediction":
target = train_noise_scheduler.get_velocity(
latents, noise, timesteps
)
else:
raise ValueError(
f"Unknown prediction type {train_noise_scheduler.prediction_type}"
)
# ---- Forward!!! -----
model_pred = net(
# noisy_latents,
torch.cat([noisy_latents,masked_latents,mask_of_latents],dim=1),
timesteps,
# ref_image_latents,
torch.cat([ref_image_latents, cloth_mask],dim=1),
clip_image_embeds,
pixel_values_pose,
uncond_fwd=uncond_fwd,
)
if cfg.snr_gamma == 0:
loss = F.mse_loss(
model_pred.float(), target.float(), reduction="mean"
)
else:
snr = compute_snr(train_noise_scheduler, timesteps)
if train_noise_scheduler.config.prediction_type == "v_prediction":
# Velocity objective requires that we add one to SNR values before we divide by them.
snr = snr + 1
mse_loss_weights = (
torch.stack(
[snr, cfg.snr_gamma * torch.ones_like(timesteps)], dim=1
).min(dim=1)[0]
/ snr
)
loss = F.mse_loss(
model_pred.float(), target.float(), reduction="none"
)
loss = (
loss.mean(dim=list(range(1, len(loss.shape))))
* mse_loss_weights
)
loss = loss.mean()
# Gather the losses across all processes for logging (if we use distributed training).
avg_loss = accelerator.gather(loss.repeat(cfg.data.train_bs)).mean()
train_loss += avg_loss.item() / cfg.solver.gradient_accumulation_steps
# Backpropagate
accelerator.backward(loss)
if accelerator.sync_gradients:
accelerator.clip_grad_norm_(
trainable_params,
cfg.solver.max_grad_norm,
)
optimizer.step()
lr_scheduler.step()
optimizer.zero_grad()
if accelerator.sync_gradients:
reference_control_reader.clear()
reference_control_writer.clear()
progress_bar.update(1)
global_step += 1
accelerator.log({"train_loss": train_loss}, step=global_step)
train_loss = 0.0
if global_step % cfg.val.validation_steps == 0:
if accelerator.is_main_process:
generator = torch.Generator(device=accelerator.device)
generator.manual_seed(cfg.seed)
log_validation(
vae=vae,
image_enc=image_enc,
net=net,
scheduler=val_noise_scheduler,
accelerator=accelerator,
width=cfg.data.train_width,
height=cfg.data.train_height,
global_step=global_step,
clip_length=cfg.data.n_sample_frames,
generator=generator,
)
# for sample_id, sample_dict in enumerate(sample_dicts):
# sample_name = sample_dict["name"]
# vid = sample_dict["vid"]
# with TemporaryDirectory() as temp_dir:
# out_file = Path(
# f"{temp_dir}/{global_step:06d}-{sample_name}.gif"
# )
# save_videos_grid(vid, out_file, n_rows=2)
# mlflow.log_artifact(out_file)
logs = {
"step_loss": loss.detach().item(),
"lr": lr_scheduler.get_last_lr()[0],
"td": f"{t_data:.2f}s",
}
t_data_start = time.time()
progress_bar.set_postfix(**logs)
if global_step >= cfg.solver.max_train_steps:
break
# save model after each epoch
if accelerator.is_main_process:
save_path = os.path.join(save_dir, f"checkpoint-{global_step}")
delete_additional_ckpt(save_dir, 1)
# accelerator.save_state(save_path)
# save motion module only
unwrap_net = accelerator.unwrap_model(net)
save_checkpoint(
unwrap_net.denoising_unet,
save_dir,
"motion_module",
global_step,
total_limit=3,
)
# Create the pipeline using the trained modules and save it.
accelerator.wait_for_everyone()
accelerator.end_training()
def save_checkpoint(model, save_dir, prefix, ckpt_num, total_limit=None):
save_path = osp.join(save_dir, f"{prefix}-{ckpt_num}.pth")
if total_limit is not None:
checkpoints = os.listdir(save_dir)
checkpoints = [d for d in checkpoints if d.startswith(prefix)]
checkpoints = sorted(
checkpoints, key=lambda x: int(x.split("-")[1].split(".")[0])
)
if len(checkpoints) >= total_limit:
num_to_remove = len(checkpoints) - total_limit + 1
removing_checkpoints = checkpoints[0:num_to_remove]
logger.info(
f"{len(checkpoints)} checkpoints already exist, removing {len(removing_checkpoints)} checkpoints"
)
logger.info(f"removing checkpoints: {', '.join(removing_checkpoints)}")
for removing_checkpoint in removing_checkpoints:
removing_checkpoint = os.path.join(save_dir, removing_checkpoint)
os.remove(removing_checkpoint)
mm_state_dict = OrderedDict()
state_dict = model.state_dict()
for key in state_dict:
if "motion_module" in key:
mm_state_dict[key] = state_dict[key]
torch.save(mm_state_dict, save_path)
def decode_latents(vae, latents):
video_length = latents.shape[2]
latents = 1 / 0.18215 * latents
latents = rearrange(latents, "b c f h w -> (b f) c h w")
# video = self.vae.decode(latents).sample
video = []
for frame_idx in tqdm(range(latents.shape[0])):
video.append(vae.decode(latents[frame_idx : frame_idx + 1]).sample)
video = torch.cat(video)
video = rearrange(video, "(b f) c h w -> b c f h w", f=video_length)
video = (video / 2 + 0.5).clamp(0, 1)
# we always cast to float32 as this does not cause significant overhead and is compatible with bfloa16
video = video.cpu().float().numpy()
return video
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--config", type=str, default="./configs/training/stage2.yaml")
args = parser.parse_args()
if args.config[-5:] == ".yaml":
config = OmegaConf.load(args.config)
elif args.config[-3:] == ".py":
config = import_filename(args.config).cfg
else:
raise ValueError("Do not support this format config file")
main(config)