Spaces:
Running
on
Zero
Running
on
Zero
| import os.path as osp | |
| import os | |
| import sys | |
| import json | |
| import itertools | |
| import time | |
| from collections import deque | |
| import torch | |
| import tqdm | |
| import concurrent.futures | |
| import psutil | |
| import io | |
| import cv2 | |
| from PIL import Image | |
| import numpy as np | |
| from models.SpaTrackV2.models.utils import matrix_to_quaternion | |
| from models.SpaTrackV2.datasets.base_sfm_dataset import BaseSfMViewDataset | |
| from models.SpaTrackV2.models.utils import ( | |
| camera_to_pose_encoding, pose_encoding_to_camera | |
| ) | |
| from models.SpaTrackV2.models.camera_transform import normalize_cameras | |
| from models.SpaTrackV2.datasets.tartan_utils.traj_tf import ned2cam | |
| from models.SpaTrackV2.datasets.tartan_utils.cam_tf import pos_quats2SE_matrices | |
| from models.SpaTrackV2.datasets.dataset_util import ( | |
| imread_cv2, npz_loader, read_video,npy_loader,resize_crop_video,logsig_fn | |
| ) | |
| import copy | |
| class Vkitti(BaseSfMViewDataset): | |
| def __init__(self, *args, ROOT, **kwargs): | |
| self.ROOT = ROOT | |
| super().__init__(*args, **kwargs) | |
| self.dataset_label = 'Vkitti' | |
| scene_list = os.listdir(self.ROOT) | |
| scene_dir = [os.path.join(self.ROOT, scene) for scene in scene_list if os.path.isdir(os.path.join(self.ROOT, scene))] | |
| case_list = ["15-deg-left", "15-deg-right", "30-deg-left", "30-deg-right", "clone", "fog", "morning", "overcast", "rain", "sunset"] | |
| self.scene_list = [os.path.join(scene, case) for scene in scene_dir for case in case_list] | |
| self.intr = np.array([[725.0, 0, 620.5], [0, 725.0, 187.0], [0, 0, 1]]) | |
| def __len__(self): | |
| return len(self.scene_list) | |
| def _get_views(self, idx, resolution, rng): | |
| #NOTE: to be removed | |
| # idx = 0 | |
| scene = self.scene_list[idx] | |
| # cam types | |
| cam_picked = np.random.choice(['Camera_0', 'Camera_1']) | |
| rgbs_dir_root = os.path.join(scene, 'frames', 'rgb', cam_picked) | |
| # get the pose data | |
| pose_dir = os.path.join(scene, f"extrinsic.txt") | |
| with open(pose_dir, 'r') as f: | |
| pose_data = [line.strip().split() for line in f.readlines()] | |
| pose_data = pose_data[1:] | |
| if cam_picked == 'Camera_0': | |
| pose_data = pose_data[::2] | |
| else: | |
| pose_data = pose_data[1::2] | |
| pose_data = np.array(pose_data, dtype=np.float32) | |
| pose_data = pose_data[:,2:].reshape(-1, 4, 4) | |
| imgs_pool = sorted(os.listdir(rgbs_dir_root)) | |
| imgs_pool = [os.path.join(rgbs_dir_root, img) for img in imgs_pool if img.endswith('.jpg')] | |
| T = len(imgs_pool) | |
| # randomly choose a scene | |
| sclae_num = int(np.random.uniform(1, 1)) | |
| start = np.random.choice(np.arange(0, max(T - sclae_num*self.num_views, 1))) | |
| idxs = np.arange(start, start+sclae_num*self.num_views, sclae_num).clip(0, T-1) | |
| images_pick = np.array(imgs_pool)[idxs] | |
| # get the all attributes | |
| extrs = [] | |
| rgbs = [] | |
| depths = [] | |
| for i, img_dir_i in enumerate(images_pick): | |
| img_dir = img_dir_i | |
| depth_dir = img_dir.replace("/rgb/", "/depth/").replace("jpg", "png").replace("rgb_", "depth_") | |
| # load rgb and depth | |
| rgb = imread_cv2(img_dir) | |
| depth = imread_cv2(depth_dir, cv2.IMREAD_UNCHANGED)/100 | |
| rgbs.append(rgb) | |
| depths.append(depth) | |
| # load pose | |
| id_num = int(depth_dir.split("/")[-1][6:-4]) | |
| extrs.append(pose_data[id_num]) | |
| rgbs = np.stack(rgbs, axis=0) | |
| depths = np.stack(depths, axis=0) | |
| extrs = np.stack(extrs, axis=0) | |
| # augumentation | |
| intrinsics = self.intr[None,...].repeat(len(rgbs), axis=0) | |
| #TODO change the sequence of the frames | |
| # if np.random.choice([True, False]): | |
| # rgbs = rgbs[::-1].copy() | |
| # depths = depths[::-1].copy() | |
| # poses = poses[::-1].copy() | |
| # intrinsics = intrinsics[::-1].copy() | |
| # crop and resize | |
| rgbs, depths, Intrs = resize_crop_video(rgbs, depths, intrinsics, resolution[0]) | |
| # encode the camera poses | |
| Extrs = torch.from_numpy(extrs) | |
| camera_poses = torch.inverse(Extrs) #NOTE: C2W | |
| focal0 = Intrs[:, 0, 0] / resolution[0] | |
| focal1 = Intrs[:, 1, 1] / resolution[0] | |
| focal = (focal0.unsqueeze(1)+focal1.unsqueeze(1))/2 | |
| # first frame normalize | |
| camera_poses = torch.inverse(camera_poses[:1]) @ camera_poses | |
| T_center = camera_poses[:, :3, 3].mean(dim=0) | |
| Radius = (camera_poses[:, :3, 3].norm(dim=1).max()) | |
| # if Radius < 1e-2: | |
| Radius = 1 | |
| camera_poses[:, :3, 3] = (camera_poses[:, :3, 3])/Radius | |
| R = camera_poses[:, :3, :3] | |
| t = camera_poses[:, :3, 3] | |
| rot_vec = matrix_to_quaternion(R) | |
| pose_enc = torch.cat([t, rot_vec, focal], dim=1) | |
| # depth_cano = Radius*focal[:,:,None,None] / depths.clamp(min=1e-6) | |
| depth_cano = depths / Radius | |
| traj_3d = torch.zeros(self.num_views, self.track_num, 3) | |
| vis = torch.zeros(self.num_views, self.track_num) | |
| syn_real = torch.tensor([1]) | |
| metric_rel = torch.tensor([1]) | |
| static = torch.tensor([0]) | |
| data_dir = scene | |
| views = dict( | |
| rgbs=rgbs, | |
| depths=depth_cano, | |
| pose_enc=pose_enc, | |
| traj_mat=camera_poses, | |
| intrs=Intrs, | |
| traj_3d=traj_3d, | |
| vis=vis, | |
| syn_real=syn_real, | |
| metric_rel=metric_rel, | |
| static=static, | |
| data_dir=data_dir | |
| ) | |
| return views | |
| if __name__ == "__main__": | |
| from models.videocam.datasets.base_sfm_dataset import view_name | |
| from functools import partial | |
| DATA_DIR = "/mnt/bn/xyxdata/data/4d_data/vkitti/" | |
| dataset = Vkitti(split='train', ROOT=DATA_DIR, | |
| resolution=518, aug_crop=16, num_views=16) | |
| rng = np.random.default_rng(seed=0) | |
| data_ret = dataset._get_views(0,(518,518),rng) | |
| from models.videocam.datasets.vis3d_check import vis4d | |
| vis4d(data_ret["rgbs"], data_ret["depths"], | |
| data_ret["traj_mat"], data_ret["intrs"], workspace="/mnt/bn/xyxdata/home/codes/my_projs/SpaTrack2/vis_results/test") | |
| import pdb; pdb.set_trace() |