Spaces:
Running
on
Zero
Running
on
Zero
| import io | |
| import os | |
| import cv2 | |
| import numpy as np | |
| import torch | |
| from torch.utils.data import DataLoader, Dataset | |
| from einops import rearrange | |
| from PIL import Image | |
| from models.SpaTrackV2.datasets.delta_utils import DeltaData, least_square_align | |
| from models.SpaTrackV2.utils.model_utils import sample_features5d, bilinear_sampler | |
| from models.SpaTrackV2.models.tracker3D.spatrack_modules.alignment import align_depth_affine | |
| UINT16_MAX = 65535 | |
| TAPVID3D_ROOT = None | |
| def get_jpeg_byte_hw(jpeg_bytes: bytes): | |
| with io.BytesIO(jpeg_bytes) as img_bytes: | |
| img = Image.open(img_bytes) | |
| img = img.convert("RGB") | |
| return np.array(img).shape[:2] | |
| def get_new_hw_with_given_smallest_side_length(*, orig_height: int, orig_width: int, smallest_side_length: int = 256): | |
| orig_shape = np.array([orig_height, orig_width]) | |
| scaling_factor = smallest_side_length / np.min(orig_shape) | |
| resized_shape = np.round(orig_shape * scaling_factor) | |
| return (int(resized_shape[0]), int(resized_shape[1])), scaling_factor | |
| def project_points_to_video_frame(camera_pov_points3d, camera_intrinsics, height, width): | |
| """Project 3d points to 2d image plane.""" | |
| u_d = camera_pov_points3d[..., 0] / (camera_pov_points3d[..., 2] + 1e-8) | |
| v_d = camera_pov_points3d[..., 1] / (camera_pov_points3d[..., 2] + 1e-8) | |
| f_u, f_v, c_u, c_v = camera_intrinsics | |
| u_d = u_d * f_u + c_u | |
| v_d = v_d * f_v + c_v | |
| # Mask of points that are in front of the camera and within image boundary | |
| masks = camera_pov_points3d[..., 2] >= 1 | |
| masks = masks & (u_d >= 0) & (u_d < width) & (v_d >= 0) & (v_d < height) | |
| return np.stack([u_d, v_d], axis=-1), masks | |
| class TapVid3DDataset(Dataset): | |
| def __init__( | |
| self, | |
| data_root, | |
| mega_data_root=None, | |
| datatype="pstudio", | |
| crop_size=256, | |
| debug=False, | |
| use_metric_depth=True, | |
| split="minival", | |
| depth_type="megasam", | |
| read_from_s3=False, | |
| ): | |
| if split == "all": | |
| datatype = datatype | |
| self.datatype = datatype | |
| self.data_root = os.path.join(data_root, datatype) | |
| if mega_data_root is not None: | |
| self.mega_data_root = os.path.join(mega_data_root, datatype) | |
| else: | |
| self.mega_data_root = None | |
| self.video_names = sorted([f.split(".")[0] for f in os.listdir(self.data_root) if f.endswith(".npz")]) | |
| self.debug = debug | |
| self.crop_size = crop_size | |
| self.use_metric_depth = use_metric_depth | |
| self.depth_type = depth_type | |
| print(f"Found {len(self.video_names)} samples for TapVid3D {datatype}") | |
| def __len__(self): | |
| if self.debug: | |
| return 10 | |
| return len(self.video_names) | |
| def __getitem__(self, index): | |
| video_name = self.video_names[index] | |
| gt_path = os.path.join(self.data_root, f"{video_name}.npz") | |
| # with open(gt_path, 'rb') as in_f: | |
| # in_npz = np.load(in_f, allow_pickle=True) | |
| try: | |
| in_npz = np.load(gt_path, allow_pickle=True) | |
| except: | |
| # return self.__getitem__(1) | |
| return None | |
| images_jpeg_bytes = in_npz["images_jpeg_bytes"] | |
| video = [] | |
| for frame_bytes in images_jpeg_bytes: | |
| arr = np.frombuffer(frame_bytes, np.uint8) | |
| image_bgr = cv2.imdecode(arr, flags=cv2.IMREAD_UNCHANGED) | |
| image_rgb = cv2.cvtColor(image_bgr, cv2.COLOR_BGR2RGB) | |
| video.append(image_rgb) | |
| video = np.stack(video, axis=0) | |
| metric_extrs = None | |
| if self.use_metric_depth: | |
| try: | |
| if self.depth_type == "unidepth": | |
| metric_videodepth = in_npz["depth_preds"] # NOTE UniDepth | |
| elif self.depth_type == "megasam": | |
| if os.path.exists(os.path.join(self.mega_data_root, f"{video_name}.npz")): | |
| mega_meta_npz = np.load(os.path.join(self.mega_data_root, f"{video_name}.npz"), allow_pickle=True) | |
| metric_videodepth = mega_meta_npz["depths"] # resive the depth to the same size as the video # T HW | |
| metric_videodepth = cv2.resize(metric_videodepth.transpose(1, 2, 0), (video.shape[2], video.shape[1]), interpolation=cv2.INTER_NEAREST).transpose(2, 0, 1) # use the nearest interpolation | |
| metric_videodepth_unidepth = in_npz["depth_preds"] # NOTE UniDepth | |
| metric_extrs = mega_meta_npz["extrinsics"] | |
| #NOTE: scale and shift | |
| # scale_gt, shift_gt = align_depth_affine( | |
| # torch.from_numpy(metric_videodepth[:,::32,::32]).reshape(150, -1).cuda(), | |
| # torch.from_numpy(metric_videodepth_unidepth[:,::32,::32]).reshape(150, -1).cuda(), | |
| # weight=torch.ones(metric_videodepth[:,::32,::32].reshape(150, -1).shape).cuda(), | |
| # ) | |
| # metric_videodepth = (metric_videodepth * scale_gt[:,None,None].cpu().numpy() + shift_gt[:,None,None].cpu().numpy()) | |
| # visualize the metric_videodepth as mp4 | |
| # Normalize depth to 0-255 for visualization | |
| # metric_videodepth = np.abs(metric_videodepth - metric_videodepth_unidepth) | |
| # depth_min = metric_videodepth.min() | |
| # depth_max = metric_videodepth.max() | |
| # depth_normalized = ((metric_videodepth - depth_min) / (depth_max - depth_min) * 255).astype(np.uint8) | |
| # # Create video writer | |
| # fourcc = cv2.VideoWriter_fourcc(*'mp4v') | |
| # out = cv2.VideoWriter('basketball_17_depth.mp4', fourcc, 30.0, (depth_normalized.shape[2], depth_normalized.shape[1]), False) | |
| # # Write each frame | |
| # for frame in depth_normalized: | |
| # out.write(frame) | |
| # out.release() | |
| # print("Depth visualization saved as basketball_17_depth.mp4") | |
| # import pdb; pdb.set_trace() | |
| else: | |
| return None | |
| videodepth = metric_videodepth | |
| except: | |
| return None | |
| else: | |
| try: | |
| videodisp = in_npz["depth_preds_depthcrafter"] | |
| except: | |
| T, H, W, _ = video.shape | |
| videodisp = np.ones(video.shape[:3], dtype=video.dtype) | |
| metric_videodepth = np.ones(video.shape[:3], dtype=video.dtype) | |
| videodisp = videodisp.astype(np.float32) / UINT16_MAX | |
| videodepth = least_square_align(metric_videodepth, videodisp, return_align_scalar=False) | |
| queries_xyt = in_npz["queries_xyt"] | |
| tracks_xyz = in_npz["tracks_XYZ"] | |
| visibles = in_npz["visibility"] | |
| intrinsics_params = in_npz["fx_fy_cx_cy"] | |
| tracks_uv, _ = project_points_to_video_frame(tracks_xyz, intrinsics_params, video.shape[1], video.shape[2]) | |
| scaling_factor = 1.0 | |
| intrinsics_params_resized = intrinsics_params * scaling_factor | |
| intrinsic_mat = np.array( | |
| [ | |
| [intrinsics_params_resized[0], 0, intrinsics_params_resized[2]], | |
| [0, intrinsics_params_resized[1], intrinsics_params_resized[3]], | |
| [0, 0, 1], | |
| ] | |
| ) | |
| intrinsic_mat = torch.from_numpy(intrinsic_mat).float() | |
| intrinsic_mat = intrinsic_mat[None].repeat(video.shape[0], 1, 1) | |
| video = torch.from_numpy(video).permute(0, 3, 1, 2).float() | |
| videodepth = torch.from_numpy(videodepth).float().unsqueeze(1) | |
| segs = torch.ones_like(videodepth) | |
| trajectory_3d = torch.from_numpy(tracks_xyz).float() # T N D | |
| trajectory_2d = torch.from_numpy(tracks_uv).float() # T N 2 | |
| visibility = torch.from_numpy(visibles) | |
| query_points = torch.from_numpy(queries_xyt).float() | |
| sample_coords = torch.cat([query_points[:, 2:3], query_points[:, :2]], dim=-1)[None, None, ...] # 1 1 N 3 | |
| rgb_h, rgb_w = video.shape[2], video.shape[3] | |
| depth_h, depth_w = videodepth.shape[2], videodepth.shape[3] | |
| if rgb_h != depth_h or rgb_w != depth_w: | |
| sample_coords[..., 1] = sample_coords[..., 1] * depth_w / rgb_w | |
| sample_coords[..., 2] = sample_coords[..., 2] * depth_h / rgb_h | |
| query_points_depth = sample_features5d(videodepth[None], sample_coords, interp_mode="nearest") | |
| query_points_depth = query_points_depth.squeeze(0, 1) | |
| query_points_3d = torch.cat( | |
| [query_points[:, 2:3], query_points[:, :2], query_points_depth], dim=-1 | |
| ) # NOTE by default, query is N 3: xyt but we use N 3: txy | |
| data = DeltaData( | |
| video=video, | |
| videodepth=videodepth, | |
| segmentation=segs, | |
| trajectory=trajectory_2d, | |
| trajectory3d=trajectory_3d, | |
| visibility=visibility, | |
| seq_name=video_name, | |
| query_points=query_points_3d, | |
| intrs=intrinsic_mat, | |
| extrs=metric_extrs, | |
| ) | |
| return data | |
| if __name__ == "__main__": | |
| dataset = TapVid3DDataset( | |
| data_root="/mnt/bn/xyxdata/home/codes/my_projs/tapnet/tapvid3d_dataset", | |
| datatype="pstudio", | |
| depth_type="megasam", | |
| mega_data_root="/mnt/bn/xyxdata/home/codes/my_projs/TAPIP3D/outputs/inference", | |
| ) | |
| print(len(dataset)) | |
| for i in range(len(dataset)): | |
| data = dataset[i] | |
| if data is None: | |
| print(f"data is None at index {i}") | |
| print(data) |