cbensimon's picture
cbensimon HF Staff
fp8e4m3 (disable aoti)
c72054d
raw
history blame
1.08 kB
"""
"""
from typing import Any
from typing import Callable
from typing import ParamSpec
import spaces
import torch
from fa3 import FlashFusedFluxAttnProcessor3_0
P = ParamSpec('P')
INDUCTOR_CONFIGS = {
'conv_1x1_as_mm': True,
'epilogue_fusion': False,
'coordinate_descent_tuning': True,
'coordinate_descent_check_all_directions': True,
'max_autotune': True,
'triton.cudagraphs': True,
}
def optimize_pipeline_(pipeline: Callable[P, Any], *args: P.args, **kwargs: P.kwargs):
@spaces.GPU(duration=1500)
def compile_transformer():
with spaces.aoti_capture(pipeline.transformer) as call:
pipeline(*args, **kwargs)
exported = torch.export.export(
mod=pipeline.transformer,
args=call.args,
kwargs=call.kwargs,
)
return spaces.aoti_compile(exported, INDUCTOR_CONFIGS)
pipeline.transformer.fuse_qkv_projections()
pipeline.transformer.set_attn_processor(FlashFusedFluxAttnProcessor3_0())
# spaces.aoti_apply(compile_transformer(), pipeline.transformer)