File size: 1,316 Bytes
ee02270
 
4b0fe46
ee02270
 
 
 
4b0fe46
3625a6b
ee02270
 
 
 
 
 
 
4b0fe46
ee02270
 
 
 
 
 
 
 
4b0fe46
 
ee02270
 
 
 
3625a6b
ee02270
 
4b0fe46
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
from io import BytesIO
from huggingface_hub import create_repo, upload_file
import tempfile
import os

DEFAULT_ARCHIVE_FILENAME = "archived_graph.pt2"


def _push_compiled_graph_to_hub(archive: BytesIO, repo_id, **kwargs) -> str:
    if not isinstance(archive, BytesIO):
        raise NotImplementedError("Incorrect type of `archive` provided.")

    commit_message = kwargs.pop("commit_message", "Uploaded from spaces.")
    private = kwargs.pop("private", False)
    path_in_repo = kwargs.pop("path_in_repo", DEFAULT_ARCHIVE_FILENAME)
    token = kwargs.pop("token")
    repo_id = create_repo(repo_id, private=private, exist_ok=True, token=token).repo_id

    with tempfile.TemporaryDirectory() as tmpdir:
        output_path = os.path.join(tmpdir, os.path.basename(path_in_repo))
        with open(output_path, "wb") as f:
            f.write(archive.getvalue())

        try:
            info = upload_file(
                repo_id=repo_id,
                path_or_fileobj=output_path,
                path_in_repo=os.path.basename(path_in_repo),
                commit_message=commit_message,
                token=token,
            )
            return info.commit_url
        except Exception as e:
            print(f"File couldn't be pushed to the Hub with the following error: {e}.")
            return e