import spaces import gradio as gr import time import requests from wan2pt1_t2v_rcm_infer import inference, prepare_models from huggingface_hub import hf_hub_download import random from types import SimpleNamespace import gc import torch from imaginaire.lazy_config import LazyCall as L, LazyDict, instantiate from wan2pt1_t2v_rcm_infer import load_dit_model, WanModel import flash_attn print("flash_attn version: ", flash_attn.__version__) WAN2PT1_1PT3B_T2V: LazyDict = L(WanModel)( dim=1536, eps=1e-06, ffn_dim=8960, freq_dim=256, in_dim=16, model_type="t2v", num_heads=12, num_layers=30, out_dim=16, text_len=512, ) WAN2PT1_14B_T2V: LazyDict = L(WanModel)( dim=5120, eps=1e-06, ffn_dim=13824, freq_dim=256, in_dim=16, model_type="t2v", num_heads=40, num_layers=40, out_dim=16, text_len=512, ) dit_configs = {"1.3B": WAN2PT1_1PT3B_T2V, "14B": WAN2PT1_14B_T2V} dit_path_14B_720p = hf_hub_download( repo_id="worstcoder/rcm-Wan", filename="rCM_Wan2.1_T2V_14B_720p.pt", ) vae_path = hf_hub_download( repo_id="Wan-AI/Wan2.1-T2V-1.3B", filename="Wan2.1_VAE.pth" ) text_encoder_path = hf_hub_download( repo_id="Wan-AI/Wan2.1-T2V-1.3B", filename="models_t5_umt5-xxl-enc-bf16.pth" ) net_14B_720p, tokenizer, t5_encoder = prepare_models(dit_path_14B_720p, vae_path, text_encoder_path) print("Loaded models") gc.collect() def random_seed(): return random.randint(0, 2**32 - 1) @spaces.GPU(duration=360) def generate_videos(prompt, model_size, num_samples, aspect_ratio, sigma_max, num_steps, seed): if seed is None: seed = random.randint(0, 2**32 - 1) if "480p" in model_size: resolution = "480p" else: resolution = "720p" args = SimpleNamespace( prompt=prompt, model_size=model_size, num_steps=num_steps, num_samples=num_samples, sigma_max=sigma_max, num_frames=77, resolution=resolution, aspect_ratio=aspect_ratio, seed=seed, ) with torch.no_grad(): video_list = inference(args, net_14B_720p, tokenizer, t5_encoder) if aspect_ratio == "16:9": return video_list, None else: return None, video_list def update_num_samples(model_choice): if model_choice == "rCM-Wan2.1-T2V-1.3B-480p": options = [1, 2, 3, 4] elif model_choice == "rCM-Wan2.1-T2V-14B-480p": options = [1, 2, 3] else: options = [1, 2, 3] return gr.Dropdown(choices=options, value=options[0], label="num_samples") def update_sigma_max(model_choice): if "480p" in model_choice: options = [80, 120, 200, 400, 800, 1600] else: options = [120, 200, 400, 800, 1600] return gr.Dropdown(choices=options, value=options[0], label="sigma_max") with gr.Blocks() as demo: gr.Markdown("## rCM model for Wan") examples = [ ["A stylish woman walks down a Tokyo street filled with warm glowing neon and animated city signage. She wears a black leather jacket, a long red dress, and black boots, and carries a black purse. She wears sunglasses and red lipstick. She walks confidently and casually. The street is damp and reflective, creating a mirror effect of the colorful lights. Many pedestrians walk about."], ["A close-up shot captures a steaming hot pot brimming with vegetables and dumplings, set on a rustic wooden table. The camera focuses on the bubbling broth as a woman, dressed in a light, patterned blouse, reaches in with chopsticks to lift a tender leaf of cabbage from the simmering mixture. Steam rises around her as she leans back slightly, her warm smile reflecting satisfaction and joy. Her movements are smooth and deliberate, showcasing her comfort and familiarity with the dining process. The background includes a small bowl of dipping sauce and a clay pot, adding to the cozy, communal dining atmosphere."], ["A dynamic time-lapse video showing the rapidly moving scenery from the window of a speeding train. The camera captures various elements such as lush green fields, towering trees, quaint countryside houses, and distant mountain ranges passing by quickly. The train window frames the view, adding a sense of speed and motion as the landscape rushes past. The camera remains static but emphasizes the fast-paced movement outside. The overall atmosphere is serene yet exhilarating, capturing the essence of travel and exploration. Medium shot focusing on the train window and the rushing scenery beyond."] ] with gr.Row(): with gr.Column(scale=1): with gr.Row(): prompt = gr.Textbox(label="Text prompt", placeholder="Text prompt for videos") model_size = gr.Radio( ["rCM-Wan2.1-T2V-14B-720p"], value="rCM-Wan2.1-T2V-14B-720p", label="Model" ) with gr.Row(): num_samples = gr.Dropdown([1, 2], value=1, label="num_samples") aspect_ratio = gr.Radio(["16:9", "9:16"], value="16:9", label="aspect_ratio") sigma_max = gr.Dropdown([120, 200, 400, 800, 1600], value=120, label="sigma_max") with gr.Row(): num_steps = gr.Slider(1, 4, value=4, step=1, label="num_steps") seed = gr.Number(label="seed", value=random_seed(), interactive=True) with gr.Row(): regenerate_btn = gr.Button("New Seed") run_btn = gr.Button("Generate Videos") with gr.Row(): gr.Examples( examples, inputs=[prompt], label="Example prompts" ) with gr.Column(scale=1): video_16_9 = gr.Video(label="Videos 16:9", width=832) video_9_16 = gr.Video(label="Videos 9:16", width=480, visible=False) def show_video(aspect): if aspect == "16:9": return gr.update(visible=True), gr.update(visible=False, value=None) else: return gr.update(visible=False, value=None), gr.update(visible=True) model_size.change(fn=update_num_samples, inputs=model_size, outputs=num_samples) model_size.change(fn=update_sigma_max, inputs=model_size, outputs=sigma_max) aspect_ratio.change(show_video, inputs=aspect_ratio, outputs=[video_16_9, video_9_16]) regenerate_btn.click(fn=random_seed, outputs=seed) run_btn.click( fn=generate_videos, inputs=[prompt, model_size, num_samples, aspect_ratio, sigma_max, num_steps, seed], outputs=[video_16_9, video_9_16], ) demo.launch()