update model imports to be compatible with transformers 4.56.0
Browse files- modeling_bert.py +8 -8
modeling_bert.py
CHANGED
|
@@ -48,12 +48,8 @@ from transformers.modeling_outputs import (
|
|
| 48 |
SequenceClassifierOutput,
|
| 49 |
TokenClassifierOutput,
|
| 50 |
)
|
| 51 |
-
from transformers.modeling_utils import
|
| 52 |
-
|
| 53 |
-
apply_chunking_to_forward,
|
| 54 |
-
find_pruneable_heads_and_indices,
|
| 55 |
-
prune_linear_layer,
|
| 56 |
-
)
|
| 57 |
from transformers.utils import logging
|
| 58 |
from transformers.models.bert.configuration_bert import BertConfig
|
| 59 |
|
|
@@ -1843,6 +1839,9 @@ class BertForSequenceClassification(BertPreTrainedModel):
|
|
| 1843 |
|
| 1844 |
loss = None
|
| 1845 |
if labels is not None:
|
|
|
|
|
|
|
|
|
|
| 1846 |
if self.config.problem_type is None:
|
| 1847 |
if self.num_labels == 1:
|
| 1848 |
self.config.problem_type = "regression"
|
|
@@ -1850,7 +1849,7 @@ class BertForSequenceClassification(BertPreTrainedModel):
|
|
| 1850 |
self.config.problem_type = "single_label_classification"
|
| 1851 |
else:
|
| 1852 |
self.config.problem_type = "multi_label_classification"
|
| 1853 |
-
|
| 1854 |
if self.config.problem_type == "regression":
|
| 1855 |
loss_fct = MSELoss()
|
| 1856 |
if self.num_labels == 1:
|
|
@@ -1858,6 +1857,8 @@ class BertForSequenceClassification(BertPreTrainedModel):
|
|
| 1858 |
else:
|
| 1859 |
loss = loss_fct(logits, labels)
|
| 1860 |
elif self.config.problem_type == "single_label_classification":
|
|
|
|
|
|
|
| 1861 |
loss_fct = CrossEntropyLoss()
|
| 1862 |
loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
|
| 1863 |
elif self.config.problem_type == "multi_label_classification":
|
|
@@ -1987,7 +1988,6 @@ class BertForTokenClassification(BertPreTrainedModel):
|
|
| 1987 |
self.config = config
|
| 1988 |
if getattr(self.config, 'problem_type', None) is None:
|
| 1989 |
self.config.problem_type = 'single_label_classification'
|
| 1990 |
-
|
| 1991 |
self.bert = BertModel(config, add_pooling_layer=False)
|
| 1992 |
classifier_dropout = (
|
| 1993 |
config.classifier_dropout if config.classifier_dropout is not None else config.hidden_dropout_prob
|
|
|
|
| 48 |
SequenceClassifierOutput,
|
| 49 |
TokenClassifierOutput,
|
| 50 |
)
|
| 51 |
+
from transformers.modeling_utils import PreTrainedModel
|
| 52 |
+
from transformers.pytorch_utils import apply_chunking_to_forward, find_pruneable_heads_and_indices, prune_linear_layer
|
|
|
|
|
|
|
|
|
|
|
|
|
| 53 |
from transformers.utils import logging
|
| 54 |
from transformers.models.bert.configuration_bert import BertConfig
|
| 55 |
|
|
|
|
| 1839 |
|
| 1840 |
loss = None
|
| 1841 |
if labels is not None:
|
| 1842 |
+
# print (f"self.config.problem_type from init: {self.config.problem_type}")
|
| 1843 |
+
# print (f"self.num_labels from init: {self.num_labels}")
|
| 1844 |
+
# print (f"labels.dtype {labels.dtype}")
|
| 1845 |
if self.config.problem_type is None:
|
| 1846 |
if self.num_labels == 1:
|
| 1847 |
self.config.problem_type = "regression"
|
|
|
|
| 1849 |
self.config.problem_type = "single_label_classification"
|
| 1850 |
else:
|
| 1851 |
self.config.problem_type = "multi_label_classification"
|
| 1852 |
+
# print (f"self.config.problem_type from init: {self.config.problem_type}")
|
| 1853 |
if self.config.problem_type == "regression":
|
| 1854 |
loss_fct = MSELoss()
|
| 1855 |
if self.num_labels == 1:
|
|
|
|
| 1857 |
else:
|
| 1858 |
loss = loss_fct(logits, labels)
|
| 1859 |
elif self.config.problem_type == "single_label_classification":
|
| 1860 |
+
# print (logits)
|
| 1861 |
+
# print (labels)
|
| 1862 |
loss_fct = CrossEntropyLoss()
|
| 1863 |
loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
|
| 1864 |
elif self.config.problem_type == "multi_label_classification":
|
|
|
|
| 1988 |
self.config = config
|
| 1989 |
if getattr(self.config, 'problem_type', None) is None:
|
| 1990 |
self.config.problem_type = 'single_label_classification'
|
|
|
|
| 1991 |
self.bert = BertModel(config, add_pooling_layer=False)
|
| 1992 |
classifier_dropout = (
|
| 1993 |
config.classifier_dropout if config.classifier_dropout is not None else config.hidden_dropout_prob
|