Spaces:
Running
Running
| # unreal_explain_gradio.py | |
| import gradio as gr | |
| from transformers import pipeline, AutoImageProcessor, AutoModelForImageClassification | |
| from PIL import Image | |
| import traceback | |
| import time | |
| import threading | |
| import torch | |
| import torch.nn.functional as F | |
| import numpy as np | |
| import io | |
| import base64 | |
| import cv2 | |
| # ---------- Configuration ---------- | |
| # If any of your Hugging Face models are private, set HF_TOKEN = "<YOUR_TOKEN>" | |
| HF_TOKEN = None # or "hf_xxx" if needed | |
| models = [ | |
| ("Ateeqq/ai-vs-human-image-detector", "ateeq"), | |
| ("umm-maybe/AI-image-detector", "umm_maybe"), | |
| ("dima806/ai_vs_human_generated_image_detection", "dimma"), | |
| ] | |
| # ---------- Helper functions for explainability ---------- | |
| def find_last_conv(module): | |
| last = None | |
| for name, m in module.named_modules(): | |
| if isinstance(m, torch.nn.Conv2d): | |
| last = m | |
| return last | |
| class GradCAM: | |
| def __init__(self, model, target_layer): | |
| self.model = model | |
| self.target_layer = target_layer | |
| self.activations = None | |
| self.gradients = None | |
| # register hooks | |
| target_layer.register_forward_hook(self._save_activation) | |
| # backward hook signature differs by torch version | |
| try: | |
| target_layer.register_backward_hook(self._save_gradient) | |
| except Exception: | |
| target_layer.register_full_backward_hook(self._save_gradient) | |
| def _save_activation(self, module, input, output): | |
| self.activations = output.detach() | |
| def _save_gradient(self, module, grad_input, grad_output): | |
| # grad_output can be tuple | |
| self.gradients = grad_output[0].detach() | |
| def __call__(self, input_tensor, class_idx=None): | |
| self.activations = None | |
| self.gradients = None | |
| # forward | |
| logits = self.model(input_tensor.unsqueeze(0)) | |
| # transformers models return objects, handle both | |
| if hasattr(logits, "logits"): | |
| logits_tensor = logits.logits | |
| else: | |
| logits_tensor = logits | |
| if class_idx is None: | |
| class_idx = int(torch.argmax(logits_tensor, dim=1).item()) | |
| # backward | |
| self.model.zero_grad() | |
| score = logits_tensor[0, class_idx] | |
| score.backward(retain_graph=False) | |
| # compute weights | |
| pooled_grads = torch.mean(self.gradients[0], dim=(1,2)) # C | |
| activ = self.activations[0].cpu() | |
| for i in range(activ.shape[0]): | |
| activ[i, :, :] *= pooled_grads[i].cpu() | |
| heatmap = torch.sum(activ, dim=0).cpu().numpy() | |
| heatmap = np.maximum(heatmap, 0) | |
| heatmap = heatmap - np.min(heatmap) | |
| denom = (np.max(heatmap) + 1e-8) | |
| heatmap = heatmap / denom | |
| return heatmap, int(class_idx), logits_tensor | |
| def overlay_heatmap_on_pil(orig_pil, heatmap, alpha=0.45): | |
| orig = np.array(orig_pil.convert("RGB")) | |
| heatmap_resized = cv2.resize(heatmap, (orig.shape[1], orig.shape[0])) | |
| heatmap_u8 = np.uint8(255 * heatmap_resized) | |
| colored = cv2.applyColorMap(heatmap_u8, cv2.COLORMAP_JET) | |
| colored = cv2.cvtColor(colored, cv2.COLOR_BGR2RGB) | |
| overlay = np.uint8(orig * (1 - alpha) + colored * alpha) | |
| return Image.fromarray(overlay) | |
| # Attention rollout for ViT-style models | |
| def attention_rollout_from_attentions(attentions, discard_ratio=0.9): | |
| """ | |
| attentions: tuple/list of tensors, each shape (batch, heads, seq, seq) | |
| returns token-to-token rollout matrix shape (seq, seq) | |
| """ | |
| # Convert to numpy arrays, avg heads | |
| result = None | |
| for attn in attentions: | |
| # attn shape (batch, heads, seq, seq) | |
| a = attn[0].mean(0).detach().cpu().numpy() # (seq, seq) | |
| # optionally remove low weights | |
| a = np.maximum(a, 0) | |
| a = a / (a.sum(-1, keepdims=True) + 1e-8) | |
| if result is None: | |
| result = a | |
| else: | |
| result = a @ result | |
| return result | |
| def vit_attention_heatmap(processor, model, image: Image.Image): | |
| # preprocess | |
| inputs = processor(images=image, return_tensors="pt") | |
| # call model with output_attentions=True | |
| outputs = model(**inputs, output_attentions=True) | |
| if not hasattr(outputs, "attentions") or outputs.attentions is None: | |
| return None | |
| rollout = attention_rollout_from_attentions(outputs.attentions) | |
| # rollout shape (seq, seq). First token is CLS — we use CLS attention to patches. | |
| cls_attention = rollout[0, 1:] # skip CLS->CLS token | |
| # map patch attention to image heatmap | |
| # get image size and patch grid shape from processor/model config | |
| try: | |
| config = model.config | |
| if hasattr(config, "image_size"): | |
| image_size = config.image_size | |
| else: | |
| image_size = processor.size.get("shortest_edge", 224) if hasattr(processor, "size") else 224 | |
| patch_size = config.patch_size if hasattr(config, "patch_size") else 16 | |
| except Exception: | |
| image_size = 224 | |
| patch_size = 16 | |
| grid_size = int(image_size // patch_size) | |
| # if tokens don't match product, try sqrt | |
| if cls_attention.shape[0] != grid_size * grid_size: | |
| # fallback: reshape by nearest square | |
| n = int(np.sqrt(cls_attention.shape[0])) | |
| grid_size = n | |
| heatmap = cls_attention.reshape(grid_size, grid_size) | |
| heatmap = heatmap - heatmap.min() | |
| heatmap = heatmap / (heatmap.max() + 1e-8) | |
| return heatmap | |
| # ---------- Load pipelines and also underlying models/processors ---------- | |
| pipes = [] # (model_id, pipeline) | |
| hf_models = {} # model_id -> (processor, model, explain_type) | |
| for model_id, short in models: | |
| try: | |
| p = pipeline("image-classification", model=model_id, use_auth_token=HF_TOKEN) | |
| pipes.append((model_id, p)) | |
| print(f"Loaded pipeline {model_id}") | |
| except Exception as e: | |
| print(f"Error loading pipeline for {model_id}: {e}") | |
| # try to load processor + raw model for explainability | |
| try: | |
| processor = AutoImageProcessor.from_pretrained(model_id, use_auth_token=HF_TOKEN) | |
| except Exception: | |
| # older HF spacing: AutoFeatureExtractor fallback | |
| try: | |
| from transformers import AutoFeatureExtractor | |
| processor = AutoFeatureExtractor.from_pretrained(model_id, use_auth_token=HF_TOKEN) | |
| except Exception: | |
| processor = None | |
| try: | |
| raw_model = AutoModelForImageClassification.from_pretrained(model_id, use_auth_token=HF_TOKEN) | |
| raw_model.eval() | |
| # attempt to detect conv layers | |
| # try to find a backbone / base model | |
| base = None | |
| for candidate in ("base_model", "backbone", "model", "vit", "resnet", "conv_stem"): | |
| if hasattr(raw_model, candidate): | |
| base = getattr(raw_model, candidate) | |
| break | |
| if base is None: | |
| base = raw_model | |
| last_conv = find_last_conv(base) | |
| if last_conv is not None: | |
| explain_type = "gradcam" | |
| explain_helper = GradCAM(raw_model, last_conv) | |
| print(f"{model_id} -> Grad-CAM available") | |
| else: | |
| # try transformer attention route | |
| # check config for is_vit | |
| cfg = raw_model.config | |
| if getattr(cfg, "architectures", None) and any("ViT" in a or "VisionTransformer" in a for a in cfg.architectures): | |
| explain_type = "vit" | |
| explain_helper = None | |
| print(f"{model_id} -> ViT | will use attention rollout") | |
| else: | |
| # fallback: no explainability | |
| explain_type = "none" | |
| explain_helper = None | |
| print(f"{model_id} -> No explainability (no convs and not ViT)") | |
| except Exception as e: | |
| print(f"Couldn't load raw hf model for {model_id}: {e}") | |
| raw_model = None | |
| processor = None | |
| explain_type = "none" | |
| explain_helper = None | |
| hf_models[model_id] = { | |
| "processor": processor, | |
| "model": raw_model, | |
| "explain_type": explain_type, | |
| "helper": explain_helper | |
| } | |
| # ---------- original predict function updated to produce overlay ---------- | |
| def predict_image_with_explain(image: Image.Image): | |
| try: | |
| # run all pipelines to get consensus / first result for UI | |
| results = [] | |
| for model_id, pipe in pipes: | |
| try: | |
| res = pipe(image)[0] | |
| results.append((model_id, res)) | |
| except Exception as e: | |
| results.append((model_id, {"label": "error", "score": 0.0})) | |
| # pick first result for the main verdict (like before) | |
| final_model_id, final_res = results[0] | |
| label = final_res.get("label", "").lower() | |
| score = final_res.get("score", 0.0) * 100 | |
| if "ai" in label or "fake" in label: | |
| verdict = f"🧠 AI-Generated ({score:.1f}% confidence)" | |
| color = "#007BFF" | |
| else: | |
| verdict = f"🧍 Human-Made ({score:.1f}% confidence)" | |
| color = "#4CAF50" | |
| # Try to compute explainability overlay from the corresponding HF model if available | |
| explain_entry = hf_models.get(final_model_id) | |
| overlay_data_uri = None | |
| explain_reason = None | |
| if explain_entry and explain_entry["explain_type"] == "gradcam" and explain_entry["helper"] is not None: | |
| try: | |
| # preprocess: use processor if present, else fallback to torchvision transforms | |
| proc = explain_entry["processor"] | |
| raw_model = explain_entry["model"] | |
| if proc is not None: | |
| inputs = proc(images=image, return_tensors="pt") | |
| input_tensor = inputs["pixel_values"][0] if "pixel_values" in inputs else inputs["input_tensor"][0] | |
| else: | |
| # fallback resize + normalize similar to common models | |
| from torchvision import transforms | |
| pre = transforms.Compose([ | |
| transforms.Resize((224,224)), | |
| transforms.ToTensor(), | |
| transforms.Normalize([0.485,0.456,0.406],[0.229,0.224,0.225]) | |
| ]) | |
| input_tensor = pre(image) | |
| grad_helper = explain_entry["helper"] | |
| heatmap, class_idx, logits = grad_helper(input_tensor) | |
| # overlay | |
| overlay_img = overlay_heatmap_on_pil(image, heatmap, alpha=0.45) | |
| buf = io.BytesIO() | |
| overlay_img.save(buf, format="PNG") | |
| overlay_b64 = base64.b64encode(buf.getvalue()).decode("utf-8") | |
| overlay_data_uri = "data:image/png;base64," + overlay_b64 | |
| explain_reason = "Grad-CAM heatmap (activations)" | |
| except Exception as e: | |
| traceback.print_exc() | |
| explain_reason = f"Grad-CAM failed: {e}" | |
| elif explain_entry and explain_entry["explain_type"] == "vit" and explain_entry["model"] is not None: | |
| try: | |
| proc = explain_entry["processor"] | |
| raw_model = explain_entry["model"] | |
| heatmap = vit_attention_heatmap(proc, raw_model, image) | |
| if heatmap is not None: | |
| overlay_img = overlay_heatmap_on_pil(image, heatmap, alpha=0.45) | |
| buf = io.BytesIO() | |
| overlay_img.save(buf, format="PNG") | |
| overlay_b64 = base64.b64encode(buf.getvalue()).decode("utf-8") | |
| overlay_data_uri = "data:image/png;base64," + overlay_b64 | |
| explain_reason = "ViT attention rollout heatmap" | |
| except Exception as e: | |
| traceback.print_exc() | |
| explain_reason = f"ViT rollout failed: {e}" | |
| # Build HTML for verdict box | |
| html = f""" | |
| <div class='result-box' style=" | |
| background: linear-gradient(135deg, {color}33, #1a1a1a); | |
| border: 2px solid {color}; | |
| border-radius: 15px; | |
| padding: 20px; | |
| text-align: center; | |
| color: white; | |
| font-size: 18px; | |
| font-weight: 600; | |
| box-shadow: 0 0 20px {color}55; | |
| animation: fadeIn 0.6s ease-in-out; | |
| "> | |
| {verdict} | |
| <div style="font-size:12px; margin-top:8px; font-weight:400; opacity:0.9;"> | |
| Model: <b>{final_model_id}</b> — Score by model: {score:.1f}% | |
| </div> | |
| </div> | |
| """ | |
| return { | |
| "html": html, | |
| "overlay": overlay_data_uri, | |
| "explain_reason": explain_reason or "" | |
| } | |
| except Exception as e: | |
| traceback.print_exc() | |
| return {"html": f"<div style='color:red;'>Error analyzing image: {str(e)}</div>", "overlay": None, "explain_reason": ""} | |
| # ---------- Gradio UI ---------- | |
| css = """ | |
| body, .gradio-container { | |
| font-family: 'Poppins', sans-serif !important; | |
| background: transparent !important; | |
| } | |
| h1 { | |
| text-align: center; | |
| font-weight: 700; | |
| color: #007BFF; | |
| margin-bottom: 10px; | |
| } | |
| .gr-button-primary { | |
| background-color: #007BFF !important; | |
| color: white !important; | |
| font-weight: 600; | |
| border-radius: 10px; | |
| height: 45px; | |
| } | |
| .gr-button-secondary { | |
| background-color: #dc3545 !important; | |
| color: white !important; | |
| border-radius: 10px; | |
| height: 45px; | |
| } | |
| #pulse-loader { | |
| width: 100%; | |
| height: 4px; | |
| background: linear-gradient(90deg, #007BFF, #00C3FF); | |
| animation: pulse 1.2s infinite ease-in-out; | |
| border-radius: 2px; | |
| box-shadow: 0 0 10px #007BFF; | |
| } | |
| @keyframes pulse { | |
| 0% { transform: scaleX(0.1); opacity: 0.6; } | |
| 50% { transform: scaleX(1); opacity: 1; } | |
| 100% { transform: scaleX(0.1); opacity: 0.6; } | |
| } | |
| @keyframes fadeIn { | |
| from { opacity: 0; transform: scale(0.95); } | |
| to { opacity: 1; transform: scale(1); } | |
| } | |
| """ | |
| with gr.Blocks(css=css, theme=gr.themes.Soft()) as demo: | |
| gr.Markdown("<h1>🔍 AI Image Detector w/ Explainability</h1>") | |
| with gr.Row(): | |
| with gr.Column(scale=1): | |
| image_input = gr.Image(type="pil", label="Upload an image") | |
| analyze_button = gr.Button("Analyze", variant="primary") | |
| clear_button = gr.Button("Clear", variant="secondary") | |
| loader = gr.HTML("") | |
| gr.Markdown("Opacity:") | |
| opacity = gr.Slider(minimum=0, maximum=1, value=0.6, step=0.05) | |
| with gr.Column(scale=1): | |
| # show original image plus overlay using HTML | |
| image_display = gr.Image(type="pil", label="Original / Overlay", interactive=False) | |
| output_html = gr.HTML(label="Result") | |
| explanation_text = gr.Textbox(label="Explainability", interactive=False) | |
| def analyze(img, op): | |
| if img is None: | |
| return (None, "<div style='color:red;'>Please upload an image first!</div>", "") | |
| loader_html = "<div id='pulse-loader'></div>" | |
| # show loader | |
| yield (None, loader_html, "") | |
| # run analysis | |
| out = predict_image_with_explain(img) | |
| # overlay image if available | |
| overlay_uri = out.get("overlay") | |
| if overlay_uri: | |
| # convert data uri to PIL for gr.Image output | |
| header, b64 = overlay_uri.split(",", 1) | |
| overlay_bytes = base64.b64decode(b64) | |
| overlay_img = Image.open(io.BytesIO(overlay_bytes)).convert("RGB") | |
| else: | |
| overlay_img = img # fallback: show orig | |
| # explanation text | |
| explain_reason = out.get("explain_reason", "") | |
| html = out.get("html", "") | |
| # yield overlay image, html, explanation string | |
| yield (overlay_img, html, explain_reason) | |
| analyze_button.click(analyze, inputs=[image_input, opacity], outputs=[image_display, output_html, explanation_text]) | |
| clear_button.click(lambda: (None, "", ""), outputs=[image_display, output_html, explanation_text]) | |
| demo.launch() |