""" pose_detection.py - Dog Pose Detection using YOLOv8-Pose """ import cv2 import numpy as np import torch from ultralytics import YOLO from typing import Optional, List, Tuple from pathlib import Path class DogPoseDetector: def __init__(self, model_path: str = 'dog-pose-trained.pt', confidence_threshold: float = 0.5, device: str = 'cuda'): self.model_path = model_path self.confidence_threshold = confidence_threshold self.device = device if torch.cuda.is_available() else 'cpu' if not Path(model_path).exists(): raise FileNotFoundError(f"Pose model not found: {model_path}") self.model = YOLO(model_path) self.model.to(self.device) self.keypoint_8_indices = [0, 3, 7, 10, 12, 14, 17, 19] self.skeleton_8 = [ (0, 1), # Nose -> Ear (1, 2), # Ear -> Shoulder (2, 3), # Shoulder -> Hip (This replaces 7 -> 10) (3, 4), # Hip -> Tail Start (4, 5), # Tail Start -> Tail End (3, 6), # Hip -> Knee Left (3, 7) # Hip -> Knee Right ] print(f"✅ Pose Detector initialized on {self.device}") print(f" Model: {model_path}") def detect_pose(self, image: np.ndarray) -> Optional[np.ndarray]: if image is None or image.size == 0: return None try: results = self.model(image, conf=self.confidence_threshold, verbose=False) if not results or len(results) == 0: return None result = results[0] if result.keypoints is None or len(result.keypoints) == 0: return None keypoints = result.keypoints.data[0].cpu().numpy() if keypoints.shape[0] != 24: return None return keypoints except Exception as e: return None def extract_8_keypoints(self, keypoints_24: np.ndarray) -> np.ndarray: if keypoints_24 is None or keypoints_24.shape[0] != 24: return np.zeros((8, 3)) keypoints_8 = keypoints_24[self.keypoint_8_indices].copy() return keypoints_8 def visualize_8_keypoints(self, image: np.ndarray, keypoints_8: np.ndarray, draw_skeleton: bool = True, draw_points: bool = True) -> np.ndarray: vis_image = image.copy() if keypoints_8 is None or keypoints_8.shape[0] != 8: return vis_image if draw_skeleton: for i, (start_idx, end_idx) in enumerate(self.skeleton_8): if (keypoints_8[start_idx][2] > 0.5 and keypoints_8[end_idx][2] > 0.5): start_point = (int(keypoints_8[start_idx][0]), int(keypoints_8[start_idx][1])) end_point = (int(keypoints_8[end_idx][0]), int(keypoints_8[end_idx][1])) cv2.line(vis_image, start_point, end_point, (0, 255, 0), 2, cv2.LINE_AA) if draw_points: for i, kp in enumerate(keypoints_8): if kp[2] > 0.5: x, y = int(kp[0]), int(kp[1]) cv2.circle(vis_image, (x, y), 4, (0, 0, 255), -1) cv2.circle(vis_image, (x, y), 6, (255, 255, 255), 1) return vis_image def create_visualization_video(self, video_path: str, tracks_data: List, output_path: str = "pose_visualization.mp4") -> Optional[str]: try: cap = cv2.VideoCapture(video_path) if not cap.isOpened(): return None fps = int(cap.get(cv2.CAP_PROP_FPS)) or 30 width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH)) height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT)) total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT)) fourcc = cv2.VideoWriter_fourcc(*'mp4v') out = cv2.VideoWriter(output_path, fourcc, fps, (width, height)) frame_num = 0 processed_frame_idx = 0 while cap.isOpened(): ret, frame = cap.read() if not ret: break if processed_frame_idx < len(tracks_data): frame_tracks = tracks_data[processed_frame_idx] if frame_tracks: for track in frame_tracks: if hasattr(track, 'bbox'): x1, y1, x2, y2 = map(int, track.bbox) color = (0, 255, 0) cv2.rectangle(frame, (x1, y1), (x2, y2), color, 8) if hasattr(track, 'detections') and len(track.detections) > 0: detection = track.detections[-1] if hasattr(detection, 'image_crop') and detection.image_crop is not None: keypoints_24 = self.detect_pose(detection.image_crop) if keypoints_24 is not None: keypoints_8 = self.extract_8_keypoints(keypoints_24) keypoints_8_scaled = keypoints_8.copy() keypoints_8_scaled[:, 0] += x1 keypoints_8_scaled[:, 1] += y1 frame = self.visualize_8_keypoints( frame, keypoints_8_scaled, draw_skeleton=True, draw_points=True ) processed_frame_idx += 1 out.write(frame) frame_num += 1 if frame_num % 30 == 0: print(f"Visualization progress: {frame_num}/{total_frames}") cap.release() out.release() if Path(output_path).exists() and Path(output_path).stat().st_size > 1000: print(f"✅ Visualization video saved: {output_path}") return output_path else: return None except Exception as e: print(f"Visualization error: {e}") return None