import argparse from datetime import datetime from pathlib import Path import sys import torch import os from diffusers import AutoencoderKL, DDIMScheduler from omegaconf import OmegaConf from PIL import Image from torchvision import transforms from transformers import CLIPVisionModelWithProjection from src.models.pose_guider import PoseGuider from src.models.unet_2d_condition import UNet2DConditionModel from src.models.unet_3d import UNet3DConditionModel from src.pipelines.pipeline_pose2vid_long import Pose2VideoPipeline from src.utils.util import get_fps, read_frames, save_videos_grid def parse_args(): parser = argparse.ArgumentParser() parser.add_argument("--config",type=str,default="/mnt/lpai-dione/ssai/cvg/team/wjj/ViViD/configs/prompts/valid.yaml") parser.add_argument("-W", type=int, default=384) parser.add_argument("-H", type=int, default=512) parser.add_argument("-L", type=int, default=24) parser.add_argument("--seed", type=int, default=42) parser.add_argument("--cfg", type=float, default=3.5) parser.add_argument("--steps", type=int, default=20) parser.add_argument("--fps", type=int) args = parser.parse_args() return args def main(): args = parse_args() config = OmegaConf.load(args.config) if config.weight_dtype == "fp16": weight_dtype = torch.float16 else: weight_dtype = torch.float32 vae = AutoencoderKL.from_pretrained( config.pretrained_vae_path, ).to("cuda", dtype=weight_dtype) reference_unet = UNet2DConditionModel.from_pretrained_2d( config.pretrained_base_model_path, subfolder="unet", unet_additional_kwargs={ "in_channels": 5, } ).to(dtype=weight_dtype, device="cuda") inference_config_path = config.inference_config #'/mnt/lpai-dione/ssai/cvg/team/wjj/ViViD/configs/inference/inference.yaml' infer_config = OmegaConf.load(inference_config_path) denoising_unet = UNet3DConditionModel.from_pretrained_2d( config.pretrained_base_model_path, config.motion_module_path, subfolder="unet", unet_additional_kwargs=infer_config.unet_additional_kwargs, ).to(dtype=weight_dtype, device="cuda") pose_guider = PoseGuider(320, block_out_channels=(16, 32, 96, 256)).to( dtype=weight_dtype, device="cuda" ) image_enc = CLIPVisionModelWithProjection.from_pretrained( config.image_encoder_path ).to(dtype=weight_dtype, device="cuda") sched_kwargs = OmegaConf.to_container(infer_config.noise_scheduler_kwargs) scheduler = DDIMScheduler(**sched_kwargs) seed = config.get("seed",args.seed) generator = torch.manual_seed(seed) width, height = args.W, args.H clip_length = config.get("L",args.L) steps = args.steps guidance_scale = args.cfg # load pretrained weights denoising_unet.load_state_dict( torch.load(config.denoising_unet_path, map_location="cpu"), strict=False, ) reference_unet.load_state_dict( torch.load(config.reference_unet_path, map_location="cpu"), ) pose_guider.load_state_dict( torch.load(config.pose_guider_path, map_location="cpu"), ) pipe = Pose2VideoPipeline( vae=vae, image_encoder=image_enc, reference_unet=reference_unet, denoising_unet=denoising_unet, pose_guider=pose_guider, scheduler=scheduler, ) # 设置日志文件路径 # log_file_path = "model_structures.log" # with open(log_file_path, 'w') as log_file: # # 重定向标准输出到日志文件 # orig_stdout = sys.stdout # 保存原始的标准输出 # sys.stdout = log_file # 将标准输出重定向到日志文件 # # 打印模型结构 # print("Denoising UNet structure:") # print(denoising_unet) # 打印 denoising_unet 的结构 # print("Reference UNet structure:") # print(reference_unet) # 打印 reference_unet 的结构 # print("Pose Guider structure:") # print(pose_guider) # 打印 pose_guider 的结构 # print("image_enc:") # print(image_enc) # print("Pose Guider structure:") # print(pose_guider) # print("pipe:") # print(pipe) # # 恢复标准输出 # sys.stdout = orig_stdout # 还原标准输出 # print(f"The model structures have been saved to {log_file_path}.") pipe = pipe.to("cuda", dtype=weight_dtype) date_str = datetime.now().strftime("%Y%m%d") time_str = datetime.now().strftime("%H%M") save_dir_name = f"{time_str}--seed_{seed}-{args.W}x{args.H}" save_dir = Path(f"output/{date_str}/{save_dir_name}") save_dir.mkdir(exist_ok=True, parents=True) model_video_paths = config.model_video_paths cloth_image_paths = config.cloth_image_paths transform = transforms.Compose( [transforms.Resize((height, width)), transforms.ToTensor()] ) for model_image_path in model_video_paths: # print("model_image_path", model_image_path) src_fps = get_fps(model_image_path) model_name = Path(model_image_path).stem agnostic_path=model_image_path.replace("videos","agnostic") #data/videos/upper1.mp4——>data/agnostic/upper1.mp4 agn_mask_path=model_image_path.replace("videos","agnostic_mask") densepose_path=model_image_path.replace("videos","densepose") video_tensor_list=[] video_images=read_frames(model_image_path) clip_length = len(video_images) # 设置 clip_length 为输入视频的总帧数 # clip_length=48 for vid_image_pil in video_images[:clip_length]: #clip_length=24 video_tensor_list.append(transform(vid_image_pil)) video_tensor = torch.stack(video_tensor_list, dim=0) # (f, c, h, w) video_tensor = video_tensor.transpose(0, 1) agnostic_list=[] agnostic_images=read_frames(agnostic_path) for agnostic_image_pil in agnostic_images[:clip_length]: agnostic_list.append(agnostic_image_pil) agn_mask_list=[] agn_mask_images=read_frames(agn_mask_path) # print(" agn_mask_images", agn_mask_images) for agn_mask_image_pil in agn_mask_images[:clip_length]: agn_mask_list.append(agn_mask_image_pil) pose_list=[] pose_images=read_frames(densepose_path) for pose_image_pil in pose_images[:clip_length]: pose_list.append(pose_image_pil) video_tensor = video_tensor.unsqueeze(0) for cloth_image_path in cloth_image_paths: cloth_name = Path(cloth_image_path).stem cloth_image_pil = Image.open(cloth_image_path).convert("RGB") cloth_mask_path=cloth_image_path.replace("cloth","cloth_mask") cloth_mask_pil = Image.open(cloth_mask_path).convert("RGB") pipeline_output = pipe( agnostic_list, agn_mask_list, cloth_image_pil, cloth_mask_pil, pose_list, width, height, clip_length, steps, guidance_scale, generator=generator, ) # print("pipeline_output", pipeline_output) video = pipeline_output.videos video = torch.cat([video_tensor,video], dim=0) save_videos_grid( video, f"{save_dir}/{model_name}_{cloth_name}_{args.H}x{args.W}_{int(guidance_scale)}_{time_str}.mp4", n_rows=2, fps=src_fps if args.fps is None else args.fps, ) if __name__ == "__main__": main()