myduy commited on
Commit
bf988ed
·
verified ·
1 Parent(s): 4d0639a

Update dd_generator

Browse files
Files changed (1) hide show
  1. dd_generator.py +465 -0
dd_generator.py ADDED
@@ -0,0 +1,465 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from dataclasses import dataclass, field
2
+ from typing import Optional
3
+
4
+ import torch
5
+ import torch.distributions as dists
6
+ from torch.nn.utils.rnn import pad_sequence
7
+
8
+ import numpy as np
9
+
10
+ import math
11
+
12
+
13
+ import sacrebleu
14
+
15
+ from rouge import Rouge
16
+
17
+ @dataclass
18
+ class DiscreteDiffusionGeneratorArguments:
19
+ max_iterations: int = field(
20
+ default=10
21
+ )
22
+ mbr: int = field(
23
+ default=1
24
+ )
25
+ length_beam: int = field(
26
+ default=1
27
+ )
28
+ oracle_length: bool = field(
29
+ default=False
30
+ )
31
+ strategy: str = field(
32
+ default="reparam-uncond-deterministic-cosine"
33
+ )
34
+ argmax_decoding: bool = field(
35
+ default=True
36
+ )
37
+ bpe: str = field(
38
+ default="sentencepiece"
39
+ )
40
+ bleu_tokenize: str = field(
41
+ default="13a"
42
+ )
43
+ return_history: bool = field(
44
+ default=False
45
+ )
46
+ temperature: float = field(
47
+ default=0.8
48
+ )
49
+
50
+
51
+
52
+ def topk_masking(scores, cutoff_len, stochastic=False, temp=1.0):
53
+ """
54
+ scores: [b, n]
55
+ cutoff_len: [b, 1]
56
+ stochastic: bool, whether to add noise to select top_k or not
57
+ returns:
58
+ mask: [b, n], with 1 if the token is in top-k lowest scores, 0 otherwise
59
+ """
60
+ if stochastic:
61
+ gumbel_noise = -torch.log(-torch.log(torch.rand_like(scores) + 1e-8) + 1e-8)
62
+ _scores = scores + temp * gumbel_noise
63
+ else:
64
+ _scores = scores
65
+ sorted_index = _scores.sort(-1)[0]
66
+ cutoff = sorted_index.gather(dim=-1, index=cutoff_len) # + 1e-10
67
+ # cutoff_len = k -> select k + 1 tokens
68
+ masking = _scores < cutoff
69
+ try:
70
+ assert (~(cutoff_len == 0).all()) | (~masking).all()
71
+ except:
72
+ import ipdb;ipdb.set_trace()
73
+ return masking
74
+
75
+
76
+ class MergeBLEU(object):
77
+ def __call__(self, evalpreds):
78
+ # if torch.distributed.get_rank() == 0:
79
+ # import ipdb; ipdb.set_trace()
80
+ # else:
81
+ # import time; time.sleep(120)
82
+ import inspect
83
+ sys_stats, ref_stats = evalpreds[0], evalpreds[1]
84
+
85
+ sys_stats = sys_stats.reshape(-1, 5).astype('long').sum(0).tolist()
86
+ ref_stats = ref_stats.reshape(-1, 5).astype('long').sum(0).tolist()
87
+ try:
88
+ from sacrebleu.metrics import BLEU
89
+ comp_bleu = BLEU.compute_bleu
90
+ except ImportError:
91
+ comp_bleu = sacrebleu.compute_bleu
92
+ fn_sig = inspect.getfullargspec(comp_bleu)[0]
93
+ if "smooth_method" in fn_sig:
94
+ smooth = {"smooth_method": "exp"}
95
+ else:
96
+ smooth = {"smooth": "exp"}
97
+ return {
98
+ "bleu": comp_bleu(
99
+ correct=sys_stats[:4],
100
+ total=ref_stats[:4],
101
+ sys_len=sys_stats[-1],
102
+ ref_len=ref_stats[-1],
103
+ **smooth
104
+ ).score
105
+ }
106
+
107
+ class MergeRouge(object):
108
+ def __call__(self, evalpreds):
109
+ # if torch.distributed.get_rank() == 0:
110
+ # import ipdb; ipdb.set_trace()
111
+ # else:
112
+ # import time; time.sleep(120)
113
+ import inspect
114
+ # sys
115
+ avg_rouge, batch_size = evalpreds[0], evalpreds[1]
116
+
117
+ rouge = (avg_rouge * batch_size).sum() / batch_size.sum()
118
+
119
+ return {
120
+ "rouge": rouge
121
+ }
122
+
123
+
124
+ class DiscreteDiffusionGenerator:
125
+ def __init__(self, args, dictionary=None, tokenizer=None) -> None:
126
+ self.args = args
127
+ self.dictionary = dictionary
128
+ self.tokenizer = tokenizer
129
+ self.write_prediction = None
130
+
131
+ assert (dictionary is not None) or (tokenizer is not None)
132
+ assert (dictionary is None) ^ (tokenizer is None)
133
+
134
+ self.retain_history = args.return_history
135
+
136
+ if dictionary is not None:
137
+ self.pad_id = dictionary.pad()
138
+ self.bos_id = dictionary.bos()
139
+ self.eos_id = dictionary.eos()
140
+ self.mask_id = dictionary.mask_index
141
+ else:
142
+ self.pad_id = tokenizer.pad_token_id
143
+ self.bos_id = tokenizer.bos_token_id
144
+ self.eos_id = tokenizer.eos_token_id
145
+ self.mask_id = tokenizer.mask_token_id
146
+
147
+ self.rouge = Rouge(["rouge-l"])
148
+
149
+ def set_write_to(self, path):
150
+ self.write_prediction = path
151
+
152
+ def _reparam_decoding(
153
+ self,
154
+ output_tokens,
155
+ output_scores,
156
+ cur_tokens,
157
+ cur_scores,
158
+ decoding_strategy,
159
+ xt_neq_x0,
160
+ non_special_sym_mask,
161
+ t,
162
+ max_step,
163
+ noise
164
+ ):
165
+ """
166
+ This function is used to perform reparameterized decoding.
167
+ """
168
+ # output_tokens: [B, N]
169
+ # output_scores: [B, N]
170
+ # cur_tokens: [B, N]
171
+ # cur_scores: [B, N]
172
+ # xt_neq_x0: equivalent to not_b_t [B, N]
173
+ # non_special_sym_mask: [B, N]
174
+ # noise: either [B, N] or scalar (if using the mask noise)
175
+
176
+ # decoding_strategy needs to take the form of "reparam-<conditioning>-<topk_mode>-<schedule>"
177
+ _, condition, topk_mode, schedule = decoding_strategy.split("-")
178
+
179
+ # first set the denoising rate according to the schedule
180
+ if schedule == "linear":
181
+ rate = 1 - t / max_step
182
+ elif schedule == "cosine":
183
+ rate = np.cos(t / max_step * np.pi * 0.5)
184
+ else:
185
+ raise NotImplementedError
186
+
187
+ # compute the cutoff length for denoising top-k positions
188
+ cutoff_len = (
189
+ non_special_sym_mask.sum(1, keepdim=True).type_as(output_scores) * rate
190
+ ).long()
191
+ # set the scores of special symbols to a large value so that they will never be selected
192
+ _scores_for_topk = cur_scores.masked_fill(~non_special_sym_mask, 1000.0)
193
+
194
+ # the top-k selection can be done in two ways: stochastic by injecting Gumbel noise or deterministic
195
+ if topk_mode.startswith("stochastic"):
196
+ noise_scale = float(topk_mode.replace("stochastic", ""))
197
+ lowest_k_mask = topk_masking(_scores_for_topk, cutoff_len, stochastic=True, temp=noise_scale * rate)
198
+ elif topk_mode == "deterministic":
199
+ lowest_k_mask = topk_masking(_scores_for_topk, cutoff_len, stochastic=False)
200
+ else:
201
+ raise NotImplementedError
202
+
203
+ # Various choices to generate v_t := [v1_t, v2_t].
204
+ # Note that
205
+ # v1_t governs the outcomes of tokens where b_t = 1,
206
+ # v2_t governs the outcomes of tokens where b_t = 0.
207
+
208
+ # #### the `uncond` mode ####
209
+ # In our reparameterized decoding,
210
+ # both v1_t and v2_t can be fully determined by the current token scores .
211
+
212
+ # #### the `cond` mode ####
213
+ # However, we can also impose some conditional constraints on v1_t so that
214
+ # the decoding can be performed in a more conservative manner.
215
+ # For example, we can set v1_t = 0 only when
216
+ # (the newly output tokens are the same as previous denoised results, AND
217
+ # the current token score becomes lower, AND
218
+ # the current token score is not in the top-k share among all tokens).
219
+ if condition == "cond":
220
+ not_v1_t = (cur_tokens == output_tokens) & (cur_scores < output_scores) & lowest_k_mask
221
+ elif condition == "uncond":
222
+ not_v1_t = lowest_k_mask
223
+ else:
224
+ raise NotImplementedError
225
+
226
+ # for b_t = 0, the token is set to noise if it is in the lowest k scores.
227
+ not_v2_t = lowest_k_mask
228
+
229
+ masked_to_noise = (~xt_neq_x0 & not_v1_t) | (xt_neq_x0 & not_v2_t)
230
+ if isinstance(noise, torch.Tensor):
231
+ output_tokens.masked_scatter_(masked_to_noise, noise[masked_to_noise])
232
+ elif isinstance(noise, (int, float)):
233
+ output_tokens.masked_fill_(masked_to_noise, noise)
234
+ else:
235
+ raise NotImplementedError("noise should be either a tensor or a scalar")
236
+ output_scores.masked_fill_(masked_to_noise, -math.inf)
237
+
238
+ masked_to_x0 = xt_neq_x0 & ~not_v2_t
239
+ output_tokens.masked_scatter_(masked_to_x0, cur_tokens[masked_to_x0])
240
+ output_scores.masked_scatter_(masked_to_x0, cur_scores[masked_to_x0])
241
+ # b_{t} = (b_{t+1} & u_t) | v_t
242
+ # For convenience, save the NOT of b_t for the next iteration
243
+ # NOT_b_{t} = (NOT_b_{t+1} | not_v1_t) & not_v2_t
244
+ new_xt_neq_x0 = (xt_neq_x0 | not_v1_t) & not_v2_t
245
+ return new_xt_neq_x0
246
+
247
+ def denoise_step(self, model, decoder_out, partial_masks):
248
+ output_tokens = decoder_out.output_tokens
249
+ output_scores = decoder_out.output_scores
250
+ prev_step, cur_step = decoder_out.step, decoder_out.step + 1
251
+ max_step = decoder_out.max_step
252
+ temperature = self.args.temperature
253
+ # temperature = (
254
+ # -0.05 * (cur_step / (max_step - 1)) + 0.5
255
+ # if self.temperature_annealing
256
+ # else self.temperature
257
+ # )
258
+
259
+ # t = torch.LongTensor(
260
+ # [(max_step - prev_step) * (model.num_diffusion_timesteps // max_step)] * output_tokens.size(0)
261
+ # ).to(output_tokens)
262
+ logits = model(output_tokens, partial_masks)
263
+
264
+ logits[..., self.mask_id] = -math.inf
265
+ scores = torch.log_softmax(logits, dim=-1)
266
+
267
+
268
+ if self.args.strategy == "cmlm":
269
+ # get the mask
270
+ # <bos>, <eos> are ignored in this case since
271
+ # they are not equal to unk.
272
+ output_masks = output_tokens.eq(self.mask_id)
273
+ unmask_prob = 1 / (max_step - prev_step)
274
+ # where to unmask
275
+ changes = torch.rand(output_tokens.shape, device=output_tokens.device) < unmask_prob
276
+ # don't unmask somewhere already unmasked
277
+ changes = torch.bitwise_and(changes, output_masks)
278
+
279
+ if self.args.argmax_decoding:
280
+ output_scores, new_tokens = scores.max(-1)
281
+ else:
282
+ new_tokens = dists.Categorical(logits=scores / temperature).sample()
283
+ output_scores = torch.gather(scores, -1, new_tokens.unsqueeze(-1)).squeeze(-1)
284
+ output_tokens[changes] = new_tokens[changes]
285
+ elif self.args.strategy == "ar":
286
+ output_masks = output_tokens.eq(self.mask_id)
287
+ unmask_indices = (output_tokens.ne(self.mask_id) & output_tokens.ne(self.eos_id) & output_tokens.ne(self.pad_id)).sum(dim=-1)
288
+ indices = torch.arange(output_tokens.size(-1)).expand(output_tokens.shape).to(output_masks.device)
289
+ if self.args.argmax_decoding:
290
+ output_scores, new_tokens = scores.max(-1)
291
+ else:
292
+ new_tokens = dists.Categorical(logits=scores / temperature).sample()
293
+ output_scores = torch.gather(scores, -1, new_tokens.unsqueeze(-1)).squeeze(-1)
294
+ output_tokens[unmask_indices[:, None]==indices] = new_tokens[unmask_indices[:, None]==indices]
295
+ # output_tokens[changes] = new_tokens[changes]
296
+ else:
297
+ if self.args.argmax_decoding:
298
+ cur_scores, cur_tokens = scores.max(-1)
299
+ else:
300
+ cur_tokens = dists.Categorical(logits=scores / temperature).sample()
301
+ cur_scores = torch.gather(scores, -1, cur_tokens.unsqueeze(-1)).squeeze(-1)
302
+ cur_scores = cur_scores.to(output_scores)
303
+
304
+ output_masks = self._reparam_decoding(
305
+ output_tokens=output_tokens,
306
+ output_scores=output_scores,
307
+ cur_tokens=cur_tokens,
308
+ cur_scores=cur_scores,
309
+ decoding_strategy=self.args.strategy,
310
+ xt_neq_x0=decoder_out.output_masks,
311
+ non_special_sym_mask=decoder_out.non_fixed_sym_masks,
312
+ t=cur_step,
313
+ max_step=max_step,
314
+ noise=self.mask_id
315
+ )
316
+ if self.retain_history:
317
+ history = ([] if decoder_out.history is None else decoder_out.history) + [output_tokens.clone()]
318
+ else:
319
+ history = None
320
+ # history = (
321
+ # decoder_out.history + [output_tokens.clone()]
322
+ # if self.retain_history
323
+ # else None
324
+ # )
325
+ return decoder_out._replace(
326
+ step=cur_step,
327
+ output_tokens=output_tokens,
328
+ output_scores=output_scores,
329
+ output_masks=output_masks,
330
+ history=history,
331
+ )
332
+
333
+
334
+ def decode(self, seqs_tensors, preserve_special=False):
335
+ seqs_tensors[seqs_tensors < 0] = self.pad_id
336
+ if self.dictionary is not None:
337
+ seqs = [
338
+ self.dictionary.string(seq, self.args.bpe).strip()
339
+ for seq in seqs_tensors
340
+ ]
341
+ if not preserve_special:
342
+ seqs = [seq.replace(self.dictionary.pad_word, '') for seq in seqs]
343
+ else:
344
+ seqs = self.tokenizer.batch_decode(seqs_tensors, skip_special_tokens=(not preserve_special))
345
+ return [seq.lower() for seq in seqs]
346
+
347
+ def compute_bleu(self, hyps, refs):
348
+ if isinstance(hyps, torch.Tensor):
349
+ hyps = self.decode(hyps)
350
+ if isinstance(refs, torch.Tensor):
351
+ refs = self.decode(refs)
352
+ return sacrebleu.corpus_bleu(hyps, [refs], tokenize=self.args.bleu_tokenize)
353
+
354
+ def compute_rouge(self, hyps, refs):
355
+ if isinstance(hyps, torch.Tensor):
356
+ hyps = self.decode(hyps)
357
+ if isinstance(refs, torch.Tensor):
358
+ refs = self.decode(refs)
359
+ return self.rouge.get_scores(hyps, [[ref] for ref in refs])['rouge-l']['f']
360
+
361
+ def stepwise_generate(self, model, inputs):
362
+ src_tokens = inputs["net_input"]["src_tokens"]
363
+ partial_masks = inputs["net_input"]["partial_masks"]
364
+ # assert src_tokens.size(-1) < 514
365
+ # assert partial_masks.size(-1) < 514
366
+ # target = inputs["target"]
367
+ raw_model = model.module if hasattr(model, "module") else model
368
+ if "prefix_masks" in inputs["net_input"]:
369
+ prefix_masks = inputs["net_input"]["prefix_masks"]
370
+ else:
371
+ prefix_masks = partial_masks
372
+ # TODO: FIXME: to support general blockwise generation.
373
+ partial_masks, prev_decoder_out = raw_model.initialize_decode_samples(
374
+ src_tokens, partial_masks, prefix_masks, oracle_length=self.args.oracle_length, length_beam=self.args.length_beam, mbr=self.args.mbr
375
+ )
376
+ prev_decoder_out = prev_decoder_out._replace(
377
+ step=0, max_step=self.args.max_iterations
378
+ )
379
+ for step in range(self.args.max_iterations):
380
+ prev_decoder_out = self.denoise_step(model, prev_decoder_out, partial_masks)
381
+ yield prev_decoder_out
382
+
383
+ @torch.no_grad()
384
+ def generate(self, model, inputs):
385
+ src_tokens = inputs["net_input"]["src_tokens"]
386
+ partial_masks = inputs["net_input"]["partial_masks"]
387
+ # assert src_tokens.size(-1) < 514
388
+ # assert partial_masks.size(-1) < 514
389
+ # target = inputs["target"]
390
+ # TODO: FIXME: to support general blockwise generation.
391
+ if "prefix_masks" in inputs["net_input"]:
392
+ prefix_masks = inputs["net_input"]["prefix_masks"]
393
+ else:
394
+ prefix_masks = partial_masks
395
+ partial_masks, prev_decoder_out = model.initialize_decode_samples(
396
+ src_tokens, partial_masks, prefix_masks, oracle_length=self.args.oracle_length, length_beam=self.args.length_beam, mbr=self.args.mbr
397
+ )
398
+ prev_decoder_out = prev_decoder_out._replace(
399
+ step=0, max_step=self.args.max_iterations
400
+ )
401
+
402
+ for step in range(self.args.max_iterations):
403
+ prev_decoder_out = self.denoise_step(model, prev_decoder_out, partial_masks)
404
+
405
+ def finalized_hypos(tokens, scores, partial_mask, history=None):
406
+ cutoff = (
407
+ tokens.ne(self.pad_id) &
408
+ tokens.ne(self.bos_id) &
409
+ tokens.ne(self.eos_id) &
410
+ (~partial_mask)
411
+ )
412
+ tokens = tokens[cutoff]
413
+ if scores is None:
414
+ score = None
415
+ else:
416
+ scores = scores[cutoff]
417
+ score = scores.mean().item()
418
+ ret_dict = {
419
+ "tokens": tokens,
420
+ "positional_scores": scores,
421
+ "score": score,
422
+ "alignment": None
423
+ }
424
+ if history is not None:
425
+ ret_dict["history"] = [
426
+ finalized_hypos(history_tokens, None, partial_mask, history=None)
427
+ for history_tokens in history
428
+ ]
429
+ return ret_dict
430
+
431
+ def mbr_select(hyps):
432
+ index = np.argmax(np.array(
433
+ [self.rouge.get_scores([hyps[i]], [[hyps[j]]])['rouge-l']['f']
434
+ for j in range(len(hyps)) if i != j]
435
+ ).mean() for i in range(len(hyps)))
436
+ return hyps[index]
437
+
438
+ def score_select(hyps):
439
+ index = np.argmax([hyp["score"] for hyp in hyps])
440
+ return hyps[index]
441
+
442
+ output_tokens, output_scores = prev_decoder_out.output_tokens, prev_decoder_out.output_scores
443
+ if self.retain_history:
444
+ full_history = prev_decoder_out.history
445
+ histories = [[full_history[j][i] for j in range(self.args.max_iterations)] for i in range(output_tokens.size(0))]
446
+ hyps = []
447
+ for tokens, scores, partial_mask, history in zip(output_tokens, output_scores, partial_masks, histories):
448
+ hyps.append(finalized_hypos(tokens, scores, partial_mask, history))
449
+ # hyps = [
450
+ # finalized_hypos(tokens, scores, partial_mask, history)
451
+ # for tokens, scores, partial_mask, history in zip(output_tokens, output_scores, partial_masks, histories)
452
+ # ]
453
+ else:
454
+ hyps = [
455
+ finalized_hypos(tokens, scores, partial_mask, None)
456
+ for tokens, scores, partial_mask in zip(output_tokens, output_scores, partial_masks)
457
+ ]
458
+ repeatition = self.args.mbr * self.args.length_beam
459
+ if repeatition > 1:
460
+ hyps = [score_select(hyps[i:i+repeatition])for i in range(0, len(hyps), repeatition)]
461
+ # hyps = [mbr_select(hyps[i:i+repeatition])for i in range(0, len(hyps), repeatition)]
462
+
463
+ finalized = pad_sequence([h["tokens"] for h in hyps ], batch_first=True, padding_value=self.pad_id)
464
+ history = [[item["tokens"] for item in h["history"]] for h in hyps] if self.retain_history else None
465
+ return finalized, history