Spaces:
Sleeping
Sleeping
File size: 8,650 Bytes
5b29993 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 |
import yaml
import os
import importlib
from copy import deepcopy
from comfy_integration.nodes import NODE_CLASS_MAPPINGS
class WorkflowAssembler:
def __init__(self, recipe_path, dynamic_values=None):
self.base_path = os.path.dirname(recipe_path)
self.node_counter = 0
self.workflow = {}
self.node_map = {}
self._load_injector_config()
self.recipe = self._load_and_merge_recipe(os.path.basename(recipe_path), dynamic_values or {})
def _load_injector_config(self):
try:
project_root = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
injectors_path = os.path.join(project_root, 'yaml', 'injectors.yaml')
with open(injectors_path, 'r', encoding='utf-8') as f:
injector_config = yaml.safe_load(f)
definitions = injector_config.get("injector_definitions", {})
self.injector_order = injector_config.get("injector_order", [])
self.global_injectors = {}
for chain_type, config in definitions.items():
module_path = config.get("module")
if not module_path:
print(f"Warning: Injector '{chain_type}' in injectors.yaml is missing 'module' path.")
continue
try:
module = importlib.import_module(module_path)
if hasattr(module, 'inject'):
self.global_injectors[chain_type] = module.inject
print(f"✅ Successfully registered global injector: {chain_type} from {module_path}")
else:
print(f"⚠️ Warning: Module '{module_path}' for injector '{chain_type}' does not have an 'inject' function.")
except ImportError as e:
print(f"❌ Error importing module '{module_path}' for injector '{chain_type}': {e}")
if not self.injector_order:
print("⚠️ Warning: 'injector_order' is not defined in injectors.yaml. Using definition order.")
self.injector_order = list(definitions.keys())
except FileNotFoundError:
print(f"❌ FATAL: Could not find injectors.yaml at {injectors_path}. Dynamic chains will not work.")
self.injector_order = []
self.global_injectors = {}
except Exception as e:
print(f"❌ FATAL: Could not load or parse injectors.yaml. Dynamic chains will not work. Error: {e}")
self.injector_order = []
self.global_injectors = {}
def _get_unique_id(self):
self.node_counter += 1
return str(self.node_counter)
def _get_node_template(self, class_type):
if class_type not in NODE_CLASS_MAPPINGS:
raise ValueError(f"Node class '{class_type}' not found. Ensure it's correctly imported in comfy_integration/nodes.py.")
node_class = NODE_CLASS_MAPPINGS[class_type]
input_types = node_class.INPUT_TYPES()
template = {
"inputs": {},
"class_type": class_type,
"_meta": {"title": node_class.NODE_NAME if hasattr(node_class, 'NODE_NAME') else class_type}
}
all_inputs = {**input_types.get('required', {}), **input_types.get('optional', {})}
for name, details in all_inputs.items():
config = details[1] if len(details) > 1 and isinstance(details[1], dict) else {}
template["inputs"][name] = config.get("default")
return template
def _load_and_merge_recipe(self, recipe_filename, dynamic_values, search_context_dir=None):
search_path = search_context_dir or self.base_path
recipe_path_to_use = os.path.join(search_path, recipe_filename)
if not os.path.exists(recipe_path_to_use):
raise FileNotFoundError(f"Recipe file not found: {recipe_path_to_use}")
with open(recipe_path_to_use, 'r', encoding='utf-8') as f:
content = f.read()
for key, value in dynamic_values.items():
if value is not None:
content = content.replace(f"{{{{ {key} }}}}", str(value))
main_recipe = yaml.safe_load(content)
merged_recipe = {'nodes': {}, 'connections': [], 'ui_map': {}}
for key in self.injector_order:
if key.startswith('dynamic_'):
merged_recipe[key] = {}
parent_recipe_dir = os.path.dirname(recipe_path_to_use)
for import_path_template in main_recipe.get('imports', []):
import_path = import_path_template
for key, value in dynamic_values.items():
if value is not None:
import_path = import_path.replace(f"{{{{ {key} }}}}", str(value))
try:
imported_recipe = self._load_and_merge_recipe(import_path, dynamic_values, search_context_dir=parent_recipe_dir)
merged_recipe['nodes'].update(imported_recipe.get('nodes', {}))
merged_recipe['connections'].extend(imported_recipe.get('connections', []))
merged_recipe['ui_map'].update(imported_recipe.get('ui_map', {}))
for key in self.injector_order:
if key in imported_recipe and key.startswith('dynamic_'):
merged_recipe[key].update(imported_recipe.get(key, {}))
except FileNotFoundError:
print(f"Warning: Optional recipe partial '{import_path}' not found. Skipping.")
merged_recipe['nodes'].update(main_recipe.get('nodes', {}))
merged_recipe['connections'].extend(main_recipe.get('connections', []))
merged_recipe['ui_map'].update(main_recipe.get('ui_map', {}))
for key in self.injector_order:
if key in main_recipe and key.startswith('dynamic_'):
merged_recipe[key].update(main_recipe.get(key, {}))
return merged_recipe
def assemble(self, ui_values):
for name, details in self.recipe['nodes'].items():
class_type = details['class_type']
template = self._get_node_template(class_type)
node_data = deepcopy(template)
unique_id = self._get_unique_id()
self.node_map[name] = unique_id
if 'params' in details:
for param, value in details['params'].items():
if param in node_data['inputs']:
node_data['inputs'][param] = value
self.workflow[unique_id] = node_data
for ui_key, target in self.recipe.get('ui_map', {}).items():
if ui_key in ui_values and ui_values[ui_key] is not None:
target_list = target if isinstance(target, list) else [target]
for t in target_list:
target_name, target_param = t.split(':')
if target_name in self.node_map:
self.workflow[self.node_map[target_name]]['inputs'][target_param] = ui_values[ui_key]
for conn in self.recipe.get('connections', []):
from_name, from_output_idx = conn['from'].split(':')
to_name, to_input_name = conn['to'].split(':')
from_id = self.node_map.get(from_name)
to_id = self.node_map.get(to_name)
if from_id and to_id:
self.workflow[to_id]['inputs'][to_input_name] = [from_id, int(from_output_idx)]
print("--- [Assembler] Applying dynamic injectors ---")
recipe_chain_types = {key for key in self.recipe if key.startswith('dynamic_')}
processing_order = [key for key in self.injector_order if key in recipe_chain_types]
for chain_type in processing_order:
injector_func = self.global_injectors.get(chain_type)
if injector_func:
for chain_key, chain_def in self.recipe.get(chain_type, {}).items():
if chain_key in ui_values and ui_values[chain_key]:
print(f" -> Injecting '{chain_type}' for '{chain_key}'...")
chain_items = ui_values[chain_key]
injector_func(self, chain_def, chain_items)
print("--- [Assembler] Finished applying injectors ---")
return self.workflow |