Spaces:
Running
Running
| import os | |
| from pathlib import Path | |
| import gradio as gr | |
| import spaces | |
| import torch | |
| import smplx | |
| import numpy as np | |
| from website import CREDITS, WEB_source, WEB_target, WEBSITE | |
| from download_deps import get_smpl_models, download_models, download_model_config | |
| from download_deps import download_tmr, download_motionfix, download_motionfix_dataset | |
| from download_deps import download_embeddings | |
| import random | |
| # DO NOT initialize CUDA here | |
| DEFAULT_TEXT = "do it slower" | |
| import os | |
| os.environ['PYOPENGL_PLATFORM'] = 'egl' | |
| os.environ['LD_LIBRARY_PATH'] = '/usr/lib/x86_64-linux-gnu:/usr/lib/x86_64-linux-gnu/nvidia/current:' + os.environ.get('LD_LIBRARY_PATH', '') | |
| # Optional debugging | |
| import subprocess | |
| try: | |
| result = subprocess.run(['ldconfig', '-p'], capture_output=True, text=True) | |
| egl_libs = [line for line in result.stdout.split('\n') if 'EGL' in line] | |
| print("Available EGL libraries:", egl_libs) | |
| except Exception as e: | |
| print(f"Error finding libraries: {e}") | |
| # Example videos | |
| example_videos = [ | |
| "./examples/1919.mp4", | |
| "./examples/5376.mp4", | |
| "./examples/1259.mp4", | |
| "./examples/3686.mp4", | |
| "./examples/1289.mp4", | |
| "./examples/1893.mp4", | |
| "./examples/3262.mp4", | |
| "./examples/6117.mp4", | |
| "./examples/1031.mp4", | |
| "./examples/6247.mp4", | |
| ] | |
| # Example videos | |
| example_keys = [ | |
| "001919", | |
| "005376", | |
| "001259", | |
| "003686", | |
| "001289", | |
| "001893", | |
| "003262", | |
| "006117", | |
| "001031", | |
| "006247", | |
| ] | |
| # Example videos | |
| example_texts = [ | |
| "mirror", | |
| "move in a smaller circle", | |
| "less deep", | |
| "turn back faster", | |
| "cross your legs", | |
| "step to the right", | |
| "start sitting down a bit later", | |
| "start a bit later, hold elbow lower at the end", | |
| "extend the arm further back and catch higher", | |
| "hold right arm higher", | |
| ] | |
| example_video_outputs = [gr.Video(label=f"Example {i+1}", | |
| value=example_videos[i]) | |
| for i in range(4)] | |
| class MotionEditor: | |
| def __init__(self): | |
| # Don't initialize any CUDA components in __init__ | |
| self.is_initialized = False | |
| self.MFIX_p = download_motionfix() + '/motionfix' | |
| # self.SOURCE_MOTS_p = download_embeddings() + '/embeddings' | |
| self.MFIX_DATASET_DICT = download_motionfix_dataset() | |
| self.model_ckpt_path = download_models("899_bs128_zipped") # small_model_zipped_last/last_zipped | |
| self.model_cfg = download_model_config('bs_128_conf') # small_model_config / big_model_config | |
| self.model_config_feats = self.model_cfg.model.input_feats | |
| def initialize_if_needed(self): | |
| """Initialize models only when needed, within a GPU-decorated function""" | |
| if self.is_initialized: | |
| return | |
| if not torch.cuda.is_available(): | |
| raise RuntimeError("CUDA is not available") | |
| print(f"Current CUDA device: {torch.cuda.current_device()}") | |
| print(f"CUDA device name: {torch.cuda.get_device_name(0)}") | |
| # Check total and available memory | |
| total_memory = torch.cuda.get_device_properties(0).total_memory | |
| reserved_memory = torch.cuda.memory_reserved(0) | |
| allocated_memory = torch.cuda.memory_allocated(0) | |
| print(f"Total GPU Memory: {total_memory / 1e9} GB") | |
| print(f"Reserved Memory: {reserved_memory / 1e9} GB") | |
| print(f"Allocated Memory: {allocated_memory / 1e9} GB") | |
| from normalization import Normalizer | |
| from diffusion import create_diffusion | |
| from text_encoder import ClipTextEncoder | |
| from tmed_denoiser import TMED_denoiser | |
| # Initialize components | |
| self.device = torch.device('cuda') | |
| self.normalizer = Normalizer() | |
| self.text_encoder = ClipTextEncoder() | |
| # Load models and configs | |
| model_ckpt = self.model_ckpt_path | |
| self.infeats = self.model_config_feats | |
| checkpoint = torch.load(model_ckpt, map_location=self.device) | |
| checkpoint = {k.replace('denoiser.', ''): v for k, v in checkpoint.items()} | |
| # Setup denoiser | |
| self.tmed_denoiser = TMED_denoiser(latent_dim=self.model_cfg.model.latent_dim, | |
| num_layers=8, | |
| ff_size=1024, | |
| num_heads=4).to(self.device) | |
| self.tmed_denoiser.load_state_dict(checkpoint, strict=False) | |
| self.tmed_denoiser.eval() | |
| # Setup diffusion | |
| self.diffusion = create_diffusion( | |
| timestep_respacing=None, | |
| learn_sigma=False, | |
| sigma_small=True, | |
| diffusion_steps=self.model_cfg.model.diff_params.num_train_timesteps, | |
| noise_schedule='squaredcos_cap_v2', | |
| predict_xstart=True | |
| ) | |
| # Setup SMPL model | |
| smpl_models_path = str(Path(get_smpl_models())) | |
| self.body_model = smplx.SMPLHLayer( | |
| f"{smpl_models_path}/smplh", | |
| model_type='smplh', | |
| gender='neutral', | |
| ext='npz' | |
| ) | |
| self.is_initialized = True | |
| def process_motion(self, input_text, key_to_use): | |
| """Main processing function, GPU-decorated""" | |
| self.initialize_if_needed() | |
| # import ipdb; ipdb.set_trace() | |
| # Load dataset sample | |
| ds_sample = self.MFIX_DATASET_DICT[key_to_use] | |
| # Process features | |
| data_dict = self.process_features(ds_sample) | |
| source_motion_norm, target_motion_norm = self.normalize_motions(data_dict) | |
| source_motion = self.denormalize_motion(source_motion_norm) | |
| # Generate edited motion | |
| edited_motion = self.generate_edited_motion( | |
| input_text, | |
| source_motion_norm, | |
| target_motion_norm | |
| ) | |
| # Render result | |
| return self.render_result(edited_motion, source_motion) | |
| def process_features(self, ds_sample): | |
| """Process features - called from within GPU-decorated function""" | |
| from feature_extractor import FEAT_GET_METHODS | |
| data_dict = {} | |
| for feat in self.infeats: | |
| data_dict[f'{feat}_source'] = FEAT_GET_METHODS[feat]( | |
| ds_sample['motion_source'] | |
| )[None].to(self.device) | |
| data_dict[f'{feat}_target'] = FEAT_GET_METHODS[feat]( | |
| ds_sample['motion_target'] | |
| )[None].to(self.device) | |
| return data_dict | |
| def normalize_motions(self, data_dict): | |
| """Normalize motions - called from within GPU-decorated function""" | |
| batch = self.normalizer.norm_and_cat(data_dict, self.infeats) | |
| return batch['source'], batch['target'] | |
| def generate_edited_motion(self, input_text, source_motion, target_motion): | |
| """Generate edited motion - called from within GPU-decorated function""" | |
| # Encode text | |
| texts_cond = [''] * 2 + [input_text] | |
| text_emb, text_mask = self.text_encoder(texts_cond) | |
| # Setup masks | |
| bsz = 1 | |
| seqlen_src = source_motion.shape[0] | |
| seqlen_tgt = target_motion.shape[0] | |
| cond_motion_mask = torch.ones((bsz, seqlen_src), dtype=bool, device=self.device) | |
| mask_target = torch.ones((bsz, seqlen_tgt), dtype=bool, device=self.device) | |
| # Generate diffusion output | |
| diff_out = self.tmed_cenoiser._diffusion_reverse( | |
| text_emb.to(self.device), | |
| text_mask.to(self.device), | |
| source_motion, | |
| cond_motion_mask, | |
| mask_target, | |
| self.diffusion, | |
| init_vec=None, | |
| init_from='noise', | |
| gd_text=2.0, | |
| gd_motion=3.0, | |
| steps_num=self.model_cfg.model.diff_params.num_train_timesteps | |
| ) | |
| return self.denormalize_motion(diff_out) | |
| def denormalize_motion(self, diff_out): | |
| """Denormalize motion - called from within GPU-decorated function""" | |
| from geometry_utils import diffout2motion | |
| # import ipdb; ipdb.set_trace() | |
| return diffout2motion(diff_out.permute(1, 0, 2), self.normalizer).squeeze() | |
| def render_result(self, edited_motion, source_motion): | |
| """Render result - called from within GPU-decorated function""" | |
| from body_renderer import get_render | |
| from transform3d import transform_body_pose, rotate_body_degrees | |
| # Transform motions | |
| edited_motion_transformed = self.transform_motion(edited_motion) | |
| source_motion_transformed = self.transform_motion(source_motion) | |
| # Render video | |
| if os.path.exists('./output_movie.mp4'): | |
| os.remove('./output_movie.mp4') | |
| # import ipdb; ipdb.set_trace() | |
| return get_render( | |
| self.body_model, | |
| [edited_motion_transformed['trans'].detach().cpu(), | |
| source_motion_transformed['trans'].detach().cpu()], | |
| [edited_motion_transformed['rots_init'].detach().cpu(), | |
| source_motion_transformed['rots_init'].detach().cpu()], | |
| [edited_motion_transformed['rots_rest'].detach().cpu(), | |
| source_motion_transformed['rots_rest'].detach().cpu()], | |
| output_path='./output_movie.mp4', | |
| text='', | |
| colors=['sky blue', 'red'] | |
| ) | |
| def transform_motion(self, motion): | |
| """Transform motion - called from within GPU-decorated function""" | |
| from transform3d import transform_body_pose, rotate_body_degrees | |
| motion_aa = transform_body_pose(motion[:, 3:], '6d->aa') | |
| trans = motion[..., :3].detach().cpu() | |
| rots_aa = motion_aa.detach().cpu() | |
| rots_rotated, trans_rotated = rotate_body_degrees( | |
| transform_body_pose(rots_aa, 'aa->rot'), | |
| trans, | |
| offset=np.pi | |
| ) | |
| rots_rotated_aa = transform_body_pose(rots_rotated, 'rot->aa') | |
| return { | |
| 'trans': trans_rotated, | |
| 'rots_init': rots_rotated_aa[:, 0], | |
| 'rots_rest': rots_rotated_aa[:, 1:] | |
| } | |
| # Gradio Interface | |
| def create_gradio_interface(): | |
| editor = MotionEditor() | |
| def process_and_show_video(input_text, random_key_state): | |
| return editor.process_motion(input_text, random_key_state) | |
| def random_source_motion(set_to_pick): | |
| from dataset_utils import load_motionfix | |
| mfix_train, mfix_test = load_motionfix(editor.MFIX_p) | |
| current_set = { | |
| 'all': mfix_test | mfix_train, | |
| 'train': mfix_train, | |
| 'test': mfix_test | |
| }[set_to_pick] | |
| random_key = random.choice(list(current_set.keys())) | |
| motion = current_set[random_key]['motion_a'] | |
| text_annot = current_set[random_key]['annotation'] | |
| # should add one more text_annot | |
| return gr.update(value=motion, | |
| visible=True), random_key, text_annot | |
| def clear(): | |
| return "" | |
| # Gradio UI | |
| with gr.Blocks(css=CUSTOM_CSS) as demo: | |
| gr.HTML(WEBSITE) | |
| random_key_state = gr.State() | |
| with gr.Row(): | |
| with gr.Column(scale=5): | |
| gr.HTML(WEB_source) | |
| with gr.Row(): | |
| random_button = gr.Button("Random", scale=0) | |
| # clear_button_retrieval = gr.Button("Clear", scale=0) | |
| # Example videos grid with buttons | |
| # suggested_edit_text = gr.Textbox( | |
| # placeholder="Texts likely to edit the motion:", | |
| # label="Suggested Edit Text", | |
| # value='' | |
| # ) | |
| set_to_pick = gr.Radio( | |
| ['all', 'train', 'test'], | |
| value='all', | |
| label="Set to pick from" | |
| ) | |
| retrieved_video_output = gr.Video( | |
| label="Retrieved Motion", | |
| height=360, | |
| width=480, | |
| visible=False # Initially hidden | |
| ) | |
| gr.HTML(("""<div class="embed_hidden" style="text-align: center;"> | |
| <h1>Examples</h1></div>""")) | |
| with gr.Row(): | |
| # First example | |
| with gr.Column(): | |
| gr.Video(value=example_videos[0], | |
| height=180,width=240, | |
| label="Example 1") | |
| example_button1 = gr.Button("Select Ex. 1", | |
| elem_classes=["fit-text"]) | |
| # Second example | |
| with gr.Column(): | |
| gr.Video(value=example_videos[1], | |
| height=180,width=240, | |
| label="Example 2") | |
| example_button2 = gr.Button("Select Ex. 2", | |
| elem_classes=["fit-text"]) | |
| with gr.Row(): | |
| # Third example | |
| with gr.Column(): | |
| gr.Video(value=example_videos[2], | |
| height=180,width=240, | |
| label="Example 3") | |
| example_button3 = gr.Button("Select Ex. 3", | |
| elem_classes=["fit-text"]) | |
| # Fourth example | |
| with gr.Column(): | |
| gr.Video(value=example_videos[3], | |
| height=180,width=240, | |
| label="Example 4") | |
| example_button4 = gr.Button("Select Ex. 4", | |
| elem_classes=["fit-text"]) | |
| with gr.Column(scale=5): | |
| gr.HTML(WEB_target) | |
| with gr.Row(): | |
| clear_button_edit = gr.Button("Clear", scale=0) | |
| edit_button = gr.Button("Edit", scale=0) | |
| input_text = gr.Textbox( | |
| placeholder="Type the edit text you want:", | |
| label="Input Text", | |
| value=DEFAULT_TEXT | |
| ) | |
| video_output = gr.Video( | |
| label="Generated Video", | |
| height=360, | |
| width=480 | |
| ) | |
| # Event handlers | |
| edit_button.click( | |
| process_and_show_video, | |
| inputs=[input_text, random_key_state], | |
| outputs=video_output | |
| ) | |
| random_button.click( | |
| random_source_motion, | |
| inputs=set_to_pick, | |
| outputs=[ | |
| retrieved_video_output, | |
| # suggested_edit_text, | |
| random_key_state, | |
| input_text | |
| ] | |
| ) | |
| # def load_example_video(example_path): | |
| # # motion = current_set[random_key]['motion_a'] | |
| # # text_annot = current_set[random_key]['annotation'] | |
| # import ipdb; ipdb.set_trace() | |
| # return gr.update(value=example_path, visible=True) | |
| def load_example(example_video, example_key, example_text): | |
| # Update all outputs | |
| return ( | |
| gr.update(value=example_video, visible=True), # Update video output | |
| # example_text, # Update suggested edit text | |
| example_key, # Update random key state | |
| example_text # Update input text | |
| ) | |
| example_button1.click( | |
| fn=lambda: load_example(example_videos[0], example_keys[0], example_texts[0]), | |
| inputs=None, | |
| outputs=[ | |
| retrieved_video_output, | |
| # suggested_edit_text, | |
| random_key_state, | |
| input_text | |
| ] | |
| ) | |
| example_button2.click( | |
| fn=lambda: load_example(example_videos[1], example_keys[1], example_texts[1]), | |
| inputs=None, | |
| outputs=[ | |
| retrieved_video_output, | |
| # suggested_edit_text, | |
| random_key_state, | |
| input_text | |
| ] | |
| ) | |
| example_button3.click( | |
| fn=lambda: load_example(example_videos[2], example_keys[2], example_texts[2]), | |
| inputs=None, | |
| outputs=[ | |
| retrieved_video_output, | |
| # suggested_edit_text, | |
| random_key_state, | |
| input_text | |
| ] | |
| ) | |
| example_button4.click( | |
| fn=lambda: load_example(example_videos[3], example_keys[3], example_texts[3]), | |
| inputs=None, | |
| outputs=[ | |
| retrieved_video_output, | |
| # suggested_edit_text, | |
| random_key_state, | |
| input_text | |
| ] | |
| ) | |
| clear_button_edit.click(clear, outputs=input_text) | |
| # clear_button_retrieval.click(clear, outputs=suggested_edit_text) | |
| gr.Markdown(CREDITS) | |
| return demo | |
| # Constants | |
| CUSTOM_CSS = """ | |
| .gradio-row { display: flex; gap: 20px; } | |
| .gradio-column { flex: 1; } | |
| .gradio-container { display: flex; flex-direction: column; gap: 10px; } | |
| .gradio-button-row { display: flex; gap: 10px; } | |
| .gradio-textbox-row { display: flex; gap: 10px; align-items: center; } | |
| .gradio-edit-row { gap: 10px; align-items: center; } | |
| .gradio-textbox-with-button { display: flex; align-items: center; } | |
| .gradio-textbox-with-button input { flex-grow: 1; } | |
| button.fit-text { | |
| width: auto; /* Automatically adjusts to the text length */ | |
| padding: 10px 20px; /* Adjust padding for a better look */ | |
| font-size: 12px; /* Control font size */ | |
| text-align: center; /* Center the text */ | |
| margin: 0 auto; /* Center the button horizontally */ | |
| display: inline-block; /* Prevent it from stretching */ | |
| } | |
| """ | |
| if __name__ == "__main__": | |
| demo = create_gradio_interface() | |
| demo.launch(share=True) |