File size: 14,599 Bytes
50ba3db
 
f817b37
 
50ba3db
f817b37
1a77833
50ba3db
f817b37
 
50ba3db
f817b37
 
 
50ba3db
f817b37
 
1a77833
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
f817b37
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
50ba3db
f817b37
 
50ba3db
f817b37
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1a77833
 
f817b37
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
50ba3db
f817b37
 
 
 
 
 
 
 
 
 
 
50ba3db
f817b37
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1a77833
 
 
 
 
 
 
 
 
 
f817b37
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
50ba3db
f817b37
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1a77833
 
d784d44
1a77833
d784d44
e5fe4dc
f817b37
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1a77833
df2b7b4
f817b37
 
 
 
 
 
1a77833
f817b37
 
1ae97d3
 
f817b37
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1a77833
 
 
 
 
f817b37
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1a77833
 
50ba3db
 
f817b37
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
import os
import random
import threading
from datetime import datetime
from pathlib import Path
from typing import Dict, List, Optional, Set, Tuple
import csv

import gradio as gr
import pandas as pd

# ----------------------------
# Configuration and constants
# ----------------------------

# Folder containing one subfolder per method with identically named video files
COMMON_VIDEOS_DIR = Path(__file__).resolve().parent / "common_videos"

# CSV file for persistent votes (prefer HF Spaces persistent storage if available)
def _resolve_votes_csv() -> Path:
    candidates = [
        Path(os.getenv("HF_DATA_DIR", "/data")),
        Path(__file__).resolve().parent,
    ]
    for d in candidates:
        try:
            d.mkdir(parents=True, exist_ok=True)
            test = d / ".write_test"
            with open(test, "w") as f:
                f.write("ok")
            try:
                test.unlink()
            except Exception:
                pass
            return d / "votes.csv"
        except Exception:
            continue
    return Path(__file__).resolve().parent / "votes.csv"

VOTES_CSV = _resolve_votes_csv()

# Methods
GROUND_TRUTH = "used_videos"
OLD_METHODS = [
    "liveportrait", "controltalk", "lia", "hallo2",
    "echomimic_acc", "dimitra", "sadtalker", "wav2lip",
]
NEW_METHODS = [
    "fom_gen", "xportrait", "mcnet", "emoportrait",
    "dagan", "liax", "omniavatar", "real3d",
]

# Study parameters
VOTES_PER_PAIR = 23
STOP_TOTAL = 100 * VOTES_PER_PAIR  # 100 new pairs * 23 votes each

# Allowed video extensions
VIDEO_EXTS = {".mp4", ".mov", ".webm", ".avi", ".mkv"}

# Thread lock for safe CSV writes
_write_lock = threading.Lock()

# ----------------------------
# Global in-memory state
# ----------------------------

# Mapping method -> set of available video filenames (e.g., "abc.mp4")
METHOD_VIDEOS: Dict[str, Set[str]] = {}

# All required unordered pairs as canonical tuples (m1, m2) sorted lexicographically
ALL_REQUIRED_PAIRS: List[Tuple[str, str]] = []

# Per-pair vote counts from CSV (key: (m1, m2) sorted tuple)
PAIR_COUNTS: Dict[Tuple[str, str], int] = {}

# Total votes recorded so far
TOTAL_VOTES: int = 0


# ----------------------------
# Utility functions
# ----------------------------

def ensure_votes_csv() -> pd.DataFrame:
    """Ensure votes.csv exists with headers and load it."""
    columns = ["method1", "method2", "video_name", "winner", "timestamp"]
    if not VOTES_CSV.exists():
        df = pd.DataFrame(columns=columns)
        df.to_csv(VOTES_CSV, index=False)
        return df
    try:
        df = pd.read_csv(VOTES_CSV)
        # Normalize columns if needed
        missing = [c for c in columns if c not in df.columns]
        if missing:
            for c in missing:
                df[c] = None
            df = df[columns]
            df.to_csv(VOTES_CSV, index=False)
        return df[columns]
    except Exception:
        # If corrupted, back up and start fresh
        backup = VOTES_CSV.with_suffix(".bak.csv")
        try:
            VOTES_CSV.replace(backup)
        except Exception:
            pass
        df = pd.DataFrame(columns=columns)
        df.to_csv(VOTES_CSV, index=False)
        return df


def scan_method_videos() -> Dict[str, Set[str]]:
    """Scan common_videos/ for each method, returning mapping method -> set of filenames."""
    methods = [GROUND_TRUTH] + OLD_METHODS + NEW_METHODS
    mapping: Dict[str, Set[str]] = {}
    for m in methods:
        folder = COMMON_VIDEOS_DIR / m
        if not folder.exists() or not folder.is_dir():
            mapping[m] = set()
            continue
        files = set()
        for p in folder.iterdir():
            if p.is_file() and p.suffix.lower() in VIDEO_EXTS:
                files.add(p.name)
        mapping[m] = files
    return mapping


def generate_required_pairs() -> List[Tuple[str, str]]:
    """Generate the 100 required pairs: NEW vs NEW, NEW vs OLD, NEW vs GT."""
    pairs: Set[Tuple[str, str]] = set()

    # NEW vs NEW
    for i in range(len(NEW_METHODS)):
        for j in range(i + 1, len(NEW_METHODS)):
            a, b = sorted((NEW_METHODS[i], NEW_METHODS[j]))
            pairs.add((a, b))

    # NEW vs OLD
    for n in NEW_METHODS:
        for o in OLD_METHODS:
            a, b = sorted((n, o))
            pairs.add((a, b))

    # NEW vs GT
    for n in NEW_METHODS:
        a, b = sorted((n, GROUND_TRUTH))
        pairs.add((a, b))

    # Sanity: should be 100
    return sorted(pairs)


def rebuild_counts_from_csv(df: pd.DataFrame) -> Tuple[Dict[Tuple[str, str], int], int]:
    """Rebuild per-pair counts and total votes from the CSV."""
    counts: Dict[Tuple[str, str], int] = {pair: 0 for pair in ALL_REQUIRED_PAIRS}
    total = 0
    if df is not None and not df.empty:
        for _, row in df.iterrows():
            # normalize to canonical sorted tuple
            pair = tuple(sorted((str(row["method1"]), str(row["method2"]))))
            # Only count votes that are part of this study's 100 pairs
            if pair in counts:
                counts[pair] += 1
                total += 1
    return counts, total


def select_next_pair() -> Optional[Tuple[str, str]]:
    """Pick an unordered pair (m1, m2) with the fewest votes (<23), breaking ties randomly."""
    # Filter to those under the per-pair quota
    under_quota = [p for p in ALL_REQUIRED_PAIRS if PAIR_COUNTS.get(p, 0) < VOTES_PER_PAIR]
    if not under_quota:
        return None
    # Find minimal count among under-quota pairs
    min_count = min(PAIR_COUNTS.get(p, 0) for p in under_quota)
    candidates = [p for p in under_quota if PAIR_COUNTS.get(p, 0) == min_count]
    return random.choice(candidates)


def pick_video_for_pair(m1: str, m2: str) -> Optional[str]:
    """Pick a random video filename available for both methods."""
    set1 = METHOD_VIDEOS.get(m1, set())
    set2 = METHOD_VIDEOS.get(m2, set())
    common = list(set1 & set2)
    if not common:
        return None
    return random.choice(common)


def video_path(method: str, filename: str) -> str:
    """Build absolute path to a method's video file."""
    return str(COMMON_VIDEOS_DIR / method / filename)


def progress_text() -> str:
    return f"Votes Collected: {min(TOTAL_VOTES, STOP_TOTAL)} / {STOP_TOTAL}"


def prepare_next_display():
    """Compute the next pair, randomize sides, and return UI payload."""
    global TOTAL_VOTES
    # Stop condition on total votes
    if TOTAL_VOTES >= STOP_TOTAL:
        return {
            "left_src": None,
            "right_src": None,
            "status": "Study Complete. Thank you!",
            "progress": progress_text(),
            "state": {
                "method_left": None,
                "method_right": None,
                "video_name": None,
                "pair": None,
            },
            "disable": True,
        }

    pair = select_next_pair()
    if pair is None:
        # No more pairs under quota; either done or cannot proceed
        return {
            "left_src": None,
            "right_src": None,
            "status": "Study Complete. Thank you!",
            "progress": progress_text(),
            "state": {
                "method_left": None,
                "method_right": None,
                "video_name": None,
                "pair": None,
            },
            "disable": True,
        }

    m1, m2 = pair
    filename = pick_video_for_pair(m1, m2)
    # If no common file found, try a few times; fallback to disable
    tries = 5
    while filename is None and tries > 0:
        filename = pick_video_for_pair(m1, m2)
        tries -= 1
    if filename is None:
        return {
            "left_src": None,
            "right_src": None,
            "status": "No common video found for selected pair. Please try again.",
            "progress": progress_text(),
            "state": {
                "method_left": None,
                "method_right": None,
                "video_name": None,
                "pair": None,
            },
            "disable": False,
        }

    # Randomize left/right
    if random.random() < 0.5:
        left_m, right_m = m1, m2
    else:
        left_m, right_m = m2, m1

    return {
        "left_src": video_path(left_m, filename),
        "right_src": video_path(right_m, filename),
        "status": "",
        "progress": progress_text(),
        "state": {
            "method_left": left_m,
            "method_right": right_m,
            "video_name": filename,
            "pair": tuple(sorted((m1, m2))),
        },
        "disable": False,
    }


def append_vote(method1: str, method2: str, video_name: str, winner: str):
    """Append a vote to CSV safely and update in-memory counters."""
    global TOTAL_VOTES
    ts = datetime.utcnow().isoformat()
    row = {
        "method1": method1,
        "method2": method2,
        "video_name": video_name,
        "winner": winner,
        "timestamp": ts,
    }
    with _write_lock:
        # Robust append with immediate flush/fsync so the file is always up-to-date
        VOTES_CSV.parent.mkdir(parents=True, exist_ok=True)
        need_header = (not VOTES_CSV.exists()) or (VOTES_CSV.stat().st_size == 0)
        with open(VOTES_CSV, "a", newline="", encoding="utf-8") as f:
            writer = csv.DictWriter(f, fieldnames=["method1", "method2", "video_name", "winner", "timestamp"])
            if need_header:
                writer.writeheader()
            writer.writerow(row)
            f.flush()
            os.fsync(f.fileno())
        # Update memory
        PAIR_COUNTS[(method1, method2)] = PAIR_COUNTS.get((method1, method2), 0) + 1
        TOTAL_VOTES += 1


# ----------------------------
# Gradio callback functions
# ----------------------------

def on_load():
    """Load initial pair and media."""
    payload = prepare_next_display()
    disable = payload["disable"]
    return (
        payload["left_src"],
        payload["right_src"],
        payload["progress"],
        payload["status"],
        payload["state"],
        gr.update(interactive=not disable),
        gr.update(interactive=not disable),
    )


def on_vote(choice: str, state: dict):
    """Handle a vote and load the next pair."""
    # If study complete or invalid state, just refresh next
    if not state or not state.get("pair") or not state.get("video_name"):
        payload = prepare_next_display()
        disable = payload["disable"]
        return (
            payload["left_src"],
            payload["right_src"],
            payload["progress"],
            payload["status"],
            payload["state"],
            gr.update(interactive=not disable),
            gr.update(interactive=not disable),
        )

    left_m = state["method_left"]
    right_m = state["method_right"]
    filename = state["video_name"]
    pair = state["pair"]  # canonical sorted tuple

    # Determine winner label
    if choice == "left":
        winner = left_m
    elif choice == "right":
        winner = right_m
    else:
        winner = "equal"

    # Persist vote (canonical pair order in CSV)
    append_vote(pair[0], pair[1], filename, winner)

    # Prepare next
    payload = prepare_next_display()
    disable = payload["disable"]
    return (
        payload["left_src"],
        payload["right_src"],
        payload["progress"],
        payload["status"],
        payload["state"],
        gr.update(interactive=not disable),
        gr.update(interactive=not disable),
    )


def export_votes():
    """Return the current CSV path for download."""
    ensure_votes_csv()
    # Return path; Gradio will serve it as a downloadable file
    return str(VOTES_CSV)

# ----------------------------
# App initialization
# ----------------------------

def initialize():
    """Initialize global state: CSV, files, pairs, counts."""
    global METHOD_VIDEOS, ALL_REQUIRED_PAIRS, PAIR_COUNTS, TOTAL_VOTES

    # Ensure dirs exist
    if not COMMON_VIDEOS_DIR.exists():
        os.makedirs(COMMON_VIDEOS_DIR, exist_ok=True)

    # Scan files
    METHOD_VIDEOS = scan_method_videos()

    # Pairs
    ALL_REQUIRED_PAIRS = generate_required_pairs()  # 100 pairs

    # Load votes and rebuild counts
    df = ensure_votes_csv()
    PAIR_COUNTS, TOTAL_VOTES = rebuild_counts_from_csv(df)


# Initialize on import
initialize()
###

# ----------------------------
# Gradio UI
# ----------------------------

with gr.Blocks(title="Scientific Video Comparison Study") as demo:
    gr.Markdown("Compare the two videos and vote. Randomized positions prevent bias.")
    gr.Markdown(f"Votes file: {VOTES_CSV}")

    with gr.Row():
        left_video = gr.Video(label="Left Video", autoplay=True, height=360)
        right_video = gr.Video(label="Right Video", autoplay=True, height=360)

    progress = gr.Markdown(progress_text())
    status = gr.Markdown("")

    # Hidden state storing current assignment and video filename
    state = gr.State(value={
        "method_left": None,
        "method_right": None,
        "video_name": None,
        "pair": None,
    })

    with gr.Row():
        btn_left = gr.Button("Left Video is Better", variant="primary")
        btn_right = gr.Button("Right Video is Better", variant="primary")
        btn_toggle = gr.Button("Play/Pause Both")

    # Add export/download controls
    with gr.Row():
        btn_export = gr.Button("Refresh & Download votes.csv")
        votes_file = gr.File(label="votes.csv", interactive=False)

    # Wire events
    demo.load(
        fn=on_load,
        inputs=None,
        outputs=[left_video, right_video, progress, status, state, btn_left, btn_right],
    )
    btn_left.click(
        fn=lambda s: on_vote("left", s),
        inputs=[state],
        outputs=[left_video, right_video, progress, status, state, btn_left, btn_right],
    )
    btn_right.click(
        fn=lambda s: on_vote("right", s),
        inputs=[state],
        outputs=[left_video, right_video, progress, status, state, btn_left, btn_right],
    )
    btn_toggle.click(
        fn=None,
        inputs=None,
        outputs=None,
        js="""
() => {
  const vids = Array.from(document.querySelectorAll('video'));
  if (vids.length === 0) return;
  const anyPlaying = vids.some(v => !v.paused && !v.ended && v.readyState > 2);
  if (anyPlaying) {
    vids.forEach(v => v.pause());
  } else {
    vids.forEach(v => v.play());
  }
}
"""
    )
    # Export votes.csv
    btn_export.click(fn=export_votes, inputs=None, outputs=[votes_file])

if __name__ == "__main__":
    # Launch locally; adjust server_name/port as needed
    demo.queue().launch()