File size: 4,701 Bytes
5b29993
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
import yaml
import os
from collections import OrderedDict

CHECKPOINT_DIR = "models/checkpoints"
LORA_DIR = "models/loras"
EMBEDDING_DIR = "models/embeddings" 
CONTROLNET_DIR = "models/controlnet"
DIFFUSION_MODELS_DIR = "models/diffusion_models"
VAE_DIR = "models/vae"
TEXT_ENCODERS_DIR = "models/text_encoders"
INPUT_DIR = "input"
OUTPUT_DIR = "output"

_PROJECT_ROOT = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
_MODEL_LIST_PATH = os.path.join(_PROJECT_ROOT, 'yaml', 'model_list.yaml')
_FILE_LIST_PATH = os.path.join(_PROJECT_ROOT, 'yaml', 'file_list.yaml')
_IPADAPTER_LIST_PATH = os.path.join(_PROJECT_ROOT, 'yaml', 'ipadapter.yaml')
_CONSTANTS_PATH = os.path.join(_PROJECT_ROOT, 'yaml', 'constants.yaml')


def load_constants_from_yaml(filepath=_CONSTANTS_PATH):
    if not os.path.exists(filepath):
        print(f"Warning: Constants file not found at {filepath}. Using fallback values.")
        return {}
    with open(filepath, 'r', encoding='utf-8') as f:
        return yaml.safe_load(f)

def load_file_download_map(filepath=_FILE_LIST_PATH):
    if not os.path.exists(filepath):
        raise FileNotFoundError(f"The file list (for downloads) was not found at: {filepath}")
    
    with open(filepath, 'r', encoding='utf-8') as f:
        file_list_data = yaml.safe_load(f)
    
    download_info_map = {}
    for category, files in file_list_data.get('file', {}).items():
        if isinstance(files, list):
            for file_info in files:
                if 'filename' in file_info:
                    file_info['category'] = category
                    download_info_map[file_info['filename']] = file_info
    return download_info_map


def load_models_from_yaml(model_list_filepath=_MODEL_LIST_PATH, download_map=None):
    if not os.path.exists(model_list_filepath):
        raise FileNotFoundError(f"The model list file was not found at: {model_list_filepath}")
    if download_map is None:
        raise ValueError("download_map must be provided to load_models_from_yaml")

    with open(model_list_filepath, 'r', encoding='utf-8') as f:
        model_data = yaml.safe_load(f)

    model_maps = {
        "MODEL_MAP_CHECKPOINT": OrderedDict(),
        "ALL_MODEL_MAP": OrderedDict(),
    }
    category_map_names = {
        "Checkpoint": "MODEL_MAP_CHECKPOINT",
    }
    
    for category, models in model_data.items():
        if category in category_map_names:
            map_name = category_map_names[category]
            if not isinstance(models, list): continue
            for model in models:
                display_name = model['display_name']
                filename = model['path']
                
                download_info = download_map.get(filename, {})
                repo_id = download_info.get('repo_id', '')

                model_tuple = (
                    repo_id,
                    filename,
                    "SDXL",
                    None
                )
                model_maps[map_name][display_name] = model_tuple
                model_maps["ALL_MODEL_MAP"][display_name] = model_tuple
    
    return model_maps

try:
    ALL_FILE_DOWNLOAD_MAP = load_file_download_map()
    loaded_maps = load_models_from_yaml(download_map=ALL_FILE_DOWNLOAD_MAP)
    MODEL_MAP_CHECKPOINT = loaded_maps["MODEL_MAP_CHECKPOINT"]
    ALL_MODEL_MAP = loaded_maps["ALL_MODEL_MAP"]
    
    MODEL_TYPE_MAP = {k: v[2] for k, v in ALL_MODEL_MAP.items()}

except Exception as e:
    print(f"FATAL: Could not load model configuration from YAML. Error: {e}")
    ALL_FILE_DOWNLOAD_MAP = {}
    MODEL_MAP_CHECKPOINT, ALL_MODEL_MAP = {}, {}
    MODEL_TYPE_MAP = {}


try:
    _constants = load_constants_from_yaml()
    MAX_LORAS = _constants.get('MAX_LORAS', 5)
    MAX_EMBEDDINGS = _constants.get('MAX_EMBEDDINGS', 5)
    MAX_CONDITIONINGS = _constants.get('MAX_CONDITIONINGS', 10)
    MAX_CONTROLNETS = _constants.get('MAX_CONTROLNETS', 5)
    MAX_IPADAPTERS = _constants.get('MAX_IPADAPTERS', 5)
    LORA_SOURCE_CHOICES = _constants.get('LORA_SOURCE_CHOICES', ["Civitai", "Custom URL", "File"])
    RESOLUTION_MAP = _constants.get('RESOLUTION_MAP', {})
    SAMPLER_MAP = _constants.get('SAMPLER_MAP', {})
except Exception as e:
    print(f"FATAL: Could not load constants from YAML. Error: {e}")
    MAX_LORAS, MAX_EMBEDDINGS, MAX_CONDITIONINGS, MAX_CONTROLNETS, MAX_IPADAPTERS = 5, 5, 10, 5, 5
    LORA_SOURCE_CHOICES = ["Civitai", "Custom URL", "File"]
    RESOLUTION_MAP, SAMPLER_MAP = {}, {}


DEFAULT_NEGATIVE_PROMPT = "monochrome, (low quality, worst quality:1.2), 3d, watermark, signature, ugly, poorly drawn,"