Spaces:
Running
on
Zero
Running
on
Zero
| import os.path as osp | |
| import os | |
| import sys | |
| import json | |
| import itertools | |
| import time | |
| from collections import deque | |
| import torch | |
| import tqdm | |
| import concurrent.futures | |
| import psutil | |
| import io | |
| import cv2 | |
| from PIL import Image | |
| import numpy as np | |
| from models.SpaTrackV2.models.utils import matrix_to_quaternion | |
| from models.SpaTrackV2.datasets.base_sfm_dataset import BaseSfMViewDataset | |
| from models.SpaTrackV2.models.utils import ( | |
| camera_to_pose_encoding, pose_encoding_to_camera | |
| ) | |
| from models.SpaTrackV2.models.camera_transform import normalize_cameras | |
| from models.SpaTrackV2.datasets.tartan_utils.traj_tf import ned2cam | |
| from models.SpaTrackV2.datasets.tartan_utils.cam_tf import pos_quats2SE_matrices | |
| from models.SpaTrackV2.utils.visualizer import Visualizer | |
| try: | |
| from pcache_fileio import fileio | |
| except Exception: | |
| fileio = None | |
| try: | |
| import fsspec | |
| #NOTE: stable version (not public) | |
| PCACHE_HOST = "vilabpcacheproxyi-pool.cz50c.alipay.com" | |
| PCACHE_PORT = 39999 | |
| pcache_kwargs = {"host": PCACHE_HOST, "port": PCACHE_PORT} | |
| pcache_fs = fsspec.filesystem("pcache", pcache_kwargs=pcache_kwargs) | |
| except Exception: | |
| fsspec = None | |
| from models.SpaTrackV2.datasets.dataset_util import ( | |
| imread_cv2, npz_loader, read_video,npy_loader,resize_crop_video | |
| ) | |
| import re | |
| class Kubric(BaseSfMViewDataset): | |
| def __init__(self, *args, ROOT, **kwargs): | |
| self.ROOT = ROOT | |
| super().__init__(*args, **kwargs) | |
| scene_list = os.listdir(self.ROOT) | |
| self.scene_list = [osp.join(self.ROOT, scene) for scene in scene_list if bool(re.fullmatch(r"\d{4}", scene))] | |
| def __len__(self): | |
| return len(self.scene_list) | |
| def _get_views(self, idx, resolution, rng): | |
| # scene root dir | |
| scene = self.scene_list[idx] | |
| sclae_num = int(np.random.uniform(2, 3)) | |
| start = np.random.choice(np.arange(0, max(120 - sclae_num*self.num_views, 1))) | |
| img_idxs = np.arange(start, start+sclae_num*self.num_views, sclae_num).clip(0, 120-1) | |
| frame_dir = os.path.join(scene, "frames") | |
| depth_dir = os.path.join(scene, "depths") | |
| rgbs = [] | |
| depths = [] | |
| # read the frames and depths | |
| for idx_i in img_idxs: | |
| idx_i = int(idx_i) | |
| img_path = os.path.join(frame_dir, f"{idx_i:03d}.png") | |
| depth_path = os.path.join(depth_dir, f"{idx_i:03d}.npy") | |
| img = imread_cv2(img_path) | |
| depth = npy_loader(depth_path) | |
| rgbs.append(img) | |
| depths.append(depth) | |
| rgbs = np.stack(rgbs, axis=0) # (T, H, W, C) | |
| depths = np.stack(depths, axis=0) # (T, H, W) | |
| img_idxs = np.array(img_idxs).astype(int) | |
| # convert BGR to RGB | |
| rgbs = rgbs | |
| depths = depths | |
| T, H, W, _ = rgbs.shape | |
| num_scene = scene.split("/")[-1] | |
| meta_dir = os.path.join(scene, f"{num_scene}.npy") | |
| meta_rank_dir = os.path.join(scene, f"{num_scene}_with_rank.npz") | |
| meta_data = np.load(meta_dir, allow_pickle=True).item() | |
| meta_data_rank = dict(np.load(meta_rank_dir, allow_pickle=True).items()) | |
| # total track num | |
| t_num = meta_data['coords'].shape[0] | |
| t_idx = np.random.choice(t_num, self.track_num // 2, replace=False) | |
| traj_2d = meta_data['coords'][t_idx][:, img_idxs].transpose(1,0,2) | |
| traj_depth = meta_data["coords_depth"][t_idx][:, img_idxs].transpose(1,0)[...,None] | |
| # c2w | |
| intrinsics = meta_data_rank["shared_intrinsics"][None].repeat(len(img_idxs), 0) | |
| extrinsics = meta_data_rank["extrinsics"][img_idxs] | |
| traj_2d_add = np.concatenate([traj_2d, np.ones_like(traj_depth)], axis=-1) | |
| ray_uv = np.linalg.norm(np.einsum('tij,tnj->tni', np.linalg.inv(intrinsics), traj_2d_add), axis=-1)[...,None] | |
| traj_3d = np.concatenate([traj_2d, traj_depth/ray_uv], axis=-1) | |
| vis = ~meta_data['visibility'][t_idx][:, img_idxs].transpose(1,0) | |
| # convert them into numpy array | |
| intrinsics = np.array(intrinsics) | |
| extrinsics = np.array(extrinsics) | |
| traj_3d = np.array(traj_3d) | |
| traj_2d = np.array(traj_2d) | |
| vis = np.array(vis) | |
| # augumentation | |
| # change the sequence of the frames | |
| # create the track annotations | |
| if traj_3d.shape[-1] != 3: | |
| print("The shape of traj_3d is not correct") | |
| traj_2d = np.zeros((self.num_views, self.track_num // 2, 2)) | |
| traj_3d = np.zeros((self.num_views, self.track_num // 2, 3)) | |
| vis = np.zeros((self.num_views, self.track_num // 2)) | |
| poses = extrinsics.copy() | |
| # get tensor track | |
| traj_2d = torch.from_numpy(traj_2d) | |
| traj_3d = torch.from_numpy(traj_3d) | |
| vis = torch.from_numpy(vis) | |
| # crop and resize | |
| rgbs, depths, Intrs, traj_3d = resize_crop_video(rgbs, depths, | |
| intrinsics, resolution[0], traj_3d) | |
| # update the visibility | |
| if traj_3d.sum() != 0: | |
| traj_2d_proj = traj_3d[..., :2] | |
| H_, W_ = rgbs.shape[-2:] | |
| in_scope = (traj_2d_proj[..., 0] > 0) & (traj_2d_proj[..., 0] < W_) & (traj_2d_proj[..., 1] > 0) & (traj_2d_proj[..., 1] < H_) | |
| vis = vis & in_scope | |
| # filter the invisible points | |
| mask_vis = vis.sum(dim=0) > 0 | |
| traj_3d = traj_3d[:, mask_vis] | |
| vis = vis[:, mask_vis] | |
| # pick fixed number of points | |
| if traj_3d.shape[1] != self.track_num // 2: | |
| traj_3d = torch.cat([traj_3d, | |
| traj_3d[:, :1].repeat(1, self.track_num // 2-traj_3d.shape[1], 1)], dim=1) | |
| vis = torch.cat([vis, | |
| vis[:, :1].repeat(1, self.track_num // 2-vis.shape[1])], dim=1) | |
| # encode the camera poses | |
| Extrs = torch.from_numpy(poses) | |
| Extrs_ = torch.eye(4).repeat(self.num_views, 1, 1) | |
| Extrs_[:, :3, :] = Extrs | |
| camera_poses = torch.inverse(Extrs_) #NOTE: C2W | |
| focal0 = Intrs[:, 0, 0] / resolution[0] | |
| focal1 = Intrs[:, 1, 1] / resolution[0] | |
| focal = (focal0.unsqueeze(1)+focal1.unsqueeze(1))/2 | |
| # first frame normalize | |
| camera_poses = torch.inverse(camera_poses[:1]) @ camera_poses | |
| T_center = camera_poses[:, :3, 3].mean(dim=0) | |
| Radius = (camera_poses[:, :3, 3].norm(dim=1).max()) | |
| # if Radius < 1e-2: | |
| Radius = 1 | |
| camera_poses[:, :3, 3] = (camera_poses[:, :3, 3])/Radius | |
| R = camera_poses[:, :3, :3] | |
| t = camera_poses[:, :3, 3] | |
| rot_vec = matrix_to_quaternion(R) | |
| pose_enc = torch.cat([t, rot_vec, focal], dim=1) | |
| # depth_cano = Radius*focal[:,:,None,None] / depths.clamp(min=1e-6) | |
| depth_cano = depths / Radius | |
| traj_3d[..., 2] = traj_3d[..., 2] / Radius | |
| depth_cano[depth_cano==torch.nan] = 0 | |
| syn_real = torch.tensor([1]) | |
| metric_rel = torch.tensor([1]) | |
| static = torch.tensor([0]) | |
| data_dir = scene | |
| views = dict( | |
| rgbs=rgbs, | |
| depths=depth_cano, | |
| pose_enc=pose_enc, | |
| traj_mat=camera_poses, | |
| intrs=Intrs, | |
| traj_3d=traj_3d, | |
| vis=vis, | |
| syn_real=syn_real, | |
| metric_rel=metric_rel, | |
| static=static, | |
| data_dir=data_dir | |
| ) | |
| return views | |
| if __name__ == "__main__": | |
| from models.videocam.datasets.base_sfm_dataset import view_name | |
| from functools import partial | |
| DATA_DIR ="/mnt/bn/xyxdata/data/4d_data/kubric" | |
| dataset = Kubric(split='train', ROOT=DATA_DIR, | |
| resolution=518, aug_crop=16, num_views=16) | |
| rng = np.random.default_rng(seed=0) | |
| data_ret = dataset._get_views(30,(384,384),rng) | |
| # check the 2d tracking vis | |
| viser = Visualizer(save_dir=".", grayscale=True, | |
| fps=10, pad_value=0, tracks_leave_trace=5) | |
| viser.visualize(video=data_ret["rgbs"][None], | |
| tracks=data_ret["traj_3d"][None,..., :2], | |
| visibility=data_ret["vis"][None], filename="test") | |
| # check the 4d visualization | |
| from models.videocam.datasets.vis3d_check import vis4d | |
| vis4d(data_ret["rgbs"], data_ret["depths"], | |
| data_ret["traj_mat"], data_ret["intrs"], track3d=data_ret["traj_3d"], | |
| workspace="/mnt/bn/xyxdata/home/codes/my_projs/SpaTrack2/viser_result/test") | |
| import pdb; pdb.set_trace() |