# ================================================================================================================ from typing import Optional, Tuple import numpy as np import pickle import torch import torch.nn as nn import torch.nn.functional as F from torch.nn import CrossEntropyLoss, MSELoss from torch import Tensor import copy from dataclasses import dataclass from transformers.activations import ACT2FN from transformers.file_utils import ModelOutput from transformers.models.bert.modeling_bert import ( BertAttention, BertEmbeddings, BertEncoder, BertIntermediate, BertLayer, BertModel, BertOutput, BertPooler, BertPreTrainedModel, ) import logging logger = logging.getLogger(__name__) def use_experts(layer_idx): return True def process_ffn(model): if model.config.model_type == "bert": inner_model = model.bert else: raise ValueError("Model type not recognized.") for i in range(model.config.num_hidden_layers): model_layer = inner_model.encoder.layer[i] class FeedForward(nn.Module): def __init__(self, config, intermediate_size, dropout): nn.Module.__init__(self) # first layer self.fc1 = nn.Linear(config.hidden_size, intermediate_size) if isinstance(config.hidden_act, str): self.intermediate_act_fn = ACT2FN[config.hidden_act] else: self.intermediate_act_fn = config.hidden_act # second layer self.fc2 = nn.Linear(intermediate_size, config.hidden_size) self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) self.dropout = nn.Dropout(dropout) def forward(self, hidden_states: Tensor): input_tensor = hidden_states hidden_states = self.fc1(hidden_states) hidden_states = self.intermediate_act_fn(hidden_states) hidden_states = self.fc2(hidden_states) hidden_states = self.dropout(hidden_states) hidden_states = self.LayerNorm(hidden_states + input_tensor) return hidden_states @dataclass class MoEModelOutput(ModelOutput): last_hidden_state: torch.FloatTensor = None past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None hidden_states: Optional[Tuple[torch.FloatTensor]] = None attentions: Optional[Tuple[torch.FloatTensor]] = None cross_attentions: Optional[Tuple[torch.FloatTensor]] = None gate_loss: torch.FloatTensor = None @dataclass class MoEModelOutputWithPooling(ModelOutput): last_hidden_state: torch.FloatTensor = None pooler_output: torch.FloatTensor = None hidden_states: Optional[Tuple[torch.FloatTensor]] = None past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None attentions: Optional[Tuple[torch.FloatTensor]] = None cross_attentions: Optional[Tuple[torch.FloatTensor]] = None gate_loss: torch.FloatTensor = None # ================================================================================================================ class MoELayer(nn.Module): def __init__(self, hidden_size, num_experts, expert, route_method, vocab_size, hash_list): nn.Module.__init__(self) self.num_experts = num_experts self.experts = nn.ModuleList([copy.deepcopy(expert) for i in range(num_experts)]) self.route_method = route_method if route_method in ["gate-token", "gate-sentence"]: self.gate = nn.Linear(hidden_size, num_experts, bias=False).float() elif route_method == "hash-random": self.hash_list = self._random_hash_list(vocab_size) elif route_method == "hash-balance": self.hash_list = self._balance_hash_list(hash_list) else: raise KeyError("Routing method not supported.") def _random_hash_list(self, vocab_size): hash_list = torch.randint(low=0, high=self.num_experts, size=(vocab_size,)) return hash_list def _balance_hash_list(self, hash_list): with open(hash_list, "rb") as file: result = pickle.load(file) result = torch.tensor(result, dtype=torch.int64) return result def _forward_gate_token(self, x): bsz, seq_len, dim = x.size() x = x.view(-1, dim) logits_gate = self.gate(x) prob_gate = F.softmax(logits_gate, dim=-1) gate = torch.argmax(prob_gate, dim=-1) order = gate.argsort(0) num_tokens = F.one_hot(gate, self.num_experts).gt(0).sum(0) gate_load = num_tokens.clone() x = x[order] # reorder according to expert number x = x.split(num_tokens.tolist(), dim=0) # a list of length self.num_experts # compute the load balancing loss P = prob_gate.mean(0) temp = num_tokens.float() f = temp / temp.sum(0, keepdim=True) balance_loss = self.num_experts * torch.sum(P * f) prob_gate = prob_gate.gather(dim=1, index=gate.unsqueeze(1)) prob_gate = prob_gate[order] prob_gate = prob_gate.split(num_tokens.tolist(), dim=0) def forward_expert(input_x, prob_x, expert_idx): input_x = self.experts[expert_idx].forward(input_x) input_x = input_x * prob_x return input_x x = [forward_expert(x[i], prob_gate[i], i) for i in range(self.num_experts)] x = torch.vstack(x) x = x[order.argsort(0)] # restore original order x = x.view(bsz, seq_len, dim) return x, balance_loss, gate_load def _forward_gate_sentence(self, x, attention_mask): x_masked = x * attention_mask.unsqueeze(-1) x_average = x_masked.sum(1) / attention_mask.unsqueeze(-1).sum(1) logits_gate = self.gate(x_average) prob_gate = F.softmax(logits_gate, dim=-1) gate = torch.argmax(prob_gate, dim=-1) order = gate.argsort(0) num_sentences = F.one_hot(gate, self.num_experts).gt(0).sum(0) gate_load = num_sentences.clone() x = x[order] # reorder according to expert number x = x.split(num_sentences.tolist(), dim=0) # a list of length self.num_experts # compute the load balancing loss P = prob_gate.mean(0) temp = num_sentences.float() f = temp / temp.sum(0, keepdim=True) balance_loss = self.num_experts * torch.sum(P * f) prob_gate = prob_gate.gather(dim=1, index=gate.unsqueeze(1)) prob_gate = prob_gate[order] prob_gate = prob_gate.split(num_sentences.tolist(), dim=0) def forward_expert(input_x, prob_x, expert_idx): input_x = self.experts[expert_idx].forward(input_x) input_x = input_x * prob_x.unsqueeze(-1) return input_x result = [] for i in range(self.num_experts): if x[i].size(0) > 0: result.append(forward_expert(x[i], prob_gate[i], i)) result = torch.vstack(result) result = result[order.argsort(0)] # restore original order return result, balance_loss, gate_load def _forward_sentence_single_expert(self, x, attention_mask): x_masked = x * attention_mask.unsqueeze(-1) x_average = x_masked.sum(1) / attention_mask.unsqueeze(-1).sum(1) logits_gate = self.gate(x_average) prob_gate = F.softmax(logits_gate, dim=-1) gate = torch.argmax(prob_gate, dim=-1) gate_load = F.one_hot(gate, self.num_experts).gt(0).sum(0) x = self.experts[gate.cpu().item()].forward(x) return x, 0.0, gate_load def _forward_hash(self, x, input_ids): bsz, seq_len, dim = x.size() x = x.view(-1, dim) self.hash_list = self.hash_list.to(x.device) gate = self.hash_list[input_ids.view(-1)] order = gate.argsort(0) num_tokens = F.one_hot(gate, self.num_experts).gt(0).sum(0) gate_load = num_tokens.clone() x = x[order] # reorder according to expert number x = x.split(num_tokens.tolist(), dim=0) # a list of length self.num_experts x = [self.experts[i].forward(x[i]) for i in range(self.num_experts)] x = torch.vstack(x) x = x[order.argsort(0)] # restore original order x = x.view(bsz, seq_len, dim) return x, 0.0, gate_load def forward(self, x, input_ids, attention_mask): if self.route_method == "gate-token": x, balance_loss, gate_load = self._forward_gate_token(x) elif self.route_method == "gate-sentence": if x.size(0) == 1: x, balance_loss, gate_load = self._forward_sentence_single_expert(x, attention_mask) else: x, balance_loss, gate_load = self._forward_gate_sentence(x, attention_mask) elif self.route_method in ["hash-random", "hash-balance"]: x, balance_loss, gate_load = self._forward_hash(x, input_ids) else: raise KeyError("Routing method not supported.") return x, balance_loss, gate_load # ================================================================================================================ def symmetric_KL_loss(p, q): """ symmetric KL-divergence 1/2*(KL(p||q)+KL(q||p)) """ p, q = p.float(), q.float() loss = (p - q) * (torch.log(p) - torch.log(q)) return 0.5 * loss.sum() def softmax(x): return F.softmax(x, dim=-1, dtype=torch.float32) class MoEBertLayer(BertLayer): def __init__(self, config, layer_idx=-100): nn.Module.__init__(self) self.chunk_size_feed_forward = config.chunk_size_feed_forward self.seq_len_dim = 1 self.attention = BertAttention(config) self.is_decoder = config.is_decoder self.add_cross_attention = config.add_cross_attention if self.add_cross_attention: assert self.is_decoder, f"{self} should be used as a decoder model if cross attention is added" self.crossattention = BertAttention(config) self.intermediate = BertIntermediate(config) self.output = BertOutput(config) # construct experts self.use_experts = use_experts(layer_idx) dropout = config.moebert_expert_dropout if self.use_experts else config.hidden_dropout_prob if self.use_experts: ffn = FeedForward(config, config.moebert_expert_dim, dropout) self.experts = MoELayer( hidden_size=config.hidden_size, expert=ffn, num_experts=config.moebert_expert_num, route_method=config.moebert_route_method, vocab_size=config.vocab_size, hash_list=config.moebert_route_hash_list, ) else: self.experts = FeedForward(config, config.intermediate_size, dropout) def forward( self, hidden_states, attention_mask=None, head_mask=None, encoder_hidden_states=None, encoder_attention_mask=None, past_key_value=None, output_attentions=False, expert_input_ids=None, expert_attention_mask=None, ): # decoder uni-directional self-attention cached key/values tuple is at positions 1,2 self_attn_past_key_value = past_key_value[:2] if past_key_value is not None else None self_attention_outputs = self.attention( hidden_states, attention_mask, head_mask, output_attentions=output_attentions, past_key_value=self_attn_past_key_value, ) attention_output = self_attention_outputs[0] # if decoder, the last output is tuple of self-attn cache if self.is_decoder: outputs = self_attention_outputs[1:-1] present_key_value = self_attention_outputs[-1] else: outputs = self_attention_outputs[1:] # add self attentions if we output attention weights cross_attn_present_key_value = None if self.is_decoder and encoder_hidden_states is not None: assert hasattr( self, "crossattention" ), f"If `encoder_hidden_states` are passed, {self} has to be instantiated with cross-attention layers by setting `config.add_cross_attention=True`" # cross_attn cached key/values tuple is at positions 3,4 of past_key_value tuple cross_attn_past_key_value = past_key_value[-2:] if past_key_value is not None else None cross_attention_outputs = self.crossattention( attention_output, attention_mask, head_mask, encoder_hidden_states, encoder_attention_mask, cross_attn_past_key_value, output_attentions, ) attention_output = cross_attention_outputs[0] outputs = outputs + cross_attention_outputs[1:-1] # add cross attentions if we output attention weights # add cross-attn cache to positions 3,4 of present_key_value tuple cross_attn_present_key_value = cross_attention_outputs[-1] present_key_value = present_key_value + cross_attn_present_key_value layer_output = self.feed_forward(attention_output, expert_input_ids, expert_attention_mask) outputs = (layer_output,) + outputs # if decoder, return the attn key/values as the last output if self.is_decoder: outputs = outputs + (present_key_value,) return outputs def feed_forward(self, attention_output, expert_input_ids, expert_attention_mask): if not self.use_experts: layer_output = self.experts(attention_output) return layer_output, 0.0 layer_output, gate_loss, gate_load = self.experts( attention_output, expert_input_ids, expert_attention_mask ) return layer_output, gate_loss class MoEBertEncoder(BertEncoder): def __init__(self, config): nn.Module.__init__(self) self.config = config self.layer = nn.ModuleList([MoEBertLayer(config, i) for i in range(config.num_hidden_layers)]) def forward( self, hidden_states, attention_mask=None, head_mask=None, encoder_hidden_states=None, encoder_attention_mask=None, past_key_values=None, use_cache=None, output_attentions=False, output_hidden_states=False, return_dict=True, expert_input_ids=None, expert_attention_mask=None, ): all_hidden_states = () if output_hidden_states else None all_self_attentions = () if output_attentions else None all_cross_attentions = () if output_attentions and self.config.add_cross_attention else None next_decoder_cache = () if use_cache else None gate_loss = 0.0 for i, layer_module in enumerate(self.layer): if output_hidden_states: all_hidden_states = all_hidden_states + (hidden_states,) layer_head_mask = head_mask[i] if head_mask is not None else None past_key_value = past_key_values[i] if past_key_values is not None else None if getattr(self.config, "gradient_checkpointing", False) and self.training: if use_cache: logger.warn( "`use_cache=True` is incompatible with `config.gradient_checkpointing=True`. Setting " "`use_cache=False`..." ) use_cache = False def create_custom_forward(module): def custom_forward(*inputs): return module(*inputs, past_key_value, output_attentions) return custom_forward layer_outputs = torch.utils.checkpoint.checkpoint( create_custom_forward(layer_module), hidden_states, attention_mask, layer_head_mask, encoder_hidden_states, encoder_attention_mask, ) else: layer_outputs = layer_module( hidden_states, attention_mask, layer_head_mask, encoder_hidden_states, encoder_attention_mask, past_key_value, output_attentions, expert_input_ids, expert_attention_mask, ) hidden_states = layer_outputs[0][0] gate_loss = gate_loss + layer_outputs[0][1] if use_cache: next_decoder_cache += (layer_outputs[-1],) if output_attentions: all_self_attentions = all_self_attentions + (layer_outputs[1],) if self.config.add_cross_attention: all_cross_attentions = all_cross_attentions + (layer_outputs[2],) if output_hidden_states: all_hidden_states = all_hidden_states + (hidden_states,) if not return_dict: return tuple( v for v in [ hidden_states, next_decoder_cache, all_hidden_states, all_self_attentions, all_cross_attentions, ] if v is not None ) return MoEModelOutput( last_hidden_state=hidden_states, past_key_values=next_decoder_cache, hidden_states=all_hidden_states, attentions=all_self_attentions, cross_attentions=all_cross_attentions, gate_loss=gate_loss, ) class MoEBertModel(BertModel): def __init__(self, config, add_pooling_layer=True): BertModel.__init__(self, config) self.config = config self.embeddings = BertEmbeddings(config) self.encoder = MoEBertEncoder(config) self.pooler = BertPooler(config) if add_pooling_layer else None self.init_weights() def forward( self, input_ids=None, attention_mask=None, token_type_ids=None, position_ids=None, head_mask=None, inputs_embeds=None, encoder_hidden_states=None, encoder_attention_mask=None, past_key_values=None, use_cache=None, output_attentions=None, output_hidden_states=None, return_dict=None, expert_input_ids=None, expert_attention_mask=None, ): output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_hidden_states = ( output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states ) return_dict = return_dict if return_dict is not None else self.config.use_return_dict if self.config.is_decoder: use_cache = use_cache if use_cache is not None else self.config.use_cache else: use_cache = False if input_ids is not None and inputs_embeds is not None: raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time") elif input_ids is not None: input_shape = input_ids.size() batch_size, seq_length = input_shape elif inputs_embeds is not None: input_shape = inputs_embeds.size()[:-1] batch_size, seq_length = input_shape else: raise ValueError("You have to specify either input_ids or inputs_embeds") device = input_ids.device if input_ids is not None else inputs_embeds.device # past_key_values_length past_key_values_length = past_key_values[0][0].shape[2] if past_key_values is not None else 0 if attention_mask is None: attention_mask = torch.ones(((batch_size, seq_length + past_key_values_length)), device=device) if token_type_ids is None: token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=device) # We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length] # ourselves in which case we just need to make it broadcastable to all heads. extended_attention_mask: torch.Tensor = self.get_extended_attention_mask(attention_mask, input_shape, device) # If a 2D or 3D attention mask is provided for the cross-attention # we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length] if self.config.is_decoder and encoder_hidden_states is not None: encoder_batch_size, encoder_sequence_length, _ = encoder_hidden_states.size() encoder_hidden_shape = (encoder_batch_size, encoder_sequence_length) if encoder_attention_mask is None: encoder_attention_mask = torch.ones(encoder_hidden_shape, device=device) encoder_extended_attention_mask = self.invert_attention_mask(encoder_attention_mask) else: encoder_extended_attention_mask = None # Prepare head mask if needed # 1.0 in head_mask indicate we keep the head # attention_probs has shape bsz x n_heads x N x N # input head_mask has shape [num_heads] or [num_hidden_layers x num_heads] # and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length] head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers) embedding_output = self.embeddings( input_ids=input_ids, position_ids=position_ids, token_type_ids=token_type_ids, inputs_embeds=inputs_embeds, past_key_values_length=past_key_values_length, ) encoder_outputs = self.encoder( embedding_output, attention_mask=extended_attention_mask, head_mask=head_mask, encoder_hidden_states=encoder_hidden_states, encoder_attention_mask=encoder_extended_attention_mask, past_key_values=past_key_values, use_cache=use_cache, output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, expert_input_ids=expert_input_ids, expert_attention_mask=expert_attention_mask, ) sequence_output = encoder_outputs[0] pooled_output = self.pooler(sequence_output) if self.pooler is not None else None if not return_dict: return (sequence_output, pooled_output) + encoder_outputs[1:] return MoEModelOutputWithPooling( last_hidden_state=sequence_output, pooler_output=pooled_output, past_key_values=encoder_outputs.past_key_values, hidden_states=encoder_outputs.hidden_states, attentions=encoder_outputs.attentions, cross_attentions=encoder_outputs.cross_attentions, gate_loss=encoder_outputs.gate_loss, )