Feature Extraction
Transformers
Safetensors
English
closp
remote-sensing
text-to-image-retrieval
multimodal
geospatial
SAR
multispectral
crisis-management
earth-observation
contrastive-learning
custom_code
Instructions to use DarthReca/GeoCLOSP-RN with libraries, inference providers, notebooks, and local apps. Follow these links to get started.
- Libraries
- Transformers
How to use DarthReca/GeoCLOSP-RN with Transformers:
# Use a pipeline as a high-level helper from transformers import pipeline pipe = pipeline("feature-extraction", model="DarthReca/GeoCLOSP-RN", trust_remote_code=True)# Load model directly from transformers import AutoModel model = AutoModel.from_pretrained("DarthReca/GeoCLOSP-RN", trust_remote_code=True, dtype="auto") - Notebooks
- Google Colab
- Kaggle
| from dataclasses import dataclass | |
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| from timm import create_model | |
| from transformers import ( | |
| AutoConfig, | |
| AutoModel, | |
| AutoTokenizer, | |
| PretrainedConfig, | |
| PreTrainedModel, | |
| ) | |
| from transformers.utils import ModelOutput | |
| from .location_encoder import LocationEncoder | |
| class CLOSPConfig(PretrainedConfig): | |
| """ | |
| Configuration class for CLOSPModel. | |
| This class stores the configuration of a CLOSPModel, which is used to instantiate the model | |
| according to the specified parameters. | |
| """ | |
| model_type = "closp" | |
| def __init__( | |
| self, | |
| # Vision model parameters | |
| vision_model_key: str = "vit-s", | |
| s1_embedding_dim: int = 384, | |
| s2_embedding_dim: int = 384, | |
| s1_head_dim: int = 0, | |
| s2_head_dim: int = 0, | |
| # Text model parameters | |
| text_model_name_or_path: str = "distilbert-base-uncased", | |
| # Location encoder parameters (optional) | |
| use_location_encoder: bool = True, | |
| location_embedding_dim: int = 512, | |
| # General model parameters | |
| projection_dim: int = 768, | |
| **kwargs, | |
| ): | |
| super().__init__(**kwargs) | |
| self.vision_model_key = vision_model_key | |
| self.s1_embedding_dim = s1_embedding_dim | |
| self.s2_embedding_dim = s2_embedding_dim | |
| self.text_model_name_or_path = text_model_name_or_path | |
| self.use_location_encoder = use_location_encoder | |
| self.location_embedding_dim = location_embedding_dim | |
| self.projection_dim = projection_dim | |
| self.s1_head_dim = s1_head_dim | |
| self.s2_head_dim = s2_head_dim | |
| # --- Structured Model Output --- | |
| class CLOSPOutput(ModelOutput): | |
| """ | |
| Base class for CLOSP model's outputs. | |
| """ | |
| loss: torch.FloatTensor = None | |
| logits_per_image: torch.FloatTensor = None | |
| logits_per_text: torch.FloatTensor = None | |
| logits_per_loc_img: torch.FloatTensor = None | |
| logits_per_img_loc: torch.FloatTensor = None | |
| image_embeds: torch.FloatTensor = None | |
| text_embeds: torch.FloatTensor = None | |
| location_embeds: torch.FloatTensor = None | |
| class CLOSPModel(PreTrainedModel): | |
| config_class = CLOSPConfig | |
| def __init__(self, config: CLOSPConfig): | |
| super().__init__(config) | |
| # --- Vision Encoders --- | |
| self.s1_encoder = create_model( | |
| config.vision_model_key, | |
| in_chans=2, | |
| num_classes=config.s1_head_dim, | |
| pretrained=False, | |
| ) | |
| self.s2_encoder = create_model( | |
| config.vision_model_key, | |
| in_chans=13, | |
| num_classes=config.s2_head_dim, | |
| pretrained=False, | |
| ) | |
| self.s1_projection = nn.Linear(config.s1_embedding_dim, config.projection_dim) | |
| self.s2_projection = nn.Linear(config.s2_embedding_dim, config.projection_dim) | |
| # --- Text Encoder --- | |
| self.text_model = AutoModel.from_config( | |
| AutoConfig.from_pretrained(config.text_model_name_or_path) | |
| ) | |
| self.tokenizer = AutoTokenizer.from_pretrained(config.text_model_name_or_path) | |
| # --- Location Encoder --- | |
| if config.use_location_encoder: | |
| self.location_encoder = LocationEncoder(512, 2, 256, 10) | |
| self.location_projection = nn.Linear( | |
| config.location_embedding_dim, config.projection_dim | |
| ) | |
| def tokenize_text(self, text: str): | |
| """Tokenizes input text using the model's tokenizer.""" | |
| return self.tokenizer( | |
| text, | |
| padding="max_length", | |
| truncation=True, | |
| max_length=self.tokenizer.model_max_length, | |
| return_tensors="pt", | |
| ) | |
| def get_image_features(self, image: torch.Tensor) -> torch.Tensor: | |
| """Encodes an image tensor into features.""" | |
| image = image.float() | |
| if image.shape[1] == 2: # Sentinel-1 | |
| image_features = self.s1_projection(self.s1_encoder(image)) | |
| else: # Sentinel-2 | |
| image_features = self.s2_projection(self.s2_encoder(image)) | |
| return F.normalize(image_features, p=2, dim=-1) | |
| def get_text_features( | |
| self, input_ids: torch.Tensor, attention_mask: torch.Tensor | |
| ) -> torch.Tensor: | |
| """Encodes text tokens into features.""" | |
| text_outputs = self.text_model( | |
| input_ids=input_ids, | |
| attention_mask=attention_mask, | |
| output_hidden_states=True, | |
| ) | |
| text_features = text_outputs.last_hidden_state[:, 0, :] | |
| return F.normalize(text_features, p=2, dim=-1) | |
| def get_location_features(self, coords: torch.Tensor) -> torch.Tensor: | |
| """Encodes coordinates into features.""" | |
| if not self.config.use_location_encoder: | |
| raise ValueError( | |
| "Location encoder is not enabled for this model. Set `use_location_encoder=True` in config." | |
| ) | |
| location_features = self.location_encoder(coords) | |
| location_features = self.location_projection(location_features) | |
| return F.normalize(location_features, p=2, dim=-1) | |
| def forward( | |
| self, | |
| image: torch.Tensor, | |
| input_ids: torch.Tensor, | |
| attention_mask: torch.Tensor, | |
| coords: torch.Tensor = None, | |
| return_loss: bool = False, | |
| ) -> CLOSPOutput: | |
| image_embeds = self.get_image_features(image) | |
| text_embeds = self.get_text_features(input_ids, attention_mask) | |
| # Cosine similarity as logits | |
| logits_per_image = image_embeds @ text_embeds.T | |
| logits_per_text = logits_per_image.T | |
| # --- Optional Location Logic --- | |
| location_embeds = None | |
| logits_per_loc_img = None | |
| logits_per_img_loc = None | |
| if self.config.use_location_encoder: | |
| if coords is None: | |
| raise ValueError( | |
| "Coordinates must be provided when use_location_encoder is True." | |
| ) | |
| location_embeds = self.get_location_features(coords) | |
| logits_per_loc_img = location_embeds @ image_embeds.T | |
| logits_per_img_loc = image_embeds @ location_embeds.T | |
| # --- Optional Loss Calculation --- | |
| loss = None | |
| if return_loss: | |
| outputs = [ | |
| logits_per_image, | |
| logits_per_text, | |
| logits_per_loc_img, | |
| logits_per_img_loc, | |
| ] | |
| ground_truth = torch.arange(len(input_ids)).to(self.device) | |
| loss = [F.cross_entropy(o, ground_truth) for o in outputs if o is not None] | |
| loss = sum(loss) / len(loss) | |
| return CLOSPOutput( | |
| loss=loss, | |
| logits_per_image=logits_per_image, | |
| logits_per_text=logits_per_text, | |
| logits_per_loc_img=logits_per_loc_img, | |
| logits_per_img_loc=logits_per_img_loc, | |
| image_embeds=image_embeds, | |
| text_embeds=text_embeds, | |
| location_embeds=location_embeds, | |
| ) | |