xiaoyuxi commited on
Commit
0166665
·
1 Parent(s): 49d1cf3
Files changed (2) hide show
  1. app.py +4 -4
  2. models/SpaTrackV2/models/predictor.py +3 -3
app.py CHANGED
@@ -122,7 +122,7 @@ def gpu_run_inference(predictor_arg, image, points, boxes):
122
  return run_inference(predictor_arg, image, points, boxes)
123
 
124
  @spaces.GPU
125
- def gpu_run_tracker(tracker_model_arg, tracker_viser_arg, temp_dir, video_name, grid_size, vo_points, fps):
126
  """GPU-accelerated tracking"""
127
  import torchvision.transforms as T
128
  import decord
@@ -211,7 +211,7 @@ def gpu_run_tracker(tracker_model_arg, tracker_viser_arg, temp_dir, video_name,
211
  intrs=intrs, extrs=extrs,
212
  queries=query_xyt,
213
  fps=1, full_point=False, iters_track=4,
214
- query_no_BA=True, fixed_cam=False, stage=1,
215
  support_frame=len(video_tensor)-1, replace_ratio=0.2)
216
 
217
  # Resize results to avoid large I/O
@@ -530,7 +530,7 @@ def reset_points(original_img: str, sel_pix):
530
  print(f"❌ Error in reset_points: {e}")
531
  return None, []
532
 
533
- def launch_viz(grid_size, vo_points, fps, original_image_state):
534
  """Launch visualization with user-specific temp directory"""
535
  if original_image_state is None:
536
  return None, None, None
@@ -560,7 +560,7 @@ def launch_viz(grid_size, vo_points, fps, original_image_state):
560
  out_dir = os.path.join(temp_dir, "results")
561
  os.makedirs(out_dir, exist_ok=True)
562
 
563
- gpu_run_tracker(None, None, temp_dir, video_name, grid_size, vo_points, fps)
564
 
565
  # Process results
566
  npz_path = os.path.join(out_dir, "result.npz")
 
122
  return run_inference(predictor_arg, image, points, boxes)
123
 
124
  @spaces.GPU
125
+ def gpu_run_tracker(tracker_model_arg, tracker_viser_arg, temp_dir, video_name, grid_size, vo_points, fps, mode="offline"):
126
  """GPU-accelerated tracking"""
127
  import torchvision.transforms as T
128
  import decord
 
211
  intrs=intrs, extrs=extrs,
212
  queries=query_xyt,
213
  fps=1, full_point=False, iters_track=4,
214
+ query_no_BA=True, fixed_cam=False, stage=1, unc_metric=unc_metric,
215
  support_frame=len(video_tensor)-1, replace_ratio=0.2)
216
 
217
  # Resize results to avoid large I/O
 
530
  print(f"❌ Error in reset_points: {e}")
531
  return None, []
532
 
533
+ def launch_viz(grid_size, vo_points, fps, original_image_state, mode="offline"):
534
  """Launch visualization with user-specific temp directory"""
535
  if original_image_state is None:
536
  return None, None, None
 
560
  out_dir = os.path.join(temp_dir, "results")
561
  os.makedirs(out_dir, exist_ok=True)
562
 
563
+ gpu_run_tracker(None, None, temp_dir, video_name, grid_size, vo_points, fps, mode=mode)
564
 
565
  # Process results
566
  npz_path = os.path.join(out_dir, "result.npz")
models/SpaTrackV2/models/predictor.py CHANGED
@@ -22,8 +22,8 @@ class Predictor(torch.nn.Module):
22
  super().__init__()
23
  self.args = args
24
  self.spatrack = SpaTrack2(loggers=[None, None, None], **args)
25
- self.S_wind = 200
26
- self.overlap = 8
27
 
28
  def to(self, device: Union[str, torch.device]):
29
  self.spatrack.to(device)
@@ -138,7 +138,7 @@ class Predictor(torch.nn.Module):
138
  if extrs is not None:
139
  extrs = torch.cat([extrs, extrs[-1:].repeat(T-extrs.shape[0], 1, 1)], dim=0)
140
  if unc_metric is not None:
141
- unc_metric = torch.cat([unc_metric, unc_metric[-1:].repeat(T-unc_metric.shape[0], 1)], dim=0)
142
  with torch.no_grad():
143
  ret = self.spatrack.forward_stream(video, queries, T_org=T_,
144
  depth=depth, intrs=intrs, unc_metric_in=unc_metric, extrs=extrs, queries_3d=queries_3d,
 
22
  super().__init__()
23
  self.args = args
24
  self.spatrack = SpaTrack2(loggers=[None, None, None], **args)
25
+ self.S_wind = args.Track_cfg.s_wind
26
+ self.overlap = args.Track_cfg.overlap
27
 
28
  def to(self, device: Union[str, torch.device]):
29
  self.spatrack.to(device)
 
138
  if extrs is not None:
139
  extrs = torch.cat([extrs, extrs[-1:].repeat(T-extrs.shape[0], 1, 1)], dim=0)
140
  if unc_metric is not None:
141
+ unc_metric = torch.cat([unc_metric, unc_metric[-1:].repeat(T-unc_metric.shape[0], 1, 1)], dim=0)
142
  with torch.no_grad():
143
  ret = self.spatrack.forward_stream(video, queries, T_org=T_,
144
  depth=depth, intrs=intrs, unc_metric_in=unc_metric, extrs=extrs, queries_3d=queries_3d,