ImageGen-SD15 / core /pipelines /controlnet_preprocessor.py
RioShiina's picture
Upload folder using huggingface_hub
d6c92ff verified
raw
history blame
7.02 kB
from typing import Dict, Any, List
import imageio
import tempfile
import numpy as np
import torch
import gradio as gr
from PIL import Image
import spaces
from .base_pipeline import BasePipeline
from comfy_integration.nodes import NODE_CLASS_MAPPINGS
from nodes import NODE_DISPLAY_NAME_MAPPINGS
from utils.app_utils import get_value_at_index
REVERSE_DISPLAY_NAME_MAP = None
CPU_ONLY_PREPROCESSORS = {
"Binary Lines", "Canny Edge", "Color Pallete", "Fake Scribble Lines (aka scribble_hed)",
"Image Intensity", "Image Luminance", "Inpaint Preprocessor", "PyraCanny", "Scribble Lines",
"Scribble XDoG Lines", "Standard Lineart", "Content Shuffle", "Tile"
}
def run_node_by_function_name(node_instance: Any, **kwargs) -> Any:
node_class = type(node_instance)
function_name = getattr(node_class, 'FUNCTION', None)
if not function_name:
raise AttributeError(f"Node class '{node_class.__name__}' is missing the required 'FUNCTION' attribute.")
execution_method = getattr(node_instance, function_name, None)
if not callable(execution_method):
raise AttributeError(f"Method '{function_name}' not found or not callable on node '{node_class.__name__}'.")
return execution_method(**kwargs)
class ControlNetPreprocessorPipeline(BasePipeline):
def get_required_models(self, **kwargs) -> List[str]:
return []
def _gpu_logic(
self, pil_images: List[Image.Image], preprocessor_name: str, model_name: str,
params: Dict[str, Any], progress=gr.Progress(track_tqdm=True)
) -> List[Image.Image]:
global REVERSE_DISPLAY_NAME_MAP
if REVERSE_DISPLAY_NAME_MAP is None:
raise RuntimeError("REVERSE_DISPLAY_NAME_MAP has not been initialized. `build_reverse_map` must be called on startup.")
class_name = REVERSE_DISPLAY_NAME_MAP.get(preprocessor_name)
if not class_name or class_name not in NODE_CLASS_MAPPINGS:
raise ValueError(f"Preprocessor '{preprocessor_name}' not found.")
preprocessor_instance = NODE_CLASS_MAPPINGS[class_name]()
call_args = {**params, 'ckpt_name': model_name}
processed_pil_images = []
total_frames = len(pil_images)
for i, frame_pil in enumerate(pil_images):
progress(i / total_frames, desc=f"Processing frame {i+1}/{total_frames} with {preprocessor_name}...")
frame_tensor = torch.from_numpy(np.array(frame_pil).astype(np.float32) / 255.0).unsqueeze(0)
resolution_arg = {'resolution': max(frame_tensor.shape[2], frame_tensor.shape[3])}
result_tuple = run_node_by_function_name(
preprocessor_instance,
image=frame_tensor,
**resolution_arg,
**call_args
)
processed_tensor = get_value_at_index(result_tuple, 0)
processed_np = (processed_tensor.squeeze(0).cpu().numpy().clip(0, 1) * 255.0).astype(np.uint8)
processed_pil_images.append(Image.fromarray(processed_np))
return processed_pil_images
def run(self, input_type, image_input, video_input, preprocessor_name, model_name, zero_gpu_duration, *args, progress=gr.Progress(track_tqdm=True)):
from utils import app_utils
pil_images, is_video, fps = [], False, 30
progress(0, desc="Reading input file...")
if input_type == "Image":
if image_input is None: raise gr.Error("Please provide an input image.")
pil_images = [image_input]
elif input_type == "Video":
if video_input is None: raise gr.Error("Please provide an input video.")
try:
video_reader = imageio.get_reader(video_input)
meta = video_reader.get_meta_data()
fps = meta.get('fps', 30)
pil_images = [Image.fromarray(frame) for frame in video_reader]
is_video = True
video_reader.close()
except Exception as e: raise gr.Error(f"Failed to read video file: {e}")
else:
raise gr.Error("Invalid input type selected.")
if not pil_images: raise gr.Error("Could not extract any frames from the input.")
if app_utils.PREPROCESSOR_PARAMETER_MAP is None:
raise RuntimeError("Preprocessor parameter map is not built. Check startup logs.")
params_config = app_utils.PREPROCESSOR_PARAMETER_MAP.get(preprocessor_name, [])
sliders_params = [p for p in params_config if p['type'] in ["INT", "FLOAT"]]
dropdown_params = [p for p in params_config if isinstance(p['type'], list)]
checkbox_params = [p for p in params_config if p['type'] == "BOOLEAN"]
ordered_params_config = sliders_params + dropdown_params + checkbox_params
param_names = [p['name'] for p in ordered_params_config]
provided_params = {param_names[i]: args[i] for i in range(len(param_names))}
if preprocessor_name not in CPU_ONLY_PREPROCESSORS:
print(f"--- '{preprocessor_name}' requires GPU, requesting ZeroGPU. ---")
try:
processed_pil_images = self._execute_gpu_logic(
self._gpu_logic,
duration=zero_gpu_duration,
default_duration=60,
task_name=f"Preprocessor '{preprocessor_name}'",
pil_images=pil_images,
preprocessor_name=preprocessor_name,
model_name=model_name,
params=provided_params,
progress=progress
)
except Exception as e:
import traceback; traceback.print_exc()
raise gr.Error(f"Failed to run preprocessor '{preprocessor_name}' on GPU: {e}")
else:
print(f"--- Running '{preprocessor_name}' on CPU, no ZeroGPU requested. ---")
try:
processed_pil_images = self._gpu_logic(pil_images, preprocessor_name, model_name, provided_params, progress=progress)
except Exception as e:
import traceback; traceback.print_exc()
raise gr.Error(f"Failed to run preprocessor '{preprocessor_name}' on CPU: {e}")
if not processed_pil_images: raise gr.Error("Processing returned no frames.")
progress(0.9, desc="Finalizing output...")
if is_video:
frames_np = [np.array(img) for img in processed_pil_images]
frames_tensor = torch.from_numpy(np.stack(frames_np)).to(torch.float32) / 255.0
video_path = self._encode_video_from_frames(frames_tensor, fps, progress)
return [video_path]
else:
progress(1.0, desc="Done!")
return processed_pil_images