import gradio as gr import pandas as pd from PIL import Image from model_registry import ( ALL_CATEGORIES, DEFAULT_THRESHOLD, REGISTRY, get_model, NUDENET_ONLY ) from video_utils import ( has_ffmpeg, probe_duration, extract_frames_ffmpeg, runs_from_indices, merge_seconds_union, redact_with_ffmpeg ) import os try: from huggingface_hub import login tok = os.getenv("HF_TOKEN") or os.getenv("HUGGINGFACE_HUB_TOKEN") if tok: login(tok) except Exception: pass APP_TITLE = "Content Moderation Demo (Image & Video)" APP_DESC = """ Minimal prototype: image/video analysis, model & category selection, and threshold control. """ MODEL_NAMES = list(REGISTRY.keys()) # ---------- Shared ---------- def on_model_change(model_name): if model_name in NUDENET_ONLY: cats_state = gr.CheckboxGroup(choices=["sexual"], value=["sexual"], interactive=False, label="Categories") else: cats_state = gr.CheckboxGroup(choices=ALL_CATEGORIES, value=ALL_CATEGORIES, interactive=True, label="Categories") th = DEFAULT_THRESHOLD return cats_state, gr.Slider(minimum=0.0, maximum=1.0, value=th, step=0.01, label="Threshold") # ---------- Image ---------- def analyze_image(model_name, image, selected_categories, threshold): if image is None: return "No image.", None, gr.update(visible=False) pil = Image.fromarray(image) if not isinstance(image, Image.Image) else image model = get_model(model_name) allowed = set(getattr(model, "categories", ALL_CATEGORIES)) req = [c for c in selected_categories if c in allowed] if not req: return "No categories selected.", None, gr.update(visible=False) scores = model.predict_image(pil, req) verdict = "RISKY" if any(v >= threshold for v in scores.values()) else "SAFE" df = pd.DataFrame([{"category": k, "score": f"{(float(v)*100):.1f}%"} for k, v in sorted(scores.items())]) if getattr(model, "supports_selected_tags", False): extra = model.extra_selected_tags(pil, top_k=15) txt = "\n".join(f"- {t}: {s:.3f}" for t, s in extra) return verdict, df, gr.update(visible=True, value=txt) else: return verdict, df, gr.update(visible=False) # ---------- Video ---------- def analyze_video(model_name, video_file, selected_categories, threshold, sampling_fps, redact): import tempfile, os, shutil if video_file is None: return pd.DataFrame([{"segment":"Error: No video."}]), gr.update(value=None) dur = probe_duration(video_file) if dur is not None and dur > 60.0: return pd.DataFrame([{"segment":"Error: Video too long (limit: 60s)."}]), gr.update(value=None) model = get_model(model_name) allowed = set(getattr(model, "categories", ALL_CATEGORIES)) req = [c for c in selected_categories if c in allowed] if not req: return pd.DataFrame([{"segment":"Error: No categories selected."}]), gr.update(value=None) with tempfile.TemporaryDirectory() as td: try: frames = extract_frames_ffmpeg(video_file, sampling_fps, os.path.join(td, "frames")) except Exception: return pd.DataFrame([{"segment":"Error: FFmpeg not available or failed to extract frames."}]), gr.update(value=None) all_hit_idx: list[int] = [] frame_stats: dict[int, dict] = {} for fp, idx in frames: with Image.open(fp) as im: pil = im.convert("RGB") scores = model.predict_image(pil, req) over = {c: float(scores.get(c, 0.0)) for c in req if float(scores.get(c, 0.0)) >= threshold} if over: all_hit_idx.append(idx) peak_cat, peak_p = max(over.items(), key=lambda kv: kv[1]) frame_stats[idx] = {"hits": over, "peak_cat": peak_cat, "peak_p": peak_p} if not all_hit_idx: return pd.DataFrame([{"segment":"(no hits)"}]), gr.update(value=None) union_runs = runs_from_indices(sorted(set(all_hit_idx))) rows = [] for seg_id, (a, b) in enumerate(union_runs, start=1): for i in range(a, b + 1): st = frame_stats.get(i) if not st: continue cat_counts = {c: 0 for c in req} cat_maxp = {c: 0.0 for c in req} for i in range(a, b + 1): st = frame_stats.get(i) if not st: continue for c, p in st["hits"].items(): cat_counts[c] += 1 if p > cat_maxp[c]: cat_maxp[c] = p present = [c for c in req if cat_counts[c] > 0] present.sort(key=lambda c: (-cat_counts[c], -cat_maxp[c], c)) for c in present: rows.append({ "seg": seg_id, "start": round(a / sampling_fps, 3), "end": round((b + 1) / sampling_fps, 3), "category": c, "max_p": round(cat_maxp[c], 3), }) df = pd.DataFrame(rows).sort_values(["seg", "max_p"], ascending=[True, False]).reset_index(drop=True) out_video = gr.update(value=None) if redact and has_ffmpeg(): intervals = merge_seconds_union(all_hit_idx, sampling_fps, pad=0.25) try: out_path = os.path.join(td, "redacted.mp4") redact_with_ffmpeg(video_file, intervals, out_path) final_out = os.path.join(os.getcwd(), "redacted_output.mp4") shutil.copyfile(out_path, final_out) out_video = gr.update(value=final_out) except Exception: out_video = gr.update(value=None) return df, out_video # ---------- UI ---------- with gr.Blocks(title=APP_TITLE, css=".wrap-row { gap: 16px; }") as demo: gr.Markdown(f"# {APP_TITLE}") gr.Markdown(APP_DESC) with gr.Tabs(): with gr.Tab("Image"): with gr.Row(elem_classes=["wrap-row"]): with gr.Column(scale=1, min_width=360): model_dd = gr.Dropdown(label="Model", choices=MODEL_NAMES, value=MODEL_NAMES[0]) threshold = gr.Slider(0.0, 1.0, value=DEFAULT_THRESHOLD, step=0.01, label="Threshold") categories = gr.CheckboxGroup(label="Categories", choices=ALL_CATEGORIES, value=ALL_CATEGORIES) inp_img = gr.Image(type="pil", label="Upload Image") btn = gr.Button("Analyze", variant="primary") with gr.Column(scale=1, min_width=360): verdict = gr.Label(label="Verdict") scores_df = gr.Dataframe(headers=["category", "score"], datatype="str", label="Scores", interactive=False) extra_tags = gr.Textbox(label="Selected tags", visible=False, lines=12) model_dd.change(on_model_change, inputs=model_dd, outputs=[categories, threshold]) btn.click(analyze_image, inputs=[model_dd, inp_img, categories, threshold], outputs=[verdict, scores_df, extra_tags]) with gr.Tab("Video"): with gr.Row(elem_classes=["wrap-row"]): with gr.Column(scale=1, min_width=360): v_model = gr.Dropdown(label="Model", choices=MODEL_NAMES, value=MODEL_NAMES[0]) v_threshold = gr.Slider(0.0, 1.0, value=DEFAULT_THRESHOLD, step=0.01, label="Threshold") v_fps = gr.Slider(0.25, 5.0, value=1.0, step=0.25, label="Sampling FPS") v_redact = gr.Checkbox(label="Redact scenes (requires FFmpeg)", value=False) v_categories = gr.CheckboxGroup(label="Categories", choices=ALL_CATEGORIES, value=ALL_CATEGORIES) v_input = gr.Video(label="Upload short video (≤ 60s)") v_btn = gr.Button("Analyze Video", variant="primary") with gr.Column(scale=1, min_width=360): v_segments = gr.Dataframe(label="Segments", interactive=False) v_out = gr.Video(label="Redacted Video") v_model.change(on_model_change, inputs=v_model, outputs=[v_categories, v_threshold]) v_btn.click(analyze_video, inputs=[v_model, v_input, v_categories, v_threshold, v_fps, v_redact], outputs=[v_segments, v_out]) if __name__ == "__main__": demo.launch()