Spaces:
Running
Running
| from gen_utils import extract_ckpt | |
| import hydra | |
| import os | |
| from hydra.utils import instantiate | |
| from gen_utils import read_config | |
| from model_utils import collate_x_dict | |
| import torch | |
| from tmr_model import TMR_textencoder | |
| def load_model_from_cfg(cfg, ckpt_name="last", device="cuda", eval_mode=True): | |
| import src.prepare # noqa | |
| import torch | |
| run_dir = cfg.run_dir | |
| model = hydra.utils.instantiate(cfg.model) | |
| # Loading modules one by one | |
| # motion_encoder / text_encoder / text_decoder | |
| pt_path = os.path.join(run_dir, f"{ckpt_name}_weights") | |
| if not os.path.exists(pt_path): | |
| extract_ckpt(run_dir, ckpt_name) | |
| for fname in os.listdir(pt_path): | |
| module_name, ext = os.path.splitext(fname) | |
| if ext != ".pt": | |
| continue | |
| module = getattr(model, module_name, None) | |
| if module is None: | |
| continue | |
| module_path = os.path.join(pt_path, fname) | |
| state_dict = torch.load(module_path) | |
| module.load_state_dict(state_dict) | |
| model = model.to(device) | |
| if eval_mode: | |
| model = model.eval() | |
| return model | |
| # def get_tmr_model(run_dir): | |
| # from gen_utils import read_config | |
| # cfg = read_config(run_dir+'/tmr') | |
| # import ipdb;ipdb.set_trace() | |
| # text_model = instantiate(cfg.data.text_to_token_emb, device='cuda') | |
| # model = load_model_from_cfg(cfg, 'last', eval_mode=True, device='cuda') | |
| # return text_model, model | |
| def get_tmr_model(run_dir): | |
| text_params = { | |
| "latent_dim": 256, | |
| "ff_size": 1024, | |
| "num_layers": 6, | |
| "num_heads": 4, | |
| "activation": "gelu", | |
| "modelpath": "distilbert-base-uncased", | |
| } | |
| "unit_motion_embs" | |
| model = TMR_textencoder(**text_params) | |
| state_dict = torch.load(f"{run_dir}/tmr/last_weights/text_encoder.pt", | |
| map_location='cuda') | |
| # load values for the transformer only | |
| model.load_state_dict(state_dict, strict=False) | |
| model = model.eval() | |
| return model.to('cuda') | |