from typing import Dict, Any, List import imageio import tempfile import numpy as np import torch import gradio as gr from PIL import Image import spaces from .base_pipeline import BasePipeline from comfy_integration.nodes import NODE_CLASS_MAPPINGS from nodes import NODE_DISPLAY_NAME_MAPPINGS from utils.app_utils import get_value_at_index REVERSE_DISPLAY_NAME_MAP = None CPU_ONLY_PREPROCESSORS = { "Binary Lines", "Canny Edge", "Color Pallete", "Fake Scribble Lines (aka scribble_hed)", "Image Intensity", "Image Luminance", "Inpaint Preprocessor", "PyraCanny", "Scribble Lines", "Scribble XDoG Lines", "Standard Lineart", "Content Shuffle", "Tile" } def run_node_by_function_name(node_instance: Any, **kwargs) -> Any: node_class = type(node_instance) function_name = getattr(node_class, 'FUNCTION', None) if not function_name: raise AttributeError(f"Node class '{node_class.__name__}' is missing the required 'FUNCTION' attribute.") execution_method = getattr(node_instance, function_name, None) if not callable(execution_method): raise AttributeError(f"Method '{function_name}' not found or not callable on node '{node_class.__name__}'.") return execution_method(**kwargs) class ControlNetPreprocessorPipeline(BasePipeline): def get_required_models(self, **kwargs) -> List[str]: return [] def _gpu_logic( self, pil_images: List[Image.Image], preprocessor_name: str, model_name: str, params: Dict[str, Any], progress=gr.Progress(track_tqdm=True) ) -> List[Image.Image]: global REVERSE_DISPLAY_NAME_MAP if REVERSE_DISPLAY_NAME_MAP is None: raise RuntimeError("REVERSE_DISPLAY_NAME_MAP has not been initialized. `build_reverse_map` must be called on startup.") class_name = REVERSE_DISPLAY_NAME_MAP.get(preprocessor_name) if not class_name or class_name not in NODE_CLASS_MAPPINGS: raise ValueError(f"Preprocessor '{preprocessor_name}' not found.") preprocessor_instance = NODE_CLASS_MAPPINGS[class_name]() call_args = {**params, 'ckpt_name': model_name} processed_pil_images = [] total_frames = len(pil_images) for i, frame_pil in enumerate(pil_images): progress(i / total_frames, desc=f"Processing frame {i+1}/{total_frames} with {preprocessor_name}...") frame_tensor = torch.from_numpy(np.array(frame_pil).astype(np.float32) / 255.0).unsqueeze(0) resolution_arg = {'resolution': max(frame_tensor.shape[2], frame_tensor.shape[3])} result_tuple = run_node_by_function_name( preprocessor_instance, image=frame_tensor, **resolution_arg, **call_args ) processed_tensor = get_value_at_index(result_tuple, 0) processed_np = (processed_tensor.squeeze(0).cpu().numpy().clip(0, 1) * 255.0).astype(np.uint8) processed_pil_images.append(Image.fromarray(processed_np)) return processed_pil_images def run(self, input_type, image_input, video_input, preprocessor_name, model_name, zero_gpu_duration, *args, progress=gr.Progress(track_tqdm=True)): from utils import app_utils pil_images, is_video, fps = [], False, 30 progress(0, desc="Reading input file...") if input_type == "Image": if image_input is None: raise gr.Error("Please provide an input image.") pil_images = [image_input] elif input_type == "Video": if video_input is None: raise gr.Error("Please provide an input video.") try: video_reader = imageio.get_reader(video_input) meta = video_reader.get_meta_data() fps = meta.get('fps', 30) pil_images = [Image.fromarray(frame) for frame in video_reader] is_video = True video_reader.close() except Exception as e: raise gr.Error(f"Failed to read video file: {e}") else: raise gr.Error("Invalid input type selected.") if not pil_images: raise gr.Error("Could not extract any frames from the input.") if app_utils.PREPROCESSOR_PARAMETER_MAP is None: raise RuntimeError("Preprocessor parameter map is not built. Check startup logs.") params_config = app_utils.PREPROCESSOR_PARAMETER_MAP.get(preprocessor_name, []) sliders_params = [p for p in params_config if p['type'] in ["INT", "FLOAT"]] dropdown_params = [p for p in params_config if isinstance(p['type'], list)] checkbox_params = [p for p in params_config if p['type'] == "BOOLEAN"] ordered_params_config = sliders_params + dropdown_params + checkbox_params param_names = [p['name'] for p in ordered_params_config] provided_params = {param_names[i]: args[i] for i in range(len(param_names))} if preprocessor_name not in CPU_ONLY_PREPROCESSORS: print(f"--- '{preprocessor_name}' requires GPU, requesting ZeroGPU. ---") try: processed_pil_images = self._execute_gpu_logic( self._gpu_logic, duration=zero_gpu_duration, default_duration=60, task_name=f"Preprocessor '{preprocessor_name}'", pil_images=pil_images, preprocessor_name=preprocessor_name, model_name=model_name, params=provided_params, progress=progress ) except Exception as e: import traceback; traceback.print_exc() raise gr.Error(f"Failed to run preprocessor '{preprocessor_name}' on GPU: {e}") else: print(f"--- Running '{preprocessor_name}' on CPU, no ZeroGPU requested. ---") try: processed_pil_images = self._gpu_logic(pil_images, preprocessor_name, model_name, provided_params, progress=progress) except Exception as e: import traceback; traceback.print_exc() raise gr.Error(f"Failed to run preprocessor '{preprocessor_name}' on CPU: {e}") if not processed_pil_images: raise gr.Error("Processing returned no frames.") progress(0.9, desc="Finalizing output...") if is_video: frames_np = [np.array(img) for img in processed_pil_images] frames_tensor = torch.from_numpy(np.stack(frames_np)).to(torch.float32) / 255.0 video_path = self._encode_video_from_frames(frames_tensor, fps, progress) return [video_path] else: progress(1.0, desc="Done!") return processed_pil_images