import collections.abc import dataclasses import enum import inspect import types from collections.abc import Mapping as MappingABC from functools import cached_property from typing import ( Any, Callable, Dict, Iterable, List, Mapping, Optional, Sequence, Tuple, Type, Union, ) import torch import transformers class StrEnum(str, enum.Enum): """ A minimal drop-in replacement for backports.strenum.StrEnum """ def __str__(self): return str(self.value) def __new__(cls, value): # Create new instance that properly handles string initialization if isinstance(value, str): obj = str.__new__(cls, value) obj._value_ = value return obj return super().__new__(cls, value) @classmethod def _missing_(cls, value): # Enhanced lookup by string value with better error handling if isinstance(value, str): for member in cls: if member.value == value: return member # Return None to let enum handle the KeyError return None def __eq__(self, other): # Allow comparison with string values if isinstance(other, str): return self.value == other return super().__eq__(other) def __hash__(self): # Ensure consistent hashing return hash(self.value) class _cached_classproperty: def __init__(self, func): self.func = func self._values = {} def __get__(self, obj, klass): if klass not in self._values.keys(): self._values[klass] = self.func.__get__(obj, klass)() return self._values[klass] def cached_classproperty(func): if not isinstance(func, (classmethod, staticmethod)): func = classmethod(func) return _cached_classproperty(func) @dataclasses.dataclass class Dataclass: def __post_init__(self): pass @classmethod def make_empty(cls) -> "Dataclass": return cls( **{ k: (v.make_empty() if inspect.isclass(v) and issubclass(v, Dataclass) else None) for (k, v) in cls.types.items() } ) @cached_classproperty def fields(cls) -> Tuple[dataclasses.Field, ...]: """Returns a sorted list of the Field objects""" return tuple(sorted(dataclasses.fields(cls), key=lambda x: x.name)) @cached_classproperty def types(cls) -> Dict[str, type]: return {f.name: f.type for f in cls.fields} def as_json(self, recursive: bool = True) -> dict: return {k: v.as_json() if isinstance(v, Dataclass) and recursive else v for (k, v) in self.items()} @classmethod def keys(cls) -> List[str]: return [field.name for field in cls.fields] def values(self): return [getattr(self, field.name) for field in self.fields] def items(self, recursive: bool = False): for key, value in zip(self.keys(), self.values(), strict=True): if recursive and isinstance(value, Dataclass): for subkey, subvalue in value.items(recursive=True): yield (f"{key}.{subkey}", subvalue) else: yield (key, value) def replace(self, **kwargs): """ Return a new instance of Dataclass with the kwargs overwritten. """ kwargs = maybe_chained_keys_to_nested_dict(kwargs) data = self.as_json(recursive=False) for key, value in kwargs.items(): value_type = self.types.get(key, None) if value_type is None: raise KeyError(f"Dataclass {self.__class__} does not have a field {key}") value_type = get_maybe_optional_type(value_type) if inspect.isclass(value_type) and issubclass(value_type, Dataclass): if isinstance(value, dict): data[key] = data[key].replace(**value) else: data[key] = value else: data[key] = value return self.__class__(**data) def apply(self, fcn: Callable, recursive: bool = True, skip_nones: bool = False) -> "Dataclass": def fcn_wrapper(value: Any) -> Any: if value is None and skip_nones: return None if isinstance(value, dict) and recursive: return type(value)(**{k: fcn(v) for (k, v) in value.items()}) if isinstance(value, (list, tuple)) and recursive: return type(value)([fcn(v) for v in value]) if isinstance(value, Dataclass) and recursive: return value.apply(fcn, recursive=True, skip_nones=skip_nones) return fcn(value) return self.__class__(**{key: fcn_wrapper(value) for (key, value) in self.items()}) def __getitem__(self, index) -> "Dataclass": def extract(obj): if obj is None: return None if isinstance(obj, torch.Tensor): return obj[index] raise ValueError(f"Cannot slice {obj.__class__.__name__} object") return self.apply(extract) class Config: def __init__(self, **kwargs): self._apply_defaults() self._set_attributes(**kwargs) super().__init__() self.__post_init__() def _apply_defaults(self): """ Initializes all annotated fields with defaults or sensible instances. """ annotations = getattr(self, "__annotations__", {}) for key, type_hint in annotations.items(): # Skip if already set via class-level value or __init__ kwarg if hasattr(self, key): continue # Case 1: class variable has a default (declared at class level) if key in self.__class__.__dict__: setattr(self, key, getattr(self.__class__, key)) continue # Case 2: if the type is another Config subclass, default-construct it if inspect.isclass(type_hint) and issubclass(type_hint, Config): setattr(self, key, type_hint()) continue # Case 3: fallback None (or empty dict for mappings) if hasattr(type_hint, "__origin__") and type_hint.__origin__ in ( dict, Dict, MappingABC, ): setattr(self, key, {}) else: setattr(self, key, None) def _set_attributes(self, **kwargs): subconfig_types = self._subconfig_types for key, value in kwargs.items(): if key in subconfig_types: if not isinstance(value, Mapping): raise ValueError( f"{self.__class__.__name__}.{key} expects dict-like object for nested config, but got: {value}" ) setattr(self, key, subconfig_types[key](**value)) else: setattr(self, key, value) def keys(self) -> List[str]: """Get all annotated keys including those from parent classes.""" all_keys = {} # Walk through MRO in reverse to respect inheritance order for cls in reversed(self.__class__.__mro__): if cls is object: continue all_keys.update(getattr(cls, "__annotations__", {})) return list(all_keys.keys()) def items(self) -> Iterable[Tuple[str, Any]]: for key in self.keys(): yield (key, getattr(self, key)) @cached_classproperty def _subconfig_types(cls) -> dict[str, Type]: keys = { key: value for (key, value) in cls.__annotations__.items() if inspect.isclass(value) and issubclass(value, Config) } for base in cls.__bases__: if not issubclass(base, Config): continue keys = {**keys, **base._subconfig_types} return keys def __post_init__(self): pass def as_json(self) -> dict: data = {} for key, value in self.items(): if isinstance(value, Config): data[key] = value.as_json() elif ( isinstance(value, collections.abc.Sequence) and len(value) > 0 and isinstance(value[0], Config) ): data[key] = [v.as_json() for v in value] elif ( isinstance(value, collections.abc.Mapping) and len(value) > 0 and isinstance(next(iter(value.values())), Config) ): data[key] = {k: v.as_json() for k, v in value.items()} else: data[key] = value return data class HFConfigMixin(transformers.PretrainedConfig): """ Bridge between your Config system and HF PretrainedConfig. Usage: class SPEAR1Config(HFConfigMixin, Config): model_type = "spear1" processor_config: PaliGemmaProcessorConfig ... """ def __init__(self, **kwargs): # Let HF's machinery initialize its own attributes / defaults first. # PretrainedConfig.__init__ will set things like `model_type`, # `_name_or_path`, `architectures`, and keep a `kwargs`->dict of extra items. super().__init__(**kwargs) # Now initialize your Config behavior: set defaults and construct nested configs. # We call Config.__init__ explicitly because HFConfigMixin inherits from PretrainedConfig, # and the user's concrete class will use multiple-inheritance with Config. # (This approach mirrors the earlier MRO design: class Concrete(HFConfigMixin, Config).) # We pass kwargs again so nested configs get overridden by user kwargs. # Note: Config.__init__ itself calls super().__init__() — but because we are calling # Config.__init__ directly (not via super()) the MRO won't re-call PretrainedConfig.__init__ here. # (I.e., we are deliberately calling the concrete base initializer.) Config.__init__(self, **kwargs) # type: ignore[name-defined] def to_dict(self) -> Dict[str, Any]: """ Merge HF PretrainedConfig serialization and Config.as_json(). Strategy: 1. Take HF dict (super().to_dict()) so HF metadata/defaults are present. 2. Take our nested config dict (Config.as_json(self)). 3. Update the HF dict with our nested config dict so annotated fields (nested configs, lists/dicts that should be recursively serialized) take precedence. """ # HF's representation (contains model_type, etc.). This is trusted HF serialization. hf = super().to_dict() # Our nested config representation (recursively serializes Config objects). # Do not call self.to_dict() because that would recurse back here. cfg_json = Config.as_json(self) # type: ignore[name-defined] # Merge: prefer cfg_json values for keys present in our config (so nested configs # are represented as dicts rather than raw objects or omitted). merged: Dict[str, Any] = dict(hf) merged.update(cfg_json) return merged @classmethod def from_dict( cls: Type["HFConfigMixin"], config_dict: Dict[str, Any], **kwargs, ) -> "HFConfigMixin": """ Construct by delegating to the class constructor — that will instantiate nested configs. This is simple and consistent with PretrainedConfig.from_dict/from_pretrained behavior. """ return_unused_kwargs = kwargs.pop("return_unused_kwargs", False) instance = cls(**config_dict) if return_unused_kwargs: # Return tuple of (instance, unused_kwargs) if requested # Since we consume everything in __init__, unused is typically empty return instance, {} return instance class Configurable: def __init__(self, config: Config): self._config = config @property def config(self) -> Config: return self._config class RotationFormat(StrEnum): """Determines how rotations will be encoded in the loaded batch""" EULER = "euler" QUATERNION = "quaternion" ROTMAT = "rotmat" class ResizeMode(StrEnum): """ Different modes for resizing images. """ MATCH_WIDTH = "match_width" MATCH_HEIGHT = "match_height" MATCH_MAX = "match_max" NAIVE = "naive" SMART = "smart" PAD = "pad" CROP = "crop" class Normalization(StrEnum): """Action normalization types""" NONE = "none" BOUNDS = "bounds" BOUNDS_Q99 = "bounds_q99" MEAN = "mean" def expand_dims(tensor: torch.Tensor, ndim: int, order: Sequence[int]) -> torch.Tensor: """ Expand the dimensions of `tensor` to `ndim` such that all new dimensions have size of 1 Args: tensor: torch.Tensor of any shape ndim: Number of output dimensions. Must be >= `tensor.ndim` order: Sequence of size `tensor.ndim + 1`. Contains only values of 1 and a single value of -1, indicating where the new `ndim - tensor.ndim` dimensions will be inserted Returns: torch.Tensor with dimensions `ndim`, a view of `tensor` Ex: expand_dims(torch.ones([2, 3, 4]), ndim=5, order=[1, -1, 1, 1]).shape -> [2, 1, 1, 3, 4] expand_dims(torch.ones([2, 3, 4]), ndim=5, order=[-1, 1, 1, 1]).shape -> [1, 1, 2, 3, 4] expand_dims(torch.ones([2, 3, 4]), ndim=5, order=[1, 1, 1, -1]).shape -> [2, 3, 4, 1, 1] """ assert tensor.ndim <= ndim, f"{tensor.ndim} > {ndim}; shape={tensor.shape}" assert len(order) == tensor.ndim + 1, f"{len(order)} != {tensor.ndim + 1}; shape={tensor.shape}" order = list(order) assert order.count(-1) == 1, "Order must have exactly one value of -1" assert order.count(1) == len(order) - 1, "Order must have exactly len(order) - 1 values of 1" if tensor.ndim == ndim: return tensor insert_index = order.index(-1) view = list(tensor.shape[:insert_index]) + [1] * (ndim - tensor.ndim) + list(tensor.shape[insert_index:]) tensor = tensor.view(view) return tensor def merge_dicts_recursive(dict_1: Dict[str, Any], dict_2: Dict[str, Any]) -> Dict[str, Any]: """ Merges dict_1 with dict_2 recursively. Handles clashing keys: 1. If both values are dicts, merges them recursively 2. If any value is not a dict, raises ValueError """ merged = dict(dict_1) for key, value in dict_2.items(): if key in merged: if not type(merged[key]) is type(value) is dict: raise ValueError(f"Multiple values provided for key {key}: {merged[key]} and {value}") merged[key] = merge_dicts_recursive(merged[key], value) else: merged[key] = value return merged def maybe_chained_keys_to_nested_dict(data: Dict[str, Any]) -> Dict[str, Any]: """Converts a dict with keys of the form "key1.key2.key3" to a nested dict""" unpacked_data: Dict[str, Any] = {} for key, value in data.items(): if "." not in key: unpacked_data = merge_dicts_recursive(unpacked_data, {key: value}) else: (mainkey, subkey) = key.split(".", maxsplit=1) nested_value = maybe_chained_keys_to_nested_dict({subkey: value}) unpacked_data = merge_dicts_recursive(unpacked_data, {mainkey: nested_value}) return unpacked_data def annotation_is_union(type_value: Type) -> bool: return getattr(type_value, "__origin__", None) is Union or type(type_value) is types.UnionType def annotation_is_optional(type_value: Type) -> bool: if annotation_is_union(type_value): union_args = set(type_value.__args__) if len(union_args) == 2 and type(None) in union_args: return True return False def get_maybe_optional_type(type_value: Type[Optional[Any]]) -> Type[Any]: if annotation_is_optional(type_value): type_args = type_value.__args__ if type_args[1] is type(None): return type_args[0] return type_args[1] return type_value @dataclasses.dataclass class RoboticsTarget(Dataclass): control_tokens_ids: Optional[torch.Tensor] text_tokens_ids: Optional[torch.Tensor] translation: torch.Tensor rotation: torch.Tensor gripper: torch.Tensor valid_mask: torch.Tensor @dataclasses.dataclass class RoboticsControlPlan(Dataclass): translation_m: torch.Tensor rotmat: torch.Tensor gripper_prob: torch.Tensor valid_mask: torch.Tensor def __post_init__(self): super().__post_init__() assert self.translation_m.ndim == 3, self.translation_m.shape assert self.rotmat.ndim == 3, self.rotmat.shape assert self.gripper_prob.ndim == 3, self.gripper_prob.shape @dataclasses.dataclass class RoboticsInput(Dataclass): images: Dict[str, torch.Tensor] input_ids: torch.Tensor attn_mask: torch.Tensor ee_pose_translation: torch.Tensor ee_pose_rotation: torch.Tensor gripper: torch.Tensor joints: torch.Tensor control_tokens_ids: Optional[torch.Tensor] @property def inputs_embeds(self) -> Optional[torch.Tensor]: return None @property def past_key_values(self) -> Optional[List[torch.Tensor]]: return None @cached_property def multimodal_indices(self) -> torch.Tensor: """ Returns a torch.Tensor containing only the indices of the batch examples which are multimodal. Return shape is [B] """ return torch.arange(self.input_ids.shape[0], dtype=torch.int64, device=self.input_ids.device) @cached_property def unimodal_indices(self) -> torch.Tensor: """ Returns a torch.Tensor containing only the indices of the batch examples which are unimodal. Return shape is [B] """ return torch.tensor([], dtype=torch.int64, device=self.input_ids.device) @dataclasses.dataclass class FlowInput(Dataclass): timestep: torch.Tensor translation_t: torch.Tensor rotation_t: torch.Tensor gripper_t: torch.Tensor translation_t0: torch.Tensor rotation_t0: torch.Tensor gripper_t0: torch.Tensor @dataclasses.dataclass class RoboticsFlowInput(RoboticsInput): """Input to the entire Robotics VLM""" flow_input: FlowInput @dataclasses.dataclass class DiffusionInput(Dataclass): timestep: torch.Tensor noised_translation: torch.Tensor noised_rotation: torch.Tensor noised_gripper: torch.Tensor @dataclasses.dataclass class LLMOutput(Dataclass): """Fork of transformers.modeling_outputs.CausalLMOutputWithPast""" input_ids: torch.Tensor logits: Optional[torch.Tensor] output_ids: Optional[torch.Tensor] loss: Optional[torch.Tensor] past_key_values: List[Tuple[torch.Tensor, torch.Tensor]] hidden_states: List[torch.Tensor] text_indices: torch.Tensor image_indices: torch.Tensor @classmethod def from_transformers( cls, input_ids: torch.Tensor, llm_output: transformers.modeling_outputs.CausalLMOutputWithPast, text_indices: Optional[torch.Tensor], image_indices: Optional[torch.Tensor], ) -> "LLMOutput": return LLMOutput( input_ids=input_ids, logits=llm_output.logits, output_ids=None, loss=llm_output.loss, past_key_values=( list(llm_output.past_key_values) if llm_output.past_key_values is not None else [] ), hidden_states=(list(llm_output.hidden_states) if llm_output.hidden_states is not None else []), text_indices=text_indices, image_indices=image_indices, ) def compress(self) -> "LLMOutput": """ Compress the data contained in the class so it can be moved between CPU and GPU or concatenated much faster: - hidden_states - huge tensors; take a lot of CPU time to move across devices or concat - past_key_values - huge tensors; take a lot of CPU time to move across devices or concat - logits - huge last dimension; takes a lot of CPU time to move across devices or concat """ replace: Dict[str, Any] = { "hidden_states": [], "past_key_values": [], "loss": None, "input_ids": None, } if self.logits is not None: replace["logits"] = None if self.output_ids is None or self.output_ids.shape[1] != self.text_indices.shape[0]: replace["output_ids"] = ( torch.index_select(self.logits, dim=1, index=self.text_indices) .argmax(dim=-1) .to(dtype=torch.int64) ) return self.replace(**replace) @dataclasses.dataclass class RoboticsOutput(Dataclass): translation: Optional[torch.Tensor] rotation: Optional[torch.Tensor] gripper: Optional[torch.Tensor] token_logits: Optional[torch.Tensor] token_ids: Optional[torch.Tensor] llm_output: LLMOutput def compress(self) -> "RoboticsOutput": """ Compress output and drop unnecessary components to speed up transfer GPU <-> CPU. Note that LLM logits can be extremely expensive since their size is [B, S, vocab_size], which can reach millions or billions of values for large vocab_size """ replace: Dict[str, Any] = { "llm_output": self.llm_output.compress(), "token_logits": None, } if self.token_logits is not None and self.token_ids is None: replace["token_ids"] = torch.argmax(self.token_logits, dim=-1) return self.replace(**replace) @dataclasses.dataclass class VLMOutput(Dataclass): llm_output: LLMOutput vit_tokens: Optional[torch.Tensor] attn_mask: torch.Tensor def compress(self) -> "VLMOutput": """ Compress output and drop unnecessary components to speed up transfer GPU <-> CPU. Note that LLM logits can be extremely expensive since their size is [B, S, vocab_size], which can reach millions or billions of values for large vocab_size """ return self.replace(llm_output=self.llm_output.compress()) def is_quaternion(quaternion: torch.Tensor) -> bool: return quaternion.shape[-1] == 4 def quaternion_half_cover(quaternion: torch.Tensor) -> torch.Tensor: """ Flip quaternions so they cover only a half the space. If the q_w is negative, flip the quaternion. If q_w is 0, then choose such that the first non-zero component is positive. Note that geometrically, this doesn't correspond to a single hemisphere of the unit sphere. Follows https://docs.scipy.org/doc/scipy/reference/generated/scipy.spatial.transform.Rotation.as_quat.html#scipy.spatial.transform.Rotation.as_quat """ assert is_quaternion(quaternion), quaternion.shape with torch.no_grad(): is_zero = quaternion == 0 flip_condition = ( (quaternion[..., -1:] < 0) | is_zero[..., -1:] & (quaternion[..., 0:1] < 0) | is_zero[..., -1:] & is_zero[..., 0:1] & (quaternion[..., 1:2] < 0) | is_zero[..., -1:] & is_zero[..., 0:1] & is_zero[..., 1:2] & (quaternion[..., 2:3] < 0) ) quaternion = torch.where(flip_condition, -quaternion, quaternion) return quaternion def is_rotmat_3x3(rotmat: torch.Tensor) -> bool: return rotmat.shape[-2:] == torch.Size([3, 3]) def is_rotmat_9(rotmat: torch.Tensor) -> bool: return rotmat.shape[-1] == 9 def rotmat_as_9(rotmat: torch.Tensor) -> torch.Tensor: """Convert any rotmat input to [..., 9] shape""" if is_rotmat_9(rotmat): return rotmat if is_rotmat_3x3(rotmat): return rotmat.reshape(*rotmat.shape[:-2], 9) raise ValueError(f"Can't convert tensor of shape {rotmat.shape} to a 3x3 rotation matrix") def is_rotmat(rotmat: torch.Tensor) -> bool: """ Checks if the tensor shape matches that of a rotmat. However, it's not guaranteed the data is a valid rotmat. `is_orthonormal_rotmat` performs this additional check. NOTE: This might incorrectly return True if the underlying data is euler angles and accidentally `rotmat.shape[-2:] == [3, 3]`. This would happen very rarely, but use with caution """ return is_rotmat_3x3(rotmat) or is_rotmat_9(rotmat) def rotmat_as_3x3(rotmat: torch.Tensor) -> torch.Tensor: """Convert any rotmat input to [..., 3, 3] shape""" if rotmat.shape[-1] == 9: return rotmat.reshape(*rotmat.shape[:-1], 3, 3) if rotmat.shape[-2:] == torch.Size([3, 3]): return rotmat raise ValueError(f"Can't convert tensor of shape {rotmat.shape} to a 3x3 rotation matrix")