Spaces:
Running
on
Zero
Running
on
Zero
| import gradio as gr | |
| import os | |
| import json | |
| import numpy as np | |
| import cv2 | |
| import base64 | |
| from typing import List, Tuple | |
| # Backend Space URL - replace with your actual backend space URL | |
| BACKEND_SPACE_URL = "Yuxihenry/SpatialTrackerV2_Backend" # Replace with actual backend space URL | |
| hf_token = os.getenv("HF_TOKEN") # Replace with your actual Hugging Face token | |
| # Flag to track if backend is available | |
| BACKEND_AVAILABLE = False | |
| backend_api = None | |
| def initialize_backend(): | |
| """Initialize backend connection""" | |
| global backend_api, BACKEND_AVAILABLE | |
| try: | |
| print(f"Attempting to connect to backend: {BACKEND_SPACE_URL}") | |
| backend_api = gr.load(f"spaces/{BACKEND_SPACE_URL}", token=hf_token) | |
| # Test if the API object has the expected methods | |
| print(f"🔧 Backend API object type: {type(backend_api)}") | |
| print(f"🔧 Backend API object attributes: {dir(backend_api)}") | |
| # gr.load() typically exposes the Interface's fn function directly | |
| # So we should look for the main function name, not the wrapper names | |
| if hasattr(backend_api, 'process_video_with_points'): | |
| BACKEND_AVAILABLE = True | |
| print("✅ Backend connection successful!") | |
| print("✅ Backend API methods are available") | |
| return True | |
| else: | |
| print("❌ Backend API methods not found") | |
| print(f"🔧 Available methods: {[attr for attr in dir(backend_api) if not attr.startswith('_')]}") | |
| BACKEND_AVAILABLE = False | |
| return False | |
| except Exception as e: | |
| print(f"❌ Backend connection failed: {e}") | |
| print("⚠️ Running in standalone mode (some features may be limited)") | |
| BACKEND_AVAILABLE = False | |
| return False | |
| def numpy_to_base64(arr): | |
| """Convert numpy array to base64 string""" | |
| return base64.b64encode(arr.tobytes()).decode('utf-8') | |
| def base64_to_numpy(b64_str, shape, dtype): | |
| """Convert base64 string back to numpy array""" | |
| return np.frombuffer(base64.b64decode(b64_str), dtype=dtype).reshape(shape) | |
| def base64_to_image(b64_str): | |
| """Convert base64 string to numpy image array""" | |
| if not b64_str: | |
| return None | |
| try: | |
| # Decode base64 to bytes | |
| img_bytes = base64.b64decode(b64_str) | |
| # Convert bytes to numpy array | |
| nparr = np.frombuffer(img_bytes, np.uint8) | |
| # Decode image | |
| img = cv2.imdecode(nparr, cv2.IMREAD_COLOR) | |
| # Convert BGR to RGB | |
| img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) | |
| return img | |
| except Exception as e: | |
| print(f"Error converting base64 to image: {e}") | |
| return None | |
| def get_video_name(video_path): | |
| """Extract video name without extension""" | |
| return os.path.splitext(os.path.basename(video_path))[0] | |
| def extract_first_frame(video_path): | |
| """Extract first frame from video file""" | |
| try: | |
| cap = cv2.VideoCapture(video_path) | |
| ret, frame = cap.read() | |
| cap.release() | |
| if ret: | |
| # Convert BGR to RGB | |
| frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) | |
| return frame_rgb | |
| else: | |
| return None | |
| except Exception as e: | |
| print(f"Error extracting first frame: {e}") | |
| return None | |
| def handle_video_upload(video): | |
| """Handle video upload and extract first frame""" | |
| if video is None: | |
| return None, None, [], 50, 756, 3 | |
| try: | |
| if BACKEND_AVAILABLE and backend_api: | |
| # Try to use backend API | |
| try: | |
| # Use the main function directly since gr.load() exposes the Interface's fn | |
| result = backend_api.process_video_with_points(video, [], 50, 756, 3) | |
| # Parse the result to extract what we need | |
| if isinstance(result, dict) and result.get("success"): | |
| # For now, just extract the first frame locally | |
| display_image = extract_first_frame(video) | |
| original_image_state = json.dumps({"video_path": video, "frame": "backend_processing"}) | |
| return original_image_state, display_image, [], 50, 756, 3 | |
| else: | |
| print("Backend processing failed, using local fallback") | |
| # Fallback to local processing | |
| pass | |
| except Exception as e: | |
| print(f"Backend API call failed: {e}") | |
| # Fallback to local processing | |
| pass | |
| # Fallback: local processing | |
| print("Using local video processing...") | |
| display_image = extract_first_frame(video) | |
| # Create a simple state representation | |
| original_image_state = json.dumps({ | |
| "video_path": video, | |
| "frame": "local_processing" | |
| }) | |
| # Default settings | |
| grid_size_val, vo_points_val, fps_val = 50, 756, 3 | |
| return original_image_state, display_image, [], grid_size_val, vo_points_val, fps_val | |
| except Exception as e: | |
| print(f"Error in handle_video_upload: {e}") | |
| return None, None, [], 50, 756, 3 | |
| def select_point(original_img: str, sel_pix: list, point_type: str, evt: gr.SelectData): | |
| """Handle point selection for SAM""" | |
| if original_img is None: | |
| return None, [] | |
| try: | |
| if BACKEND_AVAILABLE and backend_api: | |
| # Try to use backend API | |
| try: | |
| display_image_b64, new_sel_pix = backend_api.select_point_api( | |
| original_img, sel_pix, point_type, evt.index[0], evt.index[1] | |
| ) | |
| display_image = base64_to_image(display_image_b64) | |
| return display_image, new_sel_pix | |
| except Exception as e: | |
| print(f"Backend API call failed: {e}") | |
| # Fallback to local processing | |
| pass | |
| # Fallback: local processing | |
| print("Using local point selection...") | |
| # Parse original image state | |
| try: | |
| state_data = json.loads(original_img) | |
| video_path = state_data.get("video_path") | |
| except: | |
| video_path = None | |
| if video_path: | |
| # Re-extract frame and add point | |
| display_image = extract_first_frame(video_path) | |
| if display_image is not None: | |
| # Add point to the image (simple visualization) | |
| x, y = evt.index[0], evt.index[1] | |
| color = (0, 255, 0) if point_type == 'positive_point' else (255, 0, 0) | |
| cv2.circle(display_image, (x, y), 5, color, -1) | |
| # Update selected points | |
| new_sel_pix = sel_pix + [(x, y, point_type)] | |
| return display_image, new_sel_pix | |
| return None, sel_pix | |
| except Exception as e: | |
| print(f"Error in select_point: {e}") | |
| return None, sel_pix | |
| def reset_points(original_img: str, sel_pix): | |
| """Reset all points and clear the mask""" | |
| if original_img is None: | |
| return None, [] | |
| try: | |
| if BACKEND_AVAILABLE and backend_api: | |
| # Try to use backend API | |
| try: | |
| display_image_b64, new_sel_pix = backend_api.reset_points_api(original_img, sel_pix) | |
| display_image = base64_to_image(display_image_b64) | |
| return display_image, new_sel_pix | |
| except Exception as e: | |
| print(f"Backend API call failed: {e}") | |
| # Fallback to local processing | |
| pass | |
| # Fallback: local processing | |
| print("Using local point reset...") | |
| # Parse original image state | |
| try: | |
| state_data = json.loads(original_img) | |
| video_path = state_data.get("video_path") | |
| except: | |
| video_path = None | |
| if video_path: | |
| # Re-extract frame without points | |
| display_image = extract_first_frame(video_path) | |
| return display_image, [] | |
| return None, [] | |
| except Exception as e: | |
| print(f"Error in reset_points: {e}") | |
| return None, [] | |
| def launch_viz(grid_size, vo_points, fps, original_image_state): | |
| """Launch visualization with user-specific temp directory""" | |
| if original_image_state is None: | |
| return None, None | |
| try: | |
| if BACKEND_AVAILABLE and backend_api: | |
| # Try to use backend API | |
| try: | |
| print(f"🔧 Calling backend API with parameters: grid_size={grid_size}, vo_points={vo_points}, fps={fps}") | |
| print(f"🔧 Original image state type: {type(original_image_state)}") | |
| print(f"🔧 Original image state preview: {str(original_image_state)[:100]}...") | |
| # Use the main function with points from the state | |
| # For now, we'll use empty points since we're in local mode | |
| result = backend_api.process_video_with_points( | |
| None, [], grid_size, vo_points, fps | |
| ) | |
| print(f"✅ Backend API call successful!") | |
| print(f"🔧 Result type: {type(result)}") | |
| print(f"🔧 Result: {result}") | |
| # Parse the result | |
| if isinstance(result, dict) and result.get("success"): | |
| viz_html = result.get("viz_html_path", "") | |
| track_video_path = result.get("track_video_path", "") | |
| return viz_html, track_video_path | |
| else: | |
| print("Backend processing failed, showing error message") | |
| # Fallback to error message | |
| pass | |
| except Exception as e: | |
| print(f"❌ Backend API call failed: {e}") | |
| print(f"🔧 Error type: {type(e)}") | |
| print(f"🔧 Error details: {str(e)}") | |
| # Fallback to local processing | |
| pass | |
| # Fallback: show message that backend is required | |
| error_message = f""" | |
| <div style='border: 3px solid #ff6b6b; border-radius: 10px; padding: 20px; background-color: #fff5f5;'> | |
| <h3 style='color: #d63031; margin-bottom: 15px;'>⚠️ Backend Connection Required</h3> | |
| <p style='color: #2d3436; line-height: 1.6;'> | |
| The tracking and visualization features require a connection to the backend Space. | |
| Please ensure: | |
| </p> | |
| <ul style='color: #2d3436; line-height: 1.6;'> | |
| <li>The backend Space is deployed and running</li> | |
| <li>The BACKEND_SPACE_URL is correctly configured</li> | |
| <li>You have proper access permissions to the backend Space</li> | |
| </ul> | |
| <div style='background-color: #f8f9fa; border-radius: 5px; padding: 10px; margin-top: 10px;'> | |
| <p style='color: #2d3436; font-weight: bold; margin: 0 0 5px 0;'>Debug Information:</p> | |
| <p style='color: #666; font-size: 12px; margin: 0;'>Backend Available: {BACKEND_AVAILABLE}</p> | |
| <p style='color: #666; font-size: 12px; margin: 0;'>Backend API Object: {backend_api is not None}</p> | |
| <p style='color: #666; font-size: 12px; margin: 0;'>Backend URL: {BACKEND_SPACE_URL}</p> | |
| </div> | |
| <p style='color: #2d3436; font-weight: bold; margin-top: 15px;'> | |
| Current Status: Backend unavailable - Running in limited mode | |
| </p> | |
| </div> | |
| """ | |
| return error_message, None | |
| except Exception as e: | |
| print(f"Error in launch_viz: {e}") | |
| return None, None | |
| def clear_all(): | |
| """Clear all buffers and temporary files""" | |
| return None, None, [] | |
| def update_tracker_model(vo_points): | |
| return None # No output needed | |
| # Function to handle both manual upload and example selection | |
| def handle_video_change(video): | |
| """Handle video change from both manual upload and example selection""" | |
| if video is None: | |
| return None, None, [], 50, 756, 3 | |
| # Handle video upload (extract first frame) | |
| original_image_state, display_image, selected_points, grid_size_val, vo_points_val, fps_val = handle_video_upload(video) | |
| return original_image_state, display_image, selected_points, grid_size_val, vo_points_val, fps_val | |
| def test_backend_connection(): | |
| """Test if backend is actually working""" | |
| global BACKEND_AVAILABLE | |
| if not backend_api: | |
| return False | |
| try: | |
| # Try a simple API call to test connection | |
| print("Testing backend connection with a simple call...") | |
| # We'll test with a dummy call or check if the API object is properly loaded | |
| if hasattr(backend_api, 'upload_video_api'): | |
| print("✅ Backend API methods are available") | |
| return True | |
| else: | |
| print("❌ Backend API methods not found") | |
| BACKEND_AVAILABLE = False | |
| return False | |
| except Exception as e: | |
| print(f"❌ Backend connection test failed: {e}") | |
| BACKEND_AVAILABLE = False | |
| return False | |
| def test_backend_api(): | |
| """Test specific backend API functions""" | |
| if not BACKEND_AVAILABLE or not backend_api: | |
| print("❌ Backend not available for testing") | |
| return False | |
| try: | |
| print("🧪 Testing backend API functions...") | |
| # Test if methods exist | |
| methods_to_test = ['upload_video_api', 'select_point_api', 'reset_points_api', 'run_tracker_api'] | |
| for method in methods_to_test: | |
| if hasattr(backend_api, method): | |
| print(f"✅ {method} is available") | |
| else: | |
| print(f"❌ {method} is not available") | |
| return True | |
| except Exception as e: | |
| print(f"❌ Backend API test failed: {e}") | |
| return False | |
| # Initialize backend connection | |
| print("🔧 Initializing backend connection...") | |
| initialize_backend() | |
| # Test the connection | |
| test_backend_connection() | |
| # Test specific API functions | |
| test_backend_api() | |
| # Build UI | |
| with gr.Blocks(css=""" | |
| #advanced_settings .wrap { | |
| font-size: 14px !important; | |
| } | |
| #advanced_settings .gr-slider { | |
| font-size: 13px !important; | |
| } | |
| #advanced_settings .gr-slider .gr-label { | |
| font-size: 13px !important; | |
| margin-bottom: 5px !important; | |
| } | |
| #advanced_settings .gr-slider .gr-info { | |
| font-size: 12px !important; | |
| } | |
| #point_label_radio .gr-radio-group { | |
| flex-direction: row !important; | |
| gap: 15px !important; | |
| } | |
| #point_label_radio .gr-radio-group label { | |
| margin-right: 0 !important; | |
| margin-bottom: 0 !important; | |
| } | |
| /* Style for example videos label */ | |
| .gr-examples .gr-label { | |
| font-weight: bold !important; | |
| font-size: 16px !important; | |
| } | |
| /* Simple horizontal scroll for examples */ | |
| .gr-examples .gr-table-wrapper { | |
| overflow-x: auto !important; | |
| overflow-y: hidden !important; | |
| } | |
| .gr-examples .gr-table { | |
| display: flex !important; | |
| flex-wrap: nowrap !important; | |
| min-width: max-content !important; | |
| } | |
| .gr-examples .gr-table tbody { | |
| display: flex !important; | |
| flex-direction: row !important; | |
| flex-wrap: nowrap !important; | |
| } | |
| .gr-examples .gr-table tbody tr { | |
| display: flex !important; | |
| flex-direction: column !important; | |
| min-width: 150px !important; | |
| margin-right: 10px !important; | |
| } | |
| .gr-examples .gr-table tbody tr td { | |
| text-align: center !important; | |
| padding: 5px !important; | |
| } | |
| """) as demo: | |
| # Initialize states inside Blocks | |
| selected_points = gr.State([]) | |
| original_image_state = gr.State() # Store original image in state | |
| with gr.Row(): | |
| # Show backend status with more detailed information | |
| status_color = "#28a745" if BACKEND_AVAILABLE else "#dc3545" | |
| status_text = "Connected" if BACKEND_AVAILABLE else "Disconnected" | |
| status_icon = "✅" if BACKEND_AVAILABLE else "❌" | |
| gr.Markdown(f""" | |
| # ✨ SpaTrackV2 Frontend (Client) | |
| <div style='background-color: #e6f3ff; padding: 20px; border-radius: 10px; margin: 10px 0;'> | |
| <h2 style='color: #0066cc; margin-bottom: 15px;'>Instructions:</h2> | |
| <ol style='font-size: 20px; line-height: 1.6;'> | |
| <li>🎬 Upload a video or select from examples below</li> | |
| <li>🎯 Select positive points (green) and negative points (red) on the first frame</li> | |
| <li>⚡ Click 'Run Tracker and Visualize' when done</li> | |
| <li>🔍 Iterative 3D result will be shown in the visualization</li> | |
| </ol> | |
| <div style='background-color: {status_color}20; border: 2px solid {status_color}; border-radius: 8px; padding: 10px; margin-top: 15px;'> | |
| <p style='font-size: 18px; color: {status_color}; margin: 0;'> | |
| {status_icon} Backend Status: {status_text} | |
| </p> | |
| <p style='font-size: 14px; color: #666; margin: 5px 0 0 0;'> | |
| {BACKEND_SPACE_URL} | |
| </p> | |
| <p style='font-size: 12px; color: #888; margin: 5px 0 0 0;'> | |
| {'API methods available' if BACKEND_AVAILABLE else 'Connection failed - using local mode'} | |
| </p> | |
| </div> | |
| </div> | |
| """) | |
| with gr.Row(): | |
| with gr.Column(scale=1): | |
| video_input = gr.Video(label="Upload Video", format="mp4", height=300) | |
| # Move Interactive Frame and 2D Tracking under video upload | |
| with gr.Row(): | |
| display_image = gr.Image(type="numpy", label="📸 Interactive Frame", height=250) | |
| track_video = gr.Video(label="🎯 2D Tracking Result", height=250) | |
| with gr.Row(): | |
| fg_bg_radio = gr.Radio(choices=['positive_point', 'negative_point'], | |
| label='Point label', | |
| value='positive_point', | |
| elem_id="point_label_radio") | |
| reset_button = gr.Button("Reset points") | |
| clear_button = gr.Button("Clear All", variant="secondary") | |
| with gr.Accordion("⚙️ Advanced Settings", open=True, elem_id="advanced_settings"): | |
| grid_size = gr.Slider(minimum=10, maximum=100, value=50, step=1, | |
| label="Grid Size", info="Size of the tracking grid") | |
| vo_points = gr.Slider(minimum=256, maximum=4096, value=756, step=50, | |
| label="VO Points", info="Number of points for solving camera pose") | |
| fps_slider = gr.Slider(minimum=1, maximum=10, value=3, step=1, | |
| label="FPS", info="FPS of the output video") | |
| viz_button = gr.Button("🚀 Run Tracker and Visualize", variant="primary", size="lg") | |
| with gr.Column(scale=2): | |
| # Add example videos using gr.Examples | |
| examples_component = gr.Examples( | |
| examples=[ | |
| "examples/kiss.mp4", | |
| "examples/backpack.mp4", | |
| "examples/pillow.mp4", | |
| "examples/hockey.mp4", | |
| "examples/drifting.mp4", | |
| "examples/ken_block_0.mp4", | |
| "examples/ball.mp4", | |
| "examples/kitchen.mp4", | |
| "examples/ego_teaser.mp4", | |
| "examples/ego_kc1.mp4", | |
| "examples/vertical_place.mp4", | |
| "examples/robot_unitree.mp4", | |
| "examples/droid_robot.mp4", | |
| "examples/robot_2.mp4", | |
| "examples/cinema_0.mp4", | |
| ], | |
| inputs=[video_input], | |
| label="📁 Example Videos", | |
| examples_per_page=20 # Show all examples on one page to enable scrolling | |
| ) | |
| # Initialize with a placeholder interface instead of static file | |
| viz_iframe = gr.HTML(""" | |
| <div style='border: 3px solid #667eea; border-radius: 10px; overflow: hidden; box-shadow: 0 8px 32px rgba(102, 126, 234, 0.3); background: #f8f9fa; display: flex; align-items: center; justify-content: center; height: 950px;'> | |
| <div style='text-align: center; color: #666;'> | |
| <h3 style='margin-bottom: 20px; color: #667eea;'>🎮 Interactive 3D Tracking</h3> | |
| <p style='font-size: 16px; margin-bottom: 10px;'>Upload a video and select points to start tracking</p> | |
| <p style='font-size: 14px; color: #999;'>Powered by SpaTrackV2</p> | |
| </div> | |
| </div> | |
| """) | |
| # Simple description below the visualization | |
| gr.HTML(""" | |
| <div style='text-align: center; margin-top: 15px; color: #666; font-size: 14px;'> | |
| 🎮 Interactive 3D visualization adapted from <a href="https://tapip3d.github.io/" target="_blank" style="color: #667eea;">TAPIP3D</a> | |
| </div> | |
| """) | |
| # Bind events | |
| video_input.change( | |
| handle_video_change, | |
| inputs=[video_input], | |
| outputs=[original_image_state, display_image, selected_points, grid_size, vo_points, fps_slider] | |
| ) | |
| reset_button.click(reset_points, | |
| inputs=[original_image_state, selected_points], | |
| outputs=[display_image, selected_points]) | |
| clear_button.click(clear_all, | |
| outputs=[video_input, display_image, selected_points]) | |
| display_image.select(select_point, | |
| inputs=[original_image_state, selected_points, fg_bg_radio], | |
| outputs=[display_image, selected_points]) | |
| # Update tracker model when vo_points changes | |
| vo_points.change(update_tracker_model, | |
| inputs=[vo_points], | |
| outputs=[]) | |
| viz_button.click(launch_viz, | |
| inputs=[grid_size, vo_points, fps_slider, original_image_state], | |
| outputs=[viz_iframe, track_video], | |
| ) | |
| # Launch the demo | |
| if __name__ == "__main__": | |
| demo.launch() |