File size: 1,461 Bytes
2f7723f
937a94e
 
 
 
 
 
4b0fe46
937a94e
4b0fe46
937a94e
f9f24d7
 
937a94e
4b0fe46
 
937a94e
 
 
4b0fe46
 
 
 
b6c7a86
4b0fe46
937a94e
 
4b0fe46
f5a3617
c5db835
 
 
 
da862e5
f9f24d7
c5db835
 
 
4b0fe46
c5db835
da862e5
c5db835
4b0fe46
f61fb8b
968b96f
f61fb8b
4b0fe46
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
import spaces
from typing import Any
from typing import Callable
from typing import ParamSpec
import torch
from torch.utils._pytree import tree_map

P = ParamSpec("P")

TRANSFORMER_HIDDEN_DIM = torch.export.Dim("hidden", min=4096, max=8212)

# Specific to Flux. More about this is available in
# https://huggingface.co/blog/zerogpu-aoti
TRANSFORMER_DYNAMIC_SHAPES = {
    "hidden_states": {1: TRANSFORMER_HIDDEN_DIM},
    "img_ids": {0: TRANSFORMER_HIDDEN_DIM},
}

INDUCTOR_CONFIGS = {
    "conv_1x1_as_mm": True,
    "epilogue_fusion": False,
    "coordinate_descent_tuning": True,
    "coordinate_descent_check_all_directions": True,
    # "max_autotune": True, # not very helpful.
    "triton.cudagraphs": True,
}


def compile_transformer(pipeline: Callable[P, Any], *args: P.args, **kwargs: P.kwargs):
    def f():
        with spaces.aoti_capture(pipeline.transformer) as call:
            pipeline(*args, **kwargs)

        print("Inputs captured.")
        dynamic_shapes = tree_map(lambda v: None, call.kwargs)
        dynamic_shapes |= TRANSFORMER_DYNAMIC_SHAPES

        exported = torch.export.export(
            mod=pipeline.transformer, args=call.args, kwargs=call.kwargs, dynamic_shapes=dynamic_shapes
        )
        print("Export done.")
        return spaces.aoti_compile(exported, INDUCTOR_CONFIGS)

    print(f"{pipeline.transformer.device=}")
    compiled_transformer = f()
    print("Compilation done.")
    return compiled_transformer