sayakpaul HF Staff commited on
Commit
6684d62
·
1 Parent(s): 4b0fe46
Files changed (1) hide show
  1. app.py +7 -1
app.py CHANGED
@@ -5,6 +5,7 @@ from diffusers import DiffusionPipeline
5
  from optimization import compile_transformer
6
  from hub_utils import _push_compiled_graph_to_hub
7
  from huggingface_hub import whoami
 
8
 
9
  # --- Model Loading ---
10
  dtype = torch.bfloat16
@@ -14,7 +15,7 @@ device = "cuda" if torch.cuda.is_available() else "cpu"
14
  pipe = DiffusionPipeline.from_pretrained("black-forest-labs/Flux.1-Dev", torch_dtype=dtype).to(device)
15
 
16
 
17
- @spaces.GPU(duration=120)
18
  def push_to_hub(repo_id, filename, oauth_token: gr.OAuthToken, progress=gr.Progress(track_tqdm=True)):
19
  if not filename.endswith(".pt2"):
20
  raise NotImplementedError("The filename must end with a `.pt2` extension.")
@@ -24,7 +25,12 @@ def push_to_hub(repo_id, filename, oauth_token: gr.OAuthToken, progress=gr.Progr
24
  _ = whoami(oauth_token.token)
25
 
26
  # --- Ahead-of-time compilation ---
 
27
  compiled_transformer = compile_transformer(pipe, prompt="prompt")
 
 
 
 
28
 
29
  token = oauth_token.token
30
  out = _push_compiled_graph_to_hub(
 
5
  from optimization import compile_transformer
6
  from hub_utils import _push_compiled_graph_to_hub
7
  from huggingface_hub import whoami
8
+ import time
9
 
10
  # --- Model Loading ---
11
  dtype = torch.bfloat16
 
15
  pipe = DiffusionPipeline.from_pretrained("black-forest-labs/Flux.1-Dev", torch_dtype=dtype).to(device)
16
 
17
 
18
+ @spaces.GPU
19
  def push_to_hub(repo_id, filename, oauth_token: gr.OAuthToken, progress=gr.Progress(track_tqdm=True)):
20
  if not filename.endswith(".pt2"):
21
  raise NotImplementedError("The filename must end with a `.pt2` extension.")
 
25
  _ = whoami(oauth_token.token)
26
 
27
  # --- Ahead-of-time compilation ---
28
+ start = time.perf_counter()
29
  compiled_transformer = compile_transformer(pipe, prompt="prompt")
30
+ if torch.cuda.is_available():
31
+ torch.cuda.synchronize()
32
+ end = time.perf_counter()
33
+ print(f"Compilation took: {start - time} seconds.")
34
 
35
  token = oauth_token.token
36
  out = _push_compiled_graph_to_hub(