import torch import torch.nn as nn import torch.nn.functional as F from transformers import AutoModel class DeBERTaLSTMClassifier(nn.Module): def __init__(self, hidden_dim=128, num_labels=2): super().__init__() self.deberta = AutoModel.from_pretrained("microsoft/deberta-base") # Đóng băng DeBERTa for param in self.deberta.parameters(): param.requires_grad = False self.lstm = nn.LSTM( input_size=self.deberta.config.hidden_size, hidden_size=hidden_dim, batch_first=True, bidirectional=True ) # Lớp Attention: chuyển đổi hidden state thành điểm số quan trọng (score) self.attention = nn.Linear(hidden_dim * 2, 1) self.fc = nn.Linear(hidden_dim * 2, num_labels) def forward(self, input_ids, attention_mask, return_attention=False): # 1. DeBERTa with torch.no_grad(): outputs = self.deberta(input_ids=input_ids, attention_mask=attention_mask, output_attentions=True) # 2. LSTM lstm_out, _ = self.lstm(outputs.last_hidden_state) # [batch, seq_len, hidden*2] # 3. Tính Attention (Luôn luôn thực hiện) # Tính score chưa qua softmax attn_scores = self.attention(lstm_out).squeeze(-1) # [batch, seq_len] # Masking chuẩn: Gán giá trị rất nhỏ (-inf) cho các vị trí padding trước khi Softmax # Để đảm bảo padding có attention weight = 0 tuyệt đối mask = attention_mask.float() attn_scores = attn_scores.masked_fill(mask == 0, -1e9) # Softmax để ra weights attn_weights = F.softmax(attn_scores, dim=-1) # [batch, seq_len] # Tính Context Vector (Weighted Sum) # [batch, seq_len, 1] * [batch, seq_len, hidden*2] -> sum -> [batch, hidden*2] context_vector = torch.sum(attn_weights.unsqueeze(-1) * lstm_out, dim=1) # 4. Classification logits = self.fc(context_vector) # 5. Return tùy theo yêu cầu if return_attention: return logits, attn_weights, outputs.attentions else: return logits