onuruls
add img examples
7faf6fd
raw
history blame
9.28 kB
import gradio as gr
import pandas as pd
from PIL import Image
from model_registry import (
ALL_CATEGORIES, DEFAULT_THRESHOLD, REGISTRY, get_model, NUDENET_ONLY
)
from video_utils import (
has_ffmpeg, probe_duration, extract_frames_ffmpeg, runs_from_indices,
merge_seconds_union, redact_with_ffmpeg
)
import os
try:
from huggingface_hub import login
tok = os.getenv("HF_TOKEN") or os.getenv("HUGGINGFACE_HUB_TOKEN")
if tok: login(tok)
except Exception:
pass
APP_TITLE = "Content Moderation Demo (Image & Video)"
APP_DESC = """
Minimal prototype: image/video analysis, model & category selection, and threshold control.
"""
MODEL_NAMES = list(REGISTRY.keys())
IMG_EXAMPLES = [
# [model, image_path, categories, threshold]
["clip-multilabel", "examples/gambling_alcohol.jpg", ALL_CATEGORIES, 0.50],
["wdeva02-multilabel", "examples/smoke_alcohol.jpg", ALL_CATEGORIES, 0.50],
["animetimm-multilabel", "examples/gambling_smoke_alcohol.jpg", ALL_CATEGORIES, 0.50],
]
# ---------- Shared ----------
def on_model_change(model_name):
if model_name in NUDENET_ONLY:
cats_state = gr.CheckboxGroup(choices=["sexual"], value=["sexual"], interactive=False, label="Categories")
else:
cats_state = gr.CheckboxGroup(choices=ALL_CATEGORIES, value=ALL_CATEGORIES, interactive=True, label="Categories")
th = DEFAULT_THRESHOLD
return cats_state, gr.Slider(minimum=0.0, maximum=1.0, value=th, step=0.01, label="Threshold")
# ---------- Image ----------
def analyze_image(model_name, image, selected_categories, threshold):
if image is None:
return "No image.", None, gr.update(visible=False)
pil = Image.fromarray(image) if not isinstance(image, Image.Image) else image
model = get_model(model_name)
allowed = set(getattr(model, "categories", ALL_CATEGORIES))
req = [c for c in selected_categories if c in allowed]
if not req:
return "No categories selected.", None, gr.update(visible=False)
scores = model.predict_image(pil, req)
verdict = "RISKY" if any(v >= threshold for v in scores.values()) else "SAFE"
df = pd.DataFrame([{"category": k, "score": f"{(float(v)*100):.1f}%"} for k, v in sorted(scores.items())])
if getattr(model, "supports_selected_tags", False):
extra = model.extra_selected_tags(pil, top_k=15)
txt = "\n".join(f"- {t}: {s:.3f}" for t, s in extra)
return verdict, df, gr.update(visible=True, value=txt)
else:
return verdict, df, gr.update(visible=False)
# ---------- Video ----------
def analyze_video(model_name, video_file, selected_categories, threshold, sampling_fps, redact):
import tempfile, os, shutil
if video_file is None:
return pd.DataFrame([{"segment":"Error: No video."}]), gr.update(value=None)
dur = probe_duration(video_file)
if dur is not None and dur > 60.0:
return pd.DataFrame([{"segment":"Error: Video too long (limit: 60s)."}]), gr.update(value=None)
model = get_model(model_name)
allowed = set(getattr(model, "categories", ALL_CATEGORIES))
req = [c for c in selected_categories if c in allowed]
if not req:
return pd.DataFrame([{"segment":"Error: No categories selected."}]), gr.update(value=None)
with tempfile.TemporaryDirectory() as td:
try:
frames = extract_frames_ffmpeg(video_file, sampling_fps, os.path.join(td, "frames"))
except Exception:
return pd.DataFrame([{"segment":"Error: FFmpeg not available or failed to extract frames."}]), gr.update(value=None)
all_hit_idx: list[int] = []
frame_stats: dict[int, dict] = {}
for fp, idx in frames:
with Image.open(fp) as im:
pil = im.convert("RGB")
scores = model.predict_image(pil, req)
over = {c: float(scores.get(c, 0.0)) for c in req if float(scores.get(c, 0.0)) >= threshold}
if over:
all_hit_idx.append(idx)
peak_cat, peak_p = max(over.items(), key=lambda kv: kv[1])
frame_stats[idx] = {"hits": over, "peak_cat": peak_cat, "peak_p": peak_p}
if not all_hit_idx:
return pd.DataFrame([{"segment":"(no hits)"}]), gr.update(value=None)
union_runs = runs_from_indices(sorted(set(all_hit_idx)))
rows = []
for seg_id, (a, b) in enumerate(union_runs, start=1):
cat_counts = {c: 0 for c in req}
cat_maxp = {c: 0.0 for c in req}
for i in range(a, b + 1):
st = frame_stats.get(i)
if not st:
continue
for c, p in st["hits"].items():
cat_counts[c] += 1
if p > cat_maxp[c]:
cat_maxp[c] = p
present = [c for c in req if cat_counts[c] > 0]
present.sort(key=lambda c: (-cat_counts[c], -cat_maxp[c], c))
for c in present:
rows.append({
"seg": seg_id,
"start": round(a / sampling_fps, 3),
"end": round((b + 1) / sampling_fps, 3),
"category": c,
"max_p": round(cat_maxp[c], 3),
})
df = pd.DataFrame(rows).sort_values(["seg", "max_p"], ascending=[True, False]).reset_index(drop=True)
out_video = gr.update(value=None)
if redact and has_ffmpeg():
intervals = merge_seconds_union(all_hit_idx, sampling_fps, pad=0.25)
try:
out_path = os.path.join(td, "redacted.mp4")
redact_with_ffmpeg(video_file, intervals, out_path)
final_out = os.path.join(os.getcwd(), "redacted_output.mp4")
shutil.copyfile(out_path, final_out)
out_video = gr.update(value=final_out)
except Exception:
out_video = gr.update(value=None)
return df, out_video
# ---------- UI ----------
with gr.Blocks(title=APP_TITLE, css=".wrap-row { gap: 16px; }") as demo:
gr.Markdown(f"# {APP_TITLE}")
gr.Markdown(APP_DESC)
with gr.Tabs():
with gr.Tab("Image"):
with gr.Row(elem_classes=["wrap-row"]):
with gr.Column(scale=1, min_width=360):
model_dd = gr.Dropdown(label="Model", choices=MODEL_NAMES, value=MODEL_NAMES[0])
threshold = gr.Slider(0.0, 1.0, value=DEFAULT_THRESHOLD, step=0.01, label="Threshold")
categories = gr.CheckboxGroup(label="Categories", choices=ALL_CATEGORIES, value=ALL_CATEGORIES)
inp_img = gr.Image(type="pil", label="Upload Image")
btn = gr.Button("Analyze", variant="primary")
with gr.Column(scale=1, min_width=360):
verdict = gr.Label(label="Verdict")
scores_df = gr.Dataframe(headers=["category", "score"], datatype="str",
label="Scores", interactive=False)
extra_tags = gr.Textbox(label="Selected tags", visible=False, lines=12)
model_dd.change(on_model_change, inputs=model_dd, outputs=[categories, threshold])
btn.click(analyze_image, inputs=[model_dd, inp_img, categories, threshold],
outputs=[verdict, scores_df, extra_tags])
gr.Examples(
label="Try an example (Image)",
examples=IMG_EXAMPLES,
inputs=[model_dd, inp_img, categories, threshold],
outputs=[verdict, scores_df, extra_tags],
fn=analyze_image,
run_on_click=True,
cache_examples=False,
)
with gr.Tab("Video"):
with gr.Row(elem_classes=["wrap-row"]):
with gr.Column(scale=1, min_width=360):
v_model = gr.Dropdown(label="Model", choices=MODEL_NAMES, value=MODEL_NAMES[0])
v_threshold = gr.Slider(0.0, 1.0, value=DEFAULT_THRESHOLD,
step=0.01, label="Threshold")
v_fps = gr.Slider(0.25, 5.0, value=1.0, step=0.25, label="Sampling FPS")
v_redact = gr.Checkbox(label="Redact scenes (requires FFmpeg)", value=False)
v_categories = gr.CheckboxGroup(label="Categories", choices=ALL_CATEGORIES, value=ALL_CATEGORIES)
v_input = gr.Video(label="Upload short video (≤ 60s)")
v_btn = gr.Button("Analyze Video", variant="primary")
with gr.Column(scale=1, min_width=360):
v_segments = gr.Dataframe(label="Segments", interactive=False)
v_out = gr.Video(label="Redacted Video")
v_model.change(on_model_change, inputs=v_model, outputs=[v_categories, v_threshold])
v_btn.click(analyze_video, inputs=[v_model, v_input, v_categories, v_threshold, v_fps, v_redact],
outputs=[v_segments, v_out])
if __name__ == "__main__":
demo.launch()