Update modeling_relik.py
Browse files- modeling_relik.py +24 -16
modeling_relik.py
CHANGED
|
@@ -233,9 +233,7 @@ class RelikReaderSpanModel(PreTrainedModel):
|
|
| 233 |
torch.permute(special_symbols_representation, (0, 2, 1)),
|
| 234 |
)
|
| 235 |
|
| 236 |
-
logits = self._mask_logits(
|
| 237 |
-
logits, (model_features_start == -100).all(2).long()
|
| 238 |
-
)
|
| 239 |
return logits
|
| 240 |
|
| 241 |
def forward(
|
|
@@ -280,7 +278,7 @@ class RelikReaderSpanModel(PreTrainedModel):
|
|
| 280 |
),
|
| 281 |
)
|
| 282 |
ned_start_predictions[ned_start_predictions > 0] = 1
|
| 283 |
-
ned_end_predictions[end_labels > 0] = 1
|
| 284 |
ned_end_predictions = ned_end_predictions[~(end_labels == -100).all(2)]
|
| 285 |
|
| 286 |
else: # compute spans
|
|
@@ -310,14 +308,20 @@ class RelikReaderSpanModel(PreTrainedModel):
|
|
| 310 |
if ned_end_logits is not None:
|
| 311 |
ned_end_probabilities = torch.softmax(ned_end_logits, dim=-1)
|
| 312 |
if not self.config.binary_end_logits:
|
| 313 |
-
ned_end_predictions = torch.argmax(
|
| 314 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 315 |
else:
|
| 316 |
ned_end_predictions = torch.argmax(ned_end_probabilities, dim=-1)
|
| 317 |
else:
|
| 318 |
ned_end_logits, ned_end_probabilities = None, None
|
| 319 |
-
ned_end_predictions = ned_start_predictions.new_zeros(
|
| 320 |
-
|
|
|
|
|
|
|
| 321 |
if not self.training:
|
| 322 |
# if len(ned_end_predictions.shape) < 2:
|
| 323 |
# print(ned_end_predictions)
|
|
@@ -344,12 +348,11 @@ class RelikReaderSpanModel(PreTrainedModel):
|
|
| 344 |
if (end_position > 0).sum() > 0:
|
| 345 |
ends_count = (end_position > 0).sum(1)
|
| 346 |
model_entity_start = torch.repeat_interleave(
|
| 347 |
-
|
| 348 |
-
|
| 349 |
model_entity_end = torch.repeat_interleave(
|
| 350 |
-
|
| 351 |
-
|
| 352 |
-
]
|
| 353 |
ents_count = torch.nn.utils.rnn.pad_sequence(
|
| 354 |
torch.split(ends_count, start_counts.tolist()),
|
| 355 |
batch_first=True,
|
|
@@ -379,7 +382,7 @@ class RelikReaderSpanModel(PreTrainedModel):
|
|
| 379 |
ed_predictions = torch.argmax(ed_probabilities, dim=-1)
|
| 380 |
else:
|
| 381 |
ed_logits, ed_probabilities, ed_predictions = (
|
| 382 |
-
None,
|
| 383 |
ned_start_predictions.new_zeros(batch_size, seq_len),
|
| 384 |
ned_start_predictions.new_zeros(batch_size),
|
| 385 |
)
|
|
@@ -429,8 +432,11 @@ class RelikReaderSpanModel(PreTrainedModel):
|
|
| 429 |
end_labels.view(-1),
|
| 430 |
)
|
| 431 |
else:
|
| 432 |
-
ned_end_loss = self.criterion(
|
| 433 |
-
|
|
|
|
|
|
|
|
|
|
| 434 |
# entity disambiguation loss
|
| 435 |
ed_loss = self.criterion(
|
| 436 |
ed_logits.view(-1, ed_logits.shape[-1]),
|
|
@@ -833,6 +839,8 @@ class RelikReaderREModel(PreTrainedModel):
|
|
| 833 |
start_counts = (start_position > 0).sum(1)
|
| 834 |
if (start_counts > 0).any():
|
| 835 |
ned_end_predictions = ned_end_predictions.split(start_counts.tolist())
|
|
|
|
|
|
|
| 836 |
# limit to 30 predictions per document using start_counts, by setting all po after sum is 30 to 0
|
| 837 |
# if is_validation or is_prediction:
|
| 838 |
# ned_start_predictions[ned_start_predictions == 1] = start_counts
|
|
|
|
| 233 |
torch.permute(special_symbols_representation, (0, 2, 1)),
|
| 234 |
)
|
| 235 |
|
| 236 |
+
logits = self._mask_logits(logits, (model_features_start == -100).all(2).long())
|
|
|
|
|
|
|
| 237 |
return logits
|
| 238 |
|
| 239 |
def forward(
|
|
|
|
| 278 |
),
|
| 279 |
)
|
| 280 |
ned_start_predictions[ned_start_predictions > 0] = 1
|
| 281 |
+
ned_end_predictions[end_labels > 0] = 1
|
| 282 |
ned_end_predictions = ned_end_predictions[~(end_labels == -100).all(2)]
|
| 283 |
|
| 284 |
else: # compute spans
|
|
|
|
| 308 |
if ned_end_logits is not None:
|
| 309 |
ned_end_probabilities = torch.softmax(ned_end_logits, dim=-1)
|
| 310 |
if not self.config.binary_end_logits:
|
| 311 |
+
ned_end_predictions = torch.argmax(
|
| 312 |
+
ned_end_probabilities, dim=-1, keepdim=True
|
| 313 |
+
)
|
| 314 |
+
ned_end_predictions = torch.zeros_like(
|
| 315 |
+
ned_end_probabilities
|
| 316 |
+
).scatter_(1, ned_end_predictions, 1)
|
| 317 |
else:
|
| 318 |
ned_end_predictions = torch.argmax(ned_end_probabilities, dim=-1)
|
| 319 |
else:
|
| 320 |
ned_end_logits, ned_end_probabilities = None, None
|
| 321 |
+
ned_end_predictions = ned_start_predictions.new_zeros(
|
| 322 |
+
batch_size, seq_len
|
| 323 |
+
)
|
| 324 |
+
|
| 325 |
if not self.training:
|
| 326 |
# if len(ned_end_predictions.shape) < 2:
|
| 327 |
# print(ned_end_predictions)
|
|
|
|
| 348 |
if (end_position > 0).sum() > 0:
|
| 349 |
ends_count = (end_position > 0).sum(1)
|
| 350 |
model_entity_start = torch.repeat_interleave(
|
| 351 |
+
model_features[start_position > 0], ends_count, dim=0
|
| 352 |
+
)
|
| 353 |
model_entity_end = torch.repeat_interleave(
|
| 354 |
+
model_features, start_counts, dim=0
|
| 355 |
+
)[end_position > 0]
|
|
|
|
| 356 |
ents_count = torch.nn.utils.rnn.pad_sequence(
|
| 357 |
torch.split(ends_count, start_counts.tolist()),
|
| 358 |
batch_first=True,
|
|
|
|
| 382 |
ed_predictions = torch.argmax(ed_probabilities, dim=-1)
|
| 383 |
else:
|
| 384 |
ed_logits, ed_probabilities, ed_predictions = (
|
| 385 |
+
None,
|
| 386 |
ned_start_predictions.new_zeros(batch_size, seq_len),
|
| 387 |
ned_start_predictions.new_zeros(batch_size),
|
| 388 |
)
|
|
|
|
| 432 |
end_labels.view(-1),
|
| 433 |
)
|
| 434 |
else:
|
| 435 |
+
ned_end_loss = self.criterion(
|
| 436 |
+
ned_end_logits.reshape(-1, ned_end_logits.shape[-1]),
|
| 437 |
+
end_labels.reshape(-1).long(),
|
| 438 |
+
)
|
| 439 |
+
|
| 440 |
# entity disambiguation loss
|
| 441 |
ed_loss = self.criterion(
|
| 442 |
ed_logits.view(-1, ed_logits.shape[-1]),
|
|
|
|
| 839 |
start_counts = (start_position > 0).sum(1)
|
| 840 |
if (start_counts > 0).any():
|
| 841 |
ned_end_predictions = ned_end_predictions.split(start_counts.tolist())
|
| 842 |
+
else:
|
| 843 |
+
ned_end_predictions = [torch.empty(0, input_ids.shape[1], dtype=torch.int64) for _ in range(batch_size)]
|
| 844 |
# limit to 30 predictions per document using start_counts, by setting all po after sum is 30 to 0
|
| 845 |
# if is_validation or is_prediction:
|
| 846 |
# ned_start_predictions[ned_start_predictions == 1] = start_counts
|