import torch import numpy as np from .dust3r_utils.batched_sampler import BatchedRandomSampler # noqa from .co3d_v2 import Co3d from .GTAV import GTAV from .arkitscenes import ARKitScenes from .wildrgbd import WildRGBD from .waymo import Waymo from .scanetpp import ScanNetpp from .blendedmvs import BlendedMVS from .tartanair import Tartan from .vkitti import Vkitti from .pointodyssey import PointOdy from .kubricL import Kubric from .spring import Spring from .DL3DV import DL3DV from .ego4d import EgoData from .replica import Replica from .CustomVid import CustomVid def collect_fn_custom(batch): """ customized collect function for the dataset input: batch: a list of tuples, each tuple contains: imgs: a tensor of shape (B, T, 3, H, W) intrs: a tensor of shape (B, T, 3, 3) traj_3d: a tensor of shape (B, T, N, 3) dyn_prob: a tensor of shape (B, T, N) """ output = {"rgbs": [], "depths": [], "pose_enc": [], "traj_mat": [], "intrs": [], "traj_3d": [], "vis": [], "syn_real": [], "metric_rel": [], "static": [], "data_dir": []} for item in batch: output["rgbs"].append(item["rgbs"]) output["depths"].append(item["depths"]) output["pose_enc"].append(item["pose_enc"]) output["traj_mat"].append(item["traj_mat"]) output["intrs"].append(item["intrs"]) output["traj_3d"].append(item["traj_3d"]) output["vis"].append(item["vis"]) output["syn_real"].append(item["syn_real"]) output["metric_rel"].append(item["metric_rel"]) output["static"].append(item["static"]) output["data_dir"].append(item["data_dir"]) return batch def collect_fn_with_strings(batch): """ Custom collect function for handling tensor and string data in batches. Args: batch: List of dictionaries containing the data Returns: Dictionary with batched data, where string data is kept as a list and tensor data is stacked """ if not batch: return {} # Get all keys from the first item keys = batch[0].keys() # Initialize the output dictionary output = {} for key in keys: # Get all values for this key values = [item[key] for item in batch] # Handle different types of data if isinstance(values[0], (str, bytes)): # Keep strings as a list output[key] = values elif isinstance(values[0], (np.ndarray, torch.Tensor)): # Stack numerical arrays/tensors output[key] = torch.stack([torch.from_numpy(v) if isinstance(v, np.ndarray) else v for v in values]) else: # For any other type, just keep as a list output[key] = values return output def get_data_loader(dataset, batch_size, lite, num_workers=8, shuffle=True, drop_last=True, pin_mem=True, use_string_collect_fn=True): # pytorch dataset if isinstance(dataset, str): dataset = eval(dataset) world_size = lite.world_size rank = lite.local_rank try: sampler = dataset.make_sampler(batch_size, shuffle=shuffle, world_size=world_size, rank=rank, drop_last=drop_last) except (AttributeError, NotImplementedError): # not avail for this dataset if torch.distributed.is_initialized(): sampler = torch.utils.data.DistributedSampler( dataset, num_replicas=world_size, rank=rank, shuffle=shuffle, drop_last=drop_last ) elif shuffle: sampler = torch.utils.data.RandomSampler(dataset) else: sampler = torch.utils.data.SequentialSampler(dataset) # Choose the appropriate collect_fn based on the parameter collate_fn = collect_fn_with_strings if use_string_collect_fn else collect_fn_custom data_loader = torch.utils.data.DataLoader( dataset, sampler=sampler, batch_size=batch_size, num_workers=num_workers, pin_memory=pin_mem, drop_last=drop_last, collate_fn=collate_fn, ) return data_loader