Spaces:
Sleeping
Sleeping
up
Browse files- app.py +25 -20
- hub_utils.py +6 -7
- optimization.py +14 -16
app.py
CHANGED
|
@@ -11,15 +11,14 @@ dtype = torch.bfloat16
|
|
| 11 |
device = "cuda" if torch.cuda.is_available() else "cpu"
|
| 12 |
|
| 13 |
# Load the model pipeline
|
| 14 |
-
pipe = DiffusionPipeline.from_pretrained(
|
| 15 |
-
|
| 16 |
-
).to(device)
|
| 17 |
|
| 18 |
@spaces.GPU(duration=120)
|
| 19 |
-
def push_to_hub(repo_id, filename, oauth_token: gr.OAuthToken):
|
| 20 |
if not filename.endswith(".pt2"):
|
| 21 |
raise NotImplementedError("The filename must end with a `.pt2` extension.")
|
| 22 |
-
|
| 23 |
# this will throw if token is invalid
|
| 24 |
try:
|
| 25 |
_ = whoami(oauth_token.token)
|
|
@@ -27,12 +26,9 @@ def push_to_hub(repo_id, filename, oauth_token: gr.OAuthToken):
|
|
| 27 |
# --- Ahead-of-time compilation ---
|
| 28 |
compiled_transformer = compile_transformer(pipe, prompt="prompt")
|
| 29 |
|
| 30 |
-
token = oauth_token.token
|
| 31 |
out = _push_compiled_graph_to_hub(
|
| 32 |
-
compiled_transformer.archive_file,
|
| 33 |
-
repo_id=repo_id,
|
| 34 |
-
token=token,
|
| 35 |
-
path_in_repo=filename
|
| 36 |
)
|
| 37 |
if not isinstance(out, str) and hasattr(out, "commit_url"):
|
| 38 |
commit_url = out.commit_url
|
|
@@ -40,9 +36,12 @@ def push_to_hub(repo_id, filename, oauth_token: gr.OAuthToken):
|
|
| 40 |
else:
|
| 41 |
return out
|
| 42 |
except Exception as e:
|
| 43 |
-
raise gr.Error(
|
| 44 |
-
|
| 45 |
-
|
|
|
|
|
|
|
|
|
|
| 46 |
#col-container {
|
| 47 |
margin: 0 auto;
|
| 48 |
max-width: 520px;
|
|
@@ -50,8 +49,12 @@ css="""
|
|
| 50 |
"""
|
| 51 |
with gr.Blocks(css=css) as demo:
|
| 52 |
with gr.Column(elem_id="col-container"):
|
| 53 |
-
gr.Markdown(
|
| 54 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 55 |
|
| 56 |
repo_id = gr.Textbox(label="repo_id", placeholder="e.g. sayakpaul/qwen-aot")
|
| 57 |
filename = gr.Textbox(label="filename", placeholder="e.g. compiled.pt2")
|
|
@@ -62,17 +65,19 @@ with gr.Blocks(css=css) as demo:
|
|
| 62 |
|
| 63 |
run.click(push_to_hub, inputs=[repo_id, filename], outputs=[markdown_out])
|
| 64 |
|
|
|
|
| 65 |
def swap_visibilty(profile: gr.OAuthProfile | None):
|
| 66 |
return gr.update(elem_classes=["main_ui_logged_in"]) if profile else gr.update(elem_classes=["main_ui_logged_out"])
|
| 67 |
-
|
| 68 |
-
|
|
|
|
| 69 |
.main_ui_logged_out{opacity: 0.3; pointer-events: none; margin: 0 auto; max-width: 520px}
|
| 70 |
-
|
| 71 |
with gr.Blocks(css=css_login) as demo_login:
|
| 72 |
gr.LoginButton()
|
| 73 |
with gr.Column(elem_classes="main_ui_logged_out") as main_ui:
|
| 74 |
demo.render()
|
| 75 |
demo_login.load(fn=swap_visibilty, outputs=main_ui)
|
| 76 |
-
|
| 77 |
demo_login.queue()
|
| 78 |
-
demo_login.launch()
|
|
|
|
| 11 |
device = "cuda" if torch.cuda.is_available() else "cpu"
|
| 12 |
|
| 13 |
# Load the model pipeline
|
| 14 |
+
pipe = DiffusionPipeline.from_pretrained("black-forest-labs/Flux.1-Dev", torch_dtype=dtype).to(device)
|
| 15 |
+
|
|
|
|
| 16 |
|
| 17 |
@spaces.GPU(duration=120)
|
| 18 |
+
def push_to_hub(repo_id, filename, oauth_token: gr.OAuthToken, progress=gr.Progress(track_tqdm=True)):
|
| 19 |
if not filename.endswith(".pt2"):
|
| 20 |
raise NotImplementedError("The filename must end with a `.pt2` extension.")
|
| 21 |
+
|
| 22 |
# this will throw if token is invalid
|
| 23 |
try:
|
| 24 |
_ = whoami(oauth_token.token)
|
|
|
|
| 26 |
# --- Ahead-of-time compilation ---
|
| 27 |
compiled_transformer = compile_transformer(pipe, prompt="prompt")
|
| 28 |
|
| 29 |
+
token = oauth_token.token
|
| 30 |
out = _push_compiled_graph_to_hub(
|
| 31 |
+
compiled_transformer.archive_file, repo_id=repo_id, token=token, path_in_repo=filename
|
|
|
|
|
|
|
|
|
|
| 32 |
)
|
| 33 |
if not isinstance(out, str) and hasattr(out, "commit_url"):
|
| 34 |
commit_url = out.commit_url
|
|
|
|
| 36 |
else:
|
| 37 |
return out
|
| 38 |
except Exception as e:
|
| 39 |
+
raise gr.Error(
|
| 40 |
+
f"""Oops, you forgot to login. Please use the loggin button on the top left to migrate your repo {e}"""
|
| 41 |
+
)
|
| 42 |
+
|
| 43 |
+
|
| 44 |
+
css = """
|
| 45 |
#col-container {
|
| 46 |
margin: 0 auto;
|
| 47 |
max-width: 520px;
|
|
|
|
| 49 |
"""
|
| 50 |
with gr.Blocks(css=css) as demo:
|
| 51 |
with gr.Column(elem_id="col-container"):
|
| 52 |
+
gr.Markdown(
|
| 53 |
+
"## Compile [Flux.1-Dev](https://hf.co/black-forest-labs/Flux.1-Dev) graph ahead of time & push to the Hub"
|
| 54 |
+
)
|
| 55 |
+
gr.Markdown(
|
| 56 |
+
"Enter a **repo_id** and **filename**. This repo automatically compiles the Flux.1-Dev model ahead of time. Read more about this in [this post](https://huggingface.co/blog/zerogpu-aoti)."
|
| 57 |
+
)
|
| 58 |
|
| 59 |
repo_id = gr.Textbox(label="repo_id", placeholder="e.g. sayakpaul/qwen-aot")
|
| 60 |
filename = gr.Textbox(label="filename", placeholder="e.g. compiled.pt2")
|
|
|
|
| 65 |
|
| 66 |
run.click(push_to_hub, inputs=[repo_id, filename], outputs=[markdown_out])
|
| 67 |
|
| 68 |
+
|
| 69 |
def swap_visibilty(profile: gr.OAuthProfile | None):
|
| 70 |
return gr.update(elem_classes=["main_ui_logged_in"]) if profile else gr.update(elem_classes=["main_ui_logged_out"])
|
| 71 |
+
|
| 72 |
+
|
| 73 |
+
css_login = """
|
| 74 |
.main_ui_logged_out{opacity: 0.3; pointer-events: none; margin: 0 auto; max-width: 520px}
|
| 75 |
+
"""
|
| 76 |
with gr.Blocks(css=css_login) as demo_login:
|
| 77 |
gr.LoginButton()
|
| 78 |
with gr.Column(elem_classes="main_ui_logged_out") as main_ui:
|
| 79 |
demo.render()
|
| 80 |
demo_login.load(fn=swap_visibilty, outputs=main_ui)
|
| 81 |
+
|
| 82 |
demo_login.queue()
|
| 83 |
+
demo_login.launch()
|
hub_utils.py
CHANGED
|
@@ -1,10 +1,11 @@
|
|
| 1 |
from io import BytesIO
|
| 2 |
from huggingface_hub import create_repo, upload_file
|
| 3 |
-
import tempfile
|
| 4 |
import os
|
| 5 |
|
| 6 |
DEFAULT_ARCHIVE_FILENAME = "archived_graph.pt2"
|
| 7 |
|
|
|
|
| 8 |
def _push_compiled_graph_to_hub(archive: BytesIO, repo_id, **kwargs):
|
| 9 |
if not isinstance(archive, BytesIO):
|
| 10 |
raise NotImplementedError("Incorrect type of `archive` provided.")
|
|
@@ -13,9 +14,7 @@ def _push_compiled_graph_to_hub(archive: BytesIO, repo_id, **kwargs):
|
|
| 13 |
private = kwargs.pop("private", False)
|
| 14 |
path_in_repo = kwargs.pop("path_in_repo", DEFAULT_ARCHIVE_FILENAME)
|
| 15 |
token = kwargs.pop("token")
|
| 16 |
-
repo_id = create_repo(
|
| 17 |
-
repo_id, private=private, exist_ok=True, token=token
|
| 18 |
-
).repo_id
|
| 19 |
|
| 20 |
with tempfile.TemporaryDirectory() as tmpdir:
|
| 21 |
output_path = os.path.join(tmpdir, os.path.basename(path_in_repo))
|
|
@@ -24,8 +23,8 @@ def _push_compiled_graph_to_hub(archive: BytesIO, repo_id, **kwargs):
|
|
| 24 |
|
| 25 |
try:
|
| 26 |
info = upload_file(
|
| 27 |
-
repo_id=repo_id,
|
| 28 |
-
path_or_fileobj=output_path,
|
| 29 |
path_in_repo=os.path.basename(path_in_repo),
|
| 30 |
commit_message=commit_message,
|
| 31 |
token=token,
|
|
@@ -33,4 +32,4 @@ def _push_compiled_graph_to_hub(archive: BytesIO, repo_id, **kwargs):
|
|
| 33 |
return info
|
| 34 |
except Exception as e:
|
| 35 |
print(f"File couldn't be pushed to the Hub with the following error: {e}.")
|
| 36 |
-
return e
|
|
|
|
| 1 |
from io import BytesIO
|
| 2 |
from huggingface_hub import create_repo, upload_file
|
| 3 |
+
import tempfile
|
| 4 |
import os
|
| 5 |
|
| 6 |
DEFAULT_ARCHIVE_FILENAME = "archived_graph.pt2"
|
| 7 |
|
| 8 |
+
|
| 9 |
def _push_compiled_graph_to_hub(archive: BytesIO, repo_id, **kwargs):
|
| 10 |
if not isinstance(archive, BytesIO):
|
| 11 |
raise NotImplementedError("Incorrect type of `archive` provided.")
|
|
|
|
| 14 |
private = kwargs.pop("private", False)
|
| 15 |
path_in_repo = kwargs.pop("path_in_repo", DEFAULT_ARCHIVE_FILENAME)
|
| 16 |
token = kwargs.pop("token")
|
| 17 |
+
repo_id = create_repo(repo_id, private=private, exist_ok=True, token=token).repo_id
|
|
|
|
|
|
|
| 18 |
|
| 19 |
with tempfile.TemporaryDirectory() as tmpdir:
|
| 20 |
output_path = os.path.join(tmpdir, os.path.basename(path_in_repo))
|
|
|
|
| 23 |
|
| 24 |
try:
|
| 25 |
info = upload_file(
|
| 26 |
+
repo_id=repo_id,
|
| 27 |
+
path_or_fileobj=output_path,
|
| 28 |
path_in_repo=os.path.basename(path_in_repo),
|
| 29 |
commit_message=commit_message,
|
| 30 |
token=token,
|
|
|
|
| 32 |
return info
|
| 33 |
except Exception as e:
|
| 34 |
print(f"File couldn't be pushed to the Hub with the following error: {e}.")
|
| 35 |
+
return e
|
optimization.py
CHANGED
|
@@ -5,26 +5,27 @@ import spaces
|
|
| 5 |
import torch
|
| 6 |
from torch.utils._pytree import tree_map
|
| 7 |
|
| 8 |
-
P = ParamSpec(
|
| 9 |
|
| 10 |
-
TRANSFORMER_HIDDEN_DIM = torch.export.Dim(
|
| 11 |
|
| 12 |
# Specific to Flux. More about this is available in
|
| 13 |
# https://huggingface.co/blog/zerogpu-aoti
|
| 14 |
TRANSFORMER_DYNAMIC_SHAPES = {
|
| 15 |
-
|
| 16 |
-
|
| 17 |
}
|
| 18 |
|
| 19 |
INDUCTOR_CONFIGS = {
|
| 20 |
-
|
| 21 |
-
|
| 22 |
-
|
| 23 |
-
|
| 24 |
-
|
| 25 |
-
|
| 26 |
}
|
| 27 |
|
|
|
|
| 28 |
def compile_transformer(pipeline: Callable[P, Any], *args: P.args, **kwargs: P.kwargs):
|
| 29 |
@spaces.GPU(duration=1500)
|
| 30 |
def f():
|
|
@@ -35,12 +36,9 @@ def compile_transformer(pipeline: Callable[P, Any], *args: P.args, **kwargs: P.k
|
|
| 35 |
dynamic_shapes |= TRANSFORMER_DYNAMIC_SHAPES
|
| 36 |
|
| 37 |
exported = torch.export.export(
|
| 38 |
-
mod=pipeline.transformer,
|
| 39 |
-
args=call.args,
|
| 40 |
-
kwargs=call.kwargs,
|
| 41 |
-
dynamic_shapes=dynamic_shapes
|
| 42 |
)
|
| 43 |
return spaces.aoti_compile(exported, INDUCTOR_CONFIGS)
|
| 44 |
-
|
| 45 |
compiled_transformer = f()
|
| 46 |
-
return compiled_transformer
|
|
|
|
| 5 |
import torch
|
| 6 |
from torch.utils._pytree import tree_map
|
| 7 |
|
| 8 |
+
P = ParamSpec("P")
|
| 9 |
|
| 10 |
+
TRANSFORMER_HIDDEN_DIM = torch.export.Dim("hidden", min=4096, max=8212)
|
| 11 |
|
| 12 |
# Specific to Flux. More about this is available in
|
| 13 |
# https://huggingface.co/blog/zerogpu-aoti
|
| 14 |
TRANSFORMER_DYNAMIC_SHAPES = {
|
| 15 |
+
"hidden_states": {1: TRANSFORMER_HIDDEN_DIM},
|
| 16 |
+
"img_ids": {0: TRANSFORMER_HIDDEN_DIM},
|
| 17 |
}
|
| 18 |
|
| 19 |
INDUCTOR_CONFIGS = {
|
| 20 |
+
"conv_1x1_as_mm": True,
|
| 21 |
+
"epilogue_fusion": False,
|
| 22 |
+
"coordinate_descent_tuning": True,
|
| 23 |
+
"coordinate_descent_check_all_directions": True,
|
| 24 |
+
"max_autotune": True,
|
| 25 |
+
"triton.cudagraphs": True,
|
| 26 |
}
|
| 27 |
|
| 28 |
+
|
| 29 |
def compile_transformer(pipeline: Callable[P, Any], *args: P.args, **kwargs: P.kwargs):
|
| 30 |
@spaces.GPU(duration=1500)
|
| 31 |
def f():
|
|
|
|
| 36 |
dynamic_shapes |= TRANSFORMER_DYNAMIC_SHAPES
|
| 37 |
|
| 38 |
exported = torch.export.export(
|
| 39 |
+
mod=pipeline.transformer, args=call.args, kwargs=call.kwargs, dynamic_shapes=dynamic_shapes
|
|
|
|
|
|
|
|
|
|
| 40 |
)
|
| 41 |
return spaces.aoti_compile(exported, INDUCTOR_CONFIGS)
|
| 42 |
+
|
| 43 |
compiled_transformer = f()
|
| 44 |
+
return compiled_transformer
|