minimax-remover / test_minimax_remover.py
zibojia's picture
Upload 3 files
0841861 verified
raw
history blame
2.27 kB
import torch
import argparse
import numpy as np
import random
import torch.nn.functional as F
from PIL import Image
import torch.distributed as dist
from diffusers.utils import export_to_video
from omegaconf import OmegaConf
from einops import rearrange
from decord import VideoReader
from diffusers.models import AutoencoderKLWan
import scipy
from transformer_minimax_remover import Transformer3DModel
from einops import rearrange
from diffusers.schedulers import UniPCMultistepScheduler
from pipeline_minimax_remover import Minimax_Remover_Pipeline
random_seed = 42
video_length = 81
device = torch.device("cuda:0")
vae = AutoencoderKLWan.from_pretrained("./vae", torch_dtype=torch.float16)
transformer = Transformer3DModel.from_pretrained("./transformer", torch_dtype=torch.float16)
scheduler = UniPCMultistepScheduler.from_pretrained("./scheduler")
pipe = Minimax_Remover_Pipeline(transformer=transformer, vae=vae, scheduler=scheduler)
pipe.to("cuda:0")
def inference(pixel_values, masks, iterations=6):
video = pipe(
images=pixel_values,
masks=masks,
num_frames=video_length,
height=480,
width=832,
num_inference_steps=12,
generator=torch.Generator(device="cuda").manual_seed(random_seed),
iterations=iterations
).frames[0]
export_to_video(video, f"./output.mp4")
def load_video(video_path):
vr = VideoReader(video_path)
images = vr.get_batch(list(range(video_length))).asnumpy()
images = torch.from_numpy(images)/127.5 - 1.0
return images
def load_mask(mask_path):
vr = VideoReader(mask_path)
masks = vr.get_batch(list(range(video_length))).asnumpy()
masks = torch.from_numpy(masks)
masks = masks[:,:,:,:1]
masks[masks>20] = 255
masks[masks<255] = 0
masks = masks/255.0
return masks
#video_path = "../pexels_export/height/5720258-hd_1080_1920_24fps.mp4"
video_path = "../pexels_export/fast/3352673-hd_1280_720_30fps.mp4"
#mask_path = "../pexels_export/height/5720258-hd_1080_1920_24fps_mask.mp4"
mask_path = "../pexels_export/fast/3352673-hd_1280_720_30fps_mask.mp4"
images = load_video(video_path)
masks = load_mask(mask_path)
inference(images, masks)