MariaFjodorowa commited on
Commit
6345ddb
·
verified ·
1 Parent(s): dc8da07

Upload folder using huggingface_hub

Browse files
README.md ADDED
@@ -0,0 +1,162 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ language:
3
+ - en
4
+ - eng
5
+ inference: false
6
+ tags:
7
+ - T5
8
+ - t5
9
+ - HPLT
10
+ - encoder-decoder
11
+ - text2text-generation
12
+ license: apache-2.0
13
+ datasets:
14
+ - HPLT/HPLT3.0
15
+ ---
16
+
17
+ # HPLT v3.0 T5 for English
18
+
19
+ <img src="https://hplt-project.org/_next/static/media/logo-hplt.d5e16ca5.svg" width=12.5%>
20
+
21
+ This is one of the encoder-decoder monolingual language models trained as a third release by the [HPLT project](https://hplt-project.org/).
22
+ It is a text-to-text transformer trained with a denoising objective. Our
23
+ models follow the setup of [NorT5](https://aclanthology.org/2023.nodalida-1.61/).
24
+
25
+ We present monolingual NorT5 models for 57 languages out of 198 total in the [HPLT v3.0 dataset](https://hplt-project.org/datasets/v3.0).
26
+
27
+ All the HPLT encoder-decoder models use the same hyper-parameters, roughly following the T5-base setup:
28
+ - hidden size: 768
29
+ - attention heads: 12
30
+ - layers: 12
31
+ - vocabulary size: 32768
32
+
33
+ Every model uses its own tokenizer trained on language-specific HPLT data.
34
+
35
+ [The training code](https://github.com/hplt-project/HPLT-WP4).
36
+
37
+ ## Example usage
38
+
39
+ This model currently needs a custom wrapper from `modeling_nort5.py`, you should therefore load the model with `trust_remote_code=True`.
40
+
41
+ ```
42
+ pip install transformers==4.46.1
43
+ ```
44
+
45
+ ```python
46
+ from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
47
+
48
+ model_path = 'HPLT/hplt_t5_base_3_0_nob_Latn'
49
+ model = AutoModelForSeq2SeqLM.from_pretrained(
50
+ model_path, trust_remote_code=True, use_safetensors=False,
51
+ )
52
+ tokenizer = AutoTokenizer.from_pretrained(model_path)
53
+ # MASKED LANGUAGE MODELING
54
+ sentence = "Ansiktsuttrykket [MASK_1] har utviklet seg til et utbredt kulturelt fenomen."
55
+ encoding = tokenizer(sentence, return_tensors="pt")
56
+ mask_1 = tokenizer.convert_tokens_to_ids("[MASK_1]")
57
+ mask_2 = tokenizer.convert_tokens_to_ids("[MASK_2]")
58
+ output_tensor = model.generate(
59
+ encoding.input_ids,
60
+ decoder_start_token_id=mask_1,
61
+ eos_token_id=mask_2,
62
+ )
63
+ print(tokenizer.decode(output_tensor.squeeze(), skip_special_tokens=False))
64
+ # should output: '[MASK_1]«The Great Gatsby»[MASK_2]'
65
+ ```
66
+
67
+ ## Intermediate checkpoints
68
+
69
+ We are releasing 10 intermediate checkpoints for each model at intervals of every 3125 training steps in separate branches. The naming convention is `stepXXX`: for example, `step18750`.
70
+
71
+ You can load a specific model revision with `transformers` using the argument `revision`:
72
+ ```python
73
+ model = AutoModelForSeq2SeqLM.from_pretrained("HPLT/hplt_t5_base_3_0_eng_Latn", revision="step21875", trust_remote_code=True)
74
+ ```
75
+
76
+ You can access all the revisions for the models with the following code:
77
+ ```python
78
+ from huggingface_hub import list_repo_refs
79
+ out = list_repo_refs("HPLT/hplt_t5_base_3_0_eng_Latn")
80
+ print([b.name for b in out.branches])
81
+ ```
82
+
83
+ ## Cite us
84
+
85
+ ```bibtex
86
+ @inproceedings{samuel-etal-2023-norbench,
87
+ title = "{N}or{B}ench {--} A Benchmark for {N}orwegian Language Models",
88
+ author = "Samuel, David and
89
+ Kutuzov, Andrey and
90
+ Touileb, Samia and
91
+ Velldal, Erik and
92
+ {\O}vrelid, Lilja and
93
+ R{\o}nningstad, Egil and
94
+ Sigdel, Elina and
95
+ Palatkina, Anna",
96
+ editor = {Alum{\"a}e, Tanel and
97
+ Fishel, Mark},
98
+ booktitle = "Proceedings of the 24th Nordic Conference on Computational Linguistics (NoDaLiDa)",
99
+ month = may,
100
+ year = "2023",
101
+ address = "T{\'o}rshavn, Faroe Islands",
102
+ publisher = "University of Tartu Library",
103
+ url = "https://aclanthology.org/2023.nodalida-1.61/",
104
+ pages = "618--633",
105
+ abstract = "We present NorBench: a streamlined suite of NLP tasks and probes for evaluating Norwegian language models (LMs) on standardized data splits and evaluation metrics. We also introduce a range of new Norwegian language models (both encoder and encoder-decoder based). Finally, we compare and analyze their performance, along with other existing LMs, across the different benchmark tests of NorBench."
106
+ }
107
+ ```
108
+
109
+ ```bibtex
110
+ @inproceedings{burchell-etal-2025-expanded,
111
+ title = "An Expanded Massive Multilingual Dataset for High-Performance Language Technologies ({HPLT})",
112
+ author = {Burchell, Laurie and
113
+ de Gibert, Ona and
114
+ Arefyev, Nikolay and
115
+ Aulamo, Mikko and
116
+ Ba{\~n}{\'o}n, Marta and
117
+ Chen, Pinzhen and
118
+ Fedorova, Mariia and
119
+ Guillou, Liane and
120
+ Haddow, Barry and
121
+ Haji{\v{c}}, Jan and
122
+ Helcl, Jind{\v{r}}ich and
123
+ Henriksson, Erik and
124
+ Klimaszewski, Mateusz and
125
+ Komulainen, Ville and
126
+ Kutuzov, Andrey and
127
+ Kyt{\"o}niemi, Joona and
128
+ Laippala, Veronika and
129
+ M{\ae}hlum, Petter and
130
+ Malik, Bhavitvya and
131
+ Mehryary, Farrokh and
132
+ Mikhailov, Vladislav and
133
+ Moghe, Nikita and
134
+ Myntti, Amanda and
135
+ O{'}Brien, Dayy{\'a}n and
136
+ Oepen, Stephan and
137
+ Pal, Proyag and
138
+ Piha, Jousia and
139
+ Pyysalo, Sampo and
140
+ Ram{\'i}rez-S{\'a}nchez, Gema and
141
+ Samuel, David and
142
+ Stepachev, Pavel and
143
+ Tiedemann, J{\"o}rg and
144
+ Vari{\v{s}}, Du{\v{s}}an and
145
+ Vojt{\v{e}}chov{\'a}, Tereza and
146
+ Zaragoza-Bernabeu, Jaume},
147
+ editor = "Che, Wanxiang and
148
+ Nabende, Joyce and
149
+ Shutova, Ekaterina and
150
+ Pilehvar, Mohammad Taher",
151
+ booktitle = "Proceedings of the 63rd Annual Meeting of the Association for Computational Linguistics (Volume 1: Long Papers)",
152
+ month = jul,
153
+ year = "2025",
154
+ address = "Vienna, Austria",
155
+ publisher = "Association for Computational Linguistics",
156
+ url = "https://aclanthology.org/2025.acl-long.854/",
157
+ doi = "10.18653/v1/2025.acl-long.854",
158
+ pages = "17452--17485",
159
+ ISBN = "979-8-89176-251-0",
160
+ abstract = "Training state-of-the-art large language models requires vast amounts of clean and diverse textual data. However, building suitable multilingual datasets remains a challenge. In this work, we present HPLT v2, a collection of high-quality multilingual monolingual and parallel corpora, extending prior work of the HPLT project. The monolingual portion of the data contains 8T tokens covering 193 languages, while the parallel data contains 380M sentence pairs covering 51 languages. We document the entire data pipeline and release the code to reproduce it. We provide extensive analysis of the quality and characteristics of our data. Finally, we evaluate the performance of language models and machine translation systems trained on HPLT v2, demonstrating its value."
161
+ }
162
+ ```
__init__.py ADDED
File without changes
config.json ADDED
@@ -0,0 +1,33 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "architectures": [
3
+ "NorT5ForConditionalGeneration"
4
+ ],
5
+ "auto_map": {
6
+ "AutoConfig": "configuration_nort5.NorT5Config",
7
+ "AutoModel": "modeling_nort5.NorT5Model",
8
+ "AutoModelForSeq2SeqLM": "modeling_nort5.NorT5ForConditionalGeneration",
9
+ "AutoModelForConditionalGeneration": "modeling_nort5.NorT5ForConditionalGeneration"
10
+ },
11
+ "attention_probs_dropout_prob": 0.0,
12
+ "bos_token_id": 5,
13
+ "cls_token_id": 2,
14
+ "eos_token_id": 6,
15
+ "hidden_dropout_prob": 0.0,
16
+ "hidden_size": 768,
17
+ "initializer_range": 0.02,
18
+ "intermediate_size": 2048,
19
+ "layer_norm_eps": 1e-07,
20
+ "max_position_embeddings": 512,
21
+ "num_attention_heads": 12,
22
+ "num_hidden_layers": 12,
23
+ "output_all_encoded_layers": true,
24
+ "pad_token_id": 0,
25
+ "position_bucket_size": 32,
26
+ "sep_token_id": 3,
27
+ "torch_dtype": "float32",
28
+ "vocab_size": 50000,
29
+ "max_length": 512,
30
+ "max_new_tokens": 256,
31
+ "is_encoder_decoder": true
32
+ }
33
+
configuration_nort5.py ADDED
@@ -0,0 +1,45 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers.configuration_utils import PretrainedConfig
2
+
3
+
4
+ class NorT5Config(PretrainedConfig):
5
+ """Configuration class to store the configuration of a `NorT5`.
6
+ """
7
+ def __init__(
8
+ self,
9
+ vocab_size=50000,
10
+ attention_probs_dropout_prob=0.0,
11
+ hidden_dropout_prob=0.0,
12
+ hidden_size=768,
13
+ intermediate_size=2048,
14
+ max_position_embeddings=512,
15
+ position_bucket_size=32,
16
+ num_attention_heads=12,
17
+ num_hidden_layers=12,
18
+ layer_norm_eps=1.0e-7,
19
+ output_all_encoded_layers=True,
20
+ pad_token_id=0,
21
+ cls_token_id=2,
22
+ sep_token_id=3,
23
+ bos_token_id=5,
24
+ eos_token_id=6,
25
+ **kwargs,
26
+ ):
27
+ super().__init__(**kwargs)
28
+
29
+ self.vocab_size = vocab_size
30
+ self.hidden_size = hidden_size
31
+ self.num_hidden_layers = num_hidden_layers
32
+ self.num_attention_heads = num_attention_heads
33
+ self.intermediate_size = intermediate_size
34
+ self.hidden_dropout_prob = hidden_dropout_prob
35
+ self.attention_probs_dropout_prob = attention_probs_dropout_prob
36
+ self.max_position_embeddings = max_position_embeddings
37
+ self.output_all_encoded_layers = output_all_encoded_layers
38
+ self.position_bucket_size = position_bucket_size
39
+ self.layer_norm_eps = layer_norm_eps
40
+ self.pad_token_id = pad_token_id
41
+ self.cls_token_id = cls_token_id
42
+ self.sep_token_id = sep_token_id
43
+ self.bos_token_id = bos_token_id
44
+ self.eos_token_id = eos_token_id
45
+
generation_config.json ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ {
2
+ "_from_model_config": true,
3
+ "decoder_start_token_id": 5,
4
+ "eos_token_id": 6,
5
+ "pad_token_id": 0
6
+ }
7
+
modeling_nort5.py ADDED
@@ -0,0 +1,712 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ from typing import List, Optional, Tuple, Union
3
+
4
+ import torch
5
+ import torch.nn as nn
6
+ import torch.nn.functional as F
7
+ from transformers.pytorch_utils import softmax_backward_data
8
+ from torch.utils import checkpoint
9
+
10
+ from .configuration_nort5 import NorT5Config
11
+ from transformers.modeling_utils import PreTrainedModel
12
+ from transformers.activations import gelu_new
13
+ from transformers.modeling_outputs import (
14
+ Seq2SeqModelOutput, Seq2SeqLMOutput, BaseModelOutput, BaseModelOutputWithPastAndCrossAttentions
15
+ )
16
+
17
+
18
+ class Encoder(nn.Module):
19
+ def __init__(self, config, activation_checkpointing=False):
20
+ super().__init__()
21
+ self.main_input_name = "input_ids"
22
+
23
+ self.relative_embedding = RelativeEmbedding(config)
24
+ self.layers = nn.ModuleList([EncoderLayer(config) for _ in range(config.num_hidden_layers)])
25
+
26
+ for i, layer in enumerate(self.layers):
27
+ layer.mlp.mlp[1].weight.data *= math.sqrt(1.0 / (2.0 * (1 + i)))
28
+ layer.mlp.mlp[-2].weight.data *= math.sqrt(1.0 / (2.0 * (1 + i)))
29
+
30
+ self.activation_checkpointing = activation_checkpointing
31
+
32
+ def forward(self, hidden_states, attention_mask):
33
+ relative_embedding = self.relative_embedding()
34
+ hidden_states, attention_probs = [hidden_states], []
35
+
36
+ for layer in self.layers:
37
+ if self.activation_checkpointing:
38
+ hidden_state, attention_p = checkpoint.checkpoint(layer, hidden_states[-1], attention_mask, relative_embedding)
39
+ else:
40
+ hidden_state, attention_p = layer(hidden_states[-1], attention_mask, relative_embedding)
41
+
42
+ hidden_states.append(hidden_state)
43
+ attention_probs.append(attention_p)
44
+
45
+ return hidden_states, attention_probs
46
+
47
+
48
+ class Decoder(nn.Module):
49
+ def __init__(self, config, activation_checkpointing=False):
50
+ super().__init__()
51
+ self.self_relative_embedding = RelativeEmbedding(config)
52
+ self.cross_relative_embedding = RelativeEmbedding(config)
53
+ self.layers = nn.ModuleList([DecoderLayer(config) for _ in range(config.num_hidden_layers)])
54
+
55
+ for i, layer in enumerate(self.layers):
56
+ layer.mlp.mlp[1].weight.data *= math.sqrt(1.0 / (2.0 * (1 + i)))
57
+ layer.mlp.mlp[-2].weight.data *= math.sqrt(1.0 / (2.0 * (1 + i)))
58
+
59
+ self.activation_checkpointing = activation_checkpointing
60
+
61
+ def forward(self, x, encoder_output, encoder_padding_mask, past_key_values=None):
62
+ self_relative_embedding = self.self_relative_embedding()
63
+ cross_relative_embedding = self.cross_relative_embedding()
64
+
65
+ if past_key_values is None:
66
+ autoreg_mask = torch.triu(
67
+ torch.full((x.size(0), x.size(0)), True, device=x.device),
68
+ diagonal=1
69
+ )
70
+ else:
71
+ autoreg_mask = None
72
+
73
+ # initialize past_key_values with `None` if past does not exist
74
+ if past_key_values is None:
75
+ past_key_values = [None] * len(self.layers)
76
+
77
+ hidden_states, self_attention_probs, cross_attention_probs, key_value_states = [x], [], [], []
78
+ for layer, past_key_value in zip(self.layers, past_key_values):
79
+ if self.activation_checkpointing:
80
+ hidden_state, self_attention_p, cross_attention_p, key_value_state = checkpoint.checkpoint(layer, hidden_states[-1], autoreg_mask, encoder_output, encoder_padding_mask, self_relative_embedding, cross_relative_embedding, past_key_value=None)
81
+ else:
82
+ hidden_state, self_attention_p, cross_attention_p, key_value_state = layer(hidden_states[-1], autoreg_mask, encoder_output, encoder_padding_mask, self_relative_embedding, cross_relative_embedding, past_key_value=past_key_value)
83
+
84
+ hidden_states.append(hidden_state)
85
+ self_attention_probs.append(self_attention_p)
86
+ cross_attention_probs.append(cross_attention_p)
87
+ key_value_states.append(key_value_state)
88
+
89
+ return hidden_states, self_attention_probs, cross_attention_probs, key_value_states
90
+
91
+
92
+ class MaskClassifier(nn.Module):
93
+ def __init__(self, config):
94
+ super().__init__()
95
+ self.nonlinearity = nn.Sequential(
96
+ nn.LayerNorm(config.hidden_size, config.layer_norm_eps, elementwise_affine=False),
97
+ nn.Dropout(config.hidden_dropout_prob),
98
+ nn.Linear(config.hidden_size, config.vocab_size)
99
+ )
100
+ self.initialize(config.hidden_size)
101
+
102
+ def initialize(self, hidden_size):
103
+ std = math.sqrt(2.0 / (5.0 * hidden_size))
104
+ nn.init.trunc_normal_(self.nonlinearity[-1].weight, mean=0.0, std=std, a=-2*std, b=2*std)
105
+ self.nonlinearity[-1].bias.data.zero_()
106
+
107
+ def forward(self, x):
108
+ x = self.nonlinearity(x)
109
+ return x
110
+
111
+
112
+ class EncoderLayer(nn.Module):
113
+ def __init__(self, config):
114
+ super().__init__()
115
+ self.attention = Attention(config, is_cross_attention=False)
116
+ self.mlp = FeedForward(config)
117
+
118
+ def forward(self, x, padding_mask, relative_embedding):
119
+ attention_output, attention_probs, _ = self.attention(x, x, padding_mask, relative_embedding)
120
+ x = x + attention_output
121
+ x = x + self.mlp(x)
122
+ return x, attention_probs
123
+
124
+
125
+ class DecoderLayer(nn.Module):
126
+ def __init__(self, config):
127
+ super().__init__()
128
+ self.self_attention = Attention(config, is_cross_attention=False)
129
+ self.cross_attention = Attention(config, is_cross_attention=True)
130
+ self.mlp = FeedForward(config)
131
+
132
+ def forward(self, x, autoreg_mask, encoder_output, encoder_padding_mask, self_relative_embedding, cross_relative_embedding, past_key_value=None):
133
+ query_offset = 0
134
+ if past_key_value is not None:
135
+ self_attn_past_key_value = past_key_value[:2]
136
+ cross_attn_past_key_value = past_key_value[2:]
137
+ query_offset = self_attn_past_key_value[0].size(2)
138
+ else:
139
+ self_attn_past_key_value, cross_attn_past_key_value = None, None
140
+
141
+ x_, self_attention_probs, self_key_value_state = self.self_attention(x, x, autoreg_mask, self_relative_embedding, past_key_value=self_attn_past_key_value, query_offset=query_offset)
142
+ x = x + x_
143
+ x_, cross_attention_probs, cross_key_value_state = self.cross_attention(x, encoder_output, encoder_padding_mask, cross_relative_embedding, past_key_value=cross_attn_past_key_value, query_offset=query_offset)
144
+ x = x + x_
145
+ x = x + self.mlp(x)
146
+
147
+ return x, self_attention_probs, cross_attention_probs, self_key_value_state + cross_key_value_state
148
+
149
+
150
+ class GeGLU(nn.Module):
151
+ def forward(self, x):
152
+ x, gate = x.chunk(2, dim=-1)
153
+ x = x * gelu_new(gate)
154
+ return x
155
+
156
+
157
+ class FeedForward(nn.Module):
158
+ def __init__(self, config):
159
+ super().__init__()
160
+ self.mlp = nn.Sequential(
161
+ nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps, elementwise_affine=False),
162
+ nn.Linear(config.hidden_size, 2*config.intermediate_size, bias=False),
163
+ GeGLU(),
164
+ nn.LayerNorm(config.intermediate_size, eps=config.layer_norm_eps, elementwise_affine=False),
165
+ nn.Linear(config.intermediate_size, config.hidden_size, bias=False),
166
+ nn.Dropout(config.hidden_dropout_prob)
167
+ )
168
+ self.initialize(config.hidden_size)
169
+
170
+ def initialize(self, hidden_size):
171
+ std = math.sqrt(2.0 / (5.0 * hidden_size))
172
+ nn.init.trunc_normal_(self.mlp[1].weight, mean=0.0, std=std, a=-2*std, b=2*std)
173
+ nn.init.trunc_normal_(self.mlp[-2].weight, mean=0.0, std=std, a=-2*std, b=2*std)
174
+
175
+ def forward(self, x):
176
+ return self.mlp(x)
177
+
178
+
179
+ class MaskedSoftmax(torch.autograd.Function):
180
+ @staticmethod
181
+ def forward(self, x, mask, dim):
182
+ self.dim = dim
183
+ if mask is not None:
184
+ x.masked_fill_(mask, float('-inf'))
185
+ x = torch.softmax(x, self.dim)
186
+ if mask is not None:
187
+ x.masked_fill_(mask, 0.0)
188
+ self.save_for_backward(x)
189
+ return x
190
+
191
+ @staticmethod
192
+ def backward(self, grad_output):
193
+ output, = self.saved_tensors
194
+ input_grad = softmax_backward_data(self, grad_output, output, self.dim, output)
195
+ return input_grad, None, None
196
+
197
+
198
+ class Attention(nn.Module):
199
+ def __init__(self, config, is_cross_attention=False):
200
+ super().__init__()
201
+
202
+ self.config = config
203
+ self.is_cross_attention = is_cross_attention
204
+
205
+ if config.hidden_size % config.num_attention_heads != 0:
206
+ raise ValueError(f"The hidden size {config.hidden_size} is not a multiple of the number of attention heads {config.num_attention_heads}")
207
+
208
+ self.hidden_size = config.hidden_size
209
+ self.num_heads = config.num_attention_heads
210
+ self.head_size = config.hidden_size // config.num_attention_heads
211
+
212
+ self.in_proj_q = nn.Linear(config.hidden_size, config.hidden_size, bias=True)
213
+ self.in_proj_k = nn.Linear(config.hidden_size, config.hidden_size, bias=True)
214
+ self.in_proj_v = nn.Linear(config.hidden_size, config.hidden_size, bias=True)
215
+ self.out_proj = nn.Linear(config.hidden_size, config.hidden_size, bias=True)
216
+
217
+ self.pre_layer_norm = nn.LayerNorm(config.hidden_size, config.layer_norm_eps, elementwise_affine=False)
218
+ self.post_layer_norm = nn.LayerNorm(config.hidden_size, config.layer_norm_eps, elementwise_affine=True)
219
+
220
+ position_indices = torch.arange(512, dtype=torch.long).unsqueeze(1) \
221
+ - torch.arange(512, dtype=torch.long).unsqueeze(0)
222
+ position_indices = self.make_log_bucket_position(position_indices, config.position_bucket_size, 512)
223
+ position_indices = config.position_bucket_size - 1 + position_indices
224
+ self.register_buffer("position_indices", position_indices, persistent=False)
225
+
226
+ self.dropout = nn.Dropout(config.attention_probs_dropout_prob)
227
+ self.scale = 1.0 / math.sqrt(3 * self.head_size)
228
+ self.initialize()
229
+
230
+ def make_log_bucket_position(self, relative_pos, bucket_size, max_position):
231
+ sign = torch.sign(relative_pos)
232
+ mid = bucket_size // 2
233
+ abs_pos = torch.where((relative_pos < mid) & (relative_pos > -mid), mid - 1, torch.abs(relative_pos).clamp(max=max_position - 1))
234
+ log_pos = torch.ceil(torch.log(abs_pos / mid) / math.log((max_position-1) / mid) * (mid - 1)).int() + mid
235
+ bucket_pos = torch.where(abs_pos <= mid, relative_pos, log_pos * sign).long()
236
+ return bucket_pos
237
+
238
+ def initialize(self):
239
+ std = math.sqrt(2.0 / (5.0 * self.hidden_size))
240
+ nn.init.trunc_normal_(self.in_proj_q.weight, mean=0.0, std=std, a=-2*std, b=2*std)
241
+ nn.init.trunc_normal_(self.in_proj_k.weight, mean=0.0, std=std, a=-2*std, b=2*std)
242
+ nn.init.trunc_normal_(self.in_proj_v.weight, mean=0.0, std=std, a=-2*std, b=2*std)
243
+ nn.init.trunc_normal_(self.out_proj.weight, mean=0.0, std=std, a=-2*std, b=2*std)
244
+ self.in_proj_q.bias.data.zero_()
245
+ self.in_proj_k.bias.data.zero_()
246
+ self.in_proj_v.bias.data.zero_()
247
+ self.out_proj.bias.data.zero_()
248
+
249
+ def forward(self, q, kv, attention_mask, relative_embedding, past_key_value=None, query_offset=0):
250
+ key_len, batch_size, _ = kv.size()
251
+ query_len, _, _ = q.size()
252
+
253
+ if not self.is_cross_attention or past_key_value is None or past_key_value[0].size(1) != kv.size(0):
254
+ kv = self.pre_layer_norm(kv)
255
+ key = self.in_proj_k(kv) # shape: [T, B, D]
256
+ value = self.in_proj_v(kv) # shape: [T, B, D]
257
+ key = key.reshape(key_len, batch_size * self.num_heads, self.head_size).transpose(0, 1) # shape: [BxH, T, D]
258
+ value = value.view(key_len, batch_size * self.num_heads, self.head_size).transpose(0, 1) # shape: [BxH, T, D]
259
+
260
+ if past_key_value is not None:
261
+ if not self.is_cross_attention:
262
+ key = torch.cat([past_key_value[0].flatten(0, 1), key], dim=1)
263
+ value = torch.cat([past_key_value[1].flatten(0, 1), value], dim=1)
264
+ key_len = key.size(1)
265
+ elif past_key_value[0].size(1) == kv.size(0):
266
+ key = past_key_value[0].flatten(0, 1)
267
+ value = past_key_value[1].flatten(0, 1)
268
+
269
+ if self.position_indices.size(0) < max(query_len, key_len):
270
+ position_indices = torch.arange(max(query_len, key_len), dtype=torch.long).unsqueeze(1) \
271
+ - torch.arange(max(query_len, key_len), dtype=torch.long).unsqueeze(0)
272
+ position_indices = self.make_log_bucket_position(position_indices, self.config.position_bucket_size, 512)
273
+ position_indices = self.config.position_bucket_size - 1 + position_indices
274
+ self.register_buffer("position_indices", position_indices.to(q.device), persistent=False)
275
+
276
+ q = self.pre_layer_norm(q)
277
+ query = self.in_proj_q(q) # shape: [T, B, D]
278
+ query = query.reshape(query_len, batch_size * self.num_heads, self.head_size).transpose(0, 1)
279
+
280
+ attention_scores = torch.bmm(query, key.transpose(1, 2) * self.scale)
281
+
282
+ query_pos = self.in_proj_q(self.dropout(relative_embedding)) # shape: [2T-1, D]
283
+ query_pos = query_pos.view(-1, self.num_heads, self.head_size) # shape: [2T-1, H, D]
284
+ key_pos = self.in_proj_k(self.dropout(relative_embedding)) # shape: [2T-1, D]
285
+ key_pos = key_pos.view(-1, self.num_heads, self.head_size) # shape: [2T-1, H, D]
286
+
287
+ query_ = query.view(batch_size, self.num_heads, query_len, self.head_size)
288
+ key_ = key.view(batch_size, self.num_heads, key_len, self.head_size)
289
+
290
+ attention_c_p = torch.einsum("bhqd,khd->bhqk", query_, key_pos.squeeze(1) * self.scale)
291
+ attention_p_c = torch.einsum("bhkd,qhd->bhqk", key_ * self.scale, query_pos.squeeze(1))
292
+ position_indices = self.position_indices[query_offset:query_offset+query_len, :key_len].expand(batch_size, self.num_heads, -1, -1)
293
+ attention_c_p = attention_c_p.gather(3, position_indices)
294
+ attention_p_c = attention_p_c.gather(2, position_indices)
295
+
296
+ attention_scores = attention_scores.view(batch_size, self.num_heads, query_len, key_len)
297
+ attention_scores.add_(attention_c_p)
298
+ attention_scores.add_(attention_p_c)
299
+
300
+ attention_probs = MaskedSoftmax.apply(attention_scores, attention_mask, -1)
301
+
302
+ attention_probs = self.dropout(attention_probs)
303
+ context = torch.bmm(attention_probs.flatten(0, 1), value) # shape: [B*H, Q, D]
304
+ context = context.transpose(0, 1).reshape(context.size(1), -1, self.hidden_size) # shape: [Q, B, H*D]
305
+ context = self.out_proj(context)
306
+ context = self.post_layer_norm(context)
307
+ context = self.dropout(context)
308
+
309
+ key = key.detach().unflatten(0, (-1, self.num_heads))
310
+ value = value.detach().unflatten(0, (-1, self.num_heads))
311
+
312
+ return context, attention_probs.detach(), (key, value)
313
+
314
+
315
+ class WordEmbedding(nn.Module):
316
+ def __init__(self, config):
317
+ super().__init__()
318
+ self.hidden_size = config.hidden_size
319
+
320
+ self.word_embedding = nn.Embedding(config.vocab_size, config.hidden_size)
321
+ self.word_layer_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps, elementwise_affine=False)
322
+ self.dropout = nn.Dropout(config.hidden_dropout_prob)
323
+
324
+ self.initialize()
325
+
326
+ def initialize(self):
327
+ std = math.sqrt(2.0 / (5.0 * self.hidden_size))
328
+ nn.init.trunc_normal_(self.word_embedding.weight, mean=0.0, std=std, a=-2*std, b=2*std)
329
+
330
+ def forward(self, input_ids):
331
+ return self.dropout(self.word_layer_norm(self.word_embedding(input_ids)))
332
+
333
+
334
+ class RelativeEmbedding(nn.Module):
335
+ def __init__(self, config):
336
+ super().__init__()
337
+ self.relative_embedding = nn.Parameter(torch.empty(2 * config.position_bucket_size - 1, config.hidden_size))
338
+ self.relative_layer_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
339
+
340
+ self.initialize(config.hidden_size)
341
+
342
+ def initialize(self, hidden_size):
343
+ std = math.sqrt(2.0 / (5.0 * hidden_size))
344
+ nn.init.trunc_normal_(self.relative_embedding, mean=0.0, std=std, a=-2*std, b=2*std)
345
+
346
+ def forward(self):
347
+ return self.relative_layer_norm(self.relative_embedding)
348
+
349
+
350
+ #
351
+ # HuggingFace wrappers
352
+ #
353
+
354
+ class NorT5PreTrainedModel(PreTrainedModel):
355
+ config_class = NorT5Config
356
+ base_model_prefix = "norT5"
357
+ supports_gradient_checkpointing = True
358
+
359
+ def _set_gradient_checkpointing(self, module, value=False):
360
+ if isinstance(module, Encoder):
361
+ module.activation_checkpointing = value
362
+
363
+ def _init_weights(self, module):
364
+ pass # everything is already initialized
365
+
366
+
367
+ class NorT5Model(NorT5PreTrainedModel):
368
+ def __init__(self, config, add_lm_layer=False, add_decoder=True):
369
+ super().__init__(config)
370
+ self.config = config
371
+
372
+ self.cls_token_id = config.cls_token_id
373
+ self.sep_token_id = config.sep_token_id
374
+ self.bos_token_id = config.bos_token_id
375
+ self.eos_token_id = config.eos_token_id
376
+ self.pad_token_id = config.pad_token_id
377
+
378
+ self.embedding = WordEmbedding(config)
379
+ self.encoder = Encoder(config, activation_checkpointing=False)
380
+ self.decoder = Decoder(config, activation_checkpointing=False) if add_decoder else None
381
+ self.classifier = MaskClassifier(config) if add_lm_layer else None
382
+
383
+ def get_input_embeddings(self):
384
+ return self.embedding.word_embedding
385
+
386
+ def set_input_embeddings(self, value):
387
+ self.embedding.word_embedding = value
388
+
389
+ def get_encoder(self):
390
+ class EncoderWrapper:
391
+ def __call__(cls, *args, **kwargs):
392
+ return cls.forward(*args, **kwargs)
393
+
394
+ def forward(
395
+ cls,
396
+ input_ids: Optional[torch.Tensor] = None,
397
+ attention_mask: Optional[torch.Tensor] = None,
398
+ output_hidden_states: Optional[bool] = None,
399
+ output_attentions: Optional[bool] = None,
400
+ return_dict: Optional[bool] = None,
401
+ ):
402
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
403
+
404
+ return self.get_encoder_output(
405
+ input_ids, attention_mask, output_hidden_states, output_attentions, return_dict=return_dict
406
+ )
407
+ return EncoderWrapper()
408
+
409
+ def get_decoder(self):
410
+ return self.get_decoder_output
411
+
412
+ def set_decoder_special_tokens(self, target_id):
413
+ target_id.masked_fill_(target_id == self.cls_token_id, self.bos_token_id)
414
+ target_id.masked_fill_(target_id == self.sep_token_id, self.eos_token_id)
415
+ return target_id
416
+
417
+ def _shift_right(self, input_ids):
418
+ shifted_input_ids = input_ids.new_zeros(input_ids.shape)
419
+ shifted_input_ids[..., 1:] = input_ids[..., :-1].clone()
420
+ shifted_input_ids[..., 0] = self.bos_token_id
421
+ shifted_input_ids.masked_fill_(shifted_input_ids == -100, self.pad_token_id)
422
+
423
+ return shifted_input_ids
424
+
425
+ def get_encoder_output(
426
+ self,
427
+ input_ids: torch.Tensor = None,
428
+ attention_mask: Optional[torch.Tensor] = None,
429
+ output_hidden_states: Optional[bool] = None,
430
+ output_attentions: Optional[bool] = None,
431
+ return_dict = False
432
+ ):
433
+ if input_ids is not None:
434
+ input_shape = input_ids.size()
435
+ else:
436
+ raise ValueError("You have to specify input_ids")
437
+
438
+ batch_size, seq_length = input_shape
439
+ device = input_ids.device
440
+
441
+ if attention_mask is None:
442
+ attention_mask = torch.zeros(batch_size, seq_length, dtype=torch.bool, device=device)
443
+ else:
444
+ attention_mask = ~attention_mask.bool()
445
+ attention_mask = attention_mask.unsqueeze(1).unsqueeze(2)
446
+
447
+ static_embeddings = self.embedding(input_ids.t())
448
+ contextualized_embeddings, attention_probs = self.encoder(static_embeddings, attention_mask)
449
+ contextualized_embeddings = [e.transpose(0, 1) for e in contextualized_embeddings]
450
+ last_layer = contextualized_embeddings[-1]
451
+ contextualized_embeddings = [contextualized_embeddings[0]] + [
452
+ contextualized_embeddings[i] - contextualized_embeddings[i - 1]
453
+ for i in range(1, len(contextualized_embeddings))
454
+ ]
455
+
456
+ if not return_dict:
457
+ return (
458
+ last_layer,
459
+ *([contextualized_embeddings] if output_hidden_states else []),
460
+ *([attention_probs] if output_attentions else [])
461
+ )
462
+
463
+ return BaseModelOutput(
464
+ last_hidden_state=last_layer,
465
+ hidden_states=contextualized_embeddings if output_hidden_states else None,
466
+ attentions=attention_probs if output_attentions else None
467
+ )
468
+
469
+ def get_decoder_output(
470
+ self,
471
+ target_ids: torch.Tensor = None,
472
+ encoder_output: torch.Tensor = None,
473
+ attention_mask: Optional[torch.Tensor] = None,
474
+ past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
475
+ use_cache: Optional[bool] = None,
476
+ output_hidden_states: Optional[bool] = None,
477
+ output_attentions: Optional[bool] = None,
478
+ return_dict = False
479
+ ):
480
+ batch_size, seq_length, _ = encoder_output.shape
481
+ device = target_ids.device
482
+
483
+ if attention_mask is None:
484
+ attention_mask = torch.zeros(batch_size, seq_length, dtype=torch.bool, device=device)
485
+ else:
486
+ attention_mask = ~attention_mask.bool()
487
+ attention_mask = attention_mask.unsqueeze(1).unsqueeze(2)
488
+
489
+ hidden_states, self_attention_p, cross_attention_p, key_value_states = self.decoder(
490
+ self.embedding(target_ids.t()),
491
+ encoder_output.transpose(0, 1),
492
+ attention_mask,
493
+ past_key_values
494
+ )
495
+
496
+ hidden_states = [e.transpose(0, 1) for e in hidden_states]
497
+ last_layer = hidden_states[-1]
498
+ hidden_states = [hidden_states[0]] + [
499
+ hidden_states[i] - hidden_states[i - 1]
500
+ for i in range(1, len(hidden_states))
501
+ ]
502
+
503
+ if not return_dict:
504
+ return (
505
+ last_layer,
506
+ *([key_value_states] if use_cache else []),
507
+ *([hidden_states] if output_hidden_states else []),
508
+ *([self_attention_p] if output_attentions else []),
509
+ *([cross_attention_p] if output_attentions else []),
510
+ )
511
+
512
+ return BaseModelOutputWithPastAndCrossAttentions(
513
+ last_hidden_state=last_layer,
514
+ past_key_values=key_value_states if use_cache else None,
515
+ hidden_states=hidden_states if output_hidden_states else None,
516
+ attentions=self_attention_p if output_attentions else None,
517
+ cross_attentions=cross_attention_p if output_attentions else None
518
+ )
519
+
520
+
521
+ def forward(
522
+ self,
523
+ input_ids: Optional[torch.LongTensor] = None,
524
+ attention_mask: Optional[torch.FloatTensor] = None,
525
+ decoder_input_ids: Optional[torch.LongTensor] = None,
526
+ decoder_attention_mask: Optional[torch.BoolTensor] = None,
527
+ encoder_outputs: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
528
+ past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
529
+ use_cache: Optional[bool] = None,
530
+ output_attentions: Optional[bool] = None,
531
+ output_hidden_states: Optional[bool] = None,
532
+ return_dict: Optional[bool] = None
533
+ ):
534
+
535
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
536
+
537
+ decoder_input_ids = self.set_decoder_special_tokens(decoder_input_ids)
538
+
539
+ if encoder_outputs is None:
540
+ encoder_outputs = self.get_encoder_output(
541
+ input_ids, attention_mask, output_hidden_states, output_attentions, return_dict
542
+ )
543
+ elif return_dict and not isinstance(encoder_outputs, BaseModelOutput):
544
+ encoder_outputs = BaseModelOutput(
545
+ last_hidden_state=encoder_outputs[0],
546
+ hidden_states=encoder_outputs[1] if len(encoder_outputs) > 1 else None,
547
+ attentions=encoder_outputs[2] if len(encoder_outputs) > 2 else None,
548
+ )
549
+
550
+ decoder_outputs = self.get_decoder_output(
551
+ decoder_input_ids, encoder_outputs[0], attention_mask, past_key_values, use_cache, output_hidden_states, output_attentions, return_dict
552
+ )
553
+
554
+ if not return_dict:
555
+ return decoder_outputs + encoder_outputs
556
+
557
+ return Seq2SeqModelOutput(
558
+ last_hidden_state=decoder_outputs.last_hidden_state,
559
+ past_key_values=decoder_outputs.past_key_values,
560
+ decoder_hidden_states=decoder_outputs.hidden_states,
561
+ decoder_attentions=decoder_outputs.attentions,
562
+ cross_attentions=decoder_outputs.cross_attentions,
563
+ encoder_last_hidden_state=encoder_outputs.last_hidden_state,
564
+ encoder_hidden_states=encoder_outputs.hidden_states,
565
+ encoder_attentions=encoder_outputs.attentions,
566
+ )
567
+
568
+
569
+ class NorT5ForConditionalGeneration(NorT5Model):
570
+
571
+ def __init__(self, config):
572
+ super().__init__(config, add_lm_layer=True)
573
+
574
+ def forward(
575
+ self,
576
+ input_ids: Optional[torch.LongTensor] = None,
577
+ attention_mask: Optional[torch.FloatTensor] = None,
578
+ decoder_input_ids: Optional[torch.LongTensor] = None,
579
+ decoder_attention_mask: Optional[torch.BoolTensor] = None,
580
+ head_mask: Optional[torch.FloatTensor] = None,
581
+ decoder_head_mask: Optional[torch.FloatTensor] = None,
582
+ cross_attn_head_mask: Optional[torch.Tensor] = None,
583
+ encoder_outputs: Optional[Tuple[Tuple[torch.Tensor]]] = None,
584
+ past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None,
585
+ inputs_embeds: Optional[torch.FloatTensor] = None,
586
+ decoder_inputs_embeds: Optional[torch.FloatTensor] = None,
587
+ labels: Optional[torch.LongTensor] = None,
588
+ use_cache: Optional[bool] = None,
589
+ output_attentions: Optional[bool] = None,
590
+ output_hidden_states: Optional[bool] = None,
591
+ return_dict: Optional[bool] = None,
592
+ token_type_ids: Optional[torch.LongTensor] = None, # for compatibility
593
+ ):
594
+ use_cache = use_cache if use_cache is not None else getattr(self.config, "use_cache", False)
595
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
596
+
597
+ if encoder_outputs is None:
598
+ encoder_outputs = self.get_encoder_output(
599
+ input_ids, attention_mask, output_hidden_states, output_attentions, return_dict
600
+ )
601
+ elif return_dict and not isinstance(encoder_outputs, BaseModelOutput):
602
+ encoder_outputs = BaseModelOutput(
603
+ last_hidden_state=encoder_outputs[0],
604
+ hidden_states=encoder_outputs[1] if len(encoder_outputs) > 1 else None,
605
+ attentions=encoder_outputs[2] if len(encoder_outputs) > 2 else None,
606
+ )
607
+
608
+ if labels is not None:
609
+ labels = self.set_decoder_special_tokens(labels)
610
+
611
+ if labels is not None and decoder_input_ids is None and decoder_inputs_embeds is None:
612
+ decoder_input_ids = self._shift_right(labels)
613
+ elif decoder_input_ids is not None:
614
+ decoder_input_ids = self.set_decoder_special_tokens(decoder_input_ids)
615
+
616
+ decoder_outputs = self.get_decoder_output(
617
+ decoder_input_ids, encoder_outputs[0], attention_mask, past_key_values, use_cache, output_hidden_states, output_attentions, return_dict
618
+ )
619
+ lm_logits = self.classifier(decoder_outputs[0])
620
+
621
+ loss = None
622
+ if labels is not None:
623
+ labels.masked_fill_(labels == self.pad_token_id, -100)
624
+ loss_fct = nn.CrossEntropyLoss(ignore_index=-100)
625
+ loss = loss_fct(lm_logits.flatten(0, 1), labels.flatten())
626
+
627
+ if not return_dict:
628
+ output = (lm_logits,) + decoder_outputs[1:] + encoder_outputs
629
+ return ((loss,) + output) if loss is not None else output
630
+
631
+ return Seq2SeqLMOutput(
632
+ loss=loss,
633
+ logits=lm_logits,
634
+ past_key_values=decoder_outputs.past_key_values,
635
+ decoder_hidden_states=decoder_outputs.hidden_states,
636
+ decoder_attentions=decoder_outputs.attentions,
637
+ cross_attentions=decoder_outputs.cross_attentions,
638
+ encoder_last_hidden_state=encoder_outputs.last_hidden_state,
639
+ encoder_hidden_states=encoder_outputs.hidden_states,
640
+ encoder_attentions=encoder_outputs.attentions,
641
+ )
642
+
643
+ def prepare_inputs_for_generation(
644
+ self,
645
+ input_ids,
646
+ past_key_values=None,
647
+ attention_mask=None,
648
+ head_mask=None,
649
+ decoder_head_mask=None,
650
+ cross_attn_head_mask=None,
651
+ use_cache=None,
652
+ encoder_outputs=None,
653
+ **kwargs,
654
+ ):
655
+ if past_key_values is not None:
656
+ input_ids = input_ids[:, -1:]
657
+
658
+ return {
659
+ "decoder_input_ids": input_ids,
660
+ "past_key_values": past_key_values,
661
+ "encoder_outputs": encoder_outputs,
662
+ "attention_mask": attention_mask,
663
+ "head_mask": head_mask,
664
+ "decoder_head_mask": decoder_head_mask,
665
+ "cross_attn_head_mask": cross_attn_head_mask,
666
+ "use_cache": use_cache,
667
+ }
668
+
669
+ def prepare_decoder_input_ids_from_labels(self, labels: torch.Tensor):
670
+ return self._shift_right(labels)
671
+
672
+ def _reorder_cache(self, past_key_values, beam_idx):
673
+ # if decoder past is not included in output
674
+ # speedy decoding is disabled and no need to reorder
675
+ if past_key_values is None:
676
+ print("You might want to consider setting `use_cache=True` to speed up decoding")
677
+ return past_key_values
678
+
679
+ reordered_decoder_past = ()
680
+ for layer_past_states in past_key_values:
681
+ # get the correct batch idx from layer past batch dim
682
+ # batch dim of `past` is at 2nd position
683
+ reordered_layer_past_states = ()
684
+ for layer_past_state in layer_past_states:
685
+ # need to set correct `past` for each of the four key / value states
686
+ layer_past_state = layer_past_state.index_select(0, beam_idx.to(layer_past_state.device))
687
+ reordered_layer_past_states = reordered_layer_past_states + (layer_past_state,)
688
+
689
+ assert reordered_layer_past_states[0].shape == layer_past_states[0].shape
690
+ assert len(reordered_layer_past_states) == len(layer_past_states)
691
+
692
+ reordered_decoder_past = reordered_decoder_past + (reordered_layer_past_states,)
693
+ return reordered_decoder_past
694
+
695
+
696
+ class NorT5Encoder(NorT5Model):
697
+ def __init__(self, config):
698
+ super().__init__(config, add_lm_layer=False, add_decoder=True)
699
+
700
+ def forward(
701
+ self,
702
+ input_ids: Optional[torch.Tensor] = None,
703
+ attention_mask: Optional[torch.Tensor] = None,
704
+ output_hidden_states: Optional[bool] = None,
705
+ output_attentions: Optional[bool] = None,
706
+ return_dict: Optional[bool] = None,
707
+ ):
708
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
709
+
710
+ return self.get_encoder_output(
711
+ input_ids, attention_mask, output_hidden_states, output_attentions, return_dict=return_dict
712
+ )
pytorch_model.bin ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:f8fa497ded7a57915cb00740cb43fd4170031dcddddab7ed148214de850f707c
3
+ size 1177063530
special_tokens_map.json ADDED
@@ -0,0 +1 @@
 
 
1
+ {"bos_token": "[BOS]", "eos_token": "[EOS]", "unk_token": "[UNK]", "sep_token": "[SEP]", "pad_token": "[PAD]", "cls_token": "[CLS]", "mask_token": "[MASK]"}
tokenizer.json ADDED
The diff for this file is too large to render. See raw diff
 
tokenizer_config.json ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ {
2
+ "tokenizer_class": "PreTrainedTokenizerFast"
3
+ }
4
+