Upload folder using huggingface_hub
Browse files- README.md +162 -0
- __init__.py +0 -0
- config.json +33 -0
- configuration_nort5.py +45 -0
- generation_config.json +7 -0
- modeling_nort5.py +712 -0
- pytorch_model.bin +3 -0
- special_tokens_map.json +1 -0
- tokenizer.json +0 -0
- tokenizer_config.json +4 -0
README.md
ADDED
|
@@ -0,0 +1,162 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
---
|
| 2 |
+
language:
|
| 3 |
+
- en
|
| 4 |
+
- eng
|
| 5 |
+
inference: false
|
| 6 |
+
tags:
|
| 7 |
+
- T5
|
| 8 |
+
- t5
|
| 9 |
+
- HPLT
|
| 10 |
+
- encoder-decoder
|
| 11 |
+
- text2text-generation
|
| 12 |
+
license: apache-2.0
|
| 13 |
+
datasets:
|
| 14 |
+
- HPLT/HPLT3.0
|
| 15 |
+
---
|
| 16 |
+
|
| 17 |
+
# HPLT v3.0 T5 for English
|
| 18 |
+
|
| 19 |
+
<img src="https://hplt-project.org/_next/static/media/logo-hplt.d5e16ca5.svg" width=12.5%>
|
| 20 |
+
|
| 21 |
+
This is one of the encoder-decoder monolingual language models trained as a third release by the [HPLT project](https://hplt-project.org/).
|
| 22 |
+
It is a text-to-text transformer trained with a denoising objective. Our
|
| 23 |
+
models follow the setup of [NorT5](https://aclanthology.org/2023.nodalida-1.61/).
|
| 24 |
+
|
| 25 |
+
We present monolingual NorT5 models for 57 languages out of 198 total in the [HPLT v3.0 dataset](https://hplt-project.org/datasets/v3.0).
|
| 26 |
+
|
| 27 |
+
All the HPLT encoder-decoder models use the same hyper-parameters, roughly following the T5-base setup:
|
| 28 |
+
- hidden size: 768
|
| 29 |
+
- attention heads: 12
|
| 30 |
+
- layers: 12
|
| 31 |
+
- vocabulary size: 32768
|
| 32 |
+
|
| 33 |
+
Every model uses its own tokenizer trained on language-specific HPLT data.
|
| 34 |
+
|
| 35 |
+
[The training code](https://github.com/hplt-project/HPLT-WP4).
|
| 36 |
+
|
| 37 |
+
## Example usage
|
| 38 |
+
|
| 39 |
+
This model currently needs a custom wrapper from `modeling_nort5.py`, you should therefore load the model with `trust_remote_code=True`.
|
| 40 |
+
|
| 41 |
+
```
|
| 42 |
+
pip install transformers==4.46.1
|
| 43 |
+
```
|
| 44 |
+
|
| 45 |
+
```python
|
| 46 |
+
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
|
| 47 |
+
|
| 48 |
+
model_path = 'HPLT/hplt_t5_base_3_0_nob_Latn'
|
| 49 |
+
model = AutoModelForSeq2SeqLM.from_pretrained(
|
| 50 |
+
model_path, trust_remote_code=True, use_safetensors=False,
|
| 51 |
+
)
|
| 52 |
+
tokenizer = AutoTokenizer.from_pretrained(model_path)
|
| 53 |
+
# MASKED LANGUAGE MODELING
|
| 54 |
+
sentence = "Ansiktsuttrykket [MASK_1] har utviklet seg til et utbredt kulturelt fenomen."
|
| 55 |
+
encoding = tokenizer(sentence, return_tensors="pt")
|
| 56 |
+
mask_1 = tokenizer.convert_tokens_to_ids("[MASK_1]")
|
| 57 |
+
mask_2 = tokenizer.convert_tokens_to_ids("[MASK_2]")
|
| 58 |
+
output_tensor = model.generate(
|
| 59 |
+
encoding.input_ids,
|
| 60 |
+
decoder_start_token_id=mask_1,
|
| 61 |
+
eos_token_id=mask_2,
|
| 62 |
+
)
|
| 63 |
+
print(tokenizer.decode(output_tensor.squeeze(), skip_special_tokens=False))
|
| 64 |
+
# should output: '[MASK_1]«The Great Gatsby»[MASK_2]'
|
| 65 |
+
```
|
| 66 |
+
|
| 67 |
+
## Intermediate checkpoints
|
| 68 |
+
|
| 69 |
+
We are releasing 10 intermediate checkpoints for each model at intervals of every 3125 training steps in separate branches. The naming convention is `stepXXX`: for example, `step18750`.
|
| 70 |
+
|
| 71 |
+
You can load a specific model revision with `transformers` using the argument `revision`:
|
| 72 |
+
```python
|
| 73 |
+
model = AutoModelForSeq2SeqLM.from_pretrained("HPLT/hplt_t5_base_3_0_eng_Latn", revision="step21875", trust_remote_code=True)
|
| 74 |
+
```
|
| 75 |
+
|
| 76 |
+
You can access all the revisions for the models with the following code:
|
| 77 |
+
```python
|
| 78 |
+
from huggingface_hub import list_repo_refs
|
| 79 |
+
out = list_repo_refs("HPLT/hplt_t5_base_3_0_eng_Latn")
|
| 80 |
+
print([b.name for b in out.branches])
|
| 81 |
+
```
|
| 82 |
+
|
| 83 |
+
## Cite us
|
| 84 |
+
|
| 85 |
+
```bibtex
|
| 86 |
+
@inproceedings{samuel-etal-2023-norbench,
|
| 87 |
+
title = "{N}or{B}ench {--} A Benchmark for {N}orwegian Language Models",
|
| 88 |
+
author = "Samuel, David and
|
| 89 |
+
Kutuzov, Andrey and
|
| 90 |
+
Touileb, Samia and
|
| 91 |
+
Velldal, Erik and
|
| 92 |
+
{\O}vrelid, Lilja and
|
| 93 |
+
R{\o}nningstad, Egil and
|
| 94 |
+
Sigdel, Elina and
|
| 95 |
+
Palatkina, Anna",
|
| 96 |
+
editor = {Alum{\"a}e, Tanel and
|
| 97 |
+
Fishel, Mark},
|
| 98 |
+
booktitle = "Proceedings of the 24th Nordic Conference on Computational Linguistics (NoDaLiDa)",
|
| 99 |
+
month = may,
|
| 100 |
+
year = "2023",
|
| 101 |
+
address = "T{\'o}rshavn, Faroe Islands",
|
| 102 |
+
publisher = "University of Tartu Library",
|
| 103 |
+
url = "https://aclanthology.org/2023.nodalida-1.61/",
|
| 104 |
+
pages = "618--633",
|
| 105 |
+
abstract = "We present NorBench: a streamlined suite of NLP tasks and probes for evaluating Norwegian language models (LMs) on standardized data splits and evaluation metrics. We also introduce a range of new Norwegian language models (both encoder and encoder-decoder based). Finally, we compare and analyze their performance, along with other existing LMs, across the different benchmark tests of NorBench."
|
| 106 |
+
}
|
| 107 |
+
```
|
| 108 |
+
|
| 109 |
+
```bibtex
|
| 110 |
+
@inproceedings{burchell-etal-2025-expanded,
|
| 111 |
+
title = "An Expanded Massive Multilingual Dataset for High-Performance Language Technologies ({HPLT})",
|
| 112 |
+
author = {Burchell, Laurie and
|
| 113 |
+
de Gibert, Ona and
|
| 114 |
+
Arefyev, Nikolay and
|
| 115 |
+
Aulamo, Mikko and
|
| 116 |
+
Ba{\~n}{\'o}n, Marta and
|
| 117 |
+
Chen, Pinzhen and
|
| 118 |
+
Fedorova, Mariia and
|
| 119 |
+
Guillou, Liane and
|
| 120 |
+
Haddow, Barry and
|
| 121 |
+
Haji{\v{c}}, Jan and
|
| 122 |
+
Helcl, Jind{\v{r}}ich and
|
| 123 |
+
Henriksson, Erik and
|
| 124 |
+
Klimaszewski, Mateusz and
|
| 125 |
+
Komulainen, Ville and
|
| 126 |
+
Kutuzov, Andrey and
|
| 127 |
+
Kyt{\"o}niemi, Joona and
|
| 128 |
+
Laippala, Veronika and
|
| 129 |
+
M{\ae}hlum, Petter and
|
| 130 |
+
Malik, Bhavitvya and
|
| 131 |
+
Mehryary, Farrokh and
|
| 132 |
+
Mikhailov, Vladislav and
|
| 133 |
+
Moghe, Nikita and
|
| 134 |
+
Myntti, Amanda and
|
| 135 |
+
O{'}Brien, Dayy{\'a}n and
|
| 136 |
+
Oepen, Stephan and
|
| 137 |
+
Pal, Proyag and
|
| 138 |
+
Piha, Jousia and
|
| 139 |
+
Pyysalo, Sampo and
|
| 140 |
+
Ram{\'i}rez-S{\'a}nchez, Gema and
|
| 141 |
+
Samuel, David and
|
| 142 |
+
Stepachev, Pavel and
|
| 143 |
+
Tiedemann, J{\"o}rg and
|
| 144 |
+
Vari{\v{s}}, Du{\v{s}}an and
|
| 145 |
+
Vojt{\v{e}}chov{\'a}, Tereza and
|
| 146 |
+
Zaragoza-Bernabeu, Jaume},
|
| 147 |
+
editor = "Che, Wanxiang and
|
| 148 |
+
Nabende, Joyce and
|
| 149 |
+
Shutova, Ekaterina and
|
| 150 |
+
Pilehvar, Mohammad Taher",
|
| 151 |
+
booktitle = "Proceedings of the 63rd Annual Meeting of the Association for Computational Linguistics (Volume 1: Long Papers)",
|
| 152 |
+
month = jul,
|
| 153 |
+
year = "2025",
|
| 154 |
+
address = "Vienna, Austria",
|
| 155 |
+
publisher = "Association for Computational Linguistics",
|
| 156 |
+
url = "https://aclanthology.org/2025.acl-long.854/",
|
| 157 |
+
doi = "10.18653/v1/2025.acl-long.854",
|
| 158 |
+
pages = "17452--17485",
|
| 159 |
+
ISBN = "979-8-89176-251-0",
|
| 160 |
+
abstract = "Training state-of-the-art large language models requires vast amounts of clean and diverse textual data. However, building suitable multilingual datasets remains a challenge. In this work, we present HPLT v2, a collection of high-quality multilingual monolingual and parallel corpora, extending prior work of the HPLT project. The monolingual portion of the data contains 8T tokens covering 193 languages, while the parallel data contains 380M sentence pairs covering 51 languages. We document the entire data pipeline and release the code to reproduce it. We provide extensive analysis of the quality and characteristics of our data. Finally, we evaluate the performance of language models and machine translation systems trained on HPLT v2, demonstrating its value."
|
| 161 |
+
}
|
| 162 |
+
```
|
__init__.py
ADDED
|
File without changes
|
config.json
ADDED
|
@@ -0,0 +1,33 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"architectures": [
|
| 3 |
+
"NorT5ForConditionalGeneration"
|
| 4 |
+
],
|
| 5 |
+
"auto_map": {
|
| 6 |
+
"AutoConfig": "configuration_nort5.NorT5Config",
|
| 7 |
+
"AutoModel": "modeling_nort5.NorT5Model",
|
| 8 |
+
"AutoModelForSeq2SeqLM": "modeling_nort5.NorT5ForConditionalGeneration",
|
| 9 |
+
"AutoModelForConditionalGeneration": "modeling_nort5.NorT5ForConditionalGeneration"
|
| 10 |
+
},
|
| 11 |
+
"attention_probs_dropout_prob": 0.0,
|
| 12 |
+
"bos_token_id": 5,
|
| 13 |
+
"cls_token_id": 2,
|
| 14 |
+
"eos_token_id": 6,
|
| 15 |
+
"hidden_dropout_prob": 0.0,
|
| 16 |
+
"hidden_size": 768,
|
| 17 |
+
"initializer_range": 0.02,
|
| 18 |
+
"intermediate_size": 2048,
|
| 19 |
+
"layer_norm_eps": 1e-07,
|
| 20 |
+
"max_position_embeddings": 512,
|
| 21 |
+
"num_attention_heads": 12,
|
| 22 |
+
"num_hidden_layers": 12,
|
| 23 |
+
"output_all_encoded_layers": true,
|
| 24 |
+
"pad_token_id": 0,
|
| 25 |
+
"position_bucket_size": 32,
|
| 26 |
+
"sep_token_id": 3,
|
| 27 |
+
"torch_dtype": "float32",
|
| 28 |
+
"vocab_size": 50000,
|
| 29 |
+
"max_length": 512,
|
| 30 |
+
"max_new_tokens": 256,
|
| 31 |
+
"is_encoder_decoder": true
|
| 32 |
+
}
|
| 33 |
+
|
configuration_nort5.py
ADDED
|
@@ -0,0 +1,45 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from transformers.configuration_utils import PretrainedConfig
|
| 2 |
+
|
| 3 |
+
|
| 4 |
+
class NorT5Config(PretrainedConfig):
|
| 5 |
+
"""Configuration class to store the configuration of a `NorT5`.
|
| 6 |
+
"""
|
| 7 |
+
def __init__(
|
| 8 |
+
self,
|
| 9 |
+
vocab_size=50000,
|
| 10 |
+
attention_probs_dropout_prob=0.0,
|
| 11 |
+
hidden_dropout_prob=0.0,
|
| 12 |
+
hidden_size=768,
|
| 13 |
+
intermediate_size=2048,
|
| 14 |
+
max_position_embeddings=512,
|
| 15 |
+
position_bucket_size=32,
|
| 16 |
+
num_attention_heads=12,
|
| 17 |
+
num_hidden_layers=12,
|
| 18 |
+
layer_norm_eps=1.0e-7,
|
| 19 |
+
output_all_encoded_layers=True,
|
| 20 |
+
pad_token_id=0,
|
| 21 |
+
cls_token_id=2,
|
| 22 |
+
sep_token_id=3,
|
| 23 |
+
bos_token_id=5,
|
| 24 |
+
eos_token_id=6,
|
| 25 |
+
**kwargs,
|
| 26 |
+
):
|
| 27 |
+
super().__init__(**kwargs)
|
| 28 |
+
|
| 29 |
+
self.vocab_size = vocab_size
|
| 30 |
+
self.hidden_size = hidden_size
|
| 31 |
+
self.num_hidden_layers = num_hidden_layers
|
| 32 |
+
self.num_attention_heads = num_attention_heads
|
| 33 |
+
self.intermediate_size = intermediate_size
|
| 34 |
+
self.hidden_dropout_prob = hidden_dropout_prob
|
| 35 |
+
self.attention_probs_dropout_prob = attention_probs_dropout_prob
|
| 36 |
+
self.max_position_embeddings = max_position_embeddings
|
| 37 |
+
self.output_all_encoded_layers = output_all_encoded_layers
|
| 38 |
+
self.position_bucket_size = position_bucket_size
|
| 39 |
+
self.layer_norm_eps = layer_norm_eps
|
| 40 |
+
self.pad_token_id = pad_token_id
|
| 41 |
+
self.cls_token_id = cls_token_id
|
| 42 |
+
self.sep_token_id = sep_token_id
|
| 43 |
+
self.bos_token_id = bos_token_id
|
| 44 |
+
self.eos_token_id = eos_token_id
|
| 45 |
+
|
generation_config.json
ADDED
|
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"_from_model_config": true,
|
| 3 |
+
"decoder_start_token_id": 5,
|
| 4 |
+
"eos_token_id": 6,
|
| 5 |
+
"pad_token_id": 0
|
| 6 |
+
}
|
| 7 |
+
|
modeling_nort5.py
ADDED
|
@@ -0,0 +1,712 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import math
|
| 2 |
+
from typing import List, Optional, Tuple, Union
|
| 3 |
+
|
| 4 |
+
import torch
|
| 5 |
+
import torch.nn as nn
|
| 6 |
+
import torch.nn.functional as F
|
| 7 |
+
from transformers.pytorch_utils import softmax_backward_data
|
| 8 |
+
from torch.utils import checkpoint
|
| 9 |
+
|
| 10 |
+
from .configuration_nort5 import NorT5Config
|
| 11 |
+
from transformers.modeling_utils import PreTrainedModel
|
| 12 |
+
from transformers.activations import gelu_new
|
| 13 |
+
from transformers.modeling_outputs import (
|
| 14 |
+
Seq2SeqModelOutput, Seq2SeqLMOutput, BaseModelOutput, BaseModelOutputWithPastAndCrossAttentions
|
| 15 |
+
)
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
class Encoder(nn.Module):
|
| 19 |
+
def __init__(self, config, activation_checkpointing=False):
|
| 20 |
+
super().__init__()
|
| 21 |
+
self.main_input_name = "input_ids"
|
| 22 |
+
|
| 23 |
+
self.relative_embedding = RelativeEmbedding(config)
|
| 24 |
+
self.layers = nn.ModuleList([EncoderLayer(config) for _ in range(config.num_hidden_layers)])
|
| 25 |
+
|
| 26 |
+
for i, layer in enumerate(self.layers):
|
| 27 |
+
layer.mlp.mlp[1].weight.data *= math.sqrt(1.0 / (2.0 * (1 + i)))
|
| 28 |
+
layer.mlp.mlp[-2].weight.data *= math.sqrt(1.0 / (2.0 * (1 + i)))
|
| 29 |
+
|
| 30 |
+
self.activation_checkpointing = activation_checkpointing
|
| 31 |
+
|
| 32 |
+
def forward(self, hidden_states, attention_mask):
|
| 33 |
+
relative_embedding = self.relative_embedding()
|
| 34 |
+
hidden_states, attention_probs = [hidden_states], []
|
| 35 |
+
|
| 36 |
+
for layer in self.layers:
|
| 37 |
+
if self.activation_checkpointing:
|
| 38 |
+
hidden_state, attention_p = checkpoint.checkpoint(layer, hidden_states[-1], attention_mask, relative_embedding)
|
| 39 |
+
else:
|
| 40 |
+
hidden_state, attention_p = layer(hidden_states[-1], attention_mask, relative_embedding)
|
| 41 |
+
|
| 42 |
+
hidden_states.append(hidden_state)
|
| 43 |
+
attention_probs.append(attention_p)
|
| 44 |
+
|
| 45 |
+
return hidden_states, attention_probs
|
| 46 |
+
|
| 47 |
+
|
| 48 |
+
class Decoder(nn.Module):
|
| 49 |
+
def __init__(self, config, activation_checkpointing=False):
|
| 50 |
+
super().__init__()
|
| 51 |
+
self.self_relative_embedding = RelativeEmbedding(config)
|
| 52 |
+
self.cross_relative_embedding = RelativeEmbedding(config)
|
| 53 |
+
self.layers = nn.ModuleList([DecoderLayer(config) for _ in range(config.num_hidden_layers)])
|
| 54 |
+
|
| 55 |
+
for i, layer in enumerate(self.layers):
|
| 56 |
+
layer.mlp.mlp[1].weight.data *= math.sqrt(1.0 / (2.0 * (1 + i)))
|
| 57 |
+
layer.mlp.mlp[-2].weight.data *= math.sqrt(1.0 / (2.0 * (1 + i)))
|
| 58 |
+
|
| 59 |
+
self.activation_checkpointing = activation_checkpointing
|
| 60 |
+
|
| 61 |
+
def forward(self, x, encoder_output, encoder_padding_mask, past_key_values=None):
|
| 62 |
+
self_relative_embedding = self.self_relative_embedding()
|
| 63 |
+
cross_relative_embedding = self.cross_relative_embedding()
|
| 64 |
+
|
| 65 |
+
if past_key_values is None:
|
| 66 |
+
autoreg_mask = torch.triu(
|
| 67 |
+
torch.full((x.size(0), x.size(0)), True, device=x.device),
|
| 68 |
+
diagonal=1
|
| 69 |
+
)
|
| 70 |
+
else:
|
| 71 |
+
autoreg_mask = None
|
| 72 |
+
|
| 73 |
+
# initialize past_key_values with `None` if past does not exist
|
| 74 |
+
if past_key_values is None:
|
| 75 |
+
past_key_values = [None] * len(self.layers)
|
| 76 |
+
|
| 77 |
+
hidden_states, self_attention_probs, cross_attention_probs, key_value_states = [x], [], [], []
|
| 78 |
+
for layer, past_key_value in zip(self.layers, past_key_values):
|
| 79 |
+
if self.activation_checkpointing:
|
| 80 |
+
hidden_state, self_attention_p, cross_attention_p, key_value_state = checkpoint.checkpoint(layer, hidden_states[-1], autoreg_mask, encoder_output, encoder_padding_mask, self_relative_embedding, cross_relative_embedding, past_key_value=None)
|
| 81 |
+
else:
|
| 82 |
+
hidden_state, self_attention_p, cross_attention_p, key_value_state = layer(hidden_states[-1], autoreg_mask, encoder_output, encoder_padding_mask, self_relative_embedding, cross_relative_embedding, past_key_value=past_key_value)
|
| 83 |
+
|
| 84 |
+
hidden_states.append(hidden_state)
|
| 85 |
+
self_attention_probs.append(self_attention_p)
|
| 86 |
+
cross_attention_probs.append(cross_attention_p)
|
| 87 |
+
key_value_states.append(key_value_state)
|
| 88 |
+
|
| 89 |
+
return hidden_states, self_attention_probs, cross_attention_probs, key_value_states
|
| 90 |
+
|
| 91 |
+
|
| 92 |
+
class MaskClassifier(nn.Module):
|
| 93 |
+
def __init__(self, config):
|
| 94 |
+
super().__init__()
|
| 95 |
+
self.nonlinearity = nn.Sequential(
|
| 96 |
+
nn.LayerNorm(config.hidden_size, config.layer_norm_eps, elementwise_affine=False),
|
| 97 |
+
nn.Dropout(config.hidden_dropout_prob),
|
| 98 |
+
nn.Linear(config.hidden_size, config.vocab_size)
|
| 99 |
+
)
|
| 100 |
+
self.initialize(config.hidden_size)
|
| 101 |
+
|
| 102 |
+
def initialize(self, hidden_size):
|
| 103 |
+
std = math.sqrt(2.0 / (5.0 * hidden_size))
|
| 104 |
+
nn.init.trunc_normal_(self.nonlinearity[-1].weight, mean=0.0, std=std, a=-2*std, b=2*std)
|
| 105 |
+
self.nonlinearity[-1].bias.data.zero_()
|
| 106 |
+
|
| 107 |
+
def forward(self, x):
|
| 108 |
+
x = self.nonlinearity(x)
|
| 109 |
+
return x
|
| 110 |
+
|
| 111 |
+
|
| 112 |
+
class EncoderLayer(nn.Module):
|
| 113 |
+
def __init__(self, config):
|
| 114 |
+
super().__init__()
|
| 115 |
+
self.attention = Attention(config, is_cross_attention=False)
|
| 116 |
+
self.mlp = FeedForward(config)
|
| 117 |
+
|
| 118 |
+
def forward(self, x, padding_mask, relative_embedding):
|
| 119 |
+
attention_output, attention_probs, _ = self.attention(x, x, padding_mask, relative_embedding)
|
| 120 |
+
x = x + attention_output
|
| 121 |
+
x = x + self.mlp(x)
|
| 122 |
+
return x, attention_probs
|
| 123 |
+
|
| 124 |
+
|
| 125 |
+
class DecoderLayer(nn.Module):
|
| 126 |
+
def __init__(self, config):
|
| 127 |
+
super().__init__()
|
| 128 |
+
self.self_attention = Attention(config, is_cross_attention=False)
|
| 129 |
+
self.cross_attention = Attention(config, is_cross_attention=True)
|
| 130 |
+
self.mlp = FeedForward(config)
|
| 131 |
+
|
| 132 |
+
def forward(self, x, autoreg_mask, encoder_output, encoder_padding_mask, self_relative_embedding, cross_relative_embedding, past_key_value=None):
|
| 133 |
+
query_offset = 0
|
| 134 |
+
if past_key_value is not None:
|
| 135 |
+
self_attn_past_key_value = past_key_value[:2]
|
| 136 |
+
cross_attn_past_key_value = past_key_value[2:]
|
| 137 |
+
query_offset = self_attn_past_key_value[0].size(2)
|
| 138 |
+
else:
|
| 139 |
+
self_attn_past_key_value, cross_attn_past_key_value = None, None
|
| 140 |
+
|
| 141 |
+
x_, self_attention_probs, self_key_value_state = self.self_attention(x, x, autoreg_mask, self_relative_embedding, past_key_value=self_attn_past_key_value, query_offset=query_offset)
|
| 142 |
+
x = x + x_
|
| 143 |
+
x_, cross_attention_probs, cross_key_value_state = self.cross_attention(x, encoder_output, encoder_padding_mask, cross_relative_embedding, past_key_value=cross_attn_past_key_value, query_offset=query_offset)
|
| 144 |
+
x = x + x_
|
| 145 |
+
x = x + self.mlp(x)
|
| 146 |
+
|
| 147 |
+
return x, self_attention_probs, cross_attention_probs, self_key_value_state + cross_key_value_state
|
| 148 |
+
|
| 149 |
+
|
| 150 |
+
class GeGLU(nn.Module):
|
| 151 |
+
def forward(self, x):
|
| 152 |
+
x, gate = x.chunk(2, dim=-1)
|
| 153 |
+
x = x * gelu_new(gate)
|
| 154 |
+
return x
|
| 155 |
+
|
| 156 |
+
|
| 157 |
+
class FeedForward(nn.Module):
|
| 158 |
+
def __init__(self, config):
|
| 159 |
+
super().__init__()
|
| 160 |
+
self.mlp = nn.Sequential(
|
| 161 |
+
nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps, elementwise_affine=False),
|
| 162 |
+
nn.Linear(config.hidden_size, 2*config.intermediate_size, bias=False),
|
| 163 |
+
GeGLU(),
|
| 164 |
+
nn.LayerNorm(config.intermediate_size, eps=config.layer_norm_eps, elementwise_affine=False),
|
| 165 |
+
nn.Linear(config.intermediate_size, config.hidden_size, bias=False),
|
| 166 |
+
nn.Dropout(config.hidden_dropout_prob)
|
| 167 |
+
)
|
| 168 |
+
self.initialize(config.hidden_size)
|
| 169 |
+
|
| 170 |
+
def initialize(self, hidden_size):
|
| 171 |
+
std = math.sqrt(2.0 / (5.0 * hidden_size))
|
| 172 |
+
nn.init.trunc_normal_(self.mlp[1].weight, mean=0.0, std=std, a=-2*std, b=2*std)
|
| 173 |
+
nn.init.trunc_normal_(self.mlp[-2].weight, mean=0.0, std=std, a=-2*std, b=2*std)
|
| 174 |
+
|
| 175 |
+
def forward(self, x):
|
| 176 |
+
return self.mlp(x)
|
| 177 |
+
|
| 178 |
+
|
| 179 |
+
class MaskedSoftmax(torch.autograd.Function):
|
| 180 |
+
@staticmethod
|
| 181 |
+
def forward(self, x, mask, dim):
|
| 182 |
+
self.dim = dim
|
| 183 |
+
if mask is not None:
|
| 184 |
+
x.masked_fill_(mask, float('-inf'))
|
| 185 |
+
x = torch.softmax(x, self.dim)
|
| 186 |
+
if mask is not None:
|
| 187 |
+
x.masked_fill_(mask, 0.0)
|
| 188 |
+
self.save_for_backward(x)
|
| 189 |
+
return x
|
| 190 |
+
|
| 191 |
+
@staticmethod
|
| 192 |
+
def backward(self, grad_output):
|
| 193 |
+
output, = self.saved_tensors
|
| 194 |
+
input_grad = softmax_backward_data(self, grad_output, output, self.dim, output)
|
| 195 |
+
return input_grad, None, None
|
| 196 |
+
|
| 197 |
+
|
| 198 |
+
class Attention(nn.Module):
|
| 199 |
+
def __init__(self, config, is_cross_attention=False):
|
| 200 |
+
super().__init__()
|
| 201 |
+
|
| 202 |
+
self.config = config
|
| 203 |
+
self.is_cross_attention = is_cross_attention
|
| 204 |
+
|
| 205 |
+
if config.hidden_size % config.num_attention_heads != 0:
|
| 206 |
+
raise ValueError(f"The hidden size {config.hidden_size} is not a multiple of the number of attention heads {config.num_attention_heads}")
|
| 207 |
+
|
| 208 |
+
self.hidden_size = config.hidden_size
|
| 209 |
+
self.num_heads = config.num_attention_heads
|
| 210 |
+
self.head_size = config.hidden_size // config.num_attention_heads
|
| 211 |
+
|
| 212 |
+
self.in_proj_q = nn.Linear(config.hidden_size, config.hidden_size, bias=True)
|
| 213 |
+
self.in_proj_k = nn.Linear(config.hidden_size, config.hidden_size, bias=True)
|
| 214 |
+
self.in_proj_v = nn.Linear(config.hidden_size, config.hidden_size, bias=True)
|
| 215 |
+
self.out_proj = nn.Linear(config.hidden_size, config.hidden_size, bias=True)
|
| 216 |
+
|
| 217 |
+
self.pre_layer_norm = nn.LayerNorm(config.hidden_size, config.layer_norm_eps, elementwise_affine=False)
|
| 218 |
+
self.post_layer_norm = nn.LayerNorm(config.hidden_size, config.layer_norm_eps, elementwise_affine=True)
|
| 219 |
+
|
| 220 |
+
position_indices = torch.arange(512, dtype=torch.long).unsqueeze(1) \
|
| 221 |
+
- torch.arange(512, dtype=torch.long).unsqueeze(0)
|
| 222 |
+
position_indices = self.make_log_bucket_position(position_indices, config.position_bucket_size, 512)
|
| 223 |
+
position_indices = config.position_bucket_size - 1 + position_indices
|
| 224 |
+
self.register_buffer("position_indices", position_indices, persistent=False)
|
| 225 |
+
|
| 226 |
+
self.dropout = nn.Dropout(config.attention_probs_dropout_prob)
|
| 227 |
+
self.scale = 1.0 / math.sqrt(3 * self.head_size)
|
| 228 |
+
self.initialize()
|
| 229 |
+
|
| 230 |
+
def make_log_bucket_position(self, relative_pos, bucket_size, max_position):
|
| 231 |
+
sign = torch.sign(relative_pos)
|
| 232 |
+
mid = bucket_size // 2
|
| 233 |
+
abs_pos = torch.where((relative_pos < mid) & (relative_pos > -mid), mid - 1, torch.abs(relative_pos).clamp(max=max_position - 1))
|
| 234 |
+
log_pos = torch.ceil(torch.log(abs_pos / mid) / math.log((max_position-1) / mid) * (mid - 1)).int() + mid
|
| 235 |
+
bucket_pos = torch.where(abs_pos <= mid, relative_pos, log_pos * sign).long()
|
| 236 |
+
return bucket_pos
|
| 237 |
+
|
| 238 |
+
def initialize(self):
|
| 239 |
+
std = math.sqrt(2.0 / (5.0 * self.hidden_size))
|
| 240 |
+
nn.init.trunc_normal_(self.in_proj_q.weight, mean=0.0, std=std, a=-2*std, b=2*std)
|
| 241 |
+
nn.init.trunc_normal_(self.in_proj_k.weight, mean=0.0, std=std, a=-2*std, b=2*std)
|
| 242 |
+
nn.init.trunc_normal_(self.in_proj_v.weight, mean=0.0, std=std, a=-2*std, b=2*std)
|
| 243 |
+
nn.init.trunc_normal_(self.out_proj.weight, mean=0.0, std=std, a=-2*std, b=2*std)
|
| 244 |
+
self.in_proj_q.bias.data.zero_()
|
| 245 |
+
self.in_proj_k.bias.data.zero_()
|
| 246 |
+
self.in_proj_v.bias.data.zero_()
|
| 247 |
+
self.out_proj.bias.data.zero_()
|
| 248 |
+
|
| 249 |
+
def forward(self, q, kv, attention_mask, relative_embedding, past_key_value=None, query_offset=0):
|
| 250 |
+
key_len, batch_size, _ = kv.size()
|
| 251 |
+
query_len, _, _ = q.size()
|
| 252 |
+
|
| 253 |
+
if not self.is_cross_attention or past_key_value is None or past_key_value[0].size(1) != kv.size(0):
|
| 254 |
+
kv = self.pre_layer_norm(kv)
|
| 255 |
+
key = self.in_proj_k(kv) # shape: [T, B, D]
|
| 256 |
+
value = self.in_proj_v(kv) # shape: [T, B, D]
|
| 257 |
+
key = key.reshape(key_len, batch_size * self.num_heads, self.head_size).transpose(0, 1) # shape: [BxH, T, D]
|
| 258 |
+
value = value.view(key_len, batch_size * self.num_heads, self.head_size).transpose(0, 1) # shape: [BxH, T, D]
|
| 259 |
+
|
| 260 |
+
if past_key_value is not None:
|
| 261 |
+
if not self.is_cross_attention:
|
| 262 |
+
key = torch.cat([past_key_value[0].flatten(0, 1), key], dim=1)
|
| 263 |
+
value = torch.cat([past_key_value[1].flatten(0, 1), value], dim=1)
|
| 264 |
+
key_len = key.size(1)
|
| 265 |
+
elif past_key_value[0].size(1) == kv.size(0):
|
| 266 |
+
key = past_key_value[0].flatten(0, 1)
|
| 267 |
+
value = past_key_value[1].flatten(0, 1)
|
| 268 |
+
|
| 269 |
+
if self.position_indices.size(0) < max(query_len, key_len):
|
| 270 |
+
position_indices = torch.arange(max(query_len, key_len), dtype=torch.long).unsqueeze(1) \
|
| 271 |
+
- torch.arange(max(query_len, key_len), dtype=torch.long).unsqueeze(0)
|
| 272 |
+
position_indices = self.make_log_bucket_position(position_indices, self.config.position_bucket_size, 512)
|
| 273 |
+
position_indices = self.config.position_bucket_size - 1 + position_indices
|
| 274 |
+
self.register_buffer("position_indices", position_indices.to(q.device), persistent=False)
|
| 275 |
+
|
| 276 |
+
q = self.pre_layer_norm(q)
|
| 277 |
+
query = self.in_proj_q(q) # shape: [T, B, D]
|
| 278 |
+
query = query.reshape(query_len, batch_size * self.num_heads, self.head_size).transpose(0, 1)
|
| 279 |
+
|
| 280 |
+
attention_scores = torch.bmm(query, key.transpose(1, 2) * self.scale)
|
| 281 |
+
|
| 282 |
+
query_pos = self.in_proj_q(self.dropout(relative_embedding)) # shape: [2T-1, D]
|
| 283 |
+
query_pos = query_pos.view(-1, self.num_heads, self.head_size) # shape: [2T-1, H, D]
|
| 284 |
+
key_pos = self.in_proj_k(self.dropout(relative_embedding)) # shape: [2T-1, D]
|
| 285 |
+
key_pos = key_pos.view(-1, self.num_heads, self.head_size) # shape: [2T-1, H, D]
|
| 286 |
+
|
| 287 |
+
query_ = query.view(batch_size, self.num_heads, query_len, self.head_size)
|
| 288 |
+
key_ = key.view(batch_size, self.num_heads, key_len, self.head_size)
|
| 289 |
+
|
| 290 |
+
attention_c_p = torch.einsum("bhqd,khd->bhqk", query_, key_pos.squeeze(1) * self.scale)
|
| 291 |
+
attention_p_c = torch.einsum("bhkd,qhd->bhqk", key_ * self.scale, query_pos.squeeze(1))
|
| 292 |
+
position_indices = self.position_indices[query_offset:query_offset+query_len, :key_len].expand(batch_size, self.num_heads, -1, -1)
|
| 293 |
+
attention_c_p = attention_c_p.gather(3, position_indices)
|
| 294 |
+
attention_p_c = attention_p_c.gather(2, position_indices)
|
| 295 |
+
|
| 296 |
+
attention_scores = attention_scores.view(batch_size, self.num_heads, query_len, key_len)
|
| 297 |
+
attention_scores.add_(attention_c_p)
|
| 298 |
+
attention_scores.add_(attention_p_c)
|
| 299 |
+
|
| 300 |
+
attention_probs = MaskedSoftmax.apply(attention_scores, attention_mask, -1)
|
| 301 |
+
|
| 302 |
+
attention_probs = self.dropout(attention_probs)
|
| 303 |
+
context = torch.bmm(attention_probs.flatten(0, 1), value) # shape: [B*H, Q, D]
|
| 304 |
+
context = context.transpose(0, 1).reshape(context.size(1), -1, self.hidden_size) # shape: [Q, B, H*D]
|
| 305 |
+
context = self.out_proj(context)
|
| 306 |
+
context = self.post_layer_norm(context)
|
| 307 |
+
context = self.dropout(context)
|
| 308 |
+
|
| 309 |
+
key = key.detach().unflatten(0, (-1, self.num_heads))
|
| 310 |
+
value = value.detach().unflatten(0, (-1, self.num_heads))
|
| 311 |
+
|
| 312 |
+
return context, attention_probs.detach(), (key, value)
|
| 313 |
+
|
| 314 |
+
|
| 315 |
+
class WordEmbedding(nn.Module):
|
| 316 |
+
def __init__(self, config):
|
| 317 |
+
super().__init__()
|
| 318 |
+
self.hidden_size = config.hidden_size
|
| 319 |
+
|
| 320 |
+
self.word_embedding = nn.Embedding(config.vocab_size, config.hidden_size)
|
| 321 |
+
self.word_layer_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps, elementwise_affine=False)
|
| 322 |
+
self.dropout = nn.Dropout(config.hidden_dropout_prob)
|
| 323 |
+
|
| 324 |
+
self.initialize()
|
| 325 |
+
|
| 326 |
+
def initialize(self):
|
| 327 |
+
std = math.sqrt(2.0 / (5.0 * self.hidden_size))
|
| 328 |
+
nn.init.trunc_normal_(self.word_embedding.weight, mean=0.0, std=std, a=-2*std, b=2*std)
|
| 329 |
+
|
| 330 |
+
def forward(self, input_ids):
|
| 331 |
+
return self.dropout(self.word_layer_norm(self.word_embedding(input_ids)))
|
| 332 |
+
|
| 333 |
+
|
| 334 |
+
class RelativeEmbedding(nn.Module):
|
| 335 |
+
def __init__(self, config):
|
| 336 |
+
super().__init__()
|
| 337 |
+
self.relative_embedding = nn.Parameter(torch.empty(2 * config.position_bucket_size - 1, config.hidden_size))
|
| 338 |
+
self.relative_layer_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
|
| 339 |
+
|
| 340 |
+
self.initialize(config.hidden_size)
|
| 341 |
+
|
| 342 |
+
def initialize(self, hidden_size):
|
| 343 |
+
std = math.sqrt(2.0 / (5.0 * hidden_size))
|
| 344 |
+
nn.init.trunc_normal_(self.relative_embedding, mean=0.0, std=std, a=-2*std, b=2*std)
|
| 345 |
+
|
| 346 |
+
def forward(self):
|
| 347 |
+
return self.relative_layer_norm(self.relative_embedding)
|
| 348 |
+
|
| 349 |
+
|
| 350 |
+
#
|
| 351 |
+
# HuggingFace wrappers
|
| 352 |
+
#
|
| 353 |
+
|
| 354 |
+
class NorT5PreTrainedModel(PreTrainedModel):
|
| 355 |
+
config_class = NorT5Config
|
| 356 |
+
base_model_prefix = "norT5"
|
| 357 |
+
supports_gradient_checkpointing = True
|
| 358 |
+
|
| 359 |
+
def _set_gradient_checkpointing(self, module, value=False):
|
| 360 |
+
if isinstance(module, Encoder):
|
| 361 |
+
module.activation_checkpointing = value
|
| 362 |
+
|
| 363 |
+
def _init_weights(self, module):
|
| 364 |
+
pass # everything is already initialized
|
| 365 |
+
|
| 366 |
+
|
| 367 |
+
class NorT5Model(NorT5PreTrainedModel):
|
| 368 |
+
def __init__(self, config, add_lm_layer=False, add_decoder=True):
|
| 369 |
+
super().__init__(config)
|
| 370 |
+
self.config = config
|
| 371 |
+
|
| 372 |
+
self.cls_token_id = config.cls_token_id
|
| 373 |
+
self.sep_token_id = config.sep_token_id
|
| 374 |
+
self.bos_token_id = config.bos_token_id
|
| 375 |
+
self.eos_token_id = config.eos_token_id
|
| 376 |
+
self.pad_token_id = config.pad_token_id
|
| 377 |
+
|
| 378 |
+
self.embedding = WordEmbedding(config)
|
| 379 |
+
self.encoder = Encoder(config, activation_checkpointing=False)
|
| 380 |
+
self.decoder = Decoder(config, activation_checkpointing=False) if add_decoder else None
|
| 381 |
+
self.classifier = MaskClassifier(config) if add_lm_layer else None
|
| 382 |
+
|
| 383 |
+
def get_input_embeddings(self):
|
| 384 |
+
return self.embedding.word_embedding
|
| 385 |
+
|
| 386 |
+
def set_input_embeddings(self, value):
|
| 387 |
+
self.embedding.word_embedding = value
|
| 388 |
+
|
| 389 |
+
def get_encoder(self):
|
| 390 |
+
class EncoderWrapper:
|
| 391 |
+
def __call__(cls, *args, **kwargs):
|
| 392 |
+
return cls.forward(*args, **kwargs)
|
| 393 |
+
|
| 394 |
+
def forward(
|
| 395 |
+
cls,
|
| 396 |
+
input_ids: Optional[torch.Tensor] = None,
|
| 397 |
+
attention_mask: Optional[torch.Tensor] = None,
|
| 398 |
+
output_hidden_states: Optional[bool] = None,
|
| 399 |
+
output_attentions: Optional[bool] = None,
|
| 400 |
+
return_dict: Optional[bool] = None,
|
| 401 |
+
):
|
| 402 |
+
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
| 403 |
+
|
| 404 |
+
return self.get_encoder_output(
|
| 405 |
+
input_ids, attention_mask, output_hidden_states, output_attentions, return_dict=return_dict
|
| 406 |
+
)
|
| 407 |
+
return EncoderWrapper()
|
| 408 |
+
|
| 409 |
+
def get_decoder(self):
|
| 410 |
+
return self.get_decoder_output
|
| 411 |
+
|
| 412 |
+
def set_decoder_special_tokens(self, target_id):
|
| 413 |
+
target_id.masked_fill_(target_id == self.cls_token_id, self.bos_token_id)
|
| 414 |
+
target_id.masked_fill_(target_id == self.sep_token_id, self.eos_token_id)
|
| 415 |
+
return target_id
|
| 416 |
+
|
| 417 |
+
def _shift_right(self, input_ids):
|
| 418 |
+
shifted_input_ids = input_ids.new_zeros(input_ids.shape)
|
| 419 |
+
shifted_input_ids[..., 1:] = input_ids[..., :-1].clone()
|
| 420 |
+
shifted_input_ids[..., 0] = self.bos_token_id
|
| 421 |
+
shifted_input_ids.masked_fill_(shifted_input_ids == -100, self.pad_token_id)
|
| 422 |
+
|
| 423 |
+
return shifted_input_ids
|
| 424 |
+
|
| 425 |
+
def get_encoder_output(
|
| 426 |
+
self,
|
| 427 |
+
input_ids: torch.Tensor = None,
|
| 428 |
+
attention_mask: Optional[torch.Tensor] = None,
|
| 429 |
+
output_hidden_states: Optional[bool] = None,
|
| 430 |
+
output_attentions: Optional[bool] = None,
|
| 431 |
+
return_dict = False
|
| 432 |
+
):
|
| 433 |
+
if input_ids is not None:
|
| 434 |
+
input_shape = input_ids.size()
|
| 435 |
+
else:
|
| 436 |
+
raise ValueError("You have to specify input_ids")
|
| 437 |
+
|
| 438 |
+
batch_size, seq_length = input_shape
|
| 439 |
+
device = input_ids.device
|
| 440 |
+
|
| 441 |
+
if attention_mask is None:
|
| 442 |
+
attention_mask = torch.zeros(batch_size, seq_length, dtype=torch.bool, device=device)
|
| 443 |
+
else:
|
| 444 |
+
attention_mask = ~attention_mask.bool()
|
| 445 |
+
attention_mask = attention_mask.unsqueeze(1).unsqueeze(2)
|
| 446 |
+
|
| 447 |
+
static_embeddings = self.embedding(input_ids.t())
|
| 448 |
+
contextualized_embeddings, attention_probs = self.encoder(static_embeddings, attention_mask)
|
| 449 |
+
contextualized_embeddings = [e.transpose(0, 1) for e in contextualized_embeddings]
|
| 450 |
+
last_layer = contextualized_embeddings[-1]
|
| 451 |
+
contextualized_embeddings = [contextualized_embeddings[0]] + [
|
| 452 |
+
contextualized_embeddings[i] - contextualized_embeddings[i - 1]
|
| 453 |
+
for i in range(1, len(contextualized_embeddings))
|
| 454 |
+
]
|
| 455 |
+
|
| 456 |
+
if not return_dict:
|
| 457 |
+
return (
|
| 458 |
+
last_layer,
|
| 459 |
+
*([contextualized_embeddings] if output_hidden_states else []),
|
| 460 |
+
*([attention_probs] if output_attentions else [])
|
| 461 |
+
)
|
| 462 |
+
|
| 463 |
+
return BaseModelOutput(
|
| 464 |
+
last_hidden_state=last_layer,
|
| 465 |
+
hidden_states=contextualized_embeddings if output_hidden_states else None,
|
| 466 |
+
attentions=attention_probs if output_attentions else None
|
| 467 |
+
)
|
| 468 |
+
|
| 469 |
+
def get_decoder_output(
|
| 470 |
+
self,
|
| 471 |
+
target_ids: torch.Tensor = None,
|
| 472 |
+
encoder_output: torch.Tensor = None,
|
| 473 |
+
attention_mask: Optional[torch.Tensor] = None,
|
| 474 |
+
past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
|
| 475 |
+
use_cache: Optional[bool] = None,
|
| 476 |
+
output_hidden_states: Optional[bool] = None,
|
| 477 |
+
output_attentions: Optional[bool] = None,
|
| 478 |
+
return_dict = False
|
| 479 |
+
):
|
| 480 |
+
batch_size, seq_length, _ = encoder_output.shape
|
| 481 |
+
device = target_ids.device
|
| 482 |
+
|
| 483 |
+
if attention_mask is None:
|
| 484 |
+
attention_mask = torch.zeros(batch_size, seq_length, dtype=torch.bool, device=device)
|
| 485 |
+
else:
|
| 486 |
+
attention_mask = ~attention_mask.bool()
|
| 487 |
+
attention_mask = attention_mask.unsqueeze(1).unsqueeze(2)
|
| 488 |
+
|
| 489 |
+
hidden_states, self_attention_p, cross_attention_p, key_value_states = self.decoder(
|
| 490 |
+
self.embedding(target_ids.t()),
|
| 491 |
+
encoder_output.transpose(0, 1),
|
| 492 |
+
attention_mask,
|
| 493 |
+
past_key_values
|
| 494 |
+
)
|
| 495 |
+
|
| 496 |
+
hidden_states = [e.transpose(0, 1) for e in hidden_states]
|
| 497 |
+
last_layer = hidden_states[-1]
|
| 498 |
+
hidden_states = [hidden_states[0]] + [
|
| 499 |
+
hidden_states[i] - hidden_states[i - 1]
|
| 500 |
+
for i in range(1, len(hidden_states))
|
| 501 |
+
]
|
| 502 |
+
|
| 503 |
+
if not return_dict:
|
| 504 |
+
return (
|
| 505 |
+
last_layer,
|
| 506 |
+
*([key_value_states] if use_cache else []),
|
| 507 |
+
*([hidden_states] if output_hidden_states else []),
|
| 508 |
+
*([self_attention_p] if output_attentions else []),
|
| 509 |
+
*([cross_attention_p] if output_attentions else []),
|
| 510 |
+
)
|
| 511 |
+
|
| 512 |
+
return BaseModelOutputWithPastAndCrossAttentions(
|
| 513 |
+
last_hidden_state=last_layer,
|
| 514 |
+
past_key_values=key_value_states if use_cache else None,
|
| 515 |
+
hidden_states=hidden_states if output_hidden_states else None,
|
| 516 |
+
attentions=self_attention_p if output_attentions else None,
|
| 517 |
+
cross_attentions=cross_attention_p if output_attentions else None
|
| 518 |
+
)
|
| 519 |
+
|
| 520 |
+
|
| 521 |
+
def forward(
|
| 522 |
+
self,
|
| 523 |
+
input_ids: Optional[torch.LongTensor] = None,
|
| 524 |
+
attention_mask: Optional[torch.FloatTensor] = None,
|
| 525 |
+
decoder_input_ids: Optional[torch.LongTensor] = None,
|
| 526 |
+
decoder_attention_mask: Optional[torch.BoolTensor] = None,
|
| 527 |
+
encoder_outputs: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
|
| 528 |
+
past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
|
| 529 |
+
use_cache: Optional[bool] = None,
|
| 530 |
+
output_attentions: Optional[bool] = None,
|
| 531 |
+
output_hidden_states: Optional[bool] = None,
|
| 532 |
+
return_dict: Optional[bool] = None
|
| 533 |
+
):
|
| 534 |
+
|
| 535 |
+
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
| 536 |
+
|
| 537 |
+
decoder_input_ids = self.set_decoder_special_tokens(decoder_input_ids)
|
| 538 |
+
|
| 539 |
+
if encoder_outputs is None:
|
| 540 |
+
encoder_outputs = self.get_encoder_output(
|
| 541 |
+
input_ids, attention_mask, output_hidden_states, output_attentions, return_dict
|
| 542 |
+
)
|
| 543 |
+
elif return_dict and not isinstance(encoder_outputs, BaseModelOutput):
|
| 544 |
+
encoder_outputs = BaseModelOutput(
|
| 545 |
+
last_hidden_state=encoder_outputs[0],
|
| 546 |
+
hidden_states=encoder_outputs[1] if len(encoder_outputs) > 1 else None,
|
| 547 |
+
attentions=encoder_outputs[2] if len(encoder_outputs) > 2 else None,
|
| 548 |
+
)
|
| 549 |
+
|
| 550 |
+
decoder_outputs = self.get_decoder_output(
|
| 551 |
+
decoder_input_ids, encoder_outputs[0], attention_mask, past_key_values, use_cache, output_hidden_states, output_attentions, return_dict
|
| 552 |
+
)
|
| 553 |
+
|
| 554 |
+
if not return_dict:
|
| 555 |
+
return decoder_outputs + encoder_outputs
|
| 556 |
+
|
| 557 |
+
return Seq2SeqModelOutput(
|
| 558 |
+
last_hidden_state=decoder_outputs.last_hidden_state,
|
| 559 |
+
past_key_values=decoder_outputs.past_key_values,
|
| 560 |
+
decoder_hidden_states=decoder_outputs.hidden_states,
|
| 561 |
+
decoder_attentions=decoder_outputs.attentions,
|
| 562 |
+
cross_attentions=decoder_outputs.cross_attentions,
|
| 563 |
+
encoder_last_hidden_state=encoder_outputs.last_hidden_state,
|
| 564 |
+
encoder_hidden_states=encoder_outputs.hidden_states,
|
| 565 |
+
encoder_attentions=encoder_outputs.attentions,
|
| 566 |
+
)
|
| 567 |
+
|
| 568 |
+
|
| 569 |
+
class NorT5ForConditionalGeneration(NorT5Model):
|
| 570 |
+
|
| 571 |
+
def __init__(self, config):
|
| 572 |
+
super().__init__(config, add_lm_layer=True)
|
| 573 |
+
|
| 574 |
+
def forward(
|
| 575 |
+
self,
|
| 576 |
+
input_ids: Optional[torch.LongTensor] = None,
|
| 577 |
+
attention_mask: Optional[torch.FloatTensor] = None,
|
| 578 |
+
decoder_input_ids: Optional[torch.LongTensor] = None,
|
| 579 |
+
decoder_attention_mask: Optional[torch.BoolTensor] = None,
|
| 580 |
+
head_mask: Optional[torch.FloatTensor] = None,
|
| 581 |
+
decoder_head_mask: Optional[torch.FloatTensor] = None,
|
| 582 |
+
cross_attn_head_mask: Optional[torch.Tensor] = None,
|
| 583 |
+
encoder_outputs: Optional[Tuple[Tuple[torch.Tensor]]] = None,
|
| 584 |
+
past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None,
|
| 585 |
+
inputs_embeds: Optional[torch.FloatTensor] = None,
|
| 586 |
+
decoder_inputs_embeds: Optional[torch.FloatTensor] = None,
|
| 587 |
+
labels: Optional[torch.LongTensor] = None,
|
| 588 |
+
use_cache: Optional[bool] = None,
|
| 589 |
+
output_attentions: Optional[bool] = None,
|
| 590 |
+
output_hidden_states: Optional[bool] = None,
|
| 591 |
+
return_dict: Optional[bool] = None,
|
| 592 |
+
token_type_ids: Optional[torch.LongTensor] = None, # for compatibility
|
| 593 |
+
):
|
| 594 |
+
use_cache = use_cache if use_cache is not None else getattr(self.config, "use_cache", False)
|
| 595 |
+
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
| 596 |
+
|
| 597 |
+
if encoder_outputs is None:
|
| 598 |
+
encoder_outputs = self.get_encoder_output(
|
| 599 |
+
input_ids, attention_mask, output_hidden_states, output_attentions, return_dict
|
| 600 |
+
)
|
| 601 |
+
elif return_dict and not isinstance(encoder_outputs, BaseModelOutput):
|
| 602 |
+
encoder_outputs = BaseModelOutput(
|
| 603 |
+
last_hidden_state=encoder_outputs[0],
|
| 604 |
+
hidden_states=encoder_outputs[1] if len(encoder_outputs) > 1 else None,
|
| 605 |
+
attentions=encoder_outputs[2] if len(encoder_outputs) > 2 else None,
|
| 606 |
+
)
|
| 607 |
+
|
| 608 |
+
if labels is not None:
|
| 609 |
+
labels = self.set_decoder_special_tokens(labels)
|
| 610 |
+
|
| 611 |
+
if labels is not None and decoder_input_ids is None and decoder_inputs_embeds is None:
|
| 612 |
+
decoder_input_ids = self._shift_right(labels)
|
| 613 |
+
elif decoder_input_ids is not None:
|
| 614 |
+
decoder_input_ids = self.set_decoder_special_tokens(decoder_input_ids)
|
| 615 |
+
|
| 616 |
+
decoder_outputs = self.get_decoder_output(
|
| 617 |
+
decoder_input_ids, encoder_outputs[0], attention_mask, past_key_values, use_cache, output_hidden_states, output_attentions, return_dict
|
| 618 |
+
)
|
| 619 |
+
lm_logits = self.classifier(decoder_outputs[0])
|
| 620 |
+
|
| 621 |
+
loss = None
|
| 622 |
+
if labels is not None:
|
| 623 |
+
labels.masked_fill_(labels == self.pad_token_id, -100)
|
| 624 |
+
loss_fct = nn.CrossEntropyLoss(ignore_index=-100)
|
| 625 |
+
loss = loss_fct(lm_logits.flatten(0, 1), labels.flatten())
|
| 626 |
+
|
| 627 |
+
if not return_dict:
|
| 628 |
+
output = (lm_logits,) + decoder_outputs[1:] + encoder_outputs
|
| 629 |
+
return ((loss,) + output) if loss is not None else output
|
| 630 |
+
|
| 631 |
+
return Seq2SeqLMOutput(
|
| 632 |
+
loss=loss,
|
| 633 |
+
logits=lm_logits,
|
| 634 |
+
past_key_values=decoder_outputs.past_key_values,
|
| 635 |
+
decoder_hidden_states=decoder_outputs.hidden_states,
|
| 636 |
+
decoder_attentions=decoder_outputs.attentions,
|
| 637 |
+
cross_attentions=decoder_outputs.cross_attentions,
|
| 638 |
+
encoder_last_hidden_state=encoder_outputs.last_hidden_state,
|
| 639 |
+
encoder_hidden_states=encoder_outputs.hidden_states,
|
| 640 |
+
encoder_attentions=encoder_outputs.attentions,
|
| 641 |
+
)
|
| 642 |
+
|
| 643 |
+
def prepare_inputs_for_generation(
|
| 644 |
+
self,
|
| 645 |
+
input_ids,
|
| 646 |
+
past_key_values=None,
|
| 647 |
+
attention_mask=None,
|
| 648 |
+
head_mask=None,
|
| 649 |
+
decoder_head_mask=None,
|
| 650 |
+
cross_attn_head_mask=None,
|
| 651 |
+
use_cache=None,
|
| 652 |
+
encoder_outputs=None,
|
| 653 |
+
**kwargs,
|
| 654 |
+
):
|
| 655 |
+
if past_key_values is not None:
|
| 656 |
+
input_ids = input_ids[:, -1:]
|
| 657 |
+
|
| 658 |
+
return {
|
| 659 |
+
"decoder_input_ids": input_ids,
|
| 660 |
+
"past_key_values": past_key_values,
|
| 661 |
+
"encoder_outputs": encoder_outputs,
|
| 662 |
+
"attention_mask": attention_mask,
|
| 663 |
+
"head_mask": head_mask,
|
| 664 |
+
"decoder_head_mask": decoder_head_mask,
|
| 665 |
+
"cross_attn_head_mask": cross_attn_head_mask,
|
| 666 |
+
"use_cache": use_cache,
|
| 667 |
+
}
|
| 668 |
+
|
| 669 |
+
def prepare_decoder_input_ids_from_labels(self, labels: torch.Tensor):
|
| 670 |
+
return self._shift_right(labels)
|
| 671 |
+
|
| 672 |
+
def _reorder_cache(self, past_key_values, beam_idx):
|
| 673 |
+
# if decoder past is not included in output
|
| 674 |
+
# speedy decoding is disabled and no need to reorder
|
| 675 |
+
if past_key_values is None:
|
| 676 |
+
print("You might want to consider setting `use_cache=True` to speed up decoding")
|
| 677 |
+
return past_key_values
|
| 678 |
+
|
| 679 |
+
reordered_decoder_past = ()
|
| 680 |
+
for layer_past_states in past_key_values:
|
| 681 |
+
# get the correct batch idx from layer past batch dim
|
| 682 |
+
# batch dim of `past` is at 2nd position
|
| 683 |
+
reordered_layer_past_states = ()
|
| 684 |
+
for layer_past_state in layer_past_states:
|
| 685 |
+
# need to set correct `past` for each of the four key / value states
|
| 686 |
+
layer_past_state = layer_past_state.index_select(0, beam_idx.to(layer_past_state.device))
|
| 687 |
+
reordered_layer_past_states = reordered_layer_past_states + (layer_past_state,)
|
| 688 |
+
|
| 689 |
+
assert reordered_layer_past_states[0].shape == layer_past_states[0].shape
|
| 690 |
+
assert len(reordered_layer_past_states) == len(layer_past_states)
|
| 691 |
+
|
| 692 |
+
reordered_decoder_past = reordered_decoder_past + (reordered_layer_past_states,)
|
| 693 |
+
return reordered_decoder_past
|
| 694 |
+
|
| 695 |
+
|
| 696 |
+
class NorT5Encoder(NorT5Model):
|
| 697 |
+
def __init__(self, config):
|
| 698 |
+
super().__init__(config, add_lm_layer=False, add_decoder=True)
|
| 699 |
+
|
| 700 |
+
def forward(
|
| 701 |
+
self,
|
| 702 |
+
input_ids: Optional[torch.Tensor] = None,
|
| 703 |
+
attention_mask: Optional[torch.Tensor] = None,
|
| 704 |
+
output_hidden_states: Optional[bool] = None,
|
| 705 |
+
output_attentions: Optional[bool] = None,
|
| 706 |
+
return_dict: Optional[bool] = None,
|
| 707 |
+
):
|
| 708 |
+
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
| 709 |
+
|
| 710 |
+
return self.get_encoder_output(
|
| 711 |
+
input_ids, attention_mask, output_hidden_states, output_attentions, return_dict=return_dict
|
| 712 |
+
)
|
pytorch_model.bin
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:f8fa497ded7a57915cb00740cb43fd4170031dcddddab7ed148214de850f707c
|
| 3 |
+
size 1177063530
|
special_tokens_map.json
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
{"bos_token": "[BOS]", "eos_token": "[EOS]", "unk_token": "[UNK]", "sep_token": "[SEP]", "pad_token": "[PAD]", "cls_token": "[CLS]", "mask_token": "[MASK]"}
|
tokenizer.json
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
tokenizer_config.json
ADDED
|
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"tokenizer_class": "PreTrainedTokenizerFast"
|
| 3 |
+
}
|
| 4 |
+
|