UnrealEye / app.py
thrimurthi2025's picture
Update app.py
e2caa40 verified
raw
history blame
15.9 kB
# unreal_explain_gradio.py
import gradio as gr
from transformers import pipeline, AutoImageProcessor, AutoModelForImageClassification
from PIL import Image
import traceback
import time
import threading
import torch
import torch.nn.functional as F
import numpy as np
import io
import base64
import cv2
# ---------- Configuration ----------
# If any of your Hugging Face models are private, set HF_TOKEN = "<YOUR_TOKEN>"
HF_TOKEN = None # or "hf_xxx" if needed
models = [
("Ateeqq/ai-vs-human-image-detector", "ateeq"),
("umm-maybe/AI-image-detector", "umm_maybe"),
("dima806/ai_vs_human_generated_image_detection", "dimma"),
]
# ---------- Helper functions for explainability ----------
def find_last_conv(module):
last = None
for name, m in module.named_modules():
if isinstance(m, torch.nn.Conv2d):
last = m
return last
class GradCAM:
def __init__(self, model, target_layer):
self.model = model
self.target_layer = target_layer
self.activations = None
self.gradients = None
# register hooks
target_layer.register_forward_hook(self._save_activation)
# backward hook signature differs by torch version
try:
target_layer.register_backward_hook(self._save_gradient)
except Exception:
target_layer.register_full_backward_hook(self._save_gradient)
def _save_activation(self, module, input, output):
self.activations = output.detach()
def _save_gradient(self, module, grad_input, grad_output):
# grad_output can be tuple
self.gradients = grad_output[0].detach()
def __call__(self, input_tensor, class_idx=None):
self.activations = None
self.gradients = None
# forward
logits = self.model(input_tensor.unsqueeze(0))
# transformers models return objects, handle both
if hasattr(logits, "logits"):
logits_tensor = logits.logits
else:
logits_tensor = logits
if class_idx is None:
class_idx = int(torch.argmax(logits_tensor, dim=1).item())
# backward
self.model.zero_grad()
score = logits_tensor[0, class_idx]
score.backward(retain_graph=False)
# compute weights
pooled_grads = torch.mean(self.gradients[0], dim=(1,2)) # C
activ = self.activations[0].cpu()
for i in range(activ.shape[0]):
activ[i, :, :] *= pooled_grads[i].cpu()
heatmap = torch.sum(activ, dim=0).cpu().numpy()
heatmap = np.maximum(heatmap, 0)
heatmap = heatmap - np.min(heatmap)
denom = (np.max(heatmap) + 1e-8)
heatmap = heatmap / denom
return heatmap, int(class_idx), logits_tensor
def overlay_heatmap_on_pil(orig_pil, heatmap, alpha=0.45):
orig = np.array(orig_pil.convert("RGB"))
heatmap_resized = cv2.resize(heatmap, (orig.shape[1], orig.shape[0]))
heatmap_u8 = np.uint8(255 * heatmap_resized)
colored = cv2.applyColorMap(heatmap_u8, cv2.COLORMAP_JET)
colored = cv2.cvtColor(colored, cv2.COLOR_BGR2RGB)
overlay = np.uint8(orig * (1 - alpha) + colored * alpha)
return Image.fromarray(overlay)
# Attention rollout for ViT-style models
def attention_rollout_from_attentions(attentions, discard_ratio=0.9):
"""
attentions: tuple/list of tensors, each shape (batch, heads, seq, seq)
returns token-to-token rollout matrix shape (seq, seq)
"""
# Convert to numpy arrays, avg heads
result = None
for attn in attentions:
# attn shape (batch, heads, seq, seq)
a = attn[0].mean(0).detach().cpu().numpy() # (seq, seq)
# optionally remove low weights
a = np.maximum(a, 0)
a = a / (a.sum(-1, keepdims=True) + 1e-8)
if result is None:
result = a
else:
result = a @ result
return result
def vit_attention_heatmap(processor, model, image: Image.Image):
# preprocess
inputs = processor(images=image, return_tensors="pt")
# call model with output_attentions=True
outputs = model(**inputs, output_attentions=True)
if not hasattr(outputs, "attentions") or outputs.attentions is None:
return None
rollout = attention_rollout_from_attentions(outputs.attentions)
# rollout shape (seq, seq). First token is CLS — we use CLS attention to patches.
cls_attention = rollout[0, 1:] # skip CLS->CLS token
# map patch attention to image heatmap
# get image size and patch grid shape from processor/model config
try:
config = model.config
if hasattr(config, "image_size"):
image_size = config.image_size
else:
image_size = processor.size.get("shortest_edge", 224) if hasattr(processor, "size") else 224
patch_size = config.patch_size if hasattr(config, "patch_size") else 16
except Exception:
image_size = 224
patch_size = 16
grid_size = int(image_size // patch_size)
# if tokens don't match product, try sqrt
if cls_attention.shape[0] != grid_size * grid_size:
# fallback: reshape by nearest square
n = int(np.sqrt(cls_attention.shape[0]))
grid_size = n
heatmap = cls_attention.reshape(grid_size, grid_size)
heatmap = heatmap - heatmap.min()
heatmap = heatmap / (heatmap.max() + 1e-8)
return heatmap
# ---------- Load pipelines and also underlying models/processors ----------
pipes = [] # (model_id, pipeline)
hf_models = {} # model_id -> (processor, model, explain_type)
for model_id, short in models:
try:
p = pipeline("image-classification", model=model_id, use_auth_token=HF_TOKEN)
pipes.append((model_id, p))
print(f"Loaded pipeline {model_id}")
except Exception as e:
print(f"Error loading pipeline for {model_id}: {e}")
# try to load processor + raw model for explainability
try:
processor = AutoImageProcessor.from_pretrained(model_id, use_auth_token=HF_TOKEN)
except Exception:
# older HF spacing: AutoFeatureExtractor fallback
try:
from transformers import AutoFeatureExtractor
processor = AutoFeatureExtractor.from_pretrained(model_id, use_auth_token=HF_TOKEN)
except Exception:
processor = None
try:
raw_model = AutoModelForImageClassification.from_pretrained(model_id, use_auth_token=HF_TOKEN)
raw_model.eval()
# attempt to detect conv layers
# try to find a backbone / base model
base = None
for candidate in ("base_model", "backbone", "model", "vit", "resnet", "conv_stem"):
if hasattr(raw_model, candidate):
base = getattr(raw_model, candidate)
break
if base is None:
base = raw_model
last_conv = find_last_conv(base)
if last_conv is not None:
explain_type = "gradcam"
explain_helper = GradCAM(raw_model, last_conv)
print(f"{model_id} -> Grad-CAM available")
else:
# try transformer attention route
# check config for is_vit
cfg = raw_model.config
if getattr(cfg, "architectures", None) and any("ViT" in a or "VisionTransformer" in a for a in cfg.architectures):
explain_type = "vit"
explain_helper = None
print(f"{model_id} -> ViT | will use attention rollout")
else:
# fallback: no explainability
explain_type = "none"
explain_helper = None
print(f"{model_id} -> No explainability (no convs and not ViT)")
except Exception as e:
print(f"Couldn't load raw hf model for {model_id}: {e}")
raw_model = None
processor = None
explain_type = "none"
explain_helper = None
hf_models[model_id] = {
"processor": processor,
"model": raw_model,
"explain_type": explain_type,
"helper": explain_helper
}
# ---------- original predict function updated to produce overlay ----------
def predict_image_with_explain(image: Image.Image):
try:
# run all pipelines to get consensus / first result for UI
results = []
for model_id, pipe in pipes:
try:
res = pipe(image)[0]
results.append((model_id, res))
except Exception as e:
results.append((model_id, {"label": "error", "score": 0.0}))
# pick first result for the main verdict (like before)
final_model_id, final_res = results[0]
label = final_res.get("label", "").lower()
score = final_res.get("score", 0.0) * 100
if "ai" in label or "fake" in label:
verdict = f"🧠 AI-Generated ({score:.1f}% confidence)"
color = "#007BFF"
else:
verdict = f"🧍 Human-Made ({score:.1f}% confidence)"
color = "#4CAF50"
# Try to compute explainability overlay from the corresponding HF model if available
explain_entry = hf_models.get(final_model_id)
overlay_data_uri = None
explain_reason = None
if explain_entry and explain_entry["explain_type"] == "gradcam" and explain_entry["helper"] is not None:
try:
# preprocess: use processor if present, else fallback to torchvision transforms
proc = explain_entry["processor"]
raw_model = explain_entry["model"]
if proc is not None:
inputs = proc(images=image, return_tensors="pt")
input_tensor = inputs["pixel_values"][0] if "pixel_values" in inputs else inputs["input_tensor"][0]
else:
# fallback resize + normalize similar to common models
from torchvision import transforms
pre = transforms.Compose([
transforms.Resize((224,224)),
transforms.ToTensor(),
transforms.Normalize([0.485,0.456,0.406],[0.229,0.224,0.225])
])
input_tensor = pre(image)
grad_helper = explain_entry["helper"]
heatmap, class_idx, logits = grad_helper(input_tensor)
# overlay
overlay_img = overlay_heatmap_on_pil(image, heatmap, alpha=0.45)
buf = io.BytesIO()
overlay_img.save(buf, format="PNG")
overlay_b64 = base64.b64encode(buf.getvalue()).decode("utf-8")
overlay_data_uri = "data:image/png;base64," + overlay_b64
explain_reason = "Grad-CAM heatmap (activations)"
except Exception as e:
traceback.print_exc()
explain_reason = f"Grad-CAM failed: {e}"
elif explain_entry and explain_entry["explain_type"] == "vit" and explain_entry["model"] is not None:
try:
proc = explain_entry["processor"]
raw_model = explain_entry["model"]
heatmap = vit_attention_heatmap(proc, raw_model, image)
if heatmap is not None:
overlay_img = overlay_heatmap_on_pil(image, heatmap, alpha=0.45)
buf = io.BytesIO()
overlay_img.save(buf, format="PNG")
overlay_b64 = base64.b64encode(buf.getvalue()).decode("utf-8")
overlay_data_uri = "data:image/png;base64," + overlay_b64
explain_reason = "ViT attention rollout heatmap"
except Exception as e:
traceback.print_exc()
explain_reason = f"ViT rollout failed: {e}"
# Build HTML for verdict box
html = f"""
<div class='result-box' style="
background: linear-gradient(135deg, {color}33, #1a1a1a);
border: 2px solid {color};
border-radius: 15px;
padding: 20px;
text-align: center;
color: white;
font-size: 18px;
font-weight: 600;
box-shadow: 0 0 20px {color}55;
animation: fadeIn 0.6s ease-in-out;
">
{verdict}
<div style="font-size:12px; margin-top:8px; font-weight:400; opacity:0.9;">
Model: <b>{final_model_id}</b> — Score by model: {score:.1f}%
</div>
</div>
"""
return {
"html": html,
"overlay": overlay_data_uri,
"explain_reason": explain_reason or ""
}
except Exception as e:
traceback.print_exc()
return {"html": f"<div style='color:red;'>Error analyzing image: {str(e)}</div>", "overlay": None, "explain_reason": ""}
# ---------- Gradio UI ----------
css = """
body, .gradio-container {
font-family: 'Poppins', sans-serif !important;
background: transparent !important;
}
h1 {
text-align: center;
font-weight: 700;
color: #007BFF;
margin-bottom: 10px;
}
.gr-button-primary {
background-color: #007BFF !important;
color: white !important;
font-weight: 600;
border-radius: 10px;
height: 45px;
}
.gr-button-secondary {
background-color: #dc3545 !important;
color: white !important;
border-radius: 10px;
height: 45px;
}
#pulse-loader {
width: 100%;
height: 4px;
background: linear-gradient(90deg, #007BFF, #00C3FF);
animation: pulse 1.2s infinite ease-in-out;
border-radius: 2px;
box-shadow: 0 0 10px #007BFF;
}
@keyframes pulse {
0% { transform: scaleX(0.1); opacity: 0.6; }
50% { transform: scaleX(1); opacity: 1; }
100% { transform: scaleX(0.1); opacity: 0.6; }
}
@keyframes fadeIn {
from { opacity: 0; transform: scale(0.95); }
to { opacity: 1; transform: scale(1); }
}
"""
with gr.Blocks(css=css, theme=gr.themes.Soft()) as demo:
gr.Markdown("<h1>🔍 AI Image Detector w/ Explainability</h1>")
with gr.Row():
with gr.Column(scale=1):
image_input = gr.Image(type="pil", label="Upload an image")
analyze_button = gr.Button("Analyze", variant="primary")
clear_button = gr.Button("Clear", variant="secondary")
loader = gr.HTML("")
gr.Markdown("Opacity:")
opacity = gr.Slider(minimum=0, maximum=1, value=0.6, step=0.05)
with gr.Column(scale=1):
# show original image plus overlay using HTML
image_display = gr.Image(type="pil", label="Original / Overlay", interactive=False)
output_html = gr.HTML(label="Result")
explanation_text = gr.Textbox(label="Explainability", interactive=False)
def analyze(img, op):
if img is None:
return (None, "<div style='color:red;'>Please upload an image first!</div>", "")
loader_html = "<div id='pulse-loader'></div>"
# show loader
yield (None, loader_html, "")
# run analysis
out = predict_image_with_explain(img)
# overlay image if available
overlay_uri = out.get("overlay")
if overlay_uri:
# convert data uri to PIL for gr.Image output
header, b64 = overlay_uri.split(",", 1)
overlay_bytes = base64.b64decode(b64)
overlay_img = Image.open(io.BytesIO(overlay_bytes)).convert("RGB")
else:
overlay_img = img # fallback: show orig
# explanation text
explain_reason = out.get("explain_reason", "")
html = out.get("html", "")
# yield overlay image, html, explanation string
yield (overlay_img, html, explain_reason)
analyze_button.click(analyze, inputs=[image_input, opacity], outputs=[image_display, output_html, explanation_text])
clear_button.click(lambda: (None, "", ""), outputs=[image_display, output_html, explanation_text])
demo.launch()