small fix
Browse files- README.md +1 -1
- modeling_lsg_distilbert.py +20 -7
README.md
CHANGED
|
@@ -6,7 +6,7 @@ tags:
|
|
| 6 |
---
|
| 7 |
|
| 8 |
# LSG model
|
| 9 |
-
**Transformers >= 4.
|
| 10 |
**This model relies on a custom modeling file, you need to add trust_remote_code=True**\
|
| 11 |
**See [\#13467](https://github.com/huggingface/transformers/pull/13467)**
|
| 12 |
|
|
|
|
| 6 |
---
|
| 7 |
|
| 8 |
# LSG model
|
| 9 |
+
**Transformers >= 4.36.1**\
|
| 10 |
**This model relies on a custom modeling file, you need to add trust_remote_code=True**\
|
| 11 |
**See [\#13467](https://github.com/huggingface/transformers/pull/13467)**
|
| 12 |
|
modeling_lsg_distilbert.py
CHANGED
|
@@ -100,14 +100,22 @@ class LSGEmbeddings(Embeddings):
|
|
| 100 |
|
| 101 |
self.block_size = config.block_size
|
| 102 |
|
| 103 |
-
def forward(self, input_ids,
|
| 104 |
"""
|
| 105 |
Parameters:
|
| 106 |
-
input_ids
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 107 |
Returns: torch.tensor(bs, max_seq_length, dim) The embedded tokens (plus position embeddings, no token_type
|
| 108 |
embeddings)
|
| 109 |
"""
|
| 110 |
-
|
|
|
|
|
|
|
|
|
|
| 111 |
|
| 112 |
# Setting the position-ids to the registered buffer in constructor, it helps
|
| 113 |
# when tracing the model without passing position-ids, solves
|
|
@@ -116,9 +124,8 @@ class LSGEmbeddings(Embeddings):
|
|
| 116 |
position_ids = self.position_ids[:, :seq_length]
|
| 117 |
else:
|
| 118 |
position_ids = torch.arange(seq_length, dtype=torch.long, device=input_ids.device) # (max_seq_length)
|
| 119 |
-
position_ids = position_ids.unsqueeze(0).
|
| 120 |
-
|
| 121 |
-
word_embeddings = self.word_embeddings(input_ids) if input_ids is not None else inputs_embeds
|
| 122 |
position_embeddings = self.position_embeddings(position_ids) # (bs, max_seq_length, dim)
|
| 123 |
word_embeddings = word_embeddings + position_embeddings # (bs, max_seq_length, dim)
|
| 124 |
|
|
@@ -853,6 +860,12 @@ class LSGDistilBertModel(LSGDistilBertPreTrainedModel, DistilBertModel):
|
|
| 853 |
self.transformer = LSGTransformer(config) # Encoder
|
| 854 |
self.num_global_tokens = config.num_global_tokens
|
| 855 |
# Initialize weights and apply final processing
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 856 |
self.post_init()
|
| 857 |
|
| 858 |
|
|
@@ -952,4 +965,4 @@ try:
|
|
| 952 |
str_to_class(value.split(".")[-1]).register_for_auto_class(key)
|
| 953 |
except:
|
| 954 |
warn("AutoRegister isn't available, you'll have to manually copy modeling.py after .save_pretrained(...).")
|
| 955 |
-
warn("Update to transformers >= 4.
|
|
|
|
| 100 |
|
| 101 |
self.block_size = config.block_size
|
| 102 |
|
| 103 |
+
def forward(self, input_ids: torch.Tensor, input_embeds: Optional[torch.Tensor] = None) -> torch.Tensor:
|
| 104 |
"""
|
| 105 |
Parameters:
|
| 106 |
+
input_ids (torch.Tensor):
|
| 107 |
+
torch.tensor(bs, max_seq_length) The token ids to embed.
|
| 108 |
+
input_embeds (*optional*, torch.Tensor):
|
| 109 |
+
The pre-computed word embeddings. Can only be passed if the input ids are `None`.
|
| 110 |
+
|
| 111 |
+
|
| 112 |
Returns: torch.tensor(bs, max_seq_length, dim) The embedded tokens (plus position embeddings, no token_type
|
| 113 |
embeddings)
|
| 114 |
"""
|
| 115 |
+
if input_ids is not None:
|
| 116 |
+
word_embeddings = self.word_embeddings(input_ids) # (bs, max_seq_length, dim)
|
| 117 |
+
|
| 118 |
+
seq_length = word_embeddings.size(1)
|
| 119 |
|
| 120 |
# Setting the position-ids to the registered buffer in constructor, it helps
|
| 121 |
# when tracing the model without passing position-ids, solves
|
|
|
|
| 124 |
position_ids = self.position_ids[:, :seq_length]
|
| 125 |
else:
|
| 126 |
position_ids = torch.arange(seq_length, dtype=torch.long, device=input_ids.device) # (max_seq_length)
|
| 127 |
+
position_ids = position_ids.unsqueeze(0).expand_as(input_ids) # (bs, max_seq_length)
|
| 128 |
+
|
|
|
|
| 129 |
position_embeddings = self.position_embeddings(position_ids) # (bs, max_seq_length, dim)
|
| 130 |
word_embeddings = word_embeddings + position_embeddings # (bs, max_seq_length, dim)
|
| 131 |
|
|
|
|
| 860 |
self.transformer = LSGTransformer(config) # Encoder
|
| 861 |
self.num_global_tokens = config.num_global_tokens
|
| 862 |
# Initialize weights and apply final processing
|
| 863 |
+
|
| 864 |
+
self._use_flash_attention_2 = config._attn_implementation == "flash_attention_2"
|
| 865 |
+
if self._use_flash_attention_2:
|
| 866 |
+
logger.warning(
|
| 867 |
+
"[WARNING flash-attention]: LSG doesnt support flash-attention currently"
|
| 868 |
+
)
|
| 869 |
self.post_init()
|
| 870 |
|
| 871 |
|
|
|
|
| 965 |
str_to_class(value.split(".")[-1]).register_for_auto_class(key)
|
| 966 |
except:
|
| 967 |
warn("AutoRegister isn't available, you'll have to manually copy modeling.py after .save_pretrained(...).")
|
| 968 |
+
warn("Update to transformers >= 4.36.1 to fix.")
|