from gen_utils import extract_ckpt import hydra import os from hydra.utils import instantiate from gen_utils import read_config from model_utils import collate_x_dict import torch from tmr_model import TMR_textencoder def load_model_from_cfg(cfg, ckpt_name="last", device="cuda", eval_mode=True): import src.prepare # noqa import torch run_dir = cfg.run_dir model = hydra.utils.instantiate(cfg.model) # Loading modules one by one # motion_encoder / text_encoder / text_decoder pt_path = os.path.join(run_dir, f"{ckpt_name}_weights") if not os.path.exists(pt_path): extract_ckpt(run_dir, ckpt_name) for fname in os.listdir(pt_path): module_name, ext = os.path.splitext(fname) if ext != ".pt": continue module = getattr(model, module_name, None) if module is None: continue module_path = os.path.join(pt_path, fname) state_dict = torch.load(module_path) module.load_state_dict(state_dict) model = model.to(device) if eval_mode: model = model.eval() return model # def get_tmr_model(run_dir): # from gen_utils import read_config # cfg = read_config(run_dir+'/tmr') # import ipdb;ipdb.set_trace() # text_model = instantiate(cfg.data.text_to_token_emb, device='cuda') # model = load_model_from_cfg(cfg, 'last', eval_mode=True, device='cuda') # return text_model, model def get_tmr_model(run_dir): text_params = { "latent_dim": 256, "ff_size": 1024, "num_layers": 6, "num_heads": 4, "activation": "gelu", "modelpath": "distilbert-base-uncased", } "unit_motion_embs" model = TMR_textencoder(**text_params) state_dict = torch.load(f"{run_dir}/tmr/last_weights/text_encoder.pt", map_location='cuda') # load values for the transformer only model.load_state_dict(state_dict, strict=False) model = model.eval() return model.to('cuda')