yermandy commited on
Commit
4ab1055
·
verified ·
1 Parent(s): a6872cf

Upload folder using huggingface_hub

Browse files
Files changed (3) hide show
  1. config.json +10 -0
  2. model.safetensors +3 -0
  3. modeling_gend.py +163 -0
config.json ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "architectures": [
3
+ "GenD"
4
+ ],
5
+ "backbone": "facebook/dinov3-vitl16-pretrain-lvd1689m",
6
+ "dtype": "float32",
7
+ "head": "LinearNorm",
8
+ "model_type": "GenD",
9
+ "transformers_version": "4.56.2"
10
+ }
model.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:76d45aa5528a9f82520b8c251578248ea4569e32ea2bf5b4e7cd850bbe51f1aa
3
+ size 1212579400
modeling_gend.py ADDED
@@ -0,0 +1,163 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+ from PIL import Image
5
+ from transformers import PretrainedConfig, PreTrainedModel
6
+
7
+
8
+ class LinearProbe(nn.Module):
9
+ def __init__(self, input_dim, num_classes, normalize_inputs=False):
10
+ super().__init__()
11
+ self.linear = nn.Linear(input_dim, num_classes)
12
+ self.normalize_inputs = normalize_inputs
13
+
14
+ def forward(self, x: torch.Tensor, **kwargs):
15
+ if self.normalize_inputs:
16
+ x = F.normalize(x, p=2, dim=1)
17
+
18
+ return self.linear(x)
19
+
20
+
21
+ class CLIPEncoder(nn.Module):
22
+ def __init__(self, model_name="openai/clip-vit-large-patch14"):
23
+ super().__init__()
24
+
25
+ from transformers import CLIPModel, CLIPProcessor
26
+
27
+ try:
28
+ self._preprocess = CLIPProcessor.from_pretrained(model_name)
29
+ except Exception:
30
+ self._preprocess = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch16")
31
+
32
+ clip: CLIPModel = CLIPModel.from_pretrained(model_name)
33
+
34
+ # take vision model from CLIP, maps image to vision_embed_dim
35
+ self.vision_model = clip.vision_model
36
+
37
+ self.model_name = model_name
38
+
39
+ self.features_dim = self.vision_model.config.hidden_size
40
+
41
+ # take visual_projection, maps vision_embed_dim to projection_dim
42
+ self.visual_projection = clip.visual_projection
43
+
44
+ def preprocess(self, image: Image) -> torch.Tensor:
45
+ return self._preprocess(images=image, return_tensors="pt")["pixel_values"][0]
46
+
47
+ def forward(self, preprocessed_images: torch.Tensor) -> torch.Tensor:
48
+ return self.vision_model(preprocessed_images).pooler_output
49
+
50
+ def get_features_dim(self):
51
+ return self.features_dim
52
+
53
+
54
+ class DINOEncoder(nn.Module):
55
+ def __init__(self, model_name="facebook/dinov2-with-registers-base"):
56
+ super().__init__()
57
+
58
+ from transformers import AutoImageProcessor, AutoModel, Dinov2Model, Dinov2WithRegistersModel
59
+
60
+ self._preprocess = AutoImageProcessor.from_pretrained(model_name)
61
+ self.backbone: Dinov2Model | Dinov2WithRegistersModel = AutoModel.from_pretrained(model_name)
62
+
63
+ self.features_dim = self.backbone.config.hidden_size
64
+
65
+ def preprocess(self, image: Image) -> torch.Tensor:
66
+ return self._preprocess(images=image, return_tensors="pt")["pixel_values"][0]
67
+
68
+ def forward(self, inputs: torch.Tensor) -> torch.Tensor:
69
+ return self.backbone(inputs).last_hidden_state[:, 0]
70
+
71
+ def get_features_dim(self) -> int:
72
+ return self.features_dim
73
+
74
+
75
+ class PerceptionEncoder(nn.Module):
76
+ def __init__(self, model_name="vit_pe_core_large_patch14_336"):
77
+ super().__init__()
78
+
79
+ import timm
80
+ from timm.models.eva import Eva
81
+
82
+ self.backbone: Eva = timm.create_model(
83
+ model_name,
84
+ pretrained=True,
85
+ dynamic_img_size=True,
86
+ )
87
+
88
+ # Get model specific transforms (normalization, resize)
89
+ data_config = timm.data.resolve_model_data_config(self.backbone)
90
+ data_config["input_size"] = (3, 224, 224)
91
+
92
+ self._preprocess = timm.data.create_transform(**data_config, is_training=False)
93
+
94
+ # Remove head
95
+ self.backbone.head = nn.Identity()
96
+
97
+ self.features_dim = self.backbone.num_features
98
+
99
+ def preprocess(self, image: Image.Image) -> torch.Tensor:
100
+ return self._preprocess(image)
101
+
102
+ def forward(self, inputs: torch.Tensor) -> torch.Tensor:
103
+ return self.backbone(inputs)
104
+
105
+ def get_features_dim(self) -> int:
106
+ return self.features_dim
107
+
108
+
109
+ class GenDConfig(PretrainedConfig):
110
+ model_type = "GenD"
111
+
112
+ def __init__(self, backbone: str = "openai/clip-vit-large-patch14", head: str = "linear", **kwargs):
113
+ super().__init__(**kwargs)
114
+ self.backbone = backbone
115
+ self.head = head
116
+
117
+
118
+ class GenD(PreTrainedModel):
119
+ config_class = GenDConfig
120
+
121
+ def __init__(self, config):
122
+ super().__init__(config)
123
+
124
+ self.head = config.head
125
+ self.backbone = config.backbone
126
+ self.config = config
127
+
128
+ self._init_feature_extractor()
129
+ self._init_head()
130
+
131
+ def _init_feature_extractor(self):
132
+ backbone = self.backbone
133
+ backbone_lowercase = backbone.lower()
134
+
135
+ if "clip" in backbone_lowercase:
136
+ self.feature_extractor = CLIPEncoder(backbone)
137
+
138
+ elif "vit_pe" in backbone_lowercase:
139
+ self.feature_extractor = PerceptionEncoder(backbone)
140
+
141
+ elif "dino" in backbone_lowercase:
142
+ self.feature_extractor = DINOEncoder(backbone)
143
+
144
+ else:
145
+ raise ValueError(f"Unknown backbone: {backbone}")
146
+
147
+ def _init_head(self):
148
+ features_dim = self.feature_extractor.get_features_dim()
149
+
150
+ match self.head:
151
+ case "linear":
152
+ self.model = LinearProbe(features_dim, 2)
153
+
154
+ case "LinearNorm":
155
+ self.model = LinearProbe(features_dim, 2, True)
156
+
157
+ case _:
158
+ raise ValueError(f"Unknown head: {self.head}")
159
+
160
+ def forward(self, inputs: torch.Tensor):
161
+ features = self.feature_extractor(inputs)
162
+ outputs = self.model.forward(features)
163
+ return outputs