import torch import torch.nn as nn from model_utils import TimestepEmbedderMDM from model_utils import PositionalEncoding class TMED_denoiser(nn.Module): def __init__(self, nfeats: int = 207, condition: str = "text", latent_dim: list = 512, ff_size: int = 1024, num_layers: int = 8, num_heads: int = 4, dropout: float = 0.1, activation: str = "gelu", text_encoded_dim: int = 768, pred_delta_motion: bool = False, use_sep: bool = True, motion_condition: str = 'source', **kwargs) -> None: super().__init__() self.latent_dim = latent_dim self.pred_delta_motion = pred_delta_motion self.text_encoded_dim = text_encoded_dim self.condition = condition self.feat_comb_coeff = nn.Parameter(torch.tensor([1.0])) self.pose_proj_in_source = nn.Linear(nfeats, self.latent_dim) self.pose_proj_in_target = nn.Linear(nfeats, self.latent_dim) self.pose_proj_out = nn.Linear(self.latent_dim, nfeats) self.first_pose_proj = nn.Linear(self.latent_dim, nfeats) self.motion_condition = motion_condition # emb proj if self.condition in ["text", "text_uncond"]: # text condition # project time from text_encoded_dim to latent_dim self.embed_timestep = TimestepEmbedderMDM(self.latent_dim) # FIXME me TODO this # self.time_embedding = TimestepEmbedderMDM(self.latent_dim) # project time+text to latent_dim if text_encoded_dim != self.latent_dim: # todo 10.24 debug why relu self.emb_proj = nn.Linear(text_encoded_dim, self.latent_dim) else: raise TypeError(f"condition type {self.condition} not supported") self.use_sep = use_sep self.query_pos = PositionalEncoding(self.latent_dim, dropout) self.mem_pos = PositionalEncoding(self.latent_dim, dropout) if self.motion_condition == "source": if self.use_sep: self.sep_token = nn.Parameter(torch.randn(1, self.latent_dim)) # use torch transformer encoder_layer = nn.TransformerEncoderLayer( d_model=self.latent_dim, nhead=num_heads, dim_feedforward=ff_size, dropout=dropout, activation=activation) self.encoder = nn.TransformerEncoder(encoder_layer, num_layers=num_layers) def forward(self, noised_motion, timestep, in_motion_mask, text_embeds, condition_mask, motion_embeds=None, lengths=None, **kwargs): # 0. dimension matching # noised_motion [latent_dim[0], batch_size, latent_dim] <= [batch_size, latent_dim[0], latent_dim[1]] bs = noised_motion.shape[0] noised_motion = noised_motion.permute(1, 0, 2) # 0. check lengths for no vae (diffusion only) # if lengths not in [None, []]: motion_in_mask = in_motion_mask # time_embedding | text_embedding | frames_source | frames_target # 1 * lat_d | max_text * lat_d | max_frames * lat_d | max_frames * lat_d # 1. time_embeddingno # broadcast to batch dimension in a way that's compatible with ONNX/Core ML timesteps = timestep.expand(noised_motion.shape[1]).clone() time_emb = self.embed_timestep(timesteps).to(dtype=noised_motion.dtype) # make it S first # time_emb = self.time_embedding(time_emb).unsqueeze(0) if self.condition in ["text", "text_uncond"]: # make it seq first text_embeds = text_embeds.permute(1, 0, 2) if self.text_encoded_dim != self.latent_dim: # [1 or 2, bs, latent_dim] <= [1 or 2, bs, text_encoded_dim] text_emb_latent = self.emb_proj(text_embeds) else: text_emb_latent = text_embeds # source_motion_zeros = torch.zeros(*noised_motion.shape[:2], # self.latent_dim, # device=noised_motion.device) # aux_fake_mask = torch.zeros(condition_mask.shape[0], # noised_motion.shape[0], # device=noised_motion.device) # condition_mask = torch.cat((condition_mask, aux_fake_mask), # 1).bool().to(noised_motion.device) emb_latent = torch.cat((time_emb, text_emb_latent), 0) if motion_embeds is not None: zeroes_mask = (motion_embeds == 0).all(dim=-1) if motion_embeds.shape[-1] != self.latent_dim: motion_embeds_proj = self.pose_proj_in_source(motion_embeds) motion_embeds_proj[zeroes_mask] = 0 else: motion_embeds_proj = motion_embeds else: raise TypeError(f"condition type {self.condition} not supported") # 4. transformer # if self.diffusion_only: proj_noised_motion = self.pose_proj_in_target(noised_motion) if motion_embeds is None: xseq = torch.cat((emb_latent, proj_noised_motion), axis=0) else: if self.use_sep: sep_token_batch = torch.tile(self.sep_token, (bs,)).reshape(bs, -1) xseq = torch.cat((emb_latent, motion_embeds_proj, sep_token_batch[None], proj_noised_motion), axis=0) else: xseq = torch.cat((emb_latent, motion_embeds_proj, proj_noised_motion), axis=0) # if self.ablation_skip_connection: # xseq = self.query_pos(xseq) # tokens = self.encoder(xseq) # else: # # adding the timestep embed # # [seqlen+1, bs, d] # # todo change to query_pos_decoder xseq = self.query_pos(xseq) # BUILD the mask now if motion_embeds is None: time_token_mask = torch.ones((bs, time_emb.shape[0]), dtype=bool, device=xseq.device) aug_mask = torch.cat((time_token_mask, condition_mask[:, :text_emb_latent.shape[0]], motion_in_mask), 1) else: time_token_mask = torch.ones((bs, time_emb.shape[0]), dtype=bool, device=xseq.device) if self.use_sep: sep_token_mask = torch.ones((bs, self.sep_token.shape[0]), dtype=bool, device=xseq.device) if self.use_sep: aug_mask = torch.cat((time_token_mask, condition_mask[:, :text_emb_latent.shape[0]], condition_mask[:, text_emb_latent.shape[0]:], sep_token_mask, motion_in_mask, ), 1) else: aug_mask = torch.cat((time_token_mask, condition_mask[:, :text_emb_latent.shape[0]], condition_mask[:, text_emb_latent.shape[0]:], motion_in_mask, ), 1) tokens = self.encoder(xseq, src_key_padding_mask=~aug_mask) # if self.diffusion_only: if motion_embeds is not None: denoised_motion_proj = tokens[emb_latent.shape[0]:] if self.use_sep: useful_tokens = motion_embeds_proj.shape[0]+1 else: useful_tokens = motion_embeds_proj.shape[0] denoised_motion_proj = denoised_motion_proj[useful_tokens:] else: denoised_motion_proj = tokens[emb_latent.shape[0]:] denoised_motion = self.pose_proj_out(denoised_motion_proj) if self.pred_delta_motion and motion_embeds is not None: import torch.nn.functional as F tgt_size = len(denoised_motion) if len(denoised_motion) > len(motion_embeds): pad_for_src = tgt_size - len(motion_embeds) motion_embeds = F.pad(motion_embeds, (0, 0, 0, 0, 0, pad_for_src)) denoised_motion = denoised_motion + motion_embeds[:tgt_size] denoised_motion[~motion_in_mask.T] = 0 # zero for padded area # else: # sample = tokens[:sample.shape[0]] # 5. [batch_size, latent_dim[0], latent_dim[1]] <= [latent_dim[0], batch_size, latent_dim[1]] denoised_motion = denoised_motion.permute(1, 0, 2) return denoised_motion def forward_with_guidance(self, noised_motion, timestep, in_motion_mask, text_embeds, condition_mask, guidance_motion, guidance_text_n_motion, motion_embeds=None, lengths=None, inpaint_dict=None, max_steps=None, prob_way='3way', **kwargs): # if motion embeds is None # TODO put here that you have tow # implement 2 cases for that case # text unconditional more or less 2 replicas # timestep if max_steps is not None: curr_ts = timestep[0].item() g_m = max(1, guidance_motion*2*curr_ts/max_steps) guidance_motion = g_m g_t_tm = max(1, guidance_text_n_motion*2*curr_ts/max_steps) guidance_text_n_motion = g_t_tm if motion_embeds is None: half = noised_motion[: len(noised_motion) // 2] combined = torch.cat([half, half], dim=0) model_out = self.forward(combined, timestep, in_motion_mask=in_motion_mask, text_embeds=text_embeds, condition_mask=condition_mask, motion_embeds=motion_embeds, lengths=lengths) uncond_eps, cond_eps_text = torch.split(model_out, len(model_out) // 2, dim=0) # make it BxSxfeatures if inpaint_dict is not None: import torch.nn.functional as F source_mot = inpaint_dict['start_motion'].permute(1, 0, 2) if source_mot.shape[1] >= uncond_eps.shape[1]: source_mot = source_mot[:, :uncond_eps.shape[1]] else: pad = uncond_eps.shape[1] - source_mot.shape[1] # Pad the tensor on the second dimension (time) source_mot = F.pad(source_mot, (0, 0, 0, pad), 'constant', 0) mot_len = source_mot.shape[1] # concat mask for all the frames mask_src_parts = inpaint_dict['mask'].unsqueeze(1).repeat(1, mot_len, 1) uncond_eps = uncond_eps*(mask_src_parts) + source_mot*(~mask_src_parts) cond_eps_text = cond_eps_text*(mask_src_parts) + source_mot*(~mask_src_parts) half_eps = uncond_eps + guidance_text_n_motion * (cond_eps_text - uncond_eps) eps = torch.cat([half_eps, half_eps], dim=0) else: third = noised_motion[: len(noised_motion) // 3] combined = torch.cat([third, third, third], dim=0) model_out = self.forward(combined, timestep, in_motion_mask=in_motion_mask, text_embeds=text_embeds, condition_mask=condition_mask, motion_embeds=motion_embeds, lengths=lengths) # For exact reproducibility reasons, we apply classifier-free guidance on only # three channels by default. The standard approach to cfg applies it to all channels. # This can be done by uncommenting the following line and commenting-out the line following that. # eps, rest = model_out[:, :self.in_channels], model_out[:, self.in_channels:] # eps, rest = model_out[:, :3], model_out[:, 3:] uncond_eps, cond_eps_motion, cond_eps_text_n_motion = torch.split(model_out, len(model_out) // 3, dim=0) if inpaint_dict is not None: import torch.nn.functional as F source_mot = inpaint_dict['start_motion'].permute(1, 0, 2) if source_mot.shape[1] >= uncond_eps.shape[1]: source_mot = source_mot[:, :uncond_eps.shape[1]] else: pad = uncond_eps.shape[1] - source_mot.shape[1] # Pad the tensor on the second dimension (time) source_mot = F.pad(source_mot, (0, 0, 0, pad), 'constant', 0) mot_len = source_mot.shape[1] # concat mask for all the frames mask_src_parts = inpaint_dict['mask'].unsqueeze(1).repeat(1, mot_len, 1) uncond_eps = uncond_eps*(~mask_src_parts) + source_mot*mask_src_parts cond_eps_text = cond_eps_text*(~mask_src_parts) + source_mot*mask_src_parts cond_eps_text_n_motion = cond_eps_text_n_motion*(~mask_src_parts) + source_mot*mask_src_parts if prob_way=='3way': third_eps = uncond_eps + guidance_motion * (cond_eps_motion - uncond_eps) + \ guidance_text_n_motion * (cond_eps_text_n_motion - cond_eps_motion) if prob_way=='2way': third_eps = uncond_eps + guidance_text_n_motion * (cond_eps_text_n_motion - uncond_eps) eps = torch.cat([third_eps, third_eps, third_eps], dim=0) return eps def _diffusion_reverse(self, text_embeds, text_masks_from_enc, motion_embeds, cond_motion_masks, inp_motion_mask, diff_process, init_vec=None, init_from='noise', gd_text=None, gd_motion=None, mode='full_cond', return_init_noise=False, steps_num=None, inpaint_dict=None, use_linear=False, prob_way='3way'): # guidance_scale_text: 7.5 # # guidance_scale_motion: 1.5 # init latents bsz = inp_motion_mask.shape[0] assert mode in ['full_cond', 'text_cond', 'mot_cond'] assert inp_motion_mask is not None # len_to_gen = max(lengths) if not self.input_deltas else max(lengths) + 1 if init_vec is None: initial_latents = torch.randn( (bsz, inp_motion_mask.shape[1], 207), device=inp_motion_mask.device, dtype=torch.float, ) else: initial_latents = init_vec gd_scale_text = 2.0 gd_scale_motion = 4.0 if text_embeds is not None: max_text_len = text_embeds.shape[1] else: max_text_len = 0 max_motion_len = cond_motion_masks.shape[1] text_masks = text_masks_from_enc.clone() nomotion_mask = torch.zeros(bsz, max_motion_len, dtype=torch.bool).to('cuda') motion_masks = torch.cat([nomotion_mask, cond_motion_masks, cond_motion_masks], dim=0) aug_mask = torch.cat([text_masks, motion_masks], dim=1) # Setup classifier-free guidance: if motion_embeds is not None: z = torch.cat([initial_latents, initial_latents, initial_latents], 0) else: z = torch.cat([initial_latents, initial_latents], 0) # y_null = torch.tensor([1000] * n, device=device) # y = torch.cat([y, y_null], 0) if use_linear: max_steps_diff = diff_process.num_timesteps else: max_steps_diff = None if motion_embeds is not None: model_kwargs = dict(# noised_motion=latent_model_input, # timestep=t, in_motion_mask=torch.cat([inp_motion_mask, inp_motion_mask, inp_motion_mask], 0), text_embeds=text_embeds, condition_mask=aug_mask, motion_embeds=torch.cat([torch.zeros_like(motion_embeds), motion_embeds, motion_embeds], 1), guidance_motion=gd_motion, guidance_text_n_motion=gd_text, inpaint_dict=inpaint_dict, max_steps=max_steps_diff, prob_way=prob_way) else: model_kwargs = dict(# noised_motion=latent_model_input, # timestep=t, in_motion_mask=torch.cat([inp_motion_mask, inp_motion_mask], 0), text_embeds=text_embeds, condition_mask=aug_mask, motion_embeds=None, guidance_motion=gd_motion, guidance_text_n_motion=gd_text, inpaint_dict=inpaint_dict, max_steps=max_steps_diff) # model_kwargs = dict(y=y, cfg_scale=args.cfg_scale) # Sample images: samples = diff_process.p_sample_loop(self.forward_with_guidance, z.shape, z, clip_denoised=False, model_kwargs=model_kwargs, progress=True, device=initial_latents.device,) _, _, samples = samples.chunk(3, dim=0) # Remove null class samples final_diffout = samples.permute(1, 0, 2) if return_init_noise: return initial_latents, final_diffout else: return final_diffout