Spaces:
Running
on
Zero
Running
on
Zero
feat: update Gradio interface to include annotated image section and refine caption generation status message
e14a05d
| #!/usr/bin/env python3 | |
| """ | |
| Gradio Demo App for Patchioner Model - Trace-based Image Captioning | |
| This demo allows users to: | |
| 1. Upload or select an image | |
| 2. Draw traces on the image using Gradio's ImageEditor | |
| 3. Generate captions for the traced regions using a pre-trained Patchioner model | |
| Author: Generated for decap-dino project | |
| """ | |
| import os | |
| import shutil | |
| import time | |
| import glob | |
| try: | |
| import spaces | |
| except ModuleNotFoundError: | |
| print("Warning: 'spaces' module not found, using mock decorator for local testing.") | |
| # local testing, mock decorator | |
| class spaces: | |
| def GPU(func): | |
| return func | |
| import gradio as gr | |
| USE_BBOX_ANNOTATOR = True | |
| if not USE_BBOX_ANNOTATOR: | |
| from gradio_image_annotation import image_annotator as foo_image_annotator | |
| else: | |
| from gradio_bbox_annotator.bbox_annotator import BBoxAnnotator | |
| import torch | |
| import yaml | |
| import traceback | |
| from pathlib import Path | |
| from PIL import Image, ImageDraw, ImageFont | |
| import numpy as np | |
| from typing import Any, List, Dict, Tuple | |
| from patchioner import Patchioner | |
| # colors for brush - orange, green, blue, magenta, yellow with ~60% opacity | |
| colors = ["#ffa2009d", "#00ff0099", "#0000ff96""#ff00ff97", "#ffa60099"] | |
| color_index = 0 | |
| # Global variable to store the loaded model | |
| loaded_model = None | |
| model_config_path = None | |
| device = "cuda" if torch.cuda.is_available() else "cpu" | |
| # Default model configuration | |
| DEFAULT_MODEL_CONFIG = "https://huggingface.co/Ruggero1912/Patch-ioner_talk2dino_decap_COCO_Captions" | |
| # Example images directory | |
| current_dir = os.path.dirname(__file__) | |
| EXAMPLE_IMAGES_DIR = Path(os.path.join(current_dir, 'example-images')).resolve() | |
| CONFIGS_DIR = Path(os.path.join(current_dir, 'configs')).resolve() | |
| def initialize_default_model() -> str: | |
| """Initialize the default model at startup.""" | |
| global loaded_model, model_config_path | |
| try: | |
| # Look for the default config file | |
| default_config_path = CONFIGS_DIR / DEFAULT_MODEL_CONFIG | |
| if not default_config_path.exists(): | |
| print( f"Default config file not found locally." ) | |
| config = DEFAULT_MODEL_CONFIG # Assume it's a URL or model identifier | |
| print( f"Attempting to load model as identifier: {config}" ) | |
| else: | |
| config = default_config_path | |
| print(f"Loading default model: {DEFAULT_MODEL_CONFIG}") | |
| # Load the model using the from_config class method | |
| model = Patchioner.from_config(config, device=device) | |
| model.eval() | |
| model.to(device) | |
| # Store the model globally | |
| loaded_model = model | |
| model_config_path = str(default_config_path) | |
| return f"β Default model loaded: {DEFAULT_MODEL_CONFIG} on {device}" | |
| except Exception as e: | |
| error_msg = f"β Error loading default model: {str(e)}" | |
| print(error_msg) | |
| print(traceback.format_exc()) | |
| return error_msg | |
| def get_example_images(limit=None) -> List[str]: | |
| """Get list of example images for the demo as file paths.""" | |
| example_images = [] | |
| if EXAMPLE_IMAGES_DIR.exists(): | |
| for ext in ['*.jpg', '*.jpeg', '*.png']: | |
| example_images.extend(str(p) for p in EXAMPLE_IMAGES_DIR.glob(ext)) | |
| if limit is not None: | |
| example_images = example_images[:limit] | |
| return example_images | |
| def get_example_configs() -> List[str]: | |
| """Get list of example config files.""" | |
| example_configs = [] | |
| if CONFIGS_DIR.exists(): | |
| example_configs = [str(p) for p in CONFIGS_DIR.glob("*.yaml")] | |
| else: | |
| print(f"Warning: Configs directory {CONFIGS_DIR} does not exist.") | |
| return sorted(example_configs) | |
| def cleanup_gradio_cache(max_folders: int = 100, gradio_temp_dir: str = "/tmp/gradio"): | |
| """ | |
| Clean up old Gradio temporary folders to prevent disk space issues. | |
| Args: | |
| max_folders: Maximum number of cache folders to keep (default: 100) | |
| gradio_temp_dir: Path to Gradio temporary directory (default: /tmp/gradio) | |
| """ | |
| try: | |
| if not os.path.exists(gradio_temp_dir): | |
| return | |
| # Get all subdirectories in the gradio temp folder | |
| cache_dirs = [] | |
| for item in os.listdir(gradio_temp_dir): | |
| item_path = os.path.join(gradio_temp_dir, item) | |
| if os.path.isdir(item_path): | |
| cache_dirs.append(item_path) | |
| # If we don't have too many folders, no cleanup needed | |
| if len(cache_dirs) <= max_folders: | |
| return | |
| # Sort by modification time (oldest first) | |
| cache_dirs.sort(key=os.path.getmtime) | |
| # Calculate how many folders to delete | |
| folders_to_delete = len(cache_dirs) - max_folders | |
| folders_to_remove = cache_dirs[:folders_to_delete] | |
| # Delete the oldest folders | |
| deleted_count = 0 | |
| for folder_path in folders_to_remove: | |
| try: | |
| shutil.rmtree(folder_path) | |
| deleted_count += 1 | |
| except Exception as e: | |
| print(f"Warning: Could not delete cache folder {folder_path}: {e}") | |
| if deleted_count > 0: | |
| print(f"π§Ή Cleaned up {deleted_count} old Gradio cache folders to save disk space") | |
| except Exception as e: | |
| print(f"Warning: Error during Gradio cache cleanup: {e}") | |
| def load_model_from_config(config_file_path: str) -> str: | |
| """ | |
| Load the Patchioner model from a config file. | |
| Args: | |
| config_file_path: Path to the YAML configuration file | |
| Returns: | |
| Status message about model loading | |
| """ | |
| global loaded_model, model_config_path | |
| try: | |
| if not config_file_path or not os.path.exists(config_file_path): | |
| return "β Error: Config file path is empty or file does not exist." | |
| print(f"Loading model from config: {config_file_path}") | |
| # Load and parse the config | |
| with open(config_file_path, 'r') as f: | |
| config = yaml.safe_load(f) | |
| # Load the model using the from_config class method | |
| model = Patchioner.from_config(config, device=device) | |
| model.eval() | |
| model.to(device) | |
| # Store the model globally | |
| loaded_model = model | |
| model_config_path = config_file_path | |
| return f"β Model loaded successfully from {os.path.basename(config_file_path)} on {device}" | |
| except Exception as e: | |
| error_msg = f"β Error loading model: {str(e)}" | |
| print(error_msg) | |
| print(traceback.format_exc()) | |
| return error_msg | |
| def process_image_trace_to_coordinates(image_editor_data) -> List[List[Dict[str, float]]]: | |
| """ | |
| Convert Gradio ImageEditor trace data to the coordinate format expected by the model. | |
| The expected format is: [[{"x": float, "y": float, "t": float}, ...], ...] | |
| where coordinates are normalized to [0, 1] and t is a timestamp. | |
| Args: | |
| image_editor_data: Data from Gradio ImageEditor component | |
| Returns: | |
| List of traces in the expected format | |
| """ | |
| try: | |
| print(f"[DEBUG] process_image_trace_to_coordinates called") | |
| print(f"[DEBUG] image_editor_data type: {type(image_editor_data)}") | |
| if image_editor_data is None: | |
| print("[DEBUG] image_editor_data is None") | |
| return [] | |
| if isinstance(image_editor_data, dict): | |
| print(f"[DEBUG] Available keys in image_editor_data: {list(image_editor_data.keys())}") | |
| # Check for different possible structures | |
| layers = None | |
| if isinstance(image_editor_data, dict): | |
| if 'layers' in image_editor_data: | |
| layers = image_editor_data['layers'] | |
| elif 'composite' in image_editor_data: | |
| # Sometimes gradio stores drawing data differently | |
| composite = image_editor_data['composite'] | |
| if isinstance(composite, dict) and 'layers' in composite: | |
| layers = composite['layers'] | |
| if not layers: | |
| print("[DEBUG] No layers found in image_editor_data") | |
| return [] | |
| traces = [] | |
| print(f"[DEBUG] Processing {len(layers)} layers") | |
| # Process each drawing layer - they are PIL Images, not coordinate data | |
| for i, layer in enumerate(layers): | |
| print(f"[DEBUG] Processing layer {i}: {layer}") | |
| # Skip if layer is not a PIL Image or is empty | |
| if not isinstance(layer, Image.Image): | |
| print(f"[DEBUG] Layer {i} is not a PIL Image") | |
| # try to parse from numpy array if possible | |
| if isinstance(layer, np.ndarray): | |
| layer_array = layer | |
| layer = Image.fromarray(layer) | |
| print(f"[DEBUG] Layer {i} converted from numpy array to PIL Image") | |
| else: | |
| continue | |
| else: | |
| # Convert layer to numpy array to find non-transparent pixels | |
| layer_array = np.array(layer) | |
| # Find non-transparent pixels (alpha > 0) | |
| if layer_array.shape[2] == 4: # RGBA | |
| non_transparent = layer_array[:, :, 3] > 0 | |
| else: # RGB - assume any non-black pixel is drawn | |
| non_transparent = np.any(layer_array > 0, axis=2) | |
| # Get coordinates of drawn pixels | |
| y_coords, x_coords = np.where(non_transparent) | |
| if len(x_coords) == 0: | |
| print(f"[DEBUG] Layer {i} has no drawn pixels") | |
| continue | |
| print(f"[DEBUG] Layer {i} has {len(x_coords)} drawn pixels") | |
| # Convert pixel coordinates to trace format | |
| trace_points = [] | |
| img_height, img_width = layer_array.shape[:2] | |
| # Sample some points from the drawn pixels (to avoid too many points) | |
| num_points = min(len(x_coords), 100) # Limit to 100 points max | |
| if num_points > 0: | |
| # Sample evenly spaced indices | |
| indices = np.linspace(0, len(x_coords) - 1, num_points, dtype=int) | |
| sampled_x = x_coords[indices] | |
| sampled_y = y_coords[indices] | |
| # Convert to normalized coordinates and create trace points | |
| for idx, (x, y) in enumerate(zip(sampled_x, sampled_y)): | |
| # Normalize coordinates to [0, 1] | |
| x_norm = float(x) / img_width if img_width > 0 else 0 | |
| y_norm = float(y) / img_height if img_height > 0 else 0 | |
| # Clamp to [0, 1] range | |
| x_norm = max(0, min(1, x_norm)) | |
| y_norm = max(0, min(1, y_norm)) | |
| # Add timestamp (arbitrary progression) | |
| t = idx * 0.1 | |
| trace_points.append({ | |
| "x": x_norm, | |
| "y": y_norm, | |
| "t": t | |
| }) | |
| if trace_points: | |
| traces.append(trace_points) | |
| return traces | |
| except Exception as e: | |
| print(f"Error processing image trace: {e}") | |
| print(traceback.format_exc()) | |
| return [] | |
| def process_bounding_box_coordinates(annotator_data) -> List[List[float]]: | |
| """ | |
| Convert Gradio image_annotator data to bounding box format expected by the model. | |
| Args: | |
| annotator_data: Data from Gradio image_annotator component | |
| Returns: | |
| List of bounding boxes in [x, y, width, height] format | |
| """ | |
| try: | |
| print(f"[DEBUG] process_bounding_box_coordinates called") | |
| print(f"[DEBUG] annotator_data type: {type(annotator_data)}") | |
| #print(f"[DEBUG] annotator_data content: {annotator_data}") | |
| if annotator_data is None: | |
| print("[DEBUG] annotator_data is None") | |
| return [] | |
| boxes = [] | |
| # Handle the dictionary format from image_annotator | |
| if isinstance(annotator_data, dict): | |
| print(f"[DEBUG] Available keys in annotator_data: {list(annotator_data.keys())}") | |
| # Extract boxes from the 'boxes' key | |
| if 'boxes' in annotator_data and annotator_data['boxes']: | |
| for box in annotator_data['boxes']: | |
| if isinstance(box, dict): | |
| # Based on image_annotator.py, boxes have format: | |
| # {"xmin": x, "ymin": y, "xmax": x2, "ymax": y2, "label": ..., "color": ...} | |
| xmin = box.get('xmin', 0) | |
| ymin = box.get('ymin', 0) | |
| xmax = box.get('xmax', 0) | |
| ymax = box.get('ymax', 0) | |
| width = xmax - xmin | |
| height = ymax - ymin | |
| # Convert to [x, y, width, height] format | |
| boxes.append([xmin, ymin, width, height]) | |
| else: | |
| print("[DEBUG] No 'boxes' key found or boxes list is empty") | |
| # Handle the tuple format from BBoxAnnotator | |
| elif isinstance(annotator_data, tuple) and len(annotator_data) == 2: | |
| print(f"[DEBUG] Tuple format detected with length {len(annotator_data)}") | |
| box_list = annotator_data[1] | |
| if isinstance(box_list, list): | |
| for box in box_list: | |
| if isinstance(box, (list, tuple)) and len(box) >= 4: | |
| # Assuming box format is [left, top, right, bottom, label (optional)] | |
| left = box[0] | |
| top = box[1] | |
| right = box[2] | |
| bottom = box[3] | |
| width = right - left | |
| height = bottom - top | |
| boxes.append([left, top, width, height]) | |
| else: | |
| print("[DEBUG] Second element of tuple is not a list") | |
| print(f"[DEBUG] Found {len(boxes)} bounding boxes: {boxes}") | |
| return boxes | |
| except Exception as e: | |
| print(f"Error processing bounding box: {e}") | |
| print(traceback.format_exc()) | |
| return [] | |
| def draw_traces_on_image(image: Image.Image, traces: List[List[Dict[str, float]]], captions: List[str], layers: List[Image.Image]) -> Image.Image: | |
| """ | |
| Draw traces on image with colored lines and caption text. | |
| Args: | |
| image: PIL Image to draw on | |
| traces: List of traces (each trace is a list of {x, y, t} dicts with normalized coords) | |
| captions: List of captions corresponding to each trace | |
| Returns: | |
| PIL Image with traces and captions drawn on it | |
| """ | |
| # Create a copy to draw on | |
| img_with_traces = image.copy().convert('RGBA') | |
| img_width, img_height = img_with_traces.size | |
| fontsize = int(min(img_width, img_height) / 30) # Example: 1/30th of the smaller dimension | |
| print(f"[DEBUG] Computed fontsize: {fontsize}") | |
| # Create a transparent overlay for drawing traces with opacity | |
| overlay = Image.new('RGBA', img_with_traces.size, (255, 255, 255, 0)) | |
| draw_overlay = ImageDraw.Draw(overlay) | |
| # Create a separate layer for text (no transparency) | |
| draw_final = ImageDraw.Draw(img_with_traces) | |
| # Try to load a font with larger size | |
| try: | |
| font = ImageFont.truetype("/usr/share/fonts/truetype/dejavu/DejaVuSans-Bold.ttf", fontsize) | |
| except: | |
| try: | |
| font = ImageFont.truetype("arial.ttf", fontsize) | |
| except: | |
| font = ImageFont.load_default(fontsize) | |
| img_width, img_height = image.size | |
| for i, (trace, caption) in enumerate(zip(traces, captions)): | |
| # Get color for this trace with alpha channel | |
| color_hex = colors[i % len(colors)] | |
| # Convert hex color to RGBA (with ~60% opacity for lines) | |
| color_rgba = tuple(int(color_hex[j:j+2], 16) for j in (1, 3, 5)) + (150,) # 150/255 β 60% opacity | |
| # Solid color for text | |
| color_rgb = tuple(int(color_hex[j:j+2], 16) for j in (1, 3, 5)) | |
| if len(layers) > i: | |
| current_layer = layers[i] | |
| # current_layer is a PIL Image or numpy array, use directly this as overlay | |
| if isinstance(current_layer, Image.Image): | |
| layer_rgba = current_layer.convert('RGBA').resize((img_width, img_height)) | |
| # set the layer_rgba to color_rgba where the layer is not transparent | |
| elif isinstance(current_layer, np.ndarray): | |
| layer_rgba = Image.fromarray(current_layer).convert('RGBA').resize((img_width, img_height)) | |
| #overlay = Image.alpha_composite(overlay, layer_image) | |
| #continue # Skip drawing trace points if layer is used | |
| datas = layer_rgba.getdata() | |
| newData = [] | |
| for item in datas: | |
| if item[3] > 0: # If not transparent | |
| newData.append(color_rgba) # Use the trace color with alpha | |
| else: | |
| newData.append((255, 255, 255, 0)) # Transparent | |
| layer_rgba.putdata(newData) | |
| overlay = Image.alpha_composite(overlay, layer_rgba) | |
| continue # Skip drawing trace points if layer is used | |
| else: | |
| # Convert normalized coordinates to pixel coordinates | |
| points = [] | |
| for point in trace: | |
| x_pixel = int(point['x'] * img_width) | |
| y_pixel = int(point['y'] * img_height) | |
| points.append((x_pixel, y_pixel)) | |
| # Draw the trace as connected lines with transparency | |
| #if len(points) > 1: | |
| # draw_overlay.line(points, fill=color_rgba, width=8) | |
| # Draw circles at each point for visibility with transparency | |
| for point in points[::2]: # Draw every 2nd point to avoid clutter | |
| draw_overlay.ellipse([point[0]-10, point[1]-10, point[0]+10, point[1]+10], fill=color_rgba) | |
| # Composite the transparent overlay onto the base image | |
| img_with_traces = Image.alpha_composite(img_with_traces, overlay) | |
| # Now draw text on top (without transparency) | |
| draw_final = ImageDraw.Draw(img_with_traces) | |
| for i, (trace, caption) in enumerate(zip(traces, captions)): | |
| color_hex = colors[i % len(colors)] | |
| color_rgb = tuple(int(color_hex[j:j+2], 16) for j in (1, 3, 5)) | |
| # Get first point for text placement | |
| points = [] | |
| for point in trace: | |
| x_pixel = int(point['x'] * img_width) | |
| y_pixel = int(point['y'] * img_height) | |
| points.append((x_pixel, y_pixel)) | |
| # Draw caption text near the first point of the trace | |
| if points: | |
| text_x, text_y = points[0] | |
| # Draw text background for readability | |
| text_bbox = draw_final.textbbox((text_x, text_y), f"T{i+1}: {caption}", font=font) | |
| draw_final.rectangle(text_bbox, fill=(255, 255, 255, 230)) | |
| draw_final.text((text_x, text_y), f"T{i+1}: {caption}", fill=color_rgb + (255,), font=font) | |
| # Convert back to RGB | |
| return img_with_traces.convert('RGB') | |
| def draw_bboxes_on_image(image: Image.Image, bboxes: List[List[float]], captions: List[str]) -> Image.Image: | |
| """ | |
| Draw bounding boxes on image with colored rectangles and caption text. | |
| Args: | |
| image: PIL Image to draw on | |
| bboxes: List of bounding boxes in [x, y, width, height] format | |
| captions: List of captions corresponding to each bbox | |
| Returns: | |
| PIL Image with bboxes and captions drawn on it | |
| """ | |
| # Create a copy to draw on | |
| img_with_bboxes = image.copy() | |
| draw = ImageDraw.Draw(img_with_bboxes) | |
| # compute fontsize depending on image size | |
| img_width, img_height = image.size | |
| fontsize = int(min(img_width, img_height) / 30) # Example: 1/30th of the smaller dimension | |
| print(f"[DEBUG] Computed fontsize: {fontsize}") | |
| # Try to load a font with larger size | |
| try: | |
| font = ImageFont.truetype("/usr/share/fonts/truetype/dejavu/DejaVuSans-Bold.ttf", fontsize) | |
| except: | |
| try: | |
| font = ImageFont.truetype("arial.ttf", fontsize) | |
| except: | |
| font = ImageFont.load_default(fontsize) | |
| for i, (bbox, caption) in enumerate(zip(bboxes, captions)): | |
| # Get color for this bbox (remove alpha for PIL) | |
| color_hex = colors[i % len(colors)] | |
| # Convert hex color to RGB (ignoring alpha) | |
| color_rgb = tuple(int(color_hex[j:j+2], 16) for j in (1, 3, 5)) | |
| # Extract bbox coordinates | |
| x, y, w, h = bbox | |
| # Draw the bounding box | |
| draw.rectangle([x, y, x + w, y + h], outline=color_rgb, width=4) | |
| # Draw caption text at the top-left corner of the bbox | |
| text_x, text_y = x, max(0, y - 25) # Place text above the box if possible | |
| # Draw text background for readability | |
| text_bbox = draw.textbbox((text_x, text_y), f"{caption}", font=font) | |
| draw.rectangle(text_bbox, fill=(255, 255, 255, 200)) | |
| draw.text((text_x, text_y), f"{caption}", fill=color_rgb, font=font) | |
| return img_with_bboxes | |
| def generate_caption(mode, image_data) -> Tuple[str, Image.Image]: | |
| """ | |
| Generate caption for the image and traces/bboxes using the loaded model. | |
| Args: | |
| mode: Either "trace" or "bbox" mode | |
| image_data: Data from Gradio ImageEditor or Annotate component | |
| Returns: | |
| Tuple of (generated caption text, annotated image) | |
| """ | |
| global loaded_model | |
| # Clean up old cache folders on each generation to keep disk usage under control | |
| cleanup_gradio_cache(max_folders=30) # More aggressive cleanup during active use | |
| try: | |
| current_time = time.strftime("%Y-%m-%d %H:%M:%S", time.localtime()) | |
| print(f"[{current_time}] generate_caption called with mode: {mode}") | |
| print(f"[DEBUG] image_data type: {type(image_data)}") | |
| print(f"[DEBUG] image_data content: {image_data}") | |
| if loaded_model is None: | |
| return "β Error: No model loaded. Please load a model first using the config file.", None | |
| # Handle different input formats from Gradio components | |
| image = None | |
| if image_data is None: | |
| return "β Error: No image data provided.", None | |
| # Check if it's a PIL Image directly | |
| if isinstance(image_data, Image.Image): | |
| print("[DEBUG] Received PIL Image directly") | |
| image = image_data | |
| # Check if it's a dict (from image_annotator component) | |
| elif isinstance(image_data, dict): | |
| print(f"[DEBUG] Received dict with keys: {list(image_data.keys())}") | |
| if 'image' in image_data: | |
| image_array = image_data['image'] | |
| # Convert numpy array to PIL Image if needed | |
| if hasattr(image_array, 'shape') and len(image_array.shape) == 3: | |
| print("[DEBUG] Converting numpy array to PIL Image") | |
| image = Image.fromarray(image_array) | |
| else: | |
| image = image_array | |
| elif 'background' in image_data: | |
| image_array = image_data['background'] | |
| # Convert numpy array to PIL Image if needed | |
| if hasattr(image_array, 'shape') and len(image_array.shape) == 3: | |
| print("[DEBUG] Converting numpy array to PIL Image") | |
| image = Image.fromarray(image_array) | |
| else: | |
| image = image_array | |
| else: | |
| return f"β Error: No image found in data. Available keys: {list(image_data.keys())}", None | |
| # Check for tuple/list format (from ImageEditor component) | |
| elif isinstance(image_data, (tuple, list)) and len(image_data) >= 1: | |
| print(f"[DEBUG] Received tuple/list with {len(image_data)} elements") | |
| image = image_data[0] # First element should be the image | |
| # image can be a path to the image or a PIL Image | |
| if isinstance(image, str): | |
| if os.path.exists(image): | |
| print("[DEBUG] Loading image from file path") | |
| image = Image.open(image) | |
| else: | |
| print(f"β Error: Image path does not exist: {image}") | |
| if not isinstance(image, Image.Image): | |
| # Sometimes the structure might be different, search for PIL Image | |
| for item in image_data: | |
| if isinstance(item, Image.Image): | |
| image = item | |
| break | |
| else: | |
| return f"β Error: Unexpected data type: {type(image_data)}", None | |
| if image is None: | |
| return "β Error: Image is None.", None | |
| # Convert PIL image if necessary | |
| if not isinstance(image, Image.Image): | |
| return "β Error: Invalid image format.", None | |
| # Convert image to RGB if needed | |
| if image.mode != 'RGB': | |
| image = image.convert('RGB') | |
| if mode == "trace": | |
| return generate_trace_caption(image_data, image) | |
| elif mode == "bbox": | |
| return generate_bbox_caption(image_data, image) | |
| else: | |
| return f"β Error: Unknown mode: {mode}", None | |
| except Exception as e: | |
| error_msg = f"β Error generating caption: {str(e)}" | |
| print(error_msg) | |
| print(traceback.format_exc()) | |
| return error_msg, None | |
| def generate_trace_caption(image_data, image) -> Tuple[str, Image.Image]: | |
| """Generate caption using traces.""" | |
| global loaded_model | |
| loaded_model.to("cuda") | |
| try: | |
| # Process traces | |
| print("[DEBUG] Processing traces...") | |
| traces = process_image_trace_to_coordinates(image_data) | |
| print(f"[DEBUG] Found {len(traces)} traces") | |
| if not traces: | |
| # For debugging, let's generate a simple image caption instead of failing | |
| print("[DEBUG] No traces found, generating image caption instead") | |
| image_tensor = loaded_model.image_transforms(image).unsqueeze(0).to(device) | |
| with torch.no_grad(): | |
| outputs = loaded_model( | |
| image_tensor, | |
| get_cls_capt=True, # Get class caption as fallback | |
| get_patch_capts=False, | |
| get_avg_patch_capt=False | |
| ) | |
| if 'cls_capt' in outputs: | |
| return f"π No traces drawn. Image caption: {outputs['cls_capt']}", image | |
| else: | |
| return "β Error: No traces detected. Please draw some traces on the image.", None | |
| print(f"Processing {len(traces)} traces") | |
| # Prepare image tensor | |
| image_tensor = loaded_model.image_transforms(image).unsqueeze(0).to(device) | |
| # Generate caption using the model | |
| with torch.no_grad(): | |
| outputs = loaded_model( | |
| image_tensor, | |
| traces=traces, | |
| get_cls_capt=False, # We want trace captions, not class captions | |
| get_patch_capts=False, | |
| get_avg_patch_capt=False | |
| ) | |
| # Extract the trace captions | |
| if 'trace_capts' in outputs: | |
| captions = outputs['trace_capts'] | |
| if isinstance(captions, list) and captions: | |
| captions = [cap.replace("<|startoftext|>", "").replace("<|endoftext|>", "") for cap in captions] | |
| # Draw traces on the image | |
| annotated_image = draw_traces_on_image(image, traces, captions, layers=image_data.get('layers', []) if isinstance(image_data, dict) else []) | |
| # Join multiple captions if there are multiple traces | |
| if len(captions) == 1: | |
| return f"Generated Caption: {captions[0]}", annotated_image | |
| else: | |
| formatted_captions = [] | |
| for i, caption in enumerate(captions, 1): | |
| formatted_captions.append(f"<span style=\"color:{colors[(i-1)%(len(colors))]}\">Trace {i}: {caption}</span>") | |
| return "Generated Captions:\n\n" + "\n\n".join(formatted_captions), annotated_image | |
| elif isinstance(captions, str): | |
| captions_list = [captions.replace("<|startoftext|>", "").replace("<|endoftext|>", "")] | |
| annotated_image = draw_traces_on_image(image, traces, captions_list) | |
| return f"Generated Caption: {captions}", annotated_image | |
| else: | |
| return "β Error: No captions generated.", None | |
| else: | |
| return "β Error: Model did not return trace captions.", None | |
| except Exception as e: | |
| error_msg = f"β Error generating trace caption: {str(e)}" | |
| print(error_msg) | |
| print(traceback.format_exc()) | |
| return error_msg, None | |
| def generate_bbox_caption(image_data, image) -> Tuple[str, Image.Image]: | |
| """Generate caption using bounding boxes.""" | |
| global loaded_model | |
| loaded_model.to("cuda") | |
| original_image_size = image.size # (width, height) | |
| image_tensor = loaded_model.image_transforms(image).unsqueeze(0).to(device) | |
| transformed_image_size = image_tensor.shape[2:] # (height, width) | |
| try: | |
| # Process bounding boxes | |
| print("[DEBUG] Processing bounding boxes...") | |
| bboxes = process_bounding_box_coordinates(image_data) | |
| print(f"[DEBUG] Found {len(bboxes)} bounding boxes") | |
| if not bboxes: | |
| # For debugging, let's generate a simple image caption instead of failing | |
| print("[DEBUG] No bounding boxes found, generating image caption instead") | |
| with torch.no_grad(): | |
| outputs = loaded_model( | |
| image_tensor, | |
| get_cls_capt=True, # Get class caption as fallback | |
| get_patch_capts=False, | |
| get_avg_patch_capt=False | |
| ) | |
| if 'cls_capt' in outputs: | |
| return f"π No bounding boxes drawn. Image caption: {outputs['cls_capt']}", image | |
| else: | |
| return "β Error: No bounding boxes detected. Please draw some bounding boxes on the image.", None | |
| print(f"Processing {len(bboxes)} bounding boxes") | |
| # scale bboxes to transformed image size | |
| scale_x = transformed_image_size[1] / original_image_size[0] | |
| scale_y = transformed_image_size[0] / original_image_size[1] | |
| scaled_bboxes = [] | |
| for bbox in bboxes: | |
| x, y, w, h = bbox | |
| x = x * scale_x | |
| y = y * scale_y | |
| w = w * scale_x | |
| h = h * scale_y | |
| scaled_bboxes.append([x, y, w, h]) | |
| bbox_tensor = torch.tensor([scaled_bboxes]).to(device) | |
| with torch.no_grad(): | |
| outputs = loaded_model( | |
| image_tensor, | |
| bboxes=bbox_tensor, | |
| get_cls_capt=False, | |
| get_patch_capts=False, | |
| get_avg_patch_capt=False | |
| ) | |
| if 'bbox_capts' in outputs: | |
| print(f"[DEBUG] bbox_capts content: {outputs['bbox_capts']}") | |
| captions = outputs['bbox_capts'] | |
| if isinstance(captions, list) and captions: | |
| if isinstance(captions[0], list): | |
| captions = captions[0] # Unwrap nested list if needed | |
| captions = [cap.replace("<|startoftext|>", "").replace("<|endoftext|>", "") for cap in captions] | |
| # Draw bboxes on the image | |
| annotated_image = draw_bboxes_on_image(image, bboxes, captions) | |
| if len(captions) == 1: | |
| return f"Generated Caption: {captions[0]}", annotated_image | |
| else: | |
| formatted_captions = [] | |
| for i, caption in enumerate(captions, 1): | |
| formatted_captions.append(f"<span style=\"color:{colors[(i-1)%(len(colors))]}\">BBox {i}: {caption}</span>") | |
| return "Generated Captions:\n\n" + "\n\n".join(formatted_captions), annotated_image | |
| elif isinstance(captions, str): | |
| captions_list = [captions.replace("<|startoftext|>", "").replace("<|endoftext|>", "")] | |
| annotated_image = draw_bboxes_on_image(image, bboxes, captions_list) | |
| return f"Generated Caption: {captions}", annotated_image | |
| else: | |
| return "β Error: No captions generated.", None | |
| else: | |
| return "β Error: Model did not return bbox captions.", None | |
| except Exception as e: | |
| error_msg = f"β Error generating bbox caption: {str(e)}" | |
| print(error_msg) | |
| print(traceback.format_exc()) | |
| return error_msg, None | |
| # def change_layer(current_layer): | |
| # """Each time the button is pressed, change the brush color.""" | |
| # global color_index | |
| # color_index = (color_index + 1) % len(colors) | |
| # return gr.update(elem_id="image_editor", brush=gr.Brush(default_size=10, colors=[colors[color_index]], color_mode="fixed")) | |
| def resize_image_if_needed(editor_value, max_dim=1024): | |
| """ | |
| Resizes the background image if it exceeds max_dim, or returns gr.skip() | |
| to prevent a change event from looping. | |
| """ | |
| # Handle no image case | |
| if editor_value is None: | |
| print("No image present") | |
| return gr.skip() | |
| # If some layers were already drawn, do not resize (to avoid losing drawings) | |
| if 'layers' in editor_value and len(editor_value['layers']): | |
| print("Not resizing because layers are present") | |
| return gr.skip() | |
| background_image = editor_value.get('background') | |
| # Handle missing background case | |
| if background_image is None: | |
| print("No background image present") | |
| return gr.skip() | |
| width, height = background_image.size | |
| # Check if resizing is necessary (THE CONDITION) | |
| if width > max_dim or height > max_dim: | |
| # --- RESIZING LOGIC --- | |
| # Calculate new dimensions while preserving aspect ratio | |
| if width > height: | |
| new_width = max_dim | |
| new_height = int(height * (max_dim / width)) | |
| else: | |
| new_height = max_dim | |
| new_width = int(width * (max_dim / height)) | |
| resized_image = background_image.resize((new_width, new_height), Image.Resampling.LANCZOS) | |
| print(f"Resizing image from ({width}, {height}) to ({resized_image.size[0]}, {resized_image.size[1]})") | |
| # Create the new dictionary with the resized image | |
| new_editor_value = editor_value.copy() | |
| new_editor_value['background'] = resized_image | |
| new_editor_value['composite'] = resized_image | |
| # Return the new value (triggers an update and one more change event) | |
| return new_editor_value | |
| # 4. If no resizing was needed, SKIP the update. (THE FIX) | |
| print("No resizing needed") | |
| return gr.skip() | |
| def create_gradio_interface(model_config_name : str): | |
| """Create and configure the Gradio interface.""" | |
| # Clean up old Gradio cache folders to prevent disk space issues | |
| cleanup_gradio_cache(max_folders=50) # Keep only 50 most recent cache folders | |
| # Get example files | |
| example_images = get_example_images() | |
| example_configs = get_example_configs() | |
| custom_js = """ | |
| <script> | |
| window.addEventListener("load", () => { | |
| // Hide Crop, Erase, and Color buttons | |
| const cropBtn = document.querySelector('.image-editor__tool[title="Crop"]'); | |
| const eraseBtn = document.querySelector('.image-editor__tool[title="Erase"]'); | |
| const colorBtn = document.querySelector('.image-editor__tool[title="Color"]'); | |
| [cropBtn, eraseBtn, colorBtn].forEach(btn => { | |
| console.log("Going to disable display for ", btn); | |
| if (btn) btn.style.display = "none"; | |
| }); | |
| // Optionally, select the Brush/Draft tool right away | |
| const brushBtn = document.querySelector('.image-editor__tool[title="Draw"]'); | |
| console.log("Selecting brushbtn: ", brushBtn); | |
| if (brushBtn) brushBtn.click(); | |
| }); | |
| </script> | |
| """ | |
| with gr.Blocks( | |
| title="Patchioner Trace Captioning Demo", | |
| theme=gr.themes.Soft(), | |
| css=""" | |
| .gradio-container { | |
| /*max-width: 1200px !important;*/ | |
| } | |
| """ | |
| ) as demo: | |
| #gr.HTML(custom_js) # inject custom JS | |
| gr.Markdown(f""" | |
| # π― Patchioner Trace Captioning Demo | |
| This demo showcases the **Patchioner** model for generating image captions based on user-drawn traces or bounding boxes. | |
| More details about the Patch-ioner framework can be found in the official [project webpage](https://paciosoft.com/Patch-ioner/). | |
| Patch-ioner is an unified zero-shot captioning framework to describe arbitrary image regions. | |
| ## Instructions: | |
| 1. Choose between Trace or BBox mode | |
| 2. Upload an image or use one of the provided examples | |
| 3. Use the appropriate tool to mark areas of interest in the image | |
| 4. Click "Generate Caption" to get AI-generated descriptions | |
| > Tip: Use the Layer tool to generate multiple captions for different traces. | |
| """) | |
| # Initialize model status | |
| model_initialization_status = initialize_default_model() | |
| with gr.Row(): | |
| gr.Markdown(f"**Model Status:** {model_initialization_status}") | |
| with gr.Column(): | |
| gr.Markdown("#### π· Select from example images or upload your own:") | |
| if example_images: | |
| example_gallery = gr.Gallery( | |
| value=example_images, | |
| label="Example Images", | |
| show_label=True, | |
| elem_id="gallery", | |
| columns=4, | |
| rows=2, | |
| object_fit="contain", | |
| height="auto" | |
| ) | |
| mode_selector = gr.Radio( | |
| choices=["trace", "bbox"], | |
| value="trace", | |
| label="π Captioning Mode", | |
| info="Choose between trace-based or bounding box-based captioning", | |
| visible=True | |
| ) | |
| with gr.Row(): | |
| with gr.Column(): | |
| gr.Markdown("### πΌοΈ Image Editor") | |
| # Image editor for drawing traces (default) | |
| image_editor = gr.ImageEditor( | |
| elem_id="image_editor", | |
| label="Upload image and draw traces", | |
| type="pil", | |
| #crop_size=None, | |
| brush=gr.Brush(default_size=10, colors=[colors[color_index]], color_mode="fixed"), # orange with ~60% opacity | |
| visible=True, | |
| eraser=False, | |
| #transforms=[], | |
| height=600, | |
| #layers=gr.LayerOptions(allow_additional_layers=True, disabled=True), | |
| ) | |
| # Image annotator for bounding boxes (hidden by default) | |
| if not USE_BBOX_ANNOTATOR: | |
| image_annotator = foo_image_annotator( #gr.Image( | |
| label="Upload image and draw bounding boxes", | |
| visible=False, | |
| #classes=["object"], | |
| #type="bbox" | |
| #tool="select" | |
| height=600 | |
| ) | |
| else: | |
| image_annotator = BBoxAnnotator( | |
| label="Upload image and draw bounding boxes", | |
| visible=False, | |
| show_label=True, | |
| show_download_button=False, | |
| interactive=True, | |
| container=True, | |
| categories=["area"] | |
| ) | |
| with gr.Column(): | |
| gr.Markdown("### πΌοΈ Annotated Image") | |
| output_image = gr.Image( | |
| label="Annotated Image", | |
| type="pil", | |
| visible=True, | |
| height=600 | |
| ) | |
| with gr.Row(): | |
| generate_button = gr.Button("β¨ Generate Caption", variant="primary", size="lg") | |
| with gr.Row(): | |
| status_message = gr.TextArea( | |
| elem_id="status_message_textarea", | |
| placeholder="Status messages will appear here...", | |
| visible=True | |
| ) | |
| with gr.Row(): | |
| output_text = gr.Markdown( | |
| label="Generated Caption", | |
| value="Generated caption will appear here...", | |
| #lines=5, | |
| #max_lines=10, | |
| #interactive=False | |
| ) | |
| # Event handlers | |
| def toggle_input_components(mode): | |
| """Toggle between image editor and annotator based on mode.""" | |
| if mode == "trace": | |
| return gr.update(visible=True), gr.update(visible=False) | |
| else: # bbox mode | |
| return gr.update(visible=False), gr.update(visible=True) | |
| def load_example_image_to_both(evt: gr.SelectData): | |
| """Load selected example image into both components.""" | |
| if not USE_BBOX_ANNOTATOR: | |
| empty_annotated_format = {"image": None, "boxes": [], "orientation": 0} | |
| else: | |
| empty_annotated_format = (None, []) | |
| try: | |
| example_images = get_example_images() | |
| if evt.index < len(example_images): | |
| selected_image_path = example_images[evt.index] | |
| img = Image.open(selected_image_path).convert('RGB') | |
| # For ImageEditor, return the PIL image directly | |
| # For image_annotator, return dict format as expected by the component | |
| if not USE_BBOX_ANNOTATOR: | |
| annotated_format = { | |
| "image": img, | |
| "boxes": [], | |
| "orientation": 0 | |
| } | |
| else: | |
| annotated_format = tuple((selected_image_path, [])) | |
| # convert to numpy array for ImageEditor | |
| img = np.array(img) | |
| return img, annotated_format | |
| return None, empty_annotated_format | |
| except Exception as e: | |
| print(f"Error loading example image: {e}") | |
| return None, empty_annotated_format | |
| def generate_caption_wrapper(mode, image_editor_data, image_annotator_data): | |
| """Wrapper to call generate_caption with the appropriate data based on mode.""" | |
| if mode == "trace": | |
| return generate_caption(mode, image_editor_data) | |
| else: # bbox mode | |
| return generate_caption(mode, image_annotator_data) | |
| def generate_with_feedback(mode, image_editor_data, image_annotator_data): | |
| """ | |
| Wrapper that provides UI feedback during caption generation. | |
| Yields intermediate states to update the UI. | |
| """ | |
| # First yield: Show processing status | |
| yield ( | |
| "β³ Processing your request...", | |
| gr.update(elem_id="status_message_textarea", value="π Generating caption... Please wait.", visible=True), | |
| None | |
| ) | |
| # Generate caption | |
| caption_text, annotated_image = generate_caption_wrapper(mode, image_editor_data, image_annotator_data) | |
| # Final yield: Show results and clear status | |
| yield ( | |
| caption_text, | |
| gr.update(elem_id="status_message_textarea", value="", visible=True), | |
| annotated_image | |
| ) | |
| # Connect event handlers | |
| mode_selector.change( | |
| fn=toggle_input_components, | |
| inputs=mode_selector, | |
| outputs=[image_editor, image_annotator] | |
| ) | |
| generate_button.click( | |
| fn=generate_with_feedback, | |
| inputs=[mode_selector, image_editor, image_annotator], | |
| outputs=[output_text, status_message, output_image] | |
| ) | |
| if example_images: | |
| example_gallery.select( | |
| fn=load_example_image_to_both, | |
| outputs=[image_editor, image_annotator] | |
| ) | |
| #image_editor.change( | |
| # fn=resize_image_if_needed, | |
| # inputs=[image_editor], | |
| # outputs=[image_editor], | |
| # # The queue=False means this runs immediately on the change event, | |
| # # which is usually desired for immediate UI updates. | |
| # #queue=False | |
| #) | |
| gr.Markdown(f""" | |
| ### π‘ Tips: | |
| - **Mode Selection**: Switch between trace and bounding box modes based on your needs | |
| - **Trace Mode**: Draw continuous lines over areas you want to describe | |
| - **BBox Mode**: Draw rectangular bounding boxes around objects of interest | |
| - **Multiple Areas**: Change Layer to create multiple traces/boxes for different objects to get individual captions | |
| ### π§ Technical Details: | |
| - **Trace Mode**: Converts drawings to normalized (x, y) coordinates | |
| - **BBox Mode**: Uses bounding box coordinates for region-specific captioning | |
| - **Processing**: Each trace/bbox is processed separately to generate corresponding captions. Aggregated region representations also attend to the global image context. | |
| ### Use the Patch-ioner framework for you projects | |
| - just use `pip install git+https://github.com/Ruggero1912/Patch-ioner` to install the Patch-ioner package | |
| - check the [official project webpage](https://paciosoft.com/Patch-ioner/) and the [GitHub repository](https://github.com/Ruggero1912/Patch-ioner) for more details | |
| """) | |
| return demo | |
| if __name__ == "__main__": | |
| import argparse | |
| parser = argparse.ArgumentParser(description="Patchioner Trace Captioning Demo") | |
| parser.add_argument("--port", type=int, default=4141, help="Port to run the Gradio app on") | |
| parser.add_argument("--local", action="store_true", help="Run the app locally. If not set, the app will be use default values for Gradio sharing host and ports.") | |
| args = parser.parse_args() | |
| print("Starting Patchioner Trace Captioning Demo...") | |
| print(f"Using device: {device}") | |
| print(f"Default model: {DEFAULT_MODEL_CONFIG}") | |
| print(f"Example images directory: {EXAMPLE_IMAGES_DIR}") | |
| print(f"Configs directory: {CONFIGS_DIR}") | |
| # Initial cleanup of old Gradio cache folders on startup | |
| print("π§Ή Cleaning up old cache folders...") | |
| cleanup_gradio_cache(max_folders=20) # Very aggressive cleanup on startup | |
| demo = create_gradio_interface(DEFAULT_MODEL_CONFIG) | |
| if not args.local: | |
| demo.launch() | |
| else: | |
| demo.launch( | |
| server_name="0.0.0.0", | |
| server_port=args.port, | |
| share=True, | |
| debug=True | |
| ) | |