Spaces:
Sleeping
Sleeping
| import gc | |
| from typing import Dict, List, Any, Set | |
| import torch | |
| import gradio as gr | |
| from comfy import model_management | |
| from core.settings import ALL_MODEL_MAP, CHECKPOINT_DIR, LORA_DIR, DIFFUSION_MODELS_DIR, VAE_DIR, TEXT_ENCODERS_DIR | |
| from comfy_integration.nodes import checkpointloadersimple, LoraLoader | |
| from nodes import NODE_CLASS_MAPPINGS | |
| from utils.app_utils import get_value_at_index, _ensure_model_downloaded | |
| class ModelManager: | |
| _instance = None | |
| def __new__(cls, *args, **kwargs): | |
| if not cls._instance: | |
| cls._instance = super(ModelManager, cls).__new__(cls, *args, **kwargs) | |
| return cls._instance | |
| def __init__(self): | |
| if hasattr(self, 'initialized'): | |
| return | |
| self.loaded_models: Dict[str, Any] = {} | |
| self.initialized = True | |
| print("β ModelManager initialized.") | |
| def get_loaded_model_names(self) -> Set[str]: | |
| return set(self.loaded_models.keys()) | |
| def _load_single_model(self, display_name: str, progress) -> Any: | |
| print(f"--- [ModelManager] Loading model: '{display_name}' ---") | |
| filename = _ensure_model_downloaded(display_name, progress) | |
| _, _, model_type, _ = ALL_MODEL_MAP[display_name] | |
| loader_map = { | |
| "SDXL": (checkpointloadersimple, "load_checkpoint", {"ckpt_name": filename}), | |
| "SD1.5": (checkpointloadersimple, "load_checkpoint", {"ckpt_name": filename}), | |
| "UNET": (NODE_CLASS_MAPPINGS["UNETLoader"](), "load_unet", {"unet_name": filename, "weight_dtype": "default"}), | |
| "VAE": (NODE_CLASS_MAPPINGS["VAELoader"](), "load_vae", {"vae_name": filename}), | |
| "TEXT_ENCODER": (NODE_CLASS_MAPPINGS["CLIPLoader"](), "load_clip", {"clip_name": filename, "type": "wan", "device": "default"}), | |
| } | |
| if model_type not in loader_map: | |
| if model_type == "LORA": | |
| print(f"--- [ModelManager] β '{display_name}' is a LoRA. It will be loaded dynamically. ---") | |
| return (filename,) | |
| raise ValueError(f"[ModelManager] No loader configured for model type '{model_type}'") | |
| loader_instance, method_name, kwargs = loader_map[model_type] | |
| load_method = getattr(loader_instance, method_name) | |
| loaded_tuple = load_method(**kwargs) | |
| print(f"--- [ModelManager] β Successfully loaded '{display_name}' to CPU/RAM ---") | |
| return loaded_tuple | |
| def move_models_to_gpu(self, required_models: List[str]): | |
| print(f"--- [ModelManager] Moving models to GPU: {required_models} ---") | |
| models_to_load_gpu = [] | |
| for name in required_models: | |
| if name in self.loaded_models: | |
| model_tuple = self.loaded_models[name] | |
| _, _, model_type, _ = ALL_MODEL_MAP[name] | |
| if model_type in ["SDXL", "SD1.5"]: | |
| models_to_load_gpu.append(get_value_at_index(model_tuple, 0)) | |
| if models_to_load_gpu: | |
| model_management.load_models_gpu(models_to_load_gpu) | |
| print("--- [ModelManager] β Models successfully moved to GPU. ---") | |
| else: | |
| print("--- [ModelManager] β οΈ No checkpoint models found to move to GPU. ---") | |
| def ensure_models_downloaded(self, required_models: List[str], progress): | |
| print(f"--- [ModelManager] Ensuring models are downloaded: {required_models} ---") | |
| for i, display_name in enumerate(required_models): | |
| if progress and hasattr(progress, '__call__'): | |
| progress(i / len(required_models), desc=f"Checking file: {display_name}") | |
| try: | |
| _ensure_model_downloaded(display_name, progress) | |
| except Exception as e: | |
| raise gr.Error(f"Failed to download model '{display_name}'. Reason: {e}") | |
| print(f"--- [ModelManager] β All required models are present on disk. ---") | |
| def load_managed_models(self, required_models: List[str], active_loras: List[Dict[str, Any]], progress) -> Dict[str, Any]: | |
| required_set = set(required_models) | |
| current_set = set(self.loaded_models.keys()) | |
| loras_changed = len(active_loras) > 0 or len(current_set - required_set) > 0 | |
| models_to_unload = current_set - required_set | |
| if models_to_unload or loras_changed: | |
| if models_to_unload: | |
| print(f"--- [ModelManager] Models to unload: {models_to_unload} ---") | |
| if loras_changed and not models_to_unload: | |
| models_to_unload = current_set.intersection(required_set) | |
| print(f"--- [ModelManager] LoRA configuration changed. Reloading base model(s): {models_to_unload} ---") | |
| model_management.unload_all_models() | |
| self.loaded_models.clear() | |
| gc.collect() | |
| torch.cuda.empty_cache() | |
| print("--- [ModelManager] All models unloaded to free RAM. ---") | |
| models_to_load = required_set if (models_to_unload or loras_changed) else required_set - current_set | |
| if models_to_load: | |
| print(f"--- [ModelManager] Models to load: {models_to_load} ---") | |
| for i, display_name in enumerate(models_to_load): | |
| progress(i / len(models_to_load), desc=f"Loading model: {display_name}") | |
| try: | |
| loaded_model_data = self._load_single_model(display_name, progress) | |
| if active_loras and ALL_MODEL_MAP[display_name][2] in ["SDXL", "SD1.5"]: | |
| print(f"--- [ModelManager] Applying {len(active_loras)} LoRAs on CPU... ---") | |
| lora_loader = LoraLoader() | |
| patched_model, patched_clip = loaded_model_data[0], loaded_model_data[1] | |
| for lora_info in active_loras: | |
| patched_model, patched_clip = lora_loader.load_lora( | |
| model=patched_model, | |
| clip=patched_clip, | |
| lora_name=lora_info["lora_name"], | |
| strength_model=lora_info["strength_model"], | |
| strength_clip=lora_info["strength_clip"] | |
| ) | |
| loaded_model_data = (patched_model, patched_clip, loaded_model_data[2]) | |
| print(f"--- [ModelManager] β All LoRAs merged into the model on CPU. ---") | |
| self.loaded_models[display_name] = loaded_model_data | |
| except Exception as e: | |
| raise gr.Error(f"Failed to load model or apply LoRA '{display_name}'. Reason: {e}") | |
| else: | |
| print(f"--- [ModelManager] All required models are already loaded. ---") | |
| return {name: self.loaded_models[name] for name in required_models} | |
| model_manager = ModelManager() |