Spaces:
Running
on
Zero
Running
on
Zero
| import yaml | |
| import os | |
| from collections import OrderedDict | |
| CHECKPOINT_DIR = "models/checkpoints" | |
| LORA_DIR = "models/loras" | |
| EMBEDDING_DIR = "models/embeddings" | |
| CONTROLNET_DIR = "models/controlnet" | |
| DIFFUSION_MODELS_DIR = "models/diffusion_models" | |
| VAE_DIR = "models/vae" | |
| TEXT_ENCODERS_DIR = "models/text_encoders" | |
| INPUT_DIR = "input" | |
| OUTPUT_DIR = "output" | |
| _PROJECT_ROOT = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) | |
| _MODEL_LIST_PATH = os.path.join(_PROJECT_ROOT, 'yaml', 'model_list.yaml') | |
| _FILE_LIST_PATH = os.path.join(_PROJECT_ROOT, 'yaml', 'file_list.yaml') | |
| _IPADAPTER_LIST_PATH = os.path.join(_PROJECT_ROOT, 'yaml', 'ipadapter.yaml') | |
| _CONSTANTS_PATH = os.path.join(_PROJECT_ROOT, 'yaml', 'constants.yaml') | |
| def load_constants_from_yaml(filepath=_CONSTANTS_PATH): | |
| if not os.path.exists(filepath): | |
| print(f"Warning: Constants file not found at {filepath}. Using fallback values.") | |
| return {} | |
| with open(filepath, 'r', encoding='utf-8') as f: | |
| return yaml.safe_load(f) | |
| def load_file_download_map(filepath=_FILE_LIST_PATH): | |
| if not os.path.exists(filepath): | |
| raise FileNotFoundError(f"The file list (for downloads) was not found at: {filepath}") | |
| with open(filepath, 'r', encoding='utf-8') as f: | |
| file_list_data = yaml.safe_load(f) | |
| download_info_map = {} | |
| for category, files in file_list_data.get('file', {}).items(): | |
| if isinstance(files, list): | |
| for file_info in files: | |
| if 'filename' in file_info: | |
| file_info['category'] = category | |
| download_info_map[file_info['filename']] = file_info | |
| return download_info_map | |
| def load_models_from_yaml(model_list_filepath=_MODEL_LIST_PATH, download_map=None): | |
| if not os.path.exists(model_list_filepath): | |
| raise FileNotFoundError(f"The model list file was not found at: {model_list_filepath}") | |
| if download_map is None: | |
| raise ValueError("download_map must be provided to load_models_from_yaml") | |
| with open(model_list_filepath, 'r', encoding='utf-8') as f: | |
| model_data = yaml.safe_load(f) | |
| model_maps = { | |
| "MODEL_MAP_CHECKPOINT": OrderedDict(), | |
| "ALL_MODEL_MAP": OrderedDict(), | |
| } | |
| category_map_names = { | |
| "Checkpoint": "MODEL_MAP_CHECKPOINT", | |
| } | |
| for category, models in model_data.items(): | |
| if category in category_map_names: | |
| map_name = category_map_names[category] | |
| if not isinstance(models, list): continue | |
| for model in models: | |
| display_name = model['display_name'] | |
| filename = model['path'] | |
| download_info = download_map.get(filename, {}) | |
| repo_id = download_info.get('repo_id', '') | |
| model_tuple = ( | |
| repo_id, | |
| filename, | |
| "SDXL", | |
| None | |
| ) | |
| model_maps[map_name][display_name] = model_tuple | |
| model_maps["ALL_MODEL_MAP"][display_name] = model_tuple | |
| return model_maps | |
| try: | |
| ALL_FILE_DOWNLOAD_MAP = load_file_download_map() | |
| loaded_maps = load_models_from_yaml(download_map=ALL_FILE_DOWNLOAD_MAP) | |
| MODEL_MAP_CHECKPOINT = loaded_maps["MODEL_MAP_CHECKPOINT"] | |
| ALL_MODEL_MAP = loaded_maps["ALL_MODEL_MAP"] | |
| MODEL_TYPE_MAP = {k: v[2] for k, v in ALL_MODEL_MAP.items()} | |
| except Exception as e: | |
| print(f"FATAL: Could not load model configuration from YAML. Error: {e}") | |
| ALL_FILE_DOWNLOAD_MAP = {} | |
| MODEL_MAP_CHECKPOINT, ALL_MODEL_MAP = {}, {} | |
| MODEL_TYPE_MAP = {} | |
| try: | |
| _constants = load_constants_from_yaml() | |
| MAX_LORAS = _constants.get('MAX_LORAS', 5) | |
| MAX_EMBEDDINGS = _constants.get('MAX_EMBEDDINGS', 5) | |
| MAX_CONDITIONINGS = _constants.get('MAX_CONDITIONINGS', 10) | |
| MAX_CONTROLNETS = _constants.get('MAX_CONTROLNETS', 5) | |
| MAX_IPADAPTERS = _constants.get('MAX_IPADAPTERS', 5) | |
| LORA_SOURCE_CHOICES = _constants.get('LORA_SOURCE_CHOICES', ["Civitai", "Custom URL", "File"]) | |
| RESOLUTION_MAP = _constants.get('RESOLUTION_MAP', {}) | |
| SAMPLER_MAP = _constants.get('SAMPLER_MAP', {}) | |
| except Exception as e: | |
| print(f"FATAL: Could not load constants from YAML. Error: {e}") | |
| MAX_LORAS, MAX_EMBEDDINGS, MAX_CONDITIONINGS, MAX_CONTROLNETS, MAX_IPADAPTERS = 5, 5, 10, 5, 5 | |
| LORA_SOURCE_CHOICES = ["Civitai", "Custom URL", "File"] | |
| RESOLUTION_MAP, SAMPLER_MAP = {}, {} | |
| DEFAULT_NEGATIVE_PROMPT = "monochrome, (low quality, worst quality:1.2), 3d, watermark, signature, ugly, poorly drawn," |