Global CLS attention
#13
by
Markus28
- opened
- configuration_bert.py +2 -0
- modeling_bert.py +8 -3
configuration_bert.py
CHANGED
|
@@ -129,6 +129,7 @@ class JinaBertConfig(PretrainedConfig):
|
|
| 129 |
feed_forward_type="original",
|
| 130 |
emb_pooler=None,
|
| 131 |
attn_implementation='torch',
|
|
|
|
| 132 |
**kwargs,
|
| 133 |
):
|
| 134 |
super().__init__(pad_token_id=pad_token_id, **kwargs)
|
|
@@ -151,6 +152,7 @@ class JinaBertConfig(PretrainedConfig):
|
|
| 151 |
self.feed_forward_type = feed_forward_type
|
| 152 |
self.emb_pooler = emb_pooler
|
| 153 |
self.attn_implementation = attn_implementation
|
|
|
|
| 154 |
|
| 155 |
class JinaBertOnnxConfig(OnnxConfig):
|
| 156 |
@property
|
|
|
|
| 129 |
feed_forward_type="original",
|
| 130 |
emb_pooler=None,
|
| 131 |
attn_implementation='torch',
|
| 132 |
+
cls_bias=None,
|
| 133 |
**kwargs,
|
| 134 |
):
|
| 135 |
super().__init__(pad_token_id=pad_token_id, **kwargs)
|
|
|
|
| 152 |
self.feed_forward_type = feed_forward_type
|
| 153 |
self.emb_pooler = emb_pooler
|
| 154 |
self.attn_implementation = attn_implementation
|
| 155 |
+
self.cls_bias = cls_bias
|
| 156 |
|
| 157 |
class JinaBertOnnxConfig(OnnxConfig):
|
| 158 |
@property
|
modeling_bert.py
CHANGED
|
@@ -701,12 +701,12 @@ class JinaBertEncoder(nn.Module):
|
|
| 701 |
self.num_attention_heads = config.num_attention_heads
|
| 702 |
self.register_buffer(
|
| 703 |
"alibi",
|
| 704 |
-
self.rebuild_alibi_tensor(size=config.max_position_embeddings),
|
| 705 |
persistent=False,
|
| 706 |
)
|
| 707 |
|
| 708 |
def rebuild_alibi_tensor(
|
| 709 |
-
self, size: int, device: Optional[Union[torch.device, str]] = None
|
| 710 |
):
|
| 711 |
# Alibi
|
| 712 |
# Following https://github.com/ofirpress/attention_with_linear_biases/issues/5 (Implementation 1)
|
|
@@ -747,6 +747,10 @@ class JinaBertEncoder(nn.Module):
|
|
| 747 |
alibi = alibi.unsqueeze(0)
|
| 748 |
assert alibi.shape == torch.Size([1, n_heads, size, size])
|
| 749 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 750 |
self._current_alibi_size = size
|
| 751 |
return alibi
|
| 752 |
|
|
@@ -778,7 +782,8 @@ class JinaBertEncoder(nn.Module):
|
|
| 778 |
)
|
| 779 |
self.register_buffer(
|
| 780 |
"alibi",
|
| 781 |
-
self.rebuild_alibi_tensor(size=seqlen,
|
|
|
|
| 782 |
hidden_states.dtype
|
| 783 |
),
|
| 784 |
persistent=False,
|
|
|
|
| 701 |
self.num_attention_heads = config.num_attention_heads
|
| 702 |
self.register_buffer(
|
| 703 |
"alibi",
|
| 704 |
+
self.rebuild_alibi_tensor(size=config.max_position_embeddings, cls_bias=config.cls_bias),
|
| 705 |
persistent=False,
|
| 706 |
)
|
| 707 |
|
| 708 |
def rebuild_alibi_tensor(
|
| 709 |
+
self, size: int, device: Optional[Union[torch.device, str]] = None, cls_bias=None
|
| 710 |
):
|
| 711 |
# Alibi
|
| 712 |
# Following https://github.com/ofirpress/attention_with_linear_biases/issues/5 (Implementation 1)
|
|
|
|
| 747 |
alibi = alibi.unsqueeze(0)
|
| 748 |
assert alibi.shape == torch.Size([1, n_heads, size, size])
|
| 749 |
|
| 750 |
+
if cls_bias is not None:
|
| 751 |
+
alibi[:, :, 0, :] = cls_bias
|
| 752 |
+
alibi[:, :, :, 0] = cls_bias
|
| 753 |
+
|
| 754 |
self._current_alibi_size = size
|
| 755 |
return alibi
|
| 756 |
|
|
|
|
| 782 |
)
|
| 783 |
self.register_buffer(
|
| 784 |
"alibi",
|
| 785 |
+
self.rebuild_alibi_tensor(size=seqlen, cls_bias=self.config.cls_bias,
|
| 786 |
+
device=hidden_states.device).to(
|
| 787 |
hidden_states.dtype
|
| 788 |
),
|
| 789 |
persistent=False,
|