Upload BertForJointParsing.py
Browse files- BertForJointParsing.py +19 -18
BertForJointParsing.py
CHANGED
|
@@ -81,6 +81,7 @@ class BertForJointParsing(BertPreTrainedModel):
|
|
| 81 |
|
| 82 |
def set_output_embeddings(self, new_embeddings):
|
| 83 |
if self.lex is not None:
|
|
|
|
| 84 |
self.cls.predictions.decoder = new_embeddings
|
| 85 |
|
| 86 |
def forward(
|
|
@@ -207,7 +208,7 @@ class BertForJointParsing(BertPreTrainedModel):
|
|
| 207 |
inputs = {k:v.to(self.device) for k,v in inputs.items()}
|
| 208 |
output = self.forward(**inputs, return_dict=True, compute_syntax_mst=compute_syntax_mst)
|
| 209 |
|
| 210 |
-
final_output = [dict(text=sentence, tokens=
|
| 211 |
# Syntax logits: each sentence gets a dict(tree: List[dict(word,dep_head,dep_head_idx,dep_func)], root_idx: int)
|
| 212 |
if output.syntax_logits is not None:
|
| 213 |
for sent_idx,parsed in enumerate(syntax_parse_logits(inputs, sentences, tokenizer, output.syntax_logits)):
|
|
@@ -231,10 +232,10 @@ class BertForJointParsing(BertPreTrainedModel):
|
|
| 231 |
|
| 232 |
# NER logits each sentence gets a list(tuple(word, ner))
|
| 233 |
if output.ner_logits is not None:
|
| 234 |
-
for sent_idx,parsed in enumerate(ner_parse_logits(inputs, sentences, tokenizer, output.ner_logits, self.config.id2label
|
| 235 |
if per_token_ner:
|
| 236 |
merge_token_list(final_output[sent_idx]['tokens'], map(itemgetter(1), parsed), 'ner')
|
| 237 |
-
final_output[sent_idx]['ner_entities'] = aggregate_ner_tokens(parsed)
|
| 238 |
|
| 239 |
if output_style in ['ud', 'iahlt_ud']:
|
| 240 |
final_output = convert_output_to_ud(final_output, style='htb' if output_style == 'ud' else 'iahlt')
|
|
@@ -245,36 +246,39 @@ class BertForJointParsing(BertPreTrainedModel):
|
|
| 245 |
|
| 246 |
|
| 247 |
|
| 248 |
-
def aggregate_ner_tokens(
|
| 249 |
entities = []
|
| 250 |
prev = None
|
| 251 |
-
for
|
| 252 |
# O does nothing
|
| 253 |
if pred == 'O': prev = None
|
| 254 |
# B- || I-entity != prev (different entity or none)
|
| 255 |
elif pred.startswith('B-') or pred[2:] != prev:
|
| 256 |
prev = pred[2:]
|
| 257 |
-
entities.append([[word], prev, start, end])
|
| 258 |
else:
|
| 259 |
entities[-1][0].append(word)
|
| 260 |
-
entities[-1][
|
|
|
|
| 261 |
|
| 262 |
-
return [dict(phrase=' '.join(words),
|
| 263 |
|
| 264 |
def merge_token_list(src, update, key):
|
| 265 |
for token_src, token_update in zip(src, update):
|
| 266 |
token_src[key] = token_update
|
| 267 |
|
| 268 |
-
def combine_token_wordpieces(input_ids: torch.Tensor, tokenizer: BertTokenizerFast):
|
|
|
|
| 269 |
ret = []
|
| 270 |
-
for token in tokenizer.convert_ids_to_tokens(input_ids):
|
| 271 |
if token in [tokenizer.cls_token, tokenizer.sep_token, tokenizer.pad_token]: continue
|
| 272 |
if token.startswith('##'):
|
| 273 |
-
ret[-1] += token[2:]
|
| 274 |
-
|
|
|
|
| 275 |
return ret
|
| 276 |
|
| 277 |
-
def ner_parse_logits(inputs: Dict[str, torch.Tensor], sentences: List[str], tokenizer: BertTokenizerFast, logits: torch.Tensor, id2label: Dict[int, str]
|
| 278 |
input_ids = inputs['input_ids']
|
| 279 |
|
| 280 |
predictions = torch.argmax(logits, dim=-1)
|
|
@@ -289,16 +293,13 @@ def ner_parse_logits(inputs: Dict[str, torch.Tensor], sentences: List[str], toke
|
|
| 289 |
|
| 290 |
token = tokenizer._convert_id_to_token(token_id)
|
| 291 |
|
| 292 |
-
# get the offsets for this token
|
| 293 |
-
start_pos, end_pos = offset_mapping[batch_idx, tok_idx]
|
| 294 |
# wordpieces should just be appended to the previous word
|
| 295 |
# we modify the last token in ret
|
| 296 |
# by discarding the original end position and replacing it with the new token's end position
|
| 297 |
if token.startswith('##'):
|
| 298 |
-
ret[-1] = (ret[-1][0] + token[2:], ret[-1][1], ret[-1][2], end_pos.item())
|
| 299 |
continue
|
| 300 |
-
|
| 301 |
-
ret.append((token, id2label[predictions[batch_idx, tok_idx].item()]
|
| 302 |
|
| 303 |
return batch_ret
|
| 304 |
|
|
|
|
| 81 |
|
| 82 |
def set_output_embeddings(self, new_embeddings):
|
| 83 |
if self.lex is not None:
|
| 84 |
+
|
| 85 |
self.cls.predictions.decoder = new_embeddings
|
| 86 |
|
| 87 |
def forward(
|
|
|
|
| 208 |
inputs = {k:v.to(self.device) for k,v in inputs.items()}
|
| 209 |
output = self.forward(**inputs, return_dict=True, compute_syntax_mst=compute_syntax_mst)
|
| 210 |
|
| 211 |
+
final_output = [dict(text=sentence, tokens=combine_token_wordpieces(ids, offsets, tokenizer)) for sentence, ids, offsets in zip(sentences, inputs['input_ids'], offset_mapping)]
|
| 212 |
# Syntax logits: each sentence gets a dict(tree: List[dict(word,dep_head,dep_head_idx,dep_func)], root_idx: int)
|
| 213 |
if output.syntax_logits is not None:
|
| 214 |
for sent_idx,parsed in enumerate(syntax_parse_logits(inputs, sentences, tokenizer, output.syntax_logits)):
|
|
|
|
| 232 |
|
| 233 |
# NER logits each sentence gets a list(tuple(word, ner))
|
| 234 |
if output.ner_logits is not None:
|
| 235 |
+
for sent_idx,parsed in enumerate(ner_parse_logits(inputs, sentences, tokenizer, output.ner_logits, self.config.id2label)):
|
| 236 |
if per_token_ner:
|
| 237 |
merge_token_list(final_output[sent_idx]['tokens'], map(itemgetter(1), parsed), 'ner')
|
| 238 |
+
final_output[sent_idx]['ner_entities'] = aggregate_ner_tokens(final_output[sent_idx], parsed)
|
| 239 |
|
| 240 |
if output_style in ['ud', 'iahlt_ud']:
|
| 241 |
final_output = convert_output_to_ud(final_output, style='htb' if output_style == 'ud' else 'iahlt')
|
|
|
|
| 246 |
|
| 247 |
|
| 248 |
|
| 249 |
+
def aggregate_ner_tokens(final_output, parsed):
|
| 250 |
entities = []
|
| 251 |
prev = None
|
| 252 |
+
for token_idx, (d, (word, pred)) in enumerate(zip(final_output['tokens'], parsed)):
|
| 253 |
# O does nothing
|
| 254 |
if pred == 'O': prev = None
|
| 255 |
# B- || I-entity != prev (different entity or none)
|
| 256 |
elif pred.startswith('B-') or pred[2:] != prev:
|
| 257 |
prev = pred[2:]
|
| 258 |
+
entities.append([[word], dict(label=prev, start=d['offsets']['start'], end=d['offsets']['end'], token_start=token_idx, token_end=token_idx)])
|
| 259 |
else:
|
| 260 |
entities[-1][0].append(word)
|
| 261 |
+
entities[-1][1]['end'] = d['offsets']['end']
|
| 262 |
+
entities[-1][1]['token_end'] = token_idx
|
| 263 |
|
| 264 |
+
return [dict(phrase=' '.join(words), **d) for words, d in entities]
|
| 265 |
|
| 266 |
def merge_token_list(src, update, key):
|
| 267 |
for token_src, token_update in zip(src, update):
|
| 268 |
token_src[key] = token_update
|
| 269 |
|
| 270 |
+
def combine_token_wordpieces(input_ids: torch.Tensor, offset_mapping: torch.Tensor, tokenizer: BertTokenizerFast):
|
| 271 |
+
offset_mapping = offset_mapping.tolist()
|
| 272 |
ret = []
|
| 273 |
+
for token, offsets in zip(tokenizer.convert_ids_to_tokens(input_ids), offset_mapping):
|
| 274 |
if token in [tokenizer.cls_token, tokenizer.sep_token, tokenizer.pad_token]: continue
|
| 275 |
if token.startswith('##'):
|
| 276 |
+
ret[-1]['token'] += token[2:]
|
| 277 |
+
ret[-1]['offsets']['end'] = offsets[1]
|
| 278 |
+
else: ret.append(dict(token=token, offsets=dict(start=offsets[0], end=offsets[1])))
|
| 279 |
return ret
|
| 280 |
|
| 281 |
+
def ner_parse_logits(inputs: Dict[str, torch.Tensor], sentences: List[str], tokenizer: BertTokenizerFast, logits: torch.Tensor, id2label: Dict[int, str]):
|
| 282 |
input_ids = inputs['input_ids']
|
| 283 |
|
| 284 |
predictions = torch.argmax(logits, dim=-1)
|
|
|
|
| 293 |
|
| 294 |
token = tokenizer._convert_id_to_token(token_id)
|
| 295 |
|
|
|
|
|
|
|
| 296 |
# wordpieces should just be appended to the previous word
|
| 297 |
# we modify the last token in ret
|
| 298 |
# by discarding the original end position and replacing it with the new token's end position
|
| 299 |
if token.startswith('##'):
|
|
|
|
| 300 |
continue
|
| 301 |
+
# for each token, we append a tuple containing: token, label, start position, end position
|
| 302 |
+
ret.append((token, id2label[predictions[batch_idx, tok_idx].item()]))
|
| 303 |
|
| 304 |
return batch_ret
|
| 305 |
|