""" Paraformer model implementation for Hugging Face Transformers. This module implements the Paraformer model for legal document retrieval, based on the paper "Attentive Deep Neural Networks for Legal Document Retrieval". """ import torch import torch.nn as nn import torch.nn.functional as F from typing import List, Optional, Union, Tuple from transformers.modeling_outputs import BaseModelOutput, SequenceClassifierOutput from transformers.modeling_utils import PreTrainedModel from transformers.utils import logging try: from .configuration_paraformer import ParaformerConfig except ImportError: from configuration_paraformer import ParaformerConfig logger = logging.get_logger(__name__) def sparsemax(input_tensor, dim=-1): """ Sparsemax activation function. Args: input_tensor: Input tensor dim: Dimension along which to apply sparsemax Returns: Sparsemax output tensor """ # Sort input in descending order sorted_input, _ = torch.sort(input_tensor, dim=dim, descending=True) # Compute cumulative sum input_cumsum = torch.cumsum(sorted_input, dim=dim) - 1 # Create range tensor k = torch.arange(1, input_tensor.size(dim) + 1, dtype=input_tensor.dtype, device=input_tensor.device) if dim != -1: shape = [1] * input_tensor.dim() shape[dim] = -1 k = k.view(shape) # Compute support support = k * sorted_input > input_cumsum # Find the largest k such that support[k] is True support_cumsum = torch.cumsum(support.float(), dim=dim) support_size = torch.sum(support.float(), dim=dim, keepdim=True) # Compute tau tau_cumsum = torch.cumsum(sorted_input * support.float(), dim=dim) tau = (tau_cumsum - 1) / support_size # Expand tau to match input shape if dim != -1: tau = tau.unsqueeze(dim) # Apply sparsemax output = torch.clamp(input_tensor - tau, min=0) return output class ParaformerAttention(nn.Module): """ Attention mechanism for Paraformer model. This implements a general attention mechanism with optional sparsemax activation. """ def __init__(self, config): super().__init__() self.config = config self.hidden_size = config.hidden_size self.use_sparsemax = config.use_sparsemax # Attention layers if config.attention_type == "general": self.attention_weights = nn.Linear(config.hidden_size, 1, bias=False) else: raise ValueError(f"Unsupported attention type: {config.attention_type}") def forward(self, query_embedding, sentence_embeddings, attention_mask=None): """ Apply attention mechanism. Args: query_embedding: Query embedding tensor [batch_size, hidden_size] sentence_embeddings: Sentence embeddings [batch_size, num_sentences, hidden_size] attention_mask: Mask for padding sentences [batch_size, num_sentences] Returns: attended_output: Weighted combination of sentence embeddings attention_weights: Attention weights for interpretability """ batch_size, num_sentences, hidden_size = sentence_embeddings.shape # Expand query embedding to match sentence embeddings query_expanded = query_embedding.unsqueeze(1).expand(-1, num_sentences, -1) # Compute attention scores using general attention # Combine query and sentence embeddings combined = query_expanded * sentence_embeddings # Element-wise multiplication attention_scores = self.attention_weights(combined).squeeze(-1) # [batch_size, num_sentences] # Apply attention mask if provided if attention_mask is not None: attention_scores = attention_scores.masked_fill(~attention_mask, float('-inf')) # Apply sparsemax or softmax if self.use_sparsemax: attention_weights = sparsemax(attention_scores, dim=-1) else: attention_weights = F.softmax(attention_scores, dim=-1) # Apply attention weights attended_output = torch.sum(attention_weights.unsqueeze(-1) * sentence_embeddings.clone(), dim=1) return attended_output, attention_weights class ParaformerModel(PreTrainedModel): """ Paraformer model for legal document retrieval. This model uses a hierarchical approach with attention mechanism to encode legal documents and queries for relevance classification. """ config_class = ParaformerConfig base_model_prefix = "paraformer" supports_gradient_checkpointing = True _no_split_modules = ["ParaformerAttention"] def __init__(self, config): super().__init__(config) self.config = config # Don't initialize SentenceTransformer in __init__ to avoid meta tensor issues self._sentence_encoder = None # Attention mechanism self.attention = ParaformerAttention(config) # Classifier self.classifier = nn.Linear(config.hidden_size, config.num_labels) self.dropout = nn.Dropout(config.dropout_prob) # Initialize weights self.post_init() @property def sentence_encoder(self): """Lazy loading of SentenceTransformer to avoid meta tensor issues""" if self._sentence_encoder is None: from sentence_transformers import SentenceTransformer self._sentence_encoder = SentenceTransformer(self.config.base_model_name) return self._sentence_encoder def forward( self, query_texts: Optional[List[str]] = None, article_texts: Optional[List[List[str]]] = None, labels: Optional[torch.Tensor] = None, return_dict: Optional[bool] = None, **kwargs ): """ Forward pass of the Paraformer model. Args: query_texts: List of query strings article_texts: List of article sentence lists labels: Optional labels for training return_dict: Whether to return a dictionary Returns: Model outputs including logits and optional loss """ return_dict = return_dict if return_dict is not None else self.config.use_return_dict if query_texts is None or article_texts is None: raise ValueError("Both query_texts and article_texts must be provided") batch_size = len(query_texts) device = next(self.parameters()).device # Encode queries query_embeddings = self.sentence_encoder.encode( query_texts, convert_to_tensor=True, device=device ).clone() # Clone to avoid inference tensor issues # Process articles all_attended_outputs = [] all_attention_weights = [] for i, article in enumerate(article_texts): if not article: # Handle empty articles attended_output = torch.zeros(self.config.hidden_size, device=device) attention_weights = torch.zeros(1, device=device) else: # Encode article sentences sentence_embeddings = self.sentence_encoder.encode( article, convert_to_tensor=True, device=device ).clone() # Clone to avoid inference tensor issues # Add batch dimension if needed if sentence_embeddings.dim() == 2: sentence_embeddings = sentence_embeddings.unsqueeze(0) # Apply attention attended_output, attention_weights = self.attention( query_embeddings[i:i+1], sentence_embeddings ) attended_output = attended_output.squeeze(0) attention_weights = attention_weights.squeeze(0) all_attended_outputs.append(attended_output) all_attention_weights.append(attention_weights) # Stack outputs attended_outputs = torch.stack(all_attended_outputs) # Apply dropout and classifier attended_outputs = self.dropout(attended_outputs) logits = self.classifier(attended_outputs) # Compute loss if labels provided loss = None if labels is not None: loss_fct = nn.CrossEntropyLoss() loss = loss_fct(logits.view(-1, self.config.num_labels), labels.view(-1)) if not return_dict: output = (logits,) + (all_attention_weights,) return ((loss,) + output) if loss is not None else output return SequenceClassifierOutput( loss=loss, logits=logits, hidden_states=None, attentions=torch.stack([w.unsqueeze(0) for w in all_attention_weights]) if all_attention_weights else None, ) def get_relevance_score(self, query: str, article: List[str]) -> float: """ Get relevance score for a single query-article pair. Args: query: Query string article: List of article sentences Returns: Relevance score between 0 and 1 """ self.eval() with torch.no_grad(): outputs = self.forward( query_texts=[query], article_texts=[article], return_dict=True ) probabilities = torch.softmax(outputs.logits, dim=-1) relevance_score = probabilities[0, 1].item() # Probability of being relevant return relevance_score def predict_relevance(self, query: str, article: List[str]) -> int: """ Predict binary relevance for a single query-article pair. Args: query: Query string article: List of article sentences Returns: Binary prediction (0 = not relevant, 1 = relevant) """ self.eval() with torch.no_grad(): outputs = self.forward( query_texts=[query], article_texts=[article], return_dict=True ) prediction = torch.argmax(outputs.logits, dim=-1).item() return prediction def _init_weights(self, module): """Initialize the weights""" if isinstance(module, nn.Linear): module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) if module.bias is not None: module.bias.data.zero_() elif isinstance(module, nn.LayerNorm): module.bias.data.zero_() module.weight.data.fill_(1.0)