Spaces:
Running
Running
| # mcqa_bert.py | |
| # -------------------------------------------------- | |
| # Plain BertModel + single‑unit classification head | |
| # -------------------------------------------------- | |
| import torch | |
| import torch.nn as nn | |
| from transformers import BertModel | |
| class MCQABERT(nn.Module): | |
| def __init__(self, ckpt: str = "bert-base-uncased"): | |
| super().__init__() | |
| self.encoder = BertModel.from_pretrained(ckpt) | |
| self.head = nn.Linear(self.encoder.config.hidden_size, 1) | |
| # -------------------------------------------------- | |
| def forward(self, input_ids, attention_mask): | |
| out = self.encoder( | |
| input_ids=input_ids, | |
| attention_mask=attention_mask, | |
| return_dict=True, | |
| ) | |
| cls_vec = out.last_hidden_state[:, 0] # [CLS] | |
| logits = self.head(cls_vec).squeeze(-1) # (B) | |
| return logits | |