sayakpaul HF Staff commited on
Commit
7d0a465
·
verified ·
1 Parent(s): 3625a6b

Update fa3.py

Browse files
Files changed (1) hide show
  1. fa3.py +73 -25
fa3.py CHANGED
@@ -1,14 +1,67 @@
 
 
 
 
 
 
 
 
 
 
1
  import torch
2
  from kernels import get_kernel
3
 
4
 
5
  _flash_attn_func = get_kernel("kernels-community/vllm-flash-attn3").flash_attn_func
6
 
7
-
8
  @torch.library.custom_op("flash::flash_attn_func", mutates_args=())
9
- def flash_attn_func(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor) -> torch.Tensor:
10
- outputs, lse = _flash_attn_func(q, k, v)
11
- return outputs
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
12
 
13
  @flash_attn_func.register_fake
14
  def _(q, k, v, **kwargs):
@@ -16,26 +69,26 @@ def _(q, k, v, **kwargs):
16
  # 1. output: (batch, seq_len, num_heads, head_dim)
17
  # 2. softmax_lse: (batch, num_heads, seq_len) with dtype=torch.float32
18
  meta_q = torch.empty_like(q).contiguous()
19
- return meta_q #, q.new_empty((q.size(0), q.size(2), q.size(1)), dtype=torch.float32)
 
20
 
21
- # Copied FusedFluxAttnProcessor2_0 but using flash v3 instead of SDPA
22
- class FlashFusedFluxAttnProcessor3_0:
23
  """Attention processor used typically in processing the SD3-like self-attention projections."""
24
 
25
  def __call__(
26
  self,
27
  attn,
28
  hidden_states: torch.FloatTensor,
29
- encoder_hidden_states: torch.FloatTensor | None = None,
30
- attention_mask: torch.FloatTensor | None = None,
31
- image_rotary_emb: torch.Tensor | None = None,
32
  ) -> torch.FloatTensor:
33
  batch_size, _, _ = hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
34
 
35
  # `sample` projections.
36
- qkv = attn.to_qkv(hidden_states)
37
- split_size = qkv.shape[-1] // 3
38
- query, key, value = torch.split(qkv, split_size, dim=-1)
39
 
40
  inner_dim = key.shape[-1]
41
  head_dim = inner_dim // attn.heads
@@ -52,13 +105,9 @@ class FlashFusedFluxAttnProcessor3_0:
52
  # the attention in FluxSingleTransformerBlock does not use `encoder_hidden_states`
53
  # `context` projections.
54
  if encoder_hidden_states is not None:
55
- encoder_qkv = attn.to_added_qkv(encoder_hidden_states)
56
- split_size = encoder_qkv.shape[-1] // 3
57
- (
58
- encoder_hidden_states_query_proj,
59
- encoder_hidden_states_key_proj,
60
- encoder_hidden_states_value_proj,
61
- ) = torch.split(encoder_qkv, split_size, dim=-1)
62
 
63
  encoder_hidden_states_query_proj = encoder_hidden_states_query_proj.view(
64
  batch_size, -1, attn.heads, head_dim
@@ -87,10 +136,9 @@ class FlashFusedFluxAttnProcessor3_0:
87
  key = apply_rotary_emb(key, image_rotary_emb)
88
 
89
  # NB: transposes are necessary to match expected SDPA input shape
90
- hidden_states = flash_attn_func(
91
- query.transpose(1, 2),
92
- key.transpose(1, 2),
93
- value.transpose(1, 2))[0].transpose(1, 2)
94
 
95
  hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
96
  hidden_states = hidden_states.to(query.dtype)
@@ -109,4 +157,4 @@ class FlashFusedFluxAttnProcessor3_0:
109
 
110
  return hidden_states, encoder_hidden_states
111
  else:
112
- return hidden_states
 
1
+ """
2
+ Adapted from
3
+ https://github.com/huggingface/flux-fast/blob/156281514e2725782ffab9431d4004840f7e3b4d/utils/pipeline_utils.py#L87
4
+ """
5
+
6
+ import torch
7
+ from typing import List, Optional
8
+ import inspect
9
+
10
+
11
  import torch
12
  from kernels import get_kernel
13
 
14
 
15
  _flash_attn_func = get_kernel("kernels-community/vllm-flash-attn3").flash_attn_func
16
 
 
17
  @torch.library.custom_op("flash::flash_attn_func", mutates_args=())
18
+ def flash_attn_func(
19
+ q: torch.Tensor,
20
+ k: torch.Tensor,
21
+ v: torch.Tensor,
22
+ softmax_scale: Optional[float] = None,
23
+ causal: bool = False,
24
+ # probably wrong type for these 4
25
+ qv: Optional[float] = None,
26
+ q_descale: Optional[float] = None,
27
+ k_descale: Optional[float] = None,
28
+ v_descale: Optional[float] = None,
29
+ window_size: Optional[List[int]] = None,
30
+ sink_token_length: int = 0,
31
+ softcap: float = 0.0,
32
+ num_splits: int = 1,
33
+ # probably wrong type for this too
34
+ pack_gqa: Optional[float] = None,
35
+ deterministic: bool = False,
36
+ sm_margin: int = 0,
37
+ ) -> torch.Tensor: # Tuple[torch.Tensor, torch.Tensor]:
38
+ if window_size is None:
39
+ window_size = (-1, -1)
40
+ else:
41
+ window_size = tuple(window_size)
42
+
43
+ sig = inspect.signature(_flash_attn_func)
44
+ accepted = set(sig.parameters)
45
+ all_kwargs = {
46
+ "softmax_scale": softmax_scale,
47
+ "causal": causal,
48
+ "qv": qv,
49
+ "q_descale": q_descale,
50
+ "k_descale": k_descale,
51
+ "v_descale": v_descale,
52
+ "window_size": window_size,
53
+ "sink_token_length": sink_token_length,
54
+ "softcap": softcap,
55
+ "num_splits": num_splits,
56
+ "pack_gqa": pack_gqa,
57
+ "deterministic": deterministic,
58
+ "sm_margin": sm_margin,
59
+ }
60
+ kwargs = {k: v for k, v in all_kwargs.items() if k in accepted}
61
+
62
+ outputs = _flash_attn_func(q, k, v, **kwargs)
63
+ return outputs[0]
64
+
65
 
66
  @flash_attn_func.register_fake
67
  def _(q, k, v, **kwargs):
 
69
  # 1. output: (batch, seq_len, num_heads, head_dim)
70
  # 2. softmax_lse: (batch, num_heads, seq_len) with dtype=torch.float32
71
  meta_q = torch.empty_like(q).contiguous()
72
+ return meta_q # , q.new_empty((q.size(0), q.size(2), q.size(1)), dtype=torch.float32)
73
+
74
 
75
+ class FlashFluxAttnProcessor3_0:
 
76
  """Attention processor used typically in processing the SD3-like self-attention projections."""
77
 
78
  def __call__(
79
  self,
80
  attn,
81
  hidden_states: torch.FloatTensor,
82
+ encoder_hidden_states: torch.FloatTensor = None,
83
+ attention_mask: Optional[torch.FloatTensor] = None,
84
+ image_rotary_emb: Optional[torch.Tensor] = None,
85
  ) -> torch.FloatTensor:
86
  batch_size, _, _ = hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
87
 
88
  # `sample` projections.
89
+ query = attn.to_q(hidden_states)
90
+ key = attn.to_k(hidden_states)
91
+ value = attn.to_v(hidden_states)
92
 
93
  inner_dim = key.shape[-1]
94
  head_dim = inner_dim // attn.heads
 
105
  # the attention in FluxSingleTransformerBlock does not use `encoder_hidden_states`
106
  # `context` projections.
107
  if encoder_hidden_states is not None:
108
+ encoder_hidden_states_query_proj = attn.add_q_proj(encoder_hidden_states)
109
+ encoder_hidden_states_key_proj = attn.add_k_proj(encoder_hidden_states)
110
+ encoder_hidden_states_value_proj = attn.add_v_proj(encoder_hidden_states)
 
 
 
 
111
 
112
  encoder_hidden_states_query_proj = encoder_hidden_states_query_proj.view(
113
  batch_size, -1, attn.heads, head_dim
 
136
  key = apply_rotary_emb(key, image_rotary_emb)
137
 
138
  # NB: transposes are necessary to match expected SDPA input shape
139
+ hidden_states = flash_attn_func(query.transpose(1, 2), key.transpose(1, 2), value.transpose(1, 2))[
140
+ 0
141
+ ].transpose(1, 2)
 
142
 
143
  hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
144
  hidden_states = hidden_states.to(query.dtype)
 
157
 
158
  return hidden_states, encoder_hidden_states
159
  else:
160
+ return hidden_states