nkkbr commited on
Commit
60f6935
·
1 Parent(s): b40e131
Files changed (2) hide show
  1. README.md +73 -0
  2. hiera_encoder.py +454 -0
README.md ADDED
@@ -0,0 +1,73 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Hiera Encoder from Meta's SAM2.1 (Segment Anything Model)
2
+
3
+ Meta's [SAM2 (Segment Anything Model v2)](https://github.com/facebookresearch/sam2) demonstrates state-of-the-art video segmentation capabilities. A core component enabling this is the **Hiera** module, which, through supervised training on object segmentation, has learned a strong understanding of hierarchical visual features.
4
+
5
+ While Meta has released the full SAM2 models and their weights, these releases are based on **PyTorch** code and **not integrated with Hugging Face Transformers** or common training frameworks such as `Trainer`, `DeepSpeed`, etc.
6
+
7
+ This repository extracts the **Hiera** module from SAM2 and **wraps it with Hugging Face compatibility**, including integration with `PretrainedConfig`, `PreTrainedModel`, etc., allowing seamless use in Hugging Face-style training and inference workflows.
8
+
9
+ ---
10
+
11
+ ## Model Details
12
+
13
+ - **Original Model**: [facebook/sam2.1-hiera-base-plus](https://huggingface.co/facebook/sam2.1-hiera-base-plus)
14
+ - **This Model**: [`nkkbr/hiera-base-plus-in-sam2.1`](https://huggingface.co/nkkbr/hiera-base-plus-in-sam2.1)
15
+
16
+ This model exposes only the **Hiera encoder** extracted from SAM2.1, wrapped for Hugging Face usage.
17
+
18
+ ---
19
+
20
+ ## Installation
21
+
22
+ You first need to install Meta’s original SAM2 code:
23
+
24
+ ```bash
25
+ git clone https://github.com/facebookresearch/sam2.git && cd sam2
26
+ pip install -e .
27
+ ```
28
+
29
+ ---
30
+
31
+ ## Usage
32
+
33
+ ```python
34
+ from hiera_encoder import HieraVisionModel
35
+
36
+ # Load the Hiera module from Hugging Face
37
+ model = HieraVisionModel.from_pretrained("nkkbr/hiera-base-plus-in-sam2.1")
38
+
39
+ # Get the raw Hiera model
40
+ model = model.hiera
41
+
42
+ # Print model parameters
43
+ for name, param in model.named_parameters():
44
+ print(f"{name:50} {param.shape}")
45
+ ```
46
+
47
+ ---
48
+
49
+ ## Weight Consistency Check
50
+
51
+ To verify that the weights are identical to those in Meta's original SAM2.1 Hiera module:
52
+
53
+ ```python
54
+ import torch
55
+ from sam2.sam2_image_predictor import SAM2ImagePredictor
56
+
57
+ # Load SAM2.1 predictor from Meta's official release
58
+ predictor = SAM2ImagePredictor.from_pretrained("facebook/sam2.1-hiera-base-plus")
59
+ hiera_model_in_predictor = predictor.model.image_encoder.trunk
60
+
61
+ # Compare weights
62
+ for name, param in model.named_parameters():
63
+ if not torch.equal(param, hiera_model_in_predictor.state_dict()[name]):
64
+ print(f"The parameter {name} has different weights in the two models.")
65
+
66
+ print("Comparison complete!")
67
+ ```
68
+
69
+ ---
70
+
71
+ ## License
72
+
73
+ Please refer to the [SAM2 repository](https://github.com/facebookresearch/sam2) for license and usage terms.
hiera_encoder.py ADDED
@@ -0,0 +1,454 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Adapted from Meta's code base: https://github.com/facebookresearch/sam2
2
+
3
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
4
+ # All rights reserved.
5
+
6
+ # This source code is licensed under the license found in the
7
+ # LICENSE file in the root directory of this source tree.
8
+
9
+ # print(torch.cuda.memory_summary())
10
+
11
+ import logging
12
+ from functools import partial
13
+ from typing import List, Tuple, Union
14
+
15
+ import torch
16
+ import torch.nn as nn
17
+ import torch.nn.functional as F
18
+ from iopath.common.file_io import g_pathmgr
19
+
20
+ from sam2.modeling.backbones.utils import (
21
+ PatchEmbed,
22
+ window_partition,
23
+ window_unpartition,
24
+ )
25
+
26
+ from sam2.modeling.sam2_utils import DropPath, MLP
27
+ from transformers import PretrainedConfig, PreTrainedModel
28
+ import json
29
+
30
+
31
+ def do_pool(x: torch.Tensor, pool: nn.Module, norm: nn.Module = None) -> torch.Tensor:
32
+ if pool is None:
33
+ return x
34
+ # (B, H, W, C) -> (B, C, H, W)
35
+ x = x.permute(0, 3, 1, 2)
36
+ x = pool(x)
37
+ # (B, C, H', W') -> (B, H', W', C)
38
+ x = x.permute(0, 2, 3, 1)
39
+ if norm:
40
+ x = norm(x)
41
+
42
+ return x
43
+
44
+
45
+ def enhanced_scaled_dot_product_attention(query, key, value):
46
+ """
47
+ Computes scaled dot-product attention with a safeguard for large batch sizes.
48
+
49
+ In practice, if the batch size or the resulting tensor size exceeds 2**16,
50
+ it can cause CUDA launch or memory errors due to hardware limitations.
51
+ To address this, we check whether the intermediate tensor size exceeds this threshold.
52
+ If it does, we split the attention computation into smaller chunks,
53
+ perform the attention calculation on each chunk separately,
54
+ and finally merge the results to obtain the final attention output.
55
+ """
56
+
57
+ batch_size = query.shape[0]
58
+ if batch_size<=2**15:
59
+ return F.scaled_dot_product_attention(
60
+ query,
61
+ key,
62
+ value,
63
+ )
64
+ else:
65
+ results = []
66
+ chunk_size = 2**15
67
+ for i in range(0,batch_size,chunk_size):
68
+ q_chunk = query[i:i+chunk_size]
69
+ k_chunk = key[i:i+chunk_size]
70
+ v_chunk = value[i:i+chunk_size]
71
+ out_chunk = F.scaled_dot_product_attention(q_chunk, k_chunk, v_chunk)
72
+ results.append(out_chunk)
73
+ x_chunked = torch.cat(results, dim=0)
74
+ return x_chunked
75
+
76
+
77
+ class MultiScaleAttention(nn.Module):
78
+ def __init__(
79
+ self,
80
+ dim: int,
81
+ dim_out: int,
82
+ num_heads: int,
83
+ q_pool: nn.Module = None,
84
+ ):
85
+ super().__init__()
86
+
87
+ self.dim = dim
88
+ self.dim_out = dim_out
89
+ self.num_heads = num_heads
90
+ self.q_pool = q_pool
91
+ self.qkv = nn.Linear(dim, dim_out * 3)
92
+ self.proj = nn.Linear(dim_out, dim_out)
93
+
94
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
95
+ B, H, W, _ = x.shape
96
+ # qkv with shape (B, H * W, 3, nHead, C)
97
+ qkv = self.qkv(x).reshape(B, H * W, 3, self.num_heads, -1)
98
+ # q, k, v with shape (B, H * W, nheads, C)
99
+ q, k, v = torch.unbind(qkv, 2)
100
+
101
+ # Q pooling (for downsample at stage changes)
102
+ if self.q_pool:
103
+ q = do_pool(q.reshape(B, H, W, -1), self.q_pool)
104
+ H, W = q.shape[1:3] # downsampled shape
105
+ q = q.reshape(B, H * W, self.num_heads, -1)
106
+
107
+ # Torch's SDPA expects [B, nheads, H*W, C] so we transpose
108
+ # x = F.scaled_dot_product_attention(
109
+ # q.transpose(1, 2),
110
+ # k.transpose(1, 2),
111
+ # v.transpose(1, 2),
112
+ # )
113
+
114
+ x = enhanced_scaled_dot_product_attention(
115
+ query=q.transpose(1, 2),
116
+ key=k.transpose(1, 2),
117
+ value=v.transpose(1, 2),
118
+ )
119
+
120
+ # Transpose back
121
+ x = x.transpose(1, 2)
122
+ x = x.reshape(B, H, W, -1)
123
+
124
+ x = self.proj(x)
125
+
126
+ return x
127
+
128
+
129
+ class MultiScaleBlock(nn.Module):
130
+ def __init__(
131
+ self,
132
+ dim: int,
133
+ dim_out: int,
134
+ num_heads: int,
135
+ mlp_ratio: float = 4.0,
136
+ drop_path: float = 0.0,
137
+ norm_layer: Union[nn.Module, str] = "LayerNorm",
138
+ q_stride: Tuple[int, int] = None,
139
+ act_layer: nn.Module = nn.GELU,
140
+ window_size: int = 0,
141
+ ):
142
+ super().__init__()
143
+
144
+ if isinstance(norm_layer, str):
145
+ norm_layer = partial(getattr(nn, norm_layer), eps=1e-6)
146
+
147
+ self.dim = dim
148
+ self.dim_out = dim_out
149
+ self.norm1 = norm_layer(dim)
150
+
151
+ self.window_size = window_size
152
+
153
+ self.pool, self.q_stride = None, q_stride
154
+ if self.q_stride:
155
+ self.pool = nn.MaxPool2d(
156
+ kernel_size=q_stride, stride=q_stride, ceil_mode=False
157
+ )
158
+
159
+ self.attn = MultiScaleAttention(
160
+ dim,
161
+ dim_out,
162
+ num_heads=num_heads,
163
+ q_pool=self.pool,
164
+ )
165
+ self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
166
+
167
+ self.norm2 = norm_layer(dim_out)
168
+ self.mlp = MLP(
169
+ dim_out,
170
+ int(dim_out * mlp_ratio),
171
+ dim_out,
172
+ num_layers=2,
173
+ activation=act_layer,
174
+ )
175
+
176
+ if dim != dim_out:
177
+ self.proj = nn.Linear(dim, dim_out)
178
+
179
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
180
+ shortcut = x # B, H, W, C
181
+ x = self.norm1(x)
182
+
183
+ # Skip connection
184
+ if self.dim != self.dim_out:
185
+ shortcut = do_pool(self.proj(x), self.pool)
186
+
187
+ # Window partition
188
+ window_size = self.window_size
189
+ if window_size > 0:
190
+ H, W = x.shape[1], x.shape[2]
191
+ x, pad_hw = window_partition(x, window_size)
192
+
193
+ # Window Attention + Q Pooling (if stage change)
194
+ x = self.attn(x)
195
+ if self.q_stride:
196
+ # Shapes have changed due to Q pooling
197
+ window_size = self.window_size // self.q_stride[0]
198
+ H, W = shortcut.shape[1:3]
199
+
200
+ pad_h = (window_size - H % window_size) % window_size
201
+ pad_w = (window_size - W % window_size) % window_size
202
+ pad_hw = (H + pad_h, W + pad_w)
203
+
204
+ # Reverse window partition
205
+ if self.window_size > 0:
206
+ x = window_unpartition(x, window_size, pad_hw, (H, W))
207
+
208
+ x = shortcut + self.drop_path(x)
209
+ # MLP
210
+ x = x + self.drop_path(self.mlp(self.norm2(x)))
211
+ return x
212
+
213
+
214
+ class Hiera(nn.Module):
215
+ """
216
+ Reference: https://arxiv.org/abs/2306.00989
217
+ """
218
+
219
+ def __init__(
220
+ self,
221
+ embed_dim: int = 96, # initial embed dim
222
+ num_heads: int = 1, # initial number of heads
223
+ drop_path_rate: float = 0.0, # stochastic depth
224
+ q_pool: int = 3, # number of q_pool stages
225
+ q_stride: Tuple[int, int] = (2, 2), # downsample stride bet. stages
226
+ stages: Tuple[int, ...] = (2, 3, 16, 3), # blocks per stage
227
+ dim_mul: float = 2.0, # dim_mul factor at stage shift
228
+ head_mul: float = 2.0, # head_mul factor at stage shift
229
+ window_pos_embed_bkg_spatial_size: Tuple[int, int] = (14, 14),
230
+ # window size per stage, when not using global att.
231
+ window_spec: Tuple[int, ...] = (
232
+ 8,
233
+ 4,
234
+ 14,
235
+ 7,
236
+ ),
237
+ # global attn in these blocks
238
+ global_att_blocks: Tuple[int, ...] = (
239
+ 12,
240
+ 16,
241
+ 20,
242
+ ),
243
+ weights_path=None,
244
+ return_interm_layers=True, # return feats from every stage
245
+ ):
246
+ super().__init__()
247
+
248
+ assert len(stages) == len(window_spec)
249
+ self.window_spec = window_spec
250
+
251
+ depth = sum(stages)
252
+ self.q_stride = q_stride
253
+ self.stage_ends = [sum(stages[:i]) - 1 for i in range(1, len(stages) + 1)]
254
+ assert 0 <= q_pool <= len(self.stage_ends[:-1])
255
+ self.q_pool_blocks = [x + 1 for x in self.stage_ends[:-1]][:q_pool]
256
+ self.return_interm_layers = return_interm_layers
257
+
258
+ self.patch_embed = PatchEmbed(
259
+ embed_dim=embed_dim,
260
+ )
261
+ # Which blocks have global att?
262
+ self.global_att_blocks = global_att_blocks
263
+
264
+ # Windowed positional embedding (https://arxiv.org/abs/2311.05613)
265
+ self.window_pos_embed_bkg_spatial_size = window_pos_embed_bkg_spatial_size
266
+ self.pos_embed = nn.Parameter(
267
+ torch.zeros(1, embed_dim, *self.window_pos_embed_bkg_spatial_size)
268
+ )
269
+ self.pos_embed_window = nn.Parameter(
270
+ torch.zeros(1, embed_dim, self.window_spec[0], self.window_spec[0])
271
+ )
272
+
273
+ dpr = [
274
+ x.item() for x in torch.linspace(0, drop_path_rate, depth)
275
+ ] # stochastic depth decay rule
276
+
277
+ cur_stage = 1
278
+ self.blocks = nn.ModuleList()
279
+
280
+ for i in range(depth):
281
+ dim_out = embed_dim
282
+ # lags by a block, so first block of
283
+ # next stage uses an initial window size
284
+ # of previous stage and final window size of current stage
285
+ window_size = self.window_spec[cur_stage - 1]
286
+
287
+ if self.global_att_blocks is not None:
288
+ window_size = 0 if i in self.global_att_blocks else window_size
289
+
290
+ if i - 1 in self.stage_ends:
291
+ dim_out = int(embed_dim * dim_mul)
292
+ num_heads = int(num_heads * head_mul)
293
+ cur_stage += 1
294
+
295
+ block = MultiScaleBlock(
296
+ dim=embed_dim,
297
+ dim_out=dim_out,
298
+ num_heads=num_heads,
299
+ drop_path=dpr[i],
300
+ q_stride=self.q_stride if i in self.q_pool_blocks else None,
301
+ window_size=window_size,
302
+ )
303
+
304
+ embed_dim = dim_out
305
+ self.blocks.append(block)
306
+
307
+ self.channel_list = (
308
+ [self.blocks[i].dim_out for i in self.stage_ends[::-1]]
309
+ if return_interm_layers
310
+ else [self.blocks[-1].dim_out]
311
+ )
312
+
313
+ if weights_path is not None:
314
+ with g_pathmgr.open(weights_path, "rb") as f:
315
+ chkpt = torch.load(f, map_location="cpu")
316
+ # logging.info("loading Hiera", self.load_state_dict(chkpt, strict=False))
317
+ res = self.load_state_dict(chkpt, strict=False)
318
+ logging.info(f"loading Hiera: {res}")
319
+
320
+ def _get_pos_embed(self, hw: Tuple[int, int]) -> torch.Tensor:
321
+ h, w = hw
322
+ window_embed = self.pos_embed_window
323
+ pos_embed = F.interpolate(self.pos_embed, size=(h, w), mode="bicubic")
324
+ pos_embed = pos_embed + window_embed.tile(
325
+ [x // y for x, y in zip(pos_embed.shape, window_embed.shape)]
326
+ )
327
+ pos_embed = pos_embed.permute(0, 2, 3, 1)
328
+ return pos_embed
329
+
330
+ def forward(self, x: torch.Tensor) -> List[torch.Tensor]:
331
+ x = self.patch_embed(x)
332
+ # x: (B, H, W, C)
333
+
334
+ # Add pos embed
335
+ x = x + self._get_pos_embed(x.shape[1:3])
336
+
337
+ outputs = []
338
+ for i, blk in enumerate(self.blocks):
339
+ x = blk(x)
340
+ if (i == self.stage_ends[-1]) or (
341
+ i in self.stage_ends and self.return_interm_layers
342
+ ):
343
+ feats = x.permute(0, 3, 1, 2)
344
+ outputs.append(feats)
345
+
346
+ return outputs
347
+
348
+ def get_layer_id(self, layer_name):
349
+ # https://github.com/microsoft/unilm/blob/master/beit/optim_factory.py#L33
350
+ num_layers = self.get_num_layers()
351
+
352
+ if layer_name.find("rel_pos") != -1:
353
+ return num_layers + 1
354
+ elif layer_name.find("pos_embed") != -1:
355
+ return 0
356
+ elif layer_name.find("patch_embed") != -1:
357
+ return 0
358
+ elif layer_name.find("blocks") != -1:
359
+ return int(layer_name.split("blocks")[1].split(".")[1]) + 1
360
+ else:
361
+ return num_layers + 1
362
+
363
+ def get_num_layers(self) -> int:
364
+ return len(self.blocks)
365
+
366
+
367
+ class HieraConfig(PretrainedConfig):
368
+ model_type = "hiera"
369
+
370
+ def __init__(
371
+ self,
372
+ embed_dim=96,
373
+ num_heads=1,
374
+ drop_path_rate=0.0,
375
+ q_pool=3,
376
+ q_stride=(2, 2),
377
+ stages=(2, 3, 16, 3),
378
+ dim_mul=2.0,
379
+ head_mul=2.0,
380
+ window_pos_embed_bkg_spatial_size=(14, 14),
381
+ window_spec=(8, 4, 14, 7),
382
+ global_att_blocks=(12, 16, 20),
383
+ weights_path=None,
384
+ return_interm_layers=True,
385
+ **kwargs,
386
+ ):
387
+ super().__init__(**kwargs)
388
+ self.embed_dim = embed_dim
389
+ self.num_heads = num_heads
390
+ self.drop_path_rate = drop_path_rate
391
+ self.q_pool = q_pool
392
+ self.q_stride = q_stride
393
+ self.stages = stages
394
+ self.dim_mul = dim_mul
395
+ self.head_mul = head_mul
396
+ self.window_pos_embed_bkg_spatial_size = window_pos_embed_bkg_spatial_size
397
+ self.window_spec = window_spec
398
+ self.global_att_blocks = global_att_blocks
399
+ self.weights_path = weights_path
400
+ self.return_interm_layers = return_interm_layers
401
+
402
+ @classmethod
403
+ def from_json_file(cls, json_file):
404
+ with open(json_file, "r") as f:
405
+ config_dict = json.load(f)
406
+ return cls(**config_dict)
407
+
408
+
409
+ class HieraVisionModel(PreTrainedModel):
410
+ config_class = HieraConfig
411
+ _no_split_modules = ["Hiera"]
412
+
413
+ def __init__(self, config, weights_path=None):
414
+ super().__init__(config)
415
+ self.hiera = Hiera(
416
+ embed_dim=config.embed_dim,
417
+ num_heads=config.num_heads,
418
+ drop_path_rate=config.drop_path_rate,
419
+ q_pool=config.q_pool,
420
+ q_stride=config.q_stride,
421
+ stages=config.stages,
422
+ dim_mul=config.dim_mul,
423
+ head_mul=config.head_mul,
424
+ window_pos_embed_bkg_spatial_size=config.window_pos_embed_bkg_spatial_size,
425
+ window_spec=config.window_spec,
426
+ global_att_blocks=config.global_att_blocks,
427
+ return_interm_layers=config.return_interm_layers,
428
+ weights_path=weights_path,
429
+ )
430
+
431
+ def forward(self, x):
432
+ return self.hiera(x)
433
+
434
+
435
+ if __name__ == "__main__":
436
+
437
+ model = HieraVisionModel.from_pretrained("nkkbr/hiera-base-plus-in-sam2.1")
438
+ model = model.hiera
439
+
440
+ for name,param in model.named_parameters():
441
+ print(f"{name:50} {param.shape}")
442
+
443
+
444
+ # Check whether the weights are consistent with the hiera module in sam2.1-hiera-base-plus.
445
+ import torch
446
+ from sam2.sam2_image_predictor import SAM2ImagePredictor
447
+
448
+ predictor = SAM2ImagePredictor.from_pretrained("facebook/sam2.1-hiera-base-plus")
449
+ hiera_model_in_predictor = predictor.model.image_encoder.trunk
450
+
451
+ for name,param in model.named_parameters():
452
+ if not torch.equal(param, hiera_model_in_predictor.state_dict()[name]):
453
+ print(f"The parameter {name} has different weights in the two models.")
454
+ print("Comparison complete!")