File size: 22,706 Bytes
dcceb35
bf11931
 
 
 
 
 
 
 
 
 
5a56247
6f5c5a6
 
 
 
 
 
 
 
 
 
b08da78
 
da74499
 
 
 
 
 
b08da78
 
 
 
 
 
da74499
b08da78
 
 
6f5c5a6
 
 
 
 
 
bf11931
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6f5c5a6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
bf11931
 
 
6f5c5a6
bf11931
 
6f5c5a6
 
 
da74499
 
 
 
 
 
 
 
 
 
 
 
6f5c5a6
 
 
 
bf11931
6f5c5a6
 
 
bf11931
6f5c5a6
 
 
 
 
bf11931
6f5c5a6
 
 
 
bf11931
 
 
 
 
 
 
 
 
 
 
6f5c5a6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
bf11931
6f5c5a6
 
 
 
 
 
bf11931
6f5c5a6
 
 
 
 
 
 
 
 
 
 
 
bf11931
6f5c5a6
bf11931
 
 
 
 
 
 
 
 
 
 
6f5c5a6
 
 
 
 
 
 
 
 
 
bf11931
6f5c5a6
 
bf11931
6f5c5a6
 
 
 
 
 
bf11931
6f5c5a6
 
 
 
 
 
bf11931
 
 
 
 
 
 
 
 
 
 
6f5c5a6
 
 
b08da78
 
 
 
da74499
 
 
 
6f5c5a6
b08da78
 
da74499
 
b08da78
da74499
 
 
 
 
 
 
 
 
6f5c5a6
b08da78
 
 
6f5c5a6
 
bf11931
6f5c5a6
b08da78
6f5c5a6
 
 
 
 
 
 
 
 
 
 
b08da78
 
 
 
 
 
6f5c5a6
 
 
 
 
 
bf11931
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b08da78
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6f5c5a6
 
 
 
b08da78
 
 
 
 
 
bf11931
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b08da78
6f5c5a6
 
 
 
 
bf11931
 
 
 
 
 
 
 
 
6f5c5a6
 
 
 
 
 
 
b08da78
 
 
6f5c5a6
bf11931
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b08da78
bf11931
b08da78
 
 
 
 
 
bf11931
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
dcceb35
bf11931
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
import gradio as gr
import os
import json
import numpy as np
import cv2
import base64
from typing import List, Tuple

# Backend Space URL - replace with your actual backend space URL
BACKEND_SPACE_URL = "Yuxihenry/SpatialTrackerV2_Backend"  # Replace with actual backend space URL
hf_token = os.getenv("HF_TOKEN")  # Replace with your actual Hugging Face token

# Flag to track if backend is available
BACKEND_AVAILABLE = False
backend_api = None

def initialize_backend():
    """Initialize backend connection"""
    global backend_api, BACKEND_AVAILABLE
    try:
        print(f"Attempting to connect to backend: {BACKEND_SPACE_URL}")
        backend_api = gr.load(f"spaces/{BACKEND_SPACE_URL}", token=hf_token)
        
        # Test if the API object has the expected methods
        print(f"🔧 Backend API object type: {type(backend_api)}")
        print(f"🔧 Backend API object attributes: {dir(backend_api)}")
        
        # gr.load() typically exposes the Interface's fn function directly
        # So we should look for the main function name, not the wrapper names
        if hasattr(backend_api, 'process_video_with_points'):
            BACKEND_AVAILABLE = True
            print("✅ Backend connection successful!")
            print("✅ Backend API methods are available")
            return True
        else:
            print("❌ Backend API methods not found")
            print(f"🔧 Available methods: {[attr for attr in dir(backend_api) if not attr.startswith('_')]}")
            BACKEND_AVAILABLE = False
            return False
            
    except Exception as e:
        print(f"❌ Backend connection failed: {e}")
        print("⚠️  Running in standalone mode (some features may be limited)")
        BACKEND_AVAILABLE = False
        return False

def numpy_to_base64(arr):
    """Convert numpy array to base64 string"""
    return base64.b64encode(arr.tobytes()).decode('utf-8')

def base64_to_numpy(b64_str, shape, dtype):
    """Convert base64 string back to numpy array"""
    return np.frombuffer(base64.b64decode(b64_str), dtype=dtype).reshape(shape)

def base64_to_image(b64_str):
    """Convert base64 string to numpy image array"""
    if not b64_str:
        return None
    try:
        # Decode base64 to bytes
        img_bytes = base64.b64decode(b64_str)
        # Convert bytes to numpy array
        nparr = np.frombuffer(img_bytes, np.uint8)
        # Decode image
        img = cv2.imdecode(nparr, cv2.IMREAD_COLOR)
        # Convert BGR to RGB
        img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
        return img
    except Exception as e:
        print(f"Error converting base64 to image: {e}")
        return None

def get_video_name(video_path):
    """Extract video name without extension"""
    return os.path.splitext(os.path.basename(video_path))[0]

def extract_first_frame(video_path):
    """Extract first frame from video file"""
    try:
        cap = cv2.VideoCapture(video_path)
        ret, frame = cap.read()
        cap.release()
        
        if ret:
            # Convert BGR to RGB
            frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
            return frame_rgb
        else:
            return None
    except Exception as e:
        print(f"Error extracting first frame: {e}")
        return None

def handle_video_upload(video):
    """Handle video upload and extract first frame"""
    if video is None:
        return None, None, [], 50, 756, 3
    
    try:
        if BACKEND_AVAILABLE and backend_api:
            # Try to use backend API
            try:
                # Use the main function directly since gr.load() exposes the Interface's fn
                result = backend_api.process_video_with_points(video, [], 50, 756, 3)
                # Parse the result to extract what we need
                if isinstance(result, dict) and result.get("success"):
                    # For now, just extract the first frame locally
                    display_image = extract_first_frame(video)
                    original_image_state = json.dumps({"video_path": video, "frame": "backend_processing"})
                    return original_image_state, display_image, [], 50, 756, 3
                else:
                    print("Backend processing failed, using local fallback")
                    # Fallback to local processing
                    pass
            except Exception as e:
                print(f"Backend API call failed: {e}")
                # Fallback to local processing
                pass
        
        # Fallback: local processing
        print("Using local video processing...")
        display_image = extract_first_frame(video)
        
        # Create a simple state representation
        original_image_state = json.dumps({
            "video_path": video,
            "frame": "local_processing"
        })
        
        # Default settings
        grid_size_val, vo_points_val, fps_val = 50, 756, 3
        
        return original_image_state, display_image, [], grid_size_val, vo_points_val, fps_val
        
    except Exception as e:
        print(f"Error in handle_video_upload: {e}")
        return None, None, [], 50, 756, 3

def select_point(original_img: str, sel_pix: list, point_type: str, evt: gr.SelectData):
    """Handle point selection for SAM"""
    if original_img is None:
        return None, []
    
    try:
        if BACKEND_AVAILABLE and backend_api:
            # Try to use backend API
            try:
                display_image_b64, new_sel_pix = backend_api.select_point_api(
                    original_img, sel_pix, point_type, evt.index[0], evt.index[1]
                )
                display_image = base64_to_image(display_image_b64)
                return display_image, new_sel_pix
            except Exception as e:
                print(f"Backend API call failed: {e}")
                # Fallback to local processing
                pass
        
        # Fallback: local processing
        print("Using local point selection...")
        
        # Parse original image state
        try:
            state_data = json.loads(original_img)
            video_path = state_data.get("video_path")
        except:
            video_path = None
        
        if video_path:
            # Re-extract frame and add point
            display_image = extract_first_frame(video_path)
            if display_image is not None:
                # Add point to the image (simple visualization)
                x, y = evt.index[0], evt.index[1]
                color = (0, 255, 0) if point_type == 'positive_point' else (255, 0, 0)
                cv2.circle(display_image, (x, y), 5, color, -1)
                
                # Update selected points
                new_sel_pix = sel_pix + [(x, y, point_type)]
                return display_image, new_sel_pix
        
        return None, sel_pix
        
    except Exception as e:
        print(f"Error in select_point: {e}")
        return None, sel_pix

def reset_points(original_img: str, sel_pix):
    """Reset all points and clear the mask"""
    if original_img is None:
        return None, []
    
    try:
        if BACKEND_AVAILABLE and backend_api:
            # Try to use backend API
            try:
                display_image_b64, new_sel_pix = backend_api.reset_points_api(original_img, sel_pix)
                display_image = base64_to_image(display_image_b64)
                return display_image, new_sel_pix
            except Exception as e:
                print(f"Backend API call failed: {e}")
                # Fallback to local processing
                pass
        
        # Fallback: local processing
        print("Using local point reset...")
        
        # Parse original image state
        try:
            state_data = json.loads(original_img)
            video_path = state_data.get("video_path")
        except:
            video_path = None
        
        if video_path:
            # Re-extract frame without points
            display_image = extract_first_frame(video_path)
            return display_image, []
        
        return None, []
        
    except Exception as e:
        print(f"Error in reset_points: {e}")
        return None, []

def launch_viz(grid_size, vo_points, fps, original_image_state):
    """Launch visualization with user-specific temp directory"""
    if original_image_state is None:
        return None, None
    
    try:
        if BACKEND_AVAILABLE and backend_api:
            # Try to use backend API
            try:
                print(f"🔧 Calling backend API with parameters: grid_size={grid_size}, vo_points={vo_points}, fps={fps}")
                print(f"🔧 Original image state type: {type(original_image_state)}")
                print(f"🔧 Original image state preview: {str(original_image_state)[:100]}...")
                
                # Use the main function with points from the state
                # For now, we'll use empty points since we're in local mode
                result = backend_api.process_video_with_points(
                    None, [], grid_size, vo_points, fps
                )
                
                print(f"✅ Backend API call successful!")
                print(f"🔧 Result type: {type(result)}")
                print(f"🔧 Result: {result}")
                
                # Parse the result
                if isinstance(result, dict) and result.get("success"):
                    viz_html = result.get("viz_html_path", "")
                    track_video_path = result.get("track_video_path", "")
                    return viz_html, track_video_path
                else:
                    print("Backend processing failed, showing error message")
                    # Fallback to error message
                    pass
            except Exception as e:
                print(f"❌ Backend API call failed: {e}")
                print(f"🔧 Error type: {type(e)}")
                print(f"🔧 Error details: {str(e)}")
                # Fallback to local processing
                pass
        
        # Fallback: show message that backend is required
        error_message = f"""
        <div style='border: 3px solid #ff6b6b; border-radius: 10px; padding: 20px; background-color: #fff5f5;'>
            <h3 style='color: #d63031; margin-bottom: 15px;'>⚠️ Backend Connection Required</h3>
            <p style='color: #2d3436; line-height: 1.6;'>
                The tracking and visualization features require a connection to the backend Space. 
                Please ensure:
            </p>
            <ul style='color: #2d3436; line-height: 1.6;'>
                <li>The backend Space is deployed and running</li>
                <li>The BACKEND_SPACE_URL is correctly configured</li>
                <li>You have proper access permissions to the backend Space</li>
            </ul>
            <div style='background-color: #f8f9fa; border-radius: 5px; padding: 10px; margin-top: 10px;'>
                <p style='color: #2d3436; font-weight: bold; margin: 0 0 5px 0;'>Debug Information:</p>
                <p style='color: #666; font-size: 12px; margin: 0;'>Backend Available: {BACKEND_AVAILABLE}</p>
                <p style='color: #666; font-size: 12px; margin: 0;'>Backend API Object: {backend_api is not None}</p>
                <p style='color: #666; font-size: 12px; margin: 0;'>Backend URL: {BACKEND_SPACE_URL}</p>
            </div>
            <p style='color: #2d3436; font-weight: bold; margin-top: 15px;'>
                Current Status: Backend unavailable - Running in limited mode
            </p>
        </div>
        """
        return error_message, None
        
    except Exception as e:
        print(f"Error in launch_viz: {e}")
        return None, None

def clear_all():
    """Clear all buffers and temporary files"""
    return None, None, []

def update_tracker_model(vo_points):
    return None  # No output needed

# Function to handle both manual upload and example selection
def handle_video_change(video):
    """Handle video change from both manual upload and example selection"""
    if video is None:
        return None, None, [], 50, 756, 3
    
    # Handle video upload (extract first frame)
    original_image_state, display_image, selected_points, grid_size_val, vo_points_val, fps_val = handle_video_upload(video)
    
    return original_image_state, display_image, selected_points, grid_size_val, vo_points_val, fps_val

def test_backend_connection():
    """Test if backend is actually working"""
    global BACKEND_AVAILABLE
    if not backend_api:
        return False
    
    try:
        # Try a simple API call to test connection
        print("Testing backend connection with a simple call...")
        # We'll test with a dummy call or check if the API object is properly loaded
        if hasattr(backend_api, 'upload_video_api'):
            print("✅ Backend API methods are available")
            return True
        else:
            print("❌ Backend API methods not found")
            BACKEND_AVAILABLE = False
            return False
    except Exception as e:
        print(f"❌ Backend connection test failed: {e}")
        BACKEND_AVAILABLE = False
        return False

def test_backend_api():
    """Test specific backend API functions"""
    if not BACKEND_AVAILABLE or not backend_api:
        print("❌ Backend not available for testing")
        return False
    
    try:
        print("🧪 Testing backend API functions...")
        
        # Test if methods exist
        methods_to_test = ['upload_video_api', 'select_point_api', 'reset_points_api', 'run_tracker_api']
        for method in methods_to_test:
            if hasattr(backend_api, method):
                print(f"✅ {method} is available")
            else:
                print(f"❌ {method} is not available")
        
        return True
    except Exception as e:
        print(f"❌ Backend API test failed: {e}")
        return False

# Initialize backend connection
print("🔧 Initializing backend connection...")
initialize_backend()

# Test the connection
test_backend_connection()

# Test specific API functions
test_backend_api()

# Build UI
with gr.Blocks(css="""
    #advanced_settings .wrap {
        font-size: 14px !important;
    }
    #advanced_settings .gr-slider {
        font-size: 13px !important;
    }
    #advanced_settings .gr-slider .gr-label {
        font-size: 13px !important;
        margin-bottom: 5px !important;
    }
    #advanced_settings .gr-slider .gr-info {
        font-size: 12px !important;
    }
    #point_label_radio .gr-radio-group {
        flex-direction: row !important;
        gap: 15px !important;
    }
    #point_label_radio .gr-radio-group label {
        margin-right: 0 !important;
        margin-bottom: 0 !important;
    }
    /* Style for example videos label */
    .gr-examples .gr-label {
        font-weight: bold !important;
        font-size: 16px !important;
    }
    /* Simple horizontal scroll for examples */
    .gr-examples .gr-table-wrapper {
        overflow-x: auto !important;
        overflow-y: hidden !important;
    }
    .gr-examples .gr-table {
        display: flex !important;
        flex-wrap: nowrap !important;
        min-width: max-content !important;
    }
    .gr-examples .gr-table tbody {
        display: flex !important;
        flex-direction: row !important;
        flex-wrap: nowrap !important;
    }
    .gr-examples .gr-table tbody tr {
        display: flex !important;
        flex-direction: column !important;
        min-width: 150px !important;
        margin-right: 10px !important;
    }
    .gr-examples .gr-table tbody tr td {
        text-align: center !important;
        padding: 5px !important;
    }
""") as demo:
    # Initialize states inside Blocks
    selected_points = gr.State([])
    original_image_state = gr.State()  # Store original image in state
    
    with gr.Row():
        # Show backend status with more detailed information
        status_color = "#28a745" if BACKEND_AVAILABLE else "#dc3545"
        status_text = "Connected" if BACKEND_AVAILABLE else "Disconnected"
        status_icon = "✅" if BACKEND_AVAILABLE else "❌"
        
        gr.Markdown(f"""
        # ✨ SpaTrackV2 Frontend (Client)
        <div style='background-color: #e6f3ff; padding: 20px; border-radius: 10px; margin: 10px 0;'>
        <h2 style='color: #0066cc; margin-bottom: 15px;'>Instructions:</h2>
        <ol style='font-size: 20px; line-height: 1.6;'>
            <li>🎬 Upload a video or select from examples below</li>
            <li>🎯 Select positive points (green) and negative points (red) on the first frame</li>
            <li>⚡ Click 'Run Tracker and Visualize' when done</li>
            <li>🔍 Iterative 3D result will be shown in the visualization</li>
        </ol>
        <div style='background-color: {status_color}20; border: 2px solid {status_color}; border-radius: 8px; padding: 10px; margin-top: 15px;'>
            <p style='font-size: 18px; color: {status_color}; margin: 0;'>
                {status_icon} Backend Status: {status_text}
            </p>
            <p style='font-size: 14px; color: #666; margin: 5px 0 0 0;'>
                {BACKEND_SPACE_URL}
            </p>
            <p style='font-size: 12px; color: #888; margin: 5px 0 0 0;'>
                {'API methods available' if BACKEND_AVAILABLE else 'Connection failed - using local mode'}
            </p>
        </div>
        </div>
        """)

    with gr.Row():
        with gr.Column(scale=1):
            video_input = gr.Video(label="Upload Video", format="mp4", height=300)
            
            # Move Interactive Frame and 2D Tracking under video upload
            with gr.Row():
                display_image = gr.Image(type="numpy", label="📸 Interactive Frame", height=250)
                track_video = gr.Video(label="🎯 2D Tracking Result", height=250)
            
            with gr.Row():
                fg_bg_radio = gr.Radio(choices=['positive_point', 'negative_point'], 
                                       label='Point label', 
                                       value='positive_point',
                                       elem_id="point_label_radio")
                reset_button = gr.Button("Reset points")
                clear_button = gr.Button("Clear All", variant="secondary")
            
            with gr.Accordion("⚙️ Advanced Settings", open=True, elem_id="advanced_settings"):
                grid_size = gr.Slider(minimum=10, maximum=100, value=50, step=1, 
                                      label="Grid Size", info="Size of the tracking grid")
                vo_points = gr.Slider(minimum=256, maximum=4096, value=756, step=50,
                                      label="VO Points", info="Number of points for solving camera pose")
                fps_slider = gr.Slider(minimum=1, maximum=10, value=3, step=1,
                                      label="FPS", info="FPS of the output video")
            
            viz_button = gr.Button("🚀 Run Tracker and Visualize", variant="primary", size="lg")

        with gr.Column(scale=2):
            # Add example videos using gr.Examples
            examples_component = gr.Examples(
                examples=[
                    "examples/kiss.mp4",
                    "examples/backpack.mp4",
                    "examples/pillow.mp4",
                    "examples/hockey.mp4",
                    "examples/drifting.mp4",
                    "examples/ken_block_0.mp4",
                    "examples/ball.mp4",
                    "examples/kitchen.mp4", 
                    "examples/ego_teaser.mp4",
                    "examples/ego_kc1.mp4",
                    "examples/vertical_place.mp4",
                    "examples/robot_unitree.mp4",
                    "examples/droid_robot.mp4",
                    "examples/robot_2.mp4",
                    "examples/cinema_0.mp4",
                ],
                inputs=[video_input],
                label="📁 Example Videos",
                examples_per_page=20  # Show all examples on one page to enable scrolling
            )
            
            # Initialize with a placeholder interface instead of static file
            viz_iframe = gr.HTML("""
                                <div style='border: 3px solid #667eea; border-radius: 10px; overflow: hidden; box-shadow: 0 8px 32px rgba(102, 126, 234, 0.3); background: #f8f9fa; display: flex; align-items: center; justify-content: center; height: 950px;'>
                                    <div style='text-align: center; color: #666;'>
                                        <h3 style='margin-bottom: 20px; color: #667eea;'>🎮 Interactive 3D Tracking</h3>
                                        <p style='font-size: 16px; margin-bottom: 10px;'>Upload a video and select points to start tracking</p>
                                        <p style='font-size: 14px; color: #999;'>Powered by SpaTrackV2</p>
                                    </div>
                                </div>
                                """)
            
            # Simple description below the visualization
            gr.HTML("""
            <div style='text-align: center; margin-top: 15px; color: #666; font-size: 14px;'>
                🎮 Interactive 3D visualization adapted from <a href="https://tapip3d.github.io/" target="_blank" style="color: #667eea;">TAPIP3D</a>
            </div>
            """)

    # Bind events
    video_input.change(
        handle_video_change,
        inputs=[video_input], 
        outputs=[original_image_state, display_image, selected_points, grid_size, vo_points, fps_slider]
    )
    
    reset_button.click(reset_points,
                     inputs=[original_image_state, selected_points],
                     outputs=[display_image, selected_points])
    
    clear_button.click(clear_all,
                      outputs=[video_input, display_image, selected_points])
    
    display_image.select(select_point,
                      inputs=[original_image_state, selected_points, fg_bg_radio],
                      outputs=[display_image, selected_points])

    # Update tracker model when vo_points changes
    vo_points.change(update_tracker_model,
                    inputs=[vo_points],
                    outputs=[])
    
    viz_button.click(launch_viz,
                    inputs=[grid_size, vo_points, fps_slider, original_image_state],
                    outputs=[viz_iframe, track_video],
                    )

# Launch the demo
if __name__ == "__main__":
    demo.launch()