merve HF Staff yonigozlan HF Staff commited on
Commit
68adc67
·
verified ·
1 Parent(s): 7ec6608

Add support for multiple prompts (#2)

Browse files

- add support for multi prompt (a00bf1798a9a3a6de9331ae637d4d739a1b63874)


Co-authored-by: Yoni Gozlan <[email protected]>

Files changed (1) hide show
  1. app.py +273 -48
app.py CHANGED
@@ -1,4 +1,3 @@
1
- import os
2
  import colorsys
3
  import gc
4
  import os
@@ -10,8 +9,10 @@ import numpy as np
10
  import torch
11
  from gradio.themes import Soft
12
  from PIL import Image, ImageDraw, ImageFont
 
13
  from transformers import Sam3TrackerVideoModel, Sam3TrackerVideoProcessor, Sam3VideoModel, Sam3VideoProcessor
14
 
 
15
  def get_device_and_dtype() -> tuple[str, torch.dtype]:
16
  device = "cuda" if torch.cuda.is_available() else "cpu"
17
  dtype = torch.bfloat16
@@ -87,6 +88,23 @@ def pastel_color_for_object(obj_id: int) -> tuple[int, int, int]:
87
  return int(r_f * 255), int(g_f * 255), int(b_f * 255)
88
 
89
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
90
  class AppState:
91
  def __init__(self):
92
  self.reset()
@@ -97,6 +115,7 @@ class AppState:
97
  self.video_fps: float | None = None
98
  self.masks_by_frame: dict[int, dict[int, np.ndarray]] = {}
99
  self.color_by_obj: dict[int, tuple[int, int, int]] = {}
 
100
  self.clicks_by_frame_obj: dict[int, dict[int, list[tuple[int, int, int]]]] = {}
101
  self.boxes_by_frame_obj: dict[int, dict[int, list[tuple[int, int, int, int]]]] = {}
102
  self.text_prompts_by_frame_obj: dict[int, dict[int, str]] = {}
@@ -119,14 +138,13 @@ class AppState:
119
  return len(self.video_frames)
120
 
121
 
122
-
123
-
124
  def init_video_session(
125
  GLOBAL_STATE: gr.State, video: str | dict, active_tab: str = "point_box"
126
  ) -> tuple[AppState, int, int, Image.Image, str]:
127
  GLOBAL_STATE.video_frames = []
128
  GLOBAL_STATE.masks_by_frame = {}
129
  GLOBAL_STATE.color_by_obj = {}
 
130
  GLOBAL_STATE.text_prompts_by_frame_obj = {}
131
  GLOBAL_STATE.clicks_by_frame_obj = {}
132
  GLOBAL_STATE.boxes_by_frame_obj = {}
@@ -180,11 +198,10 @@ def init_video_session(
180
  GLOBAL_STATE.inference_session = processor.init_video_session(
181
  video=raw_video,
182
  inference_device=device,
183
- video_storage_device=device,
184
- processing_device=device,
185
  inference_state_device=device,
186
  dtype=dtype,
187
- max_vision_features_cache_size=1,
188
  )
189
 
190
  first_frame = frames[0]
@@ -248,7 +265,36 @@ def compose_frame(state: AppState, frame_idx: int) -> Image.Image:
248
 
249
  if text_prompts_by_obj and len(masks) > 0:
250
  draw = ImageDraw.Draw(out_img)
251
- font = ImageFont.load_default()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
252
 
253
  for obj_id, text_prompt in text_prompts_by_obj.items():
254
  obj_mask = masks.get(obj_id)
@@ -261,15 +307,17 @@ def compose_frame(state: AppState, frame_idx: int) -> Image.Image:
261
  y_min, y_max = np.where(rows)[0][[0, -1]]
262
  x_min, x_max = np.where(cols)[0][[0, -1]]
263
  label_x = int(x_min)
264
- label_y = int(y_min) - 20
265
- label_y = max(5, label_y)
 
 
 
266
 
267
  obj_color = state.color_by_obj.get(obj_id, (255, 255, 255))
268
 
269
  # Include object ID in the label
270
- label_text = f"{text_prompt} (ID: {obj_id})"
271
  bbox = draw.textbbox((label_x, label_y), label_text, font=font)
272
- padding = 4
273
  draw.rectangle(
274
  [(bbox[0] - padding, bbox[1] - padding), (bbox[2] + padding, bbox[3] + padding)],
275
  fill=obj_color,
@@ -292,8 +340,38 @@ def update_frame_display(state: AppState, frame_idx: int) -> Image.Image:
292
  return compose_frame(state, frame_idx)
293
 
294
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
295
  def _ensure_color_for_obj(state: AppState, obj_id: int):
296
- if obj_id not in state.color_by_obj:
 
 
 
 
 
 
 
 
 
 
297
  state.color_by_obj[obj_id] = pastel_color_for_object(obj_id)
298
 
299
 
@@ -414,21 +492,29 @@ def on_text_prompt(
414
  state: AppState,
415
  frame_idx: int,
416
  text_prompt: str,
417
- ) -> tuple[Image.Image, str]:
418
  if state is None or state.inference_session is None:
419
- return None, "Upload a video and enter text prompt."
420
 
421
  model = _GLOBAL_TEXT_VIDEO_MODEL
422
  processor = _GLOBAL_TEXT_VIDEO_PROCESSOR
423
 
424
  if not text_prompt or not text_prompt.strip():
425
- return update_frame_display(state, int(frame_idx)), "Please enter a text prompt."
 
426
 
427
  frame_idx = int(np.clip(frame_idx, 0, len(state.video_frames) - 1))
428
 
 
 
 
 
 
 
 
429
  state.inference_session = processor.add_text_prompt(
430
  inference_session=state.inference_session,
431
- text=text_prompt.strip(),
432
  )
433
 
434
  masks_for_frame = state.masks_by_frame.setdefault(frame_idx, {})
@@ -436,6 +522,8 @@ def on_text_prompt(
436
 
437
  num_objects = 0
438
  detected_obj_ids = []
 
 
439
  with torch.no_grad():
440
  for model_outputs in model.propagate_in_video_iterator(
441
  inference_session=state.inference_session,
@@ -452,6 +540,15 @@ def on_text_prompt(
452
  object_ids = processed_outputs["object_ids"]
453
  masks = processed_outputs["masks"]
454
  scores = processed_outputs["scores"]
 
 
 
 
 
 
 
 
 
455
 
456
  num_objects = len(object_ids)
457
  if num_objects > 0:
@@ -463,22 +560,54 @@ def on_text_prompt(
463
  for mask_idx in sorted_indices:
464
  current_obj_id = int(object_ids[mask_idx].item())
465
  detected_obj_ids.append(current_obj_id)
466
- _ensure_color_for_obj(state, current_obj_id)
467
  mask_2d = masks[mask_idx].float().cpu().numpy()
468
  if mask_2d.ndim == 3:
469
  mask_2d = mask_2d.squeeze()
470
  mask_2d = (mask_2d > 0.0).astype(np.float32)
471
  masks_for_frame[current_obj_id] = mask_2d
472
- frame_texts[current_obj_id] = text_prompt.strip()
 
 
 
 
 
 
 
 
 
 
 
473
 
474
  state.composited_frames.pop(frame_idx, None)
475
 
 
476
  if detected_obj_ids:
477
- obj_ids_str = ", ".join(map(str, detected_obj_ids))
478
- status = f"Processed text prompt '{text_prompt.strip()}' on frame {frame_idx}. Found {num_objects} object(s) with IDs: {obj_ids_str}."
 
 
 
 
479
  else:
480
- status = f"Processed text prompt '{text_prompt.strip()}' on frame {frame_idx}. No objects detected."
481
- return update_frame_display(state, int(frame_idx)), status
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
482
 
483
 
484
  def propagate_masks(GLOBAL_STATE: gr.State):
@@ -504,6 +633,7 @@ def propagate_masks(GLOBAL_STATE: gr.State):
504
  model = _GLOBAL_TEXT_VIDEO_MODEL
505
  processor = _GLOBAL_TEXT_VIDEO_PROCESSOR
506
 
 
507
  text_prompt_to_obj_ids = {}
508
  for frame_idx, frame_texts in GLOBAL_STATE.text_prompts_by_frame_obj.items():
509
  for obj_id, text_prompt in frame_texts.items():
@@ -512,6 +642,12 @@ def propagate_masks(GLOBAL_STATE: gr.State):
512
  if obj_id not in text_prompt_to_obj_ids[text_prompt]:
513
  text_prompt_to_obj_ids[text_prompt].append(obj_id)
514
 
 
 
 
 
 
 
515
  for text_prompt in text_prompt_to_obj_ids:
516
  text_prompt_to_obj_ids[text_prompt].sort()
517
 
@@ -519,6 +655,7 @@ def propagate_masks(GLOBAL_STATE: gr.State):
519
  yield GLOBAL_STATE, "No text prompts found. Please add a text prompt first.", gr.update()
520
  return
521
 
 
522
  for text_prompt in text_prompt_to_obj_ids.keys():
523
  GLOBAL_STATE.inference_session = processor.add_text_prompt(
524
  inference_session=GLOBAL_STATE.inference_session,
@@ -548,6 +685,7 @@ def propagate_masks(GLOBAL_STATE: gr.State):
548
  object_ids = processed_outputs["object_ids"]
549
  masks = processed_outputs["masks"]
550
  scores = processed_outputs["scores"]
 
551
 
552
  masks_for_frame = GLOBAL_STATE.masks_by_frame.setdefault(frame_idx, {})
553
  frame_texts = GLOBAL_STATE.text_prompts_by_frame_obj.setdefault(frame_idx, {})
@@ -561,24 +699,23 @@ def propagate_masks(GLOBAL_STATE: gr.State):
561
 
562
  for mask_idx in sorted_indices:
563
  current_obj_id = int(object_ids[mask_idx].item())
564
- _ensure_color_for_obj(GLOBAL_STATE, current_obj_id)
565
  mask_2d = masks[mask_idx].float().cpu().numpy()
566
  if mask_2d.ndim == 3:
567
  mask_2d = mask_2d.squeeze()
568
  mask_2d = (mask_2d > 0.0).astype(np.float32)
569
  masks_for_frame[current_obj_id] = mask_2d
570
 
 
571
  found_prompt = None
572
- for existing_frame_idx, existing_frame_texts in GLOBAL_STATE.text_prompts_by_frame_obj.items():
573
- if current_obj_id in existing_frame_texts:
574
- found_prompt = existing_frame_texts[current_obj_id]
575
  break
576
 
577
- if found_prompt is None and text_prompt_to_obj_ids:
578
- found_prompt = list(text_prompt_to_obj_ids.keys())[0]
579
-
580
  if found_prompt:
581
- frame_texts[current_obj_id] = found_prompt
 
582
 
583
  GLOBAL_STATE.composited_frames.pop(frame_idx, None)
584
  last_frame_idx = frame_idx
@@ -620,9 +757,76 @@ def propagate_masks(GLOBAL_STATE: gr.State):
620
  yield GLOBAL_STATE, text, gr.update(value=last_frame_idx)
621
 
622
 
623
- def reset_session(GLOBAL_STATE: gr.State) -> tuple[AppState, Image.Image, int, int, str]:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
624
  if not GLOBAL_STATE.video_frames:
625
- return GLOBAL_STATE, None, 0, 0, "Session reset. Load a new video."
626
 
627
  if GLOBAL_STATE.active_tab == "text":
628
  if GLOBAL_STATE.video_frames:
@@ -645,11 +849,9 @@ def reset_session(GLOBAL_STATE: gr.State) -> tuple[AppState, Image.Image, int, i
645
  GLOBAL_STATE.inference_session = processor.init_video_session(
646
  video=raw_video,
647
  inference_device=_GLOBAL_DEVICE,
648
- video_storage_device=_GLOBAL_DEVICE,
649
- processing_device=_GLOBAL_DEVICE,
650
- inference_state_device=_GLOBAL_DEVICE,
651
  dtype=_GLOBAL_DTYPE,
652
- max_vision_features_cache_size=1,
653
  )
654
 
655
  GLOBAL_STATE.masks_by_frame.clear()
@@ -657,6 +859,8 @@ def reset_session(GLOBAL_STATE: gr.State) -> tuple[AppState, Image.Image, int, i
657
  GLOBAL_STATE.boxes_by_frame_obj.clear()
658
  GLOBAL_STATE.text_prompts_by_frame_obj.clear()
659
  GLOBAL_STATE.composited_frames.clear()
 
 
660
  GLOBAL_STATE.pending_box_start = None
661
  GLOBAL_STATE.pending_box_start_frame_idx = None
662
  GLOBAL_STATE.pending_box_start_obj_id = None
@@ -669,7 +873,8 @@ def reset_session(GLOBAL_STATE: gr.State) -> tuple[AppState, Image.Image, int, i
669
  slider_minmax = gr.update(minimum=0, maximum=max(GLOBAL_STATE.num_frames - 1, 0), interactive=True)
670
  slider_value = gr.update(value=current_idx)
671
  status = "Session reset. Prompts cleared; video preserved."
672
- return GLOBAL_STATE, preview_img, slider_minmax, slider_value, status
 
673
 
674
 
675
  def _on_video_change_pointbox(GLOBAL_STATE: gr.State, video):
@@ -684,11 +889,13 @@ def _on_video_change_pointbox(GLOBAL_STATE: gr.State, video):
684
 
685
  def _on_video_change_text(GLOBAL_STATE: gr.State, video):
686
  GLOBAL_STATE, min_idx, max_idx, first_frame, status = init_video_session(GLOBAL_STATE, video, "text")
 
687
  return (
688
  GLOBAL_STATE,
689
  gr.update(minimum=min_idx, maximum=max_idx, value=min_idx, interactive=True),
690
  first_frame,
691
  status,
 
692
  )
693
 
694
 
@@ -712,7 +919,7 @@ with gr.Blocks(title="SAM3", theme=theme) as demo:
712
  """
713
  **Quick start**
714
  - **Load a video**: Upload your own or pick an example below.
715
- - Select a frame and enter a text description to segment objects (e.g., "red car", "penguin"). The text prompt will return all the instances of the object in the frame and not specific ones (e.g. not "penguin on the left" but "penguin").
716
  """
717
  )
718
  with gr.Column():
@@ -741,11 +948,14 @@ with gr.Blocks(title="SAM3", theme=theme) as demo:
741
  propagate_status_text = gr.Markdown(visible=True)
742
  with gr.Row():
743
  text_prompt_input = gr.Textbox(
744
- label="Text Prompt",
745
- placeholder="Enter a text description (e.g., 'person', 'red car', 'short hair')",
746
  lines=2,
747
  )
748
- text_apply_btn = gr.Button("Apply Text Prompt", variant="primary")
 
 
 
749
  text_status = gr.Markdown(visible=True)
750
 
751
  with gr.Row():
@@ -762,7 +972,7 @@ with gr.Blocks(title="SAM3", theme=theme) as demo:
762
  examples=examples_list_text,
763
  inputs=[GLOBAL_STATE, video_in_text],
764
  fn=_on_video_change_text,
765
- outputs=[GLOBAL_STATE, frame_slider_text, preview_text, load_status_text],
766
  label="Examples",
767
  cache_examples=False,
768
  examples_per_page=5,
@@ -790,7 +1000,9 @@ with gr.Blocks(title="SAM3", theme=theme) as demo:
790
 
791
  with gr.Row():
792
  with gr.Column(scale=1):
793
- video_in_pointbox = gr.Video(label="Upload video", sources=["upload", "webcam"], interactive=True, max_length=7)
 
 
794
  load_status_pointbox = gr.Markdown(visible=True)
795
  reset_btn_pointbox = gr.Button("Reset Session", variant="secondary")
796
  with gr.Column(scale=2):
@@ -850,7 +1062,7 @@ with gr.Blocks(title="SAM3", theme=theme) as demo:
850
  video_in_text.change(
851
  _on_video_change_text,
852
  inputs=[GLOBAL_STATE, video_in_text],
853
- outputs=[GLOBAL_STATE, frame_slider_text, preview_text, load_status_text],
854
  show_progress=True,
855
  )
856
 
@@ -903,13 +1115,19 @@ with gr.Blocks(title="SAM3", theme=theme) as demo:
903
  )
904
 
905
  def _on_text_apply(state: AppState, frame_idx: int, text: str):
906
- img, status = on_text_prompt(state, frame_idx, text)
907
- return img, status
908
 
909
  text_apply_btn.click(
910
  _on_text_apply,
911
  inputs=[GLOBAL_STATE, frame_slider_text, text_prompt_input],
912
- outputs=[preview_text, text_status],
 
 
 
 
 
 
913
  )
914
 
915
  def _render_video(s: AppState):
@@ -962,7 +1180,14 @@ with gr.Blocks(title="SAM3", theme=theme) as demo:
962
  reset_btn_text.click(
963
  reset_session,
964
  inputs=GLOBAL_STATE,
965
- outputs=[GLOBAL_STATE, preview_text, frame_slider_text, frame_slider_text, load_status_text],
 
 
 
 
 
 
 
966
  )
967
 
968
 
 
 
1
  import colorsys
2
  import gc
3
  import os
 
9
  import torch
10
  from gradio.themes import Soft
11
  from PIL import Image, ImageDraw, ImageFont
12
+
13
  from transformers import Sam3TrackerVideoModel, Sam3TrackerVideoProcessor, Sam3VideoModel, Sam3VideoProcessor
14
 
15
+
16
  def get_device_and_dtype() -> tuple[str, torch.dtype]:
17
  device = "cuda" if torch.cuda.is_available() else "cpu"
18
  dtype = torch.bfloat16
 
88
  return int(r_f * 255), int(g_f * 255), int(b_f * 255)
89
 
90
 
91
+ def pastel_color_for_prompt(prompt_text: str) -> tuple[int, int, int]:
92
+ """Generate a consistent color for a prompt text using a deterministic hash."""
93
+ # Use a deterministic hash by summing character codes
94
+ # This ensures the same prompt always gets the same color
95
+ char_sum = sum(ord(c) for c in prompt_text)
96
+
97
+ # Use the sum to generate a hue that's well-distributed across the color spectrum
98
+ # Multiply by a large prime to spread values out
99
+ hue = ((char_sum * 2654435761) % 360) / 360.0
100
+
101
+ # Use pastel colors (lower saturation, high value)
102
+ saturation = 0.5
103
+ value = 0.95
104
+ r_f, g_f, b_f = colorsys.hsv_to_rgb(hue, saturation, value)
105
+ return int(r_f * 255), int(g_f * 255), int(b_f * 255)
106
+
107
+
108
  class AppState:
109
  def __init__(self):
110
  self.reset()
 
115
  self.video_fps: float | None = None
116
  self.masks_by_frame: dict[int, dict[int, np.ndarray]] = {}
117
  self.color_by_obj: dict[int, tuple[int, int, int]] = {}
118
+ self.color_by_prompt: dict[str, tuple[int, int, int]] = {}
119
  self.clicks_by_frame_obj: dict[int, dict[int, list[tuple[int, int, int]]]] = {}
120
  self.boxes_by_frame_obj: dict[int, dict[int, list[tuple[int, int, int, int]]]] = {}
121
  self.text_prompts_by_frame_obj: dict[int, dict[int, str]] = {}
 
138
  return len(self.video_frames)
139
 
140
 
 
 
141
  def init_video_session(
142
  GLOBAL_STATE: gr.State, video: str | dict, active_tab: str = "point_box"
143
  ) -> tuple[AppState, int, int, Image.Image, str]:
144
  GLOBAL_STATE.video_frames = []
145
  GLOBAL_STATE.masks_by_frame = {}
146
  GLOBAL_STATE.color_by_obj = {}
147
+ GLOBAL_STATE.color_by_prompt = {}
148
  GLOBAL_STATE.text_prompts_by_frame_obj = {}
149
  GLOBAL_STATE.clicks_by_frame_obj = {}
150
  GLOBAL_STATE.boxes_by_frame_obj = {}
 
198
  GLOBAL_STATE.inference_session = processor.init_video_session(
199
  video=raw_video,
200
  inference_device=device,
201
+ video_storage_device="cpu",
202
+ processing_device="cpu",
203
  inference_state_device=device,
204
  dtype=dtype,
 
205
  )
206
 
207
  first_frame = frames[0]
 
265
 
266
  if text_prompts_by_obj and len(masks) > 0:
267
  draw = ImageDraw.Draw(out_img)
268
+
269
+ # Calculate scale factor based on image size (reference: 720p height = 720)
270
+ img_width, img_height = out_img.size
271
+ reference_height = 720.0
272
+ scale_factor = img_height / reference_height
273
+
274
+ # Scale font size (base size ~13 pixels for default font, scale proportionally)
275
+ base_font_size = 13
276
+ font_size = max(10, int(base_font_size * scale_factor))
277
+
278
+ # Try to load a scalable font, fall back to default if not available
279
+ try:
280
+ # Try common system fonts
281
+ font_paths = [
282
+ "/usr/share/fonts/truetype/dejavu/DejaVuSans.ttf",
283
+ "/System/Library/Fonts/Helvetica.ttc",
284
+ "arial.ttf",
285
+ ]
286
+ font = None
287
+ for font_path in font_paths:
288
+ try:
289
+ font = ImageFont.truetype(font_path, font_size)
290
+ break
291
+ except (OSError, IOError):
292
+ continue
293
+ if font is None:
294
+ # Fallback to default font
295
+ font = ImageFont.load_default()
296
+ except Exception:
297
+ font = ImageFont.load_default()
298
 
299
  for obj_id, text_prompt in text_prompts_by_obj.items():
300
  obj_mask = masks.get(obj_id)
 
307
  y_min, y_max = np.where(rows)[0][[0, -1]]
308
  x_min, x_max = np.where(cols)[0][[0, -1]]
309
  label_x = int(x_min)
310
+ # Scale vertical offset and padding
311
+ vertical_offset = int(20 * scale_factor)
312
+ padding = max(2, int(4 * scale_factor))
313
+ label_y = int(y_min) - vertical_offset
314
+ label_y = max(int(5 * scale_factor), label_y)
315
 
316
  obj_color = state.color_by_obj.get(obj_id, (255, 255, 255))
317
 
318
  # Include object ID in the label
319
+ label_text = f"{text_prompt} - ID {obj_id}"
320
  bbox = draw.textbbox((label_x, label_y), label_text, font=font)
 
321
  draw.rectangle(
322
  [(bbox[0] - padding, bbox[1] - padding), (bbox[2] + padding, bbox[3] + padding)],
323
  fill=obj_color,
 
340
  return compose_frame(state, frame_idx)
341
 
342
 
343
+ def _get_prompt_for_obj(state: AppState, obj_id: int) -> Optional[str]:
344
+ """Get the prompt text associated with an object ID."""
345
+ # Priority 1: Check text_prompts_by_frame_obj (most reliable)
346
+ for frame_texts in state.text_prompts_by_frame_obj.values():
347
+ if obj_id in frame_texts:
348
+ return frame_texts[obj_id].strip()
349
+
350
+ # Priority 2: Check inference session mapping
351
+ if state.inference_session is not None:
352
+ if (
353
+ hasattr(state.inference_session, "obj_id_to_prompt_id")
354
+ and obj_id in state.inference_session.obj_id_to_prompt_id
355
+ ):
356
+ prompt_id = state.inference_session.obj_id_to_prompt_id[obj_id]
357
+ if hasattr(state.inference_session, "prompts") and prompt_id in state.inference_session.prompts:
358
+ return state.inference_session.prompts[prompt_id].strip()
359
+
360
+ return None
361
+
362
+
363
  def _ensure_color_for_obj(state: AppState, obj_id: int):
364
+ """Assign color to object based on its prompt if available, otherwise use object ID."""
365
+ prompt_text = _get_prompt_for_obj(state, obj_id)
366
+
367
+ if prompt_text is not None:
368
+ # Ensure prompt has a color assigned
369
+ if prompt_text not in state.color_by_prompt:
370
+ state.color_by_prompt[prompt_text] = pastel_color_for_prompt(prompt_text)
371
+ # Always update to prompt-based color
372
+ state.color_by_obj[obj_id] = state.color_by_prompt[prompt_text]
373
+ elif obj_id not in state.color_by_obj:
374
+ # Fallback to object ID-based color (for point/box prompting mode)
375
  state.color_by_obj[obj_id] = pastel_color_for_object(obj_id)
376
 
377
 
 
492
  state: AppState,
493
  frame_idx: int,
494
  text_prompt: str,
495
+ ) -> tuple[Image.Image, str, str]:
496
  if state is None or state.inference_session is None:
497
+ return None, "Upload a video and enter text prompt.", "**Active prompts:** None"
498
 
499
  model = _GLOBAL_TEXT_VIDEO_MODEL
500
  processor = _GLOBAL_TEXT_VIDEO_PROCESSOR
501
 
502
  if not text_prompt or not text_prompt.strip():
503
+ active_prompts = _get_active_prompts_display(state)
504
+ return update_frame_display(state, int(frame_idx)), "Please enter a text prompt.", active_prompts
505
 
506
  frame_idx = int(np.clip(frame_idx, 0, len(state.video_frames) - 1))
507
 
508
+ # Parse comma-separated prompts or single prompt
509
+ prompt_texts = [p.strip() for p in text_prompt.split(",") if p.strip()]
510
+ if not prompt_texts:
511
+ active_prompts = _get_active_prompts_display(state)
512
+ return update_frame_display(state, int(frame_idx)), "Please enter a valid text prompt.", active_prompts
513
+
514
+ # Add text prompt(s) - supports both single string and list of strings
515
  state.inference_session = processor.add_text_prompt(
516
  inference_session=state.inference_session,
517
+ text=prompt_texts, # Pass as list to add multiple at once
518
  )
519
 
520
  masks_for_frame = state.masks_by_frame.setdefault(frame_idx, {})
 
522
 
523
  num_objects = 0
524
  detected_obj_ids = []
525
+ prompt_to_obj_ids_summary = {}
526
+
527
  with torch.no_grad():
528
  for model_outputs in model.propagate_in_video_iterator(
529
  inference_session=state.inference_session,
 
540
  object_ids = processed_outputs["object_ids"]
541
  masks = processed_outputs["masks"]
542
  scores = processed_outputs["scores"]
543
+ prompt_to_obj_ids = processed_outputs.get("prompt_to_obj_ids", {})
544
+
545
+ # Update prompt_to_obj_ids summary for status message
546
+ for prompt, obj_ids in prompt_to_obj_ids.items():
547
+ if prompt not in prompt_to_obj_ids_summary:
548
+ prompt_to_obj_ids_summary[prompt] = []
549
+ prompt_to_obj_ids_summary[prompt].extend(
550
+ [int(oid) for oid in obj_ids if int(oid) not in prompt_to_obj_ids_summary[prompt]]
551
+ )
552
 
553
  num_objects = len(object_ids)
554
  if num_objects > 0:
 
560
  for mask_idx in sorted_indices:
561
  current_obj_id = int(object_ids[mask_idx].item())
562
  detected_obj_ids.append(current_obj_id)
 
563
  mask_2d = masks[mask_idx].float().cpu().numpy()
564
  if mask_2d.ndim == 3:
565
  mask_2d = mask_2d.squeeze()
566
  mask_2d = (mask_2d > 0.0).astype(np.float32)
567
  masks_for_frame[current_obj_id] = mask_2d
568
+
569
+ # Find which prompt detected this object
570
+ detected_prompt = None
571
+ for prompt, obj_ids in prompt_to_obj_ids.items():
572
+ if current_obj_id in obj_ids:
573
+ detected_prompt = prompt
574
+ break
575
+
576
+ # Store prompt and assign color
577
+ if detected_prompt:
578
+ frame_texts[current_obj_id] = detected_prompt.strip()
579
+ _ensure_color_for_obj(state, current_obj_id)
580
 
581
  state.composited_frames.pop(frame_idx, None)
582
 
583
+ # Build status message with prompt breakdown
584
  if detected_obj_ids:
585
+ status_parts = [f"Processed text prompt(s) on frame {frame_idx}. Found {num_objects} object(s):"]
586
+ for prompt, obj_ids in prompt_to_obj_ids_summary.items():
587
+ if obj_ids:
588
+ obj_ids_str = ", ".join(map(str, sorted(obj_ids)))
589
+ status_parts.append(f" • '{prompt}': {len(obj_ids)} object(s) (IDs: {obj_ids_str})")
590
+ status = "\n".join(status_parts)
591
  else:
592
+ prompts_str = ", ".join([f"'{p}'" for p in prompt_texts])
593
+ status = f"Processed text prompt(s) {prompts_str} on frame {frame_idx}. No objects detected."
594
+
595
+ active_prompts = _get_active_prompts_display(state)
596
+ return update_frame_display(state, int(frame_idx)), status, active_prompts
597
+
598
+
599
+ def _get_active_prompts_display(state: AppState) -> str:
600
+ """Get a formatted string showing all active prompts in the inference session."""
601
+ if state is None or state.inference_session is None:
602
+ return "**Active prompts:** None"
603
+
604
+ if hasattr(state.inference_session, "prompts") and state.inference_session.prompts:
605
+ prompts_list = sorted(set(state.inference_session.prompts.values()))
606
+ if prompts_list:
607
+ prompts_str = ", ".join([f"'{p}'" for p in prompts_list])
608
+ return f"**Active prompts:** {prompts_str}"
609
+
610
+ return "**Active prompts:** None"
611
 
612
 
613
  def propagate_masks(GLOBAL_STATE: gr.State):
 
633
  model = _GLOBAL_TEXT_VIDEO_MODEL
634
  processor = _GLOBAL_TEXT_VIDEO_PROCESSOR
635
 
636
+ # Collect all unique prompts from existing frame annotations
637
  text_prompt_to_obj_ids = {}
638
  for frame_idx, frame_texts in GLOBAL_STATE.text_prompts_by_frame_obj.items():
639
  for obj_id, text_prompt in frame_texts.items():
 
642
  if obj_id not in text_prompt_to_obj_ids[text_prompt]:
643
  text_prompt_to_obj_ids[text_prompt].append(obj_id)
644
 
645
+ # Also check if there are prompts already in the inference session
646
+ if hasattr(GLOBAL_STATE.inference_session, "prompts") and GLOBAL_STATE.inference_session.prompts:
647
+ for prompt_text in GLOBAL_STATE.inference_session.prompts.values():
648
+ if prompt_text not in text_prompt_to_obj_ids:
649
+ text_prompt_to_obj_ids[prompt_text] = []
650
+
651
  for text_prompt in text_prompt_to_obj_ids:
652
  text_prompt_to_obj_ids[text_prompt].sort()
653
 
 
655
  yield GLOBAL_STATE, "No text prompts found. Please add a text prompt first.", gr.update()
656
  return
657
 
658
+ # Add all prompts to the inference session (processor handles deduplication)
659
  for text_prompt in text_prompt_to_obj_ids.keys():
660
  GLOBAL_STATE.inference_session = processor.add_text_prompt(
661
  inference_session=GLOBAL_STATE.inference_session,
 
685
  object_ids = processed_outputs["object_ids"]
686
  masks = processed_outputs["masks"]
687
  scores = processed_outputs["scores"]
688
+ prompt_to_obj_ids = processed_outputs.get("prompt_to_obj_ids", {})
689
 
690
  masks_for_frame = GLOBAL_STATE.masks_by_frame.setdefault(frame_idx, {})
691
  frame_texts = GLOBAL_STATE.text_prompts_by_frame_obj.setdefault(frame_idx, {})
 
699
 
700
  for mask_idx in sorted_indices:
701
  current_obj_id = int(object_ids[mask_idx].item())
 
702
  mask_2d = masks[mask_idx].float().cpu().numpy()
703
  if mask_2d.ndim == 3:
704
  mask_2d = mask_2d.squeeze()
705
  mask_2d = (mask_2d > 0.0).astype(np.float32)
706
  masks_for_frame[current_obj_id] = mask_2d
707
 
708
+ # Find which prompt detected this object
709
  found_prompt = None
710
+ for prompt, obj_ids in prompt_to_obj_ids.items():
711
+ if current_obj_id in obj_ids:
712
+ found_prompt = prompt
713
  break
714
 
715
+ # Store prompt and assign color
 
 
716
  if found_prompt:
717
+ frame_texts[current_obj_id] = found_prompt.strip()
718
+ _ensure_color_for_obj(GLOBAL_STATE, current_obj_id)
719
 
720
  GLOBAL_STATE.composited_frames.pop(frame_idx, None)
721
  last_frame_idx = frame_idx
 
757
  yield GLOBAL_STATE, text, gr.update(value=last_frame_idx)
758
 
759
 
760
+ def reset_prompts(GLOBAL_STATE: gr.State) -> tuple[AppState, Image.Image, str, str]:
761
+ """Reset prompts and all outputs, but keep processed frames and cached vision features."""
762
+ if GLOBAL_STATE is None or GLOBAL_STATE.inference_session is None:
763
+ active_prompts = _get_active_prompts_display(GLOBAL_STATE)
764
+ return GLOBAL_STATE, None, "No active session to reset.", active_prompts
765
+
766
+ if GLOBAL_STATE.active_tab != "text":
767
+ active_prompts = _get_active_prompts_display(GLOBAL_STATE)
768
+ return GLOBAL_STATE, None, "Reset prompts is only available for text prompting mode.", active_prompts
769
+
770
+ # Reset inference session tracking data but keep cache and processed frames
771
+ if hasattr(GLOBAL_STATE.inference_session, "reset_tracking_data"):
772
+ GLOBAL_STATE.inference_session.reset_tracking_data()
773
+
774
+ # Manually clear prompts (reset_tracking_data doesn't clear prompts themselves)
775
+ if hasattr(GLOBAL_STATE.inference_session, "prompts"):
776
+ GLOBAL_STATE.inference_session.prompts.clear()
777
+ if hasattr(GLOBAL_STATE.inference_session, "prompt_input_ids"):
778
+ GLOBAL_STATE.inference_session.prompt_input_ids.clear()
779
+ if hasattr(GLOBAL_STATE.inference_session, "prompt_embeddings"):
780
+ GLOBAL_STATE.inference_session.prompt_embeddings.clear()
781
+ if hasattr(GLOBAL_STATE.inference_session, "prompt_attention_masks"):
782
+ GLOBAL_STATE.inference_session.prompt_attention_masks.clear()
783
+ if hasattr(GLOBAL_STATE.inference_session, "obj_id_to_prompt_id"):
784
+ GLOBAL_STATE.inference_session.obj_id_to_prompt_id.clear()
785
+
786
+ # Reset detection-tracking fusion state
787
+ if hasattr(GLOBAL_STATE.inference_session, "obj_id_to_score"):
788
+ GLOBAL_STATE.inference_session.obj_id_to_score.clear()
789
+ if hasattr(GLOBAL_STATE.inference_session, "obj_id_to_tracker_score_frame_wise"):
790
+ GLOBAL_STATE.inference_session.obj_id_to_tracker_score_frame_wise.clear()
791
+ if hasattr(GLOBAL_STATE.inference_session, "obj_id_to_last_occluded"):
792
+ GLOBAL_STATE.inference_session.obj_id_to_last_occluded.clear()
793
+ if hasattr(GLOBAL_STATE.inference_session, "max_obj_id"):
794
+ GLOBAL_STATE.inference_session.max_obj_id = -1
795
+ if hasattr(GLOBAL_STATE.inference_session, "obj_first_frame_idx"):
796
+ GLOBAL_STATE.inference_session.obj_first_frame_idx.clear()
797
+ if hasattr(GLOBAL_STATE.inference_session, "unmatched_frame_inds"):
798
+ GLOBAL_STATE.inference_session.unmatched_frame_inds.clear()
799
+ if hasattr(GLOBAL_STATE.inference_session, "overlap_pair_to_frame_inds"):
800
+ GLOBAL_STATE.inference_session.overlap_pair_to_frame_inds.clear()
801
+ if hasattr(GLOBAL_STATE.inference_session, "trk_keep_alive"):
802
+ GLOBAL_STATE.inference_session.trk_keep_alive.clear()
803
+ if hasattr(GLOBAL_STATE.inference_session, "removed_obj_ids"):
804
+ GLOBAL_STATE.inference_session.removed_obj_ids.clear()
805
+ if hasattr(GLOBAL_STATE.inference_session, "suppressed_obj_ids"):
806
+ GLOBAL_STATE.inference_session.suppressed_obj_ids.clear()
807
+ if hasattr(GLOBAL_STATE.inference_session, "hotstart_removed_obj_ids"):
808
+ GLOBAL_STATE.inference_session.hotstart_removed_obj_ids.clear()
809
+
810
+ # Clear all app state outputs
811
+ GLOBAL_STATE.masks_by_frame.clear()
812
+ GLOBAL_STATE.text_prompts_by_frame_obj.clear()
813
+ GLOBAL_STATE.composited_frames.clear()
814
+ GLOBAL_STATE.color_by_obj.clear()
815
+ GLOBAL_STATE.color_by_prompt.clear()
816
+
817
+ # Update display
818
+ current_idx = int(getattr(GLOBAL_STATE, "current_frame_idx", 0))
819
+ current_idx = max(0, min(current_idx, GLOBAL_STATE.num_frames - 1))
820
+ preview_img = update_frame_display(GLOBAL_STATE, current_idx)
821
+ active_prompts = _get_active_prompts_display(GLOBAL_STATE)
822
+ status = "Prompts and outputs reset. Processed frames and cached vision features preserved."
823
+
824
+ return GLOBAL_STATE, preview_img, status, active_prompts
825
+
826
+
827
+ def reset_session(GLOBAL_STATE: gr.State) -> tuple[AppState, Image.Image, int, int, str, str]:
828
  if not GLOBAL_STATE.video_frames:
829
+ return GLOBAL_STATE, None, 0, 0, "Session reset. Load a new video.", "**Active prompts:** None"
830
 
831
  if GLOBAL_STATE.active_tab == "text":
832
  if GLOBAL_STATE.video_frames:
 
849
  GLOBAL_STATE.inference_session = processor.init_video_session(
850
  video=raw_video,
851
  inference_device=_GLOBAL_DEVICE,
852
+ video_storage_device="cpu",
853
+ processing_device="cpu",
 
854
  dtype=_GLOBAL_DTYPE,
 
855
  )
856
 
857
  GLOBAL_STATE.masks_by_frame.clear()
 
859
  GLOBAL_STATE.boxes_by_frame_obj.clear()
860
  GLOBAL_STATE.text_prompts_by_frame_obj.clear()
861
  GLOBAL_STATE.composited_frames.clear()
862
+ GLOBAL_STATE.color_by_obj.clear()
863
+ GLOBAL_STATE.color_by_prompt.clear()
864
  GLOBAL_STATE.pending_box_start = None
865
  GLOBAL_STATE.pending_box_start_frame_idx = None
866
  GLOBAL_STATE.pending_box_start_obj_id = None
 
873
  slider_minmax = gr.update(minimum=0, maximum=max(GLOBAL_STATE.num_frames - 1, 0), interactive=True)
874
  slider_value = gr.update(value=current_idx)
875
  status = "Session reset. Prompts cleared; video preserved."
876
+ active_prompts = _get_active_prompts_display(GLOBAL_STATE)
877
+ return GLOBAL_STATE, preview_img, slider_minmax, slider_value, status, active_prompts
878
 
879
 
880
  def _on_video_change_pointbox(GLOBAL_STATE: gr.State, video):
 
889
 
890
  def _on_video_change_text(GLOBAL_STATE: gr.State, video):
891
  GLOBAL_STATE, min_idx, max_idx, first_frame, status = init_video_session(GLOBAL_STATE, video, "text")
892
+ active_prompts = _get_active_prompts_display(GLOBAL_STATE)
893
  return (
894
  GLOBAL_STATE,
895
  gr.update(minimum=min_idx, maximum=max_idx, value=min_idx, interactive=True),
896
  first_frame,
897
  status,
898
+ active_prompts,
899
  )
900
 
901
 
 
919
  """
920
  **Quick start**
921
  - **Load a video**: Upload your own or pick an example below.
922
+ - Select a frame and enter text description(s) to segment objects (e.g., "red car", "penguin"). You can add multiple prompts separated by commas (e.g., "person, bed, lamp") or add them one by one. The text prompt will return all the instances of the object in the frame and not specific ones (e.g. not "penguin on the left" but "penguin").
923
  """
924
  )
925
  with gr.Column():
 
948
  propagate_status_text = gr.Markdown(visible=True)
949
  with gr.Row():
950
  text_prompt_input = gr.Textbox(
951
+ label="Text Prompt(s)",
952
+ placeholder="Enter text description(s) (e.g., 'person' or 'person, bed, lamp' for multiple)",
953
  lines=2,
954
  )
955
+ with gr.Column(scale=0):
956
+ text_apply_btn = gr.Button("Apply Text Prompt(s)", variant="primary")
957
+ reset_prompts_btn = gr.Button("Reset Prompts", variant="secondary")
958
+ active_prompts_display = gr.Markdown("**Active prompts:** None", visible=True)
959
  text_status = gr.Markdown(visible=True)
960
 
961
  with gr.Row():
 
972
  examples=examples_list_text,
973
  inputs=[GLOBAL_STATE, video_in_text],
974
  fn=_on_video_change_text,
975
+ outputs=[GLOBAL_STATE, frame_slider_text, preview_text, load_status_text, active_prompts_display],
976
  label="Examples",
977
  cache_examples=False,
978
  examples_per_page=5,
 
1000
 
1001
  with gr.Row():
1002
  with gr.Column(scale=1):
1003
+ video_in_pointbox = gr.Video(
1004
+ label="Upload video", sources=["upload", "webcam"], interactive=True, max_length=7
1005
+ )
1006
  load_status_pointbox = gr.Markdown(visible=True)
1007
  reset_btn_pointbox = gr.Button("Reset Session", variant="secondary")
1008
  with gr.Column(scale=2):
 
1062
  video_in_text.change(
1063
  _on_video_change_text,
1064
  inputs=[GLOBAL_STATE, video_in_text],
1065
+ outputs=[GLOBAL_STATE, frame_slider_text, preview_text, load_status_text, active_prompts_display],
1066
  show_progress=True,
1067
  )
1068
 
 
1115
  )
1116
 
1117
  def _on_text_apply(state: AppState, frame_idx: int, text: str):
1118
+ img, status, active_prompts = on_text_prompt(state, frame_idx, text)
1119
+ return img, status, active_prompts
1120
 
1121
  text_apply_btn.click(
1122
  _on_text_apply,
1123
  inputs=[GLOBAL_STATE, frame_slider_text, text_prompt_input],
1124
+ outputs=[preview_text, text_status, active_prompts_display],
1125
+ )
1126
+
1127
+ reset_prompts_btn.click(
1128
+ reset_prompts,
1129
+ inputs=[GLOBAL_STATE],
1130
+ outputs=[GLOBAL_STATE, preview_text, text_status, active_prompts_display],
1131
  )
1132
 
1133
  def _render_video(s: AppState):
 
1180
  reset_btn_text.click(
1181
  reset_session,
1182
  inputs=GLOBAL_STATE,
1183
+ outputs=[
1184
+ GLOBAL_STATE,
1185
+ preview_text,
1186
+ frame_slider_text,
1187
+ frame_slider_text,
1188
+ load_status_text,
1189
+ active_prompts_display,
1190
+ ],
1191
  )
1192
 
1193