| import torch | |
| import torch.nn as nn | |
| from transformers import BertModel, BertConfig | |
| class BertHierarchicalClassification(nn.Module): | |
| def __init__(self, config): | |
| super(BertHierarchicalClassification, self).__init__() | |
| self.bert = BertModel(config) | |
| hidden_size = config.hidden_size | |
| self.num_grades = config.num_grades | |
| self.num_domains = config.num_domains | |
| self.num_clusters = config.num_clusters | |
| self.num_standards = config.num_standards | |
| self.grade_classifier = nn.Linear(hidden_size, self.num_grades) | |
| self.domain_classifier = nn.Linear(hidden_size, self.num_domains) | |
| self.cluster_classifier = nn.Linear(hidden_size, self.num_clusters) | |
| self.standard_classifier = nn.Linear(hidden_size, self.num_standards) | |
| self.dropout = nn.Dropout(0.1) | |
| def forward(self, input_ids, attention_mask): | |
| outputs = self.bert(input_ids=input_ids, attention_mask=attention_mask) | |
| pooled_output = outputs.pooler_output | |
| pooled_output = self.dropout(pooled_output) | |
| grade_logits = self.grade_classifier(pooled_output) | |
| domain_logits = self.domain_classifier(pooled_output) | |
| cluster_logits = self.cluster_classifier(pooled_output) | |
| standard_logits = self.standard_classifier(pooled_output) | |
| return grade_logits, domain_logits, cluster_logits, standard_logits | |