Spaces:
Running
on
Zero
Running
on
Zero
xiaoyuxi
commited on
Commit
·
151b615
1
Parent(s):
9193cab
add online
Browse files
app.py
CHANGED
|
@@ -43,7 +43,9 @@ except ImportError as e:
|
|
| 43 |
raise
|
| 44 |
|
| 45 |
# Constants
|
| 46 |
-
|
|
|
|
|
|
|
| 47 |
COLORS = [(0, 0, 255), (0, 255, 255)] # BGR: Red for negative, Yellow for positive
|
| 48 |
MARKERS = [1, 5] # Cross for negative, Star for positive
|
| 49 |
MARKER_SIZE = 8
|
|
@@ -88,8 +90,10 @@ vggt4track_model = vggt4track_model.to("cuda")
|
|
| 88 |
|
| 89 |
# Global model initialization
|
| 90 |
print("🚀 Initializing local models...")
|
| 91 |
-
|
| 92 |
-
|
|
|
|
|
|
|
| 93 |
predictor = get_sam_predictor()
|
| 94 |
print("✅ Models loaded successfully!")
|
| 95 |
|
|
@@ -128,9 +132,13 @@ def gpu_run_tracker(tracker_model_arg, tracker_viser_arg, temp_dir, video_name,
|
|
| 128 |
if tracker_model_arg is None or tracker_viser_arg is None:
|
| 129 |
print("Initializing tracker models inside GPU function...")
|
| 130 |
out_dir = os.path.join(temp_dir, "results")
|
| 131 |
-
os.makedirs(out_dir, exist_ok=True)
|
| 132 |
-
|
| 133 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 134 |
|
| 135 |
# Setup paths
|
| 136 |
video_path = os.path.join(temp_dir, f"{video_name}.mp4")
|
|
@@ -148,7 +156,10 @@ def gpu_run_tracker(tracker_model_arg, tracker_viser_arg, temp_dir, video_name,
|
|
| 148 |
if scale < 1:
|
| 149 |
new_h, new_w = int(h * scale), int(w * scale)
|
| 150 |
video_tensor = T.Resize((new_h, new_w))(video_tensor)
|
| 151 |
-
|
|
|
|
|
|
|
|
|
|
| 152 |
|
| 153 |
# Move to GPU
|
| 154 |
video_tensor = video_tensor.cuda()
|
|
@@ -526,7 +537,7 @@ def reset_points(original_img: str, sel_pix):
|
|
| 526 |
print(f"❌ Error in reset_points: {e}")
|
| 527 |
return None, []
|
| 528 |
|
| 529 |
-
def launch_viz(grid_size, vo_points, fps, original_image_state,
|
| 530 |
"""Launch visualization with user-specific temp directory"""
|
| 531 |
if original_image_state is None:
|
| 532 |
return None, None, None
|
|
@@ -538,7 +549,7 @@ def launch_viz(grid_size, vo_points, fps, original_image_state, mode="offline"):
|
|
| 538 |
video_name = frame_data.get('video_name', 'video')
|
| 539 |
|
| 540 |
print(f"🚀 Starting tracking for video: {video_name}")
|
| 541 |
-
print(f"📊 Parameters: grid_size={grid_size}, vo_points={vo_points}, fps={fps}")
|
| 542 |
|
| 543 |
# Check for mask files
|
| 544 |
mask_files = glob.glob(os.path.join(temp_dir, "*.png"))
|
|
@@ -552,11 +563,11 @@ def launch_viz(grid_size, vo_points, fps, original_image_state, mode="offline"):
|
|
| 552 |
mask_path = mask_files[0] if mask_files else None
|
| 553 |
|
| 554 |
# Run tracker
|
| 555 |
-
print("🎯 Running tracker...")
|
| 556 |
out_dir = os.path.join(temp_dir, "results")
|
| 557 |
os.makedirs(out_dir, exist_ok=True)
|
| 558 |
|
| 559 |
-
gpu_run_tracker(None, None, temp_dir, video_name, grid_size, vo_points, fps, mode=
|
| 560 |
|
| 561 |
# Process results
|
| 562 |
npz_path = os.path.join(out_dir, "result.npz")
|
|
@@ -609,6 +620,7 @@ def clear_all_with_download():
|
|
| 609 |
gr.update(value=50),
|
| 610 |
gr.update(value=756),
|
| 611 |
gr.update(value=3),
|
|
|
|
| 612 |
None, # tracking_video_download
|
| 613 |
None) # HTML download component
|
| 614 |
|
|
@@ -641,6 +653,13 @@ def get_video_settings(video_name):
|
|
| 641 |
|
| 642 |
return video_settings.get(video_name, (50, 756, 3))
|
| 643 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 644 |
# Create the Gradio interface
|
| 645 |
print("🎨 Creating Gradio interface...")
|
| 646 |
|
|
@@ -846,7 +865,7 @@ with gr.Blocks(
|
|
| 846 |
""")
|
| 847 |
|
| 848 |
# Status indicator
|
| 849 |
-
gr.Markdown("**Status:** 🟢 Local Processing Mode")
|
| 850 |
|
| 851 |
# Main content area - video upload left, 3D visualization right
|
| 852 |
with gr.Row():
|
|
@@ -945,18 +964,29 @@ with gr.Blocks(
|
|
| 945 |
with gr.Row():
|
| 946 |
gr.Markdown("### ⚙️ Tracking Parameters")
|
| 947 |
with gr.Row():
|
| 948 |
-
|
| 949 |
-
|
| 950 |
-
|
| 951 |
-
|
| 952 |
-
|
| 953 |
-
|
| 954 |
-
|
| 955 |
-
|
| 956 |
-
|
| 957 |
-
|
| 958 |
-
|
| 959 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 960 |
|
| 961 |
# Advanced Point Selection with SAM - Collapsed by default
|
| 962 |
with gr.Row():
|
|
@@ -1082,6 +1112,12 @@ with gr.Blocks(
|
|
| 1082 |
outputs=[original_image_state, interactive_frame, selected_points, grid_size, vo_points, fps]
|
| 1083 |
)
|
| 1084 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1085 |
interactive_frame.select(
|
| 1086 |
fn=select_point,
|
| 1087 |
inputs=[original_image_state, selected_points, point_type],
|
|
@@ -1096,12 +1132,12 @@ with gr.Blocks(
|
|
| 1096 |
|
| 1097 |
clear_all_btn.click(
|
| 1098 |
fn=clear_all_with_download,
|
| 1099 |
-
outputs=[video_input, interactive_frame, selected_points, grid_size, vo_points, fps, tracking_video_download, html_download]
|
| 1100 |
)
|
| 1101 |
|
| 1102 |
launch_btn.click(
|
| 1103 |
fn=launch_viz,
|
| 1104 |
-
inputs=[grid_size, vo_points, fps, original_image_state],
|
| 1105 |
outputs=[viz_html, tracking_video_download, html_download]
|
| 1106 |
)
|
| 1107 |
|
|
|
|
| 43 |
raise
|
| 44 |
|
| 45 |
# Constants
|
| 46 |
+
MAX_FRAMES_OFFLINE = 80
|
| 47 |
+
MAX_FRAMES_ONLINE = 300
|
| 48 |
+
|
| 49 |
COLORS = [(0, 0, 255), (0, 255, 255)] # BGR: Red for negative, Yellow for positive
|
| 50 |
MARKERS = [1, 5] # Cross for negative, Star for positive
|
| 51 |
MARKER_SIZE = 8
|
|
|
|
| 90 |
|
| 91 |
# Global model initialization
|
| 92 |
print("🚀 Initializing local models...")
|
| 93 |
+
tracker_model_offline = Predictor.from_pretrained("Yuxihenry/SpatialTrackerV2-Offline")
|
| 94 |
+
tracker_model_offline.eval()
|
| 95 |
+
tracker_model_online = Predictor.from_pretrained("Yuxihenry/SpatialTrackerV2-Online")
|
| 96 |
+
tracker_model_online.eval()
|
| 97 |
predictor = get_sam_predictor()
|
| 98 |
print("✅ Models loaded successfully!")
|
| 99 |
|
|
|
|
| 132 |
if tracker_model_arg is None or tracker_viser_arg is None:
|
| 133 |
print("Initializing tracker models inside GPU function...")
|
| 134 |
out_dir = os.path.join(temp_dir, "results")
|
| 135 |
+
os.makedirs(out_dir, exist_ok=True)
|
| 136 |
+
if mode == "offline":
|
| 137 |
+
tracker_model_arg, tracker_viser_arg = get_tracker_predictor(out_dir, vo_points=vo_points,
|
| 138 |
+
tracker_model=tracker_model_offline.cuda())
|
| 139 |
+
else:
|
| 140 |
+
tracker_model_arg, tracker_viser_arg = get_tracker_predictor(out_dir, vo_points=vo_points,
|
| 141 |
+
tracker_model=tracker_model_online.cuda())
|
| 142 |
|
| 143 |
# Setup paths
|
| 144 |
video_path = os.path.join(temp_dir, f"{video_name}.mp4")
|
|
|
|
| 156 |
if scale < 1:
|
| 157 |
new_h, new_w = int(h * scale), int(w * scale)
|
| 158 |
video_tensor = T.Resize((new_h, new_w))(video_tensor)
|
| 159 |
+
if mode == "offline":
|
| 160 |
+
video_tensor = video_tensor[::fps].float()[:MAX_FRAMES_OFFLINE]
|
| 161 |
+
else:
|
| 162 |
+
video_tensor = video_tensor[::fps].float()[:MAX_FRAMES_ONLINE]
|
| 163 |
|
| 164 |
# Move to GPU
|
| 165 |
video_tensor = video_tensor.cuda()
|
|
|
|
| 537 |
print(f"❌ Error in reset_points: {e}")
|
| 538 |
return None, []
|
| 539 |
|
| 540 |
+
def launch_viz(grid_size, vo_points, fps, original_image_state, processing_mode):
|
| 541 |
"""Launch visualization with user-specific temp directory"""
|
| 542 |
if original_image_state is None:
|
| 543 |
return None, None, None
|
|
|
|
| 549 |
video_name = frame_data.get('video_name', 'video')
|
| 550 |
|
| 551 |
print(f"🚀 Starting tracking for video: {video_name}")
|
| 552 |
+
print(f"📊 Parameters: grid_size={grid_size}, vo_points={vo_points}, fps={fps}, mode={processing_mode}")
|
| 553 |
|
| 554 |
# Check for mask files
|
| 555 |
mask_files = glob.glob(os.path.join(temp_dir, "*.png"))
|
|
|
|
| 563 |
mask_path = mask_files[0] if mask_files else None
|
| 564 |
|
| 565 |
# Run tracker
|
| 566 |
+
print(f"🎯 Running tracker in {processing_mode} mode...")
|
| 567 |
out_dir = os.path.join(temp_dir, "results")
|
| 568 |
os.makedirs(out_dir, exist_ok=True)
|
| 569 |
|
| 570 |
+
gpu_run_tracker(None, None, temp_dir, video_name, grid_size, vo_points, fps, mode=processing_mode)
|
| 571 |
|
| 572 |
# Process results
|
| 573 |
npz_path = os.path.join(out_dir, "result.npz")
|
|
|
|
| 620 |
gr.update(value=50),
|
| 621 |
gr.update(value=756),
|
| 622 |
gr.update(value=3),
|
| 623 |
+
gr.update(value="offline"), # processing_mode
|
| 624 |
None, # tracking_video_download
|
| 625 |
None) # HTML download component
|
| 626 |
|
|
|
|
| 653 |
|
| 654 |
return video_settings.get(video_name, (50, 756, 3))
|
| 655 |
|
| 656 |
+
def update_status_indicator(processing_mode):
|
| 657 |
+
"""Update status indicator based on processing mode"""
|
| 658 |
+
if processing_mode == "offline":
|
| 659 |
+
return "**Status:** 🟢 Local Processing Mode (Offline)"
|
| 660 |
+
else:
|
| 661 |
+
return "**Status:** 🔵 Cloud Processing Mode (Online)"
|
| 662 |
+
|
| 663 |
# Create the Gradio interface
|
| 664 |
print("🎨 Creating Gradio interface...")
|
| 665 |
|
|
|
|
| 865 |
""")
|
| 866 |
|
| 867 |
# Status indicator
|
| 868 |
+
status_indicator = gr.Markdown("**Status:** 🟢 Local Processing Mode (Offline)")
|
| 869 |
|
| 870 |
# Main content area - video upload left, 3D visualization right
|
| 871 |
with gr.Row():
|
|
|
|
| 964 |
with gr.Row():
|
| 965 |
gr.Markdown("### ⚙️ Tracking Parameters")
|
| 966 |
with gr.Row():
|
| 967 |
+
# 添加模式选择器
|
| 968 |
+
with gr.Column(scale=1):
|
| 969 |
+
processing_mode = gr.Radio(
|
| 970 |
+
choices=["offline", "online"],
|
| 971 |
+
value="offline",
|
| 972 |
+
label="Processing Mode",
|
| 973 |
+
info="Offline: default mode | Online: Sliding Window Mode"
|
| 974 |
+
)
|
| 975 |
+
with gr.Column(scale=1):
|
| 976 |
+
grid_size = gr.Slider(
|
| 977 |
+
minimum=10, maximum=100, step=10, value=50,
|
| 978 |
+
label="Grid Size", info="Tracking detail level"
|
| 979 |
+
)
|
| 980 |
+
with gr.Column(scale=1):
|
| 981 |
+
vo_points = gr.Slider(
|
| 982 |
+
minimum=100, maximum=2000, step=50, value=756,
|
| 983 |
+
label="VO Points", info="Motion accuracy"
|
| 984 |
+
)
|
| 985 |
+
with gr.Column(scale=1):
|
| 986 |
+
fps = gr.Slider(
|
| 987 |
+
minimum=1, maximum=20, step=1, value=3,
|
| 988 |
+
label="FPS", info="Processing speed"
|
| 989 |
+
)
|
| 990 |
|
| 991 |
# Advanced Point Selection with SAM - Collapsed by default
|
| 992 |
with gr.Row():
|
|
|
|
| 1112 |
outputs=[original_image_state, interactive_frame, selected_points, grid_size, vo_points, fps]
|
| 1113 |
)
|
| 1114 |
|
| 1115 |
+
processing_mode.change(
|
| 1116 |
+
fn=update_status_indicator,
|
| 1117 |
+
inputs=[processing_mode],
|
| 1118 |
+
outputs=[status_indicator]
|
| 1119 |
+
)
|
| 1120 |
+
|
| 1121 |
interactive_frame.select(
|
| 1122 |
fn=select_point,
|
| 1123 |
inputs=[original_image_state, selected_points, point_type],
|
|
|
|
| 1132 |
|
| 1133 |
clear_all_btn.click(
|
| 1134 |
fn=clear_all_with_download,
|
| 1135 |
+
outputs=[video_input, interactive_frame, selected_points, grid_size, vo_points, fps, processing_mode, tracking_video_download, html_download]
|
| 1136 |
)
|
| 1137 |
|
| 1138 |
launch_btn.click(
|
| 1139 |
fn=launch_viz,
|
| 1140 |
+
inputs=[grid_size, vo_points, fps, original_image_state, processing_mode],
|
| 1141 |
outputs=[viz_html, tracking_video_download, html_download]
|
| 1142 |
)
|
| 1143 |
|