import os import cv2 import gradio as gr import torch import numpy as np from PIL import Image, ImageDraw import spaces from huggingface_hub import hf_hub_download # ========================================== # 1. Global Settings & Variables # ========================================== MODEL_ID = "black-forest-labs/FLUX.1-dev" DEBLUR_LORA_PATH = "." DEBLUR_WEIGHT_NAME = "deblurNet.safetensors" BOKEH_LORA_DIR = "." BOKEH_WEIGHT_NAME = "bokehNet.safetensors" # Global variables pipe_flux = None depth_model = None depth_transform = None # ========================================== # 2. Depth Pro Loader # ========================================== class DepthProLoader: def load(self, device): print("🔄 Loading Depth Pro model...") try: global Condition, generate, seed_everything, FluxPipeline, depth_pro from Genfocus.pipeline.flux import Condition, generate, seed_everything, FluxPipeline import depth_pro from depth_pro.depth_pro import DEFAULT_MONODEPTH_CONFIG_DICT import copy WEIGHTS_REPO_ID = "nycu-cplab/Genfocus-Model" DEPTH_FILENAME = "checkpoints/depth_pro.pt" checkpoint_path = hf_hub_download( repo_id=WEIGHTS_REPO_ID, filename=DEPTH_FILENAME, repo_type="model" ) cfg = copy.deepcopy(DEFAULT_MONODEPTH_CONFIG_DICT) cfg.checkpoint_uri = checkpoint_path try: create_fn = depth_pro.create_model_and_transforms except AttributeError: from depth_pro.depth_pro import create_model_and_transforms create_fn = create_model_and_transforms model, transform = create_fn( config=cfg, device=device, precision=torch.float32 ) model.eval() print(f"✅ Depth Pro loaded on {device}.") return model, transform except Exception as e: print(f"❌ Failed to load Depth Pro: {e}") raise e # ========================================== # 3. Helper Functions # ========================================== def resize_and_crop_to_16(img: Image.Image) -> Image.Image: """ 1. Resize the longer side to 512, maintaining aspect ratio. 2. Crop the dimensions to be multiples of 16. """ w, h = img.size target = 512 # 1. Resize longer side to 512 if w >= h: scale = target / w else: scale = target / h new_w = int(w * scale) new_h = int(h * scale) img = img.resize((new_w, new_h), Image.LANCZOS) # 2. Crop to multiples of 16 final_w = (new_w // 16) * 16 final_h = (new_h // 16) * 16 # Center crop calculation left = (new_w - final_w) // 2 top = (new_h - final_h) // 2 right = left + final_w bottom = top + final_h img = img.crop((left, top, right, bottom)) return img def switch_lora_on_gpu(pipe, target_mode): print(f"🔄 Switching LoRA to [{target_mode}]...") pipe.unload_lora_weights() if target_mode == "deblur": pipe.load_lora_weights(DEBLUR_LORA_PATH, weight_name=DEBLUR_WEIGHT_NAME, adapter_name="deblurring") pipe.set_adapters(["deblurring"]) elif target_mode == "bokeh": pipe.load_lora_weights(BOKEH_LORA_DIR, weight_name=BOKEH_WEIGHT_NAME, adapter_name="bokeh") pipe.set_adapters(["bokeh"]) def preprocess_input_image(raw_img): """ Always enforces resizing to 512 (long edge) and cropping to 16x. """ if raw_img is None: return None, None print(f"🔄 Preprocessing Input... Enforcing Resize.") # Always resize and crop final_input = resize_and_crop_to_16(raw_img) return final_input, final_input def draw_red_dot_on_preview(clean_img, evt: gr.SelectData): if clean_img is None: return None, None img_copy = clean_img.copy() draw = ImageDraw.Draw(img_copy) x, y = evt.index r = 8 draw.ellipse((x-r, y-r, x+r, y+r), outline="red", width=2) draw.line((x-r, y, x+r, y), fill="red", width=2) draw.line((x, y-r, x, y+r), fill="red", width=2) return img_copy, evt.index # ========================================== # 4. Main Pipeline # ========================================== @spaces.GPU(duration=120) def run_genfocus_pipeline(clean_input, click_coords, K_value): global pipe_flux, depth_model, depth_transform device = "cuda" if clean_input is None: raise gr.Error("Please complete Step 1 (Upload Image) first.") W_dyn, H_dyn = clean_input.size print(f"📏 Processing Image Size: {W_dyn}x{H_dyn}") if pipe_flux is None: print("🚀 Loading FLUX to GPU (First Run)...") from Genfocus.pipeline.flux import FluxPipeline pipe_flux = FluxPipeline.from_pretrained( MODEL_ID, torch_dtype=torch.bfloat16, token=os.getenv("HF_TOKEN") ).to(device) else: try: _ = pipe_flux.device.type pipe_flux.to(device) except Exception: print("⚠️ GPU Context changed, reloading FLUX...") from Genfocus.pipeline.flux import FluxPipeline pipe_flux = FluxPipeline.from_pretrained( MODEL_ID, torch_dtype=torch.bfloat16, token=os.getenv("HF_TOKEN") ).to(device) # --- Load Depth Pro --- depth_loader = DepthProLoader() if depth_model is None: depth_model, depth_transform = depth_loader.load(device=device) else: try: depth_model = depth_model.to(device) except Exception: print("⚠️ GPU Context changed, reloading Depth Pro...") depth_model, depth_transform = depth_loader.load(device=device) from Genfocus.pipeline.flux import Condition, generate, seed_everything print("⚡ Running Inference...") # STAGE 1: DEBLUR switch_lora_on_gpu(pipe_flux, "deblur") condition_0_img = Image.new("RGB", (W_dyn, H_dyn), (0, 0, 0)) cond0 = Condition(condition_0_img, "deblurring", [0, 32], 1.0) cond1 = Condition(clean_input, "deblurring", [0, 0], 1.0) seed_everything(42) deblurred_img = generate( pipe_flux, height=H_dyn, width=W_dyn, prompt="a sharp photo with everything in focus", conditions=[cond0, cond1] ).images[0] if K_value == 0: return deblurred_img # STAGE 2: BOKEH if click_coords is None: click_coords = [W_dyn // 2, H_dyn // 2] # Depth Estimation img_t = depth_transform(deblurred_img).to(device) with torch.no_grad(): pred = depth_model.infer(img_t, f_px=None) depth_map = pred["depth"].cpu().numpy().squeeze() safe_depth = np.where(depth_map > 0.0, depth_map, np.finfo(np.float32).max) disp_orig = 1.0 / safe_depth # Resize disp to match current image dimensions disp = cv2.resize(disp_orig, (W_dyn, H_dyn), interpolation=cv2.INTER_LINEAR) # Defocus Map tx, ty = click_coords tx = min(max(int(tx), 0), W_dyn - 1) ty = min(max(int(ty), 0), H_dyn - 1) disp_focus = float(disp[ty, tx]) dmf = disp - np.float32(disp_focus) defocus_abs = np.abs(K_value * dmf) MAX_COC = 100.0 defocus_t = torch.from_numpy(defocus_abs).unsqueeze(0).float() cond_map = (defocus_t / MAX_COC).clamp(0, 1).repeat(3,1,1).unsqueeze(0) # Generate New Latents seed_everything(42) gen = torch.Generator(device=pipe_flux.device).manual_seed(1234) current_latents, _ = pipe_flux.prepare_latents( batch_size=1, num_channels_latents=16, height=H_dyn, width=W_dyn, dtype=pipe_flux.dtype, device=pipe_flux.device, generator=gen, latents=None ) # Generate Bokeh switch_lora_on_gpu(pipe_flux, "bokeh") cond_img = Condition(deblurred_img, "bokeh") cond_dmf = Condition(cond_map, "bokeh", [0,0], 1.0, No_preprocess=True) seed_everything(42) gen = torch.Generator(device=pipe_flux.device).manual_seed(1234) with torch.no_grad(): res = generate( pipe_flux, height=H_dyn, width=W_dyn, prompt="an excellent photo with a large aperture", conditions=[cond_img, cond_dmf], guidance_scale=1.0, kv_cache=False, generator=gen, latents=current_latents, ) generated_bokeh = res.images[0] return generated_bokeh # ========================================== # 5. UI Setup # ========================================== css = """ #col-container { margin: 0 auto; max-width: 1400px; } #output_image { min-height: 400px; } """ base_path = os.getcwd() example_dir = os.path.join(base_path, "example") valid_examples = [] if os.path.exists(example_dir): files = os.listdir(example_dir) for f in files: if f.lower().endswith(('.jpg', '.jpeg', '.png')): valid_examples.append([os.path.join(example_dir, f)]) with gr.Blocks(css=css) as demo: clean_processed_state = gr.State(value=None) click_coords_state = gr.State(value=None) with gr.Column(elem_id="col-container"): gr.Markdown("# 📷 Genfocus Pipeline: Interactive Refocusing (HF Demo)") # --- Description & Guide --- gr.Markdown(""" ### 📖 User Guide **Generative Refocusing** supports two main applications: * **All-In-Focus (AIF) Estimation:** Set **K = 0**. The model will restore the AIF image from the blurry input. * **Refocusing:** 1. **Click** on the subject you want to bring into focus in the **Step 2** image preview. 2. Increase **K** (Blur Strength) to generate realistic bokeh effects based on the scene's depth. > ⚠️ **Preprocessing Note:** Due to resource constraints in this demo, input images are **automatically resized** (longer edge = 512px). > If you wish to perform inference at the **original resolution**, please refer to our **[GitHub Code](https://github.com/rayray9999/Genfocus)** to run it locally. """) with gr.Row(): # --- Top Row: Inputs & Controls --- # [Step 1: Upload] with gr.Column(scale=1): gr.Markdown("### Step 1: Upload Image") gr.Markdown("Click an example or upload your own image.") input_raw = gr.Image(label="Raw Input Image", type="pil") if valid_examples: gr.Examples(examples=valid_examples, inputs=input_raw, label="Examples (Click to Load)") # [Step 2: Focus & Run] with gr.Column(scale=1): gr.Markdown("### Step 2: Set Focus & K") gr.Markdown("The image below shows the actual input for the model. **Click on the image** to set the focus point.") focus_preview_img = gr.Image(label="Model Input (Processed) - Click Here", type="pil", interactive=False) with gr.Row(): click_status = gr.Textbox(label="Selected Coordinates", value="Center (Default)", interactive=False, scale=1) k_slider = gr.Slider(minimum=0, maximum=50, value=20, step=1, label="Blur Strength (K)", scale=2) run_btn = gr.Button("✨ Run Genfocus", variant="primary", scale=1) # --- Bottom Row: Output --- with gr.Row(): with gr.Column(): gr.Markdown("### Result") output_img = gr.Image(label="Final Output", type="pil", interactive=False, elem_id="output_image") # ==================== Event Handling ==================== # 1. Update Preview (Removed resize_chk) update_trigger = [input_raw.change, input_raw.upload] for trigger in update_trigger: trigger( fn=preprocess_input_image, inputs=[input_raw], outputs=[focus_preview_img, clean_processed_state] ) # 2. Draw Red Dot on Click focus_preview_img.select( fn=draw_red_dot_on_preview, inputs=[clean_processed_state], outputs=[focus_preview_img, click_coords_state] ).then( fn=lambda x: f"x={x[0]}, y={x[1]}", inputs=[click_coords_state], outputs=[click_status] ) # 3. Run Pipeline run_btn.click( fn=run_genfocus_pipeline, inputs=[clean_processed_state, click_coords_state, k_slider], outputs=[output_img] ) if __name__ == "__main__": demo.launch()