import math import random from collections import OrderedDict from typing import Literal import torch from timm.layers import DropPath, Mlp from timm.models.vision_transformer import LayerScale from torch import nn from torch.nn import functional as F from torch.nn.attention import SDPBackend, sdpa_kernel class Attention(nn.Module): """ Multi-head attention module with optional query/key normalization. :param dim: Total feature dimension. :param num_heads: Number of attention heads. :param qkv_bias: Whether to include bias terms in linear projections. :param qk_norm: Whether to apply LayerNorm to individual head queries and keys. :param attn_drop: Dropout probability for attention weights. :param proj_drop: Dropout probability after the output projection. :param norm_layer: Normalization layer to use if qk_norm is True. :return: Output tensor of shape (B, N1, dim) after attention and projection. """ def __init__( self, dim: int, num_heads: int = 8, qkv_bias: bool = False, qk_norm: bool = False, attn_drop: float = 0.0, proj_drop: float = 0.0, norm_layer: nn.Module = nn.LayerNorm, ) -> None: super().__init__() assert dim % num_heads == 0, "dim should be divisible by num_heads" self.num_heads = num_heads self.head_dim = dim // num_heads self.scale = self.head_dim**-0.5 self.q = nn.Linear(dim, dim, bias=qkv_bias) self.kv = nn.Linear(dim, 2 * dim, bias=qkv_bias) self.q_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity() self.k_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity() self.attn_drop = attn_drop self.proj = nn.Linear(dim, dim) self.proj_drop = nn.Dropout(proj_drop) def forward(self, x: torch.Tensor, z: torch.Tensor) -> torch.Tensor: """ Forward pass for multi-head attention. :param x: Query tensor of shape (B, N1, dim). :param z: Key/Value tensor of shape (B, N2, dim). :return: Attention output tensor of shape (B, N1, dim). """ B, N1, C = x.shape B, N2, C = z.shape q = self.q(x).reshape([B, N1, self.num_heads, self.head_dim]).swapaxes(1, 2) kv = ( self.kv(z) .reshape(B, N2, 2, self.num_heads, self.head_dim) .permute(2, 0, 3, 1, 4) ) k, v = kv.unbind(0) q, k = self.q_norm(q), self.k_norm(k) with sdpa_kernel( [ SDPBackend.MATH, ] ): x = F.scaled_dot_product_attention( query=q, key=k, value=v, dropout_p=self.attn_drop, scale=self.scale ) x = x.transpose(1, 2).reshape(B, N1, C) x = self.proj(x) x = self.proj_drop(x) return x class Block(nn.Module): def __init__( self, dim: int, num_heads: int, mlp_ratio: float = 4.0, qkv_bias: bool = False, qk_norm: bool = False, proj_drop: float = 0.0, attn_drop: float = 0.0, init_values: float = None, drop_path: float = 0.0, act_layer: nn.Module = nn.GELU, norm_layer: nn.Module = nn.LayerNorm, mlp_layer: nn.Module = Mlp, ) -> None: """ Transformer block combining attention and MLP with residual connections and optional LayerScale and DropPath. :param dim: Feature dimension. :param num_heads: Number of attention heads. :param mlp_ratio: Ratio for hidden dimension in MLP. :param qkv_bias: Whether to include bias in QKV projections. :param qk_norm: Whether to normalize Q and K. :param proj_drop: Dropout probability after output projection. :param attn_drop: Dropout probability for attention. :param init_values: Initial value for LayerScale (if None, LayerScale is Identity). :param drop_path: Dropout probability for stochastic depth. :param act_layer: Activation layer for MLP. :param norm_layer: Normalization layer. :param mlp_layer: MLP module class. :return: Output tensor of shape (B, N, dim). """ super().__init__() self.x_norm = nn.LayerNorm(dim) self.z_norm = nn.LayerNorm(dim) self.attn = Attention( dim, num_heads=num_heads, qkv_bias=qkv_bias, qk_norm=qk_norm, attn_drop=attn_drop, proj_drop=proj_drop, norm_layer=norm_layer, ) self.ls1 = ( LayerScale(dim, init_values=init_values) if init_values else nn.Identity() ) self.drop_path1 = DropPath(drop_path) if drop_path > 0.0 else nn.Identity() self.norm2 = norm_layer(dim) self.mlp = mlp_layer( in_features=dim, hidden_features=int(dim * mlp_ratio), act_layer=act_layer, drop=proj_drop, ) self.ls2 = ( LayerScale(dim, init_values=init_values) if init_values else nn.Identity() ) self.drop_path2 = DropPath(drop_path) if drop_path > 0.0 else nn.Identity() def forward(self, x: torch.Tensor, z: torch.Tensor) -> torch.Tensor: """ Forward pass for a transformer block. :param x: Input tensor of shape (B, N, dim). :param z: Conditioning tensor for attention of same shape. :return: Output tensor of same shape after attention and MLP. """ x = x + self.drop_path1(self.ls1(self.attn(self.x_norm(x), self.z_norm(z)))) x = x + self.drop_path2(self.ls2(self.mlp(self.norm2(x)))) return x class HistaugModel(nn.Module): """ Hierarchical augmentation transformer model for embedding input features and augmentations. :param input_dim: Dimensionality of raw input features. :param depth: Number of transformer blocks. :param num_heads: Number of attention heads. :param mlp_ratio: Ratio for hidden features in MLP layers. :param use_transform_pos_embeddings: Whether to include sequence positional embeddings for augmentations. :param positional_encoding_type: Type for transform positional embeddings ('learnable' or 'sinusoidal'). :param final_activation: Name of activation layer for final head. :param chunk_size: Number of chunks to split the input. :param transforms: Dictionary containing augmentation parameter configurations. :param device: Device for tensors and buffers. :param kwargs: Additional unused keyword arguments. :return: Output tensor of shape (B, input_dim) after augmentation and transformer processing. """ def __init__( self, input_dim, depth, num_heads, mlp_ratio, use_transform_pos_embeddings=True, positional_encoding_type="learnable", # New parameter final_activation="Identity", chunk_size=16, transforms=None, device=torch.device("cpu"), **kwargs, ): super().__init__() # Features embedding assert input_dim % chunk_size == 0, "input_dim must be divisble by chunk_size" self.input_dim = input_dim self.chunk_size = chunk_size self.transforms_parameters = transforms["parameters"] self.aug_param_names = sorted(self.transforms_parameters.keys()) self.use_transform_pos_embeddings = use_transform_pos_embeddings self.positional_encoding_type = ( positional_encoding_type # Store the new parameter ) self.num_classes = 0 self.num_features = 0 self.num_classes = 0 self.embed_dim = self.input_dim // self.chunk_size self.chunk_pos_embeddings = self._get_sinusoidal_embeddings( self.chunk_size, self.embed_dim ) self.register_buffer("chunk_pos_embeddings_buffer", self.chunk_pos_embeddings) if use_transform_pos_embeddings: if positional_encoding_type == "learnable": self.sequence_pos_embedding = nn.Embedding( len(transforms["parameters"]), self.embed_dim ) elif positional_encoding_type == "sinusoidal": sinusoidal_embeddings = self._get_sinusoidal_embeddings( len(transforms["parameters"]), self.embed_dim ) self.register_buffer("sequence_pos_embedding", sinusoidal_embeddings) else: raise ValueError( f"Invalid positional_encoding_type: {positional_encoding_type}. Choose 'learnable' or 'sinusoidal'." ) else: print("Do not use transform positional embeddings") self.transform_embeddings = self._get_transforms_embeddings( transforms["parameters"], self.embed_dim ) self.features_embed = nn.Sequential( nn.Linear(input_dim, self.embed_dim), nn.LayerNorm(self.embed_dim) ) self.blocks = nn.ModuleList( [ Block(dim=self.embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio) for _ in range(depth) ] ) self.norm = nn.LayerNorm(self.embed_dim) if hasattr(nn, final_activation): self.final_activation = getattr(nn, final_activation)() else: raise ValueError(f"Activation {final_activation} is not found in torch.nn") self.head = nn.Sequential( nn.Linear(input_dim, input_dim), self.final_activation ) def _get_sinusoidal_embeddings(self, num_positions, embed_dim): """ Create sinusoidal embeddings for positional encoding. :param num_positions: Number of positions to encode. :param embed_dim: Dimensionality of each embedding vector. :return: Tensor of shape (num_positions, embed_dim) containing positional encodings. """ assert embed_dim % 2 == 0, "embed_dim must be even" position = torch.arange( 0, num_positions, dtype=torch.float ).unsqueeze(1) div_term = torch.exp( torch.arange(0, embed_dim, 2, dtype=torch.float) * (-math.log(10000.0) / embed_dim) ) # (embed_dim/2) pe = torch.zeros(num_positions, embed_dim) pe[:, 0::2] = torch.sin(position * div_term) pe[:, 1::2] = torch.cos(position * div_term) return pe def _get_transforms_embeddings(self, transforms, embed_dim): """ Create embedding modules for each augmentation parameter. :param transforms: Mapping of augmentation names to configuration. :param embed_dim: Dimensionality of the embeddings. :return: ModuleDict of embeddings for each augmentation type. """ transform_embeddings = nn.ModuleDict() for aug_name in transforms: if aug_name in [ "rotation", "h_flip", "v_flip", "gaussian_blur", "erosion", "dilation", ]: # Discrete transformations transform_embeddings[aug_name] = nn.Embedding( num_embeddings=2 if aug_name != "rotation" else 4, embedding_dim=embed_dim, ) elif aug_name in ["crop"]: # Discrete transformations transform_embeddings[aug_name] = nn.Embedding( num_embeddings=6, embedding_dim=embed_dim ) elif aug_name in [ "brightness", "contrast", "saturation", "hed", "hue", "powerlaw", ]: # Continuous transformations transform_embeddings[aug_name] = nn.Sequential( nn.Linear(1, embed_dim * 2), nn.SiLU(), nn.Linear(embed_dim * 2, embed_dim), ) else: raise ValueError( f"{aug_name} is not a valid augmentation parameter name" ) return transform_embeddings def forward_aug_params_embed(self, aug_params): """ Embed augmentation parameters and add positional embeddings if enabled. :param aug_params: OrderedDict mapping augmentation names to (value_tensor, position_tensor). :return: Tensor of shape (B, K, embed_dim) of embedded transform tokens. """ z_transforms = [] for aug_name, (aug_param, pos) in aug_params.items(): if aug_name in [ "rotation", "h_flip", "v_flip", "gaussian_blur", "erosion", "dilation", "crop", ]: z_transform = self.transform_embeddings[aug_name](aug_param) elif aug_name in [ "brightness", "contrast", "saturation", "hue", "powerlaw", "hed", ]: z_transform = self.transform_embeddings[aug_name]( aug_param[..., None].float() ) else: raise ValueError( f"{aug_name} is not a valid augmentation parameter name" ) # Add positional embedding if specified if self.use_transform_pos_embeddings: if self.positional_encoding_type == "learnable": pos_index = torch.as_tensor(pos, device=aug_param.device) pos_embedding = self.sequence_pos_embedding(pos_index) elif self.positional_encoding_type == "sinusoidal": pos_embedding = self.sequence_pos_embedding[pos].to( aug_param.device ) else: raise ValueError( f"Invalid positional_encoding_type: {self.positional_encoding_type}" ) z_transform_with_pos = z_transform + pos_embedding z_transforms.append(z_transform_with_pos) else: z_transforms.append(z_transform) # Stack the list of embeddings along a new dimension z_transforms = torch.stack(z_transforms, dim=1) return z_transforms def sample_aug_params( self, batch_size: int, device: torch.device = torch.device("cuda"), mode: Literal["instance_wise", "wsi_wise"] = "wsi_wise", ): """ Sample random augmentation parameters and their relative positions. If a transform from the supported list is missing in self.aug_param_names, include it with zero values and append it at unique tail positions. """ if mode not in ("instance_wise", "wsi_wise"): raise ValueError('mode must be "instance_wise" or "wsi_wise"') supported_aug_names = [ "rotation", "crop", "h_flip", "v_flip", "gaussian_blur", "erosion", "dilation", "brightness", "contrast", "saturation", "hue", "powerlaw", "hed", ] canonical_names = sorted(self.transforms_parameters.keys()) num_transforms = len(canonical_names) # Determine which supported transforms are missing from the current configuration. # For any missing transform, we will still include it in augmentation_parameters # so that the downstream model sees a consistent set of transforms. # These missing transforms are initialized with zero values (i.e., identity / no-op) # and assigned unique tail positions after all configured transforms. missing_names = [n for n in supported_aug_names if n not in canonical_names] required_positions = num_transforms + len(missing_names) # Build permutation/positions for configured transforms only if mode == "instance_wise": permutation_matrix = ( torch.stack( [ torch.randperm(num_transforms, device=device) for _ in range(batch_size) ], dim=0, ) if num_transforms > 0 else torch.empty((batch_size, 0), dtype=torch.long, device=device) ) else: # wsi_wise if num_transforms > 0: single_permutation = torch.randperm(num_transforms, device=device) permutation_matrix = single_permutation.unsqueeze(0).repeat( batch_size, 1 ) else: permutation_matrix = torch.empty( (batch_size, 0), dtype=torch.long, device=device ) positions_matrix = ( torch.argsort(permutation_matrix, dim=1) if num_transforms > 0 else torch.empty((batch_size, 0), dtype=torch.long, device=device) ) augmentation_parameters = OrderedDict() # --- sample configured transforms as before --- for transform_index, name in enumerate(canonical_names): config = self.transforms_parameters[name] if name == "rotation": probability = float(config) if mode == "instance_wise": apply_mask = torch.rand(batch_size, device=device) < probability random_angles = torch.randint(0, 4, (batch_size,), device=device) random_angles[~apply_mask] = 0 value_tensor = random_angles else: apply = random.random() < probability angle = random.randint(1, 3) if apply else 0 value_tensor = torch.full( (batch_size,), angle, dtype=torch.int64, device=device ) elif name == "crop": probability = float(config) if mode == "instance_wise": apply_mask = torch.rand(batch_size, device=device) < probability random_crops = torch.randint(0, 5, (batch_size,), device=device) random_crops[~apply_mask] = 0 value_tensor = random_crops else: apply = random.random() < probability crop_code = random.randint(1, 4) if apply else 0 value_tensor = torch.full( (batch_size,), crop_code, dtype=torch.int64, device=device ) elif name in ("h_flip", "v_flip", "gaussian_blur", "erosion", "dilation"): probability = float(config) if mode == "instance_wise": value_tensor = ( torch.rand(batch_size, device=device) < probability ).int() else: bit = int(random.random() < probability) value_tensor = torch.full( (batch_size,), bit, dtype=torch.int32, device=device ) elif name in ( "brightness", "contrast", "saturation", "hue", "powerlaw", "hed", ): lower_bound, upper_bound = map(float, config) if mode == "instance_wise": value_tensor = torch.empty(batch_size, device=device).uniform_( lower_bound, upper_bound ) else: scalar_value = random.uniform(lower_bound, upper_bound) value_tensor = torch.full( (batch_size,), scalar_value, dtype=torch.float32, device=device ) else: raise ValueError(f"'{name}' is not a recognised augmentation name") position_tensor = positions_matrix[:, transform_index] augmentation_parameters[name] = (value_tensor, position_tensor) for i, name in enumerate(missing_names): if name in ("rotation", "crop"): zeros = torch.zeros(batch_size, dtype=torch.int64, device=device) elif name in ("h_flip", "v_flip", "gaussian_blur", "erosion", "dilation"): zeros = torch.zeros(batch_size, dtype=torch.int32, device=device) else: # continuous zeros = torch.zeros(batch_size, dtype=torch.float32, device=device) tail_pos = num_transforms + i # unique: K, K+1, ..., K+M-1 pos = torch.full((batch_size,), tail_pos, dtype=torch.long, device=device) augmentation_parameters[name] = (zeros, pos) return augmentation_parameters def forward(self, x, aug_params, **kwargs): """ Forward pass: embed features, apply transformer blocks, and produce output. :param x: Input tensor of shape (B, input_dim). :param aug_params: Augmentation parameters from sample_aug_params. :return: Output tensor of shape (B, input_dim). """ x = x[:, None, :] x = x.view(x.shape[0], self.chunk_size, self.embed_dim) pos_embeddings = self.chunk_pos_embeddings_buffer.unsqueeze(0) x = x + pos_embeddings z = self.forward_aug_params_embed(aug_params) for block in self.blocks: x = block(x, z) x = self.norm(x) x = x.view(x.shape[0], 1, -1) x = self.head(x) x = x[:, 0, :] return x