File size: 8,650 Bytes
5b29993
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
import yaml
import os
import importlib
from copy import deepcopy
from comfy_integration.nodes import NODE_CLASS_MAPPINGS

class WorkflowAssembler:
    def __init__(self, recipe_path, dynamic_values=None):
        self.base_path = os.path.dirname(recipe_path)
        self.node_counter = 0
        self.workflow = {}
        self.node_map = {}
        
        self._load_injector_config()
        
        self.recipe = self._load_and_merge_recipe(os.path.basename(recipe_path), dynamic_values or {})

    def _load_injector_config(self):
        try:
            project_root = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
            injectors_path = os.path.join(project_root, 'yaml', 'injectors.yaml')
            
            with open(injectors_path, 'r', encoding='utf-8') as f:
                injector_config = yaml.safe_load(f)

            definitions = injector_config.get("injector_definitions", {})
            self.injector_order = injector_config.get("injector_order", [])
            self.global_injectors = {}

            for chain_type, config in definitions.items():
                module_path = config.get("module")
                if not module_path:
                    print(f"Warning: Injector '{chain_type}' in injectors.yaml is missing 'module' path.")
                    continue
                try:
                    module = importlib.import_module(module_path)
                    if hasattr(module, 'inject'):
                        self.global_injectors[chain_type] = module.inject
                        print(f"✅ Successfully registered global injector: {chain_type} from {module_path}")
                    else:
                        print(f"⚠️ Warning: Module '{module_path}' for injector '{chain_type}' does not have an 'inject' function.")
                except ImportError as e:
                    print(f"❌ Error importing module '{module_path}' for injector '{chain_type}': {e}")
            
            if not self.injector_order:
                 print("⚠️ Warning: 'injector_order' is not defined in injectors.yaml. Using definition order.")
                 self.injector_order = list(definitions.keys())

        except FileNotFoundError:
            print(f"❌ FATAL: Could not find injectors.yaml at {injectors_path}. Dynamic chains will not work.")
            self.injector_order = []
            self.global_injectors = {}
        except Exception as e:
            print(f"❌ FATAL: Could not load or parse injectors.yaml. Dynamic chains will not work. Error: {e}")
            self.injector_order = []
            self.global_injectors = {}

    def _get_unique_id(self):
        self.node_counter += 1
        return str(self.node_counter)

    def _get_node_template(self, class_type):
        if class_type not in NODE_CLASS_MAPPINGS:
            raise ValueError(f"Node class '{class_type}' not found. Ensure it's correctly imported in comfy_integration/nodes.py.")
        
        node_class = NODE_CLASS_MAPPINGS[class_type]
        input_types = node_class.INPUT_TYPES()
        
        template = {
            "inputs": {},
            "class_type": class_type,
            "_meta": {"title": node_class.NODE_NAME if hasattr(node_class, 'NODE_NAME') else class_type}
        }
        
        all_inputs = {**input_types.get('required', {}), **input_types.get('optional', {})}
        for name, details in all_inputs.items():
            config = details[1] if len(details) > 1 and isinstance(details[1], dict) else {}
            template["inputs"][name] = config.get("default")
            
        return template

    def _load_and_merge_recipe(self, recipe_filename, dynamic_values, search_context_dir=None):
        search_path = search_context_dir or self.base_path
        recipe_path_to_use = os.path.join(search_path, recipe_filename)

        if not os.path.exists(recipe_path_to_use):
            raise FileNotFoundError(f"Recipe file not found: {recipe_path_to_use}")

        with open(recipe_path_to_use, 'r', encoding='utf-8') as f:
            content = f.read()

        for key, value in dynamic_values.items():
            if value is not None:
                content = content.replace(f"{{{{ {key} }}}}", str(value))
        
        main_recipe = yaml.safe_load(content)
        
        merged_recipe = {'nodes': {}, 'connections': [], 'ui_map': {}}
        for key in self.injector_order:
            if key.startswith('dynamic_'):
                merged_recipe[key] = {}

        parent_recipe_dir = os.path.dirname(recipe_path_to_use)
        for import_path_template in main_recipe.get('imports', []):
            import_path = import_path_template
            for key, value in dynamic_values.items():
                if value is not None:
                    import_path = import_path.replace(f"{{{{ {key} }}}}", str(value))
            
            try:
                imported_recipe = self._load_and_merge_recipe(import_path, dynamic_values, search_context_dir=parent_recipe_dir)
                merged_recipe['nodes'].update(imported_recipe.get('nodes', {}))
                merged_recipe['connections'].extend(imported_recipe.get('connections', []))
                merged_recipe['ui_map'].update(imported_recipe.get('ui_map', {}))
                for key in self.injector_order:
                    if key in imported_recipe and key.startswith('dynamic_'):
                        merged_recipe[key].update(imported_recipe.get(key, {}))
            except FileNotFoundError:
                print(f"Warning: Optional recipe partial '{import_path}' not found. Skipping.")

        merged_recipe['nodes'].update(main_recipe.get('nodes', {}))
        merged_recipe['connections'].extend(main_recipe.get('connections', []))
        merged_recipe['ui_map'].update(main_recipe.get('ui_map', {}))
        for key in self.injector_order:
            if key in main_recipe and key.startswith('dynamic_'):
                merged_recipe[key].update(main_recipe.get(key, {}))
        
        return merged_recipe

    def assemble(self, ui_values):
        for name, details in self.recipe['nodes'].items():
            class_type = details['class_type']
            template = self._get_node_template(class_type)
            node_data = deepcopy(template)
            
            unique_id = self._get_unique_id()
            self.node_map[name] = unique_id
            
            if 'params' in details:
                for param, value in details['params'].items():
                    if param in node_data['inputs']:
                        node_data['inputs'][param] = value
            
            self.workflow[unique_id] = node_data

        for ui_key, target in self.recipe.get('ui_map', {}).items():
            if ui_key in ui_values and ui_values[ui_key] is not None:
                target_list = target if isinstance(target, list) else [target]
                for t in target_list:
                    target_name, target_param = t.split(':')
                    if target_name in self.node_map:
                        self.workflow[self.node_map[target_name]]['inputs'][target_param] = ui_values[ui_key]

        for conn in self.recipe.get('connections', []):
            from_name, from_output_idx = conn['from'].split(':')
            to_name, to_input_name = conn['to'].split(':')
            
            from_id = self.node_map.get(from_name)
            to_id = self.node_map.get(to_name)
            
            if from_id and to_id:
                self.workflow[to_id]['inputs'][to_input_name] = [from_id, int(from_output_idx)]
        
        print("--- [Assembler] Applying dynamic injectors ---")
        recipe_chain_types = {key for key in self.recipe if key.startswith('dynamic_')}
        processing_order = [key for key in self.injector_order if key in recipe_chain_types]

        for chain_type in processing_order:
            injector_func = self.global_injectors.get(chain_type)
            if injector_func:
                for chain_key, chain_def in self.recipe.get(chain_type, {}).items():
                    if chain_key in ui_values and ui_values[chain_key]:
                        print(f"  -> Injecting '{chain_type}' for '{chain_key}'...")
                        chain_items = ui_values[chain_key]
                        injector_func(self, chain_def, chain_items)
        
        print("--- [Assembler] Finished applying injectors ---")
        
        return self.workflow