typeof commited on
Commit
c55052c
·
verified ·
1 Parent(s): fa8c5ee

Upload folder using huggingface_hub

Browse files
config.json ADDED
@@ -0,0 +1,36 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "architectures": [
3
+ "MoEBertModel"
4
+ ],
5
+ "attention_probs_dropout_prob": 0.1,
6
+ "auto_map": {
7
+ "AutoConfig": "configuration_bert_moe.MoeBertConfig",
8
+ "AutoModel": "modeling_bert_hash.MoeBertModel"
9
+ },
10
+ "classifier_dropout": null,
11
+ "dtype": "float32",
12
+ "gradient_checkpointing": false,
13
+ "hidden_act": "gelu",
14
+ "hidden_dropout_prob": 0.1,
15
+ "hidden_size": 128,
16
+ "initializer_range": 0.02,
17
+ "intermediate_size": 128,
18
+ "layer_norm_eps": 1e-12,
19
+ "max_position_embeddings": 512,
20
+ "model_type": "bert_moe",
21
+ "moebert_expert_dim": 128,
22
+ "moebert_expert_dropout": 0.1,
23
+ "moebert_expert_num": 16,
24
+ "moebert_load_importance": null,
25
+ "moebert_route_hash_list": null,
26
+ "moebert_route_method": "gate-token",
27
+ "moebert_share_importance": 0.5,
28
+ "num_attention_heads": 2,
29
+ "num_hidden_layers": 2,
30
+ "pad_token_id": 0,
31
+ "position_embedding_type": "absolute",
32
+ "transformers_version": "4.57.3",
33
+ "type_vocab_size": 2,
34
+ "use_cache": true,
35
+ "vocab_size": 30522
36
+ }
configuration_bert_moe.py ADDED
@@ -0,0 +1,29 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers.models.bert.configuration_bert import BertConfig
2
+
3
+ class MoeBertConfig(BertConfig):
4
+ """
5
+ Extension of Bert configuration to add projections parameter.
6
+ """
7
+
8
+ model_type = "bert_moe"
9
+
10
+ def __init__(
11
+ self,
12
+ moebert_expert_num = 16,
13
+ moebert_route_method = gate-token,
14
+ moebert_expert_dropout = 0.1,
15
+ moebert_expert_dim = 128,
16
+ moebert_route_hash_list = None,
17
+ moebert_share_importance = 0.5,
18
+ moebert_load_importance = None,
19
+ **kwargs
20
+ ):
21
+ super().__init__(**kwargs)
22
+ self.moebert_expert_num = moebert_expert_num
23
+ self.moebert_route_method = moebert_route_method
24
+ self.moebert_expert_dropout = moebert_expert_dropout
25
+ self.moebert_expert_dim = moebert_expert_dim
26
+ self.moebert_route_hash_list = moebert_route_hash_list
27
+ self.moebert_share_importance = moebert_share_importance
28
+ self.moebert_load_importance = moebert_load_importance
29
+
model.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:bf20feb98e9017638464385ae6c9cc161664b462c8b904838c1b07e62680e9c3
3
+ size 21056872
modeling_bert_moe.py ADDED
@@ -0,0 +1,587 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ # ================================================================================================================
3
+
4
+ from typing import Optional, Tuple
5
+
6
+ import numpy as np
7
+ import pickle
8
+ import torch
9
+ import torch.nn as nn
10
+ import torch.nn.functional as F
11
+ from torch.nn import CrossEntropyLoss, MSELoss
12
+
13
+ from torch import Tensor
14
+
15
+ import copy
16
+
17
+ from dataclasses import dataclass
18
+ from transformers.activations import ACT2FN
19
+ from transformers.file_utils import ModelOutput
20
+
21
+ from transformers.models.bert.modeling_bert import (
22
+ BertAttention,
23
+ BertEmbeddings,
24
+ BertEncoder,
25
+ BertIntermediate,
26
+ BertLayer,
27
+ BertModel,
28
+ BertOutput,
29
+ BertPooler,
30
+ BertPreTrainedModel,
31
+ )
32
+
33
+ import logging
34
+ logger = logging.getLogger(__name__)
35
+
36
+
37
+ def use_experts(layer_idx):
38
+ return True
39
+
40
+
41
+ def process_ffn(model):
42
+ if model.config.model_type == "bert":
43
+ inner_model = model.bert
44
+ else:
45
+ raise ValueError("Model type not recognized.")
46
+
47
+ for i in range(model.config.num_hidden_layers):
48
+ model_layer = inner_model.encoder.layer[i]
49
+
50
+
51
+ class FeedForward(nn.Module):
52
+ def __init__(self, config, intermediate_size, dropout):
53
+ nn.Module.__init__(self)
54
+
55
+ # first layer
56
+ self.fc1 = nn.Linear(config.hidden_size, intermediate_size)
57
+ if isinstance(config.hidden_act, str):
58
+ self.intermediate_act_fn = ACT2FN[config.hidden_act]
59
+ else:
60
+ self.intermediate_act_fn = config.hidden_act
61
+
62
+ # second layer
63
+ self.fc2 = nn.Linear(intermediate_size, config.hidden_size)
64
+ self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
65
+ self.dropout = nn.Dropout(dropout)
66
+
67
+ def forward(self, hidden_states: Tensor):
68
+ input_tensor = hidden_states
69
+ hidden_states = self.fc1(hidden_states)
70
+ hidden_states = self.intermediate_act_fn(hidden_states)
71
+ hidden_states = self.fc2(hidden_states)
72
+ hidden_states = self.dropout(hidden_states)
73
+ hidden_states = self.LayerNorm(hidden_states + input_tensor)
74
+ return hidden_states
75
+
76
+
77
+ @dataclass
78
+ class MoEModelOutput(ModelOutput):
79
+ last_hidden_state: torch.FloatTensor = None
80
+ past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None
81
+ hidden_states: Optional[Tuple[torch.FloatTensor]] = None
82
+ attentions: Optional[Tuple[torch.FloatTensor]] = None
83
+ cross_attentions: Optional[Tuple[torch.FloatTensor]] = None
84
+ gate_loss: torch.FloatTensor = None
85
+
86
+
87
+ @dataclass
88
+ class MoEModelOutputWithPooling(ModelOutput):
89
+ last_hidden_state: torch.FloatTensor = None
90
+ pooler_output: torch.FloatTensor = None
91
+ hidden_states: Optional[Tuple[torch.FloatTensor]] = None
92
+ past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None
93
+ attentions: Optional[Tuple[torch.FloatTensor]] = None
94
+ cross_attentions: Optional[Tuple[torch.FloatTensor]] = None
95
+ gate_loss: torch.FloatTensor = None
96
+
97
+
98
+ # ================================================================================================================
99
+
100
+
101
+ class MoELayer(nn.Module):
102
+ def __init__(self, hidden_size, num_experts, expert, route_method, vocab_size, hash_list):
103
+ nn.Module.__init__(self)
104
+ self.num_experts = num_experts
105
+ self.experts = nn.ModuleList([copy.deepcopy(expert) for i in range(num_experts)])
106
+ self.route_method = route_method
107
+ if route_method in ["gate-token", "gate-sentence"]:
108
+ self.gate = nn.Linear(hidden_size, num_experts, bias=False).float()
109
+ elif route_method == "hash-random":
110
+ self.hash_list = self._random_hash_list(vocab_size)
111
+ elif route_method == "hash-balance":
112
+ self.hash_list = self._balance_hash_list(hash_list)
113
+ else:
114
+ raise KeyError("Routing method not supported.")
115
+
116
+ def _random_hash_list(self, vocab_size):
117
+ hash_list = torch.randint(low=0, high=self.num_experts, size=(vocab_size,))
118
+ return hash_list
119
+
120
+ def _balance_hash_list(self, hash_list):
121
+ with open(hash_list, "rb") as file:
122
+ result = pickle.load(file)
123
+ result = torch.tensor(result, dtype=torch.int64)
124
+ return result
125
+
126
+ def _forward_gate_token(self, x):
127
+ bsz, seq_len, dim = x.size()
128
+
129
+ x = x.view(-1, dim)
130
+ logits_gate = self.gate(x)
131
+ prob_gate = F.softmax(logits_gate, dim=-1)
132
+ gate = torch.argmax(prob_gate, dim=-1)
133
+
134
+ order = gate.argsort(0)
135
+ num_tokens = F.one_hot(gate, self.num_experts).gt(0).sum(0)
136
+ gate_load = num_tokens.clone()
137
+ x = x[order] # reorder according to expert number
138
+ x = x.split(num_tokens.tolist(), dim=0) # a list of length self.num_experts
139
+
140
+ # compute the load balancing loss
141
+ P = prob_gate.mean(0)
142
+ temp = num_tokens.float()
143
+ f = temp / temp.sum(0, keepdim=True)
144
+ balance_loss = self.num_experts * torch.sum(P * f)
145
+
146
+ prob_gate = prob_gate.gather(dim=1, index=gate.unsqueeze(1))
147
+ prob_gate = prob_gate[order]
148
+ prob_gate = prob_gate.split(num_tokens.tolist(), dim=0)
149
+
150
+ def forward_expert(input_x, prob_x, expert_idx):
151
+ input_x = self.experts[expert_idx].forward(input_x)
152
+ input_x = input_x * prob_x
153
+ return input_x
154
+
155
+ x = [forward_expert(x[i], prob_gate[i], i) for i in range(self.num_experts)]
156
+ x = torch.vstack(x)
157
+ x = x[order.argsort(0)] # restore original order
158
+ x = x.view(bsz, seq_len, dim)
159
+
160
+ return x, balance_loss, gate_load
161
+
162
+ def _forward_gate_sentence(self, x, attention_mask):
163
+ x_masked = x * attention_mask.unsqueeze(-1)
164
+ x_average = x_masked.sum(1) / attention_mask.unsqueeze(-1).sum(1)
165
+ logits_gate = self.gate(x_average)
166
+ prob_gate = F.softmax(logits_gate, dim=-1)
167
+ gate = torch.argmax(prob_gate, dim=-1)
168
+
169
+ order = gate.argsort(0)
170
+ num_sentences = F.one_hot(gate, self.num_experts).gt(0).sum(0)
171
+ gate_load = num_sentences.clone()
172
+ x = x[order] # reorder according to expert number
173
+ x = x.split(num_sentences.tolist(), dim=0) # a list of length self.num_experts
174
+
175
+ # compute the load balancing loss
176
+ P = prob_gate.mean(0)
177
+ temp = num_sentences.float()
178
+ f = temp / temp.sum(0, keepdim=True)
179
+ balance_loss = self.num_experts * torch.sum(P * f)
180
+
181
+ prob_gate = prob_gate.gather(dim=1, index=gate.unsqueeze(1))
182
+ prob_gate = prob_gate[order]
183
+ prob_gate = prob_gate.split(num_sentences.tolist(), dim=0)
184
+
185
+ def forward_expert(input_x, prob_x, expert_idx):
186
+ input_x = self.experts[expert_idx].forward(input_x)
187
+ input_x = input_x * prob_x.unsqueeze(-1)
188
+ return input_x
189
+
190
+ result = []
191
+ for i in range(self.num_experts):
192
+ if x[i].size(0) > 0:
193
+ result.append(forward_expert(x[i], prob_gate[i], i))
194
+ result = torch.vstack(result)
195
+ result = result[order.argsort(0)] # restore original order
196
+
197
+ return result, balance_loss, gate_load
198
+
199
+ def _forward_sentence_single_expert(self, x, attention_mask):
200
+ x_masked = x * attention_mask.unsqueeze(-1)
201
+ x_average = x_masked.sum(1) / attention_mask.unsqueeze(-1).sum(1)
202
+ logits_gate = self.gate(x_average)
203
+ prob_gate = F.softmax(logits_gate, dim=-1)
204
+ gate = torch.argmax(prob_gate, dim=-1)
205
+
206
+ gate_load = F.one_hot(gate, self.num_experts).gt(0).sum(0)
207
+ x = self.experts[gate.cpu().item()].forward(x)
208
+ return x, 0.0, gate_load
209
+
210
+ def _forward_hash(self, x, input_ids):
211
+ bsz, seq_len, dim = x.size()
212
+
213
+ x = x.view(-1, dim)
214
+ self.hash_list = self.hash_list.to(x.device)
215
+ gate = self.hash_list[input_ids.view(-1)]
216
+
217
+ order = gate.argsort(0)
218
+ num_tokens = F.one_hot(gate, self.num_experts).gt(0).sum(0)
219
+ gate_load = num_tokens.clone()
220
+ x = x[order] # reorder according to expert number
221
+ x = x.split(num_tokens.tolist(), dim=0) # a list of length self.num_experts
222
+
223
+ x = [self.experts[i].forward(x[i]) for i in range(self.num_experts)]
224
+ x = torch.vstack(x)
225
+ x = x[order.argsort(0)] # restore original order
226
+ x = x.view(bsz, seq_len, dim)
227
+
228
+ return x, 0.0, gate_load
229
+
230
+ def forward(self, x, input_ids, attention_mask):
231
+ if self.route_method == "gate-token":
232
+ x, balance_loss, gate_load = self._forward_gate_token(x)
233
+ elif self.route_method == "gate-sentence":
234
+ if x.size(0) == 1:
235
+ x, balance_loss, gate_load = self._forward_sentence_single_expert(x, attention_mask)
236
+ else:
237
+ x, balance_loss, gate_load = self._forward_gate_sentence(x, attention_mask)
238
+ elif self.route_method in ["hash-random", "hash-balance"]:
239
+ x, balance_loss, gate_load = self._forward_hash(x, input_ids)
240
+ else:
241
+ raise KeyError("Routing method not supported.")
242
+
243
+ return x, balance_loss, gate_load
244
+
245
+ # ================================================================================================================
246
+
247
+
248
+
249
+ def symmetric_KL_loss(p, q):
250
+ """ symmetric KL-divergence 1/2*(KL(p||q)+KL(q||p)) """
251
+ p, q = p.float(), q.float()
252
+ loss = (p - q) * (torch.log(p) - torch.log(q))
253
+ return 0.5 * loss.sum()
254
+
255
+
256
+ def softmax(x):
257
+ return F.softmax(x, dim=-1, dtype=torch.float32)
258
+
259
+
260
+ class MoEBertLayer(BertLayer):
261
+ def __init__(self, config, layer_idx=-100):
262
+ nn.Module.__init__(self)
263
+ self.chunk_size_feed_forward = config.chunk_size_feed_forward
264
+ self.seq_len_dim = 1
265
+ self.attention = BertAttention(config)
266
+ self.is_decoder = config.is_decoder
267
+ self.add_cross_attention = config.add_cross_attention
268
+ if self.add_cross_attention:
269
+ assert self.is_decoder, f"{self} should be used as a decoder model if cross attention is added"
270
+ self.crossattention = BertAttention(config)
271
+ self.intermediate = BertIntermediate(config)
272
+ self.output = BertOutput(config)
273
+
274
+ # construct experts
275
+ self.use_experts = use_experts(layer_idx)
276
+ dropout = config.moebert_expert_dropout if self.use_experts else config.hidden_dropout_prob
277
+ if self.use_experts:
278
+ ffn = FeedForward(config, config.moebert_expert_dim, dropout)
279
+ self.experts = MoELayer(
280
+ hidden_size=config.hidden_size,
281
+ expert=ffn,
282
+ num_experts=config.moebert_expert_num,
283
+ route_method=config.moebert_route_method,
284
+ vocab_size=config.vocab_size,
285
+ hash_list=config.moebert_route_hash_list,
286
+ )
287
+ else:
288
+ self.experts = FeedForward(config, config.intermediate_size, dropout)
289
+
290
+ def forward(
291
+ self,
292
+ hidden_states,
293
+ attention_mask=None,
294
+ head_mask=None,
295
+ encoder_hidden_states=None,
296
+ encoder_attention_mask=None,
297
+ past_key_value=None,
298
+ output_attentions=False,
299
+ expert_input_ids=None,
300
+ expert_attention_mask=None,
301
+ ):
302
+ # decoder uni-directional self-attention cached key/values tuple is at positions 1,2
303
+ self_attn_past_key_value = past_key_value[:2] if past_key_value is not None else None
304
+ self_attention_outputs = self.attention(
305
+ hidden_states,
306
+ attention_mask,
307
+ head_mask,
308
+ output_attentions=output_attentions,
309
+ past_key_value=self_attn_past_key_value,
310
+ )
311
+ attention_output = self_attention_outputs[0]
312
+
313
+ # if decoder, the last output is tuple of self-attn cache
314
+ if self.is_decoder:
315
+ outputs = self_attention_outputs[1:-1]
316
+ present_key_value = self_attention_outputs[-1]
317
+ else:
318
+ outputs = self_attention_outputs[1:] # add self attentions if we output attention weights
319
+
320
+ cross_attn_present_key_value = None
321
+ if self.is_decoder and encoder_hidden_states is not None:
322
+ assert hasattr(
323
+ self, "crossattention"
324
+ ), f"If `encoder_hidden_states` are passed, {self} has to be instantiated with cross-attention layers by setting `config.add_cross_attention=True`"
325
+
326
+ # cross_attn cached key/values tuple is at positions 3,4 of past_key_value tuple
327
+ cross_attn_past_key_value = past_key_value[-2:] if past_key_value is not None else None
328
+ cross_attention_outputs = self.crossattention(
329
+ attention_output,
330
+ attention_mask,
331
+ head_mask,
332
+ encoder_hidden_states,
333
+ encoder_attention_mask,
334
+ cross_attn_past_key_value,
335
+ output_attentions,
336
+ )
337
+ attention_output = cross_attention_outputs[0]
338
+ outputs = outputs + cross_attention_outputs[1:-1] # add cross attentions if we output attention weights
339
+
340
+ # add cross-attn cache to positions 3,4 of present_key_value tuple
341
+ cross_attn_present_key_value = cross_attention_outputs[-1]
342
+ present_key_value = present_key_value + cross_attn_present_key_value
343
+
344
+ layer_output = self.feed_forward(attention_output, expert_input_ids, expert_attention_mask)
345
+ outputs = (layer_output,) + outputs
346
+
347
+ # if decoder, return the attn key/values as the last output
348
+ if self.is_decoder:
349
+ outputs = outputs + (present_key_value,)
350
+
351
+ return outputs
352
+
353
+ def feed_forward(self, attention_output, expert_input_ids, expert_attention_mask):
354
+ if not self.use_experts:
355
+ layer_output = self.experts(attention_output)
356
+ return layer_output, 0.0
357
+
358
+ layer_output, gate_loss, gate_load = self.experts(
359
+ attention_output, expert_input_ids, expert_attention_mask
360
+ )
361
+ return layer_output, gate_loss
362
+
363
+
364
+ class MoEBertEncoder(BertEncoder):
365
+ def __init__(self, config):
366
+ nn.Module.__init__(self)
367
+ self.config = config
368
+ self.layer = nn.ModuleList([MoEBertLayer(config, i) for i in range(config.num_hidden_layers)])
369
+
370
+ def forward(
371
+ self,
372
+ hidden_states,
373
+ attention_mask=None,
374
+ head_mask=None,
375
+ encoder_hidden_states=None,
376
+ encoder_attention_mask=None,
377
+ past_key_values=None,
378
+ use_cache=None,
379
+ output_attentions=False,
380
+ output_hidden_states=False,
381
+ return_dict=True,
382
+ expert_input_ids=None,
383
+ expert_attention_mask=None,
384
+ ):
385
+ all_hidden_states = () if output_hidden_states else None
386
+ all_self_attentions = () if output_attentions else None
387
+ all_cross_attentions = () if output_attentions and self.config.add_cross_attention else None
388
+
389
+ next_decoder_cache = () if use_cache else None
390
+ gate_loss = 0.0
391
+ for i, layer_module in enumerate(self.layer):
392
+ if output_hidden_states:
393
+ all_hidden_states = all_hidden_states + (hidden_states,)
394
+
395
+ layer_head_mask = head_mask[i] if head_mask is not None else None
396
+ past_key_value = past_key_values[i] if past_key_values is not None else None
397
+
398
+ if getattr(self.config, "gradient_checkpointing", False) and self.training:
399
+
400
+ if use_cache:
401
+ logger.warn(
402
+ "`use_cache=True` is incompatible with `config.gradient_checkpointing=True`. Setting "
403
+ "`use_cache=False`..."
404
+ )
405
+ use_cache = False
406
+
407
+ def create_custom_forward(module):
408
+ def custom_forward(*inputs):
409
+ return module(*inputs, past_key_value, output_attentions)
410
+
411
+ return custom_forward
412
+
413
+ layer_outputs = torch.utils.checkpoint.checkpoint(
414
+ create_custom_forward(layer_module),
415
+ hidden_states,
416
+ attention_mask,
417
+ layer_head_mask,
418
+ encoder_hidden_states,
419
+ encoder_attention_mask,
420
+ )
421
+ else:
422
+ layer_outputs = layer_module(
423
+ hidden_states,
424
+ attention_mask,
425
+ layer_head_mask,
426
+ encoder_hidden_states,
427
+ encoder_attention_mask,
428
+ past_key_value,
429
+ output_attentions,
430
+ expert_input_ids,
431
+ expert_attention_mask,
432
+ )
433
+
434
+ hidden_states = layer_outputs[0][0]
435
+ gate_loss = gate_loss + layer_outputs[0][1]
436
+ if use_cache:
437
+ next_decoder_cache += (layer_outputs[-1],)
438
+ if output_attentions:
439
+ all_self_attentions = all_self_attentions + (layer_outputs[1],)
440
+ if self.config.add_cross_attention:
441
+ all_cross_attentions = all_cross_attentions + (layer_outputs[2],)
442
+
443
+ if output_hidden_states:
444
+ all_hidden_states = all_hidden_states + (hidden_states,)
445
+
446
+ if not return_dict:
447
+ return tuple(
448
+ v
449
+ for v in [
450
+ hidden_states,
451
+ next_decoder_cache,
452
+ all_hidden_states,
453
+ all_self_attentions,
454
+ all_cross_attentions,
455
+ ]
456
+ if v is not None
457
+ )
458
+ return MoEModelOutput(
459
+ last_hidden_state=hidden_states,
460
+ past_key_values=next_decoder_cache,
461
+ hidden_states=all_hidden_states,
462
+ attentions=all_self_attentions,
463
+ cross_attentions=all_cross_attentions,
464
+ gate_loss=gate_loss,
465
+ )
466
+
467
+
468
+ class MoEBertModel(BertModel):
469
+ def __init__(self, config, add_pooling_layer=True):
470
+ BertModel.__init__(self, config)
471
+ self.config = config
472
+
473
+ self.embeddings = BertEmbeddings(config)
474
+ self.encoder = MoEBertEncoder(config)
475
+
476
+ self.pooler = BertPooler(config) if add_pooling_layer else None
477
+
478
+ self.init_weights()
479
+
480
+ def forward(
481
+ self,
482
+ input_ids=None,
483
+ attention_mask=None,
484
+ token_type_ids=None,
485
+ position_ids=None,
486
+ head_mask=None,
487
+ inputs_embeds=None,
488
+ encoder_hidden_states=None,
489
+ encoder_attention_mask=None,
490
+ past_key_values=None,
491
+ use_cache=None,
492
+ output_attentions=None,
493
+ output_hidden_states=None,
494
+ return_dict=None,
495
+ expert_input_ids=None,
496
+ expert_attention_mask=None,
497
+ ):
498
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
499
+ output_hidden_states = (
500
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
501
+ )
502
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
503
+
504
+ if self.config.is_decoder:
505
+ use_cache = use_cache if use_cache is not None else self.config.use_cache
506
+ else:
507
+ use_cache = False
508
+
509
+ if input_ids is not None and inputs_embeds is not None:
510
+ raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
511
+ elif input_ids is not None:
512
+ input_shape = input_ids.size()
513
+ batch_size, seq_length = input_shape
514
+ elif inputs_embeds is not None:
515
+ input_shape = inputs_embeds.size()[:-1]
516
+ batch_size, seq_length = input_shape
517
+ else:
518
+ raise ValueError("You have to specify either input_ids or inputs_embeds")
519
+
520
+ device = input_ids.device if input_ids is not None else inputs_embeds.device
521
+
522
+ # past_key_values_length
523
+ past_key_values_length = past_key_values[0][0].shape[2] if past_key_values is not None else 0
524
+
525
+ if attention_mask is None:
526
+ attention_mask = torch.ones(((batch_size, seq_length + past_key_values_length)), device=device)
527
+ if token_type_ids is None:
528
+ token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=device)
529
+
530
+ # We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length]
531
+ # ourselves in which case we just need to make it broadcastable to all heads.
532
+ extended_attention_mask: torch.Tensor = self.get_extended_attention_mask(attention_mask, input_shape, device)
533
+
534
+ # If a 2D or 3D attention mask is provided for the cross-attention
535
+ # we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length]
536
+ if self.config.is_decoder and encoder_hidden_states is not None:
537
+ encoder_batch_size, encoder_sequence_length, _ = encoder_hidden_states.size()
538
+ encoder_hidden_shape = (encoder_batch_size, encoder_sequence_length)
539
+ if encoder_attention_mask is None:
540
+ encoder_attention_mask = torch.ones(encoder_hidden_shape, device=device)
541
+ encoder_extended_attention_mask = self.invert_attention_mask(encoder_attention_mask)
542
+ else:
543
+ encoder_extended_attention_mask = None
544
+
545
+ # Prepare head mask if needed
546
+ # 1.0 in head_mask indicate we keep the head
547
+ # attention_probs has shape bsz x n_heads x N x N
548
+ # input head_mask has shape [num_heads] or [num_hidden_layers x num_heads]
549
+ # and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length]
550
+ head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers)
551
+
552
+ embedding_output = self.embeddings(
553
+ input_ids=input_ids,
554
+ position_ids=position_ids,
555
+ token_type_ids=token_type_ids,
556
+ inputs_embeds=inputs_embeds,
557
+ past_key_values_length=past_key_values_length,
558
+ )
559
+ encoder_outputs = self.encoder(
560
+ embedding_output,
561
+ attention_mask=extended_attention_mask,
562
+ head_mask=head_mask,
563
+ encoder_hidden_states=encoder_hidden_states,
564
+ encoder_attention_mask=encoder_extended_attention_mask,
565
+ past_key_values=past_key_values,
566
+ use_cache=use_cache,
567
+ output_attentions=output_attentions,
568
+ output_hidden_states=output_hidden_states,
569
+ return_dict=return_dict,
570
+ expert_input_ids=expert_input_ids,
571
+ expert_attention_mask=expert_attention_mask,
572
+ )
573
+ sequence_output = encoder_outputs[0]
574
+ pooled_output = self.pooler(sequence_output) if self.pooler is not None else None
575
+
576
+ if not return_dict:
577
+ return (sequence_output, pooled_output) + encoder_outputs[1:]
578
+
579
+ return MoEModelOutputWithPooling(
580
+ last_hidden_state=sequence_output,
581
+ pooler_output=pooled_output,
582
+ past_key_values=encoder_outputs.past_key_values,
583
+ hidden_states=encoder_outputs.hidden_states,
584
+ attentions=encoder_outputs.attentions,
585
+ cross_attentions=encoder_outputs.cross_attentions,
586
+ gate_loss=encoder_outputs.gate_loss,
587
+ )
special_tokens_map.json ADDED
@@ -0,0 +1,37 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cls_token": {
3
+ "content": "[CLS]",
4
+ "lstrip": false,
5
+ "normalized": false,
6
+ "rstrip": false,
7
+ "single_word": false
8
+ },
9
+ "mask_token": {
10
+ "content": "[MASK]",
11
+ "lstrip": false,
12
+ "normalized": false,
13
+ "rstrip": false,
14
+ "single_word": false
15
+ },
16
+ "pad_token": {
17
+ "content": "[PAD]",
18
+ "lstrip": false,
19
+ "normalized": false,
20
+ "rstrip": false,
21
+ "single_word": false
22
+ },
23
+ "sep_token": {
24
+ "content": "[SEP]",
25
+ "lstrip": false,
26
+ "normalized": false,
27
+ "rstrip": false,
28
+ "single_word": false
29
+ },
30
+ "unk_token": {
31
+ "content": "[UNK]",
32
+ "lstrip": false,
33
+ "normalized": false,
34
+ "rstrip": false,
35
+ "single_word": false
36
+ }
37
+ }
tokenizer.json ADDED
The diff for this file is too large to render. See raw diff
 
tokenizer_config.json ADDED
@@ -0,0 +1,58 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "added_tokens_decoder": {
3
+ "0": {
4
+ "content": "[PAD]",
5
+ "lstrip": false,
6
+ "normalized": false,
7
+ "rstrip": false,
8
+ "single_word": false,
9
+ "special": true
10
+ },
11
+ "100": {
12
+ "content": "[UNK]",
13
+ "lstrip": false,
14
+ "normalized": false,
15
+ "rstrip": false,
16
+ "single_word": false,
17
+ "special": true
18
+ },
19
+ "101": {
20
+ "content": "[CLS]",
21
+ "lstrip": false,
22
+ "normalized": false,
23
+ "rstrip": false,
24
+ "single_word": false,
25
+ "special": true
26
+ },
27
+ "102": {
28
+ "content": "[SEP]",
29
+ "lstrip": false,
30
+ "normalized": false,
31
+ "rstrip": false,
32
+ "single_word": false,
33
+ "special": true
34
+ },
35
+ "103": {
36
+ "content": "[MASK]",
37
+ "lstrip": false,
38
+ "normalized": false,
39
+ "rstrip": false,
40
+ "single_word": false,
41
+ "special": true
42
+ }
43
+ },
44
+ "clean_up_tokenization_spaces": true,
45
+ "cls_token": "[CLS]",
46
+ "do_basic_tokenize": true,
47
+ "do_lower_case": true,
48
+ "extra_special_tokens": {},
49
+ "mask_token": "[MASK]",
50
+ "model_max_length": 512,
51
+ "never_split": null,
52
+ "pad_token": "[PAD]",
53
+ "sep_token": "[SEP]",
54
+ "strip_accents": null,
55
+ "tokenize_chinese_chars": true,
56
+ "tokenizer_class": "BertTokenizer",
57
+ "unk_token": "[UNK]"
58
+ }
vocab.txt ADDED
The diff for this file is too large to render. See raw diff