ImageGen-Illstrious / core /workflow_assembler.py
RioShiina's picture
Upload folder using huggingface_hub
5b29993 verified
import yaml
import os
import importlib
from copy import deepcopy
from comfy_integration.nodes import NODE_CLASS_MAPPINGS
class WorkflowAssembler:
def __init__(self, recipe_path, dynamic_values=None):
self.base_path = os.path.dirname(recipe_path)
self.node_counter = 0
self.workflow = {}
self.node_map = {}
self._load_injector_config()
self.recipe = self._load_and_merge_recipe(os.path.basename(recipe_path), dynamic_values or {})
def _load_injector_config(self):
try:
project_root = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
injectors_path = os.path.join(project_root, 'yaml', 'injectors.yaml')
with open(injectors_path, 'r', encoding='utf-8') as f:
injector_config = yaml.safe_load(f)
definitions = injector_config.get("injector_definitions", {})
self.injector_order = injector_config.get("injector_order", [])
self.global_injectors = {}
for chain_type, config in definitions.items():
module_path = config.get("module")
if not module_path:
print(f"Warning: Injector '{chain_type}' in injectors.yaml is missing 'module' path.")
continue
try:
module = importlib.import_module(module_path)
if hasattr(module, 'inject'):
self.global_injectors[chain_type] = module.inject
print(f"✅ Successfully registered global injector: {chain_type} from {module_path}")
else:
print(f"⚠️ Warning: Module '{module_path}' for injector '{chain_type}' does not have an 'inject' function.")
except ImportError as e:
print(f"❌ Error importing module '{module_path}' for injector '{chain_type}': {e}")
if not self.injector_order:
print("⚠️ Warning: 'injector_order' is not defined in injectors.yaml. Using definition order.")
self.injector_order = list(definitions.keys())
except FileNotFoundError:
print(f"❌ FATAL: Could not find injectors.yaml at {injectors_path}. Dynamic chains will not work.")
self.injector_order = []
self.global_injectors = {}
except Exception as e:
print(f"❌ FATAL: Could not load or parse injectors.yaml. Dynamic chains will not work. Error: {e}")
self.injector_order = []
self.global_injectors = {}
def _get_unique_id(self):
self.node_counter += 1
return str(self.node_counter)
def _get_node_template(self, class_type):
if class_type not in NODE_CLASS_MAPPINGS:
raise ValueError(f"Node class '{class_type}' not found. Ensure it's correctly imported in comfy_integration/nodes.py.")
node_class = NODE_CLASS_MAPPINGS[class_type]
input_types = node_class.INPUT_TYPES()
template = {
"inputs": {},
"class_type": class_type,
"_meta": {"title": node_class.NODE_NAME if hasattr(node_class, 'NODE_NAME') else class_type}
}
all_inputs = {**input_types.get('required', {}), **input_types.get('optional', {})}
for name, details in all_inputs.items():
config = details[1] if len(details) > 1 and isinstance(details[1], dict) else {}
template["inputs"][name] = config.get("default")
return template
def _load_and_merge_recipe(self, recipe_filename, dynamic_values, search_context_dir=None):
search_path = search_context_dir or self.base_path
recipe_path_to_use = os.path.join(search_path, recipe_filename)
if not os.path.exists(recipe_path_to_use):
raise FileNotFoundError(f"Recipe file not found: {recipe_path_to_use}")
with open(recipe_path_to_use, 'r', encoding='utf-8') as f:
content = f.read()
for key, value in dynamic_values.items():
if value is not None:
content = content.replace(f"{{{{ {key} }}}}", str(value))
main_recipe = yaml.safe_load(content)
merged_recipe = {'nodes': {}, 'connections': [], 'ui_map': {}}
for key in self.injector_order:
if key.startswith('dynamic_'):
merged_recipe[key] = {}
parent_recipe_dir = os.path.dirname(recipe_path_to_use)
for import_path_template in main_recipe.get('imports', []):
import_path = import_path_template
for key, value in dynamic_values.items():
if value is not None:
import_path = import_path.replace(f"{{{{ {key} }}}}", str(value))
try:
imported_recipe = self._load_and_merge_recipe(import_path, dynamic_values, search_context_dir=parent_recipe_dir)
merged_recipe['nodes'].update(imported_recipe.get('nodes', {}))
merged_recipe['connections'].extend(imported_recipe.get('connections', []))
merged_recipe['ui_map'].update(imported_recipe.get('ui_map', {}))
for key in self.injector_order:
if key in imported_recipe and key.startswith('dynamic_'):
merged_recipe[key].update(imported_recipe.get(key, {}))
except FileNotFoundError:
print(f"Warning: Optional recipe partial '{import_path}' not found. Skipping.")
merged_recipe['nodes'].update(main_recipe.get('nodes', {}))
merged_recipe['connections'].extend(main_recipe.get('connections', []))
merged_recipe['ui_map'].update(main_recipe.get('ui_map', {}))
for key in self.injector_order:
if key in main_recipe and key.startswith('dynamic_'):
merged_recipe[key].update(main_recipe.get(key, {}))
return merged_recipe
def assemble(self, ui_values):
for name, details in self.recipe['nodes'].items():
class_type = details['class_type']
template = self._get_node_template(class_type)
node_data = deepcopy(template)
unique_id = self._get_unique_id()
self.node_map[name] = unique_id
if 'params' in details:
for param, value in details['params'].items():
if param in node_data['inputs']:
node_data['inputs'][param] = value
self.workflow[unique_id] = node_data
for ui_key, target in self.recipe.get('ui_map', {}).items():
if ui_key in ui_values and ui_values[ui_key] is not None:
target_list = target if isinstance(target, list) else [target]
for t in target_list:
target_name, target_param = t.split(':')
if target_name in self.node_map:
self.workflow[self.node_map[target_name]]['inputs'][target_param] = ui_values[ui_key]
for conn in self.recipe.get('connections', []):
from_name, from_output_idx = conn['from'].split(':')
to_name, to_input_name = conn['to'].split(':')
from_id = self.node_map.get(from_name)
to_id = self.node_map.get(to_name)
if from_id and to_id:
self.workflow[to_id]['inputs'][to_input_name] = [from_id, int(from_output_idx)]
print("--- [Assembler] Applying dynamic injectors ---")
recipe_chain_types = {key for key in self.recipe if key.startswith('dynamic_')}
processing_order = [key for key in self.injector_order if key in recipe_chain_types]
for chain_type in processing_order:
injector_func = self.global_injectors.get(chain_type)
if injector_func:
for chain_key, chain_def in self.recipe.get(chain_type, {}).items():
if chain_key in ui_values and ui_values[chain_key]:
print(f" -> Injecting '{chain_type}' for '{chain_key}'...")
chain_items = ui_values[chain_key]
injector_func(self, chain_def, chain_items)
print("--- [Assembler] Finished applying injectors ---")
return self.workflow