import collections import collections.abc import re import warnings from abc import abstractmethod from functools import cached_property from typing import Dict, List, Optional, Sequence, Tuple, TypeVar import numpy as np import PIL.Image import roma import torch import torchvision.transforms.v2 import transformers import yaml from .common_spear import ( Configurable, FlowInput, Normalization, ResizeMode, RoboticsControlPlan, RoboticsFlowInput, RoboticsInput, RoboticsOutput, RoboticsTarget, RotationFormat, expand_dims, is_quaternion, is_rotmat, is_rotmat_3x3, is_rotmat_9, quaternion_half_cover, rotmat_as_3x3, rotmat_as_9, ) from .configuration_spear import ( ControlDataIOConfig, ImageSizeConfig, PaliGemmaProcessorConfig, ) class VLMProcessor(Configurable): @abstractmethod def preprocess_inputs( self, chat: List[str], images: Dict[str, List[PIL.Image.Image]] ) -> Dict[str, torch.Tensor | Dict[str, torch.Tensor]]: ... @property @abstractmethod def tokenizer(self) -> transformers.PreTrainedTokenizerBase: pass @property @abstractmethod def image_sizes(self) -> Dict[str, ImageSizeConfig]: pass class EmptyTokenizer(Configurable): """ Takes the LLM hidden states from `llm_layer_indices` and concatenates them to produce the desired result. Includes the hidden states for the image tokens. """ def __init__(self, config, tokenizer: transformers.PreTrainedTokenizerBase) -> None: super().__init__(config) self.tokenizer = tokenizer def __call__(self, *_) -> str: return "" def np_unique( data: np.ndarray, ) -> Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]: """ Compute unique elements in data and corresponding indices. np.unique returns the values in a sorted order, even if the source is not sorted. Thus, if you simply run np.unique on unsorted data, the indices you will get will be invalid. """ (_, indices, inverse) = np.unique(data, return_index=True, return_inverse=True) (_, indices_of_first_occurence, inverse_indices, counts) = np.unique( indices[inverse], return_index=True, return_inverse=True, return_counts=True ) unique_ids = data[indices_of_first_occurence] return unique_ids, indices_of_first_occurence, inverse_indices, counts def euler_to_rotmat(angles: torch.Tensor) -> torch.Tensor: """ Args: angles: Euler angles in radians in the format 'xyz', shape [..., 3] Returns: torch.Tensor of shape [..., 3, 3] containing rotation matrices """ return roma.euler_to_rotmat(convention="xyz", angles=angles, degrees=False) def euler_to_unit_quaternion(angles: torch.Tensor) -> torch.Tensor: """ Args: angles: Euler angles in radians in the format 'xyz', shape [..., 3] Returns: torch.Tensor of shape [..., 4] containing unit quaternions """ return roma.euler_to_unitquat(convention="xyz", angles=angles, degrees=False, normalize=True) def normalize_quaternion(quaternion: torch.Tensor, eps: float = 1e-08) -> torch.Tensor: """ Args: quaternion: Unnormalized quaternion, torch.Tensor of shape [..., 4] eps: Small constant to prevent division by zero Returns: torch.Tensor of shape [..., 4] of unit quaternions """ return quaternion / (quaternion.norm(dim=-1, keepdim=True).detach() + eps) def quaternion_to_euler(quaternion: torch.Tensor) -> torch.Tensor: """ Args: quaternion: torch.Tensor of shape [..., 4]; Can be non-normalized Returns: torch.Tensor of shape [..., 3, 3] containing rotation matrices in SO(3) """ unit_quat = normalize_quaternion(quaternion) rotmat = roma.unitquat_to_euler(convention="xyz", quat=unit_quat, as_tuple=False, degrees=False) return rotmat def quaternion_to_rotmat(quaternion: torch.Tensor) -> torch.Tensor: """ Args: quaternion: torch.Tensor of shape [..., 4]; Can be non-normalized Returns: torch.Tensor of shape [..., 3, 3] containing rotation matrices in SO(3) """ unit_quat = normalize_quaternion(quaternion) rotmat = roma.unitquat_to_rotmat(unit_quat) return rotmat def rotmat_to_unit_quaternion(rotmat: torch.Tensor) -> torch.Tensor: """ Args: rotmat: Batch of rotation matrices, shape [..., 3, 3] Returns: Batch of unit quaternions, shape [..., 4] """ rotmat = rotmat_as_3x3(rotmat) return roma.rotmat_to_unitquat(rotmat) def rotmat_to_euler(rotmat: torch.Tensor) -> torch.Tensor: """ Args: rotmat: Batch of rotation matrices, shape [..., 3, 3] Returns: Batch of Euler angles in radiant, shape [..., 3] """ rotmat = rotmat_as_3x3(rotmat) return roma.rotmat_to_euler(convention="xyz", rotmat=rotmat, as_tuple=False, degrees=False) def symmetric_orthogonalization(x: torch.Tensor) -> torch.Tensor: """ Maps 9D input vectors onto SO(3) via symmetric orthogonalization. - Let SVD(M) = U \Sigma V^T - Returned value is SVD+(M) = U diag(1, 1, det(UV^T)) V^T - det(UV^T) ensures that det(SVD+(M)) = 1 - The return value is a rotation matrix (ortonormal) with the least-squares distance to M Args: x: Input matrices, not necessarily orthonormal, shape [..., 9] or [..., 3, 3] Returns: torch.Tensor with the same shape as x, where each inner 3x3 matrix is in SO(3) """ with warnings.catch_warnings(): warnings.filterwarnings( "ignore", message="In CPU autocast, but the target dtype is not supported. Disabling autocast.", ) with torch.autocast(device_type=x.device.type, dtype=torch.float32): matrices = x.view(-1, 3, 3) matrices = matrices.to(dtype=torch.float32) (u, s, v) = torch.svd(matrices) vt = torch.transpose(v, 1, 2) det = torch.det(torch.matmul(u, vt)).view(-1, 1, 1) diag_vt = torch.cat((vt[:, :2, :], vt[:, -1:, :] * det), dim=1) result = torch.matmul(u, diag_vt) result = result.view(*x.shape) result = result.to(dtype=x.dtype) return result def is_rotmat_orthonormal( rotmat: torch.Tensor, epsilon: float = 1e-06, reduction: str = "none" ) -> torch.Tensor | bool: """ Check if a rotation matrix is orthonormal or not. Args: rotmat: torch.Tensor of shape [..., 3, 3] or [..., 9] epsilon: Tolerance for numerical comparisons. Bigger values allow for more freedom. Generally, anything smaller than 1e-6 might incorrectly detect some otrhonormal matrices as not reduction: 'none' - returns torch.Tensor of bools with the same batch shape 'all' - returns a bool, True is ALL matrices in the batch are orthonormal Returns: torch.Tensor with the same batch shape or bool """ assert is_rotmat(rotmat) rotmat = rotmat_as_3x3(rotmat.to(dtype=torch.float32)) is_orthonormal = roma.is_orthonormal_matrix(rotmat, epsilon=epsilon) if reduction == "none": return is_orthonormal if reduction == "all": return bool(torch.all(is_orthonormal).item()) raise ValueError(f"Unknown reduction mode {reduction}") def is_orthonormal_rotmat(rotmat: torch.Tensor) -> bool: """ Checks if the tensor shape matches that of a rotmat. If the last dimensions of shape are 3x3, also checks if the data is a valid rotmat. This is to avoid a possible clash with euler angles when accidentally `rotmat.shape[-2:] == [3, 3]` """ return ( is_rotmat_9(rotmat) or is_rotmat_3x3(rotmat) and is_rotmat_orthonormal(rotmat, epsilon=0.01, reduction="all") ) def is_euler(euler: torch.Tensor) -> bool: return euler.shape[-1] == 3 and not is_orthonormal_rotmat(euler) def normalize_rotation(rotation: torch.Tensor) -> torch.Tensor: if is_quaternion(rotation): return normalize_quaternion(rotation) if is_euler(rotation): return rotation if is_rotmat(rotation): is_flat = is_rotmat_9(rotation) rotation = rotmat_as_3x3(rotation) if is_flat else rotation rotmat = roma.special_gramschmidt(rotation) rotmat = rotmat_as_9(rotmat) if is_flat else rotmat return rotmat raise ValueError(f"Unknown rotation format: {rotation.shape}") def rotation_format_from_tensor(rotation) -> RotationFormat: if is_quaternion(rotation): return RotationFormat.QUATERNION if is_orthonormal_rotmat(rotation): return RotationFormat.ROTMAT if is_euler(rotation): return RotationFormat.EULER raise ValueError(f"Tensor shape {rotation.shape} is not a valid rotation format") def is_unit_quaternion( quaternion: torch.Tensor, epsilon: float = 1e-08, reduction: str = "none" ) -> torch.Tensor | bool: """ Check if a quternion is normalized or not. Args: quaternion: torch.Tensor of shape [..., 4] tolerance: Tolerance for numerical comparisons reduction: 'none' - returns torch.Tensor of bools with the same batch shape 'all' - returns a bool, True if ALL quaternions in the batch are normalized Returns: torch.Tensor with the same batch shape or bool """ assert is_quaternion(quaternion) is_norm = torch.isclose( quaternion.norm(dim=-1, keepdim=True), torch.tensor(1.0, dtype=quaternion.dtype, device=quaternion.device), atol=epsilon, ) if reduction == "none": return is_norm if reduction == "all": return bool(torch.all(is_norm).item()) raise ValueError(f"Unknown reduction mode {reduction}") def convert_rotation( rotation: torch.Tensor | np.ndarray, output_format: RotationFormat, autonorm: bool = True, half_cover: bool = True, ) -> torch.Tensor | np.ndarray: is_np = isinstance(rotation, np.ndarray) if is_np: rotation = torch.from_numpy(rotation) if is_quaternion(rotation): if autonorm and not is_unit_quaternion(rotation, reduction="all"): rotation = normalize_quaternion(rotation) if output_format == RotationFormat.QUATERNION: output = rotation elif output_format == RotationFormat.ROTMAT: output = rotmat_as_9(quaternion_to_rotmat(rotation)) elif output_format == RotationFormat.EULER: output = quaternion_to_euler(rotation) else: raise NotImplementedError(f"Unsupported rotation format: {output_format}") elif is_orthonormal_rotmat(rotation): if autonorm and not is_rotmat_orthonormal(rotation, epsilon=0.01, reduction="all"): rotation = symmetric_orthogonalization(rotation) if output_format == RotationFormat.QUATERNION: output = rotmat_to_unit_quaternion(rotation) elif output_format == RotationFormat.ROTMAT: output = rotmat_as_9(rotation) elif output_format == RotationFormat.EULER: output = rotmat_to_euler(rotation) else: raise NotImplementedError(f"Unsupported rotation format: {output_format}") elif is_euler(rotation): if output_format == RotationFormat.QUATERNION: output = euler_to_unit_quaternion(rotation) elif output_format == RotationFormat.ROTMAT: output = rotmat_as_9(euler_to_rotmat(rotation)) elif output_format == RotationFormat.EULER: output = rotation else: raise NotImplementedError(f"Unsupported rotation format: {output_format}") else: raise ValueError(f"Unknown rotation encoding with shape {rotation.shape}") if output_format == RotationFormat.QUATERNION and half_cover: output = quaternion_half_cover(output) if is_np: output = output.numpy() return output def delta_to_relative_rotations(rotation_sequence: torch.Tensor) -> torch.Tensor: """ Transform a sequence of rotation representations encoded w.r.t. the PREVIOUS rotation frame in the sequence to the 0-th element preceding the sequence Ex: `rotation_sequence` contains the rotations: R_01, R_12, R_23, R_34, where R0 is the base frame, implicitly encoded in R_01 and R_10 converts from R0 frame to R1 frame Output: R_01, R_02, R_03, R_04 Args: rotation_sequence: torch.Tensor of shape [..., S, 9], [..., S, 3, 3] or [..., S, 4], containing either rotation matrices (R_01, R_12, R_23, R_34, ...) or quaternions Returns: torch.Tensor of shape [..., S, 9], [..., S, 3, 3] or [..., S, 4] containing transformed rotations (R_01, R_02, R_03, R_04, ...) TODO: Can you make it work without for loop """ assert rotation_sequence.ndim >= 3, rotation_sequence.shape rotation_format: RotationFormat = rotation_format_from_tensor(rotation_sequence) rotation_sequence = convert_rotation(rotation_sequence, RotationFormat.QUATERNION) batch_dims = np.arange(rotation_sequence.ndim - 2) delta_rotations = torch.cat( [rotation_sequence[..., :1, :]] + [ roma.quat_composition(rotation_sequence[..., :i, :].permute(-2, *batch_dims, -1).unsqueeze(-2)) for i in range(2, rotation_sequence.shape[-2] + 1) ], dim=-2, ) delta_rotations = convert_rotation(delta_rotations, rotation_format) return delta_rotations def assert_np_hwc_or_hw_image(image: np.ndarray | PIL.Image.Image) -> np.ndarray: """Make sure image is of type np.ndarray and HWC format""" if isinstance(image, PIL.Image.Image): image = np.asarray(image) assert isinstance(image, np.ndarray), type(image) assert image.ndim in [2, 3], image.shape if image.ndim == 3: assert image.shape[-1] <= 4, image.shape return image def hw_from_image(image: PIL.Image.Image | np.ndarray) -> tuple[int, int]: if isinstance(image, np.ndarray): (height, width) = image.shape[:2] else: (width, height) = image.size return height, width def pad_image( image: PIL.Image.Image | np.ndarray, target_size: dict[str, int], pad_value: tuple[int, int, int] | tuple[float, float, float] | int | float = 0, ) -> PIL.Image.Image | np.ndarray: """Pad image adding a symmetric border around the height/width.""" assert isinstance(image, (PIL.Image.Image, np.ndarray)), type(image) (height, width) = hw_from_image(image) (target_width, target_height) = (target_size["width"], target_size["height"]) if width == target_width and height == target_height: return image assert target_width >= width, f"Can't pad image of width {width} to {target_width}" assert target_height >= height, f"Can't pad image of height {height} to {target_height}" (horizontal_pad, vertical_pad) = ( int((target_width - width) / 2), int((target_height - height) / 2), ) if isinstance(image, np.ndarray): padding = ((vertical_pad, vertical_pad), (horizontal_pad, horizontal_pad)) + ((0, 0),) * ( image.ndim - 2 ) image = np.pad(image, padding, mode="constant", constant_values=pad_value) else: padding = (horizontal_pad, vertical_pad, horizontal_pad, vertical_pad) image = torchvision.transforms.v2.functional.pad( image, padding=padding, fill=pad_value, padding_mode="constant" ) return image def pad_image_to_ratio( image: PIL.Image.Image | np.ndarray, target_wh_ratio: float, pad_value: tuple[int, int, int] | tuple[float, float, float] | int | float = 0, ) -> PIL.Image.Image | np.ndarray: """Pad image to a target aspect ratio.""" (height, width) = hw_from_image(image) wh_ratio = width / height if target_wh_ratio >= wh_ratio: pad_size = {"width": round(height * target_wh_ratio), "height": height} else: pad_size = {"width": width, "height": round(width / target_wh_ratio)} image = pad_image(image, target_size=pad_size, pad_value=pad_value) return image def crop_image( image: np.ndarray | PIL.Image.Image, start_height: int, start_width: int, target_height: int, target_width: int, ) -> np.ndarray | PIL.Image.Image: np_image = assert_np_hwc_or_hw_image(image) (height, width) = hw_from_image(image) assert target_width <= width, f"Can't crop image of width {width} to {target_width}" assert target_height <= height, f"Can't crop image of width {height} to {target_height}" (start_height, start_width) = (round(start_height), round(start_width)) (target_height, target_width) = (round(target_height), round(target_width)) np_image = np_image[ start_height : start_height + target_height, start_width : start_width + target_width, ..., ] image = PIL.Image.fromarray(np_image) if isinstance(image, PIL.Image.Image) else np_image return image def crop_image_center( image: np.ndarray | PIL.Image.Image, target_size: dict[str, int] ) -> np.ndarray | PIL.Image.Image: np_image = assert_np_hwc_or_hw_image(image) (height, width) = np_image.shape[:2] (target_height, target_width) = (target_size["height"], target_size["width"]) assert target_width <= width, f"Can't crop image of width {width} to {target_width}" assert target_height <= height, f"Can't crop image of width {height} to {target_height}" top = (height - target_height) // 2 left = (width - target_width) // 2 np_image = crop_image(np_image, top, left, target_height, target_width) image = PIL.Image.fromarray(np_image) if isinstance(image, PIL.Image.Image) else np_image return image def crop_image_to_ratio( image: PIL.Image.Image | np.ndarray, target_wh_ratio: float ) -> PIL.Image.Image | np.ndarray: """Pad image to a target aspect ratio.""" (height, width) = hw_from_image(image) wh_ratio = width / height if target_wh_ratio >= wh_ratio: crop_size = {"width": width, "height": round(width / target_wh_ratio)} else: crop_size = {"width": round(height * target_wh_ratio), "height": height} image = crop_image_center(image, target_size=crop_size) return image def crop_and_pad_image_to_ratio( image: PIL.Image.Image | np.ndarray, target_wh_ratio: float, mode: ResizeMode | str, pad_value: tuple[int, int, int] | tuple[float, float, float] | int | float = 0, ) -> PIL.Image.Image | np.ndarray: """ Crop and pad an image to a target size depending on the mode. It's expected that the source image and target size have different aspect ratios. Args: image: The image to crop and pad. target_size: The target size to crop and pad the image to. mode: The mode to use for cropping and padding. """ (height, width) = hw_from_image(image) wh_ratio = width / height if np.isclose(wh_ratio, target_wh_ratio, rtol=0.01, atol=0.0001): return image if mode == ResizeMode.SMART: aspect_ratio = max(width, height) / min(width, height) target_ratio = max(target_wh_ratio, 1 / target_wh_ratio) if aspect_ratio == 1: if target_ratio >= 4 / 3 - 0.01: crop_wh_ratio = 4 / 3 if target_wh_ratio >= 1.0 else 3 / 4 image = crop_image_to_ratio(image, crop_wh_ratio) else: pass elif aspect_ratio <= 4 / 3 + 0.01: if wh_ratio >= 1.0 != (target_wh_ratio >= 1.0): image = crop_image_to_ratio(image, 1.0) elif wh_ratio >= 1.0 != (target_wh_ratio >= 1.0): image = crop_image_to_ratio(image, 1.0) elif target_ratio >= 4 / 3 + 0.01: pass else: crop_wh_ratio = 4 / 3 if target_wh_ratio >= 1.0 else 3 / 4 image = crop_image_to_ratio(image, crop_wh_ratio) image = pad_image_to_ratio(image, target_wh_ratio, pad_value=pad_value) elif mode == ResizeMode.PAD: image = pad_image_to_ratio(image, target_wh_ratio, pad_value=pad_value) elif mode == ResizeMode.CROP: image = crop_image_to_ratio(image, target_wh_ratio) else: raise ValueError(f"Mode {mode} not supported") return image def is_single_channel_image(image: np.ndarray | PIL.Image.Image) -> bool: if isinstance(image, PIL.Image.Image): return image.mode in [ "1", "L", "LA", "La", "P", "PA", "F", "I", "I;16", "I;16L", "I;16B", "I;16N", ] if isinstance(image, np.ndarray): return image.ndim == 2 or image.ndim == 3 and image.shape[2] == 1 raise ValueError(f"Unsupported image type: {type(image)}") def is_binary_mask(image: np.ndarray | PIL.Image.Image) -> bool: image = np.asarray(image) return image.dtype in [np.uint8, np.bool_] and np.max(image) == 1 def resize_image( image: PIL.Image.Image | np.ndarray, target_size: dict[str, int], mode: ResizeMode | str, resample: PIL.Image.Resampling | str = "auto", pad_value: tuple[int, int, int] | tuple[float, float, float] | int | float = 0, ) -> PIL.Image.Image | np.ndarray: (target_width, target_height) = (target_size["width"], target_size["height"]) (height, width) = hw_from_image(image) if height == target_height and width == target_width: return image if resample == "auto": if is_single_channel_image(image): resample = PIL.Image.Resampling.BILINEAR else: resample = PIL.Image.Resampling.LANCZOS else: assert isinstance(resample, PIL.Image.Resampling), resample if is_single_channel_image(image) and resample not in [ PIL.Image.Resampling.BILINEAR, PIL.Image.Resampling.BICUBIC, ]: raise ValueError( f"Single channel images must be resized with bilinear or bicubic, but got {resample}" ) if is_bin_mask := is_binary_mask(image): image = np.asarray(image).astype(np.uint8) * 255 if mode == ResizeMode.SMART: image = crop_and_pad_image_to_ratio( image, target_wh_ratio=target_width / target_height, mode=mode, pad_value=pad_value, ) pil_image = PIL.Image.fromarray(image) if isinstance(image, np.ndarray) else image if mode in [ResizeMode.NAIVE, ResizeMode.SMART]: pil_image = pil_image.resize((target_width, target_height), resample=resample) else: raise NotImplementedError(f"Mode {mode} not supported") image = np.asarray(pil_image) if isinstance(image, np.ndarray) else pil_image if is_bin_mask: image = image.astype(np.uint8) > 127 return image def is_global_norm( norm: Normalization | Dict[str, torch.Tensor | np.ndarray | tuple | list], ) -> bool: """Return true if norm is NONE or global for all datasets""" return norm == Normalization.NONE or isinstance(norm, collections.abc.Mapping) def is_mean_norm( norm: Normalization | Dict[str, torch.Tensor | np.ndarray | tuple | list], ) -> bool: """Return true if norm is based on mean and std""" return ( norm == Normalization.MEAN or isinstance(norm, collections.abc.Mapping) and set(norm.keys()) == {"mean", "std"} ) def _broadcast_shapes( value: torch.Tensor, low: torch.Tensor, high: torch.Tensor ) -> Tuple[torch.Tensor, torch.Tensor]: """ Broadcast shapes for normalization: Args: value: torch.Tensor of shape [..., num_components]. The entire shape might be: - [num_components]: `value` has no batch dimension - [num_datasets, num_components]: `value` contains entries *aligned* with the dataset bounds contained in `low` and `high` - [num_datasets, ..., num_components]: `value` contains entries *aligned* with the dataset bounds contained in `low` and `high` - [..., num_components]: `value` contains multiple dimensions. In this case, `low` and `high` must be for a single dataset, i.e. `num_datasets = 1` low: torch.Tensor, shape [num_datasets, num_components], where `num_datasets` can be 1 when `low` contains normalization bounds for a single dataset high: torch.Tensor, shape [num_datasets, num_components], where `num_datasets` can be 1 when `high` contains normalization bounds for a single dataset Returns: Tuple of torch.Tensors (low, high), where `low` and `high` have the same number of dimensions as `value` """ assert low.ndim == high.ndim == 2, f"{low.shape} != {high.shape} or ndim != 2" assert value.shape[-1] == low.shape[-1] == high.shape[-1], f"{value.shape} != {low.shape} / {high.shape}" if value.ndim == low.ndim == high.ndim: return low, high if value.ndim < low.ndim: assert low.ndim == high.ndim == 2, f"{low.shape}, {high.shape}" assert low.shape[0] == high.shape[0] == 1, f"{low.shape}, {high.shape}" (low, high) = (low.view(-1), high.view(-1)) return low, high if low.shape[0] == high.shape[0] == 1: low = expand_dims(low.view(-1), ndim=value.ndim, order=[-1, 1]) high = expand_dims(high.view(-1), ndim=value.ndim, order=[-1, 1]) else: assert value.shape[0] == low.shape[0] == high.shape[0], f"{value.shape} != {low.shape} / {high.shape}" low = expand_dims(low, ndim=value.ndim, order=[1, -1, 1]) high = expand_dims(high, ndim=value.ndim, order=[1, -1, 1]) return low, high def unnormalize_by_moments(value: torch.Tensor, mean: torch.Tensor, std: torch.Tensor) -> torch.Tensor: (mean, std) = _broadcast_shapes(value, mean, std) (mean, std) = (mean.to(device=value.device), std.to(device=value.device)) return value * (std + 1e-08) + mean def unnormalize_by_bounds(value: torch.Tensor, low: torch.Tensor, high: torch.Tensor) -> torch.Tensor: (low, high) = _broadcast_shapes(value, low, high) (low, high) = (low.to(device=value.device), high.to(device=value.device)) return 0.5 * (value + 1) * (high - low) + low def normalize_gripper_by_bounds( value: torch.Tensor, low: torch.Tensor, high: torch.Tensor, binary: bool = True ) -> torch.Tensor: """ If binary, normalize to [0, 1], otherwise normalize to [-1, 1] """ (low, high) = _broadcast_shapes(value, low, high) (low, high) = (low.to(device=value.device), high.to(device=value.device)) if binary: return torch.clamp((value - low) / torch.clamp(high - low, min=1e-08), min=0.0, max=1.0) return torch.clamp(2 * (value - low) / torch.clamp(high - low, min=1e-08) - 1, min=-1.0, max=1.0) def normalize_by_moments(value: torch.Tensor, mean: torch.Tensor, std: torch.Tensor) -> torch.Tensor: (mean, std) = _broadcast_shapes(value, mean, std) (mean, std) = (mean.to(device=value.device), std.to(device=value.device)) return (value - mean) / (std + 1e-08) def normalize_by_bounds(value: torch.Tensor, low: torch.Tensor, high: torch.Tensor) -> torch.Tensor: (low, high) = _broadcast_shapes(value, low, high) (low, high) = (low.to(device=value.device), high.to(device=value.device)) return torch.clamp(2 * (value - low) / torch.clamp(high - low, min=1e-08) - 1, min=-1.0, max=1.0) def invert_gripper(gripper: np.ndarray, low: float, high: float) -> np.ndarray: if low < 0.0: return np.clip(-gripper, low, high) return high - np.clip(gripper, low, high) GRIPPER_BOUNDS = { "bridge": (0.0, 1.0), "bridge_orig": (0.0, 1.0), "droid": (0.0, 1.0), "roboset": (0.0, 1.0), } def preprocess_gripper_observation( gripper: np.ndarray, dataset_name: str | np.ndarray, binary: bool = True ) -> np.ndarray: """ Preprocess gripper observation depending on dataset. Input is the raw gripper observation from the dataset or from the robot and output is normalized continuous value. - if `binary`, output is in [0, 1], with 0 = closed and 1 = open. - otherwise, output is in [-1, 1], with -1 = closed and 1 = open. Dataset-specific gripper observations: bridge: continuous; ~[0=closed; 1=open] bridge_orig: continuous; ~[0=closed; 1=open] droid: continuous; [0=open, 1=closed] roboset: continuous; [0=open, 1=closed] """ if isinstance(dataset_name, np.ndarray): assert np.unique(dataset_name).size == 1, dataset_name dataset_name = str(dataset_name[0]) if dataset_name in [ "droid", "roboset", ]: (low, high) = GRIPPER_BOUNDS[dataset_name] gripper = normalize_gripper_by_bounds( torch.from_numpy(invert_gripper(gripper, low=low, high=high)), low=torch.full(gripper.shape, GRIPPER_BOUNDS[dataset_name][0], dtype=torch.float32), high=torch.full(gripper.shape, GRIPPER_BOUNDS[dataset_name][1], dtype=torch.float32), binary=binary, ).numpy() elif dataset_name in [ "bridge", "bridge_orig", ]: (low, high) = GRIPPER_BOUNDS[dataset_name] gripper = normalize_gripper_by_bounds( torch.from_numpy(gripper), low=torch.full(gripper.shape, low, dtype=torch.float32), high=torch.full(gripper.shape, high, dtype=torch.float32), binary=binary, ).numpy() else: raise NotImplementedError(f"Unknown dataset: {dataset_name}") return gripper def rotation_norm_bounds( rotation_norm: Normalization, rotation_format: RotationFormat, stats: Dict[str, Dict[str, Dict[str, List[float]]]], dataset_names: List[str], ) -> Dict[str, Dict[str, torch.Tensor]]: if rotation_format == RotationFormat.EULER and rotation_norm != Normalization.NONE: if rotation_norm == Normalization.BOUNDS: results = { dataset_name: { "low": torch.tensor(dataset_stats["euler"]["min"]), "high": torch.tensor(dataset_stats["euler"]["max"]), } for (dataset_name, dataset_stats) in stats.items() } elif rotation_norm == Normalization.BOUNDS_Q99: results = { dataset_name: { "low": torch.tensor(dataset_stats["euler"]["q01"]), "high": torch.tensor(dataset_stats["euler"]["q99"]), } for (dataset_name, dataset_stats) in stats.items() } else: raise NotImplementedError(f"Normalization type {rotation_norm} not yet implemented") else: assert rotation_norm == Normalization.NONE, rotation_norm if rotation_format == RotationFormat.EULER: rotation_size = 3 elif rotation_format == RotationFormat.QUATERNION: rotation_size = 4 else: rotation_size = 9 results = { dataset_name: { "low": -1 * torch.ones(rotation_size, dtype=torch.float32), "high": 1 * torch.ones(rotation_size, dtype=torch.float32), } for dataset_name in dataset_names } return results def translation_norm_bounds( translation_norm: Normalization | tuple, stats: Dict[str, Dict[str, Dict[str, List[float]]]], dataset_names: List[str], ) -> Dict[str, Dict[str, torch.Tensor]]: if isinstance(translation_norm, (Normalization, str)) and translation_norm != Normalization.NONE: if translation_norm == Normalization.BOUNDS: results = { dataset_name: { "low": torch.tensor(dataset_stats["translation"]["min"]), "high": torch.tensor(dataset_stats["translation"]["max"]), } for (dataset_name, dataset_stats) in stats.items() } elif translation_norm == Normalization.BOUNDS_Q99: results = { dataset_name: { "low": torch.tensor(dataset_stats["translation"]["q01"]), "high": torch.tensor(dataset_stats["translation"]["q99"]), } for (dataset_name, dataset_stats) in stats.items() } elif translation_norm == Normalization.MEAN: results = { dataset_name: { "mean": torch.tensor(dataset_stats["translation"]["mean"]), "std": torch.tensor(dataset_stats["translation"]["std"]), } for (dataset_name, dataset_stats) in stats.items() } else: raise NotImplementedError(f"Normalization type {translation_norm} not yet implemented") elif isinstance(translation_norm, Normalization) and translation_norm == Normalization.NONE: results = { dataset_name: { "low": -1 * torch.ones(3, dtype=torch.float32), "high": 1 * torch.ones(3, dtype=torch.float32), } for dataset_name in dataset_names } else: assert isinstance(translation_norm, collections.abc.Mapping), type(translation_norm) assert all((len(value) == 3 for value in translation_norm.values())), translation_norm assert set(translation_norm.keys()) in ( {"low", "high"}, {"mean", "std"}, ), translation_norm results = { dataset_name: { key: torch.tensor(value, dtype=torch.float32) for (key, value) in translation_norm.items() } for dataset_name in dataset_names } return results VLAMProcessorConfigT = TypeVar("VLAMProcessorConfigT") class VLAMProcessor(Configurable): def __init__(self, config: VLAMProcessorConfigT, vlm_processor: VLMProcessor): super().__init__(config) self.vlm_processor = vlm_processor self.control_tokenizer = EmptyTokenizer( config=self.config.control_tokenizer_config, tokenizer=self.tokenizer ) self.norm_bounds: Dict[str, Dict[str, Dict[str, torch.Tensor]]] = { "obs_translation": self.obs_translation_norm_bounds, "obs_rotation": self.obs_rotation_norm_bounds, "translation": self.translation_norm_bounds, "rotation": self.rotation_norm_bounds, "joints": self.joints_norm_bounds, } @property def tokenizer(self) -> transformers.PreTrainedTokenizerBase: return self.vlm_processor.tokenizer @property def image_sizes(self) -> Dict[str, ImageSizeConfig]: return self.vlm_processor.image_sizes @property def camera_names(self) -> List[str]: return list(self.vlm_processor.image_sizes.keys()) @property def control_io_config(self) -> ControlDataIOConfig: return self.config.control_io_config @cached_property def rotation_components(self) -> int: if self.config.rotation_format == RotationFormat.EULER: return 3 if self.config.rotation_format == RotationFormat.QUATERNION: return 4 if self.config.rotation_format == RotationFormat.ROTMAT: return 9 raise NotImplementedError(self.config.rotation_format) @abstractmethod def policy_control_plan_from_model_target( self, target: RoboticsTarget, dataset_name: np.ndarray ) -> RoboticsControlPlan: pass @abstractmethod def policy_control_plan_from_model_output( self, model_output: RoboticsOutput, dataset_name: np.ndarray, valid_mask: torch.Tensor, ) -> RoboticsControlPlan: pass def resize_image( self, camera_name: str, image: PIL.Image.Image | np.ndarray ) -> PIL.Image.Image | np.ndarray: return resize_image( image, target_size={ "width": self.image_sizes[camera_name].width, "height": self.image_sizes[camera_name].height, }, mode=self.config.image_resize, resample=PIL.Image.Resampling.LANCZOS, ) def preprocess_inputs( self, chat: List[str], images: Dict[str, PIL.Image.Image | List[PIL.Image.Image]], ee_pose_translation: np.ndarray, ee_pose_rotation: np.ndarray, gripper: np.ndarray, joints: np.ndarray, dataset_name: np.ndarray, inference_mode: bool, control_target: Optional[RoboticsTarget] = None, ) -> Dict[str, torch.Tensor | Dict[str, torch.Tensor]]: """ Preprocess the inputs for a single example Args: instruction: Language instruction images: History of input images with increasing timestamps ee_pose_translation: np.ndarray, shape [..., num_past_scalars, 3] ee_pose_rotation: np.ndarray, shape [..., num_past_scalars, 3 | 4 | 9] joints: np.ndarray, shape [..., num_past_scalars, <= 7] dataset_name: 1D np.ndarray inference_mode: If True, prepare the input for inference (e.g. don't include target any tokens in the input if relevant). If control_target is available, it should still be preprocessed for test dataset comparison control_target: RoboticsTarget, each component of shape [..., num_control_steps, num_control_components]. Provided only when available, usually during training and dataset test Returns: Dict containing torch.Tensor with inputs """ del control_target del inference_mode inputs = self.vlm_processor.preprocess_inputs(chat=chat, images=images) images: Dict[str, torch.Tensor] = inputs["images"] input_ids: torch.Tensor = inputs["input_ids"][..., : self.tokenizer.model_max_length] target_text_tokens_ids: torch.Tensor = inputs["target_ids"][..., : self.tokenizer.model_max_length] attn_mask = torch.ones(input_ids.shape, dtype=torch.bool) ee_pose_translation = torch.tensor(ee_pose_translation, dtype=torch.float32) ee_pose_rotation = torch.tensor(ee_pose_rotation, dtype=torch.float32) ee_pose_rotation = convert_rotation(ee_pose_rotation, self.config.rotation_format, autonorm=True) gripper = preprocess_gripper_observation(gripper, dataset_name) gripper = torch.tensor(gripper, dtype=torch.float32) ee_pose_translation = self.normalize( ee_pose_translation, dataset_name=dataset_name, key="obs_translation" ) ee_pose_rotation = self.normalize(ee_pose_rotation, dataset_name=dataset_name, key="obs_rotation") joints = torch.tensor(joints, dtype=torch.float32) if joints.shape[-1] < 7: missing_size = 7 - joints.shape[-1] joints = torch.cat([joints, torch.zeros([*joints.shape[:-1], missing_size])], dim=-1) joints = self.normalize(joints, dataset_name=dataset_name, key="joints") outputs = { "images": images, "input_ids": input_ids, "target_text_tokens_ids": target_text_tokens_ids, "attn_mask": attn_mask, "ee_pose_translation": ee_pose_translation, "ee_pose_rotation": ee_pose_rotation, "gripper": gripper, "joints": joints, "control_tokens_ids": None, "target_control_tokens_ids": None, } return outputs def create_input( self, chat: List[str], images: Dict[str, List[PIL.Image.Image]], ee_pose_translation: np.ndarray, ee_pose_rotation: np.ndarray, gripper: np.ndarray, joints: np.ndarray, dataset_name: np.ndarray, inference_mode: bool, control_target: Optional[RoboticsTarget] = None, ) -> RoboticsInput: inputs = self.preprocess_inputs( chat=chat, images=images, ee_pose_translation=ee_pose_translation, ee_pose_rotation=ee_pose_rotation, gripper=gripper, joints=joints, dataset_name=dataset_name, inference_mode=inference_mode, control_target=control_target, ) inputs.pop("target_text_tokens_ids") inputs.pop("target_control_tokens_ids") return RoboticsInput(**inputs) def normalize(self, value: torch.Tensor, dataset_name: np.ndarray, key: str) -> torch.Tensor: if is_mean_norm(getattr(self.config, f"{key}_norm")): (mean, std) = self._norm_bounds_from_dataset_name(dataset_name, component_key=key) output = normalize_by_moments(value, mean=mean, std=std) else: (low, high) = self._norm_bounds_from_dataset_name(dataset_name, component_key=key) output = normalize_by_bounds(value, low=low, high=high) return output def unnormalize(self, value: torch.Tensor, dataset_name: np.ndarray, key: str) -> torch.Tensor: if is_mean_norm(getattr(self.config, f"{key}_norm")): (mean, std) = self._norm_bounds_from_dataset_name(dataset_name, component_key=key) output = unnormalize_by_moments(value, mean=mean, std=std) else: (low, high) = self._norm_bounds_from_dataset_name(dataset_name, component_key=key) output = unnormalize_by_bounds(value, low=low, high=high) return output def _norm_bounds_from_dataset_name( self, dataset_name: np.ndarray, component_key: str ) -> Tuple[torch.Tensor, torch.Tensor]: """ Create an array of normalization bounds corresponding to dataset names Args: dataset_name: Array of shape [B] of dataset names for which to fetch the low and high normalization bounds. Note the values can be repeating component_key: str. One of 'action', 'translation', 'rotation'. Indicates for which control to compute the normalization bounds Returns: Tuple of low and high bounds or norm and std, each of shape [B, -1] """ norm = getattr(self.config, f"{component_key}_norm") if is_mean_norm(norm): (stats_key_1, stats_key_2) = ("mean", "std") else: (stats_key_1, stats_key_2) = ("low", "high") if component_key == "joints": if not isinstance(norm, collections.abc.Mapping): raise NotImplementedError() stats = { key: torch.from_numpy(np.tile(np.reshape(value, [1, -1]), [len(dataset_name), 1])) for (key, value) in self.joints_norm_bounds["ANY"].items() } return tuple(stats.values()) component_size = list(list(self.norm_bounds[component_key].values())[0].values())[0].shape[-1] if self.dataset_names == ["ANY"]: stats_1 = self.norm_bounds[component_key]["ANY"][stats_key_1] stats_2 = self.norm_bounds[component_key]["ANY"][stats_key_2] stats_1 = np.repeat(np.expand_dims(stats_1, axis=0), len(dataset_name), axis=0) stats_2 = np.repeat(np.expand_dims(stats_2, axis=0), len(dataset_name), axis=0) else: (unique_names, _, inverse_indices, _) = np_unique(dataset_name) stats_1 = np.zeros([len(unique_names), component_size], dtype=np.float32) stats_2 = np.zeros([len(unique_names), component_size], dtype=np.float32) for i, ds_name in enumerate(unique_names): stats_1[i] = self.norm_bounds[component_key][ds_name][stats_key_1].numpy() stats_2[i] = self.norm_bounds[component_key][ds_name][stats_key_2].numpy() stats_1 = stats_1[inverse_indices] stats_2 = stats_2[inverse_indices] return torch.from_numpy(stats_1), torch.from_numpy(stats_2) @cached_property def obs_rotation_norm_bounds(self) -> Dict[str, Dict[str, torch.Tensor]]: return rotation_norm_bounds( rotation_norm=self.config.obs_rotation_norm, rotation_format=self.config.rotation_format, stats=self._observation_stats, dataset_names=self.dataset_names, ) @cached_property def obs_translation_norm_bounds(self) -> Dict[str, Dict[str, torch.Tensor]]: return translation_norm_bounds( translation_norm=self.config.obs_translation_norm, stats=self._observation_stats, dataset_names=self.dataset_names, ) @cached_property def rotation_norm_bounds(self) -> Dict[str, Dict[str, torch.Tensor]]: return rotation_norm_bounds( rotation_norm=self.config.rotation_norm, rotation_format=self.config.rotation_format, stats=self._control_stats, dataset_names=self.dataset_names, ) @cached_property def translation_norm_bounds(self) -> Dict[str, Dict[str, torch.Tensor]]: return translation_norm_bounds( translation_norm=self.config.translation_norm, stats=self._control_stats, dataset_names=self.dataset_names, ) @cached_property def joints_norm_bounds(self) -> Dict[str, Dict[str, torch.Tensor]]: """ NOTE: - Joint values across all joints and all datasets vary in the range [-2pi; 2pi] - The effective range of a single joint is in practice one of [-2pi; 0], [-pi; pi], [0; 2pi] - It's possible to shift all ranges to [-pi; pi], but it requires careful handling for each joint """ low = torch.tensor(self.config.joints_norm["low"], dtype=torch.float32) high = torch.tensor(self.config.joints_norm["high"], dtype=torch.float32) results = {"ANY": {"low": low, "high": high}} return results @cached_property def _observation_stats(self) -> Dict[str, Dict[str, Dict[str, List[float]]]]: return { "bridge": { "euler": { "max": [3.141592653589793, 1.570796251296997, 3.141204357147217], "mean": [ -0.25754162314671525, -0.12370228389510128, 0.1620053749182691, ], "min": [-3.141592653492551, -1.4832241535186768, -3.14153790473938], "q01": [-3.138795563420751, -0.56544608771801, -1.4952478170394896], "q99": [3.138720980629329, 0.2677614077925682, 2.0032371997833236], "std": [3.0257414011616577, 0.1622662085147332, 0.6404942954645315], }, "gripper": { "max": [1.0370277166366577], "min": [0.04637829214334488], "q01": [0.05192930996417999], "q99": [1.0118417739868164], }, "joints": { "max": [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], "mean": [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], "min": [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], "q01": [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], "q99": [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], "std": [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], }, "translation": { "max": [0.5862360596656799, 0.4034728705883026, 0.3568263053894043], "mean": [ 0.309032678604126, 0.03403777256608009, 0.061277542263269424, ], "min": [ -0.04167502000927925, -0.2889411449432373, -0.13934996724128723, ], "q01": [ 0.1711955964565277, -0.15639324486255646, -0.048255354166030884, ], "q99": [ 0.4604376256465912, 0.24112474918365479, 0.18886254727840424, ], "std": [ 0.0635896623134613, 0.09153717756271362, 0.049334850162267685, ], }, }, "bridge_orig": { "euler": { "max": [3.141592653589793, 1.570796251296997, 3.141204357147217], "mean": [ -0.25754162314671525, -0.12370228389510128, 0.1620053749182691, ], "min": [-3.141592653492551, -1.4832241535186768, -3.14153790473938], "q01": [-3.138795563420751, -0.56544608771801, -1.4952478170394896], "q99": [3.138720980629329, 0.2677614077925682, 2.0032371997833236], "std": [3.0257414011616577, 0.1622662085147332, 0.6404942954645315], }, "gripper": { "max": [1.0370277166366577], "min": [0.04637829214334488], "q01": [0.05192930996417999], "q99": [1.0118417739868164], }, "joints": { "max": [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], "mean": [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], "min": [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], "q01": [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], "q99": [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], "std": [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], }, "translation": { "max": [0.5862360596656799, 0.4034728705883026, 0.3568263053894043], "mean": [ 0.309032678604126, 0.03403777256608009, 0.061277542263269424, ], "min": [ -0.04167502000927925, -0.2889411449432373, -0.13934996724128723, ], "q01": [ 0.1711955964565277, -0.15639324486255646, -0.048255354166030884, ], "q99": [ 0.4604376256465912, 0.24112474918365479, 0.18886254727840424, ], "std": [ 0.0635896623134613, 0.09153717756271362, 0.049334850162267685, ], }, }, "droid": { "euler": { "max": [3.141592502593994, 1.5705928802490234, 3.1415867805480957], "mean": [ 0.3140628098409554, -0.09296274023036387, -0.07227215454779846, ], "min": [ -3.141592502593994, -1.5691150426864624, -3.1415374279022217, ], "q01": [ -3.1378602981567383, -1.2125312042236327, -2.1614069032669065, ], "q99": [3.137854380607605, 0.9200375998020163, 1.9367506909370364], "std": [2.926265757944871, 0.363273475703332, 0.7576065217938824], }, "gripper": { "max": [1.0], "min": [0.0], "q01": [0.0], "q99": [0.9911894202232361], }, "joints": { "max": [ 2.668445110321045, 1.5691218376159668, 2.666306734085083, -0.3114914000034332, 2.6624162197113037, 4.28157901763916, 2.752457857131958, ], "mean": [ 0.023137084334640106, 0.2704989977282293, -0.01451389357228282, -2.018709403792315, -0.042720520800030394, 2.350281188152209, 0.12424663946659845, ], "min": [ -2.6536705493927, -1.547789216041565, -2.6781487464904785, -2.9409868717193604, -2.6705946922302246, 0.24893812835216522, -2.7615714073181152, ], "q01": [ -0.9026106441020965, -0.8547340619564057, -0.9028875434398651, -2.7698556280136106, -1.6851656341552732, 1.2335169839859008, -1.9587260699272155, ], "q99": [ 0.9569852340221403, 1.4148830294609054, 0.7693877756595566, -0.4545914208889008, 1.5623322343826267, 3.475611729621887, 2.263479118347167, ], "std": [ 0.31695080251469465, 0.49522214687158767, 0.27993538230553827, 0.478161574676113, 0.4969961591445458, 0.45101008525403846, 0.7287264344068457, ], }, "translation": { "max": [0.8575563430786133, 0.799155592918396, 1.0043904781341553], "mean": [ 0.5283099395864883, 0.005363794653877434, 0.3120132207021294, ], "min": [ -0.15604186058044434, -0.827903687953949, -0.2347021996974945, ], "q01": [ 0.26669957995414734, -0.43774398624897004, -0.048167889714241026, ], "q99": [0.7774086785316463, 0.428325751423835, 0.776091011762619], "std": [ 0.1148424841779685, 0.17489566608140428, 0.16541062032731538, ], }, }, "roboset": { "euler": { "max": [3.1415449294818236, 1.5705575529715636, 3.141527342124582], "mean": [ -0.0398455755412464, 1.0518070390619125, -0.015345692503002759, ], "min": [ -3.1415813300509536, -1.5222832468962035, -3.141575300866071, ], "q01": [ -2.9414386317311187, -0.24976770655101155, -2.985256521212579, ], "q99": [2.9380437893235993, 1.5403010739503078, 2.9746912523985025], "std": [1.7866587696177456, 0.40620530263065, 1.7288511340250616], }, "gripper": { "max": [0.83056640625], "min": [0.0001499652862548828], "q01": [0.0001499652862548828], "q99": [0.82666015625], }, "joints": { "max": [ 0.96240234375, 1.1162109375, 1.1064453125, -0.98095703125, 2.30859375, 1.576171875, 1.7412109375, ], "mean": [ 0.005913593806326389, 0.1877261847257614, 0.04653879255056381, -2.0529513359069824, -0.011298442259430885, 0.6185526251792908, -0.01701134257018566, ], "min": [ -0.8330078125, -0.74658203125, -0.8642578125, -2.892578125, -1.390625, -0.24658203125, -2.953125, ], "q01": [ -0.41015625, -0.5302734375, -0.6455078125, -2.57421875, -0.76416015625, -0.0386962890625, -1.435546875, ], "q99": [ 0.66455078125, 0.9501953125, 0.7529296875, -1.251953125, 0.75244140625, 1.2314453125, 1.384765625, ], "std": [ 0.17915399372577667, 0.32234326004981995, 0.26069700717926025, 0.31767210364341736, 0.205329030752182, 0.33385637402534485, 0.6263682842254639, ], }, "translation": { "max": [0.5747738480567932, 0.3972920775413513, 0.7443570494651794], "mean": [ 0.3331542909145355, 0.019357483834028244, 0.37330344319343567, ], "min": [ 0.09978063404560089, -0.29593944549560547, 0.10065606236457825, ], "q01": [ 0.18437016010284424, -0.25699371099472046, 0.15134164690971375, ], "q99": [0.543661892414093, 0.29646238684654236, 0.6682320833206177], "std": [ 0.07849054038524628, 0.12241040915250778, 0.1460595279932022, ], }, }, } @cached_property def _control_stats(self) -> Dict[str, Dict[str, Dict[str, List[float]]]]: if is_global_norm(self.config.rotation_norm) and is_global_norm(self.config.translation_norm): return {} with open(self.config.control_stats_path, "r") as file: stats = yaml.safe_load(file) if self.config.delta_controls: if self.control_io_config.future_controls_sequence_stride_sec is None: horizon = 0.0 else: horizon = self.control_io_config.future_controls_sequence_stride_sec elif self.control_io_config.future_controls_sequence_stride_sec is None: if self.control_io_config.future_controls_sequence_length == 1: horizon = 0.0 else: raise NotImplementedError() else: horizon = ( self.control_io_config.future_controls_sequence_length * self.control_io_config.future_controls_sequence_stride_sec ) key = f"horizon_{round(horizon, 2)}s" if key in stats: stats = stats[key] else: raise ValueError( f"Missing control statistics key {key} for future_controls_sequence_length={self.config.control_io_config.future_controls_sequence_length} future_controls_sequence_stride_sec={self.config.control_io_config.future_controls_sequence_stride_sec}. Available keys: [{stats.keys()}]" ) return stats @cached_property def dataset_names(self) -> List[str]: if ( is_global_norm(self.config.rotation_norm) and is_global_norm(self.config.obs_rotation_norm) and is_global_norm(self.config.translation_norm) and is_global_norm(self.config.obs_translation_norm) ): return ["ANY"] return list(set(self._control_stats.keys()) | set(self._observation_stats.keys())) def delta_to_relative_translations(translation_sequence: torch.Tensor) -> torch.Tensor: """ Transform a sequence of translation vectors encoded w.r.t. PREVIOUS frame in the sequence to encoding w.r.t. the 0-th element preceding the sequence Ex: Sequence of points: T1, T2, T3, T4 `translation_sequence` contains the vectors: T0T1, T1T2, T2T3, T3T4, where T0 is the base frame, implicitly encoded in T0T1 Output: T0T1, T0T2, T0T3, T0T4 Args: translation_sequence: torch.Tensor of shape [..., S, 3], containing the translation vectors, where S corresponds to the sequence dimension Returns: torch.Tensor of the same shape as translation_sequence, containing delta translations """ assert translation_sequence.ndim >= 3, translation_sequence.shape delta_translations = torch.cumsum(translation_sequence, dim=-2) return delta_translations class RegressionProcessor(VLAMProcessor): def policy_control_plan_from_model_target( self, target: RoboticsTarget, dataset_name: np.ndarray ) -> RoboticsControlPlan: translation_m = self.unnormalize(target.translation, dataset_name=dataset_name, key="translation") rotation = self.unnormalize(target.rotation, dataset_name=dataset_name, key="rotation") rotmat = convert_rotation(rotation, RotationFormat.ROTMAT) gripper_prob = target.gripper if self.config.delta_controls: translation_m = delta_to_relative_translations(translation_m) rotmat = delta_to_relative_rotations(rotmat) return RoboticsControlPlan( translation_m=translation_m, rotmat=rotmat, gripper_prob=gripper_prob, valid_mask=target.valid_mask, ) def policy_control_plan_from_model_output( self, model_output: RoboticsOutput, dataset_name: np.ndarray, valid_mask: torch.Tensor, ) -> RoboticsControlPlan: """Called during inference to create control plan from model output""" translation_m = self.unnormalize( model_output.translation, dataset_name=dataset_name, key="translation" ) rotation = self.unnormalize(model_output.rotation, dataset_name=dataset_name, key="rotation") rotmat = convert_rotation(rotation, RotationFormat.ROTMAT, autonorm=True) gripper_prob = torch.sigmoid(model_output.gripper) if self.config.delta_controls: translation_m = delta_to_relative_translations(translation_m) rotmat = delta_to_relative_rotations(rotmat) return RoboticsControlPlan( translation_m=translation_m, rotmat=rotmat, gripper_prob=gripper_prob, valid_mask=valid_mask, ) class PiZeroFlowMatchingProcessor(RegressionProcessor): def __init__(self, **kwargs): super().__init__(**kwargs) self.generator: torch.Generator = torch.Generator() @cached_property def beta_distribution(self) -> torch.distributions.Beta: return torch.distributions.Beta( self.config.distribution_hyperparams.get("alpha", 1.5), self.config.distribution_hyperparams.get("beta", 1.0), ) def create_input(self, *args, **kwargs) -> RoboticsFlowInput: """In practice used only during inference""" inputs = super().create_input(*args, **kwargs) flow_input: FlowInput = self.sample_t0_input(batch_size=1, device=torch.device("cpu")) inputs = RoboticsFlowInput(**inputs.as_json(), flow_input=flow_input[0, ...]) return inputs def sample_timestep(self, batch_size: int) -> torch.Tensor: if self.config.timestep_distribution.lower() == "uniform": eps = 1e-05 sample = (torch.rand(1, generator=self.generator) + torch.arange(batch_size) / batch_size) % ( 1 - eps ) elif self.config.timestep_distribution.lower() == "beta": sample = self.beta_distribution.sample([batch_size, 1, 1]) sample = (1 - self.config.sig_min) * (1 - sample) else: raise NotImplementedError(self.config.timestep_distribution) sample = sample.view(batch_size, 1, 1) return sample def _psi_t(self, timestep: torch.Tensor, x_0: torch.Tensor, x_1: torch.Tensor) -> torch.Tensor: return (1 - (1 - self.config.sig_min) * timestep) * x_0 + timestep * x_1 def _dpsi_dt(self, x_0: torch.Tensor, x_1: torch.Tensor) -> torch.Tensor: return x_1 - (1 - self.config.sig_min) * x_0 def sample_t0_input(self, batch_size: int, device: torch.device) -> FlowInput: if self.config.r0_distribution == "normal": controls_t0 = torch.randn( [ batch_size, self.config.control_io_config.future_controls_sequence_length, 3 + self.rotation_components + 1, ], generator=self.generator, ).to(device=device) (translation_t0, rotation_t0, gripper_t0) = torch.split( controls_t0, [3, self.rotation_components, 1], dim=-1 ) rotation_t0 = normalize_rotation(rotation_t0) elif self.config.r0_distribution == "uniform": controls_t0 = torch.randn( [ batch_size, self.config.control_io_config.future_controls_sequence_length, 4, ], generator=self.generator, ).to(device=device) (translation_t0, gripper_t0) = torch.split(controls_t0, [3, 1], dim=-1) rotation_t0 = convert_rotation( roma.random_unitquat( ( batch_size, self.config.control_io_config.future_controls_sequence_length, ), device=device, ), self.config.rotation_format, ) else: raise NotImplementedError(self.config.r0_distribution) if self.config.rotation_format == RotationFormat.QUATERNION: rotation_t0 = quaternion_half_cover(rotation_t0) timestep = torch.zeros([batch_size, 1, 1], device=device) return FlowInput( timestep=timestep, translation_t0=translation_t0, rotation_t0=rotation_t0, gripper_t0=gripper_t0, translation_t=None, rotation_t=None, gripper_t=None, ) def policy_control_plan_from_model_output( self, model_output: RoboticsOutput, dataset_name: np.ndarray, valid_mask: torch.Tensor, ) -> RoboticsControlPlan: if self.config.translation_norm == Normalization.NONE or is_mean_norm(self.config.translation_norm): model_output = model_output.replace(translation=torch.clamp(model_output.translation, -1, 1)) if self.config.rotation_norm == Normalization.NONE or is_mean_norm(self.config.rotation_norm): model_output = model_output.replace(rotation=torch.clamp(model_output.rotation, -1, 1)) control_plan = super().policy_control_plan_from_model_output(model_output, dataset_name, valid_mask) control_plan = control_plan.replace(gripper_prob=torch.clamp(model_output.gripper, 0, 1)) return control_plan def make_causal_mask(shape: Sequence[int]) -> torch.Tensor: """ Create a causal attention mask of shape `shape` Args: shape: Shape of the output mask, the last two dimensions correspond to [query_seq_len, kv_seq_len] Returns: torch.Tensor of dtype torch.bool. False values indicate that the row (i.e. query) can't attend to the corresponding column (i.e. key) Example: shape = (3, 5) -> Mask the upper triangular part [ [ 1, 0, 0, 0, 0], [ 1, 1, 0, 0, 0], [ 1, 1, 1, 0, 0] ] """ return torch.tril(torch.ones(shape, dtype=torch.bool), diagonal=0) def enable_full_attn_blocks(attn_mask: torch.Tensor, full_attn: torch.Tensor) -> torch.Tensor: """ Enable full bi-directional attention in `attn_mask` inside specific blocks Args: attn_mask: Existing attention mask of shape [..., query_seq_len, kv_seq_len] and dtype torch.bool where False values indicate disabled attention full_attn: torch.Tensor of shape [query_seq_len], dtype torch.bool. Blocks of True values indicate positions where full bi-directional attention should be enabled Example: 1, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 0, 0, 0, 0, 0, 1, 1, 0, 0, 0, 0, 0, 0, 1, 1, 1, 0, 0, 0, 0, 0, 1, 1, 1, 0, 0, 0, 0, 0, 1, 1, 1, 0, 0, 0, 0, 0, 1, 1, 1, 1, 0, 0, 0, 0, -> 1, 1, 1, 1, 0, 0, 0, 0, 1, 1, 1, 1, 1, 0, 0, 0, 1, 1, 1, 1, 1, 0, 0, 0, 1, 1, 1, 1, 1, 1, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, """ assert full_attn.dtype == torch.bool, full_attn.dtype assert full_attn.ndim == 1, full_attn.shape assert full_attn.shape[0] == attn_mask.shape[-2], f"{full_attn.shape[0]}, {attn_mask.shape}" if attn_mask.shape[-1] != attn_mask.shape[-2]: raise NotImplementedError("Only self-attention supported right now.") x = full_attn.view(-1, 1) & full_attn.view(1, -1) x = x | make_causal_mask([full_attn.shape[0], full_attn.shape[0]]) x = torch.cumprod(x, dim=1).to(dtype=torch.bool) x = x & x.permute(1, 0) mask_positions = torch.sum(x, dim=0) == 1 & ~full_attn mask_indices = torch.where(mask_positions)[0] x[mask_indices, mask_indices] = 0 attn_mask = attn_mask | expand_dims(x, ndim=attn_mask.ndim, order=[-1, 1, 1]) return attn_mask IGNORE_INDEX = -100 class PaliGemmaProcessor(VLMProcessor): def __init__( self, config: PaliGemmaProcessorConfig, hf_processor: transformers.models.paligemma.processing_paligemma.PaliGemmaProcessor, **kwargs, ): del kwargs super().__init__(config) self.hf_processor = hf_processor self.hf_processor.image_processor.size = dict(self.config.image_sizes["main"].as_json()) self.hf_processor.image_seq_length = self.config.num_image_tokens["main"] self.hf_processor.image_processor.image_seq_length = self.config.num_image_tokens["main"] self.bos_id: int = self.tokenizer.bos_token_id self.eos_id: int = self.tokenizer.eos_token_id self.sep_token = "\n" self.sep_id: int = self.tokenizer( self.sep_token, padding=False, add_special_tokens=False, return_attention_mask=False, )["input_ids"][0] self.image_token_id: int = self.tokenizer( self.config.image_token, padding=False, add_special_tokens=False, return_attention_mask=False, )["input_ids"][0] self.image_tokens: list[int] = [self.image_token_id] * sum(self.config.num_image_tokens.values()) self.bbox_pattern = re.compile( "\\[(\\d+\\.\\d+),\\s*(\\d+\\.\\d+),\\s*(\\d+\\.\\d+),\\s*(\\d+\\.\\d+)\\]" ) def preprocess_inputs( self, chat: List[str], images: Dict[str, List[PIL.Image.Image]] ) -> Dict[str, torch.Tensor | Dict[str, torch.Tensor]]: """ Based on PaliGemma paper https://arxiv.org/pdf/2407.07726 and example code at https://ai.google.dev/gemma/docs/paligemma/fine-tuning-paligemma#create_model_inputs Chat must be always made of separate messages from user and model, always starting with user ... ... Args: chat: List[str] of even size where each entry corresponds to a different turn in the conversation images: Dict[str, List[PIL.Image.Image]] where different cameras correspond to different keys in the Dict and the List corresponds to history of images """ for key, value in images.items(): if not isinstance(value, list): raise TypeError(f"Camera {key} contains values of type {type(value)} instead of list") (input_ids, target_ids) = ([], []) for i, text in enumerate(chat): text = text.replace(self.sep_token, " ").replace("", "") text = self.bbox_pattern.sub(self._bbox_to_loc_tokens, text) turn_input_ids: List[int] = self.tokenizer( text, padding=False, add_special_tokens=False, return_attention_mask=False, )["input_ids"] if i % 2 == 0: turn_target_ids = [IGNORE_INDEX] * len(turn_input_ids) else: turn_target_ids = turn_input_ids if i != len(chat) - 1: turn_input_ids = turn_input_ids + [self.sep_id] turn_target_ids = turn_target_ids + [IGNORE_INDEX] input_ids = input_ids + turn_input_ids target_ids = target_ids + turn_target_ids input_ids = [self.bos_id] + input_ids + [self.eos_id] target_ids = [IGNORE_INDEX] + target_ids + [self.eos_id] image_tokens = self.image_tokens if self.config.max_language_tokens > 0: input_ids = input_ids[: self.config.max_language_tokens] target_ids = target_ids[: self.config.max_language_tokens] input_ids = image_tokens + input_ids target_ids = [IGNORE_INDEX] * len(image_tokens) + target_ids input_ids = torch.tensor(input_ids, dtype=torch.int64) target_ids = torch.tensor(target_ids, dtype=torch.int64) image_tensors: Dict[str, torch.Tensor] = { f"{camera_name}.siglip": self.hf_processor.image_processor( camera_images, size=self.config.image_sizes[camera_name].as_json(), return_tensors="pt", )["pixel_values"] for (camera_name, camera_images) in images.items() } attn_mask = make_causal_mask([len(input_ids), len(input_ids)]) attn_mask = enable_full_attn_blocks(attn_mask, full_attn=target_ids == IGNORE_INDEX) return { "input_ids": input_ids, "target_ids": target_ids, "images": image_tensors, "attn_mask": attn_mask, } @property def tokenizer(self) -> transformers.PreTrainedTokenizerBase: return self.hf_processor.tokenizer @staticmethod def _bbox_to_loc_tokens(match: str) -> str: """ https://developers.googleblog.com/en/gemma-explained-paligemma-architecture/ """ floats = list(map(float, match.groups())) transformed = [f"" for num in floats] return f"[{', '.join(transformed)}]" @property def image_sizes(self) -> Dict[str, ImageSizeConfig]: return self.config.image_sizes class PaliGemmaDepthProcessor(PaliGemmaProcessor): def __init__( self, config: PaliGemmaProcessorConfig, hf_processor: transformers.models.paligemma.processing_paligemma.PaliGemmaProcessor, depth_tokens: int, ): super().__init__(config, hf_processor) vocab_size = len(self.tokenizer) self.depth_token_ids = np.arange(vocab_size - depth_tokens, vocab_size) self.depth_input_transforms = { camera_name: torchvision.transforms.v2.Compose( [ torchvision.transforms.v2.Resize( size=(camera_image_size.height, camera_image_size.width), interpolation=torchvision.transforms.v2.InterpolationMode.BICUBIC, max_size=None, antialias=True, ), torchvision.transforms.v2.ToTensor(), torchvision.transforms.v2.Normalize( mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225] ), ] ) for (camera_name, camera_image_size) in self.config.image_sizes.items() } def preprocess_inputs( self, chat: List[str], images: Dict[str, List[PIL.Image.Image]] ) -> Dict[str, torch.Tensor | Dict[str, torch.Tensor]]: inputs = super().preprocess_inputs(chat=chat, images=images) depth_images: Dict[str, torch.Tensor] = { f"{camera_name}.depth": torch.stack( self.depth_input_transforms[camera_name](camera_images), dim=0 ) for (camera_name, camera_images) in images.items() } inputs["images"] = {**inputs["images"], **depth_images} return inputs @property def num_depth_tokens(self) -> int: return len(self.depth_token_ids)