from dataclasses import dataclass, field from typing import Optional import torch import torch.distributions as dists from torch.nn.utils.rnn import pad_sequence import numpy as np import math import sacrebleu from rouge import Rouge @dataclass class DiscreteDiffusionGeneratorArguments: max_iterations: int = field( default=10 ) mbr: int = field( default=1 ) length_beam: int = field( default=1 ) oracle_length: bool = field( default=False ) strategy: str = field( default="reparam-uncond-deterministic-cosine" ) argmax_decoding: bool = field( default=True ) bpe: str = field( default="sentencepiece" ) bleu_tokenize: str = field( default="13a" ) return_history: bool = field( default=False ) temperature: float = field( default=0.8 ) def topk_masking(scores, cutoff_len, stochastic=False, temp=1.0): """ scores: [b, n] cutoff_len: [b, 1] stochastic: bool, whether to add noise to select top_k or not returns: mask: [b, n], with 1 if the token is in top-k lowest scores, 0 otherwise """ if stochastic: gumbel_noise = -torch.log(-torch.log(torch.rand_like(scores) + 1e-8) + 1e-8) _scores = scores + temp * gumbel_noise else: _scores = scores sorted_index = _scores.sort(-1)[0] cutoff = sorted_index.gather(dim=-1, index=cutoff_len) # + 1e-10 # cutoff_len = k -> select k + 1 tokens masking = _scores < cutoff try: assert (~(cutoff_len == 0).all()) | (~masking).all() except: import ipdb;ipdb.set_trace() return masking class MergeBLEU(object): def __call__(self, evalpreds): # if torch.distributed.get_rank() == 0: # import ipdb; ipdb.set_trace() # else: # import time; time.sleep(120) import inspect sys_stats, ref_stats = evalpreds[0], evalpreds[1] sys_stats = sys_stats.reshape(-1, 5).astype('long').sum(0).tolist() ref_stats = ref_stats.reshape(-1, 5).astype('long').sum(0).tolist() try: from sacrebleu.metrics import BLEU comp_bleu = BLEU.compute_bleu except ImportError: comp_bleu = sacrebleu.compute_bleu fn_sig = inspect.getfullargspec(comp_bleu)[0] if "smooth_method" in fn_sig: smooth = {"smooth_method": "exp"} else: smooth = {"smooth": "exp"} return { "bleu": comp_bleu( correct=sys_stats[:4], total=ref_stats[:4], sys_len=sys_stats[-1], ref_len=ref_stats[-1], **smooth ).score } class MergeRouge(object): def __call__(self, evalpreds): # if torch.distributed.get_rank() == 0: # import ipdb; ipdb.set_trace() # else: # import time; time.sleep(120) import inspect # sys avg_rouge, batch_size = evalpreds[0], evalpreds[1] rouge = (avg_rouge * batch_size).sum() / batch_size.sum() return { "rouge": rouge } class DiscreteDiffusionGenerator: def __init__(self, args, dictionary=None, tokenizer=None) -> None: self.args = args self.dictionary = dictionary self.tokenizer = tokenizer self.write_prediction = None assert (dictionary is not None) or (tokenizer is not None) assert (dictionary is None) ^ (tokenizer is None) self.retain_history = args.return_history if dictionary is not None: self.pad_id = dictionary.pad() self.bos_id = dictionary.bos() self.eos_id = dictionary.eos() self.mask_id = dictionary.mask_index else: self.pad_id = tokenizer.pad_token_id self.bos_id = tokenizer.bos_token_id self.eos_id = tokenizer.eos_token_id self.mask_id = tokenizer.mask_token_id self.rouge = Rouge(["rouge-l"]) def set_write_to(self, path): self.write_prediction = path def _reparam_decoding( self, output_tokens, output_scores, cur_tokens, cur_scores, decoding_strategy, xt_neq_x0, non_special_sym_mask, t, max_step, noise ): """ This function is used to perform reparameterized decoding. """ # output_tokens: [B, N] # output_scores: [B, N] # cur_tokens: [B, N] # cur_scores: [B, N] # xt_neq_x0: equivalent to not_b_t [B, N] # non_special_sym_mask: [B, N] # noise: either [B, N] or scalar (if using the mask noise) # decoding_strategy needs to take the form of "reparam---" _, condition, topk_mode, schedule = decoding_strategy.split("-") # first set the denoising rate according to the schedule if schedule == "linear": rate = 1 - t / max_step elif schedule == "cosine": rate = np.cos(t / max_step * np.pi * 0.5) else: raise NotImplementedError # compute the cutoff length for denoising top-k positions cutoff_len = ( non_special_sym_mask.sum(1, keepdim=True).type_as(output_scores) * rate ).long() # set the scores of special symbols to a large value so that they will never be selected _scores_for_topk = cur_scores.masked_fill(~non_special_sym_mask, 1000.0) # the top-k selection can be done in two ways: stochastic by injecting Gumbel noise or deterministic if topk_mode.startswith("stochastic"): noise_scale = float(topk_mode.replace("stochastic", "")) lowest_k_mask = topk_masking(_scores_for_topk, cutoff_len, stochastic=True, temp=noise_scale * rate) elif topk_mode == "deterministic": lowest_k_mask = topk_masking(_scores_for_topk, cutoff_len, stochastic=False) else: raise NotImplementedError # Various choices to generate v_t := [v1_t, v2_t]. # Note that # v1_t governs the outcomes of tokens where b_t = 1, # v2_t governs the outcomes of tokens where b_t = 0. # #### the `uncond` mode #### # In our reparameterized decoding, # both v1_t and v2_t can be fully determined by the current token scores . # #### the `cond` mode #### # However, we can also impose some conditional constraints on v1_t so that # the decoding can be performed in a more conservative manner. # For example, we can set v1_t = 0 only when # (the newly output tokens are the same as previous denoised results, AND # the current token score becomes lower, AND # the current token score is not in the top-k share among all tokens). if condition == "cond": not_v1_t = (cur_tokens == output_tokens) & (cur_scores < output_scores) & lowest_k_mask elif condition == "uncond": not_v1_t = lowest_k_mask else: raise NotImplementedError # for b_t = 0, the token is set to noise if it is in the lowest k scores. not_v2_t = lowest_k_mask masked_to_noise = (~xt_neq_x0 & not_v1_t) | (xt_neq_x0 & not_v2_t) if isinstance(noise, torch.Tensor): output_tokens.masked_scatter_(masked_to_noise, noise[masked_to_noise]) elif isinstance(noise, (int, float)): output_tokens.masked_fill_(masked_to_noise, noise) else: raise NotImplementedError("noise should be either a tensor or a scalar") output_scores.masked_fill_(masked_to_noise, -math.inf) masked_to_x0 = xt_neq_x0 & ~not_v2_t output_tokens.masked_scatter_(masked_to_x0, cur_tokens[masked_to_x0]) output_scores.masked_scatter_(masked_to_x0, cur_scores[masked_to_x0]) # b_{t} = (b_{t+1} & u_t) | v_t # For convenience, save the NOT of b_t for the next iteration # NOT_b_{t} = (NOT_b_{t+1} | not_v1_t) & not_v2_t new_xt_neq_x0 = (xt_neq_x0 | not_v1_t) & not_v2_t return new_xt_neq_x0 def denoise_step(self, model, decoder_out, partial_masks): output_tokens = decoder_out.output_tokens output_scores = decoder_out.output_scores prev_step, cur_step = decoder_out.step, decoder_out.step + 1 max_step = decoder_out.max_step temperature = self.args.temperature # temperature = ( # -0.05 * (cur_step / (max_step - 1)) + 0.5 # if self.temperature_annealing # else self.temperature # ) # t = torch.LongTensor( # [(max_step - prev_step) * (model.num_diffusion_timesteps // max_step)] * output_tokens.size(0) # ).to(output_tokens) logits = model(output_tokens, partial_masks) logits[..., self.mask_id] = -math.inf scores = torch.log_softmax(logits, dim=-1) if self.args.strategy == "cmlm": # get the mask # , are ignored in this case since # they are not equal to unk. output_masks = output_tokens.eq(self.mask_id) unmask_prob = 1 / (max_step - prev_step) # where to unmask changes = torch.rand(output_tokens.shape, device=output_tokens.device) < unmask_prob # don't unmask somewhere already unmasked changes = torch.bitwise_and(changes, output_masks) if self.args.argmax_decoding: output_scores, new_tokens = scores.max(-1) else: new_tokens = dists.Categorical(logits=scores / temperature).sample() output_scores = torch.gather(scores, -1, new_tokens.unsqueeze(-1)).squeeze(-1) output_tokens[changes] = new_tokens[changes] elif self.args.strategy == "ar": output_masks = output_tokens.eq(self.mask_id) unmask_indices = (output_tokens.ne(self.mask_id) & output_tokens.ne(self.eos_id) & output_tokens.ne(self.pad_id)).sum(dim=-1) indices = torch.arange(output_tokens.size(-1)).expand(output_tokens.shape).to(output_masks.device) if self.args.argmax_decoding: output_scores, new_tokens = scores.max(-1) else: new_tokens = dists.Categorical(logits=scores / temperature).sample() output_scores = torch.gather(scores, -1, new_tokens.unsqueeze(-1)).squeeze(-1) output_tokens[unmask_indices[:, None]==indices] = new_tokens[unmask_indices[:, None]==indices] # output_tokens[changes] = new_tokens[changes] else: if self.args.argmax_decoding: cur_scores, cur_tokens = scores.max(-1) else: cur_tokens = dists.Categorical(logits=scores / temperature).sample() cur_scores = torch.gather(scores, -1, cur_tokens.unsqueeze(-1)).squeeze(-1) cur_scores = cur_scores.to(output_scores) output_masks = self._reparam_decoding( output_tokens=output_tokens, output_scores=output_scores, cur_tokens=cur_tokens, cur_scores=cur_scores, decoding_strategy=self.args.strategy, xt_neq_x0=decoder_out.output_masks, non_special_sym_mask=decoder_out.non_fixed_sym_masks, t=cur_step, max_step=max_step, noise=self.mask_id ) if self.retain_history: history = ([] if decoder_out.history is None else decoder_out.history) + [output_tokens.clone()] else: history = None # history = ( # decoder_out.history + [output_tokens.clone()] # if self.retain_history # else None # ) return decoder_out._replace( step=cur_step, output_tokens=output_tokens, output_scores=output_scores, output_masks=output_masks, history=history, ) def decode(self, seqs_tensors, preserve_special=False): seqs_tensors[seqs_tensors < 0] = self.pad_id if self.dictionary is not None: seqs = [ self.dictionary.string(seq, self.args.bpe).strip() for seq in seqs_tensors ] if not preserve_special: seqs = [seq.replace(self.dictionary.pad_word, '') for seq in seqs] else: seqs = self.tokenizer.batch_decode(seqs_tensors, skip_special_tokens=(not preserve_special)) return [seq.lower() for seq in seqs] def compute_bleu(self, hyps, refs): if isinstance(hyps, torch.Tensor): hyps = self.decode(hyps) if isinstance(refs, torch.Tensor): refs = self.decode(refs) return sacrebleu.corpus_bleu(hyps, [refs], tokenize=self.args.bleu_tokenize) def compute_rouge(self, hyps, refs): if isinstance(hyps, torch.Tensor): hyps = self.decode(hyps) if isinstance(refs, torch.Tensor): refs = self.decode(refs) return self.rouge.get_scores(hyps, [[ref] for ref in refs])['rouge-l']['f'] def stepwise_generate(self, model, inputs): src_tokens = inputs["net_input"]["src_tokens"] partial_masks = inputs["net_input"]["partial_masks"] # assert src_tokens.size(-1) < 514 # assert partial_masks.size(-1) < 514 # target = inputs["target"] raw_model = model.module if hasattr(model, "module") else model if "prefix_masks" in inputs["net_input"]: prefix_masks = inputs["net_input"]["prefix_masks"] else: prefix_masks = partial_masks # TODO: FIXME: to support general blockwise generation. partial_masks, prev_decoder_out = raw_model.initialize_decode_samples( src_tokens, partial_masks, prefix_masks, oracle_length=self.args.oracle_length, length_beam=self.args.length_beam, mbr=self.args.mbr ) prev_decoder_out = prev_decoder_out._replace( step=0, max_step=self.args.max_iterations ) for step in range(self.args.max_iterations): prev_decoder_out = self.denoise_step(model, prev_decoder_out, partial_masks) yield prev_decoder_out @torch.no_grad() def generate(self, model, inputs): src_tokens = inputs["net_input"]["src_tokens"] partial_masks = inputs["net_input"]["partial_masks"] # assert src_tokens.size(-1) < 514 # assert partial_masks.size(-1) < 514 # target = inputs["target"] # TODO: FIXME: to support general blockwise generation. if "prefix_masks" in inputs["net_input"]: prefix_masks = inputs["net_input"]["prefix_masks"] else: prefix_masks = partial_masks partial_masks, prev_decoder_out = model.initialize_decode_samples( src_tokens, partial_masks, prefix_masks, oracle_length=self.args.oracle_length, length_beam=self.args.length_beam, mbr=self.args.mbr ) prev_decoder_out = prev_decoder_out._replace( step=0, max_step=self.args.max_iterations ) for step in range(self.args.max_iterations): prev_decoder_out = self.denoise_step(model, prev_decoder_out, partial_masks) def finalized_hypos(tokens, scores, partial_mask, history=None): cutoff = ( tokens.ne(self.pad_id) & tokens.ne(self.bos_id) & tokens.ne(self.eos_id) & (~partial_mask) ) tokens = tokens[cutoff] if scores is None: score = None else: scores = scores[cutoff] score = scores.mean().item() ret_dict = { "tokens": tokens, "positional_scores": scores, "score": score, "alignment": None } if history is not None: ret_dict["history"] = [ finalized_hypos(history_tokens, None, partial_mask, history=None) for history_tokens in history ] return ret_dict def mbr_select(hyps): index = np.argmax(np.array( [self.rouge.get_scores([hyps[i]], [[hyps[j]]])['rouge-l']['f'] for j in range(len(hyps)) if i != j] ).mean() for i in range(len(hyps))) return hyps[index] def score_select(hyps): index = np.argmax([hyp["score"] for hyp in hyps]) return hyps[index] output_tokens, output_scores = prev_decoder_out.output_tokens, prev_decoder_out.output_scores if self.retain_history: full_history = prev_decoder_out.history histories = [[full_history[j][i] for j in range(self.args.max_iterations)] for i in range(output_tokens.size(0))] hyps = [] for tokens, scores, partial_mask, history in zip(output_tokens, output_scores, partial_masks, histories): hyps.append(finalized_hypos(tokens, scores, partial_mask, history)) # hyps = [ # finalized_hypos(tokens, scores, partial_mask, history) # for tokens, scores, partial_mask, history in zip(output_tokens, output_scores, partial_masks, histories) # ] else: hyps = [ finalized_hypos(tokens, scores, partial_mask, None) for tokens, scores, partial_mask in zip(output_tokens, output_scores, partial_masks) ] repeatition = self.args.mbr * self.args.length_beam if repeatition > 1: hyps = [score_select(hyps[i:i+repeatition])for i in range(0, len(hyps), repeatition)] # hyps = [mbr_select(hyps[i:i+repeatition])for i in range(0, len(hyps), repeatition)] finalized = pad_sequence([h["tokens"] for h in hyps ], batch_first=True, padding_value=self.pad_id) history = [[item["tokens"] for item in h["history"]] for h in hyps] if self.retain_history else None return finalized, history