Spaces:
Running
on
Zero
Running
on
Zero
| # Copyright (C) 2024-present Naver Corporation. All rights reserved. | |
| # Licensed under CC BY-NC-SA 4.0 (non-commercial use only). | |
| # | |
| # -------------------------------------------------------- | |
| # Dataloader for preprocessed ARKitScenes | |
| # dataset at https://github.com/facebookresearch/co3d - Creative Commons Attribution-NonCommercial 4.0 International | |
| # See datasets_preprocess/preprocess_co3d.py | |
| # -------------------------------------------------------- | |
| import os.path as osp | |
| import os | |
| import sys | |
| import json | |
| import itertools | |
| import time | |
| from collections import deque | |
| import torch | |
| import tqdm | |
| import concurrent.futures | |
| import psutil | |
| import io | |
| import cv2 | |
| from PIL import Image | |
| import numpy as np | |
| from models.SpaTrackV2.datasets.base_sfm_dataset import BaseSfMViewDataset | |
| from models.SpaTrackV2.models.utils import ( | |
| camera_to_pose_encoding, pose_encoding_to_camera | |
| ) | |
| from models.SpaTrackV2.models.camera_transform import normalize_cameras | |
| try: | |
| from pcache_fileio import fileio | |
| except Exception: | |
| fileio = None | |
| try: | |
| import fsspec | |
| #NOTE: stable version (not public) | |
| PCACHE_HOST = "vilabpcacheproxyi-pool.cz50c.alipay.com" | |
| PCACHE_PORT = 39999 | |
| pcache_kwargs = {"host": PCACHE_HOST, "port": PCACHE_PORT} | |
| pcache_fs = fsspec.filesystem("pcache", pcache_kwargs=pcache_kwargs) | |
| except Exception: | |
| fsspec = None | |
| from models.SpaTrackV2.datasets.dataset_util import imread_cv2, npz_loader | |
| def bytes_to_gb(bytes): | |
| return bytes / (1024 ** 3) | |
| def get_total_size(obj, seen=None): | |
| size = sys.getsizeof(obj) | |
| if seen is None: | |
| seen = set() | |
| obj_id = id(obj) | |
| if obj_id in seen: | |
| return 0 | |
| seen.add(obj_id) | |
| if isinstance(obj, dict): | |
| size += sum([get_total_size(v, seen) for v in obj.values()]) | |
| size += sum([get_total_size(k, seen) for k in obj.keys()]) | |
| elif hasattr(obj, '__dict__'): | |
| size += get_total_size(obj.__dict__, seen) | |
| elif hasattr(obj, '__iter__') and not isinstance(obj, (str, bytes, bytearray)): | |
| size += sum([get_total_size(i, seen) for i in obj]) | |
| return size | |
| class ARKitScenes(BaseSfMViewDataset): | |
| def __init__(self, mask_bg=False, scene_st=None, scene_end=None, | |
| debug=False, *args, ROOT, **kwargs): | |
| self.ROOT = ROOT | |
| super().__init__(*args, **kwargs) | |
| assert mask_bg in (True, False, 'rand') | |
| self.mask_bg = mask_bg | |
| self.dataset_label = 'ARKitScenes' | |
| # load all scenes | |
| with open(osp.join(self.ROOT, self.split, f'scene_list.json'), 'r') as f: | |
| self.scene_list = json.load(f) | |
| # for each scene, we have 100 images ==> 360 degrees (so 25 frames ~= 90 degrees)x | |
| # self.combinations = [sorted(np.random.choice(np.arange(0, int(self.num_views*2)), | |
| # size=self.num_views, replace=False)) for i in np.arange(25)] | |
| self.combinations = [None] | |
| self.invalidate = {scene: {} for scene in self.scene_list} | |
| def __len__(self): | |
| return len(self.scene_list) * len(self.combinations) | |
| def _get_metadatapath(self, obj, instance, view_idx): | |
| return osp.join(self.ROOT, obj, instance, 'images', f'frame{view_idx:06n}.npz') | |
| def _get_impath(self, obj, instance, view_idx): | |
| return osp.join(self.ROOT, obj, instance, 'images', f'frame{view_idx:06n}.jpg') | |
| def _get_depthpath(self, obj, instance, view_idx): | |
| return osp.join(self.ROOT, obj, instance, 'depths', f'frame{view_idx:06n}.jpg.geometric.png') | |
| def _get_maskpath(self, obj, instance, view_idx): | |
| return osp.join(self.ROOT, obj, instance, 'masks', f'frame{view_idx:06n}.png') | |
| def _read_depthmap(self, depthpath, input_metadata=None): | |
| depthmap = imread_cv2(depthpath, cv2.IMREAD_UNCHANGED) | |
| depthmap = depthmap.astype(np.float32) | |
| return depthmap | |
| def _get_views(self, idx, resolution, rng): | |
| sclae_num = np.random.uniform(1, 2) | |
| # choose a scene | |
| instance = self.scene_list[idx // len(self.combinations)] | |
| scene_meta_dir = osp.join(self.ROOT, self.split, | |
| instance, "scene_metadata.npz") | |
| if (fileio is not None)&("pcache://" in scene_meta_dir): | |
| input_metadata = npz_loader(scene_meta_dir) | |
| image_pool = input_metadata["images"].tolist() | |
| intr_pool = input_metadata["intrinsics"].tolist() | |
| poses_pool = input_metadata["trajectories"].tolist() | |
| else: | |
| with open(scene_meta_dir, 'rb') as f: | |
| file_content = f.read() | |
| with io.BytesIO(file_content) as bio: | |
| input_metadata = np.load(bio) | |
| image_pool = input_metadata["images"].tolist() | |
| intr_pool = input_metadata["intrinsics"].tolist() | |
| poses_pool = input_metadata["trajectories"].tolist() | |
| start = np.random.choice(np.arange(0, max(len(image_pool) - sclae_num*self.num_views, 1))) | |
| img_idxs = sorted( | |
| np.random.choice(np.arange(start, start+sclae_num*self.num_views), | |
| size=self.num_views, replace=False) | |
| ) | |
| # img_idxs = [int(len(image_pool)*i/(self.num_views*2)) for i in img_idxs] | |
| # add a bit of randomness | |
| last = len(image_pool) - 1 | |
| views = [] | |
| imgs_idxs = [int(max(0, min(im_idx, last))) for im_idx in img_idxs] | |
| imgs_idxs = deque(imgs_idxs) | |
| # output: {rgbs, depths, camera_enc, principal_point, image_size } | |
| rgbs = None | |
| depths = None | |
| Extrs = None | |
| Intrs = None | |
| while len(imgs_idxs) > 0: # some images (few) have zero depth | |
| im_idx = imgs_idxs.pop() | |
| img_name = image_pool[im_idx] | |
| depthpath = osp.join(self.ROOT, self.split, | |
| instance, "lowres_depth", img_name) | |
| impath = depthpath.replace("lowres_depth", "vga_wide").replace(".png", ".jpg") | |
| camera_pose = np.array(poses_pool[im_idx]).astype(np.float32) | |
| intr_pool[im_idx] | |
| f_x, f_y, c_x, c_y = intr_pool[im_idx][2:] | |
| intrinsics = np.float32([[f_x * 2 * 810 / 1920, 0, c_x], [0, f_y, c_y], [0, 0, 1]]) | |
| # intrinsics = np.array([[f_x, 0, c_x], [0, f_y, c_y], [0, 0, 1]]).astype(np.float32) | |
| # load image and depth | |
| rgb_image = imread_cv2(impath) | |
| depthmap = self._read_depthmap(depthpath) | |
| depthmap = depthmap.astype(np.float32) / 1000 | |
| depthmap[~np.isfinite(depthmap)] = 0 # invalid | |
| rgb_image, depthmap, intrinsics = self._crop_resize_if_necessary( | |
| rgb_image, depthmap, intrinsics, resolution, rng=rng, info=impath) | |
| if rgbs is None: | |
| rgbs = torch.from_numpy( | |
| np.array(rgb_image)).permute(2, 0, 1).unsqueeze(0) | |
| depths = torch.from_numpy( | |
| np.array(depthmap)).unsqueeze(0).unsqueeze(0) | |
| Extrs = torch.from_numpy(camera_pose).unsqueeze(0) | |
| Intrs = torch.from_numpy(intrinsics).unsqueeze(0) | |
| else: | |
| rgbs = torch.cat([rgbs, torch.from_numpy( | |
| np.array(rgb_image)).permute(2, 0, 1).unsqueeze(0)], dim=0) | |
| depths = torch.cat([depths, torch.from_numpy( | |
| np.array(depthmap)).unsqueeze(0).unsqueeze(0)], dim=0) | |
| Extrs = torch.cat([Extrs, | |
| torch.from_numpy(camera_pose).unsqueeze(0)], dim=0) | |
| Intrs = torch.cat([Intrs, | |
| torch.from_numpy(intrinsics).unsqueeze(0)], dim=0) | |
| # encode the camera poses | |
| # C2W to W2C | |
| camera_poses = torch.inverse(Extrs) | |
| focal0 = Intrs[:, 0, 0] / resolution[0] | |
| focal1 = Intrs[:, 1, 1] / resolution[1] | |
| focal = (focal0.unsqueeze(1)+focal1.unsqueeze(1))/2 | |
| # principle | |
| R = camera_poses[:, :3, :3] | |
| T = camera_poses[:, :3, 3] | |
| K = torch.zeros((Intrs.shape[0],4,4)) | |
| K[:,:2,:3] = Intrs[:,:2,:3] | |
| K[:,2,3] = K[:,3,2] = 1 | |
| Camera = PerspectiveCameras( | |
| R=R, T=T, in_ndc=False,K=K,focal_length=focal | |
| ) | |
| POSE_MODE = "W2C" | |
| Camera, _, scale = normalize_cameras(Camera, compute_optical=False, | |
| normalize_trans=True, max_norm=True, scale=5, | |
| first_camera=True, pose_mode=POSE_MODE) | |
| pose_enc = camera_to_pose_encoding(Camera, | |
| "absT_quaR_OneFL") | |
| views = dict( | |
| rgbs=rgbs, | |
| depths=depths/scale, | |
| pose_enc=pose_enc, | |
| ) | |
| return views | |
| if __name__ == "__main__": | |
| from models.videocam.datasets.base_sfm_dataset import view_name | |
| from functools import partial | |
| # from dust3r.viz import SceneViz, auto_cam_size | |
| # from dust3r.utils.image import rgb | |
| if fileio is not None: | |
| DATA_DIR = "pcache://vilabpcacheproxyi-pool.cz50c.alipay.com:39999/mnt/antsys-vilab_datasets_pcache_datasets/GTAV_540/GTAV_540" | |
| else: | |
| DATA_DIR = "/nas3/xyx/arkitscenes_processed" | |
| dataset = ARKitScenes(split='Training', ROOT=DATA_DIR, resolution=518, aug_crop=16, num_views=48) | |
| rng = np.random.default_rng(seed=0) | |
| dataset._get_views(0,(518,518),rng) |