Spaces:
Running
on
Zero
Running
on
Zero
File size: 7,018 Bytes
d6c92ff |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 |
from typing import Dict, Any, List
import imageio
import tempfile
import numpy as np
import torch
import gradio as gr
from PIL import Image
import spaces
from .base_pipeline import BasePipeline
from comfy_integration.nodes import NODE_CLASS_MAPPINGS
from nodes import NODE_DISPLAY_NAME_MAPPINGS
from utils.app_utils import get_value_at_index
REVERSE_DISPLAY_NAME_MAP = None
CPU_ONLY_PREPROCESSORS = {
"Binary Lines", "Canny Edge", "Color Pallete", "Fake Scribble Lines (aka scribble_hed)",
"Image Intensity", "Image Luminance", "Inpaint Preprocessor", "PyraCanny", "Scribble Lines",
"Scribble XDoG Lines", "Standard Lineart", "Content Shuffle", "Tile"
}
def run_node_by_function_name(node_instance: Any, **kwargs) -> Any:
node_class = type(node_instance)
function_name = getattr(node_class, 'FUNCTION', None)
if not function_name:
raise AttributeError(f"Node class '{node_class.__name__}' is missing the required 'FUNCTION' attribute.")
execution_method = getattr(node_instance, function_name, None)
if not callable(execution_method):
raise AttributeError(f"Method '{function_name}' not found or not callable on node '{node_class.__name__}'.")
return execution_method(**kwargs)
class ControlNetPreprocessorPipeline(BasePipeline):
def get_required_models(self, **kwargs) -> List[str]:
return []
def _gpu_logic(
self, pil_images: List[Image.Image], preprocessor_name: str, model_name: str,
params: Dict[str, Any], progress=gr.Progress(track_tqdm=True)
) -> List[Image.Image]:
global REVERSE_DISPLAY_NAME_MAP
if REVERSE_DISPLAY_NAME_MAP is None:
raise RuntimeError("REVERSE_DISPLAY_NAME_MAP has not been initialized. `build_reverse_map` must be called on startup.")
class_name = REVERSE_DISPLAY_NAME_MAP.get(preprocessor_name)
if not class_name or class_name not in NODE_CLASS_MAPPINGS:
raise ValueError(f"Preprocessor '{preprocessor_name}' not found.")
preprocessor_instance = NODE_CLASS_MAPPINGS[class_name]()
call_args = {**params, 'ckpt_name': model_name}
processed_pil_images = []
total_frames = len(pil_images)
for i, frame_pil in enumerate(pil_images):
progress(i / total_frames, desc=f"Processing frame {i+1}/{total_frames} with {preprocessor_name}...")
frame_tensor = torch.from_numpy(np.array(frame_pil).astype(np.float32) / 255.0).unsqueeze(0)
resolution_arg = {'resolution': max(frame_tensor.shape[2], frame_tensor.shape[3])}
result_tuple = run_node_by_function_name(
preprocessor_instance,
image=frame_tensor,
**resolution_arg,
**call_args
)
processed_tensor = get_value_at_index(result_tuple, 0)
processed_np = (processed_tensor.squeeze(0).cpu().numpy().clip(0, 1) * 255.0).astype(np.uint8)
processed_pil_images.append(Image.fromarray(processed_np))
return processed_pil_images
def run(self, input_type, image_input, video_input, preprocessor_name, model_name, zero_gpu_duration, *args, progress=gr.Progress(track_tqdm=True)):
from utils import app_utils
pil_images, is_video, fps = [], False, 30
progress(0, desc="Reading input file...")
if input_type == "Image":
if image_input is None: raise gr.Error("Please provide an input image.")
pil_images = [image_input]
elif input_type == "Video":
if video_input is None: raise gr.Error("Please provide an input video.")
try:
video_reader = imageio.get_reader(video_input)
meta = video_reader.get_meta_data()
fps = meta.get('fps', 30)
pil_images = [Image.fromarray(frame) for frame in video_reader]
is_video = True
video_reader.close()
except Exception as e: raise gr.Error(f"Failed to read video file: {e}")
else:
raise gr.Error("Invalid input type selected.")
if not pil_images: raise gr.Error("Could not extract any frames from the input.")
if app_utils.PREPROCESSOR_PARAMETER_MAP is None:
raise RuntimeError("Preprocessor parameter map is not built. Check startup logs.")
params_config = app_utils.PREPROCESSOR_PARAMETER_MAP.get(preprocessor_name, [])
sliders_params = [p for p in params_config if p['type'] in ["INT", "FLOAT"]]
dropdown_params = [p for p in params_config if isinstance(p['type'], list)]
checkbox_params = [p for p in params_config if p['type'] == "BOOLEAN"]
ordered_params_config = sliders_params + dropdown_params + checkbox_params
param_names = [p['name'] for p in ordered_params_config]
provided_params = {param_names[i]: args[i] for i in range(len(param_names))}
if preprocessor_name not in CPU_ONLY_PREPROCESSORS:
print(f"--- '{preprocessor_name}' requires GPU, requesting ZeroGPU. ---")
try:
processed_pil_images = self._execute_gpu_logic(
self._gpu_logic,
duration=zero_gpu_duration,
default_duration=60,
task_name=f"Preprocessor '{preprocessor_name}'",
pil_images=pil_images,
preprocessor_name=preprocessor_name,
model_name=model_name,
params=provided_params,
progress=progress
)
except Exception as e:
import traceback; traceback.print_exc()
raise gr.Error(f"Failed to run preprocessor '{preprocessor_name}' on GPU: {e}")
else:
print(f"--- Running '{preprocessor_name}' on CPU, no ZeroGPU requested. ---")
try:
processed_pil_images = self._gpu_logic(pil_images, preprocessor_name, model_name, provided_params, progress=progress)
except Exception as e:
import traceback; traceback.print_exc()
raise gr.Error(f"Failed to run preprocessor '{preprocessor_name}' on CPU: {e}")
if not processed_pil_images: raise gr.Error("Processing returned no frames.")
progress(0.9, desc="Finalizing output...")
if is_video:
frames_np = [np.array(img) for img in processed_pil_images]
frames_tensor = torch.from_numpy(np.stack(frames_np)).to(torch.float32) / 255.0
video_path = self._encode_video_from_frames(frames_tensor, fps, progress)
return [video_path]
else:
progress(1.0, desc="Done!")
return processed_pil_images |