import gc from typing import Dict, List, Any, Set import torch import gradio as gr from comfy import model_management from core.settings import ALL_MODEL_MAP, CHECKPOINT_DIR, LORA_DIR, DIFFUSION_MODELS_DIR, VAE_DIR, TEXT_ENCODERS_DIR from comfy_integration.nodes import checkpointloadersimple, LoraLoader from nodes import NODE_CLASS_MAPPINGS from utils.app_utils import get_value_at_index, _ensure_model_downloaded class ModelManager: _instance = None def __new__(cls, *args, **kwargs): if not cls._instance: cls._instance = super(ModelManager, cls).__new__(cls, *args, **kwargs) return cls._instance def __init__(self): if hasattr(self, 'initialized'): return self.loaded_models: Dict[str, Any] = {} self.initialized = True print("✅ ModelManager initialized.") def get_loaded_model_names(self) -> Set[str]: return set(self.loaded_models.keys()) def _load_single_model(self, display_name: str, progress) -> Any: print(f"--- [ModelManager] Loading model: '{display_name}' ---") filename = _ensure_model_downloaded(display_name, progress) _, _, model_type, _ = ALL_MODEL_MAP[display_name] loader_map = { "SDXL": (checkpointloadersimple, "load_checkpoint", {"ckpt_name": filename}), "SD1.5": (checkpointloadersimple, "load_checkpoint", {"ckpt_name": filename}), "UNET": (NODE_CLASS_MAPPINGS["UNETLoader"](), "load_unet", {"unet_name": filename, "weight_dtype": "default"}), "VAE": (NODE_CLASS_MAPPINGS["VAELoader"](), "load_vae", {"vae_name": filename}), "TEXT_ENCODER": (NODE_CLASS_MAPPINGS["CLIPLoader"](), "load_clip", {"clip_name": filename, "type": "wan", "device": "default"}), } if model_type not in loader_map: if model_type == "LORA": print(f"--- [ModelManager] ✅ '{display_name}' is a LoRA. It will be loaded dynamically. ---") return (filename,) raise ValueError(f"[ModelManager] No loader configured for model type '{model_type}'") loader_instance, method_name, kwargs = loader_map[model_type] load_method = getattr(loader_instance, method_name) loaded_tuple = load_method(**kwargs) print(f"--- [ModelManager] ✅ Successfully loaded '{display_name}' to CPU/RAM ---") return loaded_tuple def move_models_to_gpu(self, required_models: List[str]): print(f"--- [ModelManager] Moving models to GPU: {required_models} ---") models_to_load_gpu = [] for name in required_models: if name in self.loaded_models: model_tuple = self.loaded_models[name] _, _, model_type, _ = ALL_MODEL_MAP[name] if model_type in ["SDXL", "SD1.5"]: models_to_load_gpu.append(get_value_at_index(model_tuple, 0)) if models_to_load_gpu: model_management.load_models_gpu(models_to_load_gpu) print("--- [ModelManager] ✅ Models successfully moved to GPU. ---") else: print("--- [ModelManager] ⚠️ No checkpoint models found to move to GPU. ---") def ensure_models_downloaded(self, required_models: List[str], progress): print(f"--- [ModelManager] Ensuring models are downloaded: {required_models} ---") for i, display_name in enumerate(required_models): if progress and hasattr(progress, '__call__'): progress(i / len(required_models), desc=f"Checking file: {display_name}") try: _ensure_model_downloaded(display_name, progress) except Exception as e: raise gr.Error(f"Failed to download model '{display_name}'. Reason: {e}") print(f"--- [ModelManager] ✅ All required models are present on disk. ---") def load_managed_models(self, required_models: List[str], active_loras: List[Dict[str, Any]], progress) -> Dict[str, Any]: required_set = set(required_models) current_set = set(self.loaded_models.keys()) loras_changed = len(active_loras) > 0 or len(current_set - required_set) > 0 models_to_unload = current_set - required_set if models_to_unload or loras_changed: if models_to_unload: print(f"--- [ModelManager] Models to unload: {models_to_unload} ---") if loras_changed and not models_to_unload: models_to_unload = current_set.intersection(required_set) print(f"--- [ModelManager] LoRA configuration changed. Reloading base model(s): {models_to_unload} ---") model_management.unload_all_models() self.loaded_models.clear() gc.collect() torch.cuda.empty_cache() print("--- [ModelManager] All models unloaded to free RAM. ---") models_to_load = required_set if (models_to_unload or loras_changed) else required_set - current_set if models_to_load: print(f"--- [ModelManager] Models to load: {models_to_load} ---") for i, display_name in enumerate(models_to_load): progress(i / len(models_to_load), desc=f"Loading model: {display_name}") try: loaded_model_data = self._load_single_model(display_name, progress) if active_loras and ALL_MODEL_MAP[display_name][2] in ["SDXL", "SD1.5"]: print(f"--- [ModelManager] Applying {len(active_loras)} LoRAs on CPU... ---") lora_loader = LoraLoader() patched_model, patched_clip = loaded_model_data[0], loaded_model_data[1] for lora_info in active_loras: patched_model, patched_clip = lora_loader.load_lora( model=patched_model, clip=patched_clip, lora_name=lora_info["lora_name"], strength_model=lora_info["strength_model"], strength_clip=lora_info["strength_clip"] ) loaded_model_data = (patched_model, patched_clip, loaded_model_data[2]) print(f"--- [ModelManager] ✅ All LoRAs merged into the model on CPU. ---") self.loaded_models[display_name] = loaded_model_data except Exception as e: raise gr.Error(f"Failed to load model or apply LoRA '{display_name}'. Reason: {e}") else: print(f"--- [ModelManager] All required models are already loaded. ---") return {name: self.loaded_models[name] for name in required_models} model_manager = ModelManager()