Spaces:
Sleeping
Sleeping
| import gradio as gr | |
| import pandas as pd | |
| from PIL import Image | |
| from model_registry import ( | |
| ALL_CATEGORIES, DEFAULT_THRESHOLD, REGISTRY, get_model, NUDENET_ONLY | |
| ) | |
| from video_utils import ( | |
| has_ffmpeg, probe_duration, extract_frames_ffmpeg, runs_from_indices, | |
| merge_seconds_union, redact_with_ffmpeg | |
| ) | |
| import os | |
| try: | |
| from huggingface_hub import login | |
| tok = os.getenv("HF_TOKEN") or os.getenv("HUGGINGFACE_HUB_TOKEN") | |
| if tok: login(tok) | |
| except Exception: | |
| pass | |
| APP_TITLE = "Content Moderation Demo (Image & Video)" | |
| APP_DESC = """ | |
| Minimal prototype: image/video analysis, model & category selection, and threshold control. | |
| """ | |
| MODEL_NAMES = list(REGISTRY.keys()) | |
| IMG_EXAMPLES = [ | |
| # [model, image_path, categories, threshold] | |
| ["clip-multilabel", "examples/gambling_alcohol.jpg", ALL_CATEGORIES, 0.50], | |
| ["wdeva02-multilabel", "examples/smoke_alcohol.jpg", ALL_CATEGORIES, 0.50], | |
| ["animetimm-multilabel", "examples/gambling_smoke_alcohol.jpg", ALL_CATEGORIES, 0.50], | |
| ] | |
| # ---------- Shared ---------- | |
| def on_model_change(model_name): | |
| if model_name in NUDENET_ONLY: | |
| cats_state = gr.CheckboxGroup(choices=["sexual"], value=["sexual"], interactive=False, label="Categories") | |
| else: | |
| cats_state = gr.CheckboxGroup(choices=ALL_CATEGORIES, value=ALL_CATEGORIES, interactive=True, label="Categories") | |
| th = DEFAULT_THRESHOLD | |
| return cats_state, gr.Slider(minimum=0.0, maximum=1.0, value=th, step=0.01, label="Threshold") | |
| # ---------- Image ---------- | |
| def analyze_image(model_name, image, selected_categories, threshold): | |
| if image is None: | |
| return "No image.", None, gr.update(visible=False) | |
| pil = Image.fromarray(image) if not isinstance(image, Image.Image) else image | |
| model = get_model(model_name) | |
| allowed = set(getattr(model, "categories", ALL_CATEGORIES)) | |
| req = [c for c in selected_categories if c in allowed] | |
| if not req: | |
| return "No categories selected.", None, gr.update(visible=False) | |
| scores = model.predict_image(pil, req) | |
| verdict = "RISKY" if any(v >= threshold for v in scores.values()) else "SAFE" | |
| df = pd.DataFrame([{"category": k, "score": f"{(float(v)*100):.1f}%"} for k, v in sorted(scores.items())]) | |
| if getattr(model, "supports_selected_tags", False): | |
| extra = model.extra_selected_tags(pil, top_k=15) | |
| txt = "\n".join(f"- {t}: {s:.3f}" for t, s in extra) | |
| return verdict, df, gr.update(visible=True, value=txt) | |
| else: | |
| return verdict, df, gr.update(visible=False) | |
| # ---------- Video ---------- | |
| def analyze_video(model_name, video_file, selected_categories, threshold, sampling_fps, redact): | |
| import tempfile, os, shutil | |
| if video_file is None: | |
| return pd.DataFrame([{"segment":"Error: No video."}]), gr.update(value=None) | |
| dur = probe_duration(video_file) | |
| if dur is not None and dur > 60.0: | |
| return pd.DataFrame([{"segment":"Error: Video too long (limit: 60s)."}]), gr.update(value=None) | |
| model = get_model(model_name) | |
| allowed = set(getattr(model, "categories", ALL_CATEGORIES)) | |
| req = [c for c in selected_categories if c in allowed] | |
| if not req: | |
| return pd.DataFrame([{"segment":"Error: No categories selected."}]), gr.update(value=None) | |
| with tempfile.TemporaryDirectory() as td: | |
| try: | |
| frames = extract_frames_ffmpeg(video_file, sampling_fps, os.path.join(td, "frames")) | |
| except Exception: | |
| return pd.DataFrame([{"segment":"Error: FFmpeg not available or failed to extract frames."}]), gr.update(value=None) | |
| all_hit_idx: list[int] = [] | |
| frame_stats: dict[int, dict] = {} | |
| for fp, idx in frames: | |
| with Image.open(fp) as im: | |
| pil = im.convert("RGB") | |
| scores = model.predict_image(pil, req) | |
| over = {c: float(scores.get(c, 0.0)) for c in req if float(scores.get(c, 0.0)) >= threshold} | |
| if over: | |
| all_hit_idx.append(idx) | |
| peak_cat, peak_p = max(over.items(), key=lambda kv: kv[1]) | |
| frame_stats[idx] = {"hits": over, "peak_cat": peak_cat, "peak_p": peak_p} | |
| if not all_hit_idx: | |
| return pd.DataFrame([{"segment":"(no hits)"}]), gr.update(value=None) | |
| union_runs = runs_from_indices(sorted(set(all_hit_idx))) | |
| rows = [] | |
| for seg_id, (a, b) in enumerate(union_runs, start=1): | |
| cat_counts = {c: 0 for c in req} | |
| cat_maxp = {c: 0.0 for c in req} | |
| for i in range(a, b + 1): | |
| st = frame_stats.get(i) | |
| if not st: | |
| continue | |
| for c, p in st["hits"].items(): | |
| cat_counts[c] += 1 | |
| if p > cat_maxp[c]: | |
| cat_maxp[c] = p | |
| present = [c for c in req if cat_counts[c] > 0] | |
| present.sort(key=lambda c: (-cat_counts[c], -cat_maxp[c], c)) | |
| for c in present: | |
| rows.append({ | |
| "seg": seg_id, | |
| "start": round(a / sampling_fps, 3), | |
| "end": round((b + 1) / sampling_fps, 3), | |
| "category": c, | |
| "max_p": round(cat_maxp[c], 3), | |
| }) | |
| df = pd.DataFrame(rows).sort_values(["seg", "max_p"], ascending=[True, False]).reset_index(drop=True) | |
| out_video = gr.update(value=None) | |
| if redact and has_ffmpeg(): | |
| intervals = merge_seconds_union(all_hit_idx, sampling_fps, pad=0.25) | |
| try: | |
| out_path = os.path.join(td, "redacted.mp4") | |
| redact_with_ffmpeg(video_file, intervals, out_path) | |
| final_out = os.path.join(os.getcwd(), "redacted_output.mp4") | |
| shutil.copyfile(out_path, final_out) | |
| out_video = gr.update(value=final_out) | |
| except Exception: | |
| out_video = gr.update(value=None) | |
| return df, out_video | |
| # ---------- UI ---------- | |
| with gr.Blocks(title=APP_TITLE, css=".wrap-row { gap: 16px; }") as demo: | |
| gr.Markdown(f"# {APP_TITLE}") | |
| gr.Markdown(APP_DESC) | |
| with gr.Tabs(): | |
| with gr.Tab("Image"): | |
| with gr.Row(elem_classes=["wrap-row"]): | |
| with gr.Column(scale=1, min_width=360): | |
| model_dd = gr.Dropdown(label="Model", choices=MODEL_NAMES, value=MODEL_NAMES[0]) | |
| threshold = gr.Slider(0.0, 1.0, value=DEFAULT_THRESHOLD, step=0.01, label="Threshold") | |
| categories = gr.CheckboxGroup(label="Categories", choices=ALL_CATEGORIES, value=ALL_CATEGORIES) | |
| inp_img = gr.Image(type="pil", label="Upload Image") | |
| btn = gr.Button("Analyze", variant="primary") | |
| with gr.Column(scale=1, min_width=360): | |
| verdict = gr.Label(label="Verdict") | |
| scores_df = gr.Dataframe(headers=["category", "score"], datatype="str", | |
| label="Scores", interactive=False) | |
| extra_tags = gr.Textbox(label="Selected tags", visible=False, lines=12) | |
| model_dd.change(on_model_change, inputs=model_dd, outputs=[categories, threshold]) | |
| btn.click(analyze_image, inputs=[model_dd, inp_img, categories, threshold], | |
| outputs=[verdict, scores_df, extra_tags]) | |
| gr.Examples( | |
| label="Try an example (Image)", | |
| examples=IMG_EXAMPLES, | |
| inputs=[model_dd, inp_img, categories, threshold], | |
| outputs=[verdict, scores_df, extra_tags], | |
| fn=analyze_image, | |
| run_on_click=True, | |
| cache_examples=False, | |
| ) | |
| with gr.Tab("Video"): | |
| with gr.Row(elem_classes=["wrap-row"]): | |
| with gr.Column(scale=1, min_width=360): | |
| v_model = gr.Dropdown(label="Model", choices=MODEL_NAMES, value=MODEL_NAMES[0]) | |
| v_threshold = gr.Slider(0.0, 1.0, value=DEFAULT_THRESHOLD, | |
| step=0.01, label="Threshold") | |
| v_fps = gr.Slider(0.25, 5.0, value=1.0, step=0.25, label="Sampling FPS") | |
| v_redact = gr.Checkbox(label="Redact scenes (requires FFmpeg)", value=False) | |
| v_categories = gr.CheckboxGroup(label="Categories", choices=ALL_CATEGORIES, value=ALL_CATEGORIES) | |
| v_input = gr.Video(label="Upload short video (≤ 60s)") | |
| v_btn = gr.Button("Analyze Video", variant="primary") | |
| with gr.Column(scale=1, min_width=360): | |
| v_segments = gr.Dataframe(label="Segments", interactive=False) | |
| v_out = gr.Video(label="Redacted Video") | |
| v_model.change(on_model_change, inputs=v_model, outputs=[v_categories, v_threshold]) | |
| v_btn.click(analyze_video, inputs=[v_model, v_input, v_categories, v_threshold, v_fps, v_redact], | |
| outputs=[v_segments, v_out]) | |
| if __name__ == "__main__": | |
| demo.launch() | |