Spaces:
Running
Running
| import numpy as np | |
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| from transformers import AutoModel, BertConfig, BertModel | |
| from encoder import SearchActivityEncoder | |
| class EnhancedUserTower(nn.Module): | |
| def __init__( | |
| self, | |
| input_dim=20, | |
| hidden_dim=64, | |
| out_dim=32, | |
| search_emb_dim=32, | |
| num_heads=4, | |
| num_layers=2, | |
| with_search=True, | |
| ): | |
| super().__init__() | |
| config = BertConfig( | |
| hidden_size=hidden_dim, | |
| num_hidden_layers=num_layers, | |
| num_attention_heads=num_heads, | |
| intermediate_size=hidden_dim * 4, | |
| ) | |
| self.with_search = with_search | |
| self.user_input_dim = input_dim - search_emb_dim | |
| self.feature_proj = nn.Linear( | |
| self.user_input_dim if with_search else input_dim, | |
| hidden_dim, | |
| ) | |
| self.transformer = BertModel(config) | |
| self.pooler = nn.Linear(hidden_dim, out_dim) | |
| self.search_fusion = nn.Sequential( | |
| nn.Linear(out_dim + search_emb_dim, out_dim), | |
| nn.ReLU(), | |
| ) | |
| def forward(self, encoding): | |
| if self.with_search: | |
| feat = encoding[:, : self.user_input_dim] | |
| search_emb = encoding[:, self.user_input_dim :] | |
| else: | |
| feat = encoding | |
| x = self.feature_proj(feat) | |
| out = self.transformer(inputs_embeds=x.unsqueeze(1)).last_hidden_state.squeeze( | |
| 1 | |
| ) | |
| user_emb = self.pooler(out) | |
| if self.with_search: | |
| user_emb = self.search_fusion(torch.cat([user_emb, search_emb], dim=-1)) | |
| return F.normalize(user_emb, dim=1) | |
| class TextTower(nn.Module): | |
| def __init__( | |
| self, | |
| model_name="huawei-noah/TinyBERT_General_6L_768D", | |
| proj_hidden=256, | |
| out_dim=128, | |
| ): | |
| super().__init__() | |
| self.encoder = AutoModel.from_pretrained(model_name) | |
| dim = self.encoder.config.hidden_size | |
| self.proj = nn.Sequential( | |
| nn.Linear(dim, proj_hidden), nn.ReLU(), nn.Linear(proj_hidden, out_dim) | |
| ) | |
| def forward(self, input_ids, attention_mask): | |
| # CLS token pooling | |
| outputs = self.encoder(input_ids=input_ids, attention_mask=attention_mask) | |
| cls = outputs.last_hidden_state[:, 0] | |
| return self.proj(cls) | |
| class StructuredTower(nn.Module): | |
| def __init__(self, input_dim, hidden_dim): | |
| super().__init__() | |
| self.mlp = nn.Sequential( | |
| nn.Linear(input_dim, hidden_dim), | |
| nn.ReLU(), | |
| nn.Linear(hidden_dim, hidden_dim), | |
| ) | |
| def forward(self, x): | |
| return self.mlp(x) | |
| class MultiModalAttentionFusion(nn.Module): | |
| """ | |
| Fuse text, structured, and review embeddings via self-attention. | |
| """ | |
| def __init__(self, dim, fusion_dim, num_heads=4): | |
| super().__init__() | |
| self.mha = nn.MultiheadAttention( | |
| embed_dim=dim, num_heads=num_heads, batch_first=True | |
| ) | |
| self.proj = nn.Sequential( | |
| nn.Linear(dim, fusion_dim), nn.ReLU(), nn.Linear(fusion_dim, fusion_dim) | |
| ) | |
| def forward(self, text_emb, struct_emb, review_emb): | |
| # stack modalities as sequence length=3 | |
| x = torch.stack([text_emb, struct_emb, review_emb], dim=1) | |
| attn_out, _ = self.mha(x, x, x) | |
| pooled = attn_out.mean(dim=1) | |
| return self.proj(pooled) | |
| class PropertyTower(nn.Module): | |
| def __init__( | |
| self, | |
| struct_dim, | |
| text_hidden_dim, | |
| fusion_dim, | |
| out_dim, | |
| num_views=3, | |
| dropout=0.2, | |
| noise_std=0.01, | |
| review_model="huawei-noah/TinyBERT_General_6L_768D", | |
| review_proj=128, | |
| review_out=64, | |
| ): | |
| super().__init__() | |
| # shared encoder for text & reviews | |
| self.base_text = AutoModel.from_pretrained(review_model) | |
| # separate projection heads | |
| self.text_proj = nn.Sequential( | |
| nn.Linear(self.base_text.config.hidden_size, text_hidden_dim), | |
| nn.ReLU(), | |
| nn.Linear(text_hidden_dim, fusion_dim), | |
| ) | |
| self.review_proj = nn.Sequential( | |
| nn.Linear(self.base_text.config.hidden_size, review_proj), | |
| nn.ReLU(), | |
| nn.Linear(review_proj, fusion_dim), | |
| ) | |
| self.structured_tower = StructuredTower(struct_dim, fusion_dim) | |
| self.fusion = MultiModalAttentionFusion(fusion_dim, fusion_dim) | |
| self.dropout = nn.Dropout(dropout) | |
| self.noise_std = noise_std | |
| self.views = nn.ModuleList( | |
| [ | |
| nn.Sequential( | |
| nn.Linear(fusion_dim, fusion_dim // 2), | |
| nn.ReLU(), | |
| nn.Linear(fusion_dim // 2, out_dim), | |
| ) | |
| for _ in range(num_views) | |
| ] | |
| ) | |
| self.contrastive = nn.Linear(out_dim, out_dim) | |
| def _encode_reviews(self, review_seq): | |
| ids, mask = review_seq["input_ids"], review_seq["attention_mask"] | |
| B, R, L = ids.size() | |
| if R == 0: | |
| # Return a zero tensor with expected output shape: (B, D) | |
| out_dim = self.review_proj[ | |
| -1 | |
| ].out_features # get the output dimension of the last layer in review_proj | |
| return torch.zeros(B, out_dim, device=ids.device) | |
| # Flatten to (B*R, L) | |
| flat_ids = ids.view(-1, L) | |
| flat_mask = mask.view(-1, L) | |
| # Forward through base text encoder | |
| out = self.base_text( | |
| input_ids=flat_ids, attention_mask=flat_mask | |
| ).last_hidden_state[:, 0] | |
| # Project and reshape | |
| proj = self.review_proj(out).view(B, R, -1) | |
| # Average over reviews | |
| return proj.mean(dim=1) | |
| def forward(self, struct_feat, text_seq, review_seq=None): | |
| # text | |
| t = self.base_text( | |
| input_ids=text_seq["input_ids"], attention_mask=text_seq["attention_mask"] | |
| ).last_hidden_state[:, 0] | |
| text_emb = self.text_proj(t) | |
| # struct | |
| s = self.structured_tower(struct_feat) | |
| # review | |
| if review_seq is not None: | |
| r = self._encode_reviews(review_seq) | |
| else: | |
| # zero vector fallback | |
| r = torch.zeros_like(text_emb) | |
| fused = self.fusion(text_emb, s, r) | |
| # multi-view | |
| vs = [] | |
| for m in self.views: | |
| x = self.dropout(fused) | |
| x = x + torch.randn_like(x) * self.noise_std | |
| vs.append(m(x)) | |
| avg = torch.stack(vs).mean(dim=0) | |
| return F.normalize(avg, dim=1), [ | |
| F.normalize(self.contrastive(v), dim=1) for v in vs | |
| ] | |
| class TwoTowerRec(nn.Module): | |
| def __init__( | |
| self, | |
| search_args: dict, | |
| user_input_dim: int, | |
| prop_args: dict, | |
| ): | |
| super().__init__() | |
| self.search_enc = SearchActivityEncoder(**search_args) | |
| self.user_tower = EnhancedUserTower(input_dim=user_input_dim, with_search=False) | |
| self.prop_tower = PropertyTower(**prop_args) | |
| def forward(self, user_feat, prop_text, prop_struct, review_text=None, search=None): | |
| if search is not None: | |
| s_emb = self.search_enc(search) | |
| u_emb = self.user_tower(torch.cat([user_feat, s_emb], dim=-1)) | |
| else: | |
| u_emb = self.user_tower(user_feat) | |
| p_emb, p_views = self.prop_tower(prop_text, prop_struct, review_text) | |
| return u_emb, p_emb, p_views | |
| def cross_view_contrastive_loss(views, temp=0.07): | |
| loss = 0.0 | |
| count = 0 | |
| B = views[0].size(0) | |
| for i in range(len(views)): | |
| for j in range(i + 1, len(views)): | |
| l = torch.matmul(views[i], views[j].T) / temp | |
| tgt = torch.arange(B, device=l.device) | |
| loss += (F.cross_entropy(l, tgt) + F.cross_entropy(l.T, tgt)) / 2 | |
| count += 1 | |
| return loss / count | |
| def multi_pos_info_nce( | |
| dist_matrix: torch.Tensor, | |
| label_matrix: torch.Tensor, | |
| w_neg: float = 1.5, | |
| w_unl: float = 0.3, | |
| tau: float = 0.07, | |
| eps: float = 1e-8, | |
| ) -> torch.Tensor: | |
| """ | |
| dist_matrix: [U, P] distances (>=0) | |
| label_matrix: [U, P] binary (1=positive, 0=negative) | |
| tau: temperature for softmax | |
| w_neg, w_unl: weighting for explicit vs. unlabeled negatives | |
| returns: scalar loss | |
| """ | |
| sim = -dist_matrix | |
| pos_mask = label_matrix == 1 | |
| neg_mask = label_matrix == -1 | |
| unl_mask = label_matrix == 0 | |
| W = torch.zeros_like(sim) | |
| W[pos_mask] = 1.0 | |
| W[neg_mask] = w_neg | |
| W[unl_mask] = w_unl | |
| exp_sim = torch.exp(sim / tau) # [U, P] | |
| weighted = W * exp_sim | |
| num = (weighted * pos_mask).sum(dim=1) # [U] | |
| denom = weighted.sum(dim=1) + eps # [U] | |
| valid = pos_mask.sum(dim=1) > 0 # [U] | |
| if valid.sum() == 0: | |
| return torch.tensor(0.0, device=dist_matrix.device) | |
| loss_per_user = -torch.log(num[valid] / denom[valid]) | |
| return loss_per_user.mean() | |
| def float_to_sign(tensor: torch.Tensor, low_thresh: float, high_thresh: float): | |
| result = torch.zeros_like(tensor) | |
| result[tensor > high_thresh] = 1 | |
| result[tensor < low_thresh] = -1 | |
| return result | |
| def pairwise_positive_ranking_loss(dist_matrix, score_matrix, margin=0.1): | |
| """ | |
| dist_matrix: Tensor [U, P], pairwise distances between user and items | |
| score_matrix: Tensor [U, P], scores or labels (higher = more relevant) | |
| """ | |
| loss = 0.0 | |
| num_users = dist_matrix.size(0) | |
| for u in range(num_users): | |
| pos_idx = (score_matrix[u] > 0).nonzero(as_tuple=True)[0] | |
| if pos_idx.numel() < 2: | |
| continue | |
| pos_scores = score_matrix[u, pos_idx] | |
| pos_dists = dist_matrix[u, pos_idx] | |
| for i in range(len(pos_idx)): | |
| for j in range(i + 1, len(pos_idx)): | |
| s_i, s_j = pos_scores[i], pos_scores[j] | |
| d_i, d_j = pos_dists[i], pos_dists[j] | |
| if s_i == s_j: | |
| continue | |
| sign = torch.sign(s_j - s_i) | |
| loss += torch.relu(sign * (d_i - d_j) + margin) | |
| return loss / num_users | |
| class SoftContrastiveLoss(torch.nn.Module): | |
| def __init__( | |
| self, | |
| margin: float = 1.0, | |
| temp: float = 0.3, | |
| lambda_ortho: float = 0.1, | |
| low_thresh: float = 0.4, | |
| high_thresh: float = 0.7, | |
| ): | |
| super().__init__() | |
| self.margin = margin | |
| self.temp = temp | |
| self.lambda_ortho = lambda_ortho | |
| self.low_thresh = low_thresh | |
| self.high_thresh = high_thresh | |
| def forward(self, u_emb, p_emb, p_views, t, user_ids, prop_ids): | |
| # Create matrix | |
| uniq_u, inv_u = torch.unique(user_ids, return_inverse=True) | |
| uniq_p, inv_p = torch.unique(prop_ids, return_inverse=True) | |
| U, P = uniq_u.size(0), uniq_p.size(0) | |
| M = torch.zeros(U, P, device=user_ids.device) | |
| M[inv_u, inv_p] = t | |
| T = torch.zeros(U, P, device=user_ids.device) | |
| T[inv_u, inv_p] = float_to_sign(t, self.low_thresh, self.high_thresh) | |
| dir_matrix = torch.zeros(U, P, device=user_ids.device) | |
| # scatter the scores | |
| dir_matrix[inv_u, inv_p] = torch.sign(t - 0.5) | |
| # Calculate distance | |
| dist = F.pairwise_distance(u_emb, p_emb) # [batch] | |
| dist_matrix = torch.zeros(U, P, device=user_ids.device) | |
| dist_matrix[inv_u, inv_p] = dist | |
| info_nce_loss = multi_pos_info_nce(dist_matrix, T, t=self.temp) | |
| hinge_loss = pairwise_positive_ranking_loss(dist_matrix, M) | |
| # Multi-view contrastive loss | |
| ortho_loss = torch.mean(torch.abs(torch.matmul(u_emb.T, p_emb))) | |
| return info_nce_loss + hinge_loss + ortho_loss * self.lambda_ortho | |