atnikos commited on
Commit
7d87cc1
·
1 Parent(s): d8530c7
Files changed (6) hide show
  1. app.py +14 -8
  2. gen_utils.py +3 -1
  3. geometry_utils.py +14 -14
  4. model_utils.py +4 -4
  5. text_encoder.py +1 -1
  6. tmed_denoiser.py +1 -1
app.py CHANGED
@@ -7,6 +7,7 @@ import random
7
  zero = torch.Tensor([0]).cuda()
8
  print(zero.device) # <-- 'cpu' 🤔
9
  # G&uumll Varol
 
10
 
11
  WEBSITE = """
12
  <div class="embed_hidden">
@@ -61,7 +62,8 @@ def download_models():
61
  with gr.Blocks() as demo:
62
  gr.Markdown(WEBSITE)
63
 
64
- input_text = gr.Textbox(label="Input Text")
 
65
  # output_text = gr.Textbox(label="Output Text")
66
 
67
  with gr.Row():
@@ -76,16 +78,18 @@ with gr.Blocks() as demo:
76
  from tmed_denoiser import TMED_denoiser
77
  model_ckpt = download_models()
78
  checkpoint = torch.load(model_ckpt)
79
- print(checkpoint.keys())
80
  checkpoint = {k.replace('denoiser.', ''): v for k, v in checkpoint.items()}
81
- tmed_denoiser = TMED_denoiser().load_state_dict(checkpoint, strict=False)
 
 
82
  text_encoder = ClipTextEncoder()
83
- texts_cond = [input_text]
 
84
  diffusion_process = create_diffusion(timestep_respacing=None,
85
  learn_sigma=False, sigma_small=True,
86
  diffusion_steps=300,
87
  noise_schedule='squaredcos_cap_v2',
88
- predict_type='sample',
89
  predict_xstart=True) # noise vs sample
90
  # uncond_tokens = [""] * len(texts_cond)
91
  # if self.condition == 'text':
@@ -97,6 +101,7 @@ with gr.Blocks() as demo:
97
  no_of_texts = len(texts_cond)
98
  texts_cond = ['']*no_of_texts + texts_cond
99
  texts_cond = ['']*no_of_texts + texts_cond
 
100
  text_emb, text_mask = text_encoder(texts_cond)
101
 
102
  cond_emb_motion = torch.zeros(1, bsz,
@@ -107,8 +112,9 @@ with gr.Blocks() as demo:
107
  mask_target = torch.ones((1, bsz),
108
  dtype=bool, device='cuda')
109
  # complete noise
110
- diff_out = tmed_denoiser.diffusion_reverse(text_emb,
111
- text_mask,
 
112
  cond_emb_motion,
113
  cond_motion_mask,
114
  mask_target,
@@ -118,7 +124,7 @@ with gr.Blocks() as demo:
118
  gd_text=4.0,
119
  gd_motion=2.0,
120
  steps_num=300)
121
- edited_motion = diffout2motion(diff_out)
122
  clear_button.click(clear, outputs=input_text)
123
  random_button.click(random_number, outputs=input_text)
124
 
 
7
  zero = torch.Tensor([0]).cuda()
8
  print(zero.device) # <-- 'cpu' 🤔
9
  # G&uumll Varol
10
+ DEFAULT_TEXT = "A person is "
11
 
12
  WEBSITE = """
13
  <div class="embed_hidden">
 
62
  with gr.Blocks() as demo:
63
  gr.Markdown(WEBSITE)
64
 
65
+ input_text = gr.Textbox(placeholder="Type the edit text you want:",
66
+ show_label=True,label="Input Text", value=DEFAULT_TEXT)
67
  # output_text = gr.Textbox(label="Output Text")
68
 
69
  with gr.Row():
 
78
  from tmed_denoiser import TMED_denoiser
79
  model_ckpt = download_models()
80
  checkpoint = torch.load(model_ckpt)
81
+
82
  checkpoint = {k.replace('denoiser.', ''): v for k, v in checkpoint.items()}
83
+ tmed_denoiser = TMED_denoiser().to('cuda')
84
+ tmed_denoiser.load_state_dict(checkpoint, strict=False)
85
+ tmed_denoiser.eval()
86
  text_encoder = ClipTextEncoder()
87
+ texts_cond = [input_text.value]
88
+
89
  diffusion_process = create_diffusion(timestep_respacing=None,
90
  learn_sigma=False, sigma_small=True,
91
  diffusion_steps=300,
92
  noise_schedule='squaredcos_cap_v2',
 
93
  predict_xstart=True) # noise vs sample
94
  # uncond_tokens = [""] * len(texts_cond)
95
  # if self.condition == 'text':
 
101
  no_of_texts = len(texts_cond)
102
  texts_cond = ['']*no_of_texts + texts_cond
103
  texts_cond = ['']*no_of_texts + texts_cond
104
+ print(texts_cond)
105
  text_emb, text_mask = text_encoder(texts_cond)
106
 
107
  cond_emb_motion = torch.zeros(1, bsz,
 
112
  mask_target = torch.ones((1, bsz),
113
  dtype=bool, device='cuda')
114
  # complete noise
115
+ # import ipdb;ipdb.set_trace()
116
+ diff_out = tmed_denoiser._diffusion_reverse(text_emb.to(cond_emb_motion.device),
117
+ text_mask.to(cond_emb_motion.device),
118
  cond_emb_motion,
119
  cond_motion_mask,
120
  mask_target,
 
124
  gd_text=4.0,
125
  gd_motion=2.0,
126
  steps_num=300)
127
+ edited_motion = diffout2motion(diff_out, normalizer)
128
  clear_button.click(clear, outputs=input_text)
129
  random_button.click(random_number, outputs=input_text)
130
 
gen_utils.py CHANGED
@@ -8,4 +8,6 @@ def cast_dict_to_tensors(d, device="cpu"):
8
  elif isinstance(d, torch.Tensor):
9
  return d.to(device)
10
  else:
11
- return d
 
 
 
8
  elif isinstance(d, torch.Tensor):
9
  return d.to(device)
10
  else:
11
+ return d
12
+
13
+
geometry_utils.py CHANGED
@@ -1,6 +1,7 @@
1
  import torch
 
2
 
3
- def diffout2motion(diffout):
4
 
5
  # - "body_transl_delta_pelv_xy_wo_z"
6
  # - "body_transl_z"
@@ -8,19 +9,19 @@ def diffout2motion(diffout):
8
  # - "body_orient_xy"
9
  # - "body_pose"
10
  # - "body_joints_local_wo_z_rot"
11
- feats_unnorm = self.cat_inputs(self.unnorm_inputs(
12
- self.uncat_inputs(diffout,
13
- self.input_feats_dims),
14
- self.input_feats))[0]
15
  # FIRST POSE FOR GENERATION & DELTAS FOR INTEGRATION
16
- if "body_joints_local_wo_z_rot" in self.input_feats:
17
- idx = self.input_feats.index("body_joints_local_wo_z_rot")
18
- feats_unnorm = feats_unnorm[..., :-self.input_feats_dims[idx]]
19
 
20
  first_trans = torch.zeros(*diffout.shape[:-1], 3,
21
- device=self.device)[:, [0]]
22
- if 'z_orient_delta' in self.input_feats:
23
- first_orient_z = torch.eye(3, device=self.device).unsqueeze(0) # Now the shape is (1, 1, 3, 3)
24
  first_orient_z = first_orient_z.repeat(feats_unnorm.shape[0], 1, 1) # Now the shape is (B, 1, 3, 3)
25
  first_orient_z = transform_body_pose(first_orient_z, 'rot->6d')
26
 
@@ -28,7 +29,6 @@ def diffout2motion(diffout):
28
  # integrate z orient delta --> z component tof orientation
29
  z_orient_delta = feats_unnorm[..., 9:15]
30
 
31
- from src.tools.transforms3d import apply_rot_delta, remove_z_rot, get_z_rot, change_for
32
  prev_z = first_orient_z
33
  full_z_angle = [first_orient_z[:, None]]
34
  for i in range(1, z_orient_delta.shape[1]):
@@ -52,14 +52,14 @@ def diffout2motion(diffout):
52
  full_global_orient = transform_body_pose(full_global_orient_rotmat,
53
  'rot->6d')
54
 
55
- first_trans = self.cat_inputs(self.unnorm_inputs(
56
  [first_trans],
57
  ['body_transl'])
58
  )[0]
59
 
60
  # apply deltas
61
  # get velocity in global c.f. and add it to the state position
62
- assert 'body_transl_delta_pelv' in self.input_feats
63
  pelvis_delta = feats_unnorm[..., :3]
64
  trans_vel_pelv = change_for(pelvis_delta[:, 1:],
65
  full_global_orient_rotmat[:, :-1],
 
1
  import torch
2
+ from transform3d import transform_body_pose, apply_rot_delta, remove_z_rot, get_z_rot, change_for
3
 
4
+ def diffout2motion(diffout, normalizer):
5
 
6
  # - "body_transl_delta_pelv_xy_wo_z"
7
  # - "body_transl_z"
 
9
  # - "body_orient_xy"
10
  # - "body_pose"
11
  # - "body_joints_local_wo_z_rot"
12
+ feats_unnorm = normalizer.cat_inputs(normalizer.unnorm_inputs(
13
+ normalizer.uncat_inputs(diffout,
14
+ normalizer.input_feats_dims),
15
+ normalizer.input_feats))[0]
16
  # FIRST POSE FOR GENERATION & DELTAS FOR INTEGRATION
17
+ if "body_joints_local_wo_z_rot" in normalizer.input_feats:
18
+ idx = normalizer.input_feats.index("body_joints_local_wo_z_rot")
19
+ feats_unnorm = feats_unnorm[..., :-normalizer.input_feats_dims[idx]]
20
 
21
  first_trans = torch.zeros(*diffout.shape[:-1], 3,
22
+ device='cuda')[:, [0]]
23
+ if 'z_orient_delta' in normalizer.input_feats:
24
+ first_orient_z = torch.eye(3, device='cuda').unsqueeze(0) # Now the shape is (1, 1, 3, 3)
25
  first_orient_z = first_orient_z.repeat(feats_unnorm.shape[0], 1, 1) # Now the shape is (B, 1, 3, 3)
26
  first_orient_z = transform_body_pose(first_orient_z, 'rot->6d')
27
 
 
29
  # integrate z orient delta --> z component tof orientation
30
  z_orient_delta = feats_unnorm[..., 9:15]
31
 
 
32
  prev_z = first_orient_z
33
  full_z_angle = [first_orient_z[:, None]]
34
  for i in range(1, z_orient_delta.shape[1]):
 
52
  full_global_orient = transform_body_pose(full_global_orient_rotmat,
53
  'rot->6d')
54
 
55
+ first_trans = normalizer.cat_inputs(normalizer.unnorm_inputs(
56
  [first_trans],
57
  ['body_transl'])
58
  )[0]
59
 
60
  # apply deltas
61
  # get velocity in global c.f. and add it to the state position
62
+ assert 'body_transl_delta_pelv' in normalizer.input_feats
63
  pelvis_delta = feats_unnorm[..., :3]
64
  trans_vel_pelv = change_for(pelvis_delta[:, 1:],
65
  full_global_orient_rotmat[:, :-1],
model_utils.py CHANGED
@@ -16,7 +16,7 @@ class TimestepEmbedderMDM(nn.Module):
16
  nn.Linear(self.latent_dim, time_embed_dim),
17
  nn.SiLU(),
18
  nn.Linear(time_embed_dim, time_embed_dim),
19
- )
20
 
21
  def forward(self, timesteps):
22
  return self.time_embed(self.sequence_pos_encoder.pe[timesteps]).permute(1, 0, 2)
@@ -34,11 +34,11 @@ class PositionalEncoding(nn.Module):
34
  self.negative = negative
35
 
36
  if negative:
37
- pe = torch.zeros(2*max_len, d_model)
38
  position = torch.arange(-max_len, max_len, dtype=torch.float).unsqueeze(1)
39
  else:
40
- pe = torch.zeros(max_len, d_model)
41
- position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
42
 
43
  div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-np.log(10000.0) / d_model))
44
  pe[:, 0::2] = torch.sin(position * div_term)
 
16
  nn.Linear(self.latent_dim, time_embed_dim),
17
  nn.SiLU(),
18
  nn.Linear(time_embed_dim, time_embed_dim),
19
+ ).to('cuda')
20
 
21
  def forward(self, timesteps):
22
  return self.time_embed(self.sequence_pos_encoder.pe[timesteps]).permute(1, 0, 2)
 
34
  self.negative = negative
35
 
36
  if negative:
37
+ pe = torch.zeros(2*max_len, d_model,device='cuda')
38
  position = torch.arange(-max_len, max_len, dtype=torch.float).unsqueeze(1)
39
  else:
40
+ pe = torch.zeros(max_len, d_model,device='cuda')
41
+ position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
42
 
43
  div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-np.log(10000.0) / d_model))
44
  pe[:, 0::2] = torch.sin(position * div_term)
text_encoder.py CHANGED
@@ -7,7 +7,7 @@ from torch import Tensor, nn
7
  class ClipTextEncoder(nn.Module):
8
  def __init__(
9
  self,
10
- modelpath: str='deps/clip-vit-large-patch14', # clip-vit-base-patch32
11
  finetune: bool = False,
12
  **kwargs
13
  ) -> None:
 
7
  class ClipTextEncoder(nn.Module):
8
  def __init__(
9
  self,
10
+ modelpath: str='openai/clip-vit-large-patch14', # clip-vit-base-patch32
11
  finetune: bool = False,
12
  **kwargs
13
  ) -> None:
tmed_denoiser.py CHANGED
@@ -83,7 +83,7 @@ class TMED_denoiser(nn.Module):
83
 
84
  # 1. time_embeddingno
85
  # broadcast to batch dimension in a way that's compatible with ONNX/Core ML
86
- timesteps = timestep.expand(noised_motion.shape[1]).clone()
87
  time_emb = self.embed_timestep(timesteps).to(dtype=noised_motion.dtype)
88
  # make it S first
89
  # time_emb = self.time_embedding(time_emb).unsqueeze(0)
 
83
 
84
  # 1. time_embeddingno
85
  # broadcast to batch dimension in a way that's compatible with ONNX/Core ML
86
+ timesteps = timestep.expand(noised_motion.shape[1]).clone().to(noised_motion.device)
87
  time_emb = self.embed_timestep(timesteps).to(dtype=noised_motion.dtype)
88
  # make it S first
89
  # time_emb = self.time_embedding(time_emb).unsqueeze(0)