Spaces:
Running
on
Zero
Running
on
Zero
| import os | |
| import random | |
| import shutil | |
| import torch | |
| import gradio as gr | |
| from PIL import Image, ImageChops | |
| from typing import List, Dict, Any | |
| from collections import defaultdict, deque | |
| import numpy as np | |
| from .base_pipeline import BasePipeline | |
| from core.settings import * | |
| from comfy_integration.nodes import * | |
| from utils.app_utils import get_value_at_index, sanitize_prompt, get_lora_path, get_embedding_path, ensure_controlnet_model_downloaded, ensure_ipadapter_models_downloaded, sanitize_filename | |
| from core.workflow_assembler import WorkflowAssembler | |
| class SdImagePipeline(BasePipeline): | |
| def get_required_models(self, model_display_name: str, **kwargs) -> List[str]: | |
| return [model_display_name] | |
| def _topological_sort(self, workflow: Dict[str, Any]) -> List[str]: | |
| graph = defaultdict(list) | |
| in_degree = {node_id: 0 for node_id in workflow} | |
| for node_id, node_info in workflow.items(): | |
| for input_value in node_info.get('inputs', {}).values(): | |
| if isinstance(input_value, list) and len(input_value) == 2 and isinstance(input_value[0], str): | |
| source_node_id = input_value[0] | |
| if source_node_id in workflow: | |
| graph[source_node_id].append(node_id) | |
| in_degree[node_id] += 1 | |
| queue = deque([node_id for node_id, degree in in_degree.items() if degree == 0]) | |
| sorted_nodes = [] | |
| while queue: | |
| current_node_id = queue.popleft() | |
| sorted_nodes.append(current_node_id) | |
| for neighbor_node_id in graph[current_node_id]: | |
| in_degree[neighbor_node_id] -= 1 | |
| if in_degree[neighbor_node_id] == 0: | |
| queue.append(neighbor_node_id) | |
| if len(sorted_nodes) != len(workflow): | |
| raise RuntimeError("Workflow contains a cycle and cannot be executed.") | |
| return sorted_nodes | |
| def _execute_workflow(self, workflow: Dict[str, Any], initial_objects: Dict[str, Any]): | |
| with torch.no_grad(): | |
| computed_outputs = initial_objects | |
| try: | |
| sorted_node_ids = self._topological_sort(workflow) | |
| print(f"--- [Workflow Executor] Execution order: {sorted_node_ids}") | |
| except RuntimeError as e: | |
| print("--- [Workflow Executor] ERROR: Failed to sort workflow. Dumping graph details. ---") | |
| for node_id, node_info in workflow.items(): | |
| print(f" Node {node_id} ({node_info['class_type']}):") | |
| for input_name, input_value in node_info['inputs'].items(): | |
| if isinstance(input_value, list) and len(input_value) == 2 and isinstance(input_value[0], str): | |
| print(f" - {input_name} <- [{input_value[0]}, {input_value[1]}]") | |
| raise e | |
| for node_id in sorted_node_ids: | |
| if node_id in computed_outputs: | |
| continue | |
| node_info = workflow[node_id] | |
| class_type = node_info['class_type'] | |
| node_class = NODE_CLASS_MAPPINGS.get(class_type) | |
| if node_class is None: | |
| raise RuntimeError(f"Could not find node class '{class_type}'. Is it imported in comfy_integration/nodes.py?") | |
| node_instance = node_class() | |
| kwargs = {} | |
| for param_name, param_value in node_info['inputs'].items(): | |
| if isinstance(param_value, list) and len(param_value) == 2 and isinstance(param_value[0], str): | |
| source_node_id, output_index = param_value | |
| if source_node_id not in computed_outputs: | |
| raise RuntimeError(f"Workflow integrity error: Output of node {source_node_id} needed for {node_id} but not yet computed.") | |
| source_output_tuple = computed_outputs[source_node_id] | |
| kwargs[param_name] = get_value_at_index(source_output_tuple, output_index) | |
| else: | |
| kwargs[param_name] = param_value | |
| function_name = getattr(node_class, 'FUNCTION') | |
| execution_method = getattr(node_instance, function_name) | |
| result = execution_method(**kwargs) | |
| computed_outputs[node_id] = result | |
| final_node_id = None | |
| for node_id in reversed(sorted_node_ids): | |
| if workflow[node_id]['class_type'] == 'SaveImage': | |
| final_node_id = node_id | |
| break | |
| if not final_node_id: | |
| raise RuntimeError("Workflow does not contain a 'SaveImage' node as the output.") | |
| save_image_inputs = workflow[final_node_id]['inputs'] | |
| image_source_node_id, image_source_index = save_image_inputs['images'] | |
| return get_value_at_index(computed_outputs[image_source_node_id], image_source_index) | |
| def _gpu_logic(self, ui_inputs: Dict, loras_string: str, required_models_for_gpu: List[str], workflow: Dict[str, Any], assembler: WorkflowAssembler, progress=gr.Progress(track_tqdm=True)): | |
| model_display_name = ui_inputs['model_display_name'] | |
| progress(0.1, desc="Moving models to GPU...") | |
| self.model_manager.move_models_to_gpu(required_models_for_gpu) | |
| progress(0.4, desc="Executing workflow...") | |
| loaded_model_tuple = self.model_manager.loaded_models[model_display_name] | |
| ckpt_loader_node_id = assembler.node_map.get("ckpt_loader") | |
| if not ckpt_loader_node_id: | |
| raise RuntimeError("Workflow is missing the 'ckpt_loader' node required for model injection.") | |
| initial_objects = { | |
| ckpt_loader_node_id: loaded_model_tuple | |
| } | |
| decoded_images_tensor = self._execute_workflow(workflow, initial_objects=initial_objects) | |
| output_images = [] | |
| start_seed = ui_inputs['seed'] if ui_inputs['seed'] != -1 else random.randint(0, 2**64 - 1) | |
| for i in range(decoded_images_tensor.shape[0]): | |
| img_tensor = decoded_images_tensor[i] | |
| pil_image = Image.fromarray((img_tensor.cpu().numpy() * 255.0).astype("uint8")) | |
| current_seed = start_seed + i | |
| width_for_meta = ui_inputs.get('width', 'N/A') | |
| height_for_meta = ui_inputs.get('height', 'N/A') | |
| params_string = f"{ui_inputs['positive_prompt']}\nNegative prompt: {ui_inputs['negative_prompt']}\n" | |
| params_string += f"Steps: {ui_inputs['num_inference_steps']}, Sampler: {ui_inputs['sampler']}, Scheduler: {ui_inputs['scheduler']}, CFG scale: {ui_inputs['guidance_scale']}, Seed: {current_seed}, Size: {width_for_meta}x{height_for_meta}, Base Model: {model_display_name}" | |
| if ui_inputs['task_type'] != 'txt2img': params_string += f", Denoise: {ui_inputs['denoise']}" | |
| if loras_string: params_string += f", {loras_string}" | |
| pil_image.info = {'parameters': params_string.strip()} | |
| output_images.append(pil_image) | |
| return output_images | |
| def run(self, ui_inputs: Dict, progress): | |
| progress(0, desc="Preparing models...") | |
| task_type = ui_inputs['task_type'] | |
| ui_inputs['positive_prompt'] = sanitize_prompt(ui_inputs.get('positive_prompt', '')) | |
| ui_inputs['negative_prompt'] = sanitize_prompt(ui_inputs.get('negative_prompt', '')) | |
| required_models = self.get_required_models(model_display_name=ui_inputs['model_display_name']) | |
| self.model_manager.ensure_models_downloaded(required_models, progress=progress) | |
| lora_data = ui_inputs.get('lora_data', []) | |
| active_loras_for_gpu, active_loras_for_meta = [], [] | |
| sources, ids, scales, files = lora_data[0::4], lora_data[1::4], lora_data[2::4], lora_data[3::4] | |
| for i, (source, lora_id, scale, _) in enumerate(zip(sources, ids, scales, files)): | |
| if scale > 0 and lora_id and lora_id.strip(): | |
| lora_filename = None | |
| if source == "File": | |
| lora_filename = sanitize_filename(lora_id) | |
| elif source == "Civitai": | |
| local_path, status = get_lora_path(source, lora_id, ui_inputs['civitai_api_key'], progress) | |
| if local_path: lora_filename = os.path.basename(local_path) | |
| else: raise gr.Error(f"Failed to prepare LoRA {lora_id}: {status}") | |
| if lora_filename: | |
| active_loras_for_gpu.append({"lora_name": lora_filename, "strength_model": scale, "strength_clip": scale}) | |
| active_loras_for_meta.append(f"{source} {lora_id}:{scale}") | |
| progress(0.1, desc="Loading models into RAM...") | |
| self.model_manager.load_managed_models(required_models, active_loras=active_loras_for_gpu, progress=progress) | |
| ui_inputs['denoise'] = 1.0 | |
| if task_type == 'img2img': ui_inputs['denoise'] = ui_inputs.get('img2img_denoise', 0.7) | |
| elif task_type == 'hires_fix': ui_inputs['denoise'] = ui_inputs.get('hires_denoise', 0.55) | |
| temp_files_to_clean = [] | |
| if not os.path.exists(INPUT_DIR): os.makedirs(INPUT_DIR) | |
| if task_type == 'img2img': | |
| input_image_pil = ui_inputs.get('img2img_image') | |
| if input_image_pil: | |
| temp_file_path = os.path.join(INPUT_DIR, f"temp_input_{random.randint(1000, 9999)}.png") | |
| input_image_pil.save(temp_file_path, "PNG") | |
| ui_inputs['input_image'] = os.path.basename(temp_file_path) | |
| temp_files_to_clean.append(temp_file_path) | |
| ui_inputs['width'] = input_image_pil.width | |
| ui_inputs['height'] = input_image_pil.height | |
| elif task_type == 'inpaint': | |
| inpaint_dict = ui_inputs.get('inpaint_image_dict') | |
| if not inpaint_dict or not inpaint_dict.get('background') or not inpaint_dict.get('layers'): | |
| raise gr.Error("Inpainting requires an input image and a drawn mask.") | |
| background_img = inpaint_dict['background'].convert("RGBA") | |
| composite_mask_pil = Image.new('L', background_img.size, 0) | |
| for layer in inpaint_dict['layers']: | |
| if layer: | |
| layer_alpha = layer.split()[-1] | |
| composite_mask_pil = ImageChops.lighter(composite_mask_pil, layer_alpha) | |
| inverted_mask_alpha = Image.fromarray(255 - np.array(composite_mask_pil), mode='L') | |
| r, g, b, _ = background_img.split() | |
| composite_image_with_mask = Image.merge('RGBA', [r, g, b, inverted_mask_alpha]) | |
| temp_file_path = os.path.join(INPUT_DIR, f"temp_inpaint_composite_{random.randint(1000, 9999)}.png") | |
| composite_image_with_mask.save(temp_file_path, "PNG") | |
| ui_inputs['inpaint_image'] = os.path.basename(temp_file_path) | |
| temp_files_to_clean.append(temp_file_path) | |
| ui_inputs.pop('inpaint_mask', None) | |
| elif task_type == 'outpaint': | |
| input_image_pil = ui_inputs.get('outpaint_image') | |
| if input_image_pil: | |
| temp_file_path = os.path.join(INPUT_DIR, f"temp_input_{random.randint(1000, 9999)}.png") | |
| input_image_pil.save(temp_file_path, "PNG") | |
| ui_inputs['input_image'] = os.path.basename(temp_file_path) | |
| temp_files_to_clean.append(temp_file_path) | |
| elif task_type == 'hires_fix': | |
| input_image_pil = ui_inputs.get('hires_image') | |
| if input_image_pil: | |
| temp_file_path = os.path.join(INPUT_DIR, f"temp_input_{random.randint(1000, 9999)}.png") | |
| input_image_pil.save(temp_file_path, "PNG") | |
| ui_inputs['input_image'] = os.path.basename(temp_file_path) | |
| temp_files_to_clean.append(temp_file_path) | |
| embedding_data = ui_inputs.get('embedding_data', []) | |
| embedding_filenames = [] | |
| if embedding_data: | |
| emb_sources, emb_ids, emb_files = embedding_data[0::3], embedding_data[1::3], embedding_data[2::3] | |
| for i, (source, emb_id, _) in enumerate(zip(emb_sources, emb_ids, emb_files)): | |
| if emb_id and emb_id.strip(): | |
| emb_filename = None | |
| if source == "File": | |
| emb_filename = sanitize_filename(emb_id) | |
| elif source == "Civitai": | |
| local_path, status = get_embedding_path(source, emb_id, ui_inputs['civitai_api_key'], progress) | |
| if local_path: emb_filename = os.path.basename(local_path) | |
| else: raise gr.Error(f"Failed to prepare Embedding {emb_id}: {status}") | |
| if emb_filename: | |
| embedding_filenames.append(emb_filename) | |
| if embedding_filenames: | |
| embedding_prompt_text = " ".join([f"embedding:{f}" for f in embedding_filenames]) | |
| if ui_inputs['positive_prompt']: | |
| ui_inputs['positive_prompt'] = f"{ui_inputs['positive_prompt']}, {embedding_prompt_text}" | |
| else: | |
| ui_inputs['positive_prompt'] = embedding_prompt_text | |
| controlnet_data = ui_inputs.get('controlnet_data', []) | |
| active_controlnets = [] | |
| (cn_images, _, _, cn_strengths, cn_filepaths) = [controlnet_data[i::5] for i in range(5)] | |
| for i in range(len(cn_images)): | |
| if cn_images[i] and cn_strengths[i] > 0 and cn_filepaths[i] and cn_filepaths[i] != "None": | |
| ensure_controlnet_model_downloaded(cn_filepaths[i], progress) | |
| if not os.path.exists(INPUT_DIR): os.makedirs(INPUT_DIR) | |
| cn_temp_path = os.path.join(INPUT_DIR, f"temp_cn_{i}_{random.randint(1000, 9999)}.png") | |
| cn_images[i].save(cn_temp_path, "PNG") | |
| temp_files_to_clean.append(cn_temp_path) | |
| active_controlnets.append({ | |
| "image": os.path.basename(cn_temp_path), "strength": cn_strengths[i], | |
| "start_percent": 0.0, "end_percent": 1.0, "control_net_name": cn_filepaths[i] | |
| }) | |
| ipadapter_data = ui_inputs.get('ipadapter_data', []) | |
| active_ipadapters = [] | |
| if ipadapter_data: | |
| num_ipa_units = (len(ipadapter_data) - 5) // 3 | |
| final_preset, final_weight, final_lora_strength, final_embeds_scaling, final_combine_method = ipadapter_data[-5:] | |
| ipa_images, ipa_weights, ipa_lora_strengths = [ipadapter_data[i*num_ipa_units:(i+1)*num_ipa_units] for i in range(3)] | |
| all_presets_to_download = set() | |
| for i in range(num_ipa_units): | |
| if ipa_images[i] and ipa_weights[i] > 0 and final_preset: | |
| all_presets_to_download.add(final_preset) | |
| if not os.path.exists(INPUT_DIR): os.makedirs(INPUT_DIR) | |
| ipa_temp_path = os.path.join(INPUT_DIR, f"temp_ipa_{i}_{random.randint(1000, 9999)}.png") | |
| ipa_images[i].save(ipa_temp_path, "PNG") | |
| temp_files_to_clean.append(ipa_temp_path) | |
| active_ipadapters.append({ | |
| "image": os.path.basename(ipa_temp_path), "preset": final_preset, | |
| "weight": ipa_weights[i], "lora_strength": ipa_lora_strengths[i] | |
| }) | |
| if active_ipadapters and final_preset: | |
| all_presets_to_download.add(final_preset) | |
| for preset in all_presets_to_download: | |
| ensure_ipadapter_models_downloaded(preset, progress) | |
| if active_ipadapters: | |
| active_ipadapters.append({ | |
| 'is_final_settings': True, 'model_type': 'sd15', 'final_preset': final_preset, | |
| 'final_weight': final_weight, 'final_lora_strength': final_lora_strength, | |
| 'final_embeds_scaling': final_embeds_scaling, 'final_combine_method': final_combine_method | |
| }) | |
| from utils.app_utils import get_vae_path | |
| vae_source = ui_inputs.get('vae_source') | |
| vae_id = ui_inputs.get('vae_id') | |
| vae_file = ui_inputs.get('vae_file') | |
| vae_name_override = None | |
| if vae_source and vae_source != "None": | |
| if vae_source == "File": | |
| vae_name_override = sanitize_filename(vae_id) | |
| elif vae_source == "Civitai" and vae_id and vae_id.strip(): | |
| local_path, status = get_vae_path(vae_source, vae_id, ui_inputs.get('civitai_api_key'), progress) | |
| if local_path: vae_name_override = os.path.basename(local_path) | |
| else: raise gr.Error(f"Failed to prepare VAE {vae_id}: {status}") | |
| if vae_name_override: | |
| ui_inputs['vae_name'] = vae_name_override | |
| conditioning_data = ui_inputs.get('conditioning_data', []) | |
| active_conditioning = [] | |
| if conditioning_data: | |
| num_units = len(conditioning_data) // 6 | |
| prompts = conditioning_data[0*num_units : 1*num_units] | |
| widths = conditioning_data[1*num_units : 2*num_units] | |
| heights = conditioning_data[2*num_units : 3*num_units] | |
| xs = conditioning_data[3*num_units : 4*num_units] | |
| ys = conditioning_data[4*num_units : 5*num_units] | |
| strengths = conditioning_data[5*num_units : 6*num_units] | |
| for i in range(num_units): | |
| if prompts[i] and prompts[i].strip(): | |
| active_conditioning.append({ | |
| "prompt": prompts[i], | |
| "width": int(widths[i]), | |
| "height": int(heights[i]), | |
| "x": int(xs[i]), | |
| "y": int(ys[i]), | |
| "strength": float(strengths[i]) | |
| }) | |
| loras_string = f"LoRAs: [{', '.join(active_loras_for_meta)}]" if active_loras_for_meta else "" | |
| progress(0.8, desc="Assembling workflow...") | |
| if ui_inputs.get('seed') == -1: | |
| ui_inputs['seed'] = random.randint(0, 2**32 - 1) | |
| dynamic_values = {'task_type': ui_inputs['task_type'], 'model_type': "sd15"} | |
| recipe_path = os.path.join(os.path.dirname(__file__), "workflow_recipes", "sd_unified_recipe.yaml") | |
| assembler = WorkflowAssembler(recipe_path, dynamic_values=dynamic_values) | |
| workflow_inputs = { | |
| "positive_prompt": ui_inputs['positive_prompt'], "negative_prompt": ui_inputs['negative_prompt'], | |
| "seed": ui_inputs['seed'], "steps": ui_inputs['num_inference_steps'], "cfg": ui_inputs['guidance_scale'], | |
| "sampler_name": ui_inputs['sampler'], "scheduler": ui_inputs['scheduler'], | |
| "batch_size": ui_inputs['batch_size'], | |
| "clip_skip": -int(ui_inputs['clip_skip']), | |
| "denoise": ui_inputs['denoise'], | |
| "input_image": ui_inputs.get('input_image'), | |
| "inpaint_image": ui_inputs.get('inpaint_image'), | |
| "inpaint_mask": ui_inputs.get('inpaint_mask'), | |
| "left": ui_inputs.get('outpaint_left'), "top": ui_inputs.get('outpaint_top'), | |
| "right": ui_inputs.get('outpaint_right'), "bottom": ui_inputs.get('outpaint_bottom'), | |
| "hires_upscaler": ui_inputs.get('hires_upscaler'), "hires_scale_by": ui_inputs.get('hires_scale_by'), | |
| "model_name": ALL_MODEL_MAP[ui_inputs['model_display_name']][1], | |
| "vae_name": ui_inputs.get('vae_name'), | |
| "controlnet_chain": active_controlnets, | |
| "ipadapter_chain": active_ipadapters, | |
| "conditioning_chain": active_conditioning, | |
| } | |
| if task_type == 'txt2img': | |
| workflow_inputs['width'] = ui_inputs['width'] | |
| workflow_inputs['height'] = ui_inputs['height'] | |
| workflow = assembler.assemble(workflow_inputs) | |
| if workflow_inputs.get("vae_name"): | |
| print("--- [Workflow Patch] VAE override provided. Adding VAELoader and rewiring connections. ---") | |
| vae_loader_id = assembler._get_unique_id() | |
| vae_loader_node = assembler._get_node_template("VAELoader") | |
| vae_loader_node['inputs']['vae_name'] = workflow_inputs["vae_name"] | |
| workflow[vae_loader_id] = vae_loader_node | |
| vae_decode_id = assembler.node_map.get("vae_decode") | |
| if vae_decode_id and vae_decode_id in workflow: | |
| workflow[vae_decode_id]['inputs']['vae'] = [vae_loader_id, 0] | |
| print(f" - Rewired 'vae_decode' (ID: {vae_decode_id}) to use new VAELoader.") | |
| vae_encode_id = assembler.node_map.get("vae_encode") | |
| if vae_encode_id and vae_encode_id in workflow: | |
| workflow[vae_encode_id]['inputs']['vae'] = [vae_loader_id, 0] | |
| print(f" - Rewired 'vae_encode' (ID: {vae_encode_id}) to use new VAELoader.") | |
| else: | |
| print("--- [Workflow Info] No VAE override. Using VAE from checkpoint. ---") | |
| progress(1.0, desc="All models ready. Requesting GPU for generation...") | |
| try: | |
| results = self._execute_gpu_logic( | |
| self._gpu_logic, | |
| duration=ui_inputs['zero_gpu_duration'], | |
| default_duration=60, | |
| task_name=f"ImageGen ({task_type})", | |
| ui_inputs=ui_inputs, | |
| loras_string=loras_string, | |
| required_models_for_gpu=required_models, | |
| workflow=workflow, | |
| assembler=assembler, | |
| progress=progress | |
| ) | |
| finally: | |
| for temp_file in temp_files_to_clean: | |
| if temp_file and os.path.exists(temp_file): | |
| os.remove(temp_file) | |
| print(f"✅ Cleaned up temp file: {temp_file}") | |
| return results |