Spaces:
Sleeping
Sleeping
| from typing import Dict, Any, List | |
| import imageio | |
| import tempfile | |
| import numpy as np | |
| import torch | |
| import gradio as gr | |
| from PIL import Image | |
| import spaces | |
| from .base_pipeline import BasePipeline | |
| from comfy_integration.nodes import NODE_CLASS_MAPPINGS | |
| from nodes import NODE_DISPLAY_NAME_MAPPINGS | |
| from utils.app_utils import get_value_at_index | |
| REVERSE_DISPLAY_NAME_MAP = None | |
| CPU_ONLY_PREPROCESSORS = { | |
| "Binary Lines", "Canny Edge", "Color Pallete", "Fake Scribble Lines (aka scribble_hed)", | |
| "Image Intensity", "Image Luminance", "Inpaint Preprocessor", "PyraCanny", "Scribble Lines", | |
| "Scribble XDoG Lines", "Standard Lineart", "Content Shuffle", "Tile" | |
| } | |
| def run_node_by_function_name(node_instance: Any, **kwargs) -> Any: | |
| node_class = type(node_instance) | |
| function_name = getattr(node_class, 'FUNCTION', None) | |
| if not function_name: | |
| raise AttributeError(f"Node class '{node_class.__name__}' is missing the required 'FUNCTION' attribute.") | |
| execution_method = getattr(node_instance, function_name, None) | |
| if not callable(execution_method): | |
| raise AttributeError(f"Method '{function_name}' not found or not callable on node '{node_class.__name__}'.") | |
| return execution_method(**kwargs) | |
| class ControlNetPreprocessorPipeline(BasePipeline): | |
| def get_required_models(self, **kwargs) -> List[str]: | |
| return [] | |
| def _gpu_logic( | |
| self, pil_images: List[Image.Image], preprocessor_name: str, model_name: str, | |
| params: Dict[str, Any], progress=gr.Progress(track_tqdm=True) | |
| ) -> List[Image.Image]: | |
| global REVERSE_DISPLAY_NAME_MAP | |
| if REVERSE_DISPLAY_NAME_MAP is None: | |
| raise RuntimeError("REVERSE_DISPLAY_NAME_MAP has not been initialized. `build_reverse_map` must be called on startup.") | |
| class_name = REVERSE_DISPLAY_NAME_MAP.get(preprocessor_name) | |
| if not class_name or class_name not in NODE_CLASS_MAPPINGS: | |
| raise ValueError(f"Preprocessor '{preprocessor_name}' not found.") | |
| preprocessor_instance = NODE_CLASS_MAPPINGS[class_name]() | |
| call_args = {**params, 'ckpt_name': model_name} | |
| processed_pil_images = [] | |
| total_frames = len(pil_images) | |
| for i, frame_pil in enumerate(pil_images): | |
| progress(i / total_frames, desc=f"Processing frame {i+1}/{total_frames} with {preprocessor_name}...") | |
| frame_tensor = torch.from_numpy(np.array(frame_pil).astype(np.float32) / 255.0).unsqueeze(0) | |
| resolution_arg = {'resolution': max(frame_tensor.shape[2], frame_tensor.shape[3])} | |
| result_tuple = run_node_by_function_name( | |
| preprocessor_instance, | |
| image=frame_tensor, | |
| **resolution_arg, | |
| **call_args | |
| ) | |
| processed_tensor = get_value_at_index(result_tuple, 0) | |
| processed_np = (processed_tensor.squeeze(0).cpu().numpy().clip(0, 1) * 255.0).astype(np.uint8) | |
| processed_pil_images.append(Image.fromarray(processed_np)) | |
| return processed_pil_images | |
| def run(self, input_type, image_input, video_input, preprocessor_name, model_name, zero_gpu_duration, *args, progress=gr.Progress(track_tqdm=True)): | |
| from utils import app_utils | |
| pil_images, is_video, fps = [], False, 30 | |
| progress(0, desc="Reading input file...") | |
| if input_type == "Image": | |
| if image_input is None: raise gr.Error("Please provide an input image.") | |
| pil_images = [image_input] | |
| elif input_type == "Video": | |
| if video_input is None: raise gr.Error("Please provide an input video.") | |
| try: | |
| video_reader = imageio.get_reader(video_input) | |
| meta = video_reader.get_meta_data() | |
| fps = meta.get('fps', 30) | |
| pil_images = [Image.fromarray(frame) for frame in video_reader] | |
| is_video = True | |
| video_reader.close() | |
| except Exception as e: raise gr.Error(f"Failed to read video file: {e}") | |
| else: | |
| raise gr.Error("Invalid input type selected.") | |
| if not pil_images: raise gr.Error("Could not extract any frames from the input.") | |
| if app_utils.PREPROCESSOR_PARAMETER_MAP is None: | |
| raise RuntimeError("Preprocessor parameter map is not built. Check startup logs.") | |
| params_config = app_utils.PREPROCESSOR_PARAMETER_MAP.get(preprocessor_name, []) | |
| sliders_params = [p for p in params_config if p['type'] in ["INT", "FLOAT"]] | |
| dropdown_params = [p for p in params_config if isinstance(p['type'], list)] | |
| checkbox_params = [p for p in params_config if p['type'] == "BOOLEAN"] | |
| ordered_params_config = sliders_params + dropdown_params + checkbox_params | |
| param_names = [p['name'] for p in ordered_params_config] | |
| provided_params = {param_names[i]: args[i] for i in range(len(param_names))} | |
| if preprocessor_name not in CPU_ONLY_PREPROCESSORS: | |
| print(f"--- '{preprocessor_name}' requires GPU, requesting ZeroGPU. ---") | |
| try: | |
| processed_pil_images = self._execute_gpu_logic( | |
| self._gpu_logic, | |
| duration=zero_gpu_duration, | |
| default_duration=60, | |
| task_name=f"Preprocessor '{preprocessor_name}'", | |
| pil_images=pil_images, | |
| preprocessor_name=preprocessor_name, | |
| model_name=model_name, | |
| params=provided_params, | |
| progress=progress | |
| ) | |
| except Exception as e: | |
| import traceback; traceback.print_exc() | |
| raise gr.Error(f"Failed to run preprocessor '{preprocessor_name}' on GPU: {e}") | |
| else: | |
| print(f"--- Running '{preprocessor_name}' on CPU, no ZeroGPU requested. ---") | |
| try: | |
| processed_pil_images = self._gpu_logic(pil_images, preprocessor_name, model_name, provided_params, progress=progress) | |
| except Exception as e: | |
| import traceback; traceback.print_exc() | |
| raise gr.Error(f"Failed to run preprocessor '{preprocessor_name}' on CPU: {e}") | |
| if not processed_pil_images: raise gr.Error("Processing returned no frames.") | |
| progress(0.9, desc="Finalizing output...") | |
| if is_video: | |
| frames_np = [np.array(img) for img in processed_pil_images] | |
| frames_tensor = torch.from_numpy(np.stack(frames_np)).to(torch.float32) / 255.0 | |
| video_path = self._encode_video_from_frames(frames_tensor, fps, progress) | |
| return [video_path] | |
| else: | |
| progress(1.0, desc="Done!") | |
| return processed_pil_images |