import os from pathlib import Path import gradio as gr import spaces import torch import smplx import numpy as np from website import CREDITS, WEB_source, WEB_target, WEBSITE from download_deps import get_smpl_models, download_models, download_model_config from download_deps import download_tmr, download_motionfix, download_motionfix_dataset from download_deps import download_embeddings import random # DO NOT initialize CUDA here DEFAULT_TEXT = "do it slower" import os os.environ['PYOPENGL_PLATFORM'] = 'egl' os.environ['LD_LIBRARY_PATH'] = '/usr/lib/x86_64-linux-gnu:/usr/lib/x86_64-linux-gnu/nvidia/current:' + os.environ.get('LD_LIBRARY_PATH', '') # Optional debugging import subprocess try: result = subprocess.run(['ldconfig', '-p'], capture_output=True, text=True) egl_libs = [line for line in result.stdout.split('\n') if 'EGL' in line] print("Available EGL libraries:", egl_libs) except Exception as e: print(f"Error finding libraries: {e}") # Example videos example_videos = [ "./examples/000652_0_120.mp4", # Replace with actual video paths "./examples/000652_0_120.mp4", # Replace with actual video paths "./examples/000652_0_120.mp4", # Replace with actual video paths "./examples/000652_0_120.mp4", # Replace with actual video paths ] # Example videos example_keys = [ "000091", # Replace with actual video paths "000091", # Replace with actual video paths "000091", # Replace with actual video paths "000091", # Replace with actual video paths ] # Example videos example_texts = [ "need to use the opposite leg", # Replace with actual video paths "need to use the opposite leg2", # Replace with actual video paths "need to use the opposite leg3", # Replace with actual video paths "need to use the opposite leg4", # Replace with actual video paths ] example_video_outputs = [gr.Video(label=f"Example {i+1}", value=example_videos[i]) for i in range(4)] class MotionEditor: def __init__(self): # Don't initialize any CUDA components in __init__ self.is_initialized = False self.MFIX_p = download_motionfix() + '/motionfix' self.SOURCE_MOTS_p = download_embeddings() + '/embeddings' self.MFIX_DATASET_DICT = download_motionfix_dataset() self.model_ckpt_path = download_models() self.model_config_feats = download_model_config() @spaces.GPU def initialize_if_needed(self): """Initialize models only when needed, within a GPU-decorated function""" if self.is_initialized: return if not torch.cuda.is_available(): raise RuntimeError("CUDA is not available") print(f"Current CUDA device: {torch.cuda.current_device()}") print(f"CUDA device name: {torch.cuda.get_device_name(0)}") # Check total and available memory total_memory = torch.cuda.get_device_properties(0).total_memory reserved_memory = torch.cuda.memory_reserved(0) allocated_memory = torch.cuda.memory_allocated(0) print(f"Total GPU Memory: {total_memory / 1e9} GB") print(f"Reserved Memory: {reserved_memory / 1e9} GB") print(f"Allocated Memory: {allocated_memory / 1e9} GB") from normalization import Normalizer from diffusion import create_diffusion from text_encoder import ClipTextEncoder from tmed_denoiser import TMED_denoiser # Initialize components self.device = torch.device('cuda') self.normalizer = Normalizer() self.text_encoder = ClipTextEncoder() # Load models and configs model_ckpt = self.model_ckpt_path self.infeats = self.model_config_feats checkpoint = torch.load(model_ckpt, map_location=self.device) checkpoint = {k.replace('denoiser.', ''): v for k, v in checkpoint.items()} # Setup denoiser self.tmed_denoiser = TMED_denoiser().to(self.device) self.tmed_denoiser.load_state_dict(checkpoint, strict=False) self.tmed_denoiser.eval() # Setup diffusion self.diffusion = create_diffusion( timestep_respacing=None, learn_sigma=False, sigma_small=True, diffusion_steps=300, noise_schedule='squaredcos_cap_v2', predict_xstart=True ) # Setup SMPL model smpl_models_path = str(Path(get_smpl_models())) self.body_model = smplx.SMPLHLayer( f"{smpl_models_path}/smplh", model_type='smplh', gender='neutral', ext='npz' ) self.is_initialized = True @spaces.GPU(duration=360) def process_motion(self, input_text, key_to_use): """Main processing function, GPU-decorated""" self.initialize_if_needed() # Load dataset sample ds_sample = self.MFIX_DATASET_DICT[key_to_use] # Process features data_dict = self.process_features(ds_sample) source_motion_norm, target_motion_norm = self.normalize_motions(data_dict) source_motion = self.denormalize_motion(source_motion_norm) # Generate edited motion edited_motion = self.generate_edited_motion( input_text, source_motion_norm, target_motion_norm ) # Render result return self.render_result(edited_motion, source_motion) def process_features(self, ds_sample): """Process features - called from within GPU-decorated function""" from feature_extractor import FEAT_GET_METHODS data_dict = {} for feat in self.infeats: data_dict[f'{feat}_source'] = FEAT_GET_METHODS[feat]( ds_sample['motion_source'] )[None].to(self.device) data_dict[f'{feat}_target'] = FEAT_GET_METHODS[feat]( ds_sample['motion_target'] )[None].to(self.device) return data_dict def normalize_motions(self, data_dict): """Normalize motions - called from within GPU-decorated function""" batch = self.normalizer.norm_and_cat(data_dict, self.infeats) return batch['source'], batch['target'] def generate_edited_motion(self, input_text, source_motion, target_motion): """Generate edited motion - called from within GPU-decorated function""" # Encode text texts_cond = [''] * 2 + [input_text] text_emb, text_mask = self.text_encoder(texts_cond) # Setup masks bsz = 1 seqlen_src = source_motion.shape[0] seqlen_tgt = target_motion.shape[0] cond_motion_mask = torch.ones((bsz, seqlen_src), dtype=bool, device=self.device) mask_target = torch.ones((bsz, seqlen_tgt), dtype=bool, device=self.device) # Generate diffusion output diff_out = self.tmed_denoiser._diffusion_reverse( text_emb.to(self.device), text_mask.to(self.device), source_motion, cond_motion_mask, mask_target, self.diffusion, init_vec=None, init_from='noise', gd_text=2.0, gd_motion=2.0, steps_num=300 ) return self.denormalize_motion(diff_out) def denormalize_motion(self, diff_out): """Denormalize motion - called from within GPU-decorated function""" from geometry_utils import diffout2motion return diffout2motion(diff_out.permute(1, 0, 2), self.normalizer).squeeze() def render_result(self, edited_motion, source_motion): """Render result - called from within GPU-decorated function""" from body_renderer import get_render from transform3d import transform_body_pose, rotate_body_degrees # import ipdb; ipdb.set_trace() # Transform motions edited_motion_transformed = self.transform_motion(edited_motion) source_motion_transformed = self.transform_motion(source_motion) # Render video if os.path.exists('./output_movie.mp4'): os.remove('./output_movie.mp4') return get_render( self.body_model, [edited_motion_transformed['trans'].detach().cpu(), source_motion_transformed['trans'].detach().cpu()], [edited_motion_transformed['rots_init'].detach().cpu(), source_motion_transformed['rots_init'].detach().cpu()], [edited_motion_transformed['rots_rest'].detach().cpu(), source_motion_transformed['rots_rest'].detach().cpu()], output_path='./output_movie.mp4', text='', colors=['sky blue', 'red'] ) def transform_motion(self, motion): """Transform motion - called from within GPU-decorated function""" from transform3d import transform_body_pose, rotate_body_degrees motion_aa = transform_body_pose(motion[:, 3:], '6d->aa') trans = motion[..., :3].detach().cpu() rots_aa = motion_aa.detach().cpu() rots_rotated, trans_rotated = rotate_body_degrees( transform_body_pose(rots_aa, 'aa->rot'), trans, offset=np.pi ) rots_rotated_aa = transform_body_pose(rots_rotated, 'rot->aa') return { 'trans': trans_rotated, 'rots_init': rots_rotated_aa[:, 0], 'rots_rest': rots_rotated_aa[:, 1:] } # Gradio Interface def create_gradio_interface(): editor = MotionEditor() @spaces.GPU def process_and_show_video(input_text, random_key_state): return editor.process_motion(input_text, random_key_state) def random_source_motion(set_to_pick): from dataset_utils import load_motionfix mfix_train, mfix_test = load_motionfix(editor.MFIX_p) current_set = { 'all': mfix_test | mfix_train, 'train': mfix_train, 'test': mfix_test }[set_to_pick] random_key = random.choice(list(current_set.keys())) motion = current_set[random_key]['motion_a'] text_annot = current_set[random_key]['annotation'] return gr.update(value=motion, visible=True), text_annot, random_key, text_annot def clear(): return "" # Gradio UI with gr.Blocks(css=CUSTOM_CSS) as demo: gr.HTML(WEBSITE) random_key_state = gr.State() with gr.Row(): with gr.Column(scale=5): gr.HTML(WEB_source) with gr.Row(): random_button = gr.Button("Random", scale=0) # clear_button_retrieval = gr.Button("Clear", scale=0) # Example videos grid with buttons # suggested_edit_text = gr.Textbox( # placeholder="Texts likely to edit the motion:", # label="Suggested Edit Text", # value='' # ) set_to_pick = gr.Radio( ['all', 'train', 'test'], value='all', label="Set to pick from" ) retrieved_video_output = gr.Video( label="Retrieved Motion", height=360, width=480, visible=False # Initially hidden ) gr.Markdown("### Examples") with gr.Row(): # First example with gr.Column(): gr.Video(value=example_videos[0], height=180,width=240, label="Example 1") example_button1 = gr.Button("Select Example 1", size='sm', elem_classes=["fit-text"]) # Second example with gr.Column(): gr.Video(value=example_videos[1], height=180,width=240, label="Example 2") example_button2 = gr.Button("Select Example 2", elem_classes=["fit-text"]) with gr.Row(): # Third example with gr.Column(): gr.Video(value=example_videos[2], height=180,width=240, label="Example 3") example_button3 = gr.Button("Select Example 3", elem_classes=["fit-text"]) # Fourth example with gr.Column(): gr.Video(value=example_videos[3], height=180,width=240, label="Example 4") example_button4 = gr.Button("Select Example 4", elem_classes=["fit-text"]) with gr.Column(scale=5): gr.HTML(WEB_target) with gr.Row(): clear_button_edit = gr.Button("Clear", scale=0) edit_button = gr.Button("Edit", scale=0) input_text = gr.Textbox( placeholder="Type the edit text you want:", label="Input Text", value=DEFAULT_TEXT ) video_output = gr.Video( label="Generated Video", height=360, width=480 ) # Event handlers edit_button.click( process_and_show_video, inputs=[input_text, random_key_state], outputs=video_output ) random_button.click( random_source_motion, inputs=set_to_pick, outputs=[ retrieved_video_output, # suggested_edit_text, random_key_state, input_text ] ) # def load_example_video(example_path): # # motion = current_set[random_key]['motion_a'] # # text_annot = current_set[random_key]['annotation'] # import ipdb; ipdb.set_trace() # return gr.update(value=example_path, visible=True) def load_example(example_video, example_key, example_text): # Update all outputs return ( gr.update(value=example_video, visible=True), # Update video output example_text, # Update suggested edit text example_key, # Update random key state example_text # Update input text ) example_button1.click( fn=lambda: load_example(example_videos[0], example_keys[0], example_texts[0]), inputs=None, outputs=[ retrieved_video_output, # suggested_edit_text, random_key_state, input_text ] ) example_button2.click( fn=lambda: load_example(example_videos[1], example_keys[1], example_texts[1]), inputs=None, outputs=[ retrieved_video_output, # suggested_edit_text, random_key_state, input_text ] ) example_button3.click( fn=lambda: load_example(example_videos[2], example_keys[2], example_texts[2]), inputs=None, outputs=[ retrieved_video_output, # suggested_edit_text, random_key_state, input_text ] ) example_button4.click( fn=lambda: load_example(example_videos[3], example_keys[3], example_texts[3]), inputs=None, outputs=[ retrieved_video_output, # suggested_edit_text, random_key_state, input_text ] ) clear_button_edit.click(clear, outputs=input_text) # clear_button_retrieval.click(clear, outputs=suggested_edit_text) gr.Markdown(CREDITS) return demo # Constants CUSTOM_CSS = """ .gradio-row { display: flex; gap: 20px; } .gradio-column { flex: 1; } .gradio-container { display: flex; flex-direction: column; gap: 10px; } .gradio-button-row { display: flex; gap: 10px; } .gradio-textbox-row { display: flex; gap: 10px; align-items: center; } .gradio-edit-row { gap: 10px; align-items: center; } .gradio-textbox-with-button { display: flex; align-items: center; } .gradio-textbox-with-button input { flex-grow: 1; } button.fit-text { width: auto; /* Automatically adjusts to the text length */ padding: 10px 20px; /* Adjust padding for a better look */ font-size: 14px; /* Control font size */ text-align: center; /* Center the text */ margin: 0 auto; /* Center the button horizontally */ display: inline-block; /* Prevent it from stretching */ } """ if __name__ == "__main__": demo = create_gradio_interface() demo.launch(share=True)