danhtran2mind's picture
Upload 84 files
7946a9d verified
import os
import argparse
import yaml
from huggingface_hub import snapshot_download
def download_model_checkpoint(repo_id, local_dir, token=None, ignore_patterns=None):
"""
Download a Hugging Face model checkpoint to a specified local directory.
Args:
repo_id (str): The Hugging Face repository ID (e.g., 'stabilityai/stable-diffusion-2-1').
local_dir (str): The local directory to store the downloaded checkpoint files.
token (str, optional): Hugging Face API token for accessing private or gated repositories.
ignore_patterns (list, optional): List of file patterns to exclude from downloading.
"""
try:
os.makedirs(local_dir, exist_ok=True)
snapshot_download(
repo_id=repo_id,
local_dir=local_dir,
local_dir_use_symlinks=False,
repo_type="model",
token=token,
allow_patterns=["*.safetensors", "*.ckpt", "*.json", "*.txt"],
ignore_patterns=ignore_patterns # Pass ignore_patterns to exclude specified files
)
print(f"Successfully downloaded model checkpoint from {repo_id} to {local_dir}")
except Exception as e:
print(f"Failed to download model checkpoint from {repo_id}: {str(e)}")
def load_config(config_path):
"""
Load and validate the YAML configuration file.
Args:
config_path (str): Path to the YAML configuration file.
Returns:
list: List of model configurations for HuggingFace platform.
Raises:
FileNotFoundError: If the configuration file does not exist.
yaml.YAMLError: If the YAML file is invalid.
ValueError: If the YAML file is empty or contains no valid HuggingFace entries.
"""
with open(config_path, "r") as file:
config_data = yaml.safe_load(file)
if not config_data:
raise ValueError("The YAML configuration file is empty or invalid")
# Filter for HuggingFace platform entries
huggingface_configs = [item for item in config_data if item.get("platform") == "HuggingFace"]
if not huggingface_configs:
raise ValueError("No valid HuggingFace platform entries found in the YAML configuration")
return huggingface_configs
if __name__ == "__main__":
"""
Main function to parse arguments and download model checkpoints based on the YAML configuration.
"""
parser = argparse.ArgumentParser(description="Download Hugging Face model checkpoints specified in a YAML configuration file.")
parser.add_argument(
"--config",
type=str,
default="configs/model_ckpts.yaml",
help="Path to the YAML configuration file specifying model IDs and local directories."
)
parser.add_argument(
"--token",
type=str,
default=None,
help="Hugging Face API token for accessing private or gated repositories (optional)."
)
args = parser.parse_args()
try:
# Load and validate the configuration
model_configs = load_config(args.config)
# Download each model checkpoint
for config in model_configs:
repo_id = config.get("model_id")
local_dir = config.get("local_dir")
ignore_patterns = config.get("no_download_path") # Get the no_download_path field
if not repo_id or not local_dir:
print(f"Skipping invalid configuration entry: missing model_id or local_dir")
continue
print(f"Downloading model checkpoint from {repo_id} to {local_dir}...")
download_model_checkpoint(repo_id, local_dir, args.token, ignore_patterns)
except FileNotFoundError:
print(f"Error: Configuration file '{args.config}' not found.")
except yaml.YAMLError as e:
print(f"Error: Failed to parse YAML configuration file. Details: {str(e)}")
except ValueError as e:
print(f"Error: {str(e)}")
except Exception as e:
print(f"Unexpected error: {str(e)}")