Spaces:
Running
on
Zero
Running
on
Zero
| import os.path as osp | |
| import os | |
| import sys | |
| import json | |
| import itertools | |
| import time | |
| from collections import deque | |
| import torch | |
| import tqdm | |
| import concurrent.futures | |
| import psutil | |
| import io | |
| import cv2 | |
| from PIL import Image | |
| import numpy as np | |
| from models.SpaTrackV2.datasets.base_sfm_dataset import BaseSfMViewDataset | |
| from models.SpaTrackV2.models.utils import ( | |
| camera_to_pose_encoding, pose_encoding_to_camera | |
| ) | |
| from models.SpaTrackV2.models.camera_transform import normalize_cameras | |
| try: | |
| from pcache_fileio import fileio | |
| except Exception: | |
| fileio = None | |
| from models.SpaTrackV2.datasets.dataset_util import imread_cv2, npz_loader | |
| def bytes_to_gb(bytes): | |
| return bytes / (1024 ** 3) | |
| def get_total_size(obj, seen=None): | |
| size = sys.getsizeof(obj) | |
| if seen is None: | |
| seen = set() | |
| obj_id = id(obj) | |
| if obj_id in seen: | |
| return 0 | |
| seen.add(obj_id) | |
| if isinstance(obj, dict): | |
| size += sum([get_total_size(v, seen) for v in obj.values()]) | |
| size += sum([get_total_size(k, seen) for k in obj.keys()]) | |
| elif hasattr(obj, '__dict__'): | |
| size += get_total_size(obj.__dict__, seen) | |
| elif hasattr(obj, '__iter__') and not isinstance(obj, (str, bytes, bytearray)): | |
| size += sum([get_total_size(i, seen) for i in obj]) | |
| return size | |
| class BlendedMVS(BaseSfMViewDataset): | |
| def __init__(self, mask_bg=False, scene_st=None, scene_end=None, | |
| debug=False, *args, ROOT, **kwargs): | |
| self.ROOT = ROOT | |
| super().__init__(*args, **kwargs) | |
| assert mask_bg in (True, False, 'rand') | |
| self.mask_bg = mask_bg | |
| self.dataset_label = 'BlendedMVS' | |
| # load all scenes | |
| self.scene_list = os.listdir(self.ROOT) | |
| def __len__(self): | |
| return len(self.scene_list) | |
| def _get_metadatapath(self, obj, instance, view_idx): | |
| return osp.join(self.ROOT, obj, instance, 'images', f'frame{view_idx:06n}.npz') | |
| def _get_impath(self, obj, instance, view_idx): | |
| return osp.join(self.ROOT, obj, instance, 'images', f'frame{view_idx:06n}.jpg') | |
| def _get_depthpath(self, obj, instance, view_idx): | |
| return osp.join(self.ROOT, obj, instance, 'depths', f'frame{view_idx:06n}.jpg.geometric.png') | |
| def _get_maskpath(self, obj, instance, view_idx): | |
| return osp.join(self.ROOT, obj, instance, 'masks', f'frame{view_idx:06n}.png') | |
| def _read_depthmap(self, depthpath, input_metadata=None): | |
| depthmap = imread_cv2(depthpath, cv2.IMREAD_UNCHANGED) | |
| depthmap = depthmap.astype(np.float32) | |
| return depthmap | |
| def _get_views(self, idx, resolution, rng): | |
| instance = self.scene_list[idx // len(self.combinations)] | |
| # randomly choose a scene | |
| choices = ["scene_metadata_dslr.npz", "scene_metadata_iphone.npz"] | |
| c_idx = np.random.choice([0, 1]) | |
| choice = choices[c_idx] | |
| max_try = 30 | |
| for i in range(max_try): | |
| if os.path.exists(os.path.join(instance, choice)): | |
| break | |
| else: | |
| choice = choices[1-c_idx] | |
| sclae_num = np.random.uniform(1, 2) | |
| # choose a scene | |
| scene_meta = npz_loader(os.path.join(instance, choice)) | |
| re_org_idx = np.argsort(scene_meta['images']) | |
| image_pool = scene_meta["images"][re_org_idx] | |
| poses_pool = scene_meta["trajectories"][re_org_idx] | |
| intr_pool = scene_meta["intrinsics"][re_org_idx] | |
| # get the start index | |
| start = np.random.choice(np.arange(0, max(len(image_pool) - sclae_num*self.num_views, 1))) | |
| img_idxs = sorted( | |
| np.random.choice(np.arange(start, start+sclae_num*self.num_views), | |
| size=self.num_views, replace=False) | |
| ) | |
| # add a bit of randomness | |
| last = len(image_pool) - 1 | |
| views = [] | |
| imgs_idxs = [int(max(0, min(im_idx, last))) for im_idx in img_idxs] | |
| imgs_idxs = deque(imgs_idxs) | |
| # output: {rgbs, depths, camera_enc, principal_point, image_size } | |
| rgbs = None | |
| depths = None | |
| Extrs = None | |
| Intrs = None | |
| while len(imgs_idxs) > 0: # some images (few) have zero depth | |
| im_idx = imgs_idxs.pop() | |
| # randomly sample the camera idx | |
| impath = image_pool[im_idx] | |
| impath = os.path.join(instance, 'images', impath).replace("JPG", "jpg") | |
| depthpath = impath.replace("images", "depth").replace(".jpg", ".png") | |
| camera_pose = poses_pool[im_idx] | |
| intrinsics = intr_pool[im_idx] | |
| rgb_image = imread_cv2(impath) | |
| depthmap = imread_cv2(depthpath, cv2.IMREAD_UNCHANGED) | |
| depthmap = depthmap / 1000.0 | |
| depthmap[~np.isfinite(depthmap)] = 0 # invalid | |
| # crop and resize | |
| rgb_image, depthmap, intrinsics = self._crop_resize_if_necessary( | |
| rgb_image, depthmap, intrinsics, resolution, rng=rng, info=impath) | |
| if rgbs is None: | |
| rgbs = torch.from_numpy( | |
| np.array(rgb_image)).permute(2, 0, 1).unsqueeze(0) | |
| depths = torch.from_numpy( | |
| np.array(depthmap)).unsqueeze(0).unsqueeze(0) | |
| Extrs = torch.from_numpy(camera_pose).unsqueeze(0) | |
| Intrs = torch.from_numpy(intrinsics).unsqueeze(0) | |
| else: | |
| rgbs = torch.cat([rgbs, torch.from_numpy( | |
| np.array(rgb_image)).permute(2, 0, 1).unsqueeze(0)], dim=0) | |
| depths = torch.cat([depths, torch.from_numpy( | |
| np.array(depthmap)).unsqueeze(0).unsqueeze(0)], dim=0) | |
| Extrs = torch.cat([Extrs, | |
| torch.from_numpy(camera_pose).unsqueeze(0)], dim=0) | |
| Intrs = torch.cat([Intrs, | |
| torch.from_numpy(intrinsics).unsqueeze(0)], dim=0) | |
| # encode the camera poses | |
| # C2W to W2C | |
| camera_poses = torch.inverse(Extrs) | |
| focal0 = Intrs[:, 0, 0] / resolution[0] | |
| focal1 = Intrs[:, 1, 1] / resolution[1] | |
| focal = (focal0.unsqueeze(1)+focal1.unsqueeze(1))/2 | |
| # principle | |
| R = camera_poses[:, :3, :3] | |
| T = camera_poses[:, :3, 3] | |
| K = torch.zeros((Intrs.shape[0],4,4)) | |
| K[:,:2,:3] = Intrs[:,:2,:3] | |
| K[:,2,3] = K[:,3,2] = 1 | |
| Camera = PerspectiveCameras( | |
| R=R, T=T, in_ndc=False,K=K,focal_length=focal | |
| ) | |
| POSE_MODE = "W2C" | |
| Camera, _, scale = normalize_cameras(Camera, compute_optical=False, | |
| normalize_trans=True, max_norm=True, scale=5, | |
| first_camera=True, pose_mode=POSE_MODE) | |
| pose_enc = camera_to_pose_encoding(Camera, | |
| "absT_quaR_OneFL") | |
| views = dict( | |
| rgbs=rgbs, | |
| depths=depths/scale, | |
| pose_enc=pose_enc, | |
| ) | |
| return views | |
| if __name__ == "__main__": | |
| from models.videocam.datasets.base_sfm_dataset import view_name | |
| from functools import partial | |
| # from dust3r.viz import SceneViz, auto_cam_size | |
| # from dust3r.utils.image import rgb | |
| if fileio is not None: | |
| DATA_DIR = "pcache://vilabpcacheproxyi-pool.cz50c.alipay.com:39999/mnt/antsys-vilab_datasets_pcache_datasets/3DV_Foundation/BlendedMVS_processed_v3" | |
| bakup_dir = "/input_ssd/datasets/3DV_Foundation/co3dv2" | |
| else: | |
| DATA_DIR = "/nas3/xyx/GTAV_540/GTAV_540/GTAV_540" | |
| bakup_dir = "/nas3/xyx/GTAV_540/GTAV_540/GTAV_540" | |
| cpu_num_total = int(sys.argv[1]) if len(sys.argv) > 1 else 8 | |
| cpu_num_per = int(sys.argv[2]) if len(sys.argv) > 1 else 8 | |
| dataset = BlendedMVS(split='train', ROOT=DATA_DIR, resolution=518, aug_crop=16, num_views=48) | |
| rng = np.random.default_rng(seed=0) | |
| for i in range(len(dataset)): | |
| view = dataset._get_views(i,(518,518),rng) |