Spaces:
Running
on
Zero
Running
on
Zero
| import os | |
| import torch.nn.functional as F | |
| import torch | |
| import sys | |
| import cv2 | |
| import numpy as np | |
| from PIL import Image | |
| import json | |
| # New imports for the diffuser pipeline | |
| from src.pipeline_flux_kontext_control import FluxKontextControlPipeline | |
| from src.transformer_flux import FluxTransformer2DModel | |
| import tempfile | |
| from safetensors.torch import load_file, save_file | |
| _original_load_lora_weights = FluxKontextControlPipeline.load_lora_weights | |
| def _patched_load_lora_weights(self, pretrained_model_name_or_path_or_dict, **kwargs): | |
| """自动转换混合格式的 LoRA 并添加 transformer 前缀""" | |
| weight_name = kwargs.get("weight_name", "pytorch_lora_weights.safetensors") | |
| if isinstance(pretrained_model_name_or_path_or_dict, str): | |
| if os.path.isdir(pretrained_model_name_or_path_or_dict): | |
| lora_file = os.path.join(pretrained_model_name_or_path_or_dict, weight_name) | |
| else: | |
| lora_file = pretrained_model_name_or_path_or_dict | |
| if os.path.exists(lora_file): | |
| state_dict = load_file(lora_file) | |
| # 检查是否需要转换格式或添加前缀 | |
| needs_format_conversion = any('lora_A.weight' in k or 'lora_B.weight' in k for k in state_dict.keys()) | |
| needs_prefix = not any(k.startswith('transformer.') for k in state_dict.keys()) | |
| if needs_format_conversion or needs_prefix: | |
| print(f"🔄 Processing LoRA: {lora_file}") | |
| if needs_format_conversion: | |
| print(f" - Converting PEFT format to diffusers format") | |
| if needs_prefix: | |
| print(f" - Adding 'transformer.' prefix to keys") | |
| converted_state = {} | |
| converted_count = 0 | |
| for key, value in state_dict.items(): | |
| new_key = key | |
| # 步骤 1: 转换 PEFT 格式到 diffusers 格式 | |
| if 'lora_A.weight' in new_key: | |
| new_key = new_key.replace('lora_A.weight', 'lora.down.weight') | |
| converted_count += 1 | |
| elif 'lora_B.weight' in new_key: | |
| new_key = new_key.replace('lora_B.weight', 'lora.up.weight') | |
| converted_count += 1 | |
| # 步骤 2: 添加 transformer 前缀(如果还没有的话) | |
| if not new_key.startswith('transformer.'): | |
| new_key = f'transformer.{new_key}' | |
| converted_state[new_key] = value | |
| if needs_format_conversion: | |
| print(f" ✅ Converted {converted_count} PEFT keys") | |
| print(f" ✅ Total keys: {len(converted_state)}") | |
| with tempfile.TemporaryDirectory() as temp_dir: | |
| temp_file = os.path.join(temp_dir, weight_name) | |
| save_file(converted_state, temp_file) | |
| return _original_load_lora_weights(self, temp_dir, **kwargs) | |
| else: | |
| print(f"✅ LoRA already in correct format: {lora_file}") | |
| # 不需要转换,使用原始方法 | |
| return _original_load_lora_weights(self, pretrained_model_name_or_path_or_dict, **kwargs) | |
| # 应用 monkey patch | |
| FluxKontextControlPipeline.load_lora_weights = _patched_load_lora_weights | |
| print("✅ Monkey patch applied to FluxKontextPipeline.load_lora_weights") | |
| current_dir = os.path.dirname(os.path.abspath(__file__)) | |
| sys.path.append(current_dir) | |
| sys.path.append(os.path.abspath(os.path.join(current_dir, '..'))) | |
| sys.path.append(os.path.abspath(os.path.join(current_dir, '..', '..', 'comfy_extras'))) | |
| from train.src.condition.edge_extraction import InformativeDetector, HEDDetector | |
| from utils_node import BlendInpaint, JoinImageWithAlpha, GrowMask, InvertMask, ColorDetector | |
| TEST_MODE = False | |
| class KontextEditModel(): | |
| def __init__(self, base_model_path="black-forest-labs/FLUX.1-Kontext-dev", device="cuda", | |
| aux_lora_dir="models/v2_ckpt", easycontrol_base_dir="models/v2_ckpt", | |
| aux_lora_weight_name="puzzle_lora.safetensors", | |
| aux_lora_weight=1.0): | |
| # Keep necessary preprocessors | |
| self.mask_processor = GrowMask() | |
| self.scribble_processor = HEDDetector.from_pretrained() | |
| self.lineart_processor = InformativeDetector.from_pretrained() | |
| self.color_processor = ColorDetector() | |
| self.blender = BlendInpaint() | |
| # Initialize the new pipeline (Kontext version) | |
| self.device = device | |
| self.pipe = FluxKontextControlPipeline.from_pretrained(base_model_path, torch_dtype=torch.bfloat16) | |
| transformer = FluxTransformer2DModel.from_pretrained( | |
| base_model_path, | |
| subfolder="transformer", | |
| torch_dtype=torch.bfloat16, | |
| device=self.device | |
| ) | |
| self.pipe.transformer = transformer | |
| self.pipe.to(self.device, dtype=torch.bfloat16) | |
| control_lora_config = { | |
| "local": { | |
| "path": os.path.join(easycontrol_base_dir, "local_lora.safetensors"), | |
| "lora_weights": [1.0], | |
| "cond_size": 512, | |
| }, | |
| "removal": { | |
| "path": os.path.join(easycontrol_base_dir, "removal_lora.safetensors"), | |
| "lora_weights": [1.0], | |
| "cond_size": 512, | |
| }, | |
| "edge": { | |
| "path": os.path.join(easycontrol_base_dir, "edge_lora.safetensors"), | |
| "lora_weights": [1.0], | |
| "cond_size": 512, | |
| }, | |
| "color": { | |
| "path": os.path.join(easycontrol_base_dir, "color_lora.safetensors"), | |
| "lora_weights": [1.0], | |
| "cond_size": 512, | |
| }, | |
| } | |
| self.pipe.load_control_loras(control_lora_config) | |
| # Aux LoRA for foreground mode | |
| self.aux_lora_weight_name = aux_lora_weight_name | |
| self.aux_lora_dir = aux_lora_dir | |
| self.aux_lora_weight = aux_lora_weight | |
| self.aux_adapter_name = "aux" | |
| from safetensors.torch import load_file as _sft_load | |
| aux_path = os.path.join(self.aux_lora_dir, self.aux_lora_weight_name) | |
| if os.path.isfile(aux_path): | |
| self.pipe.load_lora_weights(aux_path, adapter_name=self.aux_adapter_name) | |
| print(f"Loaded aux LoRA: {aux_path}") | |
| # Ensure aux LoRA is disabled by default; it will be enabled only in foreground_edit | |
| self._disable_aux_lora() | |
| else: | |
| print(f"Aux LoRA not found at {aux_path}, foreground mode will run without it.") | |
| # gamma is now applied inside the pipeline based on control_dict | |
| def _tensor_to_pil(self, tensor_image): | |
| # Converts a ComfyUI-style tensor [1, H, W, 3] to a PIL Image | |
| return Image.fromarray(np.clip(255. * tensor_image.cpu().numpy().squeeze(), 0, 255).astype(np.uint8)) | |
| def _pil_to_tensor(self, pil_image): | |
| # Converts a PIL image to a ComfyUI-style tensor [1, H, W, 3] | |
| return torch.from_numpy(np.array(pil_image).astype(np.float32) / 255.0).unsqueeze(0) | |
| def clear_cache(self): | |
| for name, attn_processor in self.pipe.transformer.attn_processors.items(): | |
| if hasattr(attn_processor, 'bank_kv'): | |
| attn_processor.bank_kv.clear() | |
| if hasattr(attn_processor, 'bank_attn'): | |
| attn_processor.bank_attn = None | |
| def _enable_aux_lora(self): | |
| self.pipe.enable_lora() | |
| self.pipe.set_adapters([self.aux_adapter_name], adapter_weights=[self.aux_lora_weight]) | |
| print(f"Enabled aux LoRA '{self.aux_adapter_name}' with weight {self.aux_lora_weight}") | |
| def _disable_aux_lora(self): | |
| self.pipe.disable_lora() | |
| print("Disabled aux LoRA") | |
| def _expand_mask(self, mask_tensor: torch.Tensor, expand: int = 0) -> torch.Tensor: | |
| if expand <= 0: | |
| return mask_tensor | |
| expanded = self.mask_processor.expand_mask(mask_tensor, expand=expand, tapered_corners=True)[0] | |
| return expanded | |
| def _tensor_mask_to_pil3(self, mask_tensor: torch.Tensor) -> Image.Image: | |
| mask_01 = torch.clamp(mask_tensor, 0.0, 1.0) | |
| if mask_01.ndim == 3 and mask_01.shape[-1] == 3: | |
| mask_01 = mask_01[..., 0] | |
| if mask_01.ndim == 3 and mask_01.shape[0] == 1: | |
| mask_01 = mask_01[0] | |
| pil = self._tensor_to_pil(mask_01.unsqueeze(-1).repeat(1, 1, 3)) | |
| return pil | |
| def _apply_black_mask(self, image_tensor: torch.Tensor, binary_mask: torch.Tensor) -> Image.Image: | |
| # image_tensor: [1, H, W, 3] in [0,1] | |
| # binary_mask: [H, W] or [1, H, W], 1=mask area (white) | |
| if binary_mask.ndim == 3: | |
| binary_mask = binary_mask[0] | |
| mask_bool = (binary_mask > 0.5) | |
| img = image_tensor.clone() | |
| img[0][mask_bool] = 0.0 | |
| return self._tensor_to_pil(img) | |
| def edge_edit(self, | |
| image, colored_image, positive_prompt, | |
| base_mask, add_mask, remove_mask, | |
| fine_edge, | |
| edge_strength, color_strength, | |
| seed, steps, cfg): | |
| generator = torch.Generator(device=self.device).manual_seed(seed) | |
| # Prepare mask and original image | |
| original_image_tensor = image.clone() | |
| original_mask = base_mask | |
| original_mask = self._expand_mask(original_mask, expand=25) | |
| image_pil = self._tensor_to_pil(image) | |
| # image_pil.save("image_pil.png") | |
| control_dict = {} | |
| lineart_output = None | |
| # Determine control type: color or edge | |
| if not torch.equal(image, colored_image): | |
| print("Apply color control") | |
| colored_image_pil = self._tensor_to_pil(colored_image) | |
| # Create color block condition | |
| color_image_np = np.array(colored_image_pil) | |
| downsampled = cv2.resize(color_image_np, (32, 32), interpolation=cv2.INTER_AREA) | |
| upsampled = cv2.resize(downsampled, (256, 256), interpolation=cv2.INTER_NEAREST) | |
| color_block = Image.fromarray(upsampled) | |
| # Create grayscale condition | |
| control_dict = { | |
| "type": "color", | |
| "spatial_images": [color_block], | |
| "gammas": [color_strength] | |
| } | |
| else: | |
| print("Apply edge control") | |
| if fine_edge == "enable": | |
| lineart_image = self.lineart_processor(np.array(self._tensor_to_pil(image.cpu().squeeze())), detect_resolution=1024, style="contour", output_type="pil") | |
| lineart_output = self._pil_to_tensor(lineart_image) | |
| else: | |
| scribble_image = self.scribble_processor(np.array(self._tensor_to_pil(image.cpu().squeeze())), safe=True, resolution=512, output_type="pil") | |
| lineart_output = self._pil_to_tensor(scribble_image) | |
| if lineart_output is None: | |
| raise ValueError("Preprocessor failed to generate lineart.") | |
| # Apply user sketches to the lineart | |
| add_mask_resized = F.interpolate(add_mask.unsqueeze(0).float(), size=(lineart_output.shape[1], lineart_output.shape[2]), mode='nearest').squeeze(0) | |
| remove_mask_resized = F.interpolate(remove_mask.unsqueeze(0).float(), size=(lineart_output.shape[1], lineart_output.shape[2]), mode='nearest').squeeze(0) | |
| bool_add_mask_resized = (add_mask_resized > 0.5) | |
| bool_remove_mask_resized = (remove_mask_resized > 0.5) | |
| lineart_output[bool_remove_mask_resized] = 0.0 | |
| lineart_output[bool_add_mask_resized] = 1.0 | |
| control_dict = { | |
| "type": "edge", | |
| "spatial_images": [self._tensor_to_pil(lineart_output)], | |
| "gammas": [edge_strength] | |
| } | |
| # Prepare debug/output images | |
| colored_image_np = np.array(self._tensor_to_pil(colored_image)) | |
| debug_image = lineart_output if lineart_output is not None else self.color_processor(colored_image_np, detect_resolution=1024, output_type="pil") | |
| # Run inference | |
| result_pil = self.pipe( | |
| prompt=positive_prompt, | |
| image=image_pil, | |
| height=image_pil.height, | |
| width=image_pil.width, | |
| guidance_scale=cfg, | |
| num_inference_steps=steps, | |
| generator=generator, | |
| max_sequence_length=128, | |
| control_dict=control_dict, | |
| ).images[0] | |
| self.clear_cache() | |
| # result_pil.save("result_pil.png") | |
| result_tensor = self._pil_to_tensor(result_pil) | |
| # final_image = self.blender.blend_inpaint(result_tensor, original_image_tensor, original_mask, kernel=10, sigma=10)[0] | |
| final_image = result_tensor | |
| return (final_image, debug_image, original_mask) | |
| def object_removal(self, | |
| image, positive_prompt, | |
| remove_mask, | |
| local_strength, | |
| seed, steps, cfg): | |
| generator = torch.Generator(device=self.device).manual_seed(seed) | |
| original_image_tensor = image.clone() | |
| original_mask = remove_mask | |
| original_mask = self._expand_mask(remove_mask, expand=10) | |
| image_pil = self._tensor_to_pil(image) | |
| # image_pil.save("image_pil.png") | |
| # Prepare spatial image: original masked to black in the remove area | |
| spatial_pil = self._apply_black_mask(image, original_mask) | |
| # spatial_pil.save("spatial_pil.png") | |
| # Note: mask is not passed to pipeline; we use it only for blending | |
| control_dict = { | |
| "type": "removal", | |
| "spatial_images": [spatial_pil], | |
| "gammas": [local_strength] | |
| } | |
| result_pil = self.pipe( | |
| prompt=positive_prompt, | |
| image=image_pil, | |
| height=image_pil.height, | |
| width=image_pil.width, | |
| guidance_scale=cfg, | |
| num_inference_steps=steps, | |
| generator=generator, | |
| control_dict=control_dict, | |
| ).images[0] | |
| self.clear_cache() | |
| result_tensor = self._pil_to_tensor(result_pil) | |
| final_image = self.blender.blend_inpaint(result_tensor, original_image_tensor, original_mask, kernel=10, sigma=10)[0] | |
| # final_image = result_tensor | |
| return (final_image, self._pil_to_tensor(spatial_pil), original_mask) | |
| def local_edit(self, | |
| image, positive_prompt, fill_mask, local_strength, | |
| seed, steps, cfg): | |
| generator = torch.Generator(device=self.device).manual_seed(seed) | |
| original_image_tensor = image.clone() | |
| original_mask = self._expand_mask(fill_mask, expand=10) | |
| image_pil = self._tensor_to_pil(image) | |
| # image_pil.save("image_pil.png") | |
| spatial_pil = self._apply_black_mask(image, original_mask) | |
| # spatial_pil.save("spatial_pil.png") | |
| control_dict = { | |
| "type": "local", | |
| "spatial_images": [spatial_pil], | |
| "gammas": [local_strength] | |
| } | |
| result_pil = self.pipe( | |
| prompt=positive_prompt, | |
| image=image_pil, | |
| height=image_pil.height, | |
| width=image_pil.width, | |
| guidance_scale=cfg, | |
| num_inference_steps=steps, | |
| generator=generator, | |
| max_sequence_length=128, | |
| control_dict=control_dict, | |
| ).images[0] | |
| self.clear_cache() | |
| result_tensor = self._pil_to_tensor(result_pil) | |
| final_image = self.blender.blend_inpaint(result_tensor, original_image_tensor, original_mask, kernel=10, sigma=10)[0] | |
| # final_image = result_tensor | |
| return (final_image, self._pil_to_tensor(spatial_pil), original_mask) | |
| def foreground_edit(self, | |
| merged_image, positive_prompt, | |
| add_prop_mask, fill_mask, fix_perspective, grow_size, | |
| seed, steps, cfg): | |
| generator = torch.Generator(device=self.device).manual_seed(seed) | |
| edit_mask = torch.clamp(self._expand_mask(add_prop_mask, expand=grow_size) + fill_mask, 0.0, 1.0) | |
| final_mask = self._expand_mask(edit_mask, expand=25) | |
| if fix_perspective == "enable": | |
| positive_prompt = positive_prompt + " Fix the perspective if necessary." | |
| # Prepare edited input image: inside edit_mask but outside add_prop_mask set to white | |
| img = merged_image.clone() | |
| base_mask = (edit_mask > 0.5) | |
| add_only = (add_prop_mask <= 0.5) & base_mask # [1, H, W] bool | |
| add_only_3 = add_only.squeeze(0).unsqueeze(-1).expand(-1, -1, img.shape[-1]) # [H, W, 3] | |
| img[0] = torch.where(add_only_3, torch.ones_like(img[0]), img[0]) | |
| image_pil = self._tensor_to_pil(img) | |
| # image_pil.save("image_pil.png") | |
| # Enable aux LoRA only for foreground | |
| self._enable_aux_lora() | |
| result_pil = self.pipe( | |
| prompt=positive_prompt, | |
| image=image_pil, | |
| height=image_pil.height, | |
| width=image_pil.width, | |
| guidance_scale=cfg, | |
| num_inference_steps=steps, | |
| generator=generator, | |
| max_sequence_length=128, | |
| control_dict=None, | |
| ).images[0] | |
| # Disable aux LoRA afterwards | |
| self._disable_aux_lora() | |
| self.clear_cache() | |
| final_image = self._pil_to_tensor(result_pil) | |
| # final_image = self.blender.blend_inpaint(final_image, img, final_mask, kernel=10, sigma=10)[0] | |
| return (final_image, self._pil_to_tensor(image_pil), edit_mask) | |
| def kontext_edit(self, | |
| image, positive_prompt, | |
| seed, steps, cfg): | |
| generator = torch.Generator(device=self.device).manual_seed(seed) | |
| image_pil = self._tensor_to_pil(image) | |
| result_pil = self.pipe( | |
| prompt=positive_prompt, | |
| image=image_pil, | |
| height=image_pil.height, | |
| width=image_pil.width, | |
| guidance_scale=cfg, | |
| num_inference_steps=steps, | |
| generator=generator, | |
| max_sequence_length=128, | |
| control_dict=None, | |
| ).images[0] | |
| final_image = self._pil_to_tensor(result_pil) | |
| mask = torch.zeros((1, final_image.shape[1], final_image.shape[2]), dtype=torch.float32, device=final_image.device) | |
| return (final_image, image, mask) | |
| def process(self, image, colored_image, | |
| merged_image, positive_prompt, | |
| total_mask, add_mask, remove_mask, add_prop_mask, fill_mask, | |
| fine_edge, fix_perspective, edge_strength, color_strength, local_strength, grow_size, | |
| seed, steps, cfg, flag="precise_edit"): | |
| if flag == "foreground": | |
| return self.foreground_edit(merged_image, positive_prompt, add_prop_mask, fill_mask, fix_perspective, grow_size, seed, steps, cfg) | |
| elif flag == "local": | |
| return self.local_edit(image, positive_prompt, fill_mask, local_strength, seed, steps, cfg) | |
| elif flag == "removal": | |
| return self.object_removal(image, positive_prompt, remove_mask, local_strength, seed, steps, cfg) | |
| elif flag == "precise_edit": | |
| return self.edge_edit( | |
| image, colored_image, positive_prompt, | |
| total_mask, add_mask, remove_mask, | |
| fine_edge, | |
| edge_strength, color_strength, | |
| seed, steps, cfg | |
| ) | |
| elif flag == "kontext": | |
| return self.kontext_edit(image, positive_prompt, seed, steps, cfg) | |
| else: | |
| raise ValueError("Invalid Editing Type: {}".format(flag)) | |