Spaces:
Running
on
Zero
Running
on
Zero
| from __future__ import annotations | |
| from functools import partial | |
| from math import ceil | |
| import os | |
| from accelerate.utils import DistributedDataParallelKwargs | |
| from beartype.typing import Tuple, Callable, List | |
| from einops import rearrange, repeat, reduce, pack | |
| from gateloop_transformer import SimpleGateLoopLayer | |
| from huggingface_hub import PyTorchModelHubMixin | |
| import numpy as np | |
| import trimesh | |
| from tqdm import tqdm | |
| import torch | |
| from torch import nn, Tensor | |
| from torch.nn import Module, ModuleList | |
| import torch.nn.functional as F | |
| from pytorch3d.loss import chamfer_distance | |
| from pytorch3d.transforms import euler_angles_to_matrix | |
| from x_transformers import Decoder | |
| from x_transformers.x_transformers import LayerIntermediates | |
| from x_transformers.autoregressive_wrapper import eval_decorator | |
| from .michelangelo import ShapeConditioner as ShapeConditioner_miche | |
| from .utils import ( | |
| discretize, | |
| undiscretize, | |
| set_module_requires_grad_, | |
| default, | |
| exists, | |
| safe_cat, | |
| identity, | |
| is_tensor_empty, | |
| ) | |
| from .utils.typing import Float, Int, Bool, typecheck | |
| # constants | |
| DEFAULT_DDP_KWARGS = DistributedDataParallelKwargs( | |
| find_unused_parameters = True | |
| ) | |
| SHAPE_CODE = { | |
| 'CubeBevel': 0, | |
| 'SphereSharp': 1, | |
| 'CylinderSharp': 2, | |
| } | |
| BS_NAME = { | |
| 0: 'CubeBevel', | |
| 1: 'SphereSharp', | |
| 2: 'CylinderSharp', | |
| } | |
| # FiLM block | |
| class FiLM(Module): | |
| def __init__(self, dim, dim_out = None): | |
| super().__init__() | |
| dim_out = default(dim_out, dim) | |
| self.to_gamma = nn.Linear(dim, dim_out, bias = False) | |
| self.to_beta = nn.Linear(dim, dim_out) | |
| self.gamma_mult = nn.Parameter(torch.zeros(1,)) | |
| self.beta_mult = nn.Parameter(torch.zeros(1,)) | |
| def forward(self, x, cond): | |
| gamma, beta = self.to_gamma(cond), self.to_beta(cond) | |
| gamma, beta = tuple(rearrange(t, 'b d -> b 1 d') for t in (gamma, beta)) | |
| # for initializing to identity | |
| gamma = (1 + self.gamma_mult * gamma.tanh()) | |
| beta = beta.tanh() * self.beta_mult | |
| # classic film | |
| return x * gamma + beta | |
| # gateloop layers | |
| class GateLoopBlock(Module): | |
| def __init__( | |
| self, | |
| dim, | |
| *, | |
| depth, | |
| use_heinsen = True | |
| ): | |
| super().__init__() | |
| self.gateloops = ModuleList([]) | |
| for _ in range(depth): | |
| gateloop = SimpleGateLoopLayer(dim = dim, use_heinsen = use_heinsen) | |
| self.gateloops.append(gateloop) | |
| def forward( | |
| self, | |
| x, | |
| cache = None | |
| ): | |
| received_cache = exists(cache) | |
| if is_tensor_empty(x): | |
| return x, None | |
| if received_cache: | |
| prev, x = x[:, :-1], x[:, -1:] | |
| cache = default(cache, []) | |
| cache = iter(cache) | |
| new_caches = [] | |
| for gateloop in self.gateloops: | |
| layer_cache = next(cache, None) | |
| out, new_cache = gateloop(x, cache = layer_cache, return_cache = True) | |
| new_caches.append(new_cache) | |
| x = x + out | |
| if received_cache: | |
| x = torch.cat((prev, x), dim = -2) | |
| return x, new_caches | |
| def top_k_2(logits, frac_num_tokens=0.1, k=None): | |
| num_tokens = logits.shape[-1] | |
| k = default(k, ceil(frac_num_tokens * num_tokens)) | |
| k = min(k, num_tokens) | |
| val, ind = torch.topk(logits, k) | |
| probs = torch.full_like(logits, float('-inf')) | |
| probs.scatter_(2, ind, val) | |
| return probs | |
| def soft_argmax(labels): | |
| indices = torch.arange(labels.size(-1), dtype=labels.dtype, device=labels.device) | |
| soft_argmax = torch.sum(labels * indices, dim=-1) | |
| return soft_argmax | |
| class PrimitiveTransformerDiscrete(Module, PyTorchModelHubMixin): | |
| def __init__( | |
| self, | |
| *, | |
| num_discrete_scale = 128, | |
| continuous_range_scale: List[float, float] = [0, 1], | |
| dim_scale_embed = 64, | |
| num_discrete_rotation = 180, | |
| continuous_range_rotation: List[float, float] = [-180, 180], | |
| dim_rotation_embed = 64, | |
| num_discrete_translation = 128, | |
| continuous_range_translation: List[float, float] = [-1, 1], | |
| dim_translation_embed = 64, | |
| num_type = 3, | |
| dim_type_embed = 64, | |
| embed_order = 'ctrs', | |
| bin_smooth_blur_sigma = 0.4, | |
| dim: int | Tuple[int, int] = 512, | |
| flash_attn = True, | |
| attn_depth = 12, | |
| attn_dim_head = 64, | |
| attn_heads = 16, | |
| attn_kwargs: dict = dict( | |
| ff_glu = True, | |
| attn_num_mem_kv = 4 | |
| ), | |
| max_primitive_len = 144, | |
| dropout = 0., | |
| coarse_pre_gateloop_depth = 2, | |
| coarse_post_gateloop_depth = 0, | |
| coarse_adaptive_rmsnorm = False, | |
| gateloop_use_heinsen = False, | |
| pad_id = -1, | |
| num_sos_tokens = None, | |
| condition_on_shape = True, | |
| shape_cond_with_cross_attn = False, | |
| shape_cond_with_film = False, | |
| shape_cond_with_cat = False, | |
| shape_condition_model_type = 'michelangelo', | |
| shape_condition_len = 1, | |
| shape_condition_dim = None, | |
| cross_attn_num_mem_kv = 4, # needed for preventing nan when dropping out shape condition | |
| loss_weight: dict = dict( | |
| eos = 1.0, | |
| type = 1.0, | |
| scale = 1.0, | |
| rotation = 1.0, | |
| translation = 1.0, | |
| reconstruction = 1.0, | |
| scale_huber = 1.0, | |
| rotation_huber = 1.0, | |
| translation_huber = 1.0, | |
| ), | |
| bs_pc_dir=None, | |
| ): | |
| super().__init__() | |
| # feature embedding | |
| self.num_discrete_scale = num_discrete_scale | |
| self.continuous_range_scale = continuous_range_scale | |
| self.discretize_scale = partial(discretize, num_discrete=num_discrete_scale, continuous_range=continuous_range_scale) | |
| self.undiscretize_scale = partial(undiscretize, num_discrete=num_discrete_scale, continuous_range=continuous_range_scale) | |
| self.scale_embed = nn.Embedding(num_discrete_scale, dim_scale_embed) | |
| self.num_discrete_rotation = num_discrete_rotation | |
| self.continuous_range_rotation = continuous_range_rotation | |
| self.discretize_rotation = partial(discretize, num_discrete=num_discrete_rotation, continuous_range=continuous_range_rotation) | |
| self.undiscretize_rotation = partial(undiscretize, num_discrete=num_discrete_rotation, continuous_range=continuous_range_rotation) | |
| self.rotation_embed = nn.Embedding(num_discrete_rotation, dim_rotation_embed) | |
| self.num_discrete_translation = num_discrete_translation | |
| self.continuous_range_translation = continuous_range_translation | |
| self.discretize_translation = partial(discretize, num_discrete=num_discrete_translation, continuous_range=continuous_range_translation) | |
| self.undiscretize_translation = partial(undiscretize, num_discrete=num_discrete_translation, continuous_range=continuous_range_translation) | |
| self.translation_embed = nn.Embedding(num_discrete_translation, dim_translation_embed) | |
| self.num_type = num_type | |
| self.type_embed = nn.Embedding(num_type, dim_type_embed) | |
| self.embed_order = embed_order | |
| self.bin_smooth_blur_sigma = bin_smooth_blur_sigma | |
| # initial dimension | |
| self.dim = dim | |
| init_dim = 3 * (dim_scale_embed + dim_rotation_embed + dim_translation_embed) + dim_type_embed | |
| # project into model dimension | |
| self.project_in = nn.Linear(init_dim, dim) | |
| num_sos_tokens = default(num_sos_tokens, 1 if not condition_on_shape or not shape_cond_with_film else 4) | |
| assert num_sos_tokens > 0 | |
| self.num_sos_tokens = num_sos_tokens | |
| self.sos_token = nn.Parameter(torch.randn(num_sos_tokens, dim)) | |
| # the transformer eos token | |
| self.eos_token = nn.Parameter(torch.randn(1, dim)) | |
| self.emb_layernorm = nn.LayerNorm(dim) | |
| self.max_seq_len = max_primitive_len | |
| # shape condition | |
| self.condition_on_shape = condition_on_shape | |
| self.shape_cond_with_cross_attn = False | |
| self.shape_cond_with_cat = False | |
| self.shape_condition_model_type = '' | |
| self.conditioner = None | |
| dim_shape = None | |
| if condition_on_shape: | |
| assert shape_cond_with_cross_attn or shape_cond_with_film or shape_cond_with_cat | |
| self.shape_cond_with_cross_attn = shape_cond_with_cross_attn | |
| self.shape_cond_with_cat = shape_cond_with_cat | |
| self.shape_condition_model_type = shape_condition_model_type | |
| if 'michelangelo' in shape_condition_model_type: | |
| self.conditioner = ShapeConditioner_miche(dim_latent=shape_condition_dim) | |
| self.to_cond_dim = nn.Linear(self.conditioner.dim_model_out * 2, self.conditioner.dim_latent) | |
| self.to_cond_dim_head = nn.Linear(self.conditioner.dim_model_out, self.conditioner.dim_latent) | |
| else: | |
| raise ValueError(f'unknown shape_condition_model_type {self.shape_condition_model_type}') | |
| dim_shape = self.conditioner.dim_latent | |
| set_module_requires_grad_(self.conditioner, False) | |
| self.shape_coarse_film_cond = FiLM(dim_shape, dim) if shape_cond_with_film else identity | |
| self.coarse_gateloop_block = GateLoopBlock(dim, depth=coarse_pre_gateloop_depth, use_heinsen=gateloop_use_heinsen) if coarse_pre_gateloop_depth > 0 else None | |
| self.coarse_post_gateloop_block = GateLoopBlock(dim, depth=coarse_post_gateloop_depth, use_heinsen=gateloop_use_heinsen) if coarse_post_gateloop_depth > 0 else None | |
| self.coarse_adaptive_rmsnorm = coarse_adaptive_rmsnorm | |
| self.decoder = Decoder( | |
| dim=dim, | |
| depth=attn_depth, | |
| heads=attn_heads, | |
| attn_dim_head=attn_dim_head, | |
| attn_flash=flash_attn, | |
| attn_dropout=dropout, | |
| ff_dropout=dropout, | |
| use_adaptive_rmsnorm=coarse_adaptive_rmsnorm, | |
| dim_condition=dim_shape, | |
| cross_attend=self.shape_cond_with_cross_attn, | |
| cross_attn_dim_context=dim_shape, | |
| cross_attn_num_mem_kv=cross_attn_num_mem_kv, | |
| **attn_kwargs | |
| ) | |
| # to logits | |
| self.to_eos_logits = nn.Sequential( | |
| nn.Linear(dim, dim), | |
| nn.ReLU(), | |
| nn.Linear(dim, 1) | |
| ) | |
| self.to_type_logits = nn.Sequential( | |
| nn.Linear(dim, dim), | |
| nn.ReLU(), | |
| nn.Linear(dim, num_type) | |
| ) | |
| self.to_translation_logits = nn.Sequential( | |
| nn.Linear(dim + dim_type_embed, dim), | |
| nn.ReLU(), | |
| nn.Linear(dim, 3 * num_discrete_translation) | |
| ) | |
| self.to_rotation_logits = nn.Sequential( | |
| nn.Linear(dim + dim_type_embed + 3 * dim_translation_embed, dim), | |
| nn.ReLU(), | |
| nn.Linear(dim, 3 * num_discrete_rotation) | |
| ) | |
| self.to_scale_logits = nn.Sequential( | |
| nn.Linear(dim + dim_type_embed + 3 * (dim_translation_embed + dim_rotation_embed), dim), | |
| nn.ReLU(), | |
| nn.Linear(dim, 3 * num_discrete_scale) | |
| ) | |
| self.pad_id = pad_id | |
| bs_pc_map = {} | |
| for bs_name, type_code in SHAPE_CODE.items(): | |
| pc = trimesh.load(os.path.join(bs_pc_dir, f'SM_GR_BS_{bs_name}_001.ply')) | |
| bs_pc_map[type_code] = torch.from_numpy(np.asarray(pc.vertices)).float() | |
| bs_pc_list = [] | |
| for i in range(len(bs_pc_map)): | |
| bs_pc_list.append(bs_pc_map[i]) | |
| self.bs_pc = torch.stack(bs_pc_list, dim=0) | |
| self.rotation_matrix_align_coord = euler_angles_to_matrix( | |
| torch.Tensor([np.pi/2, 0, 0]), 'XYZ').unsqueeze(0).unsqueeze(0) | |
| def device(self): | |
| return next(self.parameters()).device | |
| def embed_pc(self, pc: Tensor): | |
| if 'michelangelo' in self.shape_condition_model_type: | |
| pc_head, pc_embed = self.conditioner(shape=pc) | |
| pc_embed = torch.cat([self.to_cond_dim_head(pc_head), self.to_cond_dim(pc_embed)], dim=-2).detach() | |
| else: | |
| raise ValueError(f'unknown shape_condition_model_type {self.shape_condition_model_type}') | |
| return pc_embed | |
| def recon_primitives( | |
| self, | |
| scale_logits: Float['b np 3 nd'], | |
| rotation_logits: Float['b np 3 nd'], | |
| translation_logits: Float['b np 3 nd'], | |
| type_logits: Int['b np nd'], | |
| primitive_mask: Bool['b np'] | |
| ): | |
| recon_scale = self.undiscretize_scale(scale_logits.argmax(dim=-1)) | |
| recon_scale = recon_scale.masked_fill(~primitive_mask.unsqueeze(-1), float('nan')) | |
| recon_rotation = self.undiscretize_rotation(rotation_logits.argmax(dim=-1)) | |
| recon_rotation = recon_rotation.masked_fill(~primitive_mask.unsqueeze(-1), float('nan')) | |
| recon_translation = self.undiscretize_translation(translation_logits.argmax(dim=-1)) | |
| recon_translation = recon_translation.masked_fill(~primitive_mask.unsqueeze(-1), float('nan')) | |
| recon_type_code = type_logits.argmax(dim=-1) | |
| recon_type_code = recon_type_code.masked_fill(~primitive_mask, -1) | |
| return { | |
| 'scale': recon_scale, | |
| 'rotation': recon_rotation, | |
| 'translation': recon_translation, | |
| 'type_code': recon_type_code | |
| } | |
| def sample_primitives( | |
| self, | |
| scale: Float['b np 3 nd'], | |
| rotation: Float['b np 3 nd'], | |
| translation: Float['b np 3 nd'], | |
| type_code: Int['b np nd'], | |
| next_embed: Float['b 1 nd'], | |
| temperature: float = 1., | |
| filter_logits_fn: Callable = top_k_2, | |
| filter_kwargs: dict = dict() | |
| ): | |
| def sample_func(logits): | |
| if logits.ndim == 4: | |
| enable_squeeze = True | |
| logits = logits.squeeze(1) | |
| else: | |
| enable_squeeze = False | |
| filtered_logits = filter_logits_fn(logits, **filter_kwargs) | |
| if temperature == 0.: | |
| sample = filtered_logits.argmax(dim=-1) | |
| else: | |
| probs = F.softmax(filtered_logits / temperature, dim=-1) | |
| sample = torch.zeros((probs.shape[0], probs.shape[1]), dtype=torch.long, device=probs.device) | |
| for b_i in range(probs.shape[0]): | |
| sample[b_i] = torch.multinomial(probs[b_i], 1).squeeze() | |
| if enable_squeeze: | |
| sample = sample.unsqueeze(1) | |
| return sample | |
| next_type_logits = self.to_type_logits(next_embed) | |
| next_type_code = sample_func(next_type_logits) | |
| type_code_new, _ = pack([type_code, next_type_code], 'b *') | |
| type_embed = self.type_embed(next_type_code) | |
| next_embed_packed, _ = pack([next_embed, type_embed], 'b np *') | |
| next_translation_logits = rearrange(self.to_translation_logits(next_embed_packed), 'b np (c nd) -> b np c nd', nd=self.num_discrete_translation) | |
| next_discretize_translation = sample_func(next_translation_logits) | |
| next_translation = self.undiscretize_translation(next_discretize_translation) | |
| translation_new, _ = pack([translation, next_translation], 'b * nd') | |
| next_translation_embed = self.translation_embed(next_discretize_translation) | |
| next_embed_packed, _ = pack([next_embed_packed, next_translation_embed], 'b np *') | |
| next_rotation_logits = rearrange(self.to_rotation_logits(next_embed_packed), 'b np (c nd) -> b np c nd', nd=self.num_discrete_rotation) | |
| next_discretize_rotation = sample_func(next_rotation_logits) | |
| next_rotation = self.undiscretize_rotation(next_discretize_rotation) | |
| rotation_new, _ = pack([rotation, next_rotation], 'b * nd') | |
| next_rotation_embed = self.rotation_embed(next_discretize_rotation) | |
| next_embed_packed, _ = pack([next_embed_packed, next_rotation_embed], 'b np *') | |
| next_scale_logits = rearrange(self.to_scale_logits(next_embed_packed), 'b np (c nd) -> b np c nd', nd=self.num_discrete_scale) | |
| next_discretize_scale = sample_func(next_scale_logits) | |
| next_scale = self.undiscretize_scale(next_discretize_scale) | |
| scale_new, _ = pack([scale, next_scale], 'b * nd') | |
| return ( | |
| scale_new, | |
| rotation_new, | |
| translation_new, | |
| type_code_new | |
| ) | |
| def generate( | |
| self, | |
| batch_size: int | None = None, | |
| filter_logits_fn: Callable = top_k_2, | |
| filter_kwargs: dict = dict(), | |
| temperature: float = 1., | |
| scale: Float['b np 3'] | None = None, | |
| rotation: Float['b np 3'] | None = None, | |
| translation: Float['b np 3'] | None = None, | |
| type_code: Int['b np'] | None = None, | |
| pc: Tensor | None = None, | |
| pc_embed: Tensor | None = None, | |
| cache_kv = True, | |
| max_seq_len = None, | |
| ): | |
| max_seq_len = default(max_seq_len, self.max_seq_len) | |
| if exists(scale) and exists(rotation) and exists(translation) and exists(type_code): | |
| assert not exists(batch_size) | |
| assert scale.shape[1] == rotation.shape[1] == translation.shape[1] == type_code.shape[1] | |
| assert scale.shape[1] <= self.max_seq_len | |
| batch_size = scale.shape[0] | |
| if self.condition_on_shape: | |
| assert exists(pc) ^ exists(pc_embed), '`pc` or `pc_embed` must be passed in' | |
| if exists(pc): | |
| pc_embed = self.embed_pc(pc) | |
| batch_size = default(batch_size, pc_embed.shape[0]) | |
| batch_size = default(batch_size, 1) | |
| scale = default(scale, torch.empty((batch_size, 0, 3), dtype=torch.float64, device=self.device)) | |
| rotation = default(rotation, torch.empty((batch_size, 0, 3), dtype=torch.float64, device=self.device)) | |
| translation = default(translation, torch.empty((batch_size, 0, 3), dtype=torch.float64, device=self.device)) | |
| type_code = default(type_code, torch.empty((batch_size, 0), dtype=torch.int64, device=self.device)) | |
| curr_length = scale.shape[1] | |
| cache = None | |
| eos_codes = None | |
| for i in tqdm(range(curr_length, max_seq_len)): | |
| can_eos = i != 0 | |
| output = self.forward( | |
| scale=scale, | |
| rotation=rotation, | |
| translation=translation, | |
| type_code=type_code, | |
| pc_embed=pc_embed, | |
| return_loss=False, | |
| return_cache=cache_kv, | |
| append_eos=False, | |
| cache=cache | |
| ) | |
| if cache_kv: | |
| next_embed, cache = output | |
| else: | |
| next_embed = output | |
| ( | |
| scale, | |
| rotation, | |
| translation, | |
| type_code | |
| ) = self.sample_primitives( | |
| scale, | |
| rotation, | |
| translation, | |
| type_code, | |
| next_embed, | |
| temperature=temperature, | |
| filter_logits_fn=filter_logits_fn, | |
| filter_kwargs=filter_kwargs | |
| ) | |
| next_eos_logits = self.to_eos_logits(next_embed).squeeze(-1) | |
| next_eos_code = (F.sigmoid(next_eos_logits) > 0.5) | |
| eos_codes = safe_cat([eos_codes, next_eos_code], 1) | |
| if can_eos and eos_codes.any(dim=-1).all(): | |
| break | |
| # mask out to padding anything after the first eos | |
| mask = eos_codes.float().cumsum(dim=-1) >= 1 | |
| # concat cur_length to mask | |
| mask = torch.cat((torch.zeros((batch_size, curr_length), dtype=torch.bool, device=self.device), mask), dim=-1) | |
| type_code = type_code.masked_fill(mask, self.pad_id) | |
| scale = scale.masked_fill(mask.unsqueeze(-1), self.pad_id) | |
| rotation = rotation.masked_fill(mask.unsqueeze(-1), self.pad_id) | |
| translation = translation.masked_fill(mask.unsqueeze(-1), self.pad_id) | |
| recon_primitives = { | |
| 'scale': scale, | |
| 'rotation': rotation, | |
| 'translation': translation, | |
| 'type_code': type_code | |
| } | |
| primitive_mask = ~eos_codes | |
| return recon_primitives, primitive_mask | |
| def generate_w_recon_loss( | |
| self, | |
| batch_size: int | None = None, | |
| filter_logits_fn: Callable = top_k_2, | |
| filter_kwargs: dict = dict(), | |
| temperature: float = 1., | |
| scale: Float['b np 3'] | None = None, | |
| rotation: Float['b np 3'] | None = None, | |
| translation: Float['b np 3'] | None = None, | |
| type_code: Int['b np'] | None = None, | |
| pc: Tensor | None = None, | |
| pc_embed: Tensor | None = None, | |
| cache_kv = True, | |
| max_seq_len = None, | |
| single_directional = True, | |
| ): | |
| max_seq_len = default(max_seq_len, self.max_seq_len) | |
| if exists(scale) and exists(rotation) and exists(translation) and exists(type_code): | |
| assert not exists(batch_size) | |
| assert scale.shape[1] == rotation.shape[1] == translation.shape[1] == type_code.shape[1] | |
| assert scale.shape[1] <= self.max_seq_len | |
| batch_size = scale.shape[0] | |
| if self.condition_on_shape: | |
| assert exists(pc) ^ exists(pc_embed), '`pc` or `pc_embed` must be passed in' | |
| if exists(pc): | |
| pc_embed = self.embed_pc(pc) | |
| batch_size = default(batch_size, pc_embed.shape[0]) | |
| batch_size = default(batch_size, 1) | |
| assert batch_size == 1 # TODO: support any batch size | |
| scale = default(scale, torch.empty((batch_size, 0, 3), dtype=torch.float32, device=self.device)) | |
| rotation = default(rotation, torch.empty((batch_size, 0, 3), dtype=torch.float32, device=self.device)) | |
| translation = default(translation, torch.empty((batch_size, 0, 3), dtype=torch.float32, device=self.device)) | |
| type_code = default(type_code, torch.empty((batch_size, 0), dtype=torch.int64, device=self.device)) | |
| curr_length = scale.shape[1] | |
| cache = None | |
| eos_codes = None | |
| last_recon_loss = 1 | |
| for i in tqdm(range(curr_length, max_seq_len)): | |
| can_eos = i != 0 | |
| output = self.forward( | |
| scale=scale, | |
| rotation=rotation, | |
| translation=translation, | |
| type_code=type_code, | |
| pc_embed=pc_embed, | |
| return_loss=False, | |
| return_cache=cache_kv, | |
| append_eos=False, | |
| cache=cache | |
| ) | |
| if cache_kv: | |
| next_embed, cache = output | |
| else: | |
| next_embed = output | |
| ( | |
| scale_new, | |
| rotation_new, | |
| translation_new, | |
| type_code_new | |
| ) = self.sample_primitives( | |
| scale, | |
| rotation, | |
| translation, | |
| type_code, | |
| next_embed, | |
| temperature=temperature, | |
| filter_logits_fn=filter_logits_fn, | |
| filter_kwargs=filter_kwargs | |
| ) | |
| next_eos_logits = self.to_eos_logits(next_embed).squeeze(-1) | |
| next_eos_code = (F.sigmoid(next_eos_logits) > 0.5) | |
| eos_codes = safe_cat([eos_codes, next_eos_code], 1) | |
| if can_eos and eos_codes.any(dim=-1).all(): | |
| scale, rotation, translation, type_code = ( | |
| scale_new, rotation_new, translation_new, type_code_new) | |
| break | |
| recon_loss = self.compute_chamfer_distance(scale_new, rotation_new, translation_new, type_code_new, ~eos_codes, pc, single_directional) | |
| if recon_loss < last_recon_loss: | |
| last_recon_loss = recon_loss | |
| scale, rotation, translation, type_code = ( | |
| scale_new, rotation_new, translation_new, type_code_new) | |
| else: | |
| best_recon_loss = recon_loss | |
| best_primitives = dict( | |
| scale=scale_new, rotation=rotation_new, translation=translation_new, type_code=type_code_new) | |
| success_flag = False | |
| print(f'last_recon_loss:{last_recon_loss}, recon_loss:{recon_loss} -> to find better primitive') | |
| for try_i in range(5): | |
| ( | |
| scale_new, | |
| rotation_new, | |
| translation_new, | |
| type_code_new | |
| ) = self.sample_primitives( | |
| scale, | |
| rotation, | |
| translation, | |
| type_code, | |
| next_embed, | |
| temperature=1.0, | |
| filter_logits_fn=filter_logits_fn, | |
| filter_kwargs=filter_kwargs | |
| ) | |
| recon_loss = self.compute_chamfer_distance(scale_new, rotation_new, translation_new, type_code_new, ~eos_codes, pc) | |
| print(f'[try_{try_i}] last_recon_loss:{last_recon_loss}, best_recon_loss:{best_recon_loss}, cur_recon_loss:{recon_loss}') | |
| if recon_loss < last_recon_loss: | |
| last_recon_loss = recon_loss | |
| scale, rotation, translation, type_code = ( | |
| scale_new, rotation_new, translation_new, type_code_new) | |
| success_flag = True | |
| break | |
| else: | |
| if recon_loss < best_recon_loss: | |
| best_recon_loss = recon_loss | |
| best_primitives = dict( | |
| scale=scale_new, rotation=rotation_new, translation=translation_new, type_code=type_code_new) | |
| if not success_flag: | |
| last_recon_loss = best_recon_loss | |
| scale, rotation, translation, type_code = ( | |
| best_primitives['scale'], best_primitives['rotation'], best_primitives['translation'], best_primitives['type_code']) | |
| print(f'new_last_recon_loss:{last_recon_loss}') | |
| # mask out to padding anything after the first eos | |
| mask = eos_codes.float().cumsum(dim=-1) >= 1 | |
| type_code = type_code.masked_fill(mask, self.pad_id) | |
| scale = scale.masked_fill(mask.unsqueeze(-1), self.pad_id) | |
| rotation = rotation.masked_fill(mask.unsqueeze(-1), self.pad_id) | |
| translation = translation.masked_fill(mask.unsqueeze(-1), self.pad_id) | |
| recon_primitives = { | |
| 'scale': scale, | |
| 'rotation': rotation, | |
| 'translation': translation, | |
| 'type_code': type_code | |
| } | |
| primitive_mask = ~eos_codes | |
| return recon_primitives, primitive_mask | |
| def encode( | |
| self, | |
| *, | |
| scale: Float['b np 3'], | |
| rotation: Float['b np 3'], | |
| translation: Float['b np 3'], | |
| type_code: Int['b np'], | |
| primitive_mask: Bool['b np'], | |
| return_primitives = False | |
| ): | |
| """ | |
| einops: | |
| b - batch | |
| np - number of primitives | |
| c - coordinates (3) | |
| d - embed dim | |
| """ | |
| # compute feature embedding | |
| discretize_scale = self.discretize_scale(scale) | |
| scale_embed = self.scale_embed(discretize_scale) | |
| scale_embed = rearrange(scale_embed, 'b np c d -> b np (c d)') | |
| discretize_rotation = self.discretize_rotation(rotation) | |
| rotation_embed = self.rotation_embed(discretize_rotation) | |
| rotation_embed = rearrange(rotation_embed, 'b np c d -> b np (c d)') | |
| discretize_translation = self.discretize_translation(translation) | |
| translation_embed = self.translation_embed(discretize_translation) | |
| translation_embed = rearrange(translation_embed, 'b np c d -> b np (c d)') | |
| type_embed = self.type_embed(type_code.masked_fill(~primitive_mask, 0)) | |
| # combine all features and project into model dimension | |
| if self.embed_order == 'srtc': | |
| primitive_embed, _ = pack([scale_embed, rotation_embed, translation_embed, type_embed], 'b np *') | |
| else: | |
| primitive_embed, _ = pack([type_embed, translation_embed, rotation_embed, scale_embed], 'b np *') | |
| primitive_embed = self.project_in(primitive_embed) | |
| primitive_embed = primitive_embed.masked_fill(~primitive_mask.unsqueeze(-1), 0.) | |
| if not return_primitives: | |
| return primitive_embed | |
| primitive_embed_unpacked = { | |
| 'scale': scale_embed, | |
| 'rotation': rotation_embed, | |
| 'translation': translation_embed, | |
| 'type_code': type_embed | |
| } | |
| primitives_gt = { | |
| 'scale': discretize_scale, | |
| 'rotation': discretize_rotation, | |
| 'translation': discretize_translation, | |
| 'type_code': type_code | |
| } | |
| return primitive_embed, primitive_embed_unpacked, primitives_gt | |
| def compute_chamfer_distance( | |
| self, | |
| scale_pred: Float['b np 3'], | |
| rotation_pred: Float['b np 3'], | |
| translation_pred: Float['b np 3'], | |
| type_pred: Int['b np'], | |
| primitive_mask: Bool['b np'], | |
| pc: Tensor, # b, num_points, c | |
| single_directional = True | |
| ): | |
| scale_pred = scale_pred.float() | |
| rotation_pred = rotation_pred.float() | |
| translation_pred = translation_pred.float() | |
| pc_pred = apply_transformation(self.bs_pc.to(type_pred.device)[type_pred], scale_pred, torch.deg2rad(rotation_pred), translation_pred) | |
| pc_pred = torch.matmul(pc_pred, self.rotation_matrix_align_coord.to(type_pred.device)) | |
| pc_pred_flat = rearrange(pc_pred, 'b np p c -> b (np p) c') | |
| pc_pred_sampled = random_sample_pc(pc_pred_flat, primitive_mask.sum(dim=-1, keepdim=True), n_points=self.bs_pc.shape[1]) | |
| if single_directional: | |
| recon_loss, _ = chamfer_distance(pc[:, :, :3].float(), pc_pred_sampled.float(), single_directional=True) # single directional | |
| else: | |
| recon_loss, _ = chamfer_distance(pc_pred_sampled.float(), pc[:, :, :3].float()) | |
| return recon_loss | |
| def forward( | |
| self, | |
| *, | |
| scale: Float['b np 3'], | |
| rotation: Float['b np 3'], | |
| translation: Float['b np 3'], | |
| type_code: Int['b np'], | |
| loss_reduction: str = 'mean', | |
| return_cache = False, | |
| append_eos = True, | |
| cache: LayerIntermediates | None = None, | |
| pc: Tensor | None = None, | |
| pc_embed: Tensor | None = None, | |
| **kwargs | |
| ): | |
| primitive_mask = reduce(scale != self.pad_id, 'b np 3 -> b np', 'all') | |
| if scale.shape[1] > 0: | |
| codes, primitives_embeds, primitives_gt = self.encode( | |
| scale=scale, | |
| rotation=rotation, | |
| translation=translation, | |
| type_code=type_code, | |
| primitive_mask=primitive_mask, | |
| return_primitives=True | |
| ) | |
| else: | |
| codes = torch.empty((scale.shape[0], 0, self.dim), dtype=torch.float32, device=self.device) | |
| # handle shape conditions | |
| attn_context_kwargs = dict() | |
| if self.condition_on_shape: | |
| assert exists(pc) ^ exists(pc_embed), '`pc` or `pc_embed` must be passed in' | |
| if exists(pc): | |
| if 'michelangelo' in self.shape_condition_model_type: | |
| pc_head, pc_embed = self.conditioner(shape=pc) | |
| pc_embed = torch.cat([self.to_cond_dim_head(pc_head), self.to_cond_dim(pc_embed)], dim=-2) | |
| else: | |
| raise ValueError(f'unknown shape_condition_model_type {self.shape_condition_model_type}') | |
| assert pc_embed.shape[0] == codes.shape[0], 'batch size of point cloud is not equal to the batch size of the primitive codes' | |
| pooled_pc_embed = pc_embed.mean(dim=1) # (b, shape_condition_dim) | |
| if self.shape_cond_with_cross_attn: | |
| attn_context_kwargs = dict( | |
| context=pc_embed | |
| ) | |
| if self.coarse_adaptive_rmsnorm: | |
| attn_context_kwargs.update( | |
| condition=pooled_pc_embed | |
| ) | |
| batch, seq_len, _ = codes.shape # (b, np, dim) | |
| device = codes.device | |
| assert seq_len <= self.max_seq_len, f'received codes of length {seq_len} but needs to be less than or equal to set max_seq_len {self.max_seq_len}' | |
| if append_eos: | |
| assert exists(codes) | |
| code_lens = primitive_mask.sum(dim=-1) | |
| codes = pad_tensor(codes) | |
| batch_arange = torch.arange(batch, device=device) | |
| batch_arange = rearrange(batch_arange, '... -> ... 1') | |
| code_lens = rearrange(code_lens, '... -> ... 1') | |
| codes[batch_arange, code_lens] = self.eos_token # (b, np+1, dim) | |
| primitive_codes = codes # (b, np, dim) | |
| primitive_codes_len = primitive_codes.shape[-2] | |
| ( | |
| coarse_cache, | |
| coarse_gateloop_cache, | |
| coarse_post_gateloop_cache, | |
| ) = cache if exists(cache) else ((None,) * 3) | |
| if not exists(cache): | |
| sos = repeat(self.sos_token, 'n d -> b n d', b=batch) | |
| if self.shape_cond_with_cat: | |
| sos, _ = pack([pc_embed, sos], 'b * d') | |
| primitive_codes, packed_sos_shape = pack([sos, primitive_codes], 'b * d') # (b, n_sos+np, dim) | |
| # condition primitive codes with shape if needed | |
| if self.condition_on_shape: | |
| primitive_codes = self.shape_coarse_film_cond(primitive_codes, pooled_pc_embed) | |
| # attention on primitive codes (coarse) | |
| if exists(self.coarse_gateloop_block): | |
| primitive_codes, coarse_gateloop_cache = self.coarse_gateloop_block(primitive_codes, cache=coarse_gateloop_cache) | |
| attended_primitive_codes, coarse_cache = self.decoder( # (b, n_sos+np, dim) | |
| primitive_codes, | |
| cache=coarse_cache, | |
| return_hiddens=True, | |
| **attn_context_kwargs | |
| ) | |
| if exists(self.coarse_post_gateloop_block): | |
| primitive_codes, coarse_post_gateloop_cache = self.coarse_post_gateloop_block(primitive_codes, cache=coarse_post_gateloop_cache) | |
| embed = attended_primitive_codes[:, -(primitive_codes_len + 1):] # (b, np+1, dim) | |
| if not return_cache: | |
| return embed[:, -1:] | |
| next_cache = ( | |
| coarse_cache, | |
| coarse_gateloop_cache, | |
| coarse_post_gateloop_cache | |
| ) | |
| return embed[:, -1:], next_cache | |
| def pad_tensor(tensor): | |
| if tensor.dim() == 3: | |
| bs, seq_len, dim = tensor.shape | |
| padding = torch.zeros((bs, 1, dim), dtype=tensor.dtype, device=tensor.device) | |
| elif tensor.dim() == 2: | |
| bs, seq_len = tensor.shape | |
| padding = torch.zeros((bs, 1), dtype=tensor.dtype, device=tensor.device) | |
| else: | |
| raise ValueError('Unsupported tensor shape: {}'.format(tensor.shape)) | |
| return torch.cat([tensor, padding], dim=1) | |
| def apply_transformation(pc, scale, rotation_vector, translation): | |
| bs, np, num_points, _ = pc.shape | |
| scaled_pc = pc * scale.unsqueeze(2) | |
| rotation_matrix = euler_angles_to_matrix(rotation_vector.view(-1, 3), 'XYZ').view(bs, np, 3, 3) # euler tmp | |
| rotated_pc = torch.einsum('bnij,bnpj->bnpi', rotation_matrix, scaled_pc) | |
| transformed_pc = rotated_pc + translation.unsqueeze(2) | |
| return transformed_pc | |
| def random_sample_pc(pc, max_lens, n_points=10000): | |
| bs = max_lens.shape[0] | |
| max_len = max_lens.max().item() * n_points | |
| random_values = torch.rand(bs, max_len, device=max_lens.device) | |
| mask = torch.arange(max_len).expand(bs, max_len).to(max_lens.device) < (max_lens * n_points) | |
| masked_random_values = random_values * mask.float() | |
| _, indices = torch.topk(masked_random_values, n_points, dim=1) | |
| return pc[torch.arange(bs).unsqueeze(1), indices] |