Spaces:
Sleeping
Sleeping
| import os | |
| import random | |
| import threading | |
| from datetime import datetime | |
| from pathlib import Path | |
| from typing import Dict, List, Optional, Set, Tuple | |
| import csv | |
| import gradio as gr | |
| import pandas as pd | |
| # ---------------------------- | |
| # Configuration and constants | |
| # ---------------------------- | |
| # Folder containing one subfolder per method with identically named video files | |
| COMMON_VIDEOS_DIR = Path(__file__).resolve().parent / "common_videos" | |
| # CSV file for persistent votes (prefer HF Spaces persistent storage if available) | |
| def _resolve_votes_csv() -> Path: | |
| candidates = [ | |
| Path(os.getenv("HF_DATA_DIR", "/data")), | |
| Path(__file__).resolve().parent, | |
| ] | |
| for d in candidates: | |
| try: | |
| d.mkdir(parents=True, exist_ok=True) | |
| test = d / ".write_test" | |
| with open(test, "w") as f: | |
| f.write("ok") | |
| try: | |
| test.unlink() | |
| except Exception: | |
| pass | |
| return d / "votes.csv" | |
| except Exception: | |
| continue | |
| return Path(__file__).resolve().parent / "votes.csv" | |
| VOTES_CSV = _resolve_votes_csv() | |
| # Methods | |
| GROUND_TRUTH = "used_videos" | |
| OLD_METHODS = [ | |
| "liveportrait", "controltalk", "lia", "hallo2", | |
| "echomimic_acc", "dimitra", "sadtalker", "wav2lip", | |
| ] | |
| NEW_METHODS = [ | |
| "fom_gen", "xportrait", "mcnet", "emoportrait", | |
| "dagan", "liax", "omniavatar", "real3d", | |
| ] | |
| # Study parameters | |
| VOTES_PER_PAIR = 23 | |
| STOP_TOTAL = 100 * VOTES_PER_PAIR # 100 new pairs * 23 votes each | |
| # Allowed video extensions | |
| VIDEO_EXTS = {".mp4", ".mov", ".webm", ".avi", ".mkv"} | |
| # Thread lock for safe CSV writes | |
| _write_lock = threading.Lock() | |
| # ---------------------------- | |
| # Global in-memory state | |
| # ---------------------------- | |
| # Mapping method -> set of available video filenames (e.g., "abc.mp4") | |
| METHOD_VIDEOS: Dict[str, Set[str]] = {} | |
| # All required unordered pairs as canonical tuples (m1, m2) sorted lexicographically | |
| ALL_REQUIRED_PAIRS: List[Tuple[str, str]] = [] | |
| # Per-pair vote counts from CSV (key: (m1, m2) sorted tuple) | |
| PAIR_COUNTS: Dict[Tuple[str, str], int] = {} | |
| # Total votes recorded so far | |
| TOTAL_VOTES: int = 0 | |
| # ---------------------------- | |
| # Utility functions | |
| # ---------------------------- | |
| def ensure_votes_csv() -> pd.DataFrame: | |
| """Ensure votes.csv exists with headers and load it.""" | |
| columns = ["method1", "method2", "video_name", "winner", "timestamp"] | |
| if not VOTES_CSV.exists(): | |
| df = pd.DataFrame(columns=columns) | |
| df.to_csv(VOTES_CSV, index=False) | |
| return df | |
| try: | |
| df = pd.read_csv(VOTES_CSV) | |
| # Normalize columns if needed | |
| missing = [c for c in columns if c not in df.columns] | |
| if missing: | |
| for c in missing: | |
| df[c] = None | |
| df = df[columns] | |
| df.to_csv(VOTES_CSV, index=False) | |
| return df[columns] | |
| except Exception: | |
| # If corrupted, back up and start fresh | |
| backup = VOTES_CSV.with_suffix(".bak.csv") | |
| try: | |
| VOTES_CSV.replace(backup) | |
| except Exception: | |
| pass | |
| df = pd.DataFrame(columns=columns) | |
| df.to_csv(VOTES_CSV, index=False) | |
| return df | |
| def scan_method_videos() -> Dict[str, Set[str]]: | |
| """Scan common_videos/ for each method, returning mapping method -> set of filenames.""" | |
| methods = [GROUND_TRUTH] + OLD_METHODS + NEW_METHODS | |
| mapping: Dict[str, Set[str]] = {} | |
| for m in methods: | |
| folder = COMMON_VIDEOS_DIR / m | |
| if not folder.exists() or not folder.is_dir(): | |
| mapping[m] = set() | |
| continue | |
| files = set() | |
| for p in folder.iterdir(): | |
| if p.is_file() and p.suffix.lower() in VIDEO_EXTS: | |
| files.add(p.name) | |
| mapping[m] = files | |
| return mapping | |
| def generate_required_pairs() -> List[Tuple[str, str]]: | |
| """Generate the 100 required pairs: NEW vs NEW, NEW vs OLD, NEW vs GT.""" | |
| pairs: Set[Tuple[str, str]] = set() | |
| # NEW vs NEW | |
| for i in range(len(NEW_METHODS)): | |
| for j in range(i + 1, len(NEW_METHODS)): | |
| a, b = sorted((NEW_METHODS[i], NEW_METHODS[j])) | |
| pairs.add((a, b)) | |
| # NEW vs OLD | |
| for n in NEW_METHODS: | |
| for o in OLD_METHODS: | |
| a, b = sorted((n, o)) | |
| pairs.add((a, b)) | |
| # NEW vs GT | |
| for n in NEW_METHODS: | |
| a, b = sorted((n, GROUND_TRUTH)) | |
| pairs.add((a, b)) | |
| # Sanity: should be 100 | |
| return sorted(pairs) | |
| def rebuild_counts_from_csv(df: pd.DataFrame) -> Tuple[Dict[Tuple[str, str], int], int]: | |
| """Rebuild per-pair counts and total votes from the CSV.""" | |
| counts: Dict[Tuple[str, str], int] = {pair: 0 for pair in ALL_REQUIRED_PAIRS} | |
| total = 0 | |
| if df is not None and not df.empty: | |
| for _, row in df.iterrows(): | |
| # normalize to canonical sorted tuple | |
| pair = tuple(sorted((str(row["method1"]), str(row["method2"])))) | |
| # Only count votes that are part of this study's 100 pairs | |
| if pair in counts: | |
| counts[pair] += 1 | |
| total += 1 | |
| return counts, total | |
| def select_next_pair() -> Optional[Tuple[str, str]]: | |
| """Pick an unordered pair (m1, m2) with the fewest votes (<23), breaking ties randomly.""" | |
| # Filter to those under the per-pair quota | |
| under_quota = [p for p in ALL_REQUIRED_PAIRS if PAIR_COUNTS.get(p, 0) < VOTES_PER_PAIR] | |
| if not under_quota: | |
| return None | |
| # Find minimal count among under-quota pairs | |
| min_count = min(PAIR_COUNTS.get(p, 0) for p in under_quota) | |
| candidates = [p for p in under_quota if PAIR_COUNTS.get(p, 0) == min_count] | |
| return random.choice(candidates) | |
| def pick_video_for_pair(m1: str, m2: str) -> Optional[str]: | |
| """Pick a random video filename available for both methods.""" | |
| set1 = METHOD_VIDEOS.get(m1, set()) | |
| set2 = METHOD_VIDEOS.get(m2, set()) | |
| common = list(set1 & set2) | |
| if not common: | |
| return None | |
| return random.choice(common) | |
| def video_path(method: str, filename: str) -> str: | |
| """Build absolute path to a method's video file.""" | |
| return str(COMMON_VIDEOS_DIR / method / filename) | |
| def progress_text() -> str: | |
| return f"Votes Collected: {min(TOTAL_VOTES, STOP_TOTAL)} / {STOP_TOTAL}" | |
| def prepare_next_display(): | |
| """Compute the next pair, randomize sides, and return UI payload.""" | |
| global TOTAL_VOTES | |
| # Stop condition on total votes | |
| if TOTAL_VOTES >= STOP_TOTAL: | |
| return { | |
| "left_src": None, | |
| "right_src": None, | |
| "status": "Study Complete. Thank you!", | |
| "progress": progress_text(), | |
| "state": { | |
| "method_left": None, | |
| "method_right": None, | |
| "video_name": None, | |
| "pair": None, | |
| }, | |
| "disable": True, | |
| } | |
| pair = select_next_pair() | |
| if pair is None: | |
| # No more pairs under quota; either done or cannot proceed | |
| return { | |
| "left_src": None, | |
| "right_src": None, | |
| "status": "Study Complete. Thank you!", | |
| "progress": progress_text(), | |
| "state": { | |
| "method_left": None, | |
| "method_right": None, | |
| "video_name": None, | |
| "pair": None, | |
| }, | |
| "disable": True, | |
| } | |
| m1, m2 = pair | |
| filename = pick_video_for_pair(m1, m2) | |
| # If no common file found, try a few times; fallback to disable | |
| tries = 5 | |
| while filename is None and tries > 0: | |
| filename = pick_video_for_pair(m1, m2) | |
| tries -= 1 | |
| if filename is None: | |
| return { | |
| "left_src": None, | |
| "right_src": None, | |
| "status": "No common video found for selected pair. Please try again.", | |
| "progress": progress_text(), | |
| "state": { | |
| "method_left": None, | |
| "method_right": None, | |
| "video_name": None, | |
| "pair": None, | |
| }, | |
| "disable": False, | |
| } | |
| # Randomize left/right | |
| if random.random() < 0.5: | |
| left_m, right_m = m1, m2 | |
| else: | |
| left_m, right_m = m2, m1 | |
| return { | |
| "left_src": video_path(left_m, filename), | |
| "right_src": video_path(right_m, filename), | |
| "status": "", | |
| "progress": progress_text(), | |
| "state": { | |
| "method_left": left_m, | |
| "method_right": right_m, | |
| "video_name": filename, | |
| "pair": tuple(sorted((m1, m2))), | |
| }, | |
| "disable": False, | |
| } | |
| def append_vote(method1: str, method2: str, video_name: str, winner: str): | |
| """Append a vote to CSV safely and update in-memory counters.""" | |
| global TOTAL_VOTES | |
| ts = datetime.utcnow().isoformat() | |
| row = { | |
| "method1": method1, | |
| "method2": method2, | |
| "video_name": video_name, | |
| "winner": winner, | |
| "timestamp": ts, | |
| } | |
| with _write_lock: | |
| # Robust append with immediate flush/fsync so the file is always up-to-date | |
| VOTES_CSV.parent.mkdir(parents=True, exist_ok=True) | |
| need_header = (not VOTES_CSV.exists()) or (VOTES_CSV.stat().st_size == 0) | |
| with open(VOTES_CSV, "a", newline="", encoding="utf-8") as f: | |
| writer = csv.DictWriter(f, fieldnames=["method1", "method2", "video_name", "winner", "timestamp"]) | |
| if need_header: | |
| writer.writeheader() | |
| writer.writerow(row) | |
| f.flush() | |
| os.fsync(f.fileno()) | |
| # Update memory | |
| PAIR_COUNTS[(method1, method2)] = PAIR_COUNTS.get((method1, method2), 0) + 1 | |
| TOTAL_VOTES += 1 | |
| # ---------------------------- | |
| # Gradio callback functions | |
| # ---------------------------- | |
| def on_load(): | |
| """Load initial pair and media.""" | |
| payload = prepare_next_display() | |
| disable = payload["disable"] | |
| return ( | |
| payload["left_src"], | |
| payload["right_src"], | |
| payload["progress"], | |
| payload["status"], | |
| payload["state"], | |
| gr.update(interactive=not disable), | |
| gr.update(interactive=not disable), | |
| ) | |
| def on_vote(choice: str, state: dict): | |
| """Handle a vote and load the next pair.""" | |
| # If study complete or invalid state, just refresh next | |
| if not state or not state.get("pair") or not state.get("video_name"): | |
| payload = prepare_next_display() | |
| disable = payload["disable"] | |
| return ( | |
| payload["left_src"], | |
| payload["right_src"], | |
| payload["progress"], | |
| payload["status"], | |
| payload["state"], | |
| gr.update(interactive=not disable), | |
| gr.update(interactive=not disable), | |
| ) | |
| left_m = state["method_left"] | |
| right_m = state["method_right"] | |
| filename = state["video_name"] | |
| pair = state["pair"] # canonical sorted tuple | |
| # Determine winner label | |
| if choice == "left": | |
| winner = left_m | |
| elif choice == "right": | |
| winner = right_m | |
| else: | |
| winner = "equal" | |
| # Persist vote (canonical pair order in CSV) | |
| append_vote(pair[0], pair[1], filename, winner) | |
| # Prepare next | |
| payload = prepare_next_display() | |
| disable = payload["disable"] | |
| return ( | |
| payload["left_src"], | |
| payload["right_src"], | |
| payload["progress"], | |
| payload["status"], | |
| payload["state"], | |
| gr.update(interactive=not disable), | |
| gr.update(interactive=not disable), | |
| ) | |
| def export_votes(): | |
| """Return the current CSV path for download.""" | |
| ensure_votes_csv() | |
| # Return path; Gradio will serve it as a downloadable file | |
| return str(VOTES_CSV) | |
| # ---------------------------- | |
| # App initialization | |
| # ---------------------------- | |
| def initialize(): | |
| """Initialize global state: CSV, files, pairs, counts.""" | |
| global METHOD_VIDEOS, ALL_REQUIRED_PAIRS, PAIR_COUNTS, TOTAL_VOTES | |
| # Ensure dirs exist | |
| if not COMMON_VIDEOS_DIR.exists(): | |
| os.makedirs(COMMON_VIDEOS_DIR, exist_ok=True) | |
| # Scan files | |
| METHOD_VIDEOS = scan_method_videos() | |
| # Pairs | |
| ALL_REQUIRED_PAIRS = generate_required_pairs() # 100 pairs | |
| # Load votes and rebuild counts | |
| df = ensure_votes_csv() | |
| PAIR_COUNTS, TOTAL_VOTES = rebuild_counts_from_csv(df) | |
| # Initialize on import | |
| initialize() | |
| ### | |
| # ---------------------------- | |
| # Gradio UI | |
| # ---------------------------- | |
| with gr.Blocks(title="Scientific Video Comparison Study") as demo: | |
| gr.Markdown("Compare the two videos and vote. Randomized positions prevent bias.") | |
| gr.Markdown(f"Votes file: {VOTES_CSV}") | |
| with gr.Row(): | |
| left_video = gr.Video(label="Left Video", autoplay=True, height=360) | |
| right_video = gr.Video(label="Right Video", autoplay=True, height=360) | |
| progress = gr.Markdown(progress_text()) | |
| status = gr.Markdown("") | |
| # Hidden state storing current assignment and video filename | |
| state = gr.State(value={ | |
| "method_left": None, | |
| "method_right": None, | |
| "video_name": None, | |
| "pair": None, | |
| }) | |
| with gr.Row(): | |
| btn_left = gr.Button("Left Video is Better", variant="primary") | |
| btn_right = gr.Button("Right Video is Better", variant="primary") | |
| btn_toggle = gr.Button("Play/Pause Both") | |
| # Add export/download controls | |
| with gr.Row(): | |
| btn_export = gr.Button("Refresh & Download votes.csv") | |
| votes_file = gr.File(label="votes.csv", interactive=False) | |
| # Wire events | |
| demo.load( | |
| fn=on_load, | |
| inputs=None, | |
| outputs=[left_video, right_video, progress, status, state, btn_left, btn_right], | |
| ) | |
| btn_left.click( | |
| fn=lambda s: on_vote("left", s), | |
| inputs=[state], | |
| outputs=[left_video, right_video, progress, status, state, btn_left, btn_right], | |
| ) | |
| btn_right.click( | |
| fn=lambda s: on_vote("right", s), | |
| inputs=[state], | |
| outputs=[left_video, right_video, progress, status, state, btn_left, btn_right], | |
| ) | |
| btn_toggle.click( | |
| fn=None, | |
| inputs=None, | |
| outputs=None, | |
| js=""" | |
| () => { | |
| const vids = Array.from(document.querySelectorAll('video')); | |
| if (vids.length === 0) return; | |
| const anyPlaying = vids.some(v => !v.paused && !v.ended && v.readyState > 2); | |
| if (anyPlaying) { | |
| vids.forEach(v => v.pause()); | |
| } else { | |
| vids.forEach(v => v.play()); | |
| } | |
| } | |
| """ | |
| ) | |
| # Export votes.csv | |
| btn_export.click(fn=export_votes, inputs=None, outputs=[votes_file]) | |
| if __name__ == "__main__": | |
| # Launch locally; adjust server_name/port as needed | |
| demo.queue().launch() | |