vaibhavpandeyvpz commited on
Commit
3f66ae1
·
1 Parent(s): 5a81bcf

Fix handling on missing flash-attn

Browse files
Files changed (1) hide show
  1. wan/modules/model.py +3 -3
wan/modules/model.py CHANGED
@@ -6,7 +6,7 @@ import torch.nn as nn
6
  from diffusers.configuration_utils import ConfigMixin, register_to_config
7
  from diffusers.models.modeling_utils import ModelMixin
8
 
9
- from .attention import flash_attention
10
 
11
  __all__ = ["WanModel"]
12
 
@@ -141,7 +141,7 @@ class WanSelfAttention(nn.Module):
141
 
142
  q, k, v = qkv_fn(x)
143
 
144
- x = flash_attention(
145
  q=rope_apply(q, grid_sizes, freqs),
146
  k=rope_apply(k, grid_sizes, freqs),
147
  v=v,
@@ -172,7 +172,7 @@ class WanCrossAttention(WanSelfAttention):
172
  v = self.v(context).view(b, -1, n, d)
173
 
174
  # compute attention
175
- x = flash_attention(q, k, v, k_lens=context_lens)
176
 
177
  # output
178
  x = x.flatten(2)
 
6
  from diffusers.configuration_utils import ConfigMixin, register_to_config
7
  from diffusers.models.modeling_utils import ModelMixin
8
 
9
+ from .attention import attention
10
 
11
  __all__ = ["WanModel"]
12
 
 
141
 
142
  q, k, v = qkv_fn(x)
143
 
144
+ x = attention(
145
  q=rope_apply(q, grid_sizes, freqs),
146
  k=rope_apply(k, grid_sizes, freqs),
147
  v=v,
 
172
  v = self.v(context).view(b, -1, n, d)
173
 
174
  # compute attention
175
+ x = attention(q, k, v, k_lens=context_lens)
176
 
177
  # output
178
  x = x.flatten(2)