File size: 3,020 Bytes
e8d5a56
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
from huggingface_hub import HfApi, snapshot_download
import os
import torch
import argparse

def download_checkpoint(repo_id, save_path, repo_type="model"):
    """

    Download a model checkpoint from Hugging Face Hub to the specified local directory.

    

    Args:

        repo_id (str): The repository ID on Hugging Face Hub

        save_path (str): Local directory path to save the checkpoint

        repo_type (str): Type of repository (default: "model")

    """
    # Initialize Hugging Face API
    api = HfApi()
    
    # Create the directory if it doesn't exist
    os.makedirs(save_path, exist_ok=True)
    
    # Download the checkpoint
    print(f"Downloading {repo_id} to {save_path}...")
    snapshot_download(repo_id=repo_id, repo_type=repo_type, local_dir=save_path)
    print(f"Successfully downloaded {repo_id}")

def main(args):    
    # Define checkpoint configurations
    checkpoints = [
        {
            "repo_id": args.repo_id,
            "save_path": args.save_path,
            "repo_type": args.repo_type
        }
    ]
    
    # Add LoRA checkpoint if provided
    if args.lora_repo_id and args.lora_save_path:
        checkpoints.append({
            "repo_id": args.lora_repo_id,
            "save_path": args.lora_save_path,
            "repo_type": args.lora_repo_type
        })
    
    # Download each checkpoint
    for checkpoint in checkpoints:
        download_checkpoint(
            repo_id=checkpoint["repo_id"],
            save_path=checkpoint["save_path"],
            repo_type=checkpoint["repo_type"]
        )

if __name__ == "__main__":
    # Set up argument parser
    parser = argparse.ArgumentParser(description="Download model checkpoints from Hugging Face Hub")
    parser.add_argument(
        "--repo_id",
        type=str,
        default="cerspense/zeroscope_v2_576w",
        help="Hugging Face repository ID for the checkpoint"
    )
    parser.add_argument(
        "--save_path",
        type=str,
        default="./ckpts/zeroscope_v2_576w",
        help="Local directory to save the checkpoint"
    )
    parser.add_argument(
        "--repo_type",
        type=str,
        default="model",
        help="Type of repository (e.g., model, dataset)"
    )
    parser.add_argument(
        "--lora_repo_id",
        type=str,
        default="danhtran2mind/zeroscope_v2_576w-Ghibli-LoRA",
        help="Hugging Face repository ID for the LoRA checkpoint"
    )
    parser.add_argument(
        "--lora_save_path",
        type=str,
        default="./ckpts/zeroscope_v2_576w-Ghibli-LoRA",
        help="Local directory to save the LoRA checkpoint"
    )
    parser.add_argument(
        "--lora_repo_type",
        type=str,
        default="model",
        help="Type of repository for the LoRA checkpoint (e.g., model, dataset)"
    )
    
    # Parse arguments
    args = parser.parse_args()
    
    # Call main with parsed arguments
    main(args)