Spaces:
Running
on
Zero
Running
on
Zero
Commit
·
3f66ae1
1
Parent(s):
5a81bcf
Fix handling on missing flash-attn
Browse files- wan/modules/model.py +3 -3
wan/modules/model.py
CHANGED
|
@@ -6,7 +6,7 @@ import torch.nn as nn
|
|
| 6 |
from diffusers.configuration_utils import ConfigMixin, register_to_config
|
| 7 |
from diffusers.models.modeling_utils import ModelMixin
|
| 8 |
|
| 9 |
-
from .attention import
|
| 10 |
|
| 11 |
__all__ = ["WanModel"]
|
| 12 |
|
|
@@ -141,7 +141,7 @@ class WanSelfAttention(nn.Module):
|
|
| 141 |
|
| 142 |
q, k, v = qkv_fn(x)
|
| 143 |
|
| 144 |
-
x =
|
| 145 |
q=rope_apply(q, grid_sizes, freqs),
|
| 146 |
k=rope_apply(k, grid_sizes, freqs),
|
| 147 |
v=v,
|
|
@@ -172,7 +172,7 @@ class WanCrossAttention(WanSelfAttention):
|
|
| 172 |
v = self.v(context).view(b, -1, n, d)
|
| 173 |
|
| 174 |
# compute attention
|
| 175 |
-
x =
|
| 176 |
|
| 177 |
# output
|
| 178 |
x = x.flatten(2)
|
|
|
|
| 6 |
from diffusers.configuration_utils import ConfigMixin, register_to_config
|
| 7 |
from diffusers.models.modeling_utils import ModelMixin
|
| 8 |
|
| 9 |
+
from .attention import attention
|
| 10 |
|
| 11 |
__all__ = ["WanModel"]
|
| 12 |
|
|
|
|
| 141 |
|
| 142 |
q, k, v = qkv_fn(x)
|
| 143 |
|
| 144 |
+
x = attention(
|
| 145 |
q=rope_apply(q, grid_sizes, freqs),
|
| 146 |
k=rope_apply(k, grid_sizes, freqs),
|
| 147 |
v=v,
|
|
|
|
| 172 |
v = self.v(context).view(b, -1, n, d)
|
| 173 |
|
| 174 |
# compute attention
|
| 175 |
+
x = attention(q, k, v, k_lens=context_lens)
|
| 176 |
|
| 177 |
# output
|
| 178 |
x = x.flatten(2)
|