mustafa2ak's picture
Create reid.py
d4c3148 verified
"""
reid.py - Dog Re-Identification using MegaDescriptor
"""
import numpy as np
import cv2
import torch
import timm
from typing import Dict, Optional, List
from collections import defaultdict
from PIL import Image
class SimplifiedReID:
def __init__(self, device: str = 'cuda'):
self.device = device if torch.cuda.is_available() else 'cpu'
self.threshold = 0.40
self.temp_id_features = {}
self.next_temp_id = 1
self.current_frame = 0
self.current_video_source = "unknown"
self._initialize_model()
print(f"✅ ReID initialized on {self.device}")
print(f" Model: MegaDescriptor-L-384")
print(f" Threshold: {self.threshold:.2f}")
def _initialize_model(self):
try:
self.model = timm.create_model(
'hf-hub:BVRA/MegaDescriptor-L-384',
pretrained=True
)
self.model.to(self.device).eval()
self.transform = timm.data.create_transform(
input_size=(384, 384),
is_training=False,
mean=[0.5, 0.5, 0.5],
std=[0.5, 0.5, 0.5]
)
print(" MegaDescriptor-L-384 loaded successfully")
except Exception as e:
print(f"❌ CRITICAL: Failed to load MegaDescriptor model: {e}")
raise RuntimeError(f"Cannot initialize ReID system: {e}")
def set_threshold(self, threshold: float):
self.threshold = max(0.10, min(0.95, threshold))
print(f"ReID threshold updated: {self.threshold:.2f}")
def set_video_source(self, video_path: str):
self.current_video_source = video_path
print(f"Video source set: {video_path}")
def reset_session(self):
self.temp_id_features.clear()
self.next_temp_id = 1
self.current_frame = 0
print("ReID session reset")
def extract_features(self, image: np.ndarray) -> Optional[np.ndarray]:
if image is None or image.size == 0:
return None
try:
img_rgb = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
pil_img = Image.fromarray(img_rgb)
img_tensor = self.transform(pil_img).unsqueeze(0).to(self.device)
with torch.no_grad():
features = self.model(img_tensor)
features = features.squeeze().cpu().numpy()
features = features / (np.linalg.norm(features) + 1e-7)
return features
except Exception as e:
print(f"Feature extraction error: {e}")
return None
def match_or_register(self, track) -> Dict:
self.current_frame += 1
detection = None
for det in reversed(track.detections[-3:]):
if det.image_crop is not None:
detection = det
break
if detection is None or detection.image_crop is None:
return {'temp_id': 0, 'confidence': 0.0, 'match_type': 'failed'}
features = self.extract_features(detection.image_crop)
if features is None:
return {'temp_id': 0, 'confidence': 0.0, 'match_type': 'failed'}
detection.features = features
best_temp_id = None
best_score = -1.0
for temp_id, features_list in self.temp_id_features.items():
similarities = []
for stored_features in features_list:
sim = np.dot(features, stored_features)
similarities.append(sim)
if similarities:
max_sim = max(similarities)
if max_sim > best_score:
best_score = max_sim
best_temp_id = temp_id
if best_temp_id is not None and best_score >= self.threshold:
self.temp_id_features[best_temp_id].append(features)
if len(self.temp_id_features[best_temp_id]) > 30:
self.temp_id_features[best_temp_id] = \
self.temp_id_features[best_temp_id][-30:]
return {
'temp_id': best_temp_id,
'confidence': best_score,
'match_type': 'existing'
}
else:
new_temp_id = self.next_temp_id
self.next_temp_id += 1
self.temp_id_features[new_temp_id] = [features]
print(f" New temp dog ID: {new_temp_id} (threshold: {self.threshold:.2f})")
return {
'temp_id': new_temp_id,
'confidence': 1.0,
'match_type': 'new'
}
def get_temp_id_features(self, temp_id: int) -> Optional[np.ndarray]:
if temp_id not in self.temp_id_features:
return None
features_list = self.temp_id_features[temp_id]
if not features_list:
return None
features_array = np.array(features_list)
avg_features = np.mean(features_array, axis=0)
avg_features = avg_features / (np.linalg.norm(avg_features) + 1e-7)
return avg_features
def get_statistics(self) -> Dict:
stats = {
'temp_ids': len(self.temp_id_features),
'threshold': self.threshold,
'current_frame': self.current_frame,
'video_source': self.current_video_source
}
feature_counts = {}
for temp_id, features in self.temp_id_features.items():
feature_counts[temp_id] = len(features)
stats['feature_counts'] = feature_counts
return stats
def compare_with_permanent_database(self, temp_id: int, database) -> Optional[Dict]:
temp_features = self.get_temp_id_features(temp_id)
if temp_features is None:
return None
all_dogs = database.get_all_dogs(active_only=True)
if all_dogs.empty:
return None
best_match = None
best_score = 0.0
for _, dog in all_dogs.iterrows():
dog_id = dog['dog_id']
stored_features = database.get_dog_features(dog_id)
if stored_features is None:
continue
similarity = np.dot(temp_features, stored_features)
if similarity > best_score:
best_score = similarity
best_match = dog_id
permanent_threshold = self.threshold + 0.15
if best_match is not None and best_score >= permanent_threshold:
return {
'dog_id': best_match,
'confidence': best_score,
'matched': True
}
else:
return None
ReIDSystem = SimplifiedReID
DogReID = SimplifiedReID