nguyenthanhasia commited on
Commit
cc3319c
·
verified ·
1 Parent(s): e93661f

Upload modeling_paraformer.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. modeling_paraformer.py +311 -0
modeling_paraformer.py ADDED
@@ -0,0 +1,311 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Paraformer model implementation for Hugging Face Transformers.
3
+
4
+ This module implements the Paraformer model for legal document retrieval,
5
+ based on the paper "Attentive Deep Neural Networks for Legal Document Retrieval".
6
+ """
7
+
8
+ import torch
9
+ import torch.nn as nn
10
+ import torch.nn.functional as F
11
+ from typing import List, Optional, Union, Tuple
12
+ from transformers.modeling_outputs import BaseModelOutput, SequenceClassifierOutput
13
+ from transformers.modeling_utils import PreTrainedModel
14
+ from transformers.utils import logging
15
+
16
+ try:
17
+ from .configuration_paraformer import ParaformerConfig
18
+ except ImportError:
19
+ from configuration_paraformer import ParaformerConfig
20
+
21
+ logger = logging.get_logger(__name__)
22
+
23
+
24
+ def sparsemax(input_tensor, dim=-1):
25
+ """
26
+ Sparsemax activation function.
27
+
28
+ Args:
29
+ input_tensor: Input tensor
30
+ dim: Dimension along which to apply sparsemax
31
+
32
+ Returns:
33
+ Sparsemax output tensor
34
+ """
35
+ # Sort input in descending order
36
+ sorted_input, _ = torch.sort(input_tensor, dim=dim, descending=True)
37
+
38
+ # Compute cumulative sum
39
+ input_cumsum = torch.cumsum(sorted_input, dim=dim) - 1
40
+
41
+ # Create range tensor
42
+ k = torch.arange(1, input_tensor.size(dim) + 1, dtype=input_tensor.dtype, device=input_tensor.device)
43
+ if dim != -1:
44
+ shape = [1] * input_tensor.dim()
45
+ shape[dim] = -1
46
+ k = k.view(shape)
47
+
48
+ # Compute support
49
+ support = k * sorted_input > input_cumsum
50
+
51
+ # Find the largest k such that support[k] is True
52
+ support_cumsum = torch.cumsum(support.float(), dim=dim)
53
+ support_size = torch.sum(support.float(), dim=dim, keepdim=True)
54
+
55
+ # Compute tau
56
+ tau_cumsum = torch.cumsum(sorted_input * support.float(), dim=dim)
57
+ tau = (tau_cumsum - 1) / support_size
58
+
59
+ # Expand tau to match input shape
60
+ if dim != -1:
61
+ tau = tau.unsqueeze(dim)
62
+
63
+ # Apply sparsemax
64
+ output = torch.clamp(input_tensor - tau, min=0)
65
+
66
+ return output
67
+
68
+
69
+ class ParaformerAttention(nn.Module):
70
+ """
71
+ Attention mechanism for Paraformer model.
72
+
73
+ This implements a general attention mechanism with optional sparsemax activation.
74
+ """
75
+
76
+ def __init__(self, config):
77
+ super().__init__()
78
+ self.config = config
79
+ self.hidden_size = config.hidden_size
80
+ self.use_sparsemax = config.use_sparsemax
81
+
82
+ # Attention layers
83
+ if config.attention_type == "general":
84
+ self.attention_weights = nn.Linear(config.hidden_size, 1, bias=False)
85
+ else:
86
+ raise ValueError(f"Unsupported attention type: {config.attention_type}")
87
+
88
+ def forward(self, query_embedding, sentence_embeddings, attention_mask=None):
89
+ """
90
+ Apply attention mechanism.
91
+
92
+ Args:
93
+ query_embedding: Query embedding tensor [batch_size, hidden_size]
94
+ sentence_embeddings: Sentence embeddings [batch_size, num_sentences, hidden_size]
95
+ attention_mask: Mask for padding sentences [batch_size, num_sentences]
96
+
97
+ Returns:
98
+ attended_output: Weighted combination of sentence embeddings
99
+ attention_weights: Attention weights for interpretability
100
+ """
101
+ batch_size, num_sentences, hidden_size = sentence_embeddings.shape
102
+
103
+ # Expand query embedding to match sentence embeddings
104
+ query_expanded = query_embedding.unsqueeze(1).expand(-1, num_sentences, -1)
105
+
106
+ # Compute attention scores using general attention
107
+ # Combine query and sentence embeddings
108
+ combined = query_expanded * sentence_embeddings # Element-wise multiplication
109
+ attention_scores = self.attention_weights(combined).squeeze(-1) # [batch_size, num_sentences]
110
+
111
+ # Apply attention mask if provided
112
+ if attention_mask is not None:
113
+ attention_scores = attention_scores.masked_fill(~attention_mask, float('-inf'))
114
+
115
+ # Apply sparsemax or softmax
116
+ if self.use_sparsemax:
117
+ attention_weights = sparsemax(attention_scores, dim=-1)
118
+ else:
119
+ attention_weights = F.softmax(attention_scores, dim=-1)
120
+
121
+ # Apply attention weights
122
+ attended_output = torch.sum(attention_weights.unsqueeze(-1) * sentence_embeddings.clone(), dim=1)
123
+
124
+ return attended_output, attention_weights
125
+
126
+
127
+ class ParaformerModel(PreTrainedModel):
128
+ """
129
+ Paraformer model for legal document retrieval.
130
+
131
+ This model uses a hierarchical approach with attention mechanism to encode legal documents
132
+ and queries for relevance classification.
133
+ """
134
+
135
+ config_class = ParaformerConfig
136
+ base_model_prefix = "paraformer"
137
+ supports_gradient_checkpointing = True
138
+ _no_split_modules = ["ParaformerAttention"]
139
+
140
+ def __init__(self, config):
141
+ super().__init__(config)
142
+ self.config = config
143
+
144
+ # Don't initialize SentenceTransformer in __init__ to avoid meta tensor issues
145
+ self._sentence_encoder = None
146
+
147
+ # Attention mechanism
148
+ self.attention = ParaformerAttention(config)
149
+
150
+ # Classifier
151
+ self.classifier = nn.Linear(config.hidden_size, config.num_labels)
152
+ self.dropout = nn.Dropout(config.dropout_prob)
153
+
154
+ # Initialize weights
155
+ self.post_init()
156
+
157
+ @property
158
+ def sentence_encoder(self):
159
+ """Lazy loading of SentenceTransformer to avoid meta tensor issues"""
160
+ if self._sentence_encoder is None:
161
+ from sentence_transformers import SentenceTransformer
162
+ self._sentence_encoder = SentenceTransformer(self.config.base_model_name)
163
+ return self._sentence_encoder
164
+
165
+ def forward(
166
+ self,
167
+ query_texts: Optional[List[str]] = None,
168
+ article_texts: Optional[List[List[str]]] = None,
169
+ labels: Optional[torch.Tensor] = None,
170
+ return_dict: Optional[bool] = None,
171
+ **kwargs
172
+ ):
173
+ """
174
+ Forward pass of the Paraformer model.
175
+
176
+ Args:
177
+ query_texts: List of query strings
178
+ article_texts: List of article sentence lists
179
+ labels: Optional labels for training
180
+ return_dict: Whether to return a dictionary
181
+
182
+ Returns:
183
+ Model outputs including logits and optional loss
184
+ """
185
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
186
+
187
+ if query_texts is None or article_texts is None:
188
+ raise ValueError("Both query_texts and article_texts must be provided")
189
+
190
+ batch_size = len(query_texts)
191
+ device = next(self.parameters()).device
192
+
193
+ # Encode queries
194
+ query_embeddings = self.sentence_encoder.encode(
195
+ query_texts,
196
+ convert_to_tensor=True,
197
+ device=device
198
+ ).clone() # Clone to avoid inference tensor issues
199
+
200
+ # Process articles
201
+ all_attended_outputs = []
202
+ all_attention_weights = []
203
+
204
+ for i, article in enumerate(article_texts):
205
+ if not article: # Handle empty articles
206
+ attended_output = torch.zeros(self.config.hidden_size, device=device)
207
+ attention_weights = torch.zeros(1, device=device)
208
+ else:
209
+ # Encode article sentences
210
+ sentence_embeddings = self.sentence_encoder.encode(
211
+ article,
212
+ convert_to_tensor=True,
213
+ device=device
214
+ ).clone() # Clone to avoid inference tensor issues
215
+
216
+ # Add batch dimension if needed
217
+ if sentence_embeddings.dim() == 2:
218
+ sentence_embeddings = sentence_embeddings.unsqueeze(0)
219
+
220
+ # Apply attention
221
+ attended_output, attention_weights = self.attention(
222
+ query_embeddings[i:i+1],
223
+ sentence_embeddings
224
+ )
225
+ attended_output = attended_output.squeeze(0)
226
+ attention_weights = attention_weights.squeeze(0)
227
+
228
+ all_attended_outputs.append(attended_output)
229
+ all_attention_weights.append(attention_weights)
230
+
231
+ # Stack outputs
232
+ attended_outputs = torch.stack(all_attended_outputs)
233
+
234
+ # Apply dropout and classifier
235
+ attended_outputs = self.dropout(attended_outputs)
236
+ logits = self.classifier(attended_outputs)
237
+
238
+ # Compute loss if labels provided
239
+ loss = None
240
+ if labels is not None:
241
+ loss_fct = nn.CrossEntropyLoss()
242
+ loss = loss_fct(logits.view(-1, self.config.num_labels), labels.view(-1))
243
+
244
+ if not return_dict:
245
+ output = (logits,) + (all_attention_weights,)
246
+ return ((loss,) + output) if loss is not None else output
247
+
248
+ return SequenceClassifierOutput(
249
+ loss=loss,
250
+ logits=logits,
251
+ hidden_states=None,
252
+ attentions=torch.stack([w.unsqueeze(0) for w in all_attention_weights]) if all_attention_weights else None,
253
+ )
254
+
255
+ def get_relevance_score(self, query: str, article: List[str]) -> float:
256
+ """
257
+ Get relevance score for a single query-article pair.
258
+
259
+ Args:
260
+ query: Query string
261
+ article: List of article sentences
262
+
263
+ Returns:
264
+ Relevance score between 0 and 1
265
+ """
266
+ self.eval()
267
+ with torch.no_grad():
268
+ outputs = self.forward(
269
+ query_texts=[query],
270
+ article_texts=[article],
271
+ return_dict=True
272
+ )
273
+
274
+ probabilities = torch.softmax(outputs.logits, dim=-1)
275
+ relevance_score = probabilities[0, 1].item() # Probability of being relevant
276
+
277
+ return relevance_score
278
+
279
+ def predict_relevance(self, query: str, article: List[str]) -> int:
280
+ """
281
+ Predict binary relevance for a single query-article pair.
282
+
283
+ Args:
284
+ query: Query string
285
+ article: List of article sentences
286
+
287
+ Returns:
288
+ Binary prediction (0 = not relevant, 1 = relevant)
289
+ """
290
+ self.eval()
291
+ with torch.no_grad():
292
+ outputs = self.forward(
293
+ query_texts=[query],
294
+ article_texts=[article],
295
+ return_dict=True
296
+ )
297
+
298
+ prediction = torch.argmax(outputs.logits, dim=-1).item()
299
+
300
+ return prediction
301
+
302
+ def _init_weights(self, module):
303
+ """Initialize the weights"""
304
+ if isinstance(module, nn.Linear):
305
+ module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
306
+ if module.bias is not None:
307
+ module.bias.data.zero_()
308
+ elif isinstance(module, nn.LayerNorm):
309
+ module.bias.data.zero_()
310
+ module.weight.data.fill_(1.0)
311
+