| """ | |
| reid.py - Dog Re-Identification using MegaDescriptor | |
| """ | |
| import numpy as np | |
| import cv2 | |
| import torch | |
| import timm | |
| from typing import Dict, Optional, List | |
| from collections import defaultdict | |
| from PIL import Image | |
| class SimplifiedReID: | |
| def __init__(self, device: str = 'cuda'): | |
| self.device = device if torch.cuda.is_available() else 'cpu' | |
| self.threshold = 0.40 | |
| self.temp_id_features = {} | |
| self.next_temp_id = 1 | |
| self.current_frame = 0 | |
| self.current_video_source = "unknown" | |
| self._initialize_model() | |
| print(f"✅ ReID initialized on {self.device}") | |
| print(f" Model: MegaDescriptor-L-384") | |
| print(f" Threshold: {self.threshold:.2f}") | |
| def _initialize_model(self): | |
| try: | |
| self.model = timm.create_model( | |
| 'hf-hub:BVRA/MegaDescriptor-L-384', | |
| pretrained=True | |
| ) | |
| self.model.to(self.device).eval() | |
| self.transform = timm.data.create_transform( | |
| input_size=(384, 384), | |
| is_training=False, | |
| mean=[0.5, 0.5, 0.5], | |
| std=[0.5, 0.5, 0.5] | |
| ) | |
| print(" MegaDescriptor-L-384 loaded successfully") | |
| except Exception as e: | |
| print(f"❌ CRITICAL: Failed to load MegaDescriptor model: {e}") | |
| raise RuntimeError(f"Cannot initialize ReID system: {e}") | |
| def set_threshold(self, threshold: float): | |
| self.threshold = max(0.10, min(0.95, threshold)) | |
| print(f"ReID threshold updated: {self.threshold:.2f}") | |
| def set_video_source(self, video_path: str): | |
| self.current_video_source = video_path | |
| print(f"Video source set: {video_path}") | |
| def reset_session(self): | |
| self.temp_id_features.clear() | |
| self.next_temp_id = 1 | |
| self.current_frame = 0 | |
| print("ReID session reset") | |
| def extract_features(self, image: np.ndarray) -> Optional[np.ndarray]: | |
| if image is None or image.size == 0: | |
| return None | |
| try: | |
| img_rgb = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) | |
| pil_img = Image.fromarray(img_rgb) | |
| img_tensor = self.transform(pil_img).unsqueeze(0).to(self.device) | |
| with torch.no_grad(): | |
| features = self.model(img_tensor) | |
| features = features.squeeze().cpu().numpy() | |
| features = features / (np.linalg.norm(features) + 1e-7) | |
| return features | |
| except Exception as e: | |
| print(f"Feature extraction error: {e}") | |
| return None | |
| def match_or_register(self, track) -> Dict: | |
| self.current_frame += 1 | |
| detection = None | |
| for det in reversed(track.detections[-3:]): | |
| if det.image_crop is not None: | |
| detection = det | |
| break | |
| if detection is None or detection.image_crop is None: | |
| return {'temp_id': 0, 'confidence': 0.0, 'match_type': 'failed'} | |
| features = self.extract_features(detection.image_crop) | |
| if features is None: | |
| return {'temp_id': 0, 'confidence': 0.0, 'match_type': 'failed'} | |
| detection.features = features | |
| best_temp_id = None | |
| best_score = -1.0 | |
| for temp_id, features_list in self.temp_id_features.items(): | |
| similarities = [] | |
| for stored_features in features_list: | |
| sim = np.dot(features, stored_features) | |
| similarities.append(sim) | |
| if similarities: | |
| max_sim = max(similarities) | |
| if max_sim > best_score: | |
| best_score = max_sim | |
| best_temp_id = temp_id | |
| if best_temp_id is not None and best_score >= self.threshold: | |
| self.temp_id_features[best_temp_id].append(features) | |
| if len(self.temp_id_features[best_temp_id]) > 30: | |
| self.temp_id_features[best_temp_id] = \ | |
| self.temp_id_features[best_temp_id][-30:] | |
| return { | |
| 'temp_id': best_temp_id, | |
| 'confidence': best_score, | |
| 'match_type': 'existing' | |
| } | |
| else: | |
| new_temp_id = self.next_temp_id | |
| self.next_temp_id += 1 | |
| self.temp_id_features[new_temp_id] = [features] | |
| print(f" New temp dog ID: {new_temp_id} (threshold: {self.threshold:.2f})") | |
| return { | |
| 'temp_id': new_temp_id, | |
| 'confidence': 1.0, | |
| 'match_type': 'new' | |
| } | |
| def get_temp_id_features(self, temp_id: int) -> Optional[np.ndarray]: | |
| if temp_id not in self.temp_id_features: | |
| return None | |
| features_list = self.temp_id_features[temp_id] | |
| if not features_list: | |
| return None | |
| features_array = np.array(features_list) | |
| avg_features = np.mean(features_array, axis=0) | |
| avg_features = avg_features / (np.linalg.norm(avg_features) + 1e-7) | |
| return avg_features | |
| def get_statistics(self) -> Dict: | |
| stats = { | |
| 'temp_ids': len(self.temp_id_features), | |
| 'threshold': self.threshold, | |
| 'current_frame': self.current_frame, | |
| 'video_source': self.current_video_source | |
| } | |
| feature_counts = {} | |
| for temp_id, features in self.temp_id_features.items(): | |
| feature_counts[temp_id] = len(features) | |
| stats['feature_counts'] = feature_counts | |
| return stats | |
| def compare_with_permanent_database(self, temp_id: int, database) -> Optional[Dict]: | |
| temp_features = self.get_temp_id_features(temp_id) | |
| if temp_features is None: | |
| return None | |
| all_dogs = database.get_all_dogs(active_only=True) | |
| if all_dogs.empty: | |
| return None | |
| best_match = None | |
| best_score = 0.0 | |
| for _, dog in all_dogs.iterrows(): | |
| dog_id = dog['dog_id'] | |
| stored_features = database.get_dog_features(dog_id) | |
| if stored_features is None: | |
| continue | |
| similarity = np.dot(temp_features, stored_features) | |
| if similarity > best_score: | |
| best_score = similarity | |
| best_match = dog_id | |
| permanent_threshold = self.threshold + 0.15 | |
| if best_match is not None and best_score >= permanent_threshold: | |
| return { | |
| 'dog_id': best_match, | |
| 'confidence': best_score, | |
| 'matched': True | |
| } | |
| else: | |
| return None | |
| ReIDSystem = SimplifiedReID | |
| DogReID = SimplifiedReID |