Flux-Compiled-Graph / hub_utils.py
sayakpaul's picture
sayakpaul HF Staff
up
3625a6b
from io import BytesIO
from huggingface_hub import create_repo, upload_file
import tempfile
import os
DEFAULT_ARCHIVE_FILENAME = "archived_graph.pt2"
def _push_compiled_graph_to_hub(archive: BytesIO, repo_id, **kwargs) -> str:
if not isinstance(archive, BytesIO):
raise NotImplementedError("Incorrect type of `archive` provided.")
commit_message = kwargs.pop("commit_message", "Uploaded from spaces.")
private = kwargs.pop("private", False)
path_in_repo = kwargs.pop("path_in_repo", DEFAULT_ARCHIVE_FILENAME)
token = kwargs.pop("token")
repo_id = create_repo(repo_id, private=private, exist_ok=True, token=token).repo_id
with tempfile.TemporaryDirectory() as tmpdir:
output_path = os.path.join(tmpdir, os.path.basename(path_in_repo))
with open(output_path, "wb") as f:
f.write(archive.getvalue())
try:
info = upload_file(
repo_id=repo_id,
path_or_fileobj=output_path,
path_in_repo=os.path.basename(path_in_repo),
commit_message=commit_message,
token=token,
)
return info.commit_url
except Exception as e:
print(f"File couldn't be pushed to the Hub with the following error: {e}.")
return e