File size: 7,018 Bytes
d6c92ff
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
from typing import Dict, Any, List
import imageio
import tempfile
import numpy as np
import torch
import gradio as gr
from PIL import Image
import spaces

from .base_pipeline import BasePipeline
from comfy_integration.nodes import NODE_CLASS_MAPPINGS
from nodes import NODE_DISPLAY_NAME_MAPPINGS
from utils.app_utils import get_value_at_index

REVERSE_DISPLAY_NAME_MAP = None
CPU_ONLY_PREPROCESSORS = {
    "Binary Lines", "Canny Edge", "Color Pallete", "Fake Scribble Lines (aka scribble_hed)",
    "Image Intensity", "Image Luminance", "Inpaint Preprocessor", "PyraCanny", "Scribble Lines",
    "Scribble XDoG Lines", "Standard Lineart", "Content Shuffle", "Tile"
}

def run_node_by_function_name(node_instance: Any, **kwargs) -> Any:
    node_class = type(node_instance)
    function_name = getattr(node_class, 'FUNCTION', None)
    if not function_name:
        raise AttributeError(f"Node class '{node_class.__name__}' is missing the required 'FUNCTION' attribute.")
    execution_method = getattr(node_instance, function_name, None)
    if not callable(execution_method):
        raise AttributeError(f"Method '{function_name}' not found or not callable on node '{node_class.__name__}'.")
    return execution_method(**kwargs)

class ControlNetPreprocessorPipeline(BasePipeline):
    def get_required_models(self, **kwargs) -> List[str]:
        return []

    def _gpu_logic(

        self, pil_images: List[Image.Image], preprocessor_name: str, model_name: str,

        params: Dict[str, Any], progress=gr.Progress(track_tqdm=True)

    ) -> List[Image.Image]:
        global REVERSE_DISPLAY_NAME_MAP
        if REVERSE_DISPLAY_NAME_MAP is None:
            raise RuntimeError("REVERSE_DISPLAY_NAME_MAP has not been initialized. `build_reverse_map` must be called on startup.")
            
        class_name = REVERSE_DISPLAY_NAME_MAP.get(preprocessor_name)
        if not class_name or class_name not in NODE_CLASS_MAPPINGS:
            raise ValueError(f"Preprocessor '{preprocessor_name}' not found.")
            
        preprocessor_instance = NODE_CLASS_MAPPINGS[class_name]()
        call_args = {**params, 'ckpt_name': model_name}
        
        processed_pil_images = []
        total_frames = len(pil_images)

        for i, frame_pil in enumerate(pil_images):
            progress(i / total_frames, desc=f"Processing frame {i+1}/{total_frames} with {preprocessor_name}...")
            
            frame_tensor = torch.from_numpy(np.array(frame_pil).astype(np.float32) / 255.0).unsqueeze(0)
            
            resolution_arg = {'resolution': max(frame_tensor.shape[2], frame_tensor.shape[3])}

            result_tuple = run_node_by_function_name(
                preprocessor_instance, 
                image=frame_tensor, 
                **resolution_arg,
                **call_args
            )
            
            processed_tensor = get_value_at_index(result_tuple, 0)
            processed_np = (processed_tensor.squeeze(0).cpu().numpy().clip(0, 1) * 255.0).astype(np.uint8)
            processed_pil_images.append(Image.fromarray(processed_np))
            
        return processed_pil_images

    def run(self, input_type, image_input, video_input, preprocessor_name, model_name, zero_gpu_duration, *args, progress=gr.Progress(track_tqdm=True)):
        from utils import app_utils
        pil_images, is_video, fps = [], False, 30
        
        progress(0, desc="Reading input file...")
        if input_type == "Image":
            if image_input is None: raise gr.Error("Please provide an input image.")
            pil_images = [image_input]
        elif input_type == "Video":
            if video_input is None: raise gr.Error("Please provide an input video.")
            try:
                video_reader = imageio.get_reader(video_input)
                meta = video_reader.get_meta_data()
                fps = meta.get('fps', 30)
                pil_images = [Image.fromarray(frame) for frame in video_reader]
                is_video = True
                video_reader.close()
            except Exception as e: raise gr.Error(f"Failed to read video file: {e}")
        else:
            raise gr.Error("Invalid input type selected.")

        if not pil_images: raise gr.Error("Could not extract any frames from the input.")
        
        if app_utils.PREPROCESSOR_PARAMETER_MAP is None:
            raise RuntimeError("Preprocessor parameter map is not built. Check startup logs.")
            
        params_config = app_utils.PREPROCESSOR_PARAMETER_MAP.get(preprocessor_name, [])
        sliders_params = [p for p in params_config if p['type'] in ["INT", "FLOAT"]]
        dropdown_params = [p for p in params_config if isinstance(p['type'], list)]
        checkbox_params = [p for p in params_config if p['type'] == "BOOLEAN"]
        ordered_params_config = sliders_params + dropdown_params + checkbox_params
        param_names = [p['name'] for p in ordered_params_config]
        provided_params = {param_names[i]: args[i] for i in range(len(param_names))}
        
        if preprocessor_name not in CPU_ONLY_PREPROCESSORS:
            print(f"--- '{preprocessor_name}' requires GPU, requesting ZeroGPU. ---")
            try:
                processed_pil_images = self._execute_gpu_logic(
                    self._gpu_logic,
                    duration=zero_gpu_duration,
                    default_duration=60,
                    task_name=f"Preprocessor '{preprocessor_name}'",
                    pil_images=pil_images,
                    preprocessor_name=preprocessor_name,
                    model_name=model_name,
                    params=provided_params,
                    progress=progress
                )
            except Exception as e:
                import traceback; traceback.print_exc()
                raise gr.Error(f"Failed to run preprocessor '{preprocessor_name}' on GPU: {e}")
        else:
            print(f"--- Running '{preprocessor_name}' on CPU, no ZeroGPU requested. ---")
            try:
                processed_pil_images = self._gpu_logic(pil_images, preprocessor_name, model_name, provided_params, progress=progress)
            except Exception as e:
                import traceback; traceback.print_exc()
                raise gr.Error(f"Failed to run preprocessor '{preprocessor_name}' on CPU: {e}")

        if not processed_pil_images: raise gr.Error("Processing returned no frames.")
        
        progress(0.9, desc="Finalizing output...")
        if is_video:
            frames_np = [np.array(img) for img in processed_pil_images]
            frames_tensor = torch.from_numpy(np.stack(frames_np)).to(torch.float32) / 255.0
            video_path = self._encode_video_from_frames(frames_tensor, fps, progress)
            return [video_path]
        else:
            progress(1.0, desc="Done!")
            return processed_pil_images