Add QA head (#17)
Browse files- add qa head (9289c921b309897dce6f5cefbde1741012bdfa96)
- modeling_eurobert.py +98 -1
modeling_eurobert.py
CHANGED
|
@@ -30,7 +30,7 @@ from transformers.activations import ACT2FN
|
|
| 30 |
from transformers.cache_utils import Cache, StaticCache
|
| 31 |
from transformers.modeling_attn_mask_utils import AttentionMaskConverter
|
| 32 |
from transformers.modeling_flash_attention_utils import FlashAttentionKwargs
|
| 33 |
-
from transformers.modeling_outputs import BaseModelOutput, BaseModelOutputWithPast, MaskedLMOutput, SequenceClassifierOutput, TokenClassifierOutput
|
| 34 |
from transformers.modeling_rope_utils import ROPE_INIT_FUNCTIONS
|
| 35 |
from transformers.modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
|
| 36 |
from transformers.processing_utils import Unpack
|
|
@@ -951,10 +951,107 @@ class EuroBertForTokenClassification(EuroBertPreTrainedModel):
|
|
| 951 |
)
|
| 952 |
|
| 953 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 954 |
__all__ = [
|
| 955 |
"EuroBertPreTrainedModel",
|
| 956 |
"EuroBertModel",
|
| 957 |
"EuroBertForMaskedLM",
|
| 958 |
"EuroBertForSequenceClassification",
|
| 959 |
"EuroBertForTokenClassification",
|
|
|
|
| 960 |
]
|
|
|
|
| 30 |
from transformers.cache_utils import Cache, StaticCache
|
| 31 |
from transformers.modeling_attn_mask_utils import AttentionMaskConverter
|
| 32 |
from transformers.modeling_flash_attention_utils import FlashAttentionKwargs
|
| 33 |
+
from transformers.modeling_outputs import BaseModelOutput, BaseModelOutputWithPast, MaskedLMOutput, QuestionAnsweringModelOutput, SequenceClassifierOutput, TokenClassifierOutput
|
| 34 |
from transformers.modeling_rope_utils import ROPE_INIT_FUNCTIONS
|
| 35 |
from transformers.modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
|
| 36 |
from transformers.processing_utils import Unpack
|
|
|
|
| 951 |
)
|
| 952 |
|
| 953 |
|
| 954 |
+
@add_start_docstrings(
|
| 955 |
+
"""
|
| 956 |
+
The EuroBert Model with a span classification head on top for extractive question-answering tasks
|
| 957 |
+
like SQuAD (a linear layers on top of the hidden-states output to compute span start logits
|
| 958 |
+
and span end logits).
|
| 959 |
+
""",
|
| 960 |
+
EUROBERT_START_DOCSTRING,
|
| 961 |
+
)
|
| 962 |
+
class EuroBertForQuestionAnswering(EuroBertPreTrainedModel):
|
| 963 |
+
def __init__(self, config: EuroBertConfig):
|
| 964 |
+
super().__init__(config)
|
| 965 |
+
self.num_labels = config.num_labels
|
| 966 |
+
self.model = EuroBertModel(config)
|
| 967 |
+
|
| 968 |
+
self.qa_outputs = nn.Linear(config.hidden_size, config.num_labels)
|
| 969 |
+
self.post_init()
|
| 970 |
+
|
| 971 |
+
def get_input_embeddings(self):
|
| 972 |
+
return self.model.embed_tokens
|
| 973 |
+
|
| 974 |
+
def set_input_embeddings(self, value):
|
| 975 |
+
self.model.embed_tokens = value
|
| 976 |
+
|
| 977 |
+
@add_start_docstrings_to_model_forward(EUROBERT_INPUTS_DOCSTRING)
|
| 978 |
+
def forward(
|
| 979 |
+
self,
|
| 980 |
+
input_ids: Optional[torch.Tensor] = None,
|
| 981 |
+
attention_mask: Optional[torch.Tensor] = None,
|
| 982 |
+
position_ids: Optional[torch.Tensor] = None,
|
| 983 |
+
inputs_embeds: Optional[torch.Tensor] = None,
|
| 984 |
+
use_cache: Optional[bool] = None,
|
| 985 |
+
start_positions: Optional[torch.Tensor] = None,
|
| 986 |
+
end_positions: Optional[torch.Tensor] = None,
|
| 987 |
+
output_attentions: Optional[bool] = None,
|
| 988 |
+
output_hidden_states: Optional[bool] = None,
|
| 989 |
+
return_dict: Optional[bool] = None,
|
| 990 |
+
) -> Union[Tuple[torch.Tensor], QuestionAnsweringModelOutput]:
|
| 991 |
+
r"""
|
| 992 |
+
start_positions (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
|
| 993 |
+
Labels for position (index) of the start of the labelled span for computing the token classification loss.
|
| 994 |
+
Positions are clamped to the length of the sequence (`sequence_length`). Position outside of the sequence
|
| 995 |
+
are not taken into account for computing the loss.
|
| 996 |
+
end_positions (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
|
| 997 |
+
Labels for position (index) of the end of the labelled span for computing the token classification loss.
|
| 998 |
+
Positions are clamped to the length of the sequence (`sequence_length`). Position outside of the sequence
|
| 999 |
+
are not taken into account for computing the loss.
|
| 1000 |
+
"""
|
| 1001 |
+
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
| 1002 |
+
|
| 1003 |
+
outputs = self.model(
|
| 1004 |
+
input_ids,
|
| 1005 |
+
attention_mask=attention_mask,
|
| 1006 |
+
position_ids=position_ids,
|
| 1007 |
+
inputs_embeds=inputs_embeds,
|
| 1008 |
+
use_cache=use_cache,
|
| 1009 |
+
output_attentions=output_attentions,
|
| 1010 |
+
output_hidden_states=output_hidden_states,
|
| 1011 |
+
return_dict=return_dict,
|
| 1012 |
+
)
|
| 1013 |
+
sequence_output = outputs[0]
|
| 1014 |
+
|
| 1015 |
+
logits = self.qa_outputs(sequence_output)
|
| 1016 |
+
start_logits, end_logits = logits.split(1, dim=-1)
|
| 1017 |
+
start_logits = start_logits.squeeze(-1).contiguous()
|
| 1018 |
+
end_logits = end_logits.squeeze(-1).contiguous()
|
| 1019 |
+
|
| 1020 |
+
total_loss = None
|
| 1021 |
+
if start_positions is not None and end_positions is not None:
|
| 1022 |
+
# If we are on multi-GPU, split add a dimension
|
| 1023 |
+
if len(start_positions.size()) > 1:
|
| 1024 |
+
start_positions = start_positions.squeeze(-1)
|
| 1025 |
+
if len(end_positions.size()) > 1:
|
| 1026 |
+
end_positions = end_positions.squeeze(-1)
|
| 1027 |
+
# sometimes the start/end positions are outside our model inputs, we ignore these terms
|
| 1028 |
+
ignored_index = start_logits.size(1)
|
| 1029 |
+
start_positions = start_positions.clamp(0, ignored_index)
|
| 1030 |
+
end_positions = end_positions.clamp(0, ignored_index)
|
| 1031 |
+
|
| 1032 |
+
loss_fct = CrossEntropyLoss(ignore_index=ignored_index)
|
| 1033 |
+
start_loss = loss_fct(start_logits, start_positions)
|
| 1034 |
+
end_loss = loss_fct(end_logits, end_positions)
|
| 1035 |
+
total_loss = (start_loss + end_loss) / 2
|
| 1036 |
+
|
| 1037 |
+
if not return_dict:
|
| 1038 |
+
output = (start_logits, end_logits) + outputs[2:]
|
| 1039 |
+
return ((total_loss,) + output) if total_loss is not None else output
|
| 1040 |
+
|
| 1041 |
+
return QuestionAnsweringModelOutput(
|
| 1042 |
+
loss=total_loss,
|
| 1043 |
+
start_logits=start_logits,
|
| 1044 |
+
end_logits=end_logits,
|
| 1045 |
+
hidden_states=outputs.hidden_states,
|
| 1046 |
+
attentions=outputs.attentions,
|
| 1047 |
+
)
|
| 1048 |
+
|
| 1049 |
+
|
| 1050 |
__all__ = [
|
| 1051 |
"EuroBertPreTrainedModel",
|
| 1052 |
"EuroBertModel",
|
| 1053 |
"EuroBertForMaskedLM",
|
| 1054 |
"EuroBertForSequenceClassification",
|
| 1055 |
"EuroBertForTokenClassification",
|
| 1056 |
+
"EuroBertForQuestionAnswering",
|
| 1057 |
]
|