from calendar import EPOCH from geometry_utils import diffout2motion import gradio as gr import spaces import torch import os from pathlib import Path import smplx from body_renderer import get_render import numpy as np from download_deps import get_smpl_models, download_models, download_model_config from download_deps import download_tmr, download_motionfix, download_motionfix_dataset from download_deps import download_embeddings from website import CREDITS, WEB_source, WEB_target, WEBSITE # import cv2 # import moderngl # ctx = moderngl.create_context(standalone=True) # print(ctx) # sdk_version: 5.5.0 access_token_smpl = os.environ.get('HF_SMPL_TOKEN') os.environ["PYOPENGL_PLATFORM"] = "egl" zero = torch.Tensor([0]).cuda() print(zero.device) # <-- 'cuda:0' 🤗 DEFAULT_TEXT = "do it slower " @spaces.GPU def greet(n): print(zero.device) # <-- 'cuda:0' 🤗 try: number = float(n) except ValueError: return "Invalid input. Please enter a number." return f"Hello {zero + number} Tensor" def clear(): return "" def show_video(input_text, key_to_use): from normalization import Normalizer normalizer = Normalizer() from diffusion import create_diffusion from text_encoder import ClipTextEncoder from tmed_denoiser import TMED_denoiser model_ckpt = download_models() infeats = download_model_config() checkpoint = torch.load(model_ckpt) # motion_to_edit = download_motion_from_dataset(key_to_use) # ds_sample = joblib.load(motion_to_edit) ds_sample = MFIX_DATASET_DICT[key_to_use] from feature_extractor import FEAT_GET_METHODS data_dict_source = {f'{feat}_source': FEAT_GET_METHODS[feat](ds_sample['motion_source'])[None].to('cuda') for feat in infeats} data_dict_target = {f'{feat}_target': FEAT_GET_METHODS[feat](ds_sample['motion_target'])[None].to('cuda') for feat in infeats} full_batch = data_dict_source | data_dict_target in_batch = normalizer.norm_and_cat(full_batch, infeats) source_motion_norm = in_batch['source'] target_motion_norm = in_batch['target'] seqlen_tgt = source_motion_norm.shape[0] seqlen_src = target_motion_norm.shape[0] # import ipdb; ipdb.set_trace() checkpoint = {k.replace('denoiser.', ''): v for k, v in checkpoint.items()} tmed_denoiser = TMED_denoiser().to('cuda') tmed_denoiser.load_state_dict(checkpoint, strict=False) tmed_denoiser.eval() text_encoder = ClipTextEncoder() texts_cond = [input_text] diffusion_process = create_diffusion(timestep_respacing=None, learn_sigma=False, sigma_small=True, diffusion_steps=300, noise_schedule='squaredcos_cap_v2', predict_xstart=True) bsz = 1 no_of_texts = len(texts_cond) texts_cond = ['']*no_of_texts + texts_cond texts_cond = ['']*no_of_texts + texts_cond text_emb, text_mask = text_encoder(texts_cond) cond_emb_motion = source_motion_norm cond_motion_mask = torch.ones((bsz, seqlen_src), dtype=bool, device='cuda') mask_target = torch.ones((bsz, seqlen_tgt), dtype=bool, device='cuda') diff_out = tmed_denoiser._diffusion_reverse(text_emb.to(cond_emb_motion.device), text_mask.to(cond_emb_motion.device), cond_emb_motion, cond_motion_mask, mask_target, diffusion_process, init_vec=None, init_from='noise', gd_text=2.0, gd_motion=2.0, steps_num=300) edited_motion = diffout2motion(diff_out.permute(1,0,2), normalizer).squeeze() gt_source = diffout2motion(source_motion_norm.permute(1,0,2), normalizer).squeeze() # import ipdb; ipdb.set_trace() # aitrenderer = get_renderer() # SMPL_LAYER = SMPLLayer(model_type='smplh', ext='npz', gender='neutral') # edited_mot_to_render = pack_to_render(rots=edited_motion[..., 3:], # trans=edited_motion[..., :3]) SMPL_MODELS_PATH = str(Path(get_smpl_models())) body_model=smplx.SMPLHLayer(f"{SMPL_MODELS_PATH}/smplh", model_type='smplh', gender='neutral',ext='npz') # run_smpl_fwd_verticesbody_model, body_transl, body_orient, body_pose, # edited_mot_to_render from body_renderer import get_render from transform3d import transform_body_pose # import ipdb; ipdb.set_trace() edited_motion_aa = transform_body_pose(edited_motion[:, 3:], '6d->aa') gt_source_aa = transform_body_pose(gt_source[:, 3:], '6d->aa') if os.path.exists('./output_movie.mp4'): os.remove('./output_movie.mp4') from transform3d import rotate_body_degrees gen_motion_trans = edited_motion[..., :3].detach().cpu() gen_motion_rots_aa = edited_motion_aa.detach().cpu() source_motion_trans = gt_source[..., :3].detach().cpu() source_motion_rots_aa = gt_source_aa.detach().cpu() gen_rots_rotated, gen_trans_rotated = rotate_body_degrees(transform_body_pose( gen_motion_rots_aa, 'aa->rot'), gen_motion_trans, offset=np.pi) src_rots_rotated, src_trans_rotated = rotate_body_degrees(transform_body_pose( source_motion_rots_aa, 'aa->rot'), source_motion_trans, offset=np.pi) src_rots_rotated_aa = transform_body_pose(src_rots_rotated, 'rot->aa') gen_rots_rotated_aa = transform_body_pose(gen_rots_rotated, 'rot->aa') fname = get_render(body_model, [gen_trans_rotated, src_trans_rotated], [gen_rots_rotated_aa[:, 0], src_rots_rotated_aa[:, 0]], [gen_rots_rotated_aa[:, 1:], src_rots_rotated_aa[:, 1:]], output_path='./output_movie.mp4', text='', colors=['sky blue', 'red']) # fname = render_motion(AIT_RENDERER, [edited_mot_to_render], # f"movie_example--{str(xx)}", # pose_repr='aa', # color=[color_map['generated']], # smpl_layer=SMPL_LAYER) print(fname) print(os.path.abspath(fname)) return fname MFIX_p = download_motionfix() + '/motionfix' SOURCE_MOTS_p = download_embeddings() + '/embeddings' MFIX_DATASET_DICT = download_motionfix_dataset() import gradio as gr def clear(): return "" def random_source_motion(set_to_pick): # import ipdb;ipdb.set_trace() mfix_train, mfix_test = load_motionfix(MFIX_p) if set_to_pick == 'all': current_set = mfix_test | mfix_train elif set_to_pick == 'train': current_set = mfix_train elif set_to_pick == 'test': current_set = mfix_test import random random_key = random.choice(list(current_set.keys())) curvid = current_set[random_key]['motion_a'] text_annot = current_set[random_key]['annotation'] return curvid, text_annot, random_key, text_annot def retrieve_video(retrieve_text): tmr_text_encoder = get_tmr_model(download_tmr()) # import ipdb;ipdb.set_trace() # text_encoded = tmr_text_encoder([retrieve_text]) motion_embeds = None from gen_utils import read_json import numpy as np motion_embeds = torch.load(SOURCE_MOTS_p+'/source_motions_embeddings.pt') motion_keyids =np.array(read_json(SOURCE_MOTS_p+'/keyids_embeddings.json')) mfix_train, mfix_test = load_motionfix(MFIX_p) all_mots = mfix_test | mfix_train scores = tmr_text_encoder.compute_scores(retrieve_text, embs=motion_embeds) sorted_idxs = np.argsort(-scores) best_keyids = motion_keyids[sorted_idxs] # best_scores = scores[sorted_idxs] top_mot = best_keyids[0] curvid = all_mots[top_mot]['motion_a'] text_annot = all_mots[top_mot]['annotation'] return curvid, text_annot with gr.Blocks(css=""" .gradio-row { display: flex; gap: 20px; } .gradio-column { flex: 1; } .gradio-container { display: flex; flex-direction: column; gap: 10px; } .gradio-button-row { display: flex; gap: 10px; } .gradio-textbox-row { display: flex; gap: 10px; align-items: center; } .gradio-edit-row { gap: 10px; align-items: center; } .gradio-textbox-with-button { display: flex; align-items: center; } .gradio-textbox-with-button input { flex-grow: 1; } """) as demo: gr.Markdown(WEBSITE) random_key_state = gr.State() with gr.Row(elem_id="gradio-row"): with gr.Column(scale=5, elem_id="gradio-column"): gr.Markdown(WEB_source) with gr.Row(elem_id="gradio-button-row"): # iterative_button = gr.Button("Iterative") # retrieve_button = gr.Button("TMRetrieve") random_button = gr.Button("Random") with gr.Row(elem_id="gradio-textbox-row"): with gr.Column(scale=5, elem_id="gradio-textbox-with-button"): # retrieve_text = gr.Textbox(placeholder="Type the text for the motion you want to Retrieve:", # show_label=True, label="Retrieval Text", # value=DEFAULT_TEXT) clear_button_retrieval = gr.Button("Clear", scale=0) with gr.Row(elem_id="gradio-textbox-row"): suggested_edit_text = gr.Textbox(placeholder="Texts likely to edit the motion:", show_label=True, label="Suggested Edit Text", value='') xxx = 'https://motion-editing.s3.eu-central-1.amazonaws.com/collection_wo_walks_runs/rendered_pairs/011327_120_240-002682_120_240.mp4' set_to_pick = gr.Radio(['all', 'train', 'test'], value='all', label="Set to pick from", info="Motion will be picked from whole dataset or test or train data.") # import ipdb; ipdb.set_trace() retrieved_video_output = gr.Video(label="Retrieved Motion", # value=xxx, height=360, width=480) with gr.Column(scale=5, elem_id="gradio-column"): gr.Markdown(WEB_target) with gr.Row(elem_id="gradio-edit-row"): clear_button_edit = gr.Button("Clear", scale=0) edit_button = gr.Button("Edit", scale=0) with gr.Row(elem_id="gradio-textbox-row"): input_text = gr.Textbox(placeholder="Type the edit text you want:", show_label=False, label="Input Text", value=DEFAULT_TEXT) video_output = gr.Video(label="Generated Video", height=360, width=480) def process_and_show_video(input_text, random_key_state): fname = show_video(input_text, random_key_state) return fname def process_and_retrieve_video(input_text): fname = retrieve_video(input_text) return fname from retrieval_loader import get_tmr_model from dataset_utils import load_motionfix edit_button.click(process_and_show_video, inputs=[input_text, random_key_state], outputs=video_output) # retrieve_button.click(process_and_retrieve_video, inputs=retrieve_text, outputs=[retrieved_video_output, suggested_edit_text]) random_button.click(random_source_motion, inputs=set_to_pick, outputs=[retrieved_video_output, suggested_edit_text, random_key_state, input_text]) print(random_key_state) clear_button_edit.click(clear, outputs=input_text) # clear_button_retrieval.click(clear, outputs=retrieve_text) gr.Markdown(CREDITS) demo.launch(share=True)