ImageGen-Illstrious / core /model_manager.py
RioShiina's picture
Upload folder using huggingface_hub
5b29993 verified
raw
history blame
7.09 kB
import gc
from typing import Dict, List, Any, Set
import torch
import gradio as gr
from comfy import model_management
from core.settings import ALL_MODEL_MAP, CHECKPOINT_DIR, LORA_DIR, DIFFUSION_MODELS_DIR, VAE_DIR, TEXT_ENCODERS_DIR
from comfy_integration.nodes import checkpointloadersimple, LoraLoader
from nodes import NODE_CLASS_MAPPINGS
from utils.app_utils import get_value_at_index, _ensure_model_downloaded
class ModelManager:
_instance = None
def __new__(cls, *args, **kwargs):
if not cls._instance:
cls._instance = super(ModelManager, cls).__new__(cls, *args, **kwargs)
return cls._instance
def __init__(self):
if hasattr(self, 'initialized'):
return
self.loaded_models: Dict[str, Any] = {}
self.initialized = True
print("βœ… ModelManager initialized.")
def get_loaded_model_names(self) -> Set[str]:
return set(self.loaded_models.keys())
def _load_single_model(self, display_name: str, progress) -> Any:
print(f"--- [ModelManager] Loading model: '{display_name}' ---")
filename = _ensure_model_downloaded(display_name, progress)
_, _, model_type, _ = ALL_MODEL_MAP[display_name]
loader_map = {
"SDXL": (checkpointloadersimple, "load_checkpoint", {"ckpt_name": filename}),
"SD1.5": (checkpointloadersimple, "load_checkpoint", {"ckpt_name": filename}),
"UNET": (NODE_CLASS_MAPPINGS["UNETLoader"](), "load_unet", {"unet_name": filename, "weight_dtype": "default"}),
"VAE": (NODE_CLASS_MAPPINGS["VAELoader"](), "load_vae", {"vae_name": filename}),
"TEXT_ENCODER": (NODE_CLASS_MAPPINGS["CLIPLoader"](), "load_clip", {"clip_name": filename, "type": "wan", "device": "default"}),
}
if model_type not in loader_map:
if model_type == "LORA":
print(f"--- [ModelManager] βœ… '{display_name}' is a LoRA. It will be loaded dynamically. ---")
return (filename,)
raise ValueError(f"[ModelManager] No loader configured for model type '{model_type}'")
loader_instance, method_name, kwargs = loader_map[model_type]
load_method = getattr(loader_instance, method_name)
loaded_tuple = load_method(**kwargs)
print(f"--- [ModelManager] βœ… Successfully loaded '{display_name}' to CPU/RAM ---")
return loaded_tuple
def move_models_to_gpu(self, required_models: List[str]):
print(f"--- [ModelManager] Moving models to GPU: {required_models} ---")
models_to_load_gpu = []
for name in required_models:
if name in self.loaded_models:
model_tuple = self.loaded_models[name]
_, _, model_type, _ = ALL_MODEL_MAP[name]
if model_type in ["SDXL", "SD1.5"]:
models_to_load_gpu.append(get_value_at_index(model_tuple, 0))
if models_to_load_gpu:
model_management.load_models_gpu(models_to_load_gpu)
print("--- [ModelManager] βœ… Models successfully moved to GPU. ---")
else:
print("--- [ModelManager] ⚠️ No checkpoint models found to move to GPU. ---")
def ensure_models_downloaded(self, required_models: List[str], progress):
print(f"--- [ModelManager] Ensuring models are downloaded: {required_models} ---")
for i, display_name in enumerate(required_models):
if progress and hasattr(progress, '__call__'):
progress(i / len(required_models), desc=f"Checking file: {display_name}")
try:
_ensure_model_downloaded(display_name, progress)
except Exception as e:
raise gr.Error(f"Failed to download model '{display_name}'. Reason: {e}")
print(f"--- [ModelManager] βœ… All required models are present on disk. ---")
def load_managed_models(self, required_models: List[str], active_loras: List[Dict[str, Any]], progress) -> Dict[str, Any]:
required_set = set(required_models)
current_set = set(self.loaded_models.keys())
loras_changed = len(active_loras) > 0 or len(current_set - required_set) > 0
models_to_unload = current_set - required_set
if models_to_unload or loras_changed:
if models_to_unload:
print(f"--- [ModelManager] Models to unload: {models_to_unload} ---")
if loras_changed and not models_to_unload:
models_to_unload = current_set.intersection(required_set)
print(f"--- [ModelManager] LoRA configuration changed. Reloading base model(s): {models_to_unload} ---")
model_management.unload_all_models()
self.loaded_models.clear()
gc.collect()
torch.cuda.empty_cache()
print("--- [ModelManager] All models unloaded to free RAM. ---")
models_to_load = required_set if (models_to_unload or loras_changed) else required_set - current_set
if models_to_load:
print(f"--- [ModelManager] Models to load: {models_to_load} ---")
for i, display_name in enumerate(models_to_load):
progress(i / len(models_to_load), desc=f"Loading model: {display_name}")
try:
loaded_model_data = self._load_single_model(display_name, progress)
if active_loras and ALL_MODEL_MAP[display_name][2] in ["SDXL", "SD1.5"]:
print(f"--- [ModelManager] Applying {len(active_loras)} LoRAs on CPU... ---")
lora_loader = LoraLoader()
patched_model, patched_clip = loaded_model_data[0], loaded_model_data[1]
for lora_info in active_loras:
patched_model, patched_clip = lora_loader.load_lora(
model=patched_model,
clip=patched_clip,
lora_name=lora_info["lora_name"],
strength_model=lora_info["strength_model"],
strength_clip=lora_info["strength_clip"]
)
loaded_model_data = (patched_model, patched_clip, loaded_model_data[2])
print(f"--- [ModelManager] βœ… All LoRAs merged into the model on CPU. ---")
self.loaded_models[display_name] = loaded_model_data
except Exception as e:
raise gr.Error(f"Failed to load model or apply LoRA '{display_name}'. Reason: {e}")
else:
print(f"--- [ModelManager] All required models are already loaded. ---")
return {name: self.loaded_models[name] for name in required_models}
model_manager = ModelManager()