Spaces:
Running
Running
| from calendar import EPOCH | |
| from geometry_utils import diffout2motion | |
| import gradio as gr | |
| import spaces | |
| import torch | |
| import os | |
| from pathlib import Path | |
| import smplx | |
| from body_renderer import get_render | |
| import numpy as np | |
| from download_deps import get_smpl_models, download_models, download_model_config | |
| from download_deps import download_tmr, download_motionfix, download_motionfix_dataset | |
| from download_deps import download_embeddings | |
| from website import CREDITS, WEB_source, WEB_target, WEBSITE | |
| # import cv2 | |
| # import moderngl | |
| # ctx = moderngl.create_context(standalone=True) | |
| # print(ctx) | |
| # sdk_version: 5.5.0 | |
| access_token_smpl = os.environ.get('HF_SMPL_TOKEN') | |
| os.environ["PYOPENGL_PLATFORM"] = "egl" | |
| zero = torch.Tensor([0]).cuda() | |
| print(zero.device) # <-- 'cuda:0' 🤗 | |
| DEFAULT_TEXT = "do it slower " | |
| def greet(n): | |
| print(zero.device) # <-- 'cuda:0' 🤗 | |
| try: | |
| number = float(n) | |
| except ValueError: | |
| return "Invalid input. Please enter a number." | |
| return f"Hello {zero + number} Tensor" | |
| def clear(): | |
| return "" | |
| def show_video(input_text, key_to_use): | |
| from normalization import Normalizer | |
| normalizer = Normalizer() | |
| from diffusion import create_diffusion | |
| from text_encoder import ClipTextEncoder | |
| from tmed_denoiser import TMED_denoiser | |
| model_ckpt = download_models() | |
| infeats = download_model_config() | |
| checkpoint = torch.load(model_ckpt) | |
| # motion_to_edit = download_motion_from_dataset(key_to_use) | |
| # ds_sample = joblib.load(motion_to_edit) | |
| ds_sample = MFIX_DATASET_DICT[key_to_use] | |
| from feature_extractor import FEAT_GET_METHODS | |
| data_dict_source = {f'{feat}_source': FEAT_GET_METHODS[feat](ds_sample['motion_source'])[None].to('cuda') | |
| for feat in infeats} | |
| data_dict_target = {f'{feat}_target': FEAT_GET_METHODS[feat](ds_sample['motion_target'])[None].to('cuda') | |
| for feat in infeats} | |
| full_batch = data_dict_source | data_dict_target | |
| in_batch = normalizer.norm_and_cat(full_batch, infeats) | |
| source_motion_norm = in_batch['source'] | |
| target_motion_norm = in_batch['target'] | |
| seqlen_tgt = source_motion_norm.shape[0] | |
| seqlen_src = target_motion_norm.shape[0] | |
| # import ipdb; ipdb.set_trace() | |
| checkpoint = {k.replace('denoiser.', ''): v for k, v in checkpoint.items()} | |
| tmed_denoiser = TMED_denoiser().to('cuda') | |
| tmed_denoiser.load_state_dict(checkpoint, strict=False) | |
| tmed_denoiser.eval() | |
| text_encoder = ClipTextEncoder() | |
| texts_cond = [input_text] | |
| diffusion_process = create_diffusion(timestep_respacing=None, | |
| learn_sigma=False, sigma_small=True, | |
| diffusion_steps=300, | |
| noise_schedule='squaredcos_cap_v2', | |
| predict_xstart=True) | |
| bsz = 1 | |
| no_of_texts = len(texts_cond) | |
| texts_cond = ['']*no_of_texts + texts_cond | |
| texts_cond = ['']*no_of_texts + texts_cond | |
| text_emb, text_mask = text_encoder(texts_cond) | |
| cond_emb_motion = source_motion_norm | |
| cond_motion_mask = torch.ones((bsz, seqlen_src), | |
| dtype=bool, device='cuda') | |
| mask_target = torch.ones((bsz, seqlen_tgt), | |
| dtype=bool, device='cuda') | |
| diff_out = tmed_denoiser._diffusion_reverse(text_emb.to(cond_emb_motion.device), | |
| text_mask.to(cond_emb_motion.device), | |
| cond_emb_motion, | |
| cond_motion_mask, | |
| mask_target, | |
| diffusion_process, | |
| init_vec=None, | |
| init_from='noise', | |
| gd_text=2.0, | |
| gd_motion=2.0, | |
| steps_num=300) | |
| edited_motion = diffout2motion(diff_out.permute(1,0,2), normalizer).squeeze() | |
| gt_source = diffout2motion(source_motion_norm.permute(1,0,2), | |
| normalizer).squeeze() | |
| # import ipdb; ipdb.set_trace() | |
| # aitrenderer = get_renderer() | |
| # SMPL_LAYER = SMPLLayer(model_type='smplh', ext='npz', gender='neutral') | |
| # edited_mot_to_render = pack_to_render(rots=edited_motion[..., 3:], | |
| # trans=edited_motion[..., :3]) | |
| SMPL_MODELS_PATH = str(Path(get_smpl_models())) | |
| body_model=smplx.SMPLHLayer(f"{SMPL_MODELS_PATH}/smplh", | |
| model_type='smplh', | |
| gender='neutral',ext='npz') | |
| # run_smpl_fwd_verticesbody_model, body_transl, body_orient, body_pose, | |
| # edited_mot_to_render | |
| from body_renderer import get_render | |
| from transform3d import transform_body_pose | |
| # import ipdb; ipdb.set_trace() | |
| edited_motion_aa = transform_body_pose(edited_motion[:, 3:], | |
| '6d->aa') | |
| gt_source_aa = transform_body_pose(gt_source[:, 3:], | |
| '6d->aa') | |
| if os.path.exists('./output_movie.mp4'): | |
| os.remove('./output_movie.mp4') | |
| from transform3d import rotate_body_degrees | |
| gen_motion_trans = edited_motion[..., :3].detach().cpu() | |
| gen_motion_rots_aa = edited_motion_aa.detach().cpu() | |
| source_motion_trans = gt_source[..., :3].detach().cpu() | |
| source_motion_rots_aa = gt_source_aa.detach().cpu() | |
| gen_rots_rotated, gen_trans_rotated = rotate_body_degrees(transform_body_pose( | |
| gen_motion_rots_aa, | |
| 'aa->rot'), | |
| gen_motion_trans, offset=np.pi) | |
| src_rots_rotated, src_trans_rotated = rotate_body_degrees(transform_body_pose( | |
| source_motion_rots_aa, | |
| 'aa->rot'), | |
| source_motion_trans, offset=np.pi) | |
| src_rots_rotated_aa = transform_body_pose(src_rots_rotated, | |
| 'rot->aa') | |
| gen_rots_rotated_aa = transform_body_pose(gen_rots_rotated, | |
| 'rot->aa') | |
| fname = get_render(body_model, | |
| [gen_trans_rotated, src_trans_rotated], | |
| [gen_rots_rotated_aa[:, 0], src_rots_rotated_aa[:, 0]], | |
| [gen_rots_rotated_aa[:, 1:], src_rots_rotated_aa[:, 1:]], | |
| output_path='./output_movie.mp4', | |
| text='', colors=['sky blue', 'red']) | |
| # fname = render_motion(AIT_RENDERER, [edited_mot_to_render], | |
| # f"movie_example--{str(xx)}", | |
| # pose_repr='aa', | |
| # color=[color_map['generated']], | |
| # smpl_layer=SMPL_LAYER) | |
| print(fname) | |
| print(os.path.abspath(fname)) | |
| return fname | |
| MFIX_p = download_motionfix() + '/motionfix' | |
| SOURCE_MOTS_p = download_embeddings() + '/embeddings' | |
| MFIX_DATASET_DICT = download_motionfix_dataset() | |
| import gradio as gr | |
| def clear(): | |
| return "" | |
| def random_source_motion(set_to_pick): | |
| # import ipdb;ipdb.set_trace() | |
| mfix_train, mfix_test = load_motionfix(MFIX_p) | |
| if set_to_pick == 'all': | |
| current_set = mfix_test | mfix_train | |
| elif set_to_pick == 'train': | |
| current_set = mfix_train | |
| elif set_to_pick == 'test': | |
| current_set = mfix_test | |
| import random | |
| random_key = random.choice(list(current_set.keys())) | |
| curvid = current_set[random_key]['motion_a'] | |
| text_annot = current_set[random_key]['annotation'] | |
| return curvid, text_annot, random_key, text_annot | |
| def retrieve_video(retrieve_text): | |
| tmr_text_encoder = get_tmr_model(download_tmr()) | |
| # import ipdb;ipdb.set_trace() | |
| # text_encoded = tmr_text_encoder([retrieve_text]) | |
| motion_embeds = None | |
| from gen_utils import read_json | |
| import numpy as np | |
| motion_embeds = torch.load(SOURCE_MOTS_p+'/source_motions_embeddings.pt') | |
| motion_keyids =np.array(read_json(SOURCE_MOTS_p+'/keyids_embeddings.json')) | |
| mfix_train, mfix_test = load_motionfix(MFIX_p) | |
| all_mots = mfix_test | mfix_train | |
| scores = tmr_text_encoder.compute_scores(retrieve_text, embs=motion_embeds) | |
| sorted_idxs = np.argsort(-scores) | |
| best_keyids = motion_keyids[sorted_idxs] | |
| # best_scores = scores[sorted_idxs] | |
| top_mot = best_keyids[0] | |
| curvid = all_mots[top_mot]['motion_a'] | |
| text_annot = all_mots[top_mot]['annotation'] | |
| return curvid, text_annot | |
| with gr.Blocks(css=""" | |
| .gradio-row { | |
| display: flex; | |
| gap: 20px; | |
| } | |
| .gradio-column { | |
| flex: 1; | |
| } | |
| .gradio-container { | |
| display: flex; | |
| flex-direction: column; | |
| gap: 10px; | |
| } | |
| .gradio-button-row { | |
| display: flex; | |
| gap: 10px; | |
| } | |
| .gradio-textbox-row { | |
| display: flex; | |
| gap: 10px; | |
| align-items: center; | |
| } | |
| .gradio-edit-row { | |
| gap: 10px; | |
| align-items: center; | |
| } | |
| .gradio-textbox-with-button { | |
| display: flex; | |
| align-items: center; | |
| } | |
| .gradio-textbox-with-button input { | |
| flex-grow: 1; | |
| } | |
| """) as demo: | |
| gr.Markdown(WEBSITE) | |
| random_key_state = gr.State() | |
| with gr.Row(elem_id="gradio-row"): | |
| with gr.Column(scale=5, elem_id="gradio-column"): | |
| gr.Markdown(WEB_source) | |
| with gr.Row(elem_id="gradio-button-row"): | |
| # iterative_button = gr.Button("Iterative") | |
| # retrieve_button = gr.Button("TMRetrieve") | |
| random_button = gr.Button("Random") | |
| with gr.Row(elem_id="gradio-textbox-row"): | |
| with gr.Column(scale=5, elem_id="gradio-textbox-with-button"): | |
| # retrieve_text = gr.Textbox(placeholder="Type the text for the motion you want to Retrieve:", | |
| # show_label=True, label="Retrieval Text", | |
| # value=DEFAULT_TEXT) | |
| clear_button_retrieval = gr.Button("Clear", scale=0) | |
| with gr.Row(elem_id="gradio-textbox-row"): | |
| suggested_edit_text = gr.Textbox(placeholder="Texts likely to edit the motion:", | |
| show_label=True, label="Suggested Edit Text", | |
| value='') | |
| xxx = 'https://motion-editing.s3.eu-central-1.amazonaws.com/collection_wo_walks_runs/rendered_pairs/011327_120_240-002682_120_240.mp4' | |
| set_to_pick = gr.Radio(['all', 'train', 'test'], | |
| value='all', | |
| label="Set to pick from", | |
| info="Motion will be picked from whole dataset or test or train data.") | |
| # import ipdb; ipdb.set_trace() | |
| retrieved_video_output = gr.Video(label="Retrieved Motion", | |
| # value=xxx, | |
| height=360, width=480) | |
| with gr.Column(scale=5, elem_id="gradio-column"): | |
| gr.Markdown(WEB_target) | |
| with gr.Row(elem_id="gradio-edit-row"): | |
| clear_button_edit = gr.Button("Clear", scale=0) | |
| edit_button = gr.Button("Edit", scale=0) | |
| with gr.Row(elem_id="gradio-textbox-row"): | |
| input_text = gr.Textbox(placeholder="Type the edit text you want:", | |
| show_label=False, label="Input Text", | |
| value=DEFAULT_TEXT) | |
| video_output = gr.Video(label="Generated Video", height=360, | |
| width=480) | |
| def process_and_show_video(input_text, random_key_state): | |
| fname = show_video(input_text, random_key_state) | |
| return fname | |
| def process_and_retrieve_video(input_text): | |
| fname = retrieve_video(input_text) | |
| return fname | |
| from retrieval_loader import get_tmr_model | |
| from dataset_utils import load_motionfix | |
| edit_button.click(process_and_show_video, inputs=[input_text, random_key_state], outputs=video_output) | |
| # retrieve_button.click(process_and_retrieve_video, inputs=retrieve_text, outputs=[retrieved_video_output, suggested_edit_text]) | |
| random_button.click(random_source_motion, inputs=set_to_pick, | |
| outputs=[retrieved_video_output, | |
| suggested_edit_text, | |
| random_key_state, | |
| input_text]) | |
| print(random_key_state) | |
| clear_button_edit.click(clear, outputs=input_text) | |
| # clear_button_retrieval.click(clear, outputs=retrieve_text) | |
| gr.Markdown(CREDITS) | |
| demo.launch(share=True) | |