import os import requests import hashlib import re from typing import Sequence, Mapping, Any, Union, Set from pathlib import Path import shutil import gradio as gr from huggingface_hub import hf_hub_download, constants as hf_constants import torch import numpy as np from PIL import Image, ImageChops from core.settings import * DISK_LIMIT_GB = 120 MODELS_ROOT_DIR = "ComfyUI/models" PREPROCESSOR_MODEL_MAP = None PREPROCESSOR_PARAMETER_MAP = None IPADAPTER_PRESETS = None def save_uploaded_file_with_hash(file_obj: gr.File, target_dir: str) -> str: if not file_obj: return "" temp_path = file_obj.name sha256 = hashlib.sha256() with open(temp_path, 'rb') as f: for block in iter(lambda: f.read(65536), b''): sha256.update(block) file_hash = sha256.hexdigest() _, extension = os.path.splitext(temp_path) hashed_filename = f"{file_hash}{extension.lower()}" dest_path = os.path.join(target_dir, hashed_filename) os.makedirs(target_dir, exist_ok=True) if not os.path.exists(dest_path): shutil.copy(temp_path, dest_path) print(f"✅ Saved uploaded file as: {dest_path}") else: print(f"ℹ️ File already exists (deduplicated): {dest_path}") return hashed_filename def bytes_to_gb(byte_size: int) -> float: if byte_size is None or byte_size == 0: return 0.0 return round(byte_size / (1024 ** 3), 2) def get_directory_size(path: str) -> int: total_size = 0 if not os.path.exists(path): return 0 try: for dirpath, _, filenames in os.walk(path): for f in filenames: fp = os.path.join(dirpath, f) if os.path.isfile(fp) and not os.path.islink(fp): total_size += os.path.getsize(fp) except OSError as e: print(f"Warning: Could not access {path} to calculate size: {e}") return total_size def enforce_disk_limit(): disk_limit_bytes = DISK_LIMIT_GB * (1024 ** 3) cache_dir = hf_constants.HF_HUB_CACHE if not os.path.exists(cache_dir): return print(f"--- [Storage Manager] Checking disk usage in '{cache_dir}' (Limit: {DISK_LIMIT_GB} GB) ---") try: all_files = [] current_size_bytes = 0 for dirpath, _, filenames in os.walk(cache_dir): for f in filenames: if f.endswith(".incomplete") or f.endswith(".lock"): continue file_path = os.path.join(dirpath, f) if os.path.isfile(file_path) and not os.path.islink(file_path): try: file_size = os.path.getsize(file_path) creation_time = os.path.getctime(file_path) all_files.append((creation_time, file_path, file_size)) current_size_bytes += file_size except OSError: continue print(f"--- [Storage Manager] Current usage: {bytes_to_gb(current_size_bytes)} GB ---") if current_size_bytes > disk_limit_bytes: print(f"--- [Storage Manager] Usage exceeds limit. Starting cleanup... ---") all_files.sort(key=lambda x: x[0]) while current_size_bytes > disk_limit_bytes and all_files: oldest_file_time, oldest_file_path, oldest_file_size = all_files.pop(0) try: os.remove(oldest_file_path) current_size_bytes -= oldest_file_size print(f"--- [Storage Manager] Deleted oldest file: {os.path.basename(oldest_file_path)} ({bytes_to_gb(oldest_file_size)} GB freed) ---") except OSError as e: print(f"--- [Storage Manager] Error deleting file {oldest_file_path}: {e} ---") print(f"--- [Storage Manager] Cleanup finished. New usage: {bytes_to_gb(current_size_bytes)} GB ---") else: print("--- [Storage Manager] Disk usage is within the limit. No action needed. ---") except Exception as e: print(f"--- [Storage Manager] An unexpected error occurred: {e} ---") def get_value_at_index(obj: Union[Sequence, Mapping], index: int) -> Any: try: return obj[index] except (KeyError, IndexError): try: return obj["result"][index] except (KeyError, IndexError): return None def sanitize_prompt(prompt: str) -> str: if not isinstance(prompt, str): return "" return "".join(char for char in prompt if char.isprintable() or char in ('\n', '\t')) def sanitize_id(input_id: str) -> str: if not isinstance(input_id, str): return "" return re.sub(r'[^0-9]', '', input_id) def sanitize_url(url: str) -> str: if not isinstance(url, str): raise ValueError("URL must be a string.") url = url.strip() if not re.match(r'^https?://[^\s/$.?#].[^\s]*$', url): raise ValueError("Invalid URL format or scheme. Only HTTP and HTTPS are allowed.") return url def sanitize_filename(filename: str) -> str: if not isinstance(filename, str): return "" sanitized = filename.replace('..', '') sanitized = re.sub(r'[^\w\.\-]', '_', sanitized) return sanitized.lstrip('/\\') def get_civitai_file_info(version_id: str) -> dict | None: api_url = f"https://civitai.com/api/v1/model-versions/{version_id}" try: response = requests.get(api_url, timeout=10) response.raise_for_status() data = response.json() for file_data in data.get('files', []): if file_data.get('type') == 'Model' and file_data['name'].endswith(('.safetensors', '.pt', '.bin')): return file_data if data.get('files'): return data['files'][0] except Exception: return None def download_file(url: str, save_path: str, api_key: str = None, progress=None, desc: str = "") -> str: enforce_disk_limit() if os.path.exists(save_path): return f"File already exists: {os.path.basename(save_path)}" headers = {'Authorization': f'Bearer {api_key}'} if api_key and api_key.strip() else {} try: if progress: progress(0, desc=desc) response = requests.get(url, stream=True, headers=headers, timeout=15) response.raise_for_status() total_size = int(response.headers.get('content-length', 0)) with open(save_path, "wb") as f: downloaded = 0 for chunk in response.iter_content(chunk_size=8192): f.write(chunk) if progress and total_size > 0: downloaded += len(chunk) progress(downloaded / total_size, desc=desc) return f"Successfully downloaded: {os.path.basename(save_path)}" except Exception as e: if os.path.exists(save_path): os.remove(save_path) return f"Download failed for {os.path.basename(save_path)}: {e}" def get_lora_path(source: str, id_or_url: str, civitai_key: str, progress) -> tuple[str | None, str]: if not id_or_url or not id_or_url.strip(): return None, "No ID/URL provided." try: if source == "Civitai": version_id = sanitize_id(id_or_url) if not version_id: return None, "Invalid Civitai ID provided. Must be numeric." filename = sanitize_filename(f"civitai_{version_id}.safetensors") local_path = os.path.join(LORA_DIR, filename) file_info = get_civitai_file_info(version_id) api_key_to_use = civitai_key source_name = f"Civitai ID {version_id}" else: return None, "Invalid source." except ValueError as e: return None, f"Input validation failed: {e}" if os.path.exists(local_path): return local_path, "File already exists." if not file_info or not file_info.get('downloadUrl'): return None, f"Could not get download link for {source_name}." status = download_file(file_info['downloadUrl'], local_path, api_key_to_use, progress=progress, desc=f"Downloading {source_name}") return (local_path, status) if "Successfully" in status else (None, status) def get_embedding_path(source: str, id_or_url: str, civitai_key: str, progress) -> tuple[str | None, str]: if not id_or_url or not id_or_url.strip(): return None, "No ID/URL provided." try: file_ext = ".safetensors" if source == "Civitai": version_id = sanitize_id(id_or_url) if not version_id: return None, "Invalid Civitai ID. Must be numeric." file_info = get_civitai_file_info(version_id) if file_info and file_info['name'].lower().endswith(('.pt', '.bin')): file_ext = os.path.splitext(file_info['name'])[1] filename = sanitize_filename(f"civitai_{version_id}{file_ext}") local_path = os.path.join(EMBEDDING_DIR, filename) api_key_to_use = civitai_key source_name = f"Embedding Civitai ID {version_id}" else: return None, "Invalid source." except ValueError as e: return None, f"Input validation failed: {e}" if os.path.exists(local_path): return local_path, "File already exists." if not file_info or not file_info.get('downloadUrl'): return None, f"Could not get download link for {source_name}." status = download_file(file_info['downloadUrl'], local_path, api_key_to_use, progress=progress, desc=f"Downloading {source_name}") return (local_path, status) if "Successfully" in status else (None, status) def get_vae_path(source: str, id_or_url: str, civitai_key: str, progress) -> tuple[str | None, str]: if not id_or_url or not id_or_url.strip(): return None, "No ID/URL provided." try: file_ext = ".safetensors" if source == "Civitai": version_id = sanitize_id(id_or_url) if not version_id: return None, "Invalid Civitai ID. Must be numeric." file_info = get_civitai_file_info(version_id) if file_info and file_info['name'].lower().endswith(('.pt', '.bin')): file_ext = os.path.splitext(file_info['name'])[1] filename = sanitize_filename(f"civitai_{version_id}{file_ext}") local_path = os.path.join(VAE_DIR, filename) api_key_to_use = civitai_key source_name = f"VAE Civitai ID {version_id}" else: return None, "Invalid source." except ValueError as e: return None, f"Input validation failed: {e}" if os.path.exists(local_path): return local_path, "File already exists." if not file_info or not file_info.get('downloadUrl'): return None, f"Could not get download link for {source_name}." status = download_file(file_info['downloadUrl'], local_path, api_key_to_use, progress=progress, desc=f"Downloading {source_name}") return (local_path, status) if "Successfully" in status else (None, status) def _ensure_model_downloaded(display_name: str, progress=gr.Progress()): if display_name not in ALL_MODEL_MAP: raise ValueError(f"Model '{display_name}' not found in configuration.") _, repo_filename, model_type, _ = ALL_MODEL_MAP[display_name] type_to_dir_map = { "SDXL": CHECKPOINT_DIR, "SD1.5": CHECKPOINT_DIR, "UNET": DIFFUSION_MODELS_DIR, "VAE": VAE_DIR, "TEXT_ENCODER": TEXT_ENCODERS_DIR, "LORA": LORA_DIR, "IPADAPTER": os.path.join(os.path.dirname(LORA_DIR), "ipadapter"), "CLIP_VISION": os.path.join(os.path.dirname(LORA_DIR), "clip_vision") } dest_dir = type_to_dir_map.get(model_type) if not dest_dir: raise ValueError(f"Unknown model type '{model_type}' for '{display_name}'.") base_filename = os.path.basename(repo_filename) dest_path = os.path.join(dest_dir, base_filename) if os.path.lexists(dest_path): if not os.path.exists(dest_path): print(f"⚠️ Found and removed broken symlink: {dest_path}") os.remove(dest_path) else: return base_filename download_info = ALL_FILE_DOWNLOAD_MAP.get(base_filename) if not download_info: raise gr.Error(f"Model '{base_filename}' not found in file_list.yaml. Cannot download.") source = download_info.get("source") try: progress(0, desc=f"Downloading: {base_filename}") if source == "hf": repo_id = download_info.get("repo_id") hf_filename = download_info.get("repository_file_path", base_filename) if not repo_id: raise ValueError(f"repo_id is missing for HF model '{base_filename}'") cached_path = hf_hub_download(repo_id=repo_id, filename=hf_filename) os.makedirs(dest_dir, exist_ok=True) os.symlink(cached_path, dest_path) print(f"✅ Symlinked '{cached_path}' to '{dest_path}'") elif source == "civitai": model_version_id = download_info.get("model_version_id") if not model_version_id: raise ValueError(f"model_version_id is missing for Civitai model '{base_filename}'") file_info = get_civitai_file_info(model_version_id) if not file_info or not file_info.get('downloadUrl'): raise ConnectionError(f"Could not get download URL for Civitai model version ID {model_version_id}") status = download_file( file_info['downloadUrl'], dest_path, progress=progress, desc=f"Downloading: {base_filename}" ) if "Failed" in status: raise ConnectionError(status) else: raise NotImplementedError(f"Download source '{source}' is not implemented for '{base_filename}'") progress(1.0, desc=f"Downloaded: {base_filename}") except Exception as e: if os.path.lexists(dest_path): try: os.remove(dest_path) except OSError: pass raise gr.Error(f"Failed to download and link '{display_name}': {e}") return base_filename def ensure_controlnet_model_downloaded(filename: str, progress): if not filename or filename == "None": return dest_path = os.path.join(CONTROLNET_DIR, filename) if os.path.exists(dest_path): return download_info = ALL_FILE_DOWNLOAD_MAP.get(filename) if not download_info: raise gr.Error(f"ControlNet model '{filename}' not found in configuration (file_list.yaml). Cannot download.") source = download_info.get("source") try: if source == "hf": repo_id = download_info.get("repo_id") repo_filename = download_info.get("repository_file_path", filename) if not repo_id: raise ValueError("repo_id is missing for Hugging Face download.") progress(0, desc=f"Downloading CN: {filename}") cached_path = hf_hub_download(repo_id=repo_id, filename=repo_filename) os.makedirs(CONTROLNET_DIR, exist_ok=True) os.symlink(cached_path, dest_path) print(f"✅ Symlinked ControlNet '{cached_path}' to '{dest_path}'") progress(1.0, desc=f"Downloaded CN: {filename}") elif source == "civitai": model_version_id = download_info.get("model_version_id") if not model_version_id: raise ValueError("model_version_id is missing for Civitai download.") file_info = get_civitai_file_info(model_version_id) if not file_info or not file_info.get('downloadUrl'): raise ConnectionError(f"Could not get download URL for Civitai model version ID {model_version_id}") status = download_file( file_info['downloadUrl'], dest_path, progress=progress, desc=f"Downloading CN: {filename}" ) if "Failed" in status: raise ConnectionError(status) else: raise NotImplementedError(f"Download source '{source}' is not implemented for ControlNets.") except Exception as e: if os.path.lexists(dest_path): try: os.remove(dest_path) except OSError: pass raise gr.Error(f"Failed to download ControlNet model '{filename}': {e}") def load_ipadapter_presets(): global IPADAPTER_PRESETS if IPADAPTER_PRESETS is not None: return _PROJECT_ROOT = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) _IPADAPTER_LIST_PATH = os.path.join(_PROJECT_ROOT, 'yaml', 'ipadapter.yaml') try: with open(_IPADAPTER_LIST_PATH, 'r', encoding='utf-8') as f: presets_list = yaml.safe_load(f) IPADAPTER_PRESETS = {item['preset_name']: item for item in presets_list} print("✅ IPAdapter presets loaded successfully.") except Exception as e: print(f"❌ FATAL: Could not load or parse ipadapter.yaml. IPAdapter will not work. Error: {e}") IPADAPTER_PRESETS = {} def ensure_ipadapter_models_downloaded(preset_name: str, progress): if not preset_name: return if IPADAPTER_PRESETS is None: raise RuntimeError("IPAdapter presets have not been loaded. `load_ipadapter_presets` must be called on startup.") preset_info = IPADAPTER_PRESETS.get(preset_name) if not preset_info: print(f"⚠️ Warning: IPAdapter preset '{preset_name}' not found in configuration. Skipping download.") return model_files_to_check = { preset_info.get('vision_model'): 'CLIP_VISION', preset_info.get('ipadapter_model'): 'IPADAPTER', preset_info.get('lora_model'): 'LORA' } for filename, model_type in model_files_to_check.items(): if not filename: continue temp_display_name = f"ipadapter_asset_{filename}" if temp_display_name not in ALL_MODEL_MAP: ALL_MODEL_MAP[temp_display_name] = (None, filename, model_type, None) try: _ensure_model_downloaded(temp_display_name, progress) except Exception as e: print(f"❌ Error ensuring download for IPAdapter asset '{filename}': {e}") def parse_parameters(params_text: str) -> dict: data = {} lines = params_text.strip().split('\n') data['prompt'] = lines[0] data['negative_prompt'] = lines[1].replace("Negative prompt:", "").strip() if len(lines) > 1 and lines[1].startswith("Negative prompt:") else "" params_line = '\n'.join(lines[2:]) def find_param(key, default, cast_type=str): match = re.search(fr"\b{key}: ([^,]+?)(,|$|\n)", params_line) return cast_type(match.group(1).strip()) if match else default data['steps'] = find_param("Steps", 28, int) data['sampler'] = find_param("Sampler", "euler", str) data['scheduler'] = find_param("Scheduler", "normal", str) data['cfg_scale'] = find_param("CFG scale", 7.5, float) data['seed'] = find_param("Seed", -1, int) data['clip_skip'] = find_param("Clip skip", 1, int) data['base_model'] = find_param("Base Model", list(ALL_MODEL_MAP.keys())[0] if ALL_MODEL_MAP else "", str) data['model_hash'] = find_param("Model hash", None, str) size_match = re.search(r"Size: (\d+)x(\d+)", params_line) data['width'], data['height'] = (int(size_match.group(1)), int(size_match.group(2))) if size_match else (1024, 1024) return data def get_png_info(image) -> tuple[str, str, str]: if not image or not (params := image.info.get('parameters')): return "", "", "No metadata found in the image." parsed_data = parse_parameters(params) raw_params_list = '\n'.join(params.strip().split('\n')[2:]).split(',') other_params_text = "\n".join([p.strip() for p in raw_params_list]) return parsed_data.get('prompt', ''), parsed_data.get('negative_prompt', ''), other_params_text def build_preprocessor_model_map(): global PREPROCESSOR_MODEL_MAP if PREPROCESSOR_MODEL_MAP is not None: return PREPROCESSOR_MODEL_MAP print("--- Building ControlNet Preprocessor model map ---") manual_map = { "dwpose": [("yzd-v/DWPose", "yolox_l.onnx"), ("yzd-v/DWPose", "dw-ll_ucoco_384.onnx"), ("hr16/UnJIT-DWPose", "dw-ll_ucoco.onnx"), ("hr16/DWPose-TorchScript-BatchSize5", "dw-ll_ucoco_384_bs5.torchscript.pt"), ("hr16/DWPose-TorchScript-BatchSize5", "rtmpose-m_ap10k_256_bs5.torchscript.pt"), ("hr16/yolo-nas-fp16", "yolo_nas_l_fp16.onnx"), ("hr16/yolo-nas-fp16", "yolo_nas_m_fp16.onnx"), ("hr16/yolo-nas-fp16", "yolo_nas_s_fp16.onnx")], "densepose": [("LayerNorm/DensePose-TorchScript-with-hint-image", "densepose_r50_fpn_dl.torchscript"), ("LayerNorm/DensePose-TorchScript-with-hint-image", "densepose_r101_fpn_dl.torchscript")] } temp_map = {} from nodes import NODE_DISPLAY_NAME_MAPPINGS wrappers_dir = Path("./custom_nodes/comfyui_controlnet_aux/node_wrappers/") if not wrappers_dir.exists(): print("⚠️ ControlNet AUX wrappers directory not found. Cannot build model map.") PREPROCESSOR_MODEL_MAP = {}; return PREPROCESSOR_MODEL_MAP for wrapper_file in wrappers_dir.glob("*.py"): if wrapper_file.name == "__init__.py": continue with open(wrapper_file, 'r', encoding='utf-8') as f: content = f.read() display_name_matches = re.findall(r'NODE_DISPLAY_NAME_MAPPINGS\s*=\s*{(?:.|\n)*?["\'](.*?)["\']\s*:\s*["\'](.*?)["\']', content) for _, display_name in display_name_matches: if display_name not in temp_map: temp_map[display_name] = [] manual_key = wrapper_file.stem if manual_key in manual_map: temp_map[display_name].extend(manual_map[manual_key]) matches = re.findall(r"from_pretrained\s*\(\s*(?:filename=)?\s*f?[\"']([^\"']+)[\"']", content) for model_filename in matches: repo_id = "lllyasviel/Annotators" if "depth_anything" in model_filename and "v2" in model_filename: repo_id = "LiheYoung/Depth-Anything-V2" elif "depth_anything" in model_filename: repo_id = "LiheYoung/Depth-Anything" elif "diffusion_edge" in model_filename: repo_id = "hr16/Diffusion-Edge" temp_map[display_name].append((repo_id, model_filename)) final_map = {name: sorted(list(set(models))) for name, models in temp_map.items() if models} PREPROCESSOR_MODEL_MAP = final_map print("✅ ControlNet Preprocessor model map built."); return PREPROCESSOR_MODEL_MAP def build_preprocessor_parameter_map(): global PREPROCESSOR_PARAMETER_MAP if PREPROCESSOR_PARAMETER_MAP is not None: return print("--- Building ControlNet Preprocessor parameter map ---") param_map = {} from nodes import NODE_CLASS_MAPPINGS, NODE_DISPLAY_NAME_MAPPINGS for class_name, node_class in NODE_CLASS_MAPPINGS.items(): if not hasattr(node_class, "INPUT_TYPES"): continue if hasattr(node_class, '__module__') and 'comfyui_controlnet_aux.node_wrappers' not in node_class.__module__: continue display_name = NODE_DISPLAY_NAME_MAPPINGS.get(class_name) if not display_name: continue try: input_types = node_class.INPUT_TYPES() all_inputs = {**input_types.get('required', {}), **input_types.get('optional', {})} params = [] for name, details in all_inputs.items(): if name in ['image', 'resolution', 'pose_kps']: continue if not isinstance(details, (list, tuple)) or not details: continue param_type = details[0] param_config = details[1] if len(details) > 1 and isinstance(details[1], dict) else {} param_info = {"name": name, "type": param_type, "config": param_config} params.append(param_info) if params: param_map[display_name] = params except Exception as e: print(f"⚠️ Could not parse parameters for {display_name}: {e}") PREPROCESSOR_PARAMETER_MAP = param_map print("✅ ControlNet Preprocessor parameter map built.") def print_welcome_message(): author_name = "RioShiina" project_url = "https://huggingface.co/RioShiina" border = "=" * 72 message = ( f"\n{border}\n\n" f" Thank you for using this project!\n\n" f" **Author:** {author_name}\n" f" **Find more from the author:** {project_url}\n\n" f" This project is open-source under the GNU General Public License v3.0 (GPL-3.0).\n" f" As it's built upon GPL-3.0 components (like ComfyUI), any modifications you\n" f" distribute must also be open-sourced under the same license.\n\n" f" Your respect for the principles of free software is greatly appreciated!\n\n" f"{border}\n" ) print(message)