Spaces:
Running
on
Zero
Running
on
Zero
File size: 1,461 Bytes
2f7723f 937a94e 4b0fe46 937a94e 4b0fe46 937a94e f9f24d7 937a94e 4b0fe46 937a94e 4b0fe46 b6c7a86 4b0fe46 937a94e 4b0fe46 f5a3617 c5db835 da862e5 f9f24d7 c5db835 4b0fe46 c5db835 da862e5 c5db835 4b0fe46 f61fb8b 968b96f f61fb8b 4b0fe46 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 |
import spaces
from typing import Any
from typing import Callable
from typing import ParamSpec
import torch
from torch.utils._pytree import tree_map
P = ParamSpec("P")
TRANSFORMER_HIDDEN_DIM = torch.export.Dim("hidden", min=4096, max=8212)
# Specific to Flux. More about this is available in
# https://huggingface.co/blog/zerogpu-aoti
TRANSFORMER_DYNAMIC_SHAPES = {
"hidden_states": {1: TRANSFORMER_HIDDEN_DIM},
"img_ids": {0: TRANSFORMER_HIDDEN_DIM},
}
INDUCTOR_CONFIGS = {
"conv_1x1_as_mm": True,
"epilogue_fusion": False,
"coordinate_descent_tuning": True,
"coordinate_descent_check_all_directions": True,
# "max_autotune": True, # not very helpful.
"triton.cudagraphs": True,
}
def compile_transformer(pipeline: Callable[P, Any], *args: P.args, **kwargs: P.kwargs):
def f():
with spaces.aoti_capture(pipeline.transformer) as call:
pipeline(*args, **kwargs)
print("Inputs captured.")
dynamic_shapes = tree_map(lambda v: None, call.kwargs)
dynamic_shapes |= TRANSFORMER_DYNAMIC_SHAPES
exported = torch.export.export(
mod=pipeline.transformer, args=call.args, kwargs=call.kwargs, dynamic_shapes=dynamic_shapes
)
print("Export done.")
return spaces.aoti_compile(exported, INDUCTOR_CONFIGS)
print(f"{pipeline.transformer.device=}")
compiled_transformer = f()
print("Compilation done.")
return compiled_transformer
|