Spaces:
Running
on
Zero
Running
on
Zero
| # Copyright (C) 2024-present Naver Corporation. All rights reserved. | |
| # Licensed under CC BY-NC-SA 4.0 (non-commercial use only). | |
| # | |
| # -------------------------------------------------------- | |
| # Dataloader for preprocessed Co3d_v2 | |
| # dataset at https://github.com/facebookresearch/co3d - Creative Commons Attribution-NonCommercial 4.0 International | |
| # See datasets_preprocess/preprocess_co3d.py | |
| # -------------------------------------------------------- | |
| import os.path as osp | |
| import os | |
| import sys | |
| import json | |
| import itertools | |
| import time | |
| from collections import deque | |
| import torch | |
| import tqdm | |
| import concurrent.futures | |
| import psutil | |
| import io | |
| from models.SpaTrackV2.models.utils import matrix_to_quaternion | |
| import cv2 | |
| from PIL import Image | |
| import numpy as np | |
| from models.SpaTrackV2.datasets.base_sfm_dataset import BaseSfMViewDataset | |
| from models.SpaTrackV2.datasets.dataset_util import imread_cv2, npz_loader | |
| def bytes_to_gb(bytes): | |
| return bytes / (1024 ** 3) | |
| def get_total_size(obj, seen=None): | |
| size = sys.getsizeof(obj) | |
| if seen is None: | |
| seen = set() | |
| obj_id = id(obj) | |
| if obj_id in seen: | |
| return 0 | |
| seen.add(obj_id) | |
| if isinstance(obj, dict): | |
| size += sum([get_total_size(v, seen) for v in obj.values()]) | |
| size += sum([get_total_size(k, seen) for k in obj.keys()]) | |
| elif hasattr(obj, '__dict__'): | |
| size += get_total_size(obj.__dict__, seen) | |
| elif hasattr(obj, '__iter__') and not isinstance(obj, (str, bytes, bytearray)): | |
| size += sum([get_total_size(i, seen) for i in obj]) | |
| return size | |
| class Co3d(BaseSfMViewDataset): | |
| def __init__(self, mask_bg=True, scene_st=None, scene_end=None, | |
| debug=False, *args, ROOT, **kwargs): | |
| self.ROOT = ROOT | |
| super().__init__(*args, **kwargs) | |
| assert mask_bg in (True, False, 'rand') | |
| self.mask_bg = mask_bg | |
| self.dataset_label = 'Co3d_v2' | |
| # load all scenes | |
| with open(osp.join(self.ROOT, f'selected_seqs_{self.split}.json'), 'r') as f: | |
| self.scenes = json.load(f) | |
| self.scenes = {k: v for k, v in self.scenes.items() if len(v) > 0} | |
| self.scenes = {(k, k2): v2 for k, v in self.scenes.items() | |
| for k2, v2 in v.items()} | |
| if scene_st is not None: | |
| self.scene_list = list(self.scenes.keys())[scene_st:scene_end] | |
| else: | |
| self.scene_list = list(self.scenes.keys()) | |
| # for each scene, we have 100 images ==> 360 degrees (so 25 frames ~= 90 degrees) | |
| # we prepare all combinations of 75 images from 100 images | |
| # self.combinations = [sorted(torch.randperm(100)[:self.num_views].tolist()) for i in range(200)] | |
| self.combinations = [None] | |
| self.invalidate = {scene: {} for scene in self.scene_list} | |
| # get the buffer | |
| self.__buffer__ = False #self._get_pool_buffer_() | |
| def __len__(self): | |
| return len(self.scene_list) * len(self.combinations) | |
| def _get_metadatapath(self, obj, instance, view_idx): | |
| return osp.join(self.ROOT, obj, instance, 'images', f'frame{view_idx:06n}.npz') | |
| def _get_impath(self, obj, instance, view_idx): | |
| return osp.join(self.ROOT, obj, instance, 'images', f'frame{view_idx:06n}.jpg') | |
| def _get_depthpath(self, obj, instance, view_idx): | |
| return osp.join(self.ROOT, obj, instance, 'depths', f'frame{view_idx:06n}.jpg.geometric.png') | |
| def _get_maskpath(self, obj, instance, view_idx): | |
| return osp.join(self.ROOT, obj, instance, 'masks', f'frame{view_idx:06n}.png') | |
| def _read_depthmap(self, depthpath, input_metadata): | |
| depthmap = imread_cv2(depthpath, cv2.IMREAD_UNCHANGED) | |
| try: | |
| depth_max = input_metadata['maximum_depth'] | |
| except Exception: | |
| depth_max = 122 | |
| depthmap = (depthmap.astype(np.float32) / 65535) * np.nan_to_num(depth_max) | |
| return depthmap | |
| def _get_views(self, idx, resolution, rng): | |
| # choose a scene | |
| obj, instance = self.scene_list[idx // len(self.combinations)] | |
| image_pool = self.scenes[obj, instance] | |
| sclae_num = int(np.random.uniform(3, 5)) | |
| start = np.random.choice(np.arange(0, max(len(image_pool) - sclae_num*self.num_views, 1))) | |
| img_idxs = np.arange(start, start+sclae_num*self.num_views, sclae_num).clip(0, len(image_pool)-1) | |
| # add a bit of randomness | |
| last = len(image_pool) - 1 | |
| if resolution not in self.invalidate[obj, instance]: # flag invalid images | |
| self.invalidate[obj, instance][resolution] = [False for _ in range(len(image_pool))] | |
| # decide now if we mask the bg | |
| mask_bg = (self.mask_bg == True) or (self.mask_bg == 'rand' and rng.choice(2)) | |
| views = [] | |
| imgs_idxs = [int(max(0, min(im_idx, last))) for im_idx in img_idxs] | |
| imgs_idxs = deque(imgs_idxs) | |
| # output: {rgbs, depths, camera_enc, principal_point, image_size } | |
| rgbs = None | |
| depths = None | |
| Extrs = None | |
| Intrs = None | |
| while len(imgs_idxs) > 0: # some images (few) have zero depth | |
| im_idx = imgs_idxs.pop() | |
| if self.invalidate[obj, instance][resolution][im_idx]: | |
| # search for a valid image | |
| random_direction = 2 * rng.choice(2) - 1 | |
| for offset in range(1, len(image_pool)): | |
| tentative_im_idx = (im_idx + (random_direction * offset)) % len(image_pool) | |
| if not self.invalidate[obj, instance][resolution][tentative_im_idx]: | |
| im_idx = tentative_im_idx | |
| break | |
| view_idx = image_pool[im_idx] | |
| impath = self._get_impath(obj, instance, view_idx) | |
| depthpath = self._get_depthpath(obj, instance, view_idx) | |
| # load camera params | |
| metadata_path = self._get_metadatapath(obj, instance, view_idx) | |
| if (fileio is not None)&("pcache" in metadata_path): | |
| input_metadata = npz_loader(metadata_path) | |
| camera_pose = input_metadata['camera_pose'].astype(np.float32) | |
| intrinsics = input_metadata['camera_intrinsics'].astype(np.float32) | |
| rgb_image = imread_cv2(impath) | |
| depthmap = self._read_depthmap(depthpath, input_metadata) | |
| else: | |
| input_metadata = np.load(metadata_path) | |
| camera_pose = input_metadata['camera_pose'].astype(np.float32) | |
| intrinsics = input_metadata['camera_intrinsics'].astype(np.float32) | |
| # load image and depth | |
| rgb_image = imread_cv2(impath) | |
| depthmap = self._read_depthmap(depthpath, input_metadata) | |
| if mask_bg: | |
| # load object mask | |
| maskpath = self._get_maskpath(obj, instance, view_idx) | |
| maskmap = imread_cv2(maskpath, cv2.IMREAD_UNCHANGED).astype(np.float32) | |
| maskmap = (maskmap / 255.0) > 0.1 | |
| # update the depthmap with mask | |
| depthmap *= maskmap | |
| rgb_image, depthmap, intrinsics = self._crop_resize_if_necessary( | |
| rgb_image, depthmap, intrinsics, resolution, rng=rng, info=impath) | |
| num_valid = (depthmap > 0.0).sum() | |
| if num_valid == 0: | |
| # problem, invalidate image and retry | |
| self.invalidate[obj, instance][resolution][im_idx] = True | |
| imgs_idxs.append(im_idx) | |
| continue | |
| if rgbs is None: | |
| rgbs = torch.from_numpy( | |
| np.array(rgb_image)).permute(2, 0, 1).unsqueeze(0) | |
| depths = torch.from_numpy( | |
| np.array(depthmap)).unsqueeze(0).unsqueeze(0) | |
| Extrs = torch.from_numpy(camera_pose).unsqueeze(0) | |
| Intrs = torch.from_numpy(intrinsics).unsqueeze(0) | |
| else: | |
| rgbs = torch.cat([rgbs, torch.from_numpy( | |
| np.array(rgb_image)).permute(2, 0, 1).unsqueeze(0)], dim=0) | |
| depths = torch.cat([depths, torch.from_numpy( | |
| np.array(depthmap)).unsqueeze(0).unsqueeze(0)], dim=0) | |
| Extrs = torch.cat([Extrs, | |
| torch.from_numpy(camera_pose).unsqueeze(0)], dim=0) | |
| Intrs = torch.cat([Intrs, | |
| torch.from_numpy(intrinsics).unsqueeze(0)], dim=0) | |
| # rgbs = rgbs[:, [2, 1, 0], ...] | |
| # encode the camera poses | |
| # C2W | |
| camera_poses = Extrs | |
| focal0 = Intrs[:, 0, 0] / resolution[0] | |
| focal1 = Intrs[:, 1, 1] / resolution[0] | |
| focal = (focal0.unsqueeze(1)+focal1.unsqueeze(1))/2 | |
| # first frame normalize | |
| camera_poses = torch.inverse(camera_poses[:1]) @ camera_poses | |
| T_center = camera_poses[:, :3, 3].mean(dim=0) | |
| Radius = (camera_poses[:, :3, 3].norm(dim=1).max()) | |
| if Radius < 1e-2: | |
| Radius = 1 | |
| camera_poses[:, :3, 3] = (camera_poses[:, :3, 3])/Radius | |
| R = camera_poses[:, :3, :3] | |
| t = camera_poses[:, :3, 3] | |
| rot_vec = matrix_to_quaternion(R) | |
| pose_enc = torch.cat([t, rot_vec, focal], dim=1) | |
| # depth_cano = Radius*focal[:,:,None,None] / depths.clamp(min=1e-6) | |
| depth_cano = depths / (Radius) | |
| depth_cano[depth_cano==torch.nan] = 0 | |
| traj_3d = torch.zeros(self.num_views, self.track_num, 3) | |
| vis = torch.zeros(self.num_views, self.track_num) | |
| syn_real = torch.tensor([0]) | |
| metric_rel = torch.tensor([0]) | |
| data_dir = obj | |
| views = dict( | |
| rgbs=rgbs, | |
| depths=depth_cano, | |
| pose_enc=pose_enc, | |
| traj_mat=camera_poses, | |
| intrs=Intrs, | |
| traj_3d=traj_3d, | |
| vis=vis, | |
| syn_real=syn_real, | |
| metric_rel=metric_rel, | |
| data_dir=data_dir | |
| ) | |
| return views | |
| if __name__ == "__main__": | |
| from models.videocam.datasets.base_sfm_dataset import view_name | |
| from functools import partial | |
| # from dust3r.viz import SceneViz, auto_cam_size | |
| # from dust3r.utils.image import rgb | |
| if fileio is not None: | |
| DATA_DIR = "pcache://vilabpcacheproxyi-pool.cz50c.alipay.com:39999/mnt/antsys-vilab_datasets_pcache_datasets/3DV_Foundation/co3dv2" | |
| bakup_dir = "/input_ssd/datasets/3DV_Foundation/co3dv2" | |
| else: | |
| DATA_DIR = "/nas3/xyx/CO3D_processed" | |
| bakup_dir = "/nas3/xyx/3DV_Foundation" | |
| cpu_num_total = int(sys.argv[1]) if len(sys.argv) > 1 else 8 | |
| cpu_num_per = int(sys.argv[2]) if len(sys.argv) > 1 else 8 | |
| dataset = Co3d(split='train', ROOT=DATA_DIR, mask_bg=True, | |
| resolution=384, aug_crop=16, num_views=16) | |
| rng = np.random.default_rng(seed=0) | |
| data_ret = dataset._get_views(253, (518,518), rng) | |
| # prefetch_data = partial(dataset.prefetch_data, cpu_num_avg=cpu_num_per, bakup_dir=bakup_dir) | |
| # scene_idx_list = [i for i in range(len(dataset))] | |
| # if sys.argv[3] == "true": | |
| # with concurrent.futures.ProcessPoolExecutor(max_workers=cpu_num_total//20) as executor: | |
| # pool_buffer = executor.map(prefetch_data, scene_idx_list) #, total=len(dataset)) | |
| # elif sys.argv[3] == "false": | |
| # for idx in tqdm.tqdm(range(len(dataset))): | |
| # prefetch_data(idx) | |
| from models.videocam.datasets.vis3d_check import vis4d | |
| os.system(f"rm -rf /home/xyx/home/codes/SpaTrackerV2/vis_results/test") | |
| vis4d(data_ret["rgbs"], data_ret["depths"], | |
| data_ret["traj_mat"], data_ret["intrs"], workspace="/home/xyx/home/codes/SpaTrackerV2/vis_results/test") | |
| import pdb; pdb.set_trace() |