import os.path as osp import os import sys import json import itertools import time from collections import deque import torch import tqdm import concurrent.futures import psutil import io import cv2 from PIL import Image import numpy as np from models.SpaTrackV2.models.utils import matrix_to_quaternion from models.SpaTrackV2.datasets.base_sfm_dataset import BaseSfMViewDataset from models.SpaTrackV2.models.utils import ( camera_to_pose_encoding, pose_encoding_to_camera ) from models.SpaTrackV2.models.camera_transform import normalize_cameras from models.SpaTrackV2.datasets.tartan_utils.traj_tf import ned2cam from models.SpaTrackV2.datasets.tartan_utils.cam_tf import pos_quats2SE_matrices from models.SpaTrackV2.utils.visualizer import Visualizer import glob from models.SpaTrackV2.datasets.dataset_util import ( imread_cv2, npz_loader, read_video,npy_loader,resize_crop_video ) class Spring(BaseSfMViewDataset): def __init__(self, *args, ROOT, **kwargs): self.ROOT = ROOT super().__init__(*args, **kwargs) scene_list = os.listdir(self.ROOT) self.scene_list = sorted([osp.join(self.ROOT, scene) for scene in scene_list]) self.except_list = [6] self.scene_list = [self.scene_list[i] for i in range(len(self.scene_list)) if i not in self.except_list] def __len__(self): return len(self.scene_list) def _get_views(self, idx, resolution, rng): # scene root dir scene = self.scene_list[idx] img_len = len(glob.glob(os.path.join(scene, "rgb", "*.png"))) try: sclae_num = int(np.random.uniform(2, 3)) start = np.random.choice(np.arange(0, max(img_len - sclae_num*self.num_views, 1))) img_idxs = np.arange(start, start+sclae_num*self.num_views, sclae_num).clip(0, img_len-1) except Exception: img_idxs = np.arange(0, img_len, 1) frame_dir = os.path.join(scene, "rgb") depth_dir = os.path.join(scene, "depth") cam_dir = os.path.join(scene, "cam") rgbs = [] depths = [] extrs = [] intrs = [] # read the frames and depths for idx_i in img_idxs: idx_i = int(idx_i) img_path = os.path.join(frame_dir, f"{idx_i:04d}.png") depth_path = os.path.join(depth_dir, f"{idx_i:04d}.npy") # read depth and image img = imread_cv2(img_path) depth = npy_loader(depth_path) depth[~np.isfinite(depth)] = 0 # invalid rgbs.append(img) depths.append(depth) # read extrinsics and intrinsics cam = np.load(osp.join(cam_dir, f"{idx_i:04d}.npz")) extrs.append(cam["pose"]) intrs.append(cam["intrinsics"]) rgbs = np.stack(rgbs, axis=0) # (T, H, W, C) depths = np.stack(depths, axis=0) # (T, H, W) extrs = np.stack(extrs, axis=0) # (T, 4, 4) intrs = np.stack(intrs, axis=0) # (T, 3, 3) # convert BGR to RGB # rgbs = rgbs[..., [2, 1, 0]] depths = depths T, H, W, _ = rgbs.shape # convert them into numpy array intrinsics = intrs extrinsics = extrs # augumentation # change the sequence of the frames # create the track annotations # print("The shape of traj_3d is not correct") traj_2d = np.zeros((self.num_views, self.track_num, 2)) traj_3d = np.zeros((self.num_views, self.track_num, 3)) vis = np.zeros((self.num_views, self.track_num)) poses = extrinsics # get tensor track traj_2d = torch.from_numpy(traj_2d) traj_3d = torch.from_numpy(traj_3d) vis = torch.from_numpy(vis) # crop and resize rgbs, depths, Intrs = resize_crop_video(rgbs, depths, intrinsics, resolution[0]) # encode the camera poses Extrs = torch.from_numpy(poses) camera_poses = Extrs #NOTE: C2W focal0 = Intrs[:, 0, 0] / resolution[0] focal1 = Intrs[:, 1, 1] / resolution[0] focal = (focal0.unsqueeze(1)+focal1.unsqueeze(1))/2 # first frame normalize camera_poses = torch.inverse(camera_poses[:1]) @ camera_poses T_center = camera_poses[:, :3, 3].mean(dim=0) Radius = (camera_poses[:, :3, 3].norm(dim=1).max()) # if Radius < 1e-2: Radius = 1 camera_poses[:, :3, 3] = (camera_poses[:, :3, 3])/Radius R = camera_poses[:, :3, :3] t = camera_poses[:, :3, 3] rot_vec = matrix_to_quaternion(R) pose_enc = torch.cat([t, rot_vec, focal], dim=1) # depth_cano = Radius*focal[:,:,None,None] / depths.clamp(min=1e-6) depth_cano = depths / Radius #TODO: DEBUG depth range # metric_depth = depth_cano.clone() # metric_depth[metric_depth == torch.inf] = 0 # _depths = metric_depth[metric_depth > 0].reshape(-1) # q25 = torch.kthvalue(_depths, int(0.25 * len(_depths))).values # q75 = torch.kthvalue(_depths, int(0.75 * len(_depths))).values # iqr = q75 - q25 # upper_bound = (q75 + 0.8*iqr).clamp(min=1e-6, max=10*q25) # _depth_roi = torch.tensor( # [1e-1, upper_bound.item()], # dtype=metric_depth.dtype, # device=metric_depth.device # ) # mask_roi = (metric_depth > _depth_roi[0]) & (metric_depth < _depth_roi[1]) # depth_cano = depth_cano * mask_roi.float() depth_cano[depth_cano==torch.nan] = 0 traj_3d = torch.zeros(self.num_views, self.track_num, 3) vis = torch.zeros(self.num_views, self.track_num) syn_real = torch.tensor([1]) metric_rel = torch.tensor([1]) data_dir = scene views = dict( rgbs=rgbs, depths=depth_cano, pose_enc=pose_enc, traj_mat=camera_poses, intrs=Intrs, traj_3d=traj_3d, vis=vis, syn_real=syn_real, metric_rel=metric_rel, data_dir=data_dir ) return views if __name__ == "__main__": from models.SpaTrackV2.datasets.base_sfm_dataset import view_name from functools import partial import shutil DATA_DIR = "/mnt/bn/xyxdata/data/4d_data/spring/processed" dataset = Spring(split='train', ROOT=DATA_DIR, resolution=518, aug_crop=16, num_views=32) rng = np.random.default_rng(seed=0) data_ret = dataset._get_views(6,(518,518),rng) from models.SpaTrackV2.datasets.vis3d_check import vis4d vis4d(data_ret["rgbs"], data_ret["depths"], data_ret["traj_mat"], data_ret["intrs"], track3d=data_ret["traj_3d"], workspace="/mnt/bn/xyxdata/home/codes/my_projs/SpaTrack2/viser_result/test") import pdb; pdb.set_trace()