import torch import torch.nn as nn import torch.nn.functional as F from PIL import Image from transformers import PretrainedConfig, PreTrainedModel class LinearProbe(nn.Module): def __init__(self, input_dim, num_classes, normalize_inputs=False): super().__init__() self.linear = nn.Linear(input_dim, num_classes) self.normalize_inputs = normalize_inputs def forward(self, x: torch.Tensor, **kwargs): if self.normalize_inputs: x = F.normalize(x, p=2, dim=1) return self.linear(x) class CLIPEncoder(nn.Module): def __init__(self, model_name="openai/clip-vit-large-patch14"): super().__init__() from transformers import CLIPModel, CLIPProcessor try: self._preprocess = CLIPProcessor.from_pretrained(model_name) except Exception: self._preprocess = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch16") clip: CLIPModel = CLIPModel.from_pretrained(model_name) # take vision model from CLIP, maps image to vision_embed_dim self.vision_model = clip.vision_model self.model_name = model_name self.features_dim = self.vision_model.config.hidden_size # take visual_projection, maps vision_embed_dim to projection_dim self.visual_projection = clip.visual_projection def preprocess(self, image: Image) -> torch.Tensor: return self._preprocess(images=image, return_tensors="pt")["pixel_values"][0] def forward(self, preprocessed_images: torch.Tensor) -> torch.Tensor: return self.vision_model(preprocessed_images).pooler_output def get_features_dim(self): return self.features_dim class DINOEncoder(nn.Module): def __init__(self, model_name="facebook/dinov2-with-registers-base"): super().__init__() from transformers import AutoImageProcessor, AutoModel, Dinov2Model, Dinov2WithRegistersModel self._preprocess = AutoImageProcessor.from_pretrained(model_name) self.backbone: Dinov2Model | Dinov2WithRegistersModel = AutoModel.from_pretrained(model_name) self.features_dim = self.backbone.config.hidden_size def preprocess(self, image: Image) -> torch.Tensor: return self._preprocess(images=image, return_tensors="pt")["pixel_values"][0] def forward(self, inputs: torch.Tensor) -> torch.Tensor: return self.backbone(inputs).last_hidden_state[:, 0] def get_features_dim(self) -> int: return self.features_dim class PerceptionEncoder(nn.Module): def __init__(self, model_name="vit_pe_core_large_patch14_336"): super().__init__() import timm from timm.models.eva import Eva self.backbone: Eva = timm.create_model( model_name, pretrained=True, dynamic_img_size=True, ) # Get model specific transforms (normalization, resize) data_config = timm.data.resolve_model_data_config(self.backbone) data_config["input_size"] = (3, 224, 224) self._preprocess = timm.data.create_transform(**data_config, is_training=False) # Remove head self.backbone.head = nn.Identity() self.features_dim = self.backbone.num_features def preprocess(self, image: Image.Image) -> torch.Tensor: return self._preprocess(image) def forward(self, inputs: torch.Tensor) -> torch.Tensor: return self.backbone(inputs) def get_features_dim(self) -> int: return self.features_dim class GenDConfig(PretrainedConfig): model_type = "GenD" def __init__(self, backbone: str = "openai/clip-vit-large-patch14", head: str = "linear", **kwargs): super().__init__(**kwargs) self.backbone = backbone self.head = head class GenD(PreTrainedModel): config_class = GenDConfig def __init__(self, config): super().__init__(config) self.head = config.head self.backbone = config.backbone self.config = config self._init_feature_extractor() self._init_head() def _init_feature_extractor(self): backbone = self.backbone backbone_lowercase = backbone.lower() if "clip" in backbone_lowercase: self.feature_extractor = CLIPEncoder(backbone) elif "vit_pe" in backbone_lowercase: self.feature_extractor = PerceptionEncoder(backbone) elif "dino" in backbone_lowercase: self.feature_extractor = DINOEncoder(backbone) else: raise ValueError(f"Unknown backbone: {backbone}") def _init_head(self): features_dim = self.feature_extractor.get_features_dim() match self.head: case "linear": self.model = LinearProbe(features_dim, 2) case "LinearNorm": self.model = LinearProbe(features_dim, 2, True) case _: raise ValueError(f"Unknown head: {self.head}") def forward(self, inputs: torch.Tensor): features = self.feature_extractor(inputs) outputs = self.model.forward(features) return outputs