Ruggero1912 commited on
Commit
f281853
ยท
1 Parent(s): 86fbb78

Now the patchioner code is installed via pip, the model weights are loaded through repo id from huggingface

Browse files
This view is limited to 50 files because it contains too many changes. ย  See raw diff
Files changed (50) hide show
  1. .gitignore +3 -0
  2. README.md +5 -1
  3. app.py +15 -13
  4. configs/mlp.k.yaml +0 -8
  5. configs/mlp.viecap.k.yaml +0 -31
  6. requirements.txt +3 -36
  7. src/INViTE/clipfolder/__init__.py +0 -1
  8. src/INViTE/clipfolder/bpe_simple_vocab_16e6.txt.gz +0 -3
  9. src/INViTE/clipfolder/clip.py +0 -238
  10. src/INViTE/clipfolder/model.py +0 -515
  11. src/INViTE/clipfolder/simple_tokenizer.py +0 -132
  12. src/INViTE/loader.py +0 -72
  13. src/alphaclip/INSTALL.md +0 -113
  14. src/alphaclip/LICENSE +0 -201
  15. src/alphaclip/MANIFEST.in +0 -7
  16. src/alphaclip/README.md +0 -266
  17. src/alphaclip/__init__.py +0 -14
  18. src/alphaclip/alpha_clip/__init__.py +0 -1
  19. src/alphaclip/alpha_clip/alpha_clip.py +0 -254
  20. src/alphaclip/alpha_clip/bpe_simple_vocab_16e6.txt.gz +0 -3
  21. src/alphaclip/alpha_clip/model.py +0 -609
  22. src/alphaclip/alpha_clip/simple_tokenizer.py +0 -132
  23. src/alphaclip/alpha_mask_utils.py +0 -111
  24. src/alphaclip/alphaclip_loader.py +0 -233
  25. src/alphaclip/example.py +0 -76
  26. src/alphaclip/requirements.txt +0 -10
  27. src/alphaclip/setup.py +0 -47
  28. src/alphaclip/test_installation.py +0 -149
  29. src/bbox_utils.py +0 -421
  30. src/clipcap/CLIPCAP_INTEGRATION.md +0 -206
  31. src/clipcap/clipcapTrainREADME.md +0 -301
  32. src/clipcap/clipcapTraining.py +0 -405
  33. src/clipcap/clipcap_dino_parse_coco.py +0 -613
  34. src/clipcap/clipcap_parse_coco.py +0 -51
  35. src/clipcap/entrypoint.py +0 -564
  36. src/clipcap/predict.py +0 -302
  37. src/dataset.py +0 -94
  38. src/datasetMix.py +0 -153
  39. src/decap/decap.py +0 -193
  40. src/decap/decoderTraining.py +0 -464
  41. src/decap/decoder_config.pkl +0 -3
  42. src/decap/im2txtprojection/im2txtprojection.py +0 -500
  43. src/denseclip/clip_loader/README.md +0 -233
  44. src/denseclip/clip_loader/SUMMARY.md +0 -78
  45. src/denseclip/clip_loader/__init__.py +0 -21
  46. src/denseclip/clip_loader/bpe_simple_vocab_16e6.txt.gz +0 -3
  47. src/denseclip/clip_loader/configs/denseclip_segmentation_vitb16.yaml +0 -41
  48. src/denseclip/clip_loader/configs/denseclip_segmentation_vitb16_long_ctx.yaml +0 -41
  49. src/denseclip/clip_loader/denseclip_loader.py +0 -316
  50. src/denseclip/clip_loader/example_usage.py +0 -108
.gitignore CHANGED
@@ -1,3 +1,6 @@
1
  *.pyc
2
  *.pyo
3
  *.pyd
 
 
 
 
1
  *.pyc
2
  *.pyo
3
  *.pyd
4
+ venv/
5
+ .gradio/
6
+ .venv/
README.md CHANGED
@@ -13,4 +13,8 @@ short_description: 'Repo for the Paper "One Patch to Caption Them All: ...'
13
 
14
  ArXiv: arxiv.org/abs/2510.02898
15
 
16
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
 
 
 
 
13
 
14
  ArXiv: arxiv.org/abs/2510.02898
15
 
16
+ Demo of the Patch-ioner framework, from the paper "One Patch to Caption Them All: A Unified Zero-shot Captioning Framework".
17
+
18
+ The project page is at [paciosoft.com/Patch-ioner](https://paciosoft.com/Patch-ioner).
19
+
20
+
app.py CHANGED
@@ -24,8 +24,7 @@ from PIL import Image
24
  import numpy as np
25
  from typing import List, Dict
26
 
27
- # Import the Patchioner model from the src directory
28
- from src.model import Patchioner
29
 
30
  # Global variable to store the loaded model
31
  loaded_model = None
@@ -33,7 +32,7 @@ model_config_path = None
33
  device = "cuda" if torch.cuda.is_available() else "cpu"
34
 
35
  # Default model configuration
36
- DEFAULT_MODEL_CONFIG = "mlp.viecap.k.yaml"
37
 
38
  # Example images directory
39
  current_dir = os.path.dirname(__file__)
@@ -50,13 +49,16 @@ def initialize_default_model() -> str:
50
  default_config_path = CONFIGS_DIR / DEFAULT_MODEL_CONFIG
51
 
52
  if not default_config_path.exists():
53
- return f"โŒ Default config file not found: {default_config_path}"
 
 
 
 
 
54
 
55
  print(f"Loading default model: {DEFAULT_MODEL_CONFIG}")
56
 
57
- # Load and parse the config
58
- with open(default_config_path, 'r') as f:
59
- config = yaml.safe_load(f)
60
 
61
  # Load the model using the from_config class method
62
  model = Patchioner.from_config(config, device=device)
@@ -553,7 +555,7 @@ def generate_bbox_caption(image_data, image) -> str:
553
  return error_msg
554
 
555
 
556
- def create_gradio_interface():
557
  """Create and configure the Gradio interface."""
558
 
559
  # Get example files
@@ -593,7 +595,7 @@ def create_gradio_interface():
593
  ) as demo:
594
  #gr.HTML(custom_js) # inject custom JS
595
 
596
- gr.Markdown("""
597
  # ๐ŸŽฏ Patchioner Trace Captioning Demo
598
 
599
  This demo allows you to:
@@ -608,7 +610,7 @@ def create_gradio_interface():
608
  3. Use the appropriate tool to mark areas of interest in the image
609
  4. Click "Generate Caption" to get AI-generated descriptions
610
 
611
- **Model:** Using `mlp.karpathy.yaml` configuration (automatically loaded)
612
  """)
613
 
614
  # Initialize model status
@@ -730,7 +732,7 @@ def create_gradio_interface():
730
  outputs=[image_editor, image_annotator]
731
  )
732
 
733
- gr.Markdown("""
734
  ### ๐Ÿ’ก Tips:
735
  - **Mode Selection**: Switch between trace and bounding box modes based on your needs
736
  - **Trace Mode**: Draw continuous lines over areas you want to describe
@@ -741,7 +743,7 @@ def create_gradio_interface():
741
  ### ๐Ÿ”ง Technical Details:
742
  - **Trace Mode**: Converts drawings to normalized (x, y) coordinates with timestamps
743
  - **BBox Mode**: Uses bounding box coordinates for region-specific captioning
744
- - **Model Architecture**: Uses `mlp.karpathy.yaml` configuration with CLIP and ViT components
745
  - **Processing**: Each trace/bbox is processed separately to generate corresponding captions
746
  """)
747
 
@@ -762,7 +764,7 @@ if __name__ == "__main__":
762
  print(f"Example images directory: {EXAMPLE_IMAGES_DIR}")
763
  print(f"Configs directory: {CONFIGS_DIR}")
764
 
765
- demo = create_gradio_interface()
766
  if not args.local:
767
  demo.launch()
768
  else:
 
24
  import numpy as np
25
  from typing import List, Dict
26
 
27
+ from patchioner import Patchioner
 
28
 
29
  # Global variable to store the loaded model
30
  loaded_model = None
 
32
  device = "cuda" if torch.cuda.is_available() else "cpu"
33
 
34
  # Default model configuration
35
+ DEFAULT_MODEL_CONFIG = "https://huggingface.co/Ruggero1912/Patch-ioner_talk2dino_decap_COCO_Captions"
36
 
37
  # Example images directory
38
  current_dir = os.path.dirname(__file__)
 
49
  default_config_path = CONFIGS_DIR / DEFAULT_MODEL_CONFIG
50
 
51
  if not default_config_path.exists():
52
+ print( f"โŒ Default config file not found: {default_config_path}" )
53
+ config = DEFAULT_MODEL_CONFIG # Assume it's a URL or model identifier
54
+ print( f"Attempting to load model from identifier: {config}" )
55
+
56
+ else:
57
+ config = default_config_path
58
 
59
  print(f"Loading default model: {DEFAULT_MODEL_CONFIG}")
60
 
61
+
 
 
62
 
63
  # Load the model using the from_config class method
64
  model = Patchioner.from_config(config, device=device)
 
555
  return error_msg
556
 
557
 
558
+ def create_gradio_interface(model_config_name : str):
559
  """Create and configure the Gradio interface."""
560
 
561
  # Get example files
 
595
  ) as demo:
596
  #gr.HTML(custom_js) # inject custom JS
597
 
598
+ gr.Markdown(f"""
599
  # ๐ŸŽฏ Patchioner Trace Captioning Demo
600
 
601
  This demo allows you to:
 
610
  3. Use the appropriate tool to mark areas of interest in the image
611
  4. Click "Generate Caption" to get AI-generated descriptions
612
 
613
+ **Model:** Using `{model_config_name}` configuration (automatically loaded)
614
  """)
615
 
616
  # Initialize model status
 
732
  outputs=[image_editor, image_annotator]
733
  )
734
 
735
+ gr.Markdown(f"""
736
  ### ๐Ÿ’ก Tips:
737
  - **Mode Selection**: Switch between trace and bounding box modes based on your needs
738
  - **Trace Mode**: Draw continuous lines over areas you want to describe
 
743
  ### ๐Ÿ”ง Technical Details:
744
  - **Trace Mode**: Converts drawings to normalized (x, y) coordinates with timestamps
745
  - **BBox Mode**: Uses bounding box coordinates for region-specific captioning
746
+ - **Model Architecture**: Uses `{model_config_name}` configuration with CLIP and ViT components
747
  - **Processing**: Each trace/bbox is processed separately to generate corresponding captions
748
  """)
749
 
 
764
  print(f"Example images directory: {EXAMPLE_IMAGES_DIR}")
765
  print(f"Configs directory: {CONFIGS_DIR}")
766
 
767
+ demo = create_gradio_interface(DEFAULT_MODEL_CONFIG)
768
  if not args.local:
769
  demo.launch()
770
  else:
configs/mlp.k.yaml DELETED
@@ -1,8 +0,0 @@
1
- decap_weights: 'weights/decap-talk2dino-coco_karpathy-009.pt'
2
- prefix_size: 768
3
- linear_talk2dino: False
4
- support_memory_size: 591753
5
- dino_model: 'dinov2_vitb14_reg'
6
- normalize: True
7
- kkv_attention: False
8
- projection_type: '/raid/datasets/im2txtmemories/coco_train_karpathy.json'
 
 
 
 
 
 
 
 
 
configs/mlp.viecap.k.yaml DELETED
@@ -1,31 +0,0 @@
1
- decap_weights: null
2
- prefix_size: 768
3
- linear_talk2dino: False
4
- support_memory_size: 0
5
- dino_model: 'dinov2_vitb14_reg'
6
- normalize: False
7
- kkv_attention: False
8
- use_talk2dino_project: False
9
- clip_model_name: "ViT-B/16"
10
-
11
-
12
- # nested config
13
- viecap:
14
- clip_hidden_size: 768
15
- suffix: ViT-B16_t2d_
16
- project_length: 10
17
- temperature: 0.01
18
- top_k: 3
19
- threshold: 0.4
20
- language_model: 'gpt2'
21
- name_of_entities_text: coco_entities #vinvl_vgoi_entities
22
- files_path: 'weights/viecap_files/'
23
- prompt_ensemble: True
24
- weight_path: 'weights/viecap-talk2dino-coco_karpathy-0014.pt'
25
- using_hard_prompt: True
26
- soft_prompt_first: True
27
- only_hard_prompt: False
28
- using_greedy_search: True #if false, use beam search
29
- beam_width: 5
30
- text_prompt: None
31
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
requirements.txt CHANGED
@@ -1,36 +1,3 @@
1
- # Core dependencies - absolutely required
2
- torch
3
- transformers==4.46.3
4
- gradio>=4.0.0
5
- gradio_image_annotation
6
-
7
- # Image processing - required for the demo
8
- pillow
9
- torchvision
10
-
11
- # Data handling - required
12
- numpy
13
- tqdm
14
-
15
- # CLIP - essential for the model
16
- git+https://github.com/openai/CLIP.git
17
-
18
- # Model dependencies - needed for core functionality
19
- timm
20
-
21
- # Hugging Face model hosting
22
- huggingface_hub
23
-
24
- h5py
25
-
26
- # Optional: Only include if specifically needed
27
- # h5py # Only needed for some data formats - can be installed conditionally
28
- # scikit-learn # Only for evaluation - not needed for inference
29
- # plotly # Only for plotting - not needed for basic demo
30
- # pandas # Only for data analysis - not needed for basic demo
31
- # matplotlib # Only for plotting - not needed for basic demo
32
- # pycocotools # Only for COCO evaluation - not needed for basic demo
33
- # nbformat # Only for notebooks - not needed for basic demo
34
- # speaksee # Only for evaluation - not needed for basic demo
35
- # munkres # Only for specific evaluation metrics - not needed for basic demo
36
- # open_clip_torch # Only if using open_clip models - not needed for basic demo
 
1
+ git+https://github.com/Ruggero1912/Patch-ioner
2
+ gradio==5.48.0
3
+ gradio_image_annotation
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
src/INViTE/clipfolder/__init__.py DELETED
@@ -1 +0,0 @@
1
- from .clip import *
 
 
src/INViTE/clipfolder/bpe_simple_vocab_16e6.txt.gz DELETED
@@ -1,3 +0,0 @@
1
- version https://git-lfs.github.com/spec/v1
2
- oid sha256:924691ac288e54409236115652ad4aa250f48203de50a9e4722a6ecd48d6804a
3
- size 1356917
 
 
 
 
src/INViTE/clipfolder/clip.py DELETED
@@ -1,238 +0,0 @@
1
- import hashlib
2
- import os
3
- import urllib
4
- import warnings
5
- from typing import Any, Union, List
6
- from pkg_resources import packaging
7
-
8
- import torch
9
- from PIL import Image
10
- from torchvision.transforms import Compose, Resize, CenterCrop, ToTensor, Normalize
11
- from tqdm import tqdm
12
-
13
- from .model import build_model
14
- from .simple_tokenizer import SimpleTokenizer as _Tokenizer
15
-
16
- try:
17
- from torchvision.transforms import InterpolationMode
18
- BICUBIC = InterpolationMode.BICUBIC
19
- except ImportError:
20
- BICUBIC = Image.BICUBIC
21
-
22
-
23
- if packaging.version.parse(torch.__version__) < packaging.version.parse("1.7.1"):
24
- warnings.warn("PyTorch version 1.7.1 or higher is recommended")
25
-
26
-
27
- __all__ = ["available_models", "load", "tokenize"]
28
- _tokenizer = _Tokenizer()
29
-
30
- _MODELS = {
31
- "RN50": "https://openaipublic.azureedge.net/clip/models/afeb0e10f9e5a86da6080e35cf09123aca3b358a0c3e3b6c78a7b63bc04b6762/RN50.pt",
32
- "RN101": "https://openaipublic.azureedge.net/clip/models/8fa8567bab74a42d41c5915025a8e4538c3bdbe8804a470a72f30b0d94fab599/RN101.pt",
33
- "RN50x4": "https://openaipublic.azureedge.net/clip/models/7e526bd135e493cef0776de27d5f42653e6b4c8bf9e0f653bb11773263205fdd/RN50x4.pt",
34
- "RN50x16": "https://openaipublic.azureedge.net/clip/models/52378b407f34354e150460fe41077663dd5b39c54cd0bfd2b27167a4a06ec9aa/RN50x16.pt",
35
- "RN50x64": "https://openaipublic.azureedge.net/clip/models/be1cfb55d75a9666199fb2206c106743da0f6468c9d327f3e0d0a543a9919d9c/RN50x64.pt",
36
- "ViT-B/32": "https://openaipublic.azureedge.net/clip/models/40d365715913c9da98579312b702a82c18be219cc2a73407c4526f58eba950af/ViT-B-32.pt",
37
- "ViT-B/16": "https://openaipublic.azureedge.net/clip/models/5806e77cd80f8b59890b7e101eabd078d9fb84e6937f9e85e4ecb61988df416f/ViT-B-16.pt",
38
- "ViT-L/14": "https://openaipublic.azureedge.net/clip/models/b8cca3fd41ae0c99ba7e8951adf17d267cdb84cd88be6f7c2e0eca1737a03836/ViT-L-14.pt",
39
- "ViT-L/14@336px": "https://openaipublic.azureedge.net/clip/models/3035c92b350959924f9f00213499208652fc7ea050643e8b385c2dac08641f02/ViT-L-14-336px.pt",
40
- }
41
-
42
-
43
- def _download(url: str, root: str):
44
- os.makedirs(root, exist_ok=True)
45
- filename = os.path.basename(url)
46
-
47
- expected_sha256 = url.split("/")[-2]
48
- download_target = os.path.join(root, filename)
49
-
50
- if os.path.exists(download_target) and not os.path.isfile(download_target):
51
- raise RuntimeError(f"{download_target} exists and is not a regular file")
52
-
53
- if os.path.isfile(download_target):
54
- if hashlib.sha256(open(download_target, "rb").read()).hexdigest() == expected_sha256:
55
- return download_target
56
- else:
57
- warnings.warn(f"{download_target} exists, but the SHA256 checksum does not match; re-downloading the file")
58
-
59
- with urllib.request.urlopen(url) as source, open(download_target, "wb") as output:
60
- with tqdm(total=int(source.info().get("Content-Length")), ncols=80, unit='iB', unit_scale=True, unit_divisor=1024) as loop:
61
- while True:
62
- buffer = source.read(8192)
63
- if not buffer:
64
- break
65
-
66
- output.write(buffer)
67
- loop.update(len(buffer))
68
-
69
- if hashlib.sha256(open(download_target, "rb").read()).hexdigest() != expected_sha256:
70
- raise RuntimeError("Model has been downloaded but the SHA256 checksum does not not match")
71
-
72
- return download_target
73
-
74
-
75
- def _convert_image_to_rgb(image):
76
- return image.convert("RGB")
77
-
78
-
79
- def _transform(n_px):
80
- return Compose([
81
- Resize(n_px, interpolation=BICUBIC),
82
- CenterCrop(n_px),
83
- _convert_image_to_rgb,
84
- ToTensor(),
85
- Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711)),
86
- ])
87
-
88
-
89
- def available_models() -> List[str]:
90
- """Returns the names of available CLIP models"""
91
- return list(_MODELS.keys())
92
-
93
-
94
- def load(name: str, device: Union[str, torch.device] = "cuda" if torch.cuda.is_available() else "cpu", jit: bool = False,
95
- download_root: str = None, extract_last_k_th_token: int = -1, viz: bool = False, image_resolution: int = None):
96
- """Load a CLIP model
97
-
98
- Parameters
99
- ----------
100
- name : str
101
- A model name listed by `clip.available_models()`, or the path to a model checkpoint containing the state_dict
102
-
103
- device : Union[str, torch.device]
104
- The device to put the loaded model
105
-
106
- jit : bool
107
- Whether to load the optimized JIT model or more hackable non-JIT model (default).
108
-
109
- download_root: str
110
- path to download the model files; by default, it uses "~/.cache/clip"
111
-
112
- Returns
113
- -------
114
- model : torch.nn.Module
115
- The CLIP model
116
-
117
- preprocess : Callable[[PIL.Image], torch.Tensor]
118
- A torchvision transform that converts a PIL image into a tensor that the returned model can take as its input
119
- """
120
- if name in _MODELS:
121
- model_path = _download(_MODELS[name], download_root or os.path.expanduser("~/.cache/clip"))
122
- elif os.path.isfile(name):
123
- model_path = name
124
- else:
125
- raise RuntimeError(f"Model {name} not found; available models = {available_models()}")
126
-
127
- with open(model_path, 'rb') as opened_file:
128
- try:
129
- # loading JIT archive
130
- model = torch.jit.load(opened_file, map_location=device if jit else "cpu").eval()
131
- state_dict = None
132
- except RuntimeError:
133
- # loading saved state dict
134
- if jit:
135
- warnings.warn(f"File {model_path} is not a JIT archive. Loading as a state dict instead")
136
- jit = False
137
- state_dict = torch.load(opened_file, map_location="cpu")
138
-
139
- if not jit:
140
- model = build_model(state_dict or model.state_dict(), extract_last_k_th_token, viz, image_resolution=image_resolution).to(device)
141
- if str(device) == "cpu":
142
- model.float()
143
- return model, _transform(model.visual.input_resolution)
144
-
145
- # patch the device names
146
- device_holder = torch.jit.trace(lambda: torch.ones([]).to(torch.device(device)), example_inputs=[])
147
- device_node = [n for n in device_holder.graph.findAllNodes("prim::Constant") if "Device" in repr(n)][-1]
148
-
149
- def patch_device(module):
150
- try:
151
- graphs = [module.graph] if hasattr(module, "graph") else []
152
- except RuntimeError:
153
- graphs = []
154
-
155
- if hasattr(module, "forward1"):
156
- graphs.append(module.forward1.graph)
157
-
158
- for graph in graphs:
159
- for node in graph.findAllNodes("prim::Constant"):
160
- if "value" in node.attributeNames() and str(node["value"]).startswith("cuda"):
161
- node.copyAttributes(device_node)
162
-
163
- model.apply(patch_device)
164
- patch_device(model.encode_image)
165
- patch_device(model.encode_text)
166
-
167
- # patch dtype to float32 on CPU
168
- if str(device) == "cpu":
169
- float_holder = torch.jit.trace(lambda: torch.ones([]).float(), example_inputs=[])
170
- float_input = list(float_holder.graph.findNode("aten::to").inputs())[1]
171
- float_node = float_input.node()
172
-
173
- def patch_float(module):
174
- try:
175
- graphs = [module.graph] if hasattr(module, "graph") else []
176
- except RuntimeError:
177
- graphs = []
178
-
179
- if hasattr(module, "forward1"):
180
- graphs.append(module.forward1.graph)
181
-
182
- for graph in graphs:
183
- for node in graph.findAllNodes("aten::to"):
184
- inputs = list(node.inputs())
185
- for i in [1, 2]: # dtype can be the second or third argument to aten::to()
186
- if inputs[i].node()["value"] == 5:
187
- inputs[i].node().copyAttributes(float_node)
188
-
189
- model.apply(patch_float)
190
- patch_float(model.encode_image)
191
- patch_float(model.encode_text)
192
-
193
- model.float()
194
-
195
- return model, _transform(model.input_resolution.item())
196
-
197
-
198
- def tokenize(texts: Union[str, List[str]], context_length: int = 77, truncate: bool = False) -> Union[torch.IntTensor, torch.LongTensor]:
199
- """
200
- Returns the tokenized representation of given input string(s)
201
-
202
- Parameters
203
- ----------
204
- texts : Union[str, List[str]]
205
- An input string or a list of input strings to tokenize
206
-
207
- context_length : int
208
- The context length to use; all CLIP models use 77 as the context length
209
-
210
- truncate: bool
211
- Whether to truncate the text in case its encoding is longer than the context length
212
-
213
- Returns
214
- -------
215
- A two-dimensional tensor containing the resulting tokens, shape = [number of input strings, context_length].
216
- We return LongTensor when torch version is <1.8.0, since older index_select requires indices to be long.
217
- """
218
- if isinstance(texts, str):
219
- texts = [texts]
220
-
221
- sot_token = _tokenizer.encoder["<|startoftext|>"]
222
- eot_token = _tokenizer.encoder["<|endoftext|>"]
223
- all_tokens = [[sot_token] + _tokenizer.encode(text) + [eot_token] for text in texts]
224
- if packaging.version.parse(torch.__version__) < packaging.version.parse("1.8.0"):
225
- result = torch.zeros(len(all_tokens), context_length, dtype=torch.long)
226
- else:
227
- result = torch.zeros(len(all_tokens), context_length, dtype=torch.int)
228
-
229
- for i, tokens in enumerate(all_tokens):
230
- if len(tokens) > context_length:
231
- if truncate:
232
- tokens = tokens[:context_length]
233
- tokens[-1] = eot_token
234
- else:
235
- raise RuntimeError(f"Input {texts[i]} is too long for context length {context_length}")
236
- result[i, :len(tokens)] = torch.tensor(tokens)
237
-
238
- return result
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
src/INViTE/clipfolder/model.py DELETED
@@ -1,515 +0,0 @@
1
- from collections import OrderedDict
2
- from typing import Tuple, Union
3
-
4
- import numpy as np
5
- import torch
6
- import torch.nn.functional as F
7
- from torch import nn
8
-
9
-
10
- class Bottleneck(nn.Module):
11
- expansion = 4
12
-
13
- def __init__(self, inplanes, planes, stride=1):
14
- super().__init__()
15
-
16
- # all conv layers have stride 1. an avgpool is performed after the second convolution when stride > 1
17
- self.conv1 = nn.Conv2d(inplanes, planes, 1, bias=False)
18
- self.bn1 = nn.BatchNorm2d(planes)
19
- self.relu1 = nn.ReLU(inplace=True)
20
-
21
- self.conv2 = nn.Conv2d(planes, planes, 3, padding=1, bias=False)
22
- self.bn2 = nn.BatchNorm2d(planes)
23
- self.relu2 = nn.ReLU(inplace=True)
24
-
25
- self.avgpool = nn.AvgPool2d(stride) if stride > 1 else nn.Identity()
26
-
27
- self.conv3 = nn.Conv2d(planes, planes * self.expansion, 1, bias=False)
28
- self.bn3 = nn.BatchNorm2d(planes * self.expansion)
29
- self.relu3 = nn.ReLU(inplace=True)
30
-
31
- self.downsample = None
32
- self.stride = stride
33
-
34
- if stride > 1 or inplanes != planes * Bottleneck.expansion:
35
- # downsampling layer is prepended with an avgpool, and the subsequent convolution has stride 1
36
- self.downsample = nn.Sequential(OrderedDict([
37
- ("-1", nn.AvgPool2d(stride)),
38
- ("0", nn.Conv2d(inplanes, planes * self.expansion, 1, stride=1, bias=False)),
39
- ("1", nn.BatchNorm2d(planes * self.expansion))
40
- ]))
41
-
42
- def forward(self, x: torch.Tensor):
43
- identity = x
44
-
45
- out = self.relu1(self.bn1(self.conv1(x)))
46
- out = self.relu2(self.bn2(self.conv2(out)))
47
- out = self.avgpool(out)
48
- out = self.bn3(self.conv3(out))
49
-
50
- if self.downsample is not None:
51
- identity = self.downsample(x)
52
-
53
- out += identity
54
- out = self.relu3(out)
55
- return out
56
-
57
-
58
- class AttentionPool2d(nn.Module):
59
- def __init__(self, spacial_dim: int, embed_dim: int, num_heads: int, output_dim: int = None):
60
- super().__init__()
61
- self.positional_embedding = nn.Parameter(torch.randn(spacial_dim ** 2 + 1, embed_dim) / embed_dim ** 0.5)
62
- self.k_proj = nn.Linear(embed_dim, embed_dim)
63
- self.q_proj = nn.Linear(embed_dim, embed_dim)
64
- self.v_proj = nn.Linear(embed_dim, embed_dim)
65
- self.c_proj = nn.Linear(embed_dim, output_dim or embed_dim)
66
- self.num_heads = num_heads
67
-
68
- def forward(self, x):
69
- x = x.flatten(start_dim=2).permute(2, 0, 1) # NCHW -> (HW)NC
70
- x = torch.cat([x.mean(dim=0, keepdim=True), x], dim=0) # (HW+1)NC
71
- x = x + self.positional_embedding[:, None, :].to(x.dtype) # (HW+1)NC
72
- x, _ = F.multi_head_attention_forward(
73
- query=x[:1], key=x, value=x,
74
- embed_dim_to_check=x.shape[-1],
75
- num_heads=self.num_heads,
76
- q_proj_weight=self.q_proj.weight,
77
- k_proj_weight=self.k_proj.weight,
78
- v_proj_weight=self.v_proj.weight,
79
- in_proj_weight=None,
80
- in_proj_bias=torch.cat([self.q_proj.bias, self.k_proj.bias, self.v_proj.bias]),
81
- bias_k=None,
82
- bias_v=None,
83
- add_zero_attn=False,
84
- dropout_p=0,
85
- out_proj_weight=self.c_proj.weight,
86
- out_proj_bias=self.c_proj.bias,
87
- use_separate_proj_weight=True,
88
- training=self.training,
89
- need_weights=False
90
- )
91
- return x.squeeze(0)
92
-
93
-
94
- class ModifiedResNet(nn.Module):
95
- """
96
- A ResNet class that is similar to torchvision's but contains the following changes:
97
- - There are now 3 "stem" convolutions as opposed to 1, with an average pool instead of a max pool.
98
- - Performs anti-aliasing strided convolutions, where an avgpool is prepended to convolutions with stride > 1
99
- - The final pooling layer is a QKV attention instead of an average pool
100
- """
101
-
102
- def __init__(self, layers, output_dim, heads, input_resolution=224, width=64):
103
- super().__init__()
104
- self.output_dim = output_dim
105
- self.input_resolution = input_resolution
106
-
107
- # the 3-layer stem
108
- self.conv1 = nn.Conv2d(3, width // 2, kernel_size=3, stride=2, padding=1, bias=False)
109
- self.bn1 = nn.BatchNorm2d(width // 2)
110
- self.relu1 = nn.ReLU(inplace=True)
111
- self.conv2 = nn.Conv2d(width // 2, width // 2, kernel_size=3, padding=1, bias=False)
112
- self.bn2 = nn.BatchNorm2d(width // 2)
113
- self.relu2 = nn.ReLU(inplace=True)
114
- self.conv3 = nn.Conv2d(width // 2, width, kernel_size=3, padding=1, bias=False)
115
- self.bn3 = nn.BatchNorm2d(width)
116
- self.relu3 = nn.ReLU(inplace=True)
117
- self.avgpool = nn.AvgPool2d(2)
118
-
119
- # residual layers
120
- self._inplanes = width # this is a *mutable* variable used during construction
121
- self.layer1 = self._make_layer(width, layers[0])
122
- self.layer2 = self._make_layer(width * 2, layers[1], stride=2)
123
- self.layer3 = self._make_layer(width * 4, layers[2], stride=2)
124
- self.layer4 = self._make_layer(width * 8, layers[3], stride=2)
125
-
126
- embed_dim = width * 32 # the ResNet feature dimension
127
- self.attnpool = AttentionPool2d(input_resolution // 32, embed_dim, heads, output_dim)
128
-
129
- def _make_layer(self, planes, blocks, stride=1):
130
- layers = [Bottleneck(self._inplanes, planes, stride)]
131
-
132
- self._inplanes = planes * Bottleneck.expansion
133
- for _ in range(1, blocks):
134
- layers.append(Bottleneck(self._inplanes, planes))
135
-
136
- return nn.Sequential(*layers)
137
-
138
- def forward(self, x):
139
- def stem(x):
140
- x = self.relu1(self.bn1(self.conv1(x)))
141
- x = self.relu2(self.bn2(self.conv2(x)))
142
- x = self.relu3(self.bn3(self.conv3(x)))
143
- x = self.avgpool(x)
144
- return x
145
-
146
- x = x.type(self.conv1.weight.dtype)
147
- x = stem(x)
148
- x = self.layer1(x)
149
- x = self.layer2(x)
150
- x = self.layer3(x)
151
- x = self.layer4(x)
152
- x = self.attnpool(x)
153
-
154
- return x
155
-
156
-
157
- class LayerNorm(nn.LayerNorm):
158
- """Subclass torch's LayerNorm to handle fp16."""
159
-
160
- def forward(self, x: torch.Tensor):
161
- orig_type = x.dtype
162
- ret = super().forward(x.type(torch.float32))
163
- return ret.type(orig_type)
164
-
165
-
166
- class QuickGELU(nn.Module):
167
- def forward(self, x: torch.Tensor):
168
- return x * torch.sigmoid(1.702 * x)
169
-
170
-
171
- class ResidualAttentionBlock(nn.Module):
172
- def __init__(self, d_model: int, n_head: int, attn_mask: torch.Tensor = None, viz: bool = False):
173
- super().__init__()
174
-
175
- if viz:
176
- self.attn = nn.MultiheadAttentionViz(d_model, n_head)
177
- else:
178
- self.attn = nn.MultiheadAttention(d_model, n_head)
179
-
180
- self.ln_1 = LayerNorm(d_model)
181
- self.mlp = nn.Sequential(OrderedDict([
182
- ("c_fc", nn.Linear(d_model, d_model * 4)),
183
- ("gelu", QuickGELU()),
184
- ("c_proj", nn.Linear(d_model * 4, d_model))
185
- ]))
186
- self.ln_2 = LayerNorm(d_model)
187
- self.attn_mask = attn_mask
188
-
189
- """attn_mask โ€“ If specified, a 2D or 3D mask preventing attention to certain positions.
190
- Must be of shape (L,S)(L, S)(L,S) or (Nโ‹…num_heads,L,S)(N\cdot\text{num\_heads}, L, S)(Nโ‹…num_heads,L,S),
191
- where NNN is the batch size, LLL is the target sequence length, and SSS is the source sequence length.
192
- A 2D mask will be broadcasted across the batch while a 3D mask allows for a different mask for each entry
193
- in the batch. Binary, byte, and float masks are supported. For a binary mask,
194
- a True value indicates that the corresponding position is not allowed to attend.
195
- For a byte mask, a non-zero value indicates that the corresponding position is not allowed to attend.
196
- For a float mask, the mask values will be added to the attention weight."""
197
-
198
- def attention(self, x: torch.Tensor):
199
- self.attn_mask = self.attn_mask.to(dtype=x.dtype, device=x.device) if self.attn_mask is not None else None
200
- return self.attn(x, x, x, need_weights=False, attn_mask=self.attn_mask)[0]
201
-
202
- def forward(self, x: torch.Tensor):
203
- x = x + self.attention(self.ln_1(x))
204
- x = x + self.mlp(self.ln_2(x))
205
- return x
206
-
207
-
208
-
209
- class Transformer(nn.Module):
210
- def __init__(self, width: int, layers: int, heads: int, attn_mask: torch.Tensor = None,
211
- extract_last_k_th_token: int=-1, viz: bool = False, num_tokens: int = 50):
212
- super().__init__()
213
- self.width = width
214
- self.layers = layers
215
- print('\n\n\n\n\ntransformer total layers', layers)
216
- if extract_last_k_th_token>0:
217
- start_mask_layer = layers - extract_last_k_th_token
218
-
219
- ans = []
220
- for cnt in range(layers):
221
- if cnt < start_mask_layer:
222
- ans.append(ResidualAttentionBlock(width, heads, attn_mask, viz))
223
- else:
224
- print(' mask for layer {}'.format(cnt))
225
- mask = torch.empty(num_tokens, num_tokens)
226
- mask.fill_(float("-inf"))
227
- mask.fill_diagonal_(0)
228
- ans.append(ResidualAttentionBlock(width, heads, mask.cuda(), viz))
229
- # TODO: here is hard coded 50 sequence length
230
- # only attend to themselves
231
-
232
- self.resblocks = nn.Sequential(*ans)
233
- else:
234
- self.resblocks = nn.Sequential(*[ResidualAttentionBlock(width, heads, attn_mask, viz) for _ in range(layers)])
235
-
236
- def forward(self, x: torch.Tensor):
237
- return self.resblocks(x)
238
-
239
-
240
- class VisionTransformer(nn.Module):
241
- def __init__(self, input_resolution: int, patch_size: int, width: int, layers: int, heads: int, output_dim: int,
242
- extract_last_k_th_token: int, viz: bool):
243
- super().__init__()
244
- self.input_resolution = input_resolution
245
- self.output_dim = output_dim
246
- self.conv1 = nn.Conv2d(in_channels=3, out_channels=width, kernel_size=patch_size, stride=patch_size, bias=False)
247
-
248
- scale = width ** -0.5
249
- self.class_embedding = nn.Parameter(scale * torch.randn(width))
250
- self.positional_embedding = nn.Parameter(scale * torch.randn((input_resolution // patch_size) ** 2 + 1, width))
251
- self.ln_pre = LayerNorm(width)
252
-
253
- self.transformer = Transformer(width, layers, heads, extract_last_k_th_token=extract_last_k_th_token, viz=viz, num_tokens=(input_resolution // patch_size) ** 2 + 1)
254
-
255
- self.ln_post = LayerNorm(width)
256
- self.proj = nn.Parameter(scale * torch.randn(width, output_dim))
257
-
258
- def forward(self, x: torch.Tensor, get_all_last: bool):
259
- # convert x to conv1 dtype
260
- x = x.type(self.conv1.weight.dtype)
261
- x = self.conv1(x) # shape = [*, width, grid, grid]
262
- x = x.reshape(x.shape[0], x.shape[1], -1) # shape = [*, width, grid ** 2]
263
- x = x.permute(0, 2, 1) # shape = [*, grid ** 2, width]
264
- x = torch.cat([self.class_embedding.to(x.dtype) + torch.zeros(x.shape[0], 1, x.shape[-1], dtype=x.dtype, device=x.device), x], dim=1) # shape = [*, grid ** 2 + 1, width]
265
- x = x + self.positional_embedding.to(x.dtype)
266
- x = self.ln_pre(x)
267
-
268
- x = x.permute(1, 0, 2) # NLD -> LND
269
- x = self.transformer(x)
270
- x = x.permute(1, 0, 2) # LND -> NLD
271
-
272
- if get_all_last:
273
- # take all tokens, x is of shape [*, grid ** 2 + 1, width]
274
- # and we apply layer norm to each token separately
275
- # x is of shape [*, grid ** 2 + 1, width]
276
- x = torch.cat([self.ln_post(x[:, idx, :]).unsqueeze(1) for idx in range(x.size(1))], dim=1)
277
- else:
278
- # take the first token (CLS token), x is of shape [*, grid ** 2 + 1, width]
279
- x = self.ln_post(x[:, 0, :])
280
-
281
- if self.proj is not None:
282
- x = x @ self.proj
283
- # the returned x is of shape [*, output_dim] where * is the batch size or if get_all_last is True, [*, grid ** 2 + 1, output_dim]
284
- return x
285
-
286
-
287
- class CLIP(nn.Module):
288
- def __init__(self,
289
- embed_dim: int,
290
- # vision
291
- image_resolution: int,
292
- vision_layers: Union[Tuple[int, int, int, int], int],
293
- vision_width: int,
294
- vision_patch_size: int,
295
- # text
296
- context_length: int,
297
- vocab_size: int,
298
- transformer_width: int,
299
- transformer_heads: int,
300
- transformer_layers: int,
301
- extract_last_k_th_token: int = -1,
302
- viz: bool = False
303
- ):
304
- super().__init__()
305
-
306
- self.context_length = context_length
307
-
308
- if isinstance(vision_layers, (tuple, list)):
309
- vision_heads = vision_width * 32 // 64
310
- self.visual = ModifiedResNet(
311
- layers=vision_layers,
312
- output_dim=embed_dim,
313
- heads=vision_heads,
314
- input_resolution=image_resolution,
315
- width=vision_width
316
- )
317
- else:
318
- vision_heads = vision_width // 64
319
- self.visual = VisionTransformer(
320
- input_resolution=image_resolution,
321
- patch_size=vision_patch_size,
322
- width=vision_width,
323
- layers=vision_layers,
324
- heads=vision_heads,
325
- output_dim=embed_dim,
326
- extract_last_k_th_token=extract_last_k_th_token,
327
- viz=viz
328
- )
329
-
330
- self.transformer = Transformer(
331
- width=transformer_width,
332
- layers=transformer_layers,
333
- heads=transformer_heads,
334
- attn_mask=self.build_attention_mask()
335
- )
336
-
337
- self.vocab_size = vocab_size
338
- self.token_embedding = nn.Embedding(vocab_size, transformer_width)
339
- self.positional_embedding = nn.Parameter(torch.empty(self.context_length, transformer_width))
340
- self.ln_final = LayerNorm(transformer_width)
341
-
342
- self.text_projection = nn.Parameter(torch.empty(transformer_width, embed_dim))
343
- self.logit_scale = nn.Parameter(torch.ones([]) * np.log(1 / 0.07))
344
-
345
- self.initialize_parameters()
346
-
347
- def initialize_parameters(self):
348
- nn.init.normal_(self.token_embedding.weight, std=0.02)
349
- nn.init.normal_(self.positional_embedding, std=0.01)
350
-
351
- if isinstance(self.visual, ModifiedResNet):
352
- if self.visual.attnpool is not None:
353
- std = self.visual.attnpool.c_proj.in_features ** -0.5
354
- nn.init.normal_(self.visual.attnpool.q_proj.weight, std=std)
355
- nn.init.normal_(self.visual.attnpool.k_proj.weight, std=std)
356
- nn.init.normal_(self.visual.attnpool.v_proj.weight, std=std)
357
- nn.init.normal_(self.visual.attnpool.c_proj.weight, std=std)
358
-
359
- for resnet_block in [self.visual.layer1, self.visual.layer2, self.visual.layer3, self.visual.layer4]:
360
- for name, param in resnet_block.named_parameters():
361
- if name.endswith("bn3.weight"):
362
- nn.init.zeros_(param)
363
-
364
- proj_std = (self.transformer.width ** -0.5) * ((2 * self.transformer.layers) ** -0.5)
365
- attn_std = self.transformer.width ** -0.5
366
- fc_std = (2 * self.transformer.width) ** -0.5
367
- for block in self.transformer.resblocks:
368
- nn.init.normal_(block.attn.in_proj_weight, std=attn_std)
369
- nn.init.normal_(block.attn.out_proj.weight, std=proj_std)
370
- nn.init.normal_(block.mlp.c_fc.weight, std=fc_std)
371
- nn.init.normal_(block.mlp.c_proj.weight, std=proj_std)
372
-
373
- if self.text_projection is not None:
374
- nn.init.normal_(self.text_projection, std=self.transformer.width ** -0.5)
375
-
376
- def build_attention_mask(self):
377
- # lazily create causal attention mask, with full attention between the vision tokens
378
- # pytorch uses additive attention mask; fill with -inf
379
- mask = torch.empty(self.context_length, self.context_length)
380
- mask.fill_(float("-inf"))
381
- mask.triu_(1) # zero out the lower diagonal
382
- return mask
383
-
384
- @property
385
- def dtype(self):
386
- return self.visual.conv1.weight.dtype
387
-
388
- def encode_image(self, image, get_all_last):
389
- return self.visual(image.type(self.dtype), get_all_last)
390
-
391
- def encode_text(self, text):
392
- x = self.token_embedding(text).type(self.dtype) # [batch_size, n_ctx, d_model]
393
-
394
- x = x + self.positional_embedding.type(self.dtype)
395
- x = x.permute(1, 0, 2) # NLD -> LND
396
- x = self.transformer(x)
397
- x = x.permute(1, 0, 2) # LND -> NLD
398
- x = self.ln_final(x).type(self.dtype)
399
-
400
- # x.shape = [batch_size, n_ctx, transformer.width]
401
- # take features from the eot embedding (eot_token is the highest number in each sequence)
402
- x = x[torch.arange(x.shape[0]), text.argmax(dim=-1)] @ self.text_projection
403
-
404
- return x
405
-
406
- def forward(self, image, text, get_all_last=False):
407
- image_features = self.encode_image(image, get_all_last)
408
- text_features = self.encode_text(text)
409
-
410
- # normalized features
411
- image_features = image_features / image_features.norm(dim=1, keepdim=True)
412
- text_features = text_features / text_features.norm(dim=1, keepdim=True)
413
-
414
- # cosine similarity as logits
415
- logit_scale = self.logit_scale.exp()
416
-
417
- if get_all_last:
418
- return logit_scale * image_features, text_features
419
-
420
-
421
- logits_per_image = logit_scale * image_features @ text_features.t()
422
- logits_per_text = logits_per_image.t()
423
-
424
- # shape = [global_batch_size, global_batch_size]
425
- return logits_per_image, logits_per_text
426
-
427
-
428
- def convert_weights(model: nn.Module):
429
- """Convert applicable model parameters to fp16"""
430
-
431
- def _convert_weights_to_fp16(l):
432
- if isinstance(l, (nn.Conv1d, nn.Conv2d, nn.Linear)):
433
- l.weight.data = l.weight.data.half()
434
- if l.bias is not None:
435
- l.bias.data = l.bias.data.half()
436
-
437
- if isinstance(l, nn.MultiheadAttention):
438
- for attr in [*[f"{s}_proj_weight" for s in ["in", "q", "k", "v"]], "in_proj_bias", "bias_k", "bias_v"]:
439
- tensor = getattr(l, attr)
440
- if tensor is not None:
441
- tensor.data = tensor.data.half()
442
-
443
- for name in ["text_projection", "proj"]:
444
- if hasattr(l, name):
445
- attr = getattr(l, name)
446
- if attr is not None:
447
- attr.data = attr.data.half()
448
-
449
- model.apply(_convert_weights_to_fp16)
450
-
451
- import torch
452
- import torch.nn.functional as F
453
-
454
- def resize_pos_embed(old_pe: torch.Tensor, new_shape: int) -> torch.Tensor:
455
- # old_pe: [old_num_patches + 1, C]
456
- # new_shape: new_num_patches + 1
457
- cls_token = old_pe[:1]
458
- patch_pe = old_pe[1:]
459
- old_num = int(patch_pe.shape[0] ** 0.5)
460
- new_num = int((new_shape - 1) ** 0.5)
461
-
462
- patch_pe = patch_pe.reshape(1, old_num, old_num, -1).permute(0, 3, 1, 2) # (1, C, H, W)
463
- patch_pe = F.interpolate(patch_pe, size=(new_num, new_num), mode='bicubic', align_corners=False)
464
- patch_pe = patch_pe.permute(0, 2, 3, 1).reshape(1, new_num * new_num, -1)
465
-
466
- return torch.cat([cls_token.unsqueeze(0), patch_pe], dim=1).squeeze(0)
467
-
468
- def build_model(state_dict: dict, extract_last_k_th_token, viz, image_resolution: int = None) -> CLIP:
469
- vit = "visual.proj" in state_dict
470
-
471
- if vit:
472
- vision_width = state_dict["visual.conv1.weight"].shape[0]
473
- vision_layers = len([k for k in state_dict.keys() if k.startswith("visual.") and k.endswith(".attn.in_proj_weight")])
474
- vision_patch_size = state_dict["visual.conv1.weight"].shape[-1]
475
- grid_size = round((state_dict["visual.positional_embedding"].shape[0] - 1) ** 0.5)
476
- if image_resolution is None:
477
- image_resolution = vision_patch_size * grid_size
478
- else:
479
- counts: list = [len(set(k.split(".")[2] for k in state_dict if k.startswith(f"visual.layer{b}"))) for b in [1, 2, 3, 4]]
480
- vision_layers = tuple(counts)
481
- vision_width = state_dict["visual.layer1.0.conv1.weight"].shape[0]
482
- output_width = round((state_dict["visual.attnpool.positional_embedding"].shape[0] - 1) ** 0.5)
483
- vision_patch_size = None
484
- assert output_width ** 2 + 1 == state_dict["visual.attnpool.positional_embedding"].shape[0]
485
- if image_resolution is None:
486
- image_resolution = output_width * 32
487
-
488
- embed_dim = state_dict["text_projection"].shape[1]
489
- context_length = state_dict["positional_embedding"].shape[0]
490
- vocab_size = state_dict["token_embedding.weight"].shape[0]
491
- transformer_width = state_dict["ln_final.weight"].shape[0]
492
- transformer_heads = transformer_width // 64
493
- transformer_layers = len(set(k.split(".")[2] for k in state_dict if k.startswith("transformer.resblocks")))
494
-
495
- model = CLIP(
496
- embed_dim,
497
- image_resolution, vision_layers, vision_width, vision_patch_size,
498
- context_length, vocab_size, transformer_width, transformer_heads, transformer_layers, extract_last_k_th_token, viz
499
- )
500
-
501
- for key in ["input_resolution", "context_length", "vocab_size"]:
502
- if key in state_dict:
503
- del state_dict[key]
504
-
505
- convert_weights(model)
506
-
507
- pretrained_pe = state_dict['visual.positional_embedding']
508
- model_pe = model.visual.positional_embedding
509
-
510
- if vit and (pretrained_pe.shape != model_pe.shape):
511
- print(f"Interpolating positional embedding from {pretrained_pe.shape} to {model_pe.shape}")
512
- state_dict['visual.positional_embedding'] = resize_pos_embed(pretrained_pe, model_pe.shape[0])
513
-
514
- model.load_state_dict(state_dict)
515
- return model.eval()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
src/INViTE/clipfolder/simple_tokenizer.py DELETED
@@ -1,132 +0,0 @@
1
- import gzip
2
- import html
3
- import os
4
- from functools import lru_cache
5
-
6
- import ftfy
7
- import regex as re
8
-
9
-
10
- @lru_cache()
11
- def default_bpe():
12
- return os.path.join(os.path.dirname(os.path.abspath(__file__)), "bpe_simple_vocab_16e6.txt.gz")
13
-
14
-
15
- @lru_cache()
16
- def bytes_to_unicode():
17
- """
18
- Returns list of utf-8 byte and a corresponding list of unicode strings.
19
- The reversible bpe codes work on unicode strings.
20
- This means you need a large # of unicode characters in your vocab if you want to avoid UNKs.
21
- When you're at something like a 10B token dataset you end up needing around 5K for decent coverage.
22
- This is a signficant percentage of your normal, say, 32K bpe vocab.
23
- To avoid that, we want lookup tables between utf-8 bytes and unicode strings.
24
- And avoids mapping to whitespace/control characters the bpe code barfs on.
25
- """
26
- bs = list(range(ord("!"), ord("~")+1))+list(range(ord("ยก"), ord("ยฌ")+1))+list(range(ord("ยฎ"), ord("รฟ")+1))
27
- cs = bs[:]
28
- n = 0
29
- for b in range(2**8):
30
- if b not in bs:
31
- bs.append(b)
32
- cs.append(2**8+n)
33
- n += 1
34
- cs = [chr(n) for n in cs]
35
- return dict(zip(bs, cs))
36
-
37
-
38
- def get_pairs(word):
39
- """Return set of symbol pairs in a word.
40
- Word is represented as tuple of symbols (symbols being variable-length strings).
41
- """
42
- pairs = set()
43
- prev_char = word[0]
44
- for char in word[1:]:
45
- pairs.add((prev_char, char))
46
- prev_char = char
47
- return pairs
48
-
49
-
50
- def basic_clean(text):
51
- text = ftfy.fix_text(text)
52
- text = html.unescape(html.unescape(text))
53
- return text.strip()
54
-
55
-
56
- def whitespace_clean(text):
57
- text = re.sub(r'\s+', ' ', text)
58
- text = text.strip()
59
- return text
60
-
61
-
62
- class SimpleTokenizer(object):
63
- def __init__(self, bpe_path: str = default_bpe()):
64
- self.byte_encoder = bytes_to_unicode()
65
- self.byte_decoder = {v: k for k, v in self.byte_encoder.items()}
66
- merges = gzip.open(bpe_path).read().decode("utf-8").split('\n')
67
- merges = merges[1:49152-256-2+1]
68
- merges = [tuple(merge.split()) for merge in merges]
69
- vocab = list(bytes_to_unicode().values())
70
- vocab = vocab + [v+'</w>' for v in vocab]
71
- for merge in merges:
72
- vocab.append(''.join(merge))
73
- vocab.extend(['<|startoftext|>', '<|endoftext|>'])
74
- self.encoder = dict(zip(vocab, range(len(vocab))))
75
- self.decoder = {v: k for k, v in self.encoder.items()}
76
- self.bpe_ranks = dict(zip(merges, range(len(merges))))
77
- self.cache = {'<|startoftext|>': '<|startoftext|>', '<|endoftext|>': '<|endoftext|>'}
78
- self.pat = re.compile(r"""<\|startoftext\|>|<\|endoftext\|>|'s|'t|'re|'ve|'m|'ll|'d|[\p{L}]+|[\p{N}]|[^\s\p{L}\p{N}]+""", re.IGNORECASE)
79
-
80
- def bpe(self, token):
81
- if token in self.cache:
82
- return self.cache[token]
83
- word = tuple(token[:-1]) + ( token[-1] + '</w>',)
84
- pairs = get_pairs(word)
85
-
86
- if not pairs:
87
- return token+'</w>'
88
-
89
- while True:
90
- bigram = min(pairs, key = lambda pair: self.bpe_ranks.get(pair, float('inf')))
91
- if bigram not in self.bpe_ranks:
92
- break
93
- first, second = bigram
94
- new_word = []
95
- i = 0
96
- while i < len(word):
97
- try:
98
- j = word.index(first, i)
99
- new_word.extend(word[i:j])
100
- i = j
101
- except:
102
- new_word.extend(word[i:])
103
- break
104
-
105
- if word[i] == first and i < len(word)-1 and word[i+1] == second:
106
- new_word.append(first+second)
107
- i += 2
108
- else:
109
- new_word.append(word[i])
110
- i += 1
111
- new_word = tuple(new_word)
112
- word = new_word
113
- if len(word) == 1:
114
- break
115
- else:
116
- pairs = get_pairs(word)
117
- word = ' '.join(word)
118
- self.cache[token] = word
119
- return word
120
-
121
- def encode(self, text):
122
- bpe_tokens = []
123
- text = whitespace_clean(basic_clean(text)).lower()
124
- for token in re.findall(self.pat, text):
125
- token = ''.join(self.byte_encoder[b] for b in token.encode('utf-8'))
126
- bpe_tokens.extend(self.encoder[bpe_token] for bpe_token in self.bpe(token).split(' '))
127
- return bpe_tokens
128
-
129
- def decode(self, tokens):
130
- text = ''.join([self.decoder[token] for token in tokens])
131
- text = bytearray([self.byte_decoder[c] for c in text]).decode('utf-8', errors="replace").replace('</w>', ' ')
132
- return text
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
src/INViTE/loader.py DELETED
@@ -1,72 +0,0 @@
1
- import torch
2
- from typing import Union
3
- from .clipfolder.clip import load as invite_clip_load, tokenize as invite_clip_tokenize
4
-
5
-
6
- def load_invite_clip(config: dict, device: Union[str, torch.device] = "cuda" if torch.cuda.is_available() else "cpu"):
7
- """
8
- Load an INViTE CLIP model based on the provided configuration.
9
-
10
- This method loads an INViTE CLIP model similar to how RegionCLIP is loaded in the Patchioner class.
11
-
12
- Args:
13
- config (dict): Configuration dictionary containing the following keys:
14
- - name (str): Model name listed by `clip.available_models()`, or path to a model checkpoint
15
- - jit (bool, optional): Whether to load the optimized JIT model. Defaults to False
16
- - download_root (str, optional): Path to download model files. Defaults to '/raid/datasets/models_weights/INViTE'
17
- - extract_last_k_th_token (int, optional): Extract last k-th token. Defaults to -1
18
- - viz (bool, optional): Visualization flag. Defaults to False
19
- device (Union[str, torch.device], optional): Device to load the model on.
20
- Defaults to "cuda" if available, else "cpu"
21
-
22
- Returns:
23
- tuple: (model, preprocess_transform, tokenize_fn)
24
- - model: The loaded INViTE CLIP model
25
- - preprocess_transform: Torchvision transform for preprocessing images
26
- - tokenize_fn: Tokenization function for text processing
27
-
28
- Raises:
29
- KeyError: If required 'name' key is missing from config
30
- RuntimeError: If model loading fails
31
-
32
- Example:
33
- config = {
34
- 'name': 'ViT-B/32',
35
- 'jit': False,
36
- 'download_root': '/raid/datasets/models_weights/INViTE', # optional, this is the default
37
- 'extract_last_k_th_token': -1,
38
- 'viz': False
39
- }
40
- model, preprocess, tokenize = load_invite_clip(config, device='cuda')
41
- """
42
-
43
- # Validate required parameters
44
- if 'name' not in config:
45
- raise KeyError("'name' key is required in config dictionary")
46
-
47
- # Extract parameters with defaults
48
- name = config['name']
49
- jit = config.get('jit', False)
50
- download_root = config.get('download_root', '/raid/datasets/models_weights/INViTE')
51
- extract_last_k_th_token = config.get('extract_last_k_th_token', -1)
52
- viz = config.get('viz', False)
53
-
54
- image_resolution = config.get('resolution', None) # Default resolution if not specified
55
-
56
- # Load the INViTE CLIP model using the clip.load function
57
- try:
58
- model, preprocess_transform = invite_clip_load(
59
- name=name,
60
- device=device,
61
- jit=jit,
62
- download_root=download_root,
63
- extract_last_k_th_token=extract_last_k_th_token,
64
- viz=viz,
65
- image_resolution=image_resolution
66
- )
67
-
68
- # Return model, preprocess transform, and tokenize function
69
- return model, preprocess_transform, invite_clip_tokenize
70
-
71
- except Exception as e:
72
- raise RuntimeError(f"Failed to load INViTE CLIP model '{name}': {str(e)}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
src/alphaclip/INSTALL.md DELETED
@@ -1,113 +0,0 @@
1
- # AlphaCLIP Standalone - Installation Guide
2
-
3
- ## Quick Installation
4
-
5
- ### Prerequisites
6
- - Python 3.7 or higher
7
- - pip package manager
8
-
9
- ### Step 1: Install Dependencies
10
-
11
- ```bash
12
- cd alphaclip-standalone
13
- pip install -r requirements.txt
14
- ```
15
-
16
- ### Step 2: Install the Package
17
-
18
- ```bash
19
- # Install in development mode (recommended for testing)
20
- pip install -e .
21
-
22
- # OR install normally
23
- pip install .
24
- ```
25
-
26
- ### Step 3: Test Installation
27
-
28
- ```bash
29
- python test_installation.py
30
- ```
31
-
32
- ### Step 4: Run Example
33
-
34
- ```bash
35
- python example.py
36
- ```
37
-
38
- ## Manual Dependency Installation
39
-
40
- If you encounter issues with the requirements.txt, install dependencies manually:
41
-
42
- ```bash
43
- # Core PyTorch (choose appropriate version for your system)
44
- pip install torch torchvision torchaudio
45
-
46
- # Text processing
47
- pip install ftfy regex tqdm
48
-
49
- # LoRA support
50
- pip install loralib
51
-
52
- # Image processing
53
- pip install Pillow
54
-
55
- # Utilities
56
- pip install numpy packaging
57
- ```
58
-
59
- ## GPU Support
60
-
61
- For CUDA support, make sure you install PyTorch with CUDA:
62
-
63
- ```bash
64
- # For CUDA 11.8
65
- pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu118
66
-
67
- # For CUDA 12.1
68
- pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu121
69
-
70
- # Check your CUDA version with: nvidia-smi
71
- ```
72
-
73
- ## Verification
74
-
75
- After installation, verify everything works:
76
-
77
- ```python
78
- from alphaclip_loader import AlphaCLIPLoader
79
-
80
- # This should work without errors
81
- loader = AlphaCLIPLoader()
82
- models = loader.available_models()
83
- print("Available models:", models)
84
- ```
85
-
86
- ## Troubleshooting
87
-
88
- ### Common Issues
89
-
90
- 1. **ImportError: No module named 'loralib'**
91
- ```bash
92
- pip install loralib
93
- ```
94
-
95
- 2. **CUDA out of memory**
96
- - Use CPU: `AlphaCLIPLoader(default_device="cpu")`
97
- - Or use a smaller model like "ViT-B/32"
98
-
99
- 3. **Model download fails**
100
- - Check internet connection
101
- - Ensure you have enough disk space (~1GB per model)
102
- - Models are cached in `~/.cache/clip/`
103
-
104
- 4. **Permission errors**
105
- - Use `--user` flag: `pip install --user -e .`
106
-
107
- ### Getting Help
108
-
109
- If you encounter issues:
110
- 1. Check that all dependencies are properly installed
111
- 2. Run the test script: `python test_installation.py`
112
- 3. Check CUDA compatibility if using GPU
113
- 4. Ensure Python version is 3.7+
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
src/alphaclip/LICENSE DELETED
@@ -1,201 +0,0 @@
1
- Apache License
2
- Version 2.0, January 2004
3
- http://www.apache.org/licenses/
4
-
5
- TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
6
-
7
- 1. Definitions.
8
-
9
- "License" shall mean the terms and conditions for use, reproduction,
10
- and distribution as defined by Sections 1 through 9 of this document.
11
-
12
- "Licensor" shall mean the copyright owner or entity authorized by
13
- the copyright owner that is granting the License.
14
-
15
- "Legal Entity" shall mean the union of the acting entity and all
16
- other entities that control, are controlled by, or are under common
17
- control with that entity. For the purposes of this definition,
18
- "control" means (i) the power, direct or indirect, to cause the
19
- direction or management of such entity, whether by contract or
20
- otherwise, or (ii) ownership of fifty percent (50%) or more of the
21
- outstanding shares, or (iii) beneficial ownership of such entity.
22
-
23
- "You" (or "Your") shall mean an individual or Legal Entity
24
- exercising permissions granted by this License.
25
-
26
- "Source" form shall mean the preferred form for making modifications,
27
- including but not limited to software source code, documentation
28
- source, and configuration files.
29
-
30
- "Object" form shall mean any form resulting from mechanical
31
- transformation or translation of a Source form, including but
32
- not limited to compiled object code, generated documentation,
33
- and conversions to other media types.
34
-
35
- "Work" shall mean the work of authorship, whether in Source or
36
- Object form, made available under the License, as indicated by a
37
- copyright notice that is included in or attached to the work
38
- (an example is provided in the Appendix below).
39
-
40
- "Derivative Works" shall mean any work, whether in Source or Object
41
- form, that is based on (or derived from) the Work and for which the
42
- editorial revisions, annotations, elaborations, or other modifications
43
- represent, as a whole, an original work of authorship. For the purposes
44
- of this License, Derivative Works shall not include works that remain
45
- separable from, or merely link (or bind by name) to the interfaces of,
46
- the Work and Derivative Works thereof.
47
-
48
- "Contribution" shall mean any work of authorship, including
49
- the original version of the Work and any modifications or additions
50
- to that Work or Derivative Works thereof, that is intentionally
51
- submitted to Licensor for inclusion in the Work by the copyright owner
52
- or by an individual or Legal Entity authorized to submit on behalf of
53
- the copyright owner. For the purposes of this definition, "submitted"
54
- means any form of electronic, verbal, or written communication sent
55
- to the Licensor or its representatives, including but not limited to
56
- communication on electronic mailing lists, source code control systems,
57
- and issue tracking systems that are managed by, or on behalf of, the
58
- Licensor for the purpose of discussing and improving the Work, but
59
- excluding communication that is conspicuously marked or otherwise
60
- designated in writing by the copyright owner as "Not a Contribution."
61
-
62
- "Contributor" shall mean Licensor and any individual or Legal Entity
63
- on behalf of whom a Contribution has been received by Licensor and
64
- subsequently incorporated within the Work.
65
-
66
- 2. Grant of Copyright License. Subject to the terms and conditions of
67
- this License, each Contributor hereby grants to You a perpetual,
68
- worldwide, non-exclusive, no-charge, royalty-free, irrevocable
69
- copyright license to reproduce, prepare Derivative Works of,
70
- publicly display, publicly perform, sublicense, and distribute the
71
- Work and such Derivative Works in Source or Object form.
72
-
73
- 3. Grant of Patent License. Subject to the terms and conditions of
74
- this License, each Contributor hereby grants to You a perpetual,
75
- worldwide, non-exclusive, no-charge, royalty-free, irrevocable
76
- (except as stated in this section) patent license to make, have made,
77
- use, offer to sell, sell, import, and otherwise transfer the Work,
78
- where such license applies only to those patent claims licensable
79
- by such Contributor that are necessarily infringed by their
80
- Contribution(s) alone or by combination of their Contribution(s)
81
- with the Work to which such Contribution(s) was submitted. If You
82
- institute patent litigation against any entity (including a
83
- cross-claim or counterclaim in a lawsuit) alleging that the Work
84
- or a Contribution incorporated within the Work constitutes direct
85
- or contributory patent infringement, then any patent licenses
86
- granted to You under this License for that Work shall terminate
87
- as of the date such litigation is filed.
88
-
89
- 4. Redistribution. You may reproduce and distribute copies of the
90
- Work or Derivative Works thereof in any medium, with or without
91
- modifications, and in Source or Object form, provided that You
92
- meet the following conditions:
93
-
94
- (a) You must give any other recipients of the Work or
95
- Derivative Works a copy of this License; and
96
-
97
- (b) You must cause any modified files to carry prominent notices
98
- stating that You changed the files; and
99
-
100
- (c) You must retain, in the Source form of any Derivative Works
101
- that You distribute, all copyright, patent, trademark, and
102
- attribution notices from the Source form of the Work,
103
- excluding those notices that do not pertain to any part of
104
- the Derivative Works; and
105
-
106
- (d) If the Work includes a "NOTICE" text file as part of its
107
- distribution, then any Derivative Works that You distribute must
108
- include a readable copy of the attribution notices contained
109
- within such NOTICE file, excluding those notices that do not
110
- pertain to any part of the Derivative Works, in at least one
111
- of the following places: within a NOTICE text file distributed
112
- as part of the Derivative Works; within the Source form or
113
- documentation, if provided along with the Derivative Works; or,
114
- within a display generated by the Derivative Works, if and
115
- wherever such third-party notices normally appear. The contents
116
- of the NOTICE file are for informational purposes only and
117
- do not modify the License. You may add Your own attribution
118
- notices within Derivative Works that You distribute, alongside
119
- or as an addendum to the NOTICE text from the Work, provided
120
- that such additional attribution notices cannot be construed
121
- as modifying the License.
122
-
123
- You may add Your own copyright statement to Your modifications and
124
- may provide additional or different license terms and conditions
125
- for use, reproduction, or distribution of Your modifications, or
126
- for any such Derivative Works as a whole, provided Your use,
127
- reproduction, and distribution of the Work otherwise complies with
128
- the conditions stated in this License.
129
-
130
- 5. Submission of Contributions. Unless You explicitly state otherwise,
131
- any Contribution intentionally submitted for inclusion in the Work
132
- by You to the Licensor shall be under the terms and conditions of
133
- this License, without any additional terms or conditions.
134
- Notwithstanding the above, nothing herein shall supersede or modify
135
- the terms of any separate license agreement you may have executed
136
- with Licensor regarding such Contributions.
137
-
138
- 6. Trademarks. This License does not grant permission to use the trade
139
- names, trademarks, service marks, or product names of the Licensor,
140
- except as required for reasonable and customary use in describing the
141
- origin of the Work and reproducing the content of the NOTICE file.
142
-
143
- 7. Disclaimer of Warranty. Unless required by applicable law or
144
- agreed to in writing, Licensor provides the Work (and each
145
- Contributor provides its Contributions) on an "AS IS" BASIS,
146
- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
147
- implied, including, without limitation, any warranties or conditions
148
- of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
149
- PARTICULAR PURPOSE. You are solely responsible for determining the
150
- appropriateness of using or redistributing the Work and assume any
151
- risks associated with Your exercise of permissions under this License.
152
-
153
- 8. Limitation of Liability. In no event and under no legal theory,
154
- whether in tort (including negligence), contract, or otherwise,
155
- unless required by applicable law (such as deliberate and grossly
156
- negligent acts) or agreed to in writing, shall any Contributor be
157
- liable to You for damages, including any direct, indirect, special,
158
- incidental, or consequential damages of any character arising as a
159
- result of this License or out of the use or inability to use the
160
- Work (including but not limited to damages for loss of goodwill,
161
- work stoppage, computer failure or malfunction, or any and all
162
- other commercial damages or losses), even if such Contributor
163
- has been advised of the possibility of such damages.
164
-
165
- 9. Accepting Warranty or Additional Liability. While redistributing
166
- the Work or Derivative Works thereof, You may choose to offer,
167
- and charge a fee for, acceptance of support, warranty, indemnity,
168
- or other liability obligations and/or rights consistent with this
169
- License. However, in accepting such obligations, You may act only
170
- on Your own behalf and on Your sole responsibility, not on behalf
171
- of any other Contributor, and only if You agree to indemnify,
172
- defend, and hold each Contributor harmless for any liability
173
- incurred by, or claims asserted against, such Contributor by reason
174
- of your accepting any such warranty or additional liability.
175
-
176
- END OF TERMS AND CONDITIONS
177
-
178
- APPENDIX: How to apply the Apache License to your work.
179
-
180
- To apply the Apache License to your work, attach the following
181
- boilerplate notice, with the fields enclosed by brackets "[]"
182
- replaced with your own identifying information. (Don't include
183
- the brackets!) The text should be enclosed in the appropriate
184
- comment syntax for the file format. We also recommend that a
185
- file or class name and description of purpose be included on the
186
- same "printed page" as the copyright notice for easier
187
- identification within third-party archives.
188
-
189
- Copyright [Zeyi Sun] [name of copyright owner]
190
-
191
- Licensed under the Apache License, Version 2.0 (the "License");
192
- you may not use this file except in compliance with the License.
193
- You may obtain a copy of the License at
194
-
195
- http://www.apache.org/licenses/LICENSE-2.0
196
-
197
- Unless required by applicable law or agreed to in writing, software
198
- distributed under the License is distributed on an "AS IS" BASIS,
199
- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
200
- See the License for the specific language governing permissions and
201
- limitations under the License.
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
src/alphaclip/MANIFEST.in DELETED
@@ -1,7 +0,0 @@
1
- include README.md
2
- include requirements.txt
3
- include LICENSE
4
- recursive-include alpha_clip *.py
5
- recursive-include alpha_clip *.gz
6
- include example.py
7
- include test_installation.py
 
 
 
 
 
 
 
 
src/alphaclip/README.md DELETED
@@ -1,266 +0,0 @@
1
- # AlphaCLIP Standalone
2
-
3
- A standalone, easy-to-use version of AlphaCLIP that can be integrated into any project without complex dependencies or setup.
4
-
5
- ## Overview
6
-
7
- AlphaCLIP is an enhanced version of OpenAI's CLIP model that provides improved vision-language understanding capabilities. This standalone package makes it easy to use AlphaCLIP in your projects with minimal setup.
8
-
9
- ## Features
10
-
11
- - **Easy Installation**: Simple pip install with minimal dependencies
12
- - **Clean API**: Intuitive interface for loading models and processing data
13
- - **Device Flexibility**: Automatic CUDA/CPU detection with manual override options
14
- - **Model Variety**: Support for multiple AlphaCLIP model variants
15
- - **Preprocessing Included**: Built-in image preprocessing and text tokenization
16
-
17
- ## Installation
18
-
19
- ### Requirements
20
-
21
- - Python 3.7 or higher
22
- - PyTorch 1.7.1 or higher
23
- - CUDA (optional, for GPU acceleration)
24
-
25
- ### Install from source
26
-
27
- ```bash
28
- # Clone or download this standalone package
29
- cd alphaclip-standalone
30
-
31
- # Install dependencies
32
- pip install -r requirements.txt
33
-
34
- # Install the package
35
- pip install -e .
36
- ```
37
-
38
- ### Core Dependencies
39
-
40
- The package requires the following core dependencies:
41
-
42
- ```
43
- torch>=1.7.1
44
- torchvision
45
- ftfy
46
- regex
47
- tqdm
48
- loralib
49
- Pillow
50
- numpy
51
- packaging
52
- ```
53
-
54
- ## Quick Start
55
-
56
- ### Basic Usage
57
-
58
- ```python
59
- from alphaclip_loader import AlphaCLIPLoader
60
-
61
- # Initialize the loader
62
- loader = AlphaCLIPLoader()
63
-
64
- # Load a model (this will download the model if not cached)
65
- model, preprocess = loader.load_model("ViT-B/16")
66
-
67
- # Tokenize text
68
- text_tokens = loader.tokenize("A photo of a cat")
69
-
70
- # Get text embeddings
71
- text_features = loader.encode_text(model, "A photo of a cat")
72
-
73
- print(f"Text features shape: {text_features.shape}")
74
- ```
75
-
76
- ### Advanced Usage
77
-
78
- ```python
79
- import torch
80
- from PIL import Image
81
- from alphaclip_loader import AlphaCLIPLoader
82
-
83
- # Initialize with specific device
84
- loader = AlphaCLIPLoader(default_device="cuda")
85
-
86
- # Load model with custom options
87
- model, preprocess = loader.load_model(
88
- "ViT-B/16",
89
- device="cuda",
90
- lora_adapt=False,
91
- rank=16
92
- )
93
-
94
- # Process an image
95
- image = Image.open("your_image.jpg")
96
- image_tensor = preprocess(image).unsqueeze(0) # Add batch dimension
97
-
98
- # Get embeddings
99
- with torch.no_grad():
100
- image_features = loader.encode_image(model, image_tensor)
101
- text_features = loader.encode_text(model, ["A photo of a cat", "A dog playing"])
102
-
103
- # Compute similarities
104
- similarities = loader.get_similarity(text_features, image_features)
105
- print(f"Similarities: {similarities}")
106
- ```
107
-
108
- ### One-line Model Loading
109
-
110
- ```python
111
- from alphaclip_loader import load_alphaclip
112
-
113
- # Quick loading function
114
- loader, model, preprocess = load_alphaclip("ViT-B/16", device="cuda")
115
- ```
116
-
117
- ## Available Models
118
-
119
- You can check available models using:
120
-
121
- ```python
122
- from alphaclip_loader import AlphaCLIPLoader
123
-
124
- loader = AlphaCLIPLoader()
125
- models = loader.available_models()
126
- print("Available models:", models)
127
- ```
128
-
129
- Typically includes:
130
- - `ViT-B/32`
131
- - `ViT-B/16`
132
- - `ViT-L/14`
133
- - `ViT-L/14@336px`
134
- - `RN50`, `RN101`, `RN50x4`, `RN50x16`, `RN50x64`
135
-
136
- ## API Reference
137
-
138
- ### AlphaCLIPLoader Class
139
-
140
- #### Methods
141
-
142
- - **`__init__(default_device=None)`**: Initialize loader with optional default device
143
- - **`available_models()`**: Get list of available model names
144
- - **`load_model(name, **kwargs)`**: Load a model with preprocessing function
145
- - **`tokenize(texts, context_length=77, truncate=True)`**: Tokenize text input
146
- - **`encode_text(model, texts)`**: Encode text to embeddings
147
- - **`encode_image(model, images)`**: Encode images to embeddings
148
- - **`get_similarity(text_features, image_features)`**: Compute cosine similarity
149
-
150
- #### load_model Parameters
151
-
152
- - `name`: Model name or checkpoint path
153
- - `alpha_vision_ckpt_pth`: Additional vision checkpoint path (default: "None")
154
- - `device`: Device to load on (default: auto-detect)
155
- - `jit`: Use JIT compilation (default: False)
156
- - `download_root`: Model download directory (default: ~/.cache/clip)
157
- - `lora_adapt`: Use LoRA adaptation (default: False)
158
- - `rank`: LoRA rank if enabled (default: 16)
159
-
160
- ## Example Use Cases
161
-
162
- ### Image-Text Similarity
163
-
164
- ```python
165
- from alphaclip_loader import load_alphaclip
166
- from PIL import Image
167
- import torch
168
-
169
- loader, model, preprocess = load_alphaclip()
170
-
171
- # Load and preprocess image
172
- image = Image.open("cat.jpg")
173
- image_input = preprocess(image).unsqueeze(0)
174
-
175
- # Define candidate texts
176
- texts = ["a cat", "a dog", "a bird", "a car"]
177
-
178
- # Get features
179
- image_features = loader.encode_image(model, image_input)
180
- text_features = loader.encode_text(model, texts)
181
-
182
- # Calculate similarities
183
- similarities = loader.get_similarity(text_features, image_features)
184
-
185
- # Find best match
186
- best_match_idx = similarities.argmax()
187
- print(f"Best match: {texts[best_match_idx]} (score: {similarities[best_match_idx]:.3f})")
188
- ```
189
-
190
- ### Batch Processing
191
-
192
- ```python
193
- from alphaclip_loader import AlphaCLIPLoader
194
- import torch
195
-
196
- loader = AlphaCLIPLoader()
197
- model, preprocess = loader.load_model("ViT-B/16")
198
-
199
- # Process multiple texts at once
200
- texts = [
201
- "A red apple on a table",
202
- "A dog running in the park",
203
- "A beautiful sunset"
204
- ]
205
-
206
- # Batch tokenization and encoding
207
- text_features = loader.encode_text(model, texts)
208
- print(f"Batch text features shape: {text_features.shape}") # [3, 512]
209
- ```
210
-
211
- ## Performance Tips
212
-
213
- 1. **GPU Usage**: Use CUDA for better performance with larger models
214
- 2. **Batch Processing**: Process multiple texts/images together when possible
215
- 3. **Model Caching**: Models are automatically cached after first download
216
- 4. **Memory Management**: Use `torch.no_grad()` during inference to save memory
217
-
218
- ## Troubleshooting
219
-
220
- ### Common Issues
221
-
222
- 1. **CUDA Out of Memory**: Reduce batch size or use CPU
223
- 2. **Model Download Fails**: Check internet connection and disk space
224
- 3. **Import Errors**: Ensure all dependencies are installed
225
-
226
- ### Dependencies Issues
227
-
228
- If you encounter import errors, try:
229
-
230
- ```bash
231
- pip install --upgrade torch torchvision
232
- pip install ftfy regex tqdm loralib
233
- ```
234
-
235
- ## File Structure
236
-
237
- ```
238
- alphaclip-standalone/
239
- โ”œโ”€โ”€ __init__.py # Package initialization
240
- โ”œโ”€โ”€ alphaclip_loader.py # Main loader class
241
- โ”œโ”€โ”€ requirements.txt # Dependencies
242
- โ”œโ”€โ”€ setup.py # Package setup
243
- โ”œโ”€โ”€ README.md # This file
244
- โ””โ”€โ”€ alpha_clip/ # Core AlphaCLIP modules
245
- โ”œโ”€โ”€ __init__.py
246
- โ”œโ”€โ”€ alpha_clip.py # Main AlphaCLIP functions
247
- โ”œโ”€โ”€ model.py # Model architectures
248
- โ”œโ”€โ”€ simple_tokenizer.py # Text tokenization
249
- โ””โ”€โ”€ bpe_simple_vocab_16e6.txt.gz # Tokenizer vocabulary
250
- ```
251
-
252
- ## License
253
-
254
- This standalone package maintains the same license as the original AlphaCLIP project.
255
-
256
- ## Contributing
257
-
258
- This is a standalone distribution. For contributions to the core AlphaCLIP model, please refer to the main AlphaCLIP repository.
259
-
260
- ## Changelog
261
-
262
- ### Version 1.0.0
263
- - Initial standalone release
264
- - Clean API with AlphaCLIPLoader class
265
- - Comprehensive documentation and examples
266
- - Easy installation and setup
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
src/alphaclip/__init__.py DELETED
@@ -1,14 +0,0 @@
1
- """
2
- AlphaCLIP Standalone Package
3
-
4
- A standalone version of AlphaCLIP that can be used independently.
5
- """
6
-
7
- from .alphaclip_loader import AlphaCLIPLoader, load_alphaclip
8
-
9
- # Version info
10
- __version__ = "1.0.0"
11
- __author__ = "AlphaCLIP Team"
12
-
13
- # Make main classes available at package level
14
- __all__ = ['AlphaCLIPLoader', 'load_alphaclip']
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
src/alphaclip/alpha_clip/__init__.py DELETED
@@ -1 +0,0 @@
1
- from .alpha_clip import *
 
 
src/alphaclip/alpha_clip/alpha_clip.py DELETED
@@ -1,254 +0,0 @@
1
- import hashlib
2
- import os
3
- import urllib
4
- import warnings
5
- from typing import Any, Union, List
6
- from pkg_resources import packaging
7
-
8
- import torch
9
- from PIL import Image
10
- from torchvision.transforms import Compose, Resize, CenterCrop, ToTensor, Normalize
11
- from tqdm import tqdm
12
-
13
- from .model import build_model
14
- from .simple_tokenizer import SimpleTokenizer as _Tokenizer
15
-
16
- try:
17
- from torchvision.transforms import InterpolationMode
18
- BICUBIC = InterpolationMode.BICUBIC
19
- except ImportError:
20
- BICUBIC = Image.BICUBIC
21
-
22
-
23
- if packaging.version.parse(torch.__version__) < packaging.version.parse("1.7.1"):
24
- warnings.warn("PyTorch version 1.7.1 or higher is recommended")
25
-
26
-
27
- __all__ = ["available_models", "load", "tokenize"]
28
- _tokenizer = _Tokenizer()
29
-
30
- _MODELS = {
31
- "RN50": "https://openaipublic.azureedge.net/clip/models/afeb0e10f9e5a86da6080e35cf09123aca3b358a0c3e3b6c78a7b63bc04b6762/RN50.pt",
32
- "RN101": "https://openaipublic.azureedge.net/clip/models/8fa8567bab74a42d41c5915025a8e4538c3bdbe8804a470a72f30b0d94fab599/RN101.pt",
33
- "RN50x4": "https://openaipublic.azureedge.net/clip/models/7e526bd135e493cef0776de27d5f42653e6b4c8bf9e0f653bb11773263205fdd/RN50x4.pt",
34
- "RN50x16": "https://openaipublic.azureedge.net/clip/models/52378b407f34354e150460fe41077663dd5b39c54cd0bfd2b27167a4a06ec9aa/RN50x16.pt",
35
- "RN50x64": "https://openaipublic.azureedge.net/clip/models/be1cfb55d75a9666199fb2206c106743da0f6468c9d327f3e0d0a543a9919d9c/RN50x64.pt",
36
- "ViT-B/32": "https://openaipublic.azureedge.net/clip/models/40d365715913c9da98579312b702a82c18be219cc2a73407c4526f58eba950af/ViT-B-32.pt",
37
- "ViT-B/16": "https://openaipublic.azureedge.net/clip/models/5806e77cd80f8b59890b7e101eabd078d9fb84e6937f9e85e4ecb61988df416f/ViT-B-16.pt",
38
- "ViT-L/14": "https://openaipublic.azureedge.net/clip/models/b8cca3fd41ae0c99ba7e8951adf17d267cdb84cd88be6f7c2e0eca1737a03836/ViT-L-14.pt",
39
- "ViT-L/14@336px": "https://openaipublic.azureedge.net/clip/models/3035c92b350959924f9f00213499208652fc7ea050643e8b385c2dac08641f02/ViT-L-14-336px.pt",
40
- }
41
-
42
-
43
- def _download(url: str, root: str):
44
- os.makedirs(root, exist_ok=True)
45
- filename = os.path.basename(url)
46
-
47
- expected_sha256 = url.split("/")[-2]
48
- download_target = os.path.join(root, filename)
49
-
50
- if os.path.exists(download_target) and not os.path.isfile(download_target):
51
- raise RuntimeError(f"{download_target} exists and is not a regular file")
52
-
53
- if os.path.isfile(download_target):
54
- if hashlib.sha256(open(download_target, "rb").read()).hexdigest() == expected_sha256:
55
- return download_target
56
- else:
57
- warnings.warn(f"{download_target} exists, but the SHA256 checksum does not match; re-downloading the file")
58
-
59
- with urllib.request.urlopen(url) as source, open(download_target, "wb") as output:
60
- with tqdm(total=int(source.info().get("Content-Length")), ncols=80, unit='iB', unit_scale=True, unit_divisor=1024) as loop:
61
- while True:
62
- buffer = source.read(8192)
63
- if not buffer:
64
- break
65
-
66
- output.write(buffer)
67
- loop.update(len(buffer))
68
-
69
- if hashlib.sha256(open(download_target, "rb").read()).hexdigest() != expected_sha256:
70
- raise RuntimeError("Model has been downloaded but the SHA256 checksum does not not match")
71
-
72
- return download_target
73
-
74
-
75
- def _convert_image_to_rgb(image):
76
- return image.convert("RGB")
77
-
78
-
79
- def _transform(n_px):
80
- return Compose([
81
- Resize(n_px, interpolation=BICUBIC),
82
- CenterCrop(n_px),
83
- _convert_image_to_rgb,
84
- ToTensor(),
85
- Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711)),
86
- ])
87
-
88
-
89
- def available_models() -> List[str]:
90
- """Returns the names of available CLIP models"""
91
- return list(_MODELS.keys())
92
-
93
-
94
- def load(name: str, alpha_vision_ckpt_pth="None", device: Union[str, torch.device] = "cuda" if torch.cuda.is_available() else "cpu", jit: bool = False, download_root: str = None, lora_adapt=False, rank=16):
95
- """Load a CLIP model
96
-
97
- Parameters
98
- ----------
99
- name : str
100
- A model name listed by `clip.available_models()`, or the path to a model checkpoint containing the state_dict
101
-
102
- alpha_vision_ckpt_pth: str
103
- only changed when inferencing model instead of training
104
-
105
- device : Union[str, torch.device]
106
- The device to put the loaded model
107
-
108
- jit : bool
109
- Whether to load the optimized JIT model or more hackable non-JIT model (default).
110
-
111
- download_root: str
112
- path to download the model files; by default, it uses "~/.cache/clip"
113
-
114
- Returns
115
- -------
116
- model : torch.nn.Module
117
- The CLIP model
118
-
119
- preprocess : Callable[[PIL.Image], torch.Tensor]
120
- A torchvision transform that converts a PIL image into a tensor that the returned model can take as its input
121
- """
122
- if name in _MODELS:
123
- model_path = _download(_MODELS[name], download_root or os.path.expanduser("~/.cache/clip"))
124
- elif os.path.isfile(name):
125
- model_path = name
126
- else:
127
- raise RuntimeError(f"Model {name} not found; available models = {available_models()}")
128
-
129
- with open(model_path, 'rb') as opened_file:
130
- try:
131
- # loading JIT archive
132
- model = torch.jit.load(opened_file, map_location=device if jit else "cpu").eval()
133
- state_dict = None
134
- except RuntimeError:
135
- # loading saved state dict
136
- if jit:
137
- warnings.warn(f"File {model_path} is not a JIT archive. Loading as a state dict instead")
138
- jit = False
139
- state_dict = torch.load(opened_file, map_location="cpu")
140
-
141
- if not jit:
142
- model = build_model(state_dict or model.state_dict(), lora_adapt=lora_adapt, rank=rank).to(device)
143
- if str(device) == "cpu":
144
- model.float()
145
- # If a separate checkpoint is provided for the visual encoder (e.g., CLIP), load it
146
- if alpha_vision_ckpt_pth != "None":
147
- # Load the visual encoder weights from the given checkpoint path
148
- model.visual.load_state_dict(torch.load(alpha_vision_ckpt_pth))
149
- # Set the model to evaluation mode
150
- # Note: If LoRA is used, it may merge LoRA weights into the base model here for inference
151
- model.eval() # merge lora params if exists (for inference only)
152
- return model, _transform(model.visual.input_resolution)
153
-
154
- # patch the device names
155
- device_holder = torch.jit.trace(lambda: torch.ones([]).to(torch.device(device)), example_inputs=[])
156
- device_node = [n for n in device_holder.graph.findAllNodes("prim::Constant") if "Device" in repr(n)][-1]
157
-
158
- def _node_get(node: torch._C.Node, key: str):
159
- """Gets attributes of a node which is polymorphic over return type.
160
-
161
- From https://github.com/pytorch/pytorch/pull/82628
162
- """
163
- sel = node.kindOf(key)
164
- return getattr(node, sel)(key)
165
-
166
- def patch_device(module):
167
- try:
168
- graphs = [module.graph] if hasattr(module, "graph") else []
169
- except RuntimeError:
170
- graphs = []
171
-
172
- if hasattr(module, "forward1"):
173
- graphs.append(module.forward1.graph)
174
-
175
- for graph in graphs:
176
- for node in graph.findAllNodes("prim::Constant"):
177
- if "value" in node.attributeNames() and str(_node_get(node, "value")).startswith("cuda"):
178
- node.copyAttributes(device_node)
179
-
180
- model.apply(patch_device)
181
- patch_device(model.encode_image)
182
- patch_device(model.encode_text)
183
-
184
- # patch dtype to float32 on CPU
185
- if str(device) == "cpu":
186
- float_holder = torch.jit.trace(lambda: torch.ones([]).float(), example_inputs=[])
187
- float_input = list(float_holder.graph.findNode("aten::to").inputs())[1]
188
- float_node = float_input.node()
189
-
190
- def patch_float(module):
191
- try:
192
- graphs = [module.graph] if hasattr(module, "graph") else []
193
- except RuntimeError:
194
- graphs = []
195
-
196
- if hasattr(module, "forward1"):
197
- graphs.append(module.forward1.graph)
198
-
199
- for graph in graphs:
200
- for node in graph.findAllNodes("aten::to"):
201
- inputs = list(node.inputs())
202
- for i in [1, 2]: # dtype can be the second or third argument to aten::to()
203
- if _node_get(inputs[i].node(), "value") == 5:
204
- inputs[i].node().copyAttributes(float_node)
205
-
206
- model.apply(patch_float)
207
- patch_float(model.encode_image)
208
- patch_float(model.encode_text)
209
-
210
- model.float()
211
- return model, _transform(model.input_resolution.item())
212
-
213
-
214
- def tokenize(texts: Union[str, List[str]], context_length: int = 77, truncate: bool = True) -> Union[torch.IntTensor, torch.LongTensor]:
215
- """
216
- Returns the tokenized representation of given input string(s)
217
-
218
- Parameters
219
- ----------
220
- texts : Union[str, List[str]]
221
- An input string or a list of input strings to tokenize
222
-
223
- context_length : int
224
- The context length to use; all CLIP models use 77 as the context length
225
-
226
- truncate: bool
227
- Whether to truncate the text in case its encoding is longer than the context length
228
-
229
- Returns
230
- -------
231
- A two-dimensional tensor containing the resulting tokens, shape = [number of input strings, context_length].
232
- We return LongTensor when torch version is <1.8.0, since older index_select requires indices to be long.
233
- """
234
- if isinstance(texts, str):
235
- texts = [texts]
236
-
237
- sot_token = _tokenizer.encoder["<|startoftext|>"]
238
- eot_token = _tokenizer.encoder["<|endoftext|>"]
239
- all_tokens = [[sot_token] + _tokenizer.encode(text) + [eot_token] for text in texts]
240
- if packaging.version.parse(torch.__version__) < packaging.version.parse("1.8.0"):
241
- result = torch.zeros(len(all_tokens), context_length, dtype=torch.long)
242
- else:
243
- result = torch.zeros(len(all_tokens), context_length, dtype=torch.int)
244
-
245
- for i, tokens in enumerate(all_tokens):
246
- if len(tokens) > context_length:
247
- if truncate:
248
- tokens = tokens[:context_length]
249
- tokens[-1] = eot_token
250
- else:
251
- raise RuntimeError(f"Input {texts[i]} is too long for context length {context_length}")
252
- result[i, :len(tokens)] = torch.tensor(tokens)
253
-
254
- return result
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
src/alphaclip/alpha_clip/bpe_simple_vocab_16e6.txt.gz DELETED
@@ -1,3 +0,0 @@
1
- version https://git-lfs.github.com/spec/v1
2
- oid sha256:924691ac288e54409236115652ad4aa250f48203de50a9e4722a6ecd48d6804a
3
- size 1356917
 
 
 
 
src/alphaclip/alpha_clip/model.py DELETED
@@ -1,609 +0,0 @@
1
- from collections import OrderedDict
2
- from typing import Tuple, Union
3
-
4
- import numpy as np
5
- import torch
6
- import torch.nn.functional as F
7
- from torch import nn
8
- import loralib as lora
9
- import math
10
- import collections
11
-
12
- class Bottleneck(nn.Module):
13
- expansion = 4
14
-
15
- def __init__(self, inplanes, planes, stride=1):
16
- super().__init__()
17
-
18
- # all conv layers have stride 1. an avgpool is performed after the second convolution when stride > 1
19
- self.conv1 = nn.Conv2d(inplanes, planes, 1, bias=False)
20
- self.bn1 = nn.BatchNorm2d(planes)
21
- self.relu1 = nn.ReLU(inplace=True)
22
-
23
- self.conv2 = nn.Conv2d(planes, planes, 3, padding=1, bias=False)
24
- self.bn2 = nn.BatchNorm2d(planes)
25
- self.relu2 = nn.ReLU(inplace=True)
26
-
27
- self.avgpool = nn.AvgPool2d(stride) if stride > 1 else nn.Identity()
28
-
29
- self.conv3 = nn.Conv2d(planes, planes * self.expansion, 1, bias=False)
30
- self.bn3 = nn.BatchNorm2d(planes * self.expansion)
31
- self.relu3 = nn.ReLU(inplace=True)
32
-
33
- self.downsample = None
34
- self.stride = stride
35
-
36
- if stride > 1 or inplanes != planes * Bottleneck.expansion:
37
- # downsampling layer is prepended with an avgpool, and the subsequent convolution has stride 1
38
- self.downsample = nn.Sequential(OrderedDict([
39
- ("-1", nn.AvgPool2d(stride)),
40
- ("0", nn.Conv2d(inplanes, planes * self.expansion, 1, stride=1, bias=False)),
41
- ("1", nn.BatchNorm2d(planes * self.expansion))
42
- ]))
43
-
44
- def forward(self, x: torch.Tensor):
45
- identity = x
46
-
47
- out = self.relu1(self.bn1(self.conv1(x)))
48
- out = self.relu2(self.bn2(self.conv2(out)))
49
- out = self.avgpool(out)
50
- out = self.bn3(self.conv3(out))
51
-
52
- if self.downsample is not None:
53
- identity = self.downsample(x)
54
-
55
- out += identity
56
- out = self.relu3(out)
57
- return out
58
-
59
-
60
- class AttentionPool2d(nn.Module):
61
- def __init__(self, spacial_dim: int, embed_dim: int, num_heads: int, output_dim: int = None):
62
- super().__init__()
63
- self.positional_embedding = nn.Parameter(torch.randn(spacial_dim ** 2 + 1, embed_dim) / embed_dim ** 0.5)
64
- self.k_proj = nn.Linear(embed_dim, embed_dim)
65
- self.q_proj = nn.Linear(embed_dim, embed_dim)
66
- self.v_proj = nn.Linear(embed_dim, embed_dim)
67
- self.c_proj = nn.Linear(embed_dim, output_dim or embed_dim)
68
- self.num_heads = num_heads
69
-
70
- def forward(self, x):
71
- x = x.flatten(start_dim=2).permute(2, 0, 1) # NCHW -> (HW)NC
72
- x = torch.cat([x.mean(dim=0, keepdim=True), x], dim=0) # (HW+1)NC
73
- x = x + self.positional_embedding[:, None, :].to(x.dtype) # (HW+1)NC
74
- x, _ = F.multi_head_attention_forward(
75
- query=x[:1], key=x, value=x,
76
- embed_dim_to_check=x.shape[-1],
77
- num_heads=self.num_heads,
78
- q_proj_weight=self.q_proj.weight,
79
- k_proj_weight=self.k_proj.weight,
80
- v_proj_weight=self.v_proj.weight,
81
- in_proj_weight=None,
82
- in_proj_bias=torch.cat([self.q_proj.bias, self.k_proj.bias, self.v_proj.bias]),
83
- bias_k=None,
84
- bias_v=None,
85
- add_zero_attn=False,
86
- dropout_p=0,
87
- out_proj_weight=self.c_proj.weight,
88
- out_proj_bias=self.c_proj.bias,
89
- use_separate_proj_weight=True,
90
- training=self.training,
91
- need_weights=False
92
- )
93
- return x.squeeze(0)
94
-
95
-
96
- class ModifiedResNet(nn.Module):
97
- """
98
- A ResNet class that is similar to torchvision's but contains the following changes:
99
- - There are now 3 "stem" convolutions as opposed to 1, with an average pool instead of a max pool.
100
- - Performs anti-aliasing strided convolutions, where an avgpool is prepended to convolutions with stride > 1
101
- - The final pooling layer is a QKV attention instead of an average pool
102
- """
103
-
104
- def __init__(self, layers, output_dim, heads, input_resolution=224, width=64):
105
- super().__init__()
106
- self.output_dim = output_dim
107
- self.input_resolution = input_resolution
108
-
109
- # the 3-layer stem
110
- self.conv1 = nn.Conv2d(3, width // 2, kernel_size=3, stride=2, padding=1, bias=False)
111
- self.conv1_alpha = nn.Conv2d(in_channels=1, out_channels=width // 2, kernel_size=3, stride=2, padding=1, bias=False)
112
- self.bn1 = nn.BatchNorm2d(width // 2)
113
- self.relu1 = nn.ReLU(inplace=True)
114
- self.conv2 = nn.Conv2d(width // 2, width // 2, kernel_size=3, padding=1, bias=False)
115
- self.bn2 = nn.BatchNorm2d(width // 2)
116
- self.relu2 = nn.ReLU(inplace=True)
117
- self.conv3 = nn.Conv2d(width // 2, width, kernel_size=3, padding=1, bias=False)
118
- self.bn3 = nn.BatchNorm2d(width)
119
- self.relu3 = nn.ReLU(inplace=True)
120
- self.avgpool = nn.AvgPool2d(2)
121
-
122
- # residual layers
123
- self._inplanes = width # this is a *mutable* variable used during construction
124
- self.layer1 = self._make_layer(width, layers[0])
125
- self.layer2 = self._make_layer(width * 2, layers[1], stride=2)
126
- self.layer3 = self._make_layer(width * 4, layers[2], stride=2)
127
- self.layer4 = self._make_layer(width * 8, layers[3], stride=2)
128
-
129
- embed_dim = width * 32 # the ResNet feature dimension
130
- self.attnpool = AttentionPool2d(input_resolution // 32, embed_dim, heads, output_dim)
131
-
132
- def _make_layer(self, planes, blocks, stride=1):
133
- layers = [Bottleneck(self._inplanes, planes, stride)]
134
-
135
- self._inplanes = planes * Bottleneck.expansion
136
- for _ in range(1, blocks):
137
- layers.append(Bottleneck(self._inplanes, planes))
138
-
139
- return nn.Sequential(*layers)
140
-
141
- def forward(self, x, alpha=None):
142
- def stem(x):
143
- x = self.relu1(self.bn1(self.conv1(x) + self.conv1_alpha(alpha)))
144
- x = self.relu2(self.bn2(self.conv2(x)))
145
- x = self.relu3(self.bn3(self.conv3(x)))
146
- x = self.avgpool(x)
147
- return x
148
-
149
- x = x.type(self.conv1.weight.dtype)
150
- x = stem(x)
151
- x = self.layer1(x)
152
- x = self.layer2(x)
153
- x = self.layer3(x)
154
- x = self.layer4(x)
155
- x = self.attnpool(x)
156
-
157
- return x
158
-
159
-
160
- class LayerNorm(nn.LayerNorm):
161
- """Subclass torch's LayerNorm to handle fp16."""
162
-
163
- def forward(self, x: torch.Tensor):
164
- orig_type = x.dtype
165
- ret = super().forward(x.type(torch.float32))
166
- return ret.type(orig_type)
167
-
168
-
169
- class QuickGELU(nn.Module):
170
- def forward(self, x: torch.Tensor):
171
- return x * torch.sigmoid(1.702 * x)
172
-
173
- class Attention(nn.Module):
174
- def __init__(
175
- self,
176
- dim,
177
- num_heads=8,
178
- qkv_bias=True,
179
- scaled_cosine=False,
180
- scale_heads=False,
181
- logit_scale_max=math.log(1. / 0.01),
182
- attn_drop=0.,
183
- proj_drop=0.,
184
- lora_adapt=False,
185
- rank=16
186
- ):
187
- super().__init__()
188
- self.scaled_cosine = scaled_cosine
189
- self.scale_heads = scale_heads
190
- assert dim % num_heads == 0, 'dim should be divisible by num_heads'
191
- self.num_heads = num_heads
192
- self.head_dim = dim // num_heads
193
- self.scale = self.head_dim ** -0.5
194
- self.logit_scale_max = logit_scale_max
195
-
196
- # keeping in_proj in this form (instead of nn.Linear) to match weight scheme of original
197
- if lora_adapt:
198
- print("!!!!!!!!!!using lora for qkv projection!!!!!!!!!!")
199
- self.in_proj = lora.MergedLinear(dim, 3*dim, r=rank, enable_lora=[True, False, True])
200
- else:
201
- self.in_proj = nn.Linear(dim, dim * 3)
202
- # self.in_proj_weight = nn.Parameter(torch.randn((dim * 3, dim)) * self.scale)
203
- # if qkv_bias:
204
- # self.in_proj_bias = nn.Parameter(torch.zeros(dim * 3))
205
- # else:
206
- # self.in_proj_bias = None
207
-
208
- if self.scaled_cosine:
209
- self.logit_scale = nn.Parameter(torch.log(10 * torch.ones((num_heads, 1, 1))))
210
- else:
211
- self.logit_scale = None
212
- self.attn_drop = nn.Dropout(attn_drop)
213
- if self.scale_heads:
214
- self.head_scale = nn.Parameter(torch.ones((num_heads, 1, 1)))
215
- else:
216
- self.head_scale = None
217
- self.out_proj = nn.Linear(dim, dim) if not lora_adapt else lora.Linear(dim, dim, r=rank)
218
- self.out_drop = nn.Dropout(proj_drop)
219
-
220
- def forward(self, x, attn_mask = None):
221
- L, N, C = x.shape
222
- q, k, v = self.in_proj(x).chunk(3, dim=-1)
223
- q = q.contiguous().view(L, N * self.num_heads, -1).transpose(0, 1)
224
- k = k.contiguous().view(L, N * self.num_heads, -1).transpose(0, 1)
225
- v = v.contiguous().view(L, N * self.num_heads, -1).transpose(0, 1)
226
-
227
- if self.logit_scale is not None:
228
- attn = torch.bmm(F.normalize(q, dim=-1), F.normalize(k, dim=-1).transpose(-1, -2))
229
- logit_scale = torch.clamp(self.logit_scale, max=self.logit_scale_max).exp()
230
- attn = attn.view(N, self.num_heads, L, L) * logit_scale
231
- attn = attn.view(-1, L, L)
232
- else:
233
- q = q * self.scale
234
- attn = torch.bmm(q, k.transpose(-2, -1))
235
-
236
- if attn_mask is not None:
237
- if attn_mask.dtype == torch.bool:
238
- new_attn_mask = torch.zeros_like(attn_mask, dtype=q.dtype)
239
- new_attn_mask.masked_fill_(attn_mask, float("-inf"))
240
- attn_mask = new_attn_mask
241
- attn += attn_mask
242
-
243
- attn = attn.softmax(dim=-1)
244
- attn = self.attn_drop(attn)
245
-
246
- x = torch.bmm(attn, v)
247
- if self.head_scale is not None:
248
- x = x.view(N, self.num_heads, L, C) * self.head_scale
249
- x = x.view(-1, L, C)
250
- x = x.transpose(0, 1).reshape(L, N, C)
251
- x = self.out_proj(x)
252
- x = self.out_drop(x)
253
- return x, attn
254
-
255
-
256
- class CustomResidualAttentionBlock(nn.Module):
257
- def __init__(self, d_model: int, n_head: int, attn_mask: torch.Tensor = None, lora_adapt=False, rank=16):
258
- super().__init__()
259
-
260
- self.attn = Attention(d_model, n_head, lora_adapt=lora_adapt, rank=rank)
261
- self.ln_1 = LayerNorm(d_model)
262
- self.mlp = nn.Sequential(OrderedDict([
263
- ("c_fc", nn.Linear(d_model, d_model * 4) if not lora_adapt else lora.Linear(d_model, d_model*4, r=rank)),
264
- ("gelu", QuickGELU()),
265
- ("c_proj", nn.Linear(d_model * 4, d_model) if not lora_adapt else lora.Linear(d_model*4, d_model, r=rank))
266
- ]))
267
- self.ln_2 = LayerNorm(d_model)
268
- self.attn_mask = attn_mask
269
-
270
- def attention(self, x: torch.Tensor):
271
- self.attn_mask = self.attn_mask.to(dtype=x.dtype, device=x.device) if self.attn_mask is not None else None
272
- return self.attn(x, attn_mask=self.attn_mask)
273
-
274
- def forward(self, x: torch.Tensor, return_attn=False):
275
- attn_out, attn = self.attention(self.ln_1(x))
276
- x = x + attn_out
277
- x = x + self.mlp(self.ln_2(x))
278
- if return_attn:
279
- return x, attn
280
- else:
281
- return x
282
-
283
- class ResidualAttentionBlock(nn.Module):
284
- def __init__(self, d_model: int, n_head: int, attn_mask: torch.Tensor = None):
285
- super().__init__()
286
-
287
- self.attn = nn.MultiheadAttention(d_model, n_head)
288
- self.ln_1 = LayerNorm(d_model)
289
- self.mlp = nn.Sequential(OrderedDict([
290
- ("c_fc", nn.Linear(d_model, d_model * 4)),
291
- ("gelu", QuickGELU()),
292
- ("c_proj", nn.Linear(d_model * 4, d_model))
293
- ]))
294
- self.ln_2 = LayerNorm(d_model)
295
- self.attn_mask = attn_mask
296
-
297
- def attention(self, x: torch.Tensor):
298
- self.attn_mask = self.attn_mask.to(dtype=x.dtype, device=x.device) if self.attn_mask is not None else None
299
- return self.attn(x, x, x, attn_mask=self.attn_mask)[0]
300
-
301
- def forward(self, x: torch.Tensor):
302
- x = x + self.attention(self.ln_1(x))
303
- x = x + self.mlp(self.ln_2(x))
304
- return x
305
-
306
- class Transformer(nn.Module):
307
- def __init__(self, width: int, layers: int, heads: int, attn_mask: torch.Tensor = None):
308
- super().__init__()
309
- self.width = width
310
- self.layers = layers
311
- self.resblocks = nn.Sequential(*[ResidualAttentionBlock(width, heads, attn_mask) for _ in range(layers)])
312
-
313
- def forward(self, x: torch.Tensor):
314
- return self.resblocks(x)
315
-
316
- class CustomTransformer(nn.Module):
317
- def __init__(self, width: int, layers: int, heads: int, attn_mask: torch.Tensor = None, lora_adapt=False, rank=16):
318
- super().__init__()
319
- self.width = width
320
- self.layers = layers
321
- self.resblocks = nn.Sequential(*[CustomResidualAttentionBlock(width, heads, attn_mask, lora_adapt=lora_adapt, rank=rank) for _ in range(layers)])
322
-
323
- def forward(self, x: torch.Tensor, return_attn=False):
324
- if return_attn:
325
- for i, block in enumerate(self.resblocks):
326
- if i == len(self.resblocks) - 1:
327
- return block(x, return_attn=True)
328
- else:
329
- x = block(x)
330
- assert False
331
- return self.resblocks(x)
332
-
333
- class VisionTransformer(nn.Module):
334
- def __init__(self, input_resolution: int, patch_size: int, width: int, layers: int, heads: int, output_dim: int, lora_adapt=False, rank=16):
335
- super().__init__()
336
- self.input_resolution = input_resolution
337
- self.output_dim = output_dim
338
- self.conv1 = nn.Conv2d(in_channels=3, out_channels=width, kernel_size=patch_size, stride=patch_size, bias=False)
339
- self.conv1_alpha = nn.Conv2d(in_channels=1, out_channels=width, kernel_size=patch_size, stride=patch_size, bias=False)
340
-
341
- scale = width ** -0.5
342
- self.class_embedding = nn.Parameter(scale * torch.randn(width))
343
- self.positional_embedding = nn.Parameter(scale * torch.randn((input_resolution // patch_size) ** 2 + 1, width))
344
- self.ln_pre = LayerNorm(width)
345
-
346
- self.transformer = CustomTransformer(width, layers, heads, lora_adapt=lora_adapt, rank=rank)
347
-
348
- self.ln_post = LayerNorm(width)
349
- self.proj = nn.Parameter(scale * torch.randn(width, output_dim))
350
-
351
- def forward(self, x: torch.Tensor, alpha=None, return_attn=False, return_patches=False):
352
- # if x dtype is different from conv1, convert it
353
- if x.dtype != self.conv1.weight.dtype:
354
- x = x.type(self.conv1.weight.dtype)
355
-
356
- if alpha.dtype != self.conv1_alpha.weight.dtype:
357
- alpha = alpha.type(self.conv1_alpha.weight.dtype)
358
-
359
- x = self.conv1(x) # shape = [*, width, grid, grid]
360
- # ASSUME alpha is always not None!
361
- x = x + self.conv1_alpha(alpha)
362
- x = x.reshape(x.shape[0], x.shape[1], -1) # shape = [*, width, grid ** 2]
363
- x = x.permute(0, 2, 1) # shape = [*, grid ** 2, width]
364
- x = torch.cat([self.class_embedding.to(x.dtype) + torch.zeros(x.shape[0], 1, x.shape[-1], dtype=x.dtype, device=x.device), x], dim=1) # shape = [*, grid ** 2 + 1, width]
365
- x = x + self.positional_embedding.to(x.dtype)
366
- x = self.ln_pre(x)
367
-
368
- x = x.permute(1, 0, 2) # NLD -> LND
369
- if return_attn:
370
- x, attn_last = self.transformer(x, return_attn=True)
371
- else:
372
- x = self.transformer(x, return_attn=False)
373
- x = x.permute(1, 0, 2) # LND -> NLD
374
-
375
- if not return_patches:
376
- x = self.ln_post(x[:, 0, :])
377
- else:
378
- x = self.ln_post(x)
379
-
380
- if self.proj is not None:
381
- x = x @ self.proj
382
- if return_attn:
383
- return x, attn_last
384
- else:
385
- return x
386
-
387
-
388
- class CLIP(nn.Module):
389
- def __init__(self,
390
- embed_dim: int,
391
- # vision
392
- image_resolution: int,
393
- vision_layers: Union[Tuple[int, int, int, int], int],
394
- vision_width: int,
395
- vision_patch_size: int,
396
- # text
397
- context_length: int,
398
- vocab_size: int,
399
- transformer_width: int,
400
- transformer_heads: int,
401
- transformer_layers: int,
402
- lora_adapt = False,
403
- rank = 16,
404
- ):
405
- super().__init__()
406
-
407
- self.context_length = context_length
408
-
409
- if isinstance(vision_layers, (tuple, list)):
410
- vision_heads = vision_width * 32 // 64
411
- self.visual = ModifiedResNet(
412
- layers=vision_layers,
413
- output_dim=embed_dim,
414
- heads=vision_heads,
415
- input_resolution=image_resolution,
416
- width=vision_width
417
- )
418
- else:
419
- vision_heads = vision_width // 64
420
- self.visual = VisionTransformer(
421
- input_resolution=image_resolution,
422
- patch_size=vision_patch_size,
423
- width=vision_width,
424
- layers=vision_layers,
425
- heads=vision_heads,
426
- output_dim=embed_dim,
427
- lora_adapt=lora_adapt,
428
- rank=rank
429
- )
430
-
431
- self.transformer = Transformer(
432
- width=transformer_width,
433
- layers=transformer_layers,
434
- heads=transformer_heads,
435
- attn_mask=self.build_attention_mask()
436
- )
437
-
438
- self.vocab_size = vocab_size
439
- self.token_embedding = nn.Embedding(vocab_size, transformer_width)
440
- self.positional_embedding = nn.Parameter(torch.empty(self.context_length, transformer_width))
441
- self.ln_final = LayerNorm(transformer_width)
442
-
443
- self.text_projection = nn.Parameter(torch.empty(transformer_width, embed_dim))
444
- self.logit_scale = nn.Parameter(torch.ones([]) * np.log(1 / 0.07))
445
-
446
- self.initialize_parameters()
447
-
448
- def initialize_parameters(self):
449
- nn.init.normal_(self.token_embedding.weight, std=0.02)
450
- nn.init.normal_(self.positional_embedding, std=0.01)
451
-
452
- if isinstance(self.visual, ModifiedResNet):
453
- if self.visual.attnpool is not None:
454
- std = self.visual.attnpool.c_proj.in_features ** -0.5
455
- nn.init.normal_(self.visual.attnpool.q_proj.weight, std=std)
456
- nn.init.normal_(self.visual.attnpool.k_proj.weight, std=std)
457
- nn.init.normal_(self.visual.attnpool.v_proj.weight, std=std)
458
- nn.init.normal_(self.visual.attnpool.c_proj.weight, std=std)
459
-
460
- for resnet_block in [self.visual.layer1, self.visual.layer2, self.visual.layer3, self.visual.layer4]:
461
- for name, param in resnet_block.named_parameters():
462
- if name.endswith("bn3.weight"):
463
- nn.init.zeros_(param)
464
-
465
- proj_std = (self.transformer.width ** -0.5) * ((2 * self.transformer.layers) ** -0.5)
466
- attn_std = self.transformer.width ** -0.5
467
- fc_std = (2 * self.transformer.width) ** -0.5
468
- for block in self.transformer.resblocks:
469
- nn.init.normal_(block.attn.in_proj_weight, std=attn_std)
470
- nn.init.normal_(block.attn.out_proj.weight, std=proj_std)
471
- nn.init.normal_(block.mlp.c_fc.weight, std=fc_std)
472
- nn.init.normal_(block.mlp.c_proj.weight, std=proj_std)
473
-
474
- if self.text_projection is not None:
475
- nn.init.normal_(self.text_projection, std=self.transformer.width ** -0.5)
476
-
477
- def build_attention_mask(self):
478
- # lazily create causal attention mask, with full attention between the vision tokens
479
- # pytorch uses additive attention mask; fill with -inf
480
- mask = torch.empty(self.context_length, self.context_length)
481
- mask.fill_(float("-inf"))
482
- mask.triu_(1) # zero out the lower diagonal
483
- return mask
484
-
485
- @property
486
- def dtype(self):
487
- if not hasattr(self.visual, "conv1"):
488
- return self.visual.module.conv1.weight.dtype
489
- return self.visual.conv1.weight.dtype
490
-
491
- def encode_image(self, image, alpha):
492
- assert alpha is not None
493
- return self.visual(image.type(self.dtype), alpha.type(self.dtype))
494
-
495
- def encode_text(self, text):
496
- x = self.token_embedding(text).type(self.dtype) # [batch_size, n_ctx, d_model]
497
-
498
- x = x + self.positional_embedding.type(self.dtype)
499
- x = x.permute(1, 0, 2) # NLD -> LND
500
- x = self.transformer(x)
501
- x = x.permute(1, 0, 2) # LND -> NLD
502
- x = self.ln_final(x).type(self.dtype)
503
-
504
- # x.shape = [batch_size, n_ctx, transformer.width]
505
- # take features from the eot embedding (eot_token is the highest number in each sequence)
506
- x = x[torch.arange(x.shape[0]), text.argmax(dim=-1)] @ self.text_projection
507
-
508
- return x
509
-
510
- def forward(self, image, text, alpha):
511
-
512
- image_features = self.encode_image(image, alpha)
513
- text_features = self.encode_text(text)
514
-
515
- # normalized features
516
- image_features = image_features / image_features.norm(dim=1, keepdim=True)
517
- text_features = text_features / text_features.norm(dim=1, keepdim=True)
518
-
519
- # cosine similarity as logits
520
- logit_scale = self.logit_scale.exp()
521
- logits_per_image = logit_scale * image_features @ text_features.t()
522
- logits_per_text = logits_per_image.t()
523
-
524
- # shape = [global_batch_size, global_batch_size]
525
- return logits_per_image, logits_per_text
526
-
527
-
528
- def convert_weights(model: nn.Module):
529
- """Convert applicable model parameters to fp16"""
530
-
531
- def _convert_weights_to_fp16(l):
532
- if isinstance(l, (nn.Conv1d, nn.Conv2d, nn.Linear)):
533
- l.weight.data = l.weight.data.half()
534
- if l.bias is not None:
535
- l.bias.data = l.bias.data.half()
536
-
537
- if isinstance(l, nn.MultiheadAttention):
538
- for attr in [*[f"{s}_proj_weight" for s in ["in", "q", "k", "v"]], "in_proj_bias", "bias_k", "bias_v"]:
539
- tensor = getattr(l, attr)
540
- if tensor is not None:
541
- tensor.data = tensor.data.half()
542
-
543
- for name in ["text_projection", "proj"]:
544
- if hasattr(l, name):
545
- attr = getattr(l, name)
546
- if attr is not None:
547
- attr.data = attr.data.half()
548
-
549
- model.apply(_convert_weights_to_fp16)
550
-
551
-
552
- def build_model(state_dict: dict, lora_adapt=False, rank=16):
553
- vit = "visual.proj" in state_dict
554
-
555
- if vit:
556
- vision_width = state_dict["visual.conv1.weight"].shape[0]
557
- vision_layers = len([k for k in state_dict.keys() if k.startswith("visual.") and k.endswith(".attn.in_proj_weight")])
558
- vision_patch_size = state_dict["visual.conv1.weight"].shape[-1]
559
- grid_size = round((state_dict["visual.positional_embedding"].shape[0] - 1) ** 0.5)
560
- image_resolution = vision_patch_size * grid_size
561
- else:
562
- counts: list = [len(set(k.split(".")[2] for k in state_dict if k.startswith(f"visual.layer{b}"))) for b in [1, 2, 3, 4]]
563
- vision_layers = tuple(counts)
564
- vision_width = state_dict["visual.layer1.0.conv1.weight"].shape[0]
565
- output_width = round((state_dict["visual.attnpool.positional_embedding"].shape[0] - 1) ** 0.5)
566
- vision_patch_size = None
567
- assert output_width ** 2 + 1 == state_dict["visual.attnpool.positional_embedding"].shape[0]
568
- image_resolution = output_width * 32
569
-
570
- embed_dim = state_dict["text_projection"].shape[1]
571
- context_length = state_dict["positional_embedding"].shape[0]
572
- vocab_size = state_dict["token_embedding.weight"].shape[0]
573
- transformer_width = state_dict["ln_final.weight"].shape[0]
574
- transformer_heads = transformer_width // 64
575
- transformer_layers = len(set(k.split(".")[2] for k in state_dict if k.startswith("transformer.resblocks")))
576
-
577
- # always load lora version
578
- model = CLIP(
579
- embed_dim,
580
- image_resolution, vision_layers, vision_width, vision_patch_size,
581
- context_length, vocab_size, transformer_width, transformer_heads, transformer_layers,
582
- lora_adapt=lora_adapt, rank=rank,
583
- )
584
-
585
- for key in ["input_resolution", "context_length", "vocab_size"]:
586
- if key in state_dict:
587
- del state_dict[key]
588
- # para_wb to linear
589
- new_state_dict = collections.OrderedDict()
590
- for k, v in state_dict.items():
591
- if 'visual' in k:
592
- if 'in_proj_weight' in k:
593
- new_state_dict[k.replace('in_proj_weight', 'in_proj.weight')] = v
594
- elif 'in_proj_bias' in k:
595
- new_state_dict[k.replace('in_proj_bias', 'in_proj.bias')] = v
596
- else:
597
- new_state_dict[k] = v
598
- else:
599
- new_state_dict[k] = v
600
-
601
- state_dict = new_state_dict
602
- # add rgba_conv_weight
603
- if 'visual.conv1_alpha.weight' not in state_dict.keys(): # zero initialization on alpha channel
604
- rgb_weight = state_dict['visual.conv1.weight'].clone().detach()
605
- rgba_weigth = torch.zeros_like(rgb_weight)[:, 0:1, :, :]
606
- state_dict['visual.conv1_alpha.weight'] = rgba_weigth
607
- convert_weights(model)
608
- model.load_state_dict(state_dict, strict=False)
609
- return model.eval()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
src/alphaclip/alpha_clip/simple_tokenizer.py DELETED
@@ -1,132 +0,0 @@
1
- import gzip
2
- import html
3
- import os
4
- from functools import lru_cache
5
-
6
- import ftfy
7
- import regex as re
8
-
9
-
10
- @lru_cache()
11
- def default_bpe():
12
- return os.path.join(os.path.dirname(os.path.abspath(__file__)), "bpe_simple_vocab_16e6.txt.gz")
13
-
14
-
15
- @lru_cache()
16
- def bytes_to_unicode():
17
- """
18
- Returns list of utf-8 byte and a corresponding list of unicode strings.
19
- The reversible bpe codes work on unicode strings.
20
- This means you need a large # of unicode characters in your vocab if you want to avoid UNKs.
21
- When you're at something like a 10B token dataset you end up needing around 5K for decent coverage.
22
- This is a signficant percentage of your normal, say, 32K bpe vocab.
23
- To avoid that, we want lookup tables between utf-8 bytes and unicode strings.
24
- And avoids mapping to whitespace/control characters the bpe code barfs on.
25
- """
26
- bs = list(range(ord("!"), ord("~")+1))+list(range(ord("ยก"), ord("ยฌ")+1))+list(range(ord("ยฎ"), ord("รฟ")+1))
27
- cs = bs[:]
28
- n = 0
29
- for b in range(2**8):
30
- if b not in bs:
31
- bs.append(b)
32
- cs.append(2**8+n)
33
- n += 1
34
- cs = [chr(n) for n in cs]
35
- return dict(zip(bs, cs))
36
-
37
-
38
- def get_pairs(word):
39
- """Return set of symbol pairs in a word.
40
- Word is represented as tuple of symbols (symbols being variable-length strings).
41
- """
42
- pairs = set()
43
- prev_char = word[0]
44
- for char in word[1:]:
45
- pairs.add((prev_char, char))
46
- prev_char = char
47
- return pairs
48
-
49
-
50
- def basic_clean(text):
51
- text = ftfy.fix_text(text)
52
- text = html.unescape(html.unescape(text))
53
- return text.strip()
54
-
55
-
56
- def whitespace_clean(text):
57
- text = re.sub(r'\s+', ' ', text)
58
- text = text.strip()
59
- return text
60
-
61
-
62
- class SimpleTokenizer(object):
63
- def __init__(self, bpe_path: str = default_bpe()):
64
- self.byte_encoder = bytes_to_unicode()
65
- self.byte_decoder = {v: k for k, v in self.byte_encoder.items()}
66
- merges = gzip.open(bpe_path).read().decode("utf-8").split('\n')
67
- merges = merges[1:49152-256-2+1]
68
- merges = [tuple(merge.split()) for merge in merges]
69
- vocab = list(bytes_to_unicode().values())
70
- vocab = vocab + [v+'</w>' for v in vocab]
71
- for merge in merges:
72
- vocab.append(''.join(merge))
73
- vocab.extend(['<|startoftext|>', '<|endoftext|>'])
74
- self.encoder = dict(zip(vocab, range(len(vocab))))
75
- self.decoder = {v: k for k, v in self.encoder.items()}
76
- self.bpe_ranks = dict(zip(merges, range(len(merges))))
77
- self.cache = {'<|startoftext|>': '<|startoftext|>', '<|endoftext|>': '<|endoftext|>'}
78
- self.pat = re.compile(r"""<\|startoftext\|>|<\|endoftext\|>|'s|'t|'re|'ve|'m|'ll|'d|[\p{L}]+|[\p{N}]|[^\s\p{L}\p{N}]+""", re.IGNORECASE)
79
-
80
- def bpe(self, token):
81
- if token in self.cache:
82
- return self.cache[token]
83
- word = tuple(token[:-1]) + ( token[-1] + '</w>',)
84
- pairs = get_pairs(word)
85
-
86
- if not pairs:
87
- return token+'</w>'
88
-
89
- while True:
90
- bigram = min(pairs, key = lambda pair: self.bpe_ranks.get(pair, float('inf')))
91
- if bigram not in self.bpe_ranks:
92
- break
93
- first, second = bigram
94
- new_word = []
95
- i = 0
96
- while i < len(word):
97
- try:
98
- j = word.index(first, i)
99
- new_word.extend(word[i:j])
100
- i = j
101
- except:
102
- new_word.extend(word[i:])
103
- break
104
-
105
- if word[i] == first and i < len(word)-1 and word[i+1] == second:
106
- new_word.append(first+second)
107
- i += 2
108
- else:
109
- new_word.append(word[i])
110
- i += 1
111
- new_word = tuple(new_word)
112
- word = new_word
113
- if len(word) == 1:
114
- break
115
- else:
116
- pairs = get_pairs(word)
117
- word = ' '.join(word)
118
- self.cache[token] = word
119
- return word
120
-
121
- def encode(self, text):
122
- bpe_tokens = []
123
- text = whitespace_clean(basic_clean(text)).lower()
124
- for token in re.findall(self.pat, text):
125
- token = ''.join(self.byte_encoder[b] for b in token.encode('utf-8'))
126
- bpe_tokens.extend(self.encoder[bpe_token] for bpe_token in self.bpe(token).split(' '))
127
- return bpe_tokens
128
-
129
- def decode(self, tokens):
130
- text = ''.join([self.decoder[token] for token in tokens])
131
- text = bytearray([self.byte_decoder[c] for c in text]).decode('utf-8', errors="replace").replace('</w>', ' ')
132
- return text
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
src/alphaclip/alpha_mask_utils.py DELETED
@@ -1,111 +0,0 @@
1
- """
2
- Utility functions for converting bboxes and traces to alpha masks for AlphaClip.
3
- """
4
-
5
- import torch
6
- import math
7
-
8
-
9
- def bbox_to_alpha_mask(bbox, grid_size, patch_size, crop_dim):
10
- """
11
- Convert a single bounding box to an alpha mask for AlphaClip.
12
-
13
- Args:
14
- bbox: [x_min, y_min, w, h] format in original coordinates
15
- grid_size: Number of patches per side (e.g., 37 for 518/14)
16
- patch_size: Size of each patch in pixels
17
- crop_dim: Size of the cropped image
18
-
19
- Returns:
20
- alpha_mask: Binary mask of shape (grid_size, grid_size)
21
- """
22
- alpha_mask = torch.zeros((grid_size, grid_size))
23
-
24
- # Convert bbox to patch coordinates
25
- x_min, y_min, w, h = bbox
26
- x_max = x_min + w
27
- y_max = y_min + h
28
-
29
- # Scale to patch grid coordinates
30
- x1_patch = int(x_min // patch_size)
31
- y1_patch = int(y_min // patch_size)
32
- x2_patch = int(x_max // patch_size)
33
- y2_patch = int(y_max // patch_size)
34
-
35
- # Clamp to grid bounds
36
- x1_patch = max(0, min(x1_patch, grid_size - 1))
37
- y1_patch = max(0, min(y1_patch, grid_size - 1))
38
- x2_patch = max(0, min(x2_patch, grid_size)) # Allow up to grid_size for exclusive end
39
- y2_patch = max(0, min(y2_patch, grid_size))
40
-
41
- # Set the region to 1 (using slice notation for proper indexing)
42
- if x2_patch > x1_patch and y2_patch > y1_patch:
43
- alpha_mask[y1_patch:y2_patch, x1_patch:x2_patch] = 1.0
44
-
45
- return alpha_mask
46
-
47
-
48
- def bboxes_to_alpha_mask(bboxes, grid_size, patch_size, crop_dim):
49
- """
50
- Convert multiple bboxes to a single OR-ed alpha mask.
51
-
52
- Args:
53
- bboxes: Tensor of bboxes in [x_min, y_min, w, h] format, shape [n_boxes, 4]
54
- grid_size: Number of patches per side
55
- patch_size: Size of each patch in pixels
56
- crop_dim: Size of the cropped image
57
-
58
- Returns:
59
- alpha_mask: Binary mask of shape (grid_size, grid_size)
60
- """
61
- alpha_mask = torch.zeros((grid_size, grid_size))
62
-
63
- for bbox in bboxes:
64
- # Skip dummy boxes (negative values)
65
- if bbox.sum().item() < 0:
66
- continue
67
-
68
- bbox_mask = bbox_to_alpha_mask(bbox, grid_size, patch_size, crop_dim)
69
- alpha_mask = torch.logical_or(alpha_mask, bbox_mask).float()
70
-
71
- return alpha_mask
72
-
73
-
74
- def trace_to_alpha_mask(trace, grid_size):
75
- """
76
- Convert a trace to an alpha mask using the existing map_traces_to_grid function.
77
-
78
- Args:
79
- trace: List of trace points with 'x' and 'y' coordinates (normalized 0-1)
80
- grid_size: Number of patches per side
81
-
82
- Returns:
83
- alpha_mask: Binary mask of shape (grid_size, grid_size)
84
- """
85
- from src.bbox_utils import map_traces_to_grid
86
-
87
- alpha_mask = map_traces_to_grid(trace, grid_size)
88
- # Convert to binary (any value > 0 becomes 1)
89
- alpha_mask = (alpha_mask > 0).float()
90
-
91
- return alpha_mask
92
-
93
-
94
- def traces_to_alpha_mask(traces, grid_size):
95
- """
96
- Convert multiple traces to a single OR-ed alpha mask.
97
-
98
- Args:
99
- traces: List of traces
100
- grid_size: Number of patches per side
101
-
102
- Returns:
103
- alpha_mask: Binary mask of shape (grid_size, grid_size)
104
- """
105
- alpha_mask = torch.zeros((grid_size, grid_size))
106
-
107
- for trace in traces:
108
- trace_mask = trace_to_alpha_mask(trace, grid_size)
109
- alpha_mask = torch.logical_or(alpha_mask, trace_mask).float()
110
-
111
- return alpha_mask
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
src/alphaclip/alphaclip_loader.py DELETED
@@ -1,233 +0,0 @@
1
- """
2
- AlphaCLIP Standalone Loader
3
-
4
- This module provides a simple interface to load and use AlphaCLIP models.
5
- It exposes the core functionality of AlphaCLIP in a standalone package.
6
-
7
- Usage:
8
- from alphaclip_loader import AlphaCLIPLoader
9
-
10
- # Initialize the loader
11
- loader = AlphaCLIPLoader()
12
-
13
- # Load a model
14
- model, preprocess = loader.load_model("ViT-B/16")
15
-
16
- # Tokenize text
17
- tokens = loader.tokenize("A photo of a cat")
18
-
19
- # Get available models
20
- models = loader.available_models()
21
- """
22
-
23
- import os
24
- import sys
25
- from typing import Union, List, Tuple, Optional
26
-
27
- # Check for critical dependencies
28
- missing_deps = []
29
- try:
30
- import torch
31
- except ImportError:
32
- missing_deps.append("torch")
33
-
34
- try:
35
- from PIL import Image
36
- except ImportError:
37
- missing_deps.append("Pillow")
38
-
39
- if missing_deps:
40
- raise ImportError(f"Missing required dependencies: {', '.join(missing_deps)}. "
41
- f"Please install them with: pip install {' '.join(missing_deps)}")
42
-
43
- # Add the alpha_clip directory to the path
44
- _current_dir = os.path.dirname(os.path.abspath(__file__))
45
- _alpha_clip_dir = os.path.join(_current_dir, 'alpha_clip')
46
- if _alpha_clip_dir not in sys.path:
47
- sys.path.insert(0, _alpha_clip_dir)
48
-
49
- # Import the alpha_clip modules
50
- try:
51
- #import .alpha_clip
52
- from .alpha_clip import available_models, load, tokenize
53
- except ImportError as e:
54
- raise ImportError(f"Failed to import alpha_clip modules: {e}. Please ensure all dependencies are installed.")
55
-
56
-
57
- class AlphaCLIPLoader:
58
- """
59
- A convenience wrapper for AlphaCLIP functionality.
60
-
61
- This class provides a clean interface to load AlphaCLIP models and
62
- perform text tokenization.
63
- """
64
-
65
- def __init__(self, default_device: Optional[str] = None):
66
- """
67
- Initialize the AlphaCLIP loader.
68
-
69
- Args:
70
- default_device: Default device to load models on. If None, will use
71
- CUDA if available, otherwise CPU.
72
- """
73
- if default_device is None:
74
- self.default_device = "cuda" if torch.cuda.is_available() else "cpu"
75
- else:
76
- self.default_device = default_device
77
-
78
- def available_models(self) -> List[str]:
79
- """
80
- Get list of available AlphaCLIP model names.
81
-
82
- Returns:
83
- List of model names that can be used with load_model()
84
- """
85
- return available_models()
86
-
87
- def load_model(
88
- self,
89
- name: str,
90
- alpha_vision_ckpt_pth: str = "None",
91
- device: Optional[Union[str, torch.device]] = None,
92
- jit: bool = False,
93
- download_root: Optional[str] = None,
94
- lora_adapt: bool = False,
95
- rank: int = 16
96
- ) -> Tuple[torch.nn.Module, callable]:
97
- """
98
- Load an AlphaCLIP model.
99
-
100
- Args:
101
- name: Model name (e.g., "ViT-B/16") or path to checkpoint
102
- alpha_vision_ckpt_pth: Path to additional vision checkpoint
103
- device: Device to load model on (defaults to self.default_device)
104
- jit: Whether to load JIT optimized model
105
- download_root: Directory to download models to
106
- lora_adapt: Whether to use LoRA adaptation
107
- rank: LoRA rank if lora_adapt is True
108
-
109
- Returns:
110
- Tuple of (model, preprocess_function)
111
- """
112
- if device is None:
113
- device = self.default_device
114
-
115
- return load(
116
- name=name,
117
- alpha_vision_ckpt_pth=alpha_vision_ckpt_pth,
118
- device=device,
119
- jit=jit,
120
- download_root=download_root,
121
- lora_adapt=lora_adapt,
122
- rank=rank
123
- )
124
-
125
- def tokenize(
126
- self,
127
- texts: Union[str, List[str]],
128
- context_length: int = 77,
129
- truncate: bool = True
130
- ) -> torch.Tensor:
131
- """
132
- Tokenize text for use with AlphaCLIP models.
133
-
134
- Args:
135
- texts: String or list of strings to tokenize
136
- context_length: Maximum token length (default 77)
137
- truncate: Whether to truncate long texts
138
-
139
- Returns:
140
- Tensor of tokenized text
141
- """
142
- return tokenize(texts, context_length, truncate)
143
-
144
- def encode_text(self, model: torch.nn.Module, texts: Union[str, List[str]]) -> torch.Tensor:
145
- """
146
- Convenience method to tokenize and encode text.
147
-
148
- Args:
149
- model: Loaded AlphaCLIP model
150
- texts: Text(s) to encode
151
-
152
- Returns:
153
- Text embeddings tensor
154
- """
155
- tokens = self.tokenize(texts)
156
- if hasattr(model, 'token_embedding'):
157
- # Move tokens to same device as model
158
- device = next(model.parameters()).device
159
- tokens = tokens.to(device)
160
-
161
- with torch.no_grad():
162
- text_features = model.encode_text(tokens)
163
-
164
- return text_features
165
-
166
- def encode_image(self, model: torch.nn.Module, images: torch.Tensor) -> torch.Tensor:
167
- """
168
- Convenience method to encode images.
169
-
170
- Args:
171
- model: Loaded AlphaCLIP model
172
- images: Preprocessed image tensor
173
-
174
- Returns:
175
- Image embeddings tensor
176
- """
177
- with torch.no_grad():
178
- image_features = model.encode_image(images)
179
-
180
- return image_features
181
-
182
- def get_similarity(self, text_features: torch.Tensor, image_features: torch.Tensor) -> torch.Tensor:
183
- """
184
- Compute cosine similarity between text and image features.
185
-
186
- Args:
187
- text_features: Text embedding tensor
188
- image_features: Image embedding tensor
189
-
190
- Returns:
191
- Similarity scores tensor
192
- """
193
- # Normalize features
194
- text_features = text_features / text_features.norm(dim=-1, keepdim=True)
195
- image_features = image_features / image_features.norm(dim=-1, keepdim=True)
196
-
197
- # Compute similarity
198
- similarity = (text_features @ image_features.T)
199
- return similarity
200
-
201
-
202
- # Convenience function for quick model loading
203
- def load_alphaclip(
204
- model_name: str = "ViT-B/16",
205
- device: Optional[str] = None,
206
- alpha_vision_ckpt_pth: str = "None",
207
- download_root = '/raid/datasets/models_weights/alphaclip',
208
- **kwargs
209
- ) -> Tuple[AlphaCLIPLoader, torch.nn.Module, callable]:
210
- """
211
- Quick function to load AlphaCLIP with a loader instance.
212
-
213
- Args:
214
- model_name: Name of the model to load
215
- device: Device to use
216
- **kwargs: Additional arguments for model loading
217
-
218
- Returns:
219
- Tuple of (loader, model, preprocess_function)
220
- """
221
- loader = AlphaCLIPLoader(default_device=device)
222
- model, preprocess = loader.load_model(model_name, **kwargs)
223
- return loader, model, preprocess
224
-
225
-
226
- # Make key functions available at module level
227
- __all__ = [
228
- 'AlphaCLIPLoader',
229
- 'load_alphaclip',
230
- 'available_models',
231
- 'load',
232
- 'tokenize'
233
- ]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
src/alphaclip/example.py DELETED
@@ -1,76 +0,0 @@
1
- #!/usr/bin/env python3
2
- """
3
- Example usage of AlphaCLIP Standalone
4
-
5
- This script demonstrates basic usage of the AlphaCLIP standalone package.
6
- """
7
-
8
- import torch
9
- import numpy as np
10
- from alphaclip_loader import AlphaCLIPLoader, load_alphaclip
11
-
12
- def main():
13
- print("AlphaCLIP Standalone Example")
14
- print("=" * 40)
15
-
16
- # Check if CUDA is available
17
- device = "cuda" if torch.cuda.is_available() else "cpu"
18
- print(f"Using device: {device}")
19
-
20
- # Method 1: Using the loader class
21
- print("\n1. Using AlphaCLIPLoader class:")
22
- loader = AlphaCLIPLoader(default_device=device)
23
-
24
- # Show available models
25
- models = loader.available_models()
26
- print(f"Available models: {models}")
27
-
28
- # Load a model
29
- print("\nLoading ViT-B/16 model...")
30
- model, preprocess = loader.load_model("ViT-B/16")
31
- print(f"Model loaded successfully!")
32
-
33
- # Test text encoding
34
- test_texts = [
35
- "a photo of a cat",
36
- "a dog running in the park",
37
- "a beautiful sunset over the ocean"
38
- ]
39
-
40
- print(f"\nEncoding {len(test_texts)} texts...")
41
- text_features = loader.encode_text(model, test_texts)
42
- print(f"Text features shape: {text_features.shape}")
43
-
44
- # Compute similarities between texts
45
- print("\nComputing text-to-text similarities:")
46
- similarities = loader.get_similarity(text_features, text_features)
47
-
48
- for i, text1 in enumerate(test_texts):
49
- for j, text2 in enumerate(test_texts):
50
- if i <= j: # Only show upper triangle
51
- sim = similarities[i, j].item()
52
- print(f" '{text1}' <-> '{text2}': {sim:.3f}")
53
-
54
- # Method 2: Using the quick loader function
55
- print("\n\n2. Using quick loader function:")
56
- loader2, model2, preprocess2 = load_alphaclip("ViT-B/16", device=device)
57
-
58
- # Test single text
59
- single_text = "a red apple on a wooden table"
60
- single_features = loader2.encode_text(model2, single_text)
61
- print(f"Single text '{single_text}' encoded to shape: {single_features.shape}")
62
-
63
- # Test tokenization
64
- print("\n3. Tokenization example:")
65
- tokens = loader.tokenize(test_texts)
66
- print(f"Tokenized {len(test_texts)} texts to shape: {tokens.shape}")
67
-
68
- # Show some token examples
69
- print("First few tokens for each text:")
70
- for i, text in enumerate(test_texts):
71
- print(f" '{text}': {tokens[i][:10].tolist()}...")
72
-
73
- print("\nExample completed successfully!")
74
-
75
- if __name__ == "__main__":
76
- main()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
src/alphaclip/requirements.txt DELETED
@@ -1,10 +0,0 @@
1
- # Core dependencies for AlphaCLIP standalone
2
- torch>=1.7.1
3
- torchvision
4
- ftfy
5
- regex
6
- tqdm
7
- loralib
8
- Pillow
9
- numpy
10
- packaging
 
 
 
 
 
 
 
 
 
 
 
src/alphaclip/setup.py DELETED
@@ -1,47 +0,0 @@
1
- """
2
- Setup script for AlphaCLIP Standalone
3
- """
4
-
5
- from setuptools import setup, find_packages
6
- import os
7
-
8
- # Read requirements
9
- with open('requirements.txt', 'r') as f:
10
- requirements = [line.strip() for line in f if line.strip() and not line.startswith('#')]
11
-
12
- # Read README if it exists
13
- readme_content = ""
14
- if os.path.exists('README.md'):
15
- with open('README.md', 'r', encoding='utf-8') as f:
16
- readme_content = f.read()
17
-
18
- setup(
19
- name="alphaclip-standalone",
20
- version="1.0.0",
21
- author="AlphaCLIP Team",
22
- description="Standalone version of AlphaCLIP for easy integration",
23
- long_description=readme_content,
24
- long_description_content_type="text/markdown",
25
- packages=find_packages(),
26
- package_data={
27
- 'alpha_clip': ['*.gz'], # Include the tokenizer vocabulary file
28
- },
29
- include_package_data=True,
30
- install_requires=requirements,
31
- python_requires=">=3.7",
32
- classifiers=[
33
- "Development Status :: 4 - Beta",
34
- "Intended Audience :: Developers",
35
- "Intended Audience :: Science/Research",
36
- "License :: OSI Approved :: MIT License",
37
- "Operating System :: OS Independent",
38
- "Programming Language :: Python :: 3",
39
- "Programming Language :: Python :: 3.7",
40
- "Programming Language :: Python :: 3.8",
41
- "Programming Language :: Python :: 3.9",
42
- "Programming Language :: Python :: 3.10",
43
- "Topic :: Scientific/Engineering :: Artificial Intelligence",
44
- "Topic :: Software Development :: Libraries :: Python Modules",
45
- ],
46
- keywords="clip, vision, language, deep learning, pytorch",
47
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
src/alphaclip/test_installation.py DELETED
@@ -1,149 +0,0 @@
1
- #!/usr/bin/env python3
2
- """
3
- Test script for AlphaCLIP Standalone
4
-
5
- This script tests the basic functionality of the standalone package
6
- to ensure everything is working correctly.
7
- """
8
-
9
- import sys
10
- import os
11
-
12
- def test_imports():
13
- """Test that all required modules can be imported."""
14
- print("Testing imports...")
15
-
16
- try:
17
- import torch
18
- print(f"โœ“ PyTorch {torch.__version__} imported successfully")
19
- except ImportError as e:
20
- print(f"โœ— Failed to import PyTorch: {e}")
21
- return False
22
-
23
- try:
24
- import torchvision
25
- print(f"โœ“ Torchvision imported successfully")
26
- except ImportError as e:
27
- print(f"โœ— Failed to import torchvision: {e}")
28
- return False
29
-
30
- try:
31
- from alphaclip_loader import AlphaCLIPLoader
32
- print("โœ“ AlphaCLIPLoader imported successfully")
33
- except ImportError as e:
34
- print(f"โœ— Failed to import AlphaCLIPLoader: {e}")
35
- return False
36
-
37
- try:
38
- import loralib
39
- print("โœ“ LoraLib imported successfully")
40
- except ImportError as e:
41
- print(f"โœ— Failed to import loralib: {e}")
42
- return False
43
-
44
- return True
45
-
46
- def test_model_loading():
47
- """Test loading a model."""
48
- print("\nTesting model loading...")
49
-
50
- try:
51
- from alphaclip_loader import AlphaCLIPLoader
52
-
53
- loader = AlphaCLIPLoader(default_device="cpu") # Use CPU for testing
54
- models = loader.available_models()
55
- print(f"โœ“ Available models: {models}")
56
-
57
- # Try to load the smallest model for testing
58
- print("Loading ViT-B/32 model (this may take a while for first download)...")
59
- model, preprocess = loader.load_model("ViT-B/32", device="cpu")
60
- print("โœ“ Model loaded successfully")
61
-
62
- return True
63
-
64
- except Exception as e:
65
- print(f"โœ— Failed to load model: {e}")
66
- return False
67
-
68
- def test_tokenization():
69
- """Test text tokenization."""
70
- print("\nTesting tokenization...")
71
-
72
- try:
73
- from alphaclip_loader import AlphaCLIPLoader
74
-
75
- loader = AlphaCLIPLoader()
76
- test_text = "a photo of a cat"
77
- tokens = loader.tokenize(test_text)
78
- print(f"โœ“ Tokenized '{test_text}' to shape {tokens.shape}")
79
-
80
- # Test batch tokenization
81
- test_texts = ["a cat", "a dog", "a bird"]
82
- batch_tokens = loader.tokenize(test_texts)
83
- print(f"โœ“ Batch tokenized {len(test_texts)} texts to shape {batch_tokens.shape}")
84
-
85
- return True
86
-
87
- except Exception as e:
88
- print(f"โœ— Failed tokenization test: {e}")
89
- return False
90
-
91
- def test_text_encoding():
92
- """Test text encoding with a loaded model."""
93
- print("\nTesting text encoding...")
94
-
95
- try:
96
- from alphaclip_loader import AlphaCLIPLoader
97
-
98
- loader = AlphaCLIPLoader(default_device="cpu")
99
- model, preprocess = loader.load_model("ViT-B/32", device="cpu")
100
-
101
- test_text = "a photo of a cat"
102
- features = loader.encode_text(model, test_text)
103
- print(f"โœ“ Encoded text to features with shape {features.shape}")
104
-
105
- # Test batch encoding
106
- test_texts = ["a cat", "a dog"]
107
- batch_features = loader.encode_text(model, test_texts)
108
- print(f"โœ“ Batch encoded {len(test_texts)} texts to shape {batch_features.shape}")
109
-
110
- return True
111
-
112
- except Exception as e:
113
- print(f"โœ— Failed text encoding test: {e}")
114
- return False
115
-
116
- def main():
117
- """Run all tests."""
118
- print("AlphaCLIP Standalone Test Suite")
119
- print("=" * 40)
120
-
121
- tests = [
122
- test_imports,
123
- test_tokenization,
124
- test_model_loading,
125
- test_text_encoding,
126
- ]
127
-
128
- passed = 0
129
- total = len(tests)
130
-
131
- for test in tests:
132
- try:
133
- if test():
134
- passed += 1
135
- except Exception as e:
136
- print(f"โœ— Test {test.__name__} failed with exception: {e}")
137
-
138
- print(f"\n{'='*40}")
139
- print(f"Test Results: {passed}/{total} tests passed")
140
-
141
- if passed == total:
142
- print("๐ŸŽ‰ All tests passed! AlphaCLIP Standalone is working correctly.")
143
- return 0
144
- else:
145
- print("โŒ Some tests failed. Please check the error messages above.")
146
- return 1
147
-
148
- if __name__ == "__main__":
149
- sys.exit(main())
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
src/bbox_utils.py DELETED
@@ -1,421 +0,0 @@
1
- import torch
2
- from copy import deepcopy
3
- from PIL import ImageDraw
4
- import itertools
5
- import random
6
-
7
-
8
- def extract_bboxes_feats(patch_embeddings, bboxes, gaussian_avg=False,
9
- gaussian_bbox_variance=0.5, get_single_embedding_per_image=False,
10
- patch_size=14, attention_map=None):
11
- """
12
- if get_single_embedding_per_image is True, the weights of all the bounding boxes patches on an image will be summed and the function will return the patch weights depending on this map
13
- """
14
- N = patch_embeddings.shape[0]
15
- N_boxes = bboxes.shape[1]
16
- grid_size = int(patch_embeddings.shape[1]**0.5)
17
- device = patch_embeddings.device
18
-
19
- bboxes //= patch_size
20
- bboxes = bboxes.int()
21
-
22
- # Reshape patches to grid
23
- patch_embeddings = patch_embeddings.view(N, grid_size, grid_size, -1) # Shape (N, grid_size, grid_size, embed_dim)
24
- if attention_map is not None:
25
- attention_map = attention_map.view(N, grid_size, grid_size) # Shape (N, grid_size, grid_size)
26
- # Grid of the sum of the gaussian weights
27
- total_patch_weights = torch.zeros(N, grid_size, grid_size)
28
-
29
- # Extract boxes
30
- x1, y1, w, h = bboxes.unbind(-1) # Separate box dimensions (N, N_boxes)
31
-
32
- # Create mesh grid for slicing
33
- x2 = x1 + w # Exclusive end x
34
- y2 = y1 + h # Exclusive end y
35
-
36
- means = []
37
- for i in range(N):
38
- image_means = []
39
- for j in range(N_boxes):
40
- if bboxes[i, j].sum().item() < 0 and get_single_embedding_per_image:
41
- # this is the case where we receive a dummy box
42
- continue
43
- # Extract the region for each box
44
- region_patches = patch_embeddings[i, y1[i, j]:y2[i, j] + 1, x1[i, j]:x2[i, j] + 1, :] # (h, w, embed_dim)
45
-
46
- if attention_map is not None:
47
- patch_weights = attention_map[i, y1[i, j]:y2[i, j] + 1, x1[i, j]:x2[i, j] + 1]
48
- patch_weights /= patch_weights.sum()
49
- total_patch_weights[i, y1[i, j]:y2[i, j] + 1, x1[i, j]:x2[i, j] + 1] += patch_weights
50
-
51
- weighted_patches = region_patches * patch_weights.to(device).unsqueeze(-1) # (h, w, embed_dim)
52
- region_mean = weighted_patches.sum(dim=(0, 1)) # Weighted mean
53
-
54
- elif gaussian_avg:
55
- # Create Gaussian weights
56
- h_span, w_span = region_patches.shape[:2]
57
- y_coords, x_coords = torch.meshgrid(
58
- torch.linspace(-1, 1, h_span),
59
- torch.linspace(-1, 1, w_span),
60
- indexing="ij"
61
- )
62
- if gaussian_bbox_variance == 0:
63
- patch_weights = torch.zeros((h_span, w_span))
64
- # Determine central indices
65
- center_y = [h_span // 2] if h_span % 2 == 1 else [h_span // 2 - 1, h_span // 2]
66
- center_x = [w_span // 2] if w_span % 2 == 1 else [w_span // 2 - 1, w_span // 2]
67
- # Randomly select one of the central elements in even case
68
- center_y = random.choice(center_y)
69
- center_x = random.choice(center_x)
70
- # Set the selected central element to 1
71
- patch_weights[center_y, center_x] = 1.0
72
- else:
73
- distances = x_coords**2 + y_coords**2
74
- patch_weights = torch.exp(-distances / gaussian_bbox_variance)
75
- patch_weights = patch_weights / patch_weights.sum() # Normalize to sum to 1
76
-
77
- # Apply Gaussian weights to region patches
78
- weighted_patches = region_patches * patch_weights.to(device).unsqueeze(-1) # (h, w, embed_dim)
79
- region_mean = weighted_patches.sum(dim=(0, 1)) # Weighted mean
80
-
81
- # Recording the bbox weight inside the image patch weight map
82
- total_patch_weights[i, y1[i, j]:y2[i, j] + 1, x1[i, j]:x2[i, j] + 1] += patch_weights
83
- else:
84
- # Mean pooling case: create uniform weights
85
- h_span, w_span = region_patches.shape[:2]
86
- uniform_weights = torch.ones(h_span, w_span) / (h_span * w_span)
87
-
88
- # Update total_patch_weights for mean pooling
89
- total_patch_weights[i, y1[i,j]:y2[i,j]+1, x1[i,j]:x2[i,j]+1] += uniform_weights
90
-
91
- # Compute mean of the region
92
- region_mean = region_patches.mean(dim=(0, 1))
93
-
94
- # Store the mean
95
- image_means.append(region_mean)
96
- if not get_single_embedding_per_image:
97
- means.append(torch.stack(image_means))
98
-
99
- # Normalizing the weight map so the sum is equal to 1
100
- total_patch_weights /= total_patch_weights.sum(dim=(1,2), keepdim=True)
101
- if not get_single_embedding_per_image:
102
- return torch.stack(means) # Shape (N, N_boxes, embed_dim)
103
- else:
104
- # Expand dimensions to match embeddings
105
- total_patch_weights = total_patch_weights.unsqueeze(-1).to(device)
106
-
107
- # Compute weighted sum
108
- weighted_patch_mean = (total_patch_weights * patch_embeddings).sum(dim=(1, 2))
109
- return weighted_patch_mean
110
- # Shape (N, embed_dim)
111
-
112
- #def adjust_bbox_for_transform(image, bbox, resize_dim, crop_dim):
113
- # """
114
- # Adjusts the bounding box for a resized and center-cropped image.
115
- #
116
- # Args:
117
- # image (PIL.Image): The input image.
118
- # bbox (list): The bounding box in [x1, y1, w, h] format.
119
- # resize_dim (int): The dimension of the shortest side after resizing.
120
- # crop_dim (int): The size of the square crop.
121
- #
122
- # Returns:
123
- # list: The adjusted bounding box in [x1, y1, w, h] format.
124
- # """
125
- # x1, y1, w, h = bbox
126
- # orig_width, orig_height = image.size
127
- #
128
- # # Calculate resize scale for the shortest side
129
- # if orig_width < orig_height:
130
- # scale = resize_dim / orig_width
131
- # resized_width, resized_height = resize_dim, int(orig_height * scale)
132
- # else:
133
- # scale = resize_dim / orig_height
134
- # resized_width, resized_height = int(orig_width * scale), resize_dim
135
- #
136
- # # Scale the bounding box
137
- # x1 *= scale
138
- # y1 *= scale
139
- # w *= scale
140
- # h *= scale
141
- #
142
- # # Calculate cropping offsets
143
- # crop_x = (resized_width - crop_dim) // 2
144
- # crop_y = (resized_height - crop_dim) // 2
145
- #
146
- # # Adjust bounding box for cropping
147
- # x1 -= crop_x
148
- # y1 -= crop_y
149
- #
150
- # # Clamp the bounding box to the cropped area
151
- # x1 = max(0, x1)
152
- # y1 = max(0, y1)
153
- # w = min(w, crop_dim - x1)
154
- # h = min(h, crop_dim - y1)
155
- #
156
- # return [x1, y1, w, h]
157
-
158
- def map_traces_to_grid(traces, n_patch):
159
- grid = torch.zeros((n_patch, n_patch))
160
- patch_size = 1.0 / n_patch
161
-
162
- for trace in traces:
163
- x, y = trace['x'], trace['y']
164
- if 0 <= x <= 1 and 0 <= y <= 1:
165
- grid_x, grid_y = int(x / patch_size), int(y / patch_size)
166
- grid[min(grid_y, n_patch - 1), min(grid_x, n_patch - 1)] += 1
167
-
168
- return grid
169
-
170
- def adjust_bbox_for_transform(image, bbox, resize_dim, crop_dim):
171
- """
172
- Adjusts the bounding box for a resized and center-cropped image.
173
-
174
- Args:
175
- image (PIL.Image): The input image.
176
- bbox (list): The bounding box in [x1, y1, w, h] format.
177
- resize_dim (int): The dimension of the shortest side after resizing.
178
- crop_dim (int): The size of the square crop.
179
-
180
- Returns:
181
- list: The adjusted bounding box in [x1, y1, w, h] format.
182
- """
183
- x1, y1, w, h = bbox
184
- orig_width, orig_height = image.size
185
-
186
- # Scale factors for resizing
187
- if orig_width < orig_height:
188
- scale_w = resize_dim / orig_width
189
- scale_h = (resize_dim * orig_height) / orig_width / orig_height
190
- else:
191
- scale_h = resize_dim / orig_height
192
- scale_w = (resize_dim * orig_width) / orig_height / orig_width
193
-
194
- # New dimensions after resize
195
- new_width = int(orig_width * scale_w)
196
- new_height = int(orig_height * scale_h)
197
-
198
- # Update bounding box for resizing
199
- x1 = x1 * scale_w
200
- y1 = y1 * scale_h
201
- w = w * scale_w
202
- h = h * scale_h
203
-
204
- # Compute cropping offsets
205
- crop_x_offset = max(0, (new_width - crop_dim) // 2)
206
- crop_y_offset = max(0, (new_height - crop_dim) // 2)
207
-
208
- # Adjust bounding box for cropping
209
- x1 -= crop_x_offset
210
- y1 -= crop_y_offset
211
-
212
- # Clip bounding box to crop dimensions
213
- x1 = max(0, min(x1, crop_dim - 1))
214
- y1 = max(0, min(y1, crop_dim - 1))
215
- w = max(0, min(w, crop_dim - x1))
216
- h = max(0, min(h, crop_dim - y1))
217
-
218
- return [x1, y1, w, h]
219
-
220
-
221
-
222
- def adjust_bbox_for_transform_no_scale(image, bbox, target_width, target_height):
223
- """
224
- - Does not preserve the image scale.
225
- Adjusts the bounding box for an image resized to a fixed width and height.
226
-
227
- Args:
228
- image (PIL.Image): The original image.
229
- bbox (list): The bounding box in [x1, y1, w, h] format.
230
- target_width (int): The width of the resized image.
231
- target_height (int): The height of the resized image.
232
-
233
- Returns:
234
- list: The adjusted bounding box in [x1, y1, w, h] format.
235
- """
236
- x1, y1, w, h = bbox
237
- orig_width, orig_height = image.size
238
-
239
- # Calculate scale factors for width and height
240
- scale_w = target_width / orig_width
241
- scale_h = target_height / orig_height
242
-
243
- # Adjust the bounding box
244
- x1 = x1 * scale_w
245
- y1 = y1 * scale_h
246
- w = w * scale_w
247
- h = h * scale_h
248
-
249
- # Return the adjusted bounding box
250
- return [x1, y1, w, h]
251
-
252
-
253
- def draw_bounding_boxes(input_image, bounding_boxes, captions=[""], color="red", width=2, text_background=True, boxes_to_show = None):
254
- """
255
- Draws bounding boxes on an image.
256
-
257
- Args:
258
- image (PIL.Image): The image to draw on.
259
- bounding_boxes (list): A list of bounding boxes, each as [x1, y1, x2, y2].
260
- color (str): The color of the bounding boxes (default is red).
261
- width (int): The width of the bounding box lines (default is 2).
262
-
263
- Returns:
264
- PIL.Image: The image with bounding boxes drawn.
265
- """
266
- # Create a drawing context
267
- image = deepcopy(input_image)
268
- draw = ImageDraw.Draw( image )
269
-
270
- #scale = 720.0 / max(image.size)
271
- if boxes_to_show is not None:
272
- if isinstance(boxes_to_show, int):
273
- indexes_to_show = random.sample(range(len(bounding_boxes)), boxes_to_show)
274
- else:
275
- indexes_to_show = boxes_to_show
276
-
277
- for i, (bbox, cap ) in enumerate(itertools.zip_longest(bounding_boxes, captions, fillvalue="")):
278
-
279
- if boxes_to_show is not None:
280
- if i not in indexes_to_show: continue
281
- #bbox = [ i / scale for i in bbox ]
282
- #x1, y1, w, h = bbox
283
- x1, y1, x2, y2 = bbox
284
-
285
- #x2, y2 = x1 + w, y1 + h # Convert width/height to bottom-right corner
286
- try:
287
- draw.rectangle([x1, y1, x2, y2], outline=color, width=width)
288
- if cap != "":
289
- if text_background:
290
- left,top,right,bottom = draw.multiline_textbbox((x1,y1), cap) #textbbox
291
- draw.rectangle((left-5, top-5, right+5, bottom+5), fill="white")
292
- draw.multiline_text((x1,y1), cap, fill=color) #text
293
-
294
- except Exception as e:
295
- print("exception, i: ", i, f"{x1 = } {y1 = } {x2 = }, {y2 = }")
296
- print(e)
297
-
298
- return image
299
-
300
- def extract_bboxes_feats_double_dino(dino_model, patch_embeddings, bboxes, cls_token, registers_tokens, patch_size, return_type="cls", gaussian_bbox_variance=0.5):
301
- """
302
- Perform a forward pass of the last DINO layer with selected features, batched.
303
-
304
- Args:
305
- dino_model: The DINO model.
306
- patch_embeddings: Patch embeddings before the last layer.
307
- bboxes: Bounding boxes for each image in the batch (BS x N_BOX_MAX x 4).
308
- cls_token: CLS token embedding.
309
- return_type: Type of feature to return ('cls', 'avg', 'gaussian_avg').
310
- gaussian_bbox_variance: Variance for Gaussian averaging.
311
-
312
- Returns:
313
- bbox_features: Features for each bounding box based on return_type.
314
- """
315
- N = patch_embeddings.shape[0] # Batch size
316
- N_boxes = bboxes.shape[1] # Number of bounding boxes
317
- grid_size = int(patch_embeddings.shape[1] ** 0.5) # Assuming square grid
318
- embed_dim = patch_embeddings.shape[-1]
319
-
320
- bboxes_patch_indexes = bboxes.clone()
321
- bboxes_patch_indexes //= patch_size # Scale down bbox coordinates to match patch grid
322
- bboxes_patch_indexes = bboxes_patch_indexes.int()
323
-
324
- # Reshape patches to grid
325
- patch_embeddings = patch_embeddings.view(N, grid_size, grid_size, embed_dim) # (N, grid_size, grid_size, embed_dim)
326
-
327
- if cls_token is not None:
328
- cls_tokens = cls_token.view(N, embed_dim)
329
- if registers_tokens is not None:
330
- patches_offset = 5
331
- else:
332
- patches_offset = 1
333
- else:
334
- assert return_type != "cls"
335
- patches_offset = 0
336
- batch_outputs = []
337
-
338
- #batch_inputs = []
339
-
340
- means = []
341
- for i in range(N): # Iterate over batch
342
- image_means = []
343
-
344
- if cls_token is not None:
345
- cls_cur_img = cls_tokens[i].reshape(1, 1, embed_dim)
346
- if registers_tokens is not None:
347
- cur_img_register_tokens = registers_tokens[i].reshape(1, 4, embed_dim)
348
-
349
- for j in range(N_boxes): # Iterate over bounding boxes
350
- # Extract the region for the bounding box
351
- region_patches_xy = patch_embeddings[i, bboxes_patch_indexes[i, j, 1]:bboxes_patch_indexes[i, j, 3] + 1, bboxes_patch_indexes[i, j, 0]:bboxes_patch_indexes[i, j, 2] + 1, :]
352
- #region_patches = region_patches.reshape(-1, embed_dim) # Flatten to (num_patches, embed_dim)
353
-
354
- #region_patches = region_patches.view(-1, embed_dim) # Flatten to (num_patches, embed_dim)
355
- #cls_cur_img = cls_tokens[i].unsqueeze(0) # Add batch dimension (1, embed_dim)
356
- #region_patches = region_patches.unsqueeze(0) # Add batch dimension (1, num_patches, embed_dim)
357
- region_patches = region_patches_xy.reshape(1,-1, embed_dim)
358
- if cls_token is not None:
359
- inputs = torch.cat([cls_cur_img, region_patches], dim=1) # Concatenate along the token dimension (1, num_patches + 1, embed_dim)
360
- if registers_tokens is not None:
361
- inputs = torch.cat([cls_cur_img, cur_img_register_tokens, region_patches], dim=1) # Concatenate along the token dimension (1, num_patches + 5, embed_dim)
362
- else:
363
- inputs = torch.cat([region_patches], dim=1) # Concatenate along the token dimension (1, num_patches + 1, embed_dim)
364
-
365
- outputs = dino_model.blocks[-1](inputs) # Forward pass
366
- # shape (1, 1 + len(region_patches), 768)
367
- #cls_cur_img = cls_tokens[i]
368
- #cls_cur_img = cls_cur_img.reshape(1, embed_dim)
369
- #inputs = torch.cat([cls_cur_img, region_patches], dim=0) # Add CLS token to inputs
370
- #outputs = dino_model.blocks[-1](inputs) # Forward pass
371
-
372
- batch_outputs.append(outputs)
373
-
374
- region_patches = outputs[0, patches_offset: ,] #(1,45,768) -> (1,1,768)
375
-
376
- if return_type == "gaussian_avg":
377
- #region_patches = outputs[5: ,]
378
- h_span, w_span = region_patches_xy.shape[:2]
379
- y_coords, x_coords = torch.meshgrid(
380
- torch.linspace(-1, 1, h_span),
381
- torch.linspace(-1, 1, w_span),
382
- indexing="ij"
383
- )
384
- distances = x_coords**2 + y_coords**2
385
- gaussian_weights = torch.exp(-distances / gaussian_bbox_variance) # Adjust 0.1 for variance control
386
- gaussian_weights = gaussian_weights / gaussian_weights.sum() # Normalize to sum to 1
387
-
388
- # Apply Gaussian weights to region patches
389
- weighted_patches = region_patches_xy * gaussian_weights.to(next(dino_model.parameters()).device).unsqueeze(-1) # (h, w, embed_dim)
390
- region_mean = weighted_patches.sum(dim=(0,1)) # Weighted mean
391
- #image_means.append(region_mean)
392
- elif return_type == "avg":
393
- # Compute mean of the region
394
- region_mean = region_patches.mean(dim=(0)) # Mean over h, w
395
- elif return_type == "cls":
396
- region_mean = outputs[0, 0, ]
397
- image_means.append(region_mean)
398
-
399
- means.append(torch.stack(image_means))
400
-
401
- stacked_means = torch.stack(means)
402
- #stacked_means = stacked_means.reshape(-1, embed_dim)
403
- return stacked_means
404
-
405
-
406
- def process_bboxes(imgs, bboxes, transform):
407
- transformed_bboxes = []
408
- bboxes = bboxes.tolist()
409
- for img, img_bboxes in zip(imgs, bboxes):
410
- for bbox in img_bboxes:
411
- # Crop the region defined by bbox
412
- x_min, y_min, w, h = bbox
413
- x_max = x_min + w
414
- y_max = y_min + h
415
- cropped_region = img.crop((x_min, y_min, x_max, y_max))
416
-
417
- # Apply the transform to the cropped region
418
- transformed_region = transform(cropped_region)
419
- transformed_bboxes.append(transformed_region)
420
-
421
- return torch.stack(transformed_bboxes)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
src/clipcap/CLIPCAP_INTEGRATION.md DELETED
@@ -1,206 +0,0 @@
1
- # ClipCap Integration with Patchioner Class
2
-
3
- This document describes how ClipCap models have been integrated into the Patchioner class for DINO feature-based image captioning.
4
-
5
- ## Overview
6
-
7
- ClipCap support has been added to the Patchioner class following the same pattern as other captioning models (VieCap, MeaCap, etc.). This integration allows you to use trained ClipCap models with DINO features for image captioning tasks.
8
-
9
- ## Architecture
10
-
11
- ### Files Added/Modified
12
-
13
- 1. **`src/clipcap/entrypoint.py`** - Main ClipCap integration module
14
- - `ClipCapModel` class for DINO feature-based captioning
15
- - Model classes: `ClipCaptionModel`, `ClipCaptionPrefix`, `MLP`, `TransformerMapper`
16
- - Text generation utilities
17
-
18
- 2. **`src/model.py`** - Modified Patchioner class
19
- - Added `clipcap_config` parameter to constructor
20
- - Added ClipCap initialization logic
21
- - Added ClipCap support to `caption_tokens` method
22
-
23
- 3. **Configuration Files**
24
- - `configs/clipcap_dino_vitb14.k.yaml` - DINOv2-B/14 configuration
25
- - `configs/clipcap_dino_vitl14.k.yaml` - DINOv2-L/14 configuration
26
-
27
- ## Configuration
28
-
29
- ### YAML Configuration Format
30
-
31
- ```yaml
32
- decap_weights: '/path/to/decap/weights.pt'
33
- prefix_size: 768 # DINO feature dimension
34
- support_memory_size: 0
35
- dino_model: 'dinov2_vitb14'
36
- normalize: True
37
- resize_dim: 518
38
- crop_dim: 518
39
- use_talk2dino_project: False
40
-
41
- # ClipCap configuration
42
- clipcap:
43
- language_model: 'gpt2'
44
- prefix_length: 10 # Sequence length for prefix
45
- clip_length: 10 # CLIP sequence length (for transformer mapping)
46
- num_layers: 8 # Number of transformer layers (for transformer mapping)
47
- mapping_type: 'mlp' # 'mlp' or 'transformer'
48
- only_prefix: True # Train only prefix mapping vs full model
49
- temperature: 1.0 # Sampling temperature
50
- top_p: 0.8 # Nucleus sampling parameter
51
- entry_length: 67 # Maximum caption length
52
- stop_token: '.' # Stop token for generation
53
- weight_path: '/path/to/trained/clipcap/model.pt'
54
- ```
55
-
56
- ### Supported DINO Models
57
-
58
- The integration automatically detects DINO feature dimensions:
59
-
60
- - **DINOv2-S/14**: 384 dimensions (`dinov2_vits14`)
61
- - **DINOv2-B/14**: 768 dimensions (`dinov2_vitb14`)
62
- - **DINOv2-L/14**: 1024 dimensions (`dinov2_vitl14`)
63
- - **DINOv2-G/14**: 1536 dimensions (`dinov2_vitg14`)
64
-
65
- ## Usage
66
-
67
- ### 1. Training ClipCap Models
68
-
69
- First, train your ClipCap model with DINO features:
70
-
71
- ```bash
72
- # Extract DINO features
73
- python clipcap_dino_parse_coco.py --dino_model_type dinov2_vitb14
74
-
75
- # Train ClipCap model
76
- python clipcapTraining.py \
77
- --use_dino \
78
- --dino_model_type dinov2_vitb14 \
79
- --prefix_length 10 \
80
- --mapping_type mlp \
81
- --only_prefix \
82
- --epochs 10
83
- ```
84
-
85
- ### 2. Using ClipCap with Patchioner
86
-
87
- ```python
88
- import torch
89
- from src.model import Patchioner
90
-
91
- # Load model with ClipCap configuration
92
- device = torch.device('cuda')
93
- model = Patchioner.from_config('configs/clipcap_dino_vitb14.k.yaml', device)
94
-
95
- # Generate captions from images
96
- imgs = torch.randn(2, 3, 518, 518).to(device) # Example batch
97
- results = model.forward(imgs, get_cls_capt=True)
98
- captions = results['cls_capt']
99
-
100
- print("Generated captions:")
101
- for i, caption in enumerate(captions):
102
- print(f"Image {i+1}: {caption}")
103
- ```
104
-
105
- ### 3. Using ClipCap Directly
106
-
107
- ```python
108
- from src.clipcap.entrypoint import ClipCapModel
109
- import torch
110
-
111
- # Configuration
112
- config = {
113
- 'language_model': 'gpt2',
114
- 'prefix_length': 10,
115
- 'mapping_type': 'mlp',
116
- 'only_prefix': True,
117
- 'weight_path': '/path/to/trained/model.pt'
118
- }
119
-
120
- # Initialize model
121
- device = torch.device('cuda')
122
- clipcap = ClipCapModel(config, device, dino_feature_dim=768)
123
-
124
- # Generate captions from DINO features
125
- dino_features = torch.randn(2, 768).to(device)
126
- captions = clipcap.forward(dino_features)
127
-
128
- print(captions)
129
- ```
130
-
131
- ## Performance Improvements
132
-
133
- ### Batched Text Generation
134
-
135
- The ClipCap integration includes an efficient batched text generation implementation:
136
-
137
- - **`generate_batched()`**: Processes entire batches simultaneously
138
- - **Significant speedup**: 2-8x faster than sequential processing
139
- - **Memory efficient**: Optimized for GPU memory usage
140
- - **Configurable**: Can fallback to sequential mode if needed
141
-
142
- ### Configuration Options
143
-
144
- ```yaml
145
- clipcap:
146
- use_batched_generation: True # Enable batched generation (recommended)
147
- temperature: 1.0 # Sampling temperature
148
- top_p: 0.8 # Nucleus sampling parameter
149
- entry_length: 67 # Maximum sequence length
150
- ```
151
-
152
- ## Model Architecture Details
153
-
154
- ### ClipCap Model Structure
155
-
156
- 1. **Input**: DINO features (384/768/1024/1536 dimensions)
157
- 2. **Mapping Layer**:
158
- - **MLP**: `DINO_dim โ†’ GPT2_dim * prefix_length`
159
- - **Transformer**: Multi-layer transformer mapping
160
- 3. **GPT-2 Decoder**: Pretrained GPT-2 for text generation
161
- 4. **Output**: Natural language captions
162
-
163
- ### Key Components
164
-
165
- - **`ClipCapModel`**: Main class for DINO-to-text captioning
166
- - **`MLP`/`TransformerMapper`**: Feature mapping from DINO to GPT-2 space
167
- - **Text Generation**: Nucleus sampling with configurable parameters
168
-
169
- ## Integration with Existing Pipeline
170
-
171
- The ClipCap integration follows the established pattern:
172
-
173
- 1. **Configuration**: YAML-based configuration like other models
174
- 2. **Initialization**: Automatic DINO dimension detection
175
- 3. **Forward Pass**: Seamless integration with existing forward methods
176
- 4. **Scoring**: Optional confidence scoring support
177
-
178
- ## Testing
179
-
180
- Run the integration test:
181
-
182
- ```bash
183
- python test_clipcap_integration.py
184
- ```
185
-
186
- This test verifies:
187
- - Configuration loading from YAML
188
- - Model instantiation with ClipCap
189
- - Caption generation with dummy DINO features
190
- - Score computation functionality
191
-
192
-
193
- ## Troubleshooting
194
-
195
- ### Common Issues
196
-
197
- 1. **Dimension Mismatch**: Ensure `prefix_size` matches DINO model dimension
198
- 2. **Missing Weights**: Verify `weight_path` points to trained ClipCap model
199
- 3. **Memory Issues**: Use `only_prefix=True` for lower memory usage
200
- 4. **Generation Quality**: Tune `temperature`, `top_p`, and `entry_length`
201
-
202
- ## References
203
-
204
- - [ClipCap Paper](https://arxiv.org/abs/2111.09734)
205
- - [DINO Paper](https://arxiv.org/abs/2104.14294)
206
- - [DINOv2 Paper](https://arxiv.org/abs/2304.07193)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
src/clipcap/clipcapTrainREADME.md DELETED
@@ -1,301 +0,0 @@
1
- # ClipCap Training with DINO Features - README
2
-
3
- This guide provides instructions for training ClipCap with DINO visual features instead of CLIP features.
4
-
5
- ## Prerequisites
6
-
7
- 1. Ensure you have the required dependencies installed:
8
- - PyTorch
9
- - torchvision
10
- - transformers
11
- - tqdm
12
- - Pillow
13
- - scikit-image
14
-
15
- 2. Prepare your COCO dataset with the following structure:
16
- ```
17
- ./data/coco/
18
- โ”œโ”€โ”€ annotations/
19
- โ”‚ โ””โ”€โ”€ train_caption.json
20
- โ”œโ”€โ”€ train2014/
21
- โ”‚ โ””โ”€โ”€ COCO_train2014_*.jpg
22
- โ””โ”€โ”€ val2014/
23
- โ””โ”€โ”€ COCO_val2014_*.jpg
24
- ```
25
-
26
- ## Required Files for DINO Feature Extraction
27
-
28
- To start the DINO feature extraction for the COCO dataset, you need:
29
-
30
- ### 1. **COCO Dataset Structure**:
31
- ```
32
- /raid/datasets/coco/ # Main COCO directory (default)
33
- โ”œโ”€โ”€ train2014/ # REQUIRED: Training images
34
- โ”‚ โ””โ”€โ”€ COCO_train2014_*.jpg # Image files
35
- โ”œโ”€โ”€ val2014/ # REQUIRED: Validation images
36
- โ”‚ โ””โ”€โ”€ COCO_val2014_*.jpg # Image files
37
- โ””โ”€โ”€ train_split_karpathy.json # REQUIRED: Karpathy format annotations (default)
38
- ```
39
-
40
- ### 2. **Required Files**:
41
- - **`train_split_karpathy.json`**: COCO caption annotations in Karpathy format (default)
42
- - **Training images**: COCO 2014 training set (COCO_train2014_*.jpg)
43
- - **Validation images**: COCO 2014 validation set (COCO_val2014_*.jpg)
44
-
45
- ### 3. **Annotation Format Support**:
46
-
47
- The script supports two annotation formats:
48
-
49
- #### **A. Karpathy Format** (default, recommended):
50
- ```json
51
- {
52
- "images": [
53
- {"id": 522418, "file_name": "COCO_val2014_000000522418.jpg"}
54
- ],
55
- "annotations": [
56
- {"image_id": 522418, "id": 0, "caption": "A woman wearing a net..."}
57
- ]
58
- }
59
- ```
60
-
61
- #### **B. ClipCap Format** (legacy):
62
- ```json
63
- [
64
- {"image_id": 522418, "caption": "A woman wearing a net..."}
65
- ]
66
- ```
67
-
68
- ### 3. **Specifying Custom Input/Output Paths**:
69
-
70
- You can customize the paths using command-line arguments:
71
-
72
- ```bash
73
- python clipcap_dino_parse_coco.py \
74
- --dino_model_type dinov2_vitb14 \
75
- --coco_images_dir "/path/to/your/coco/dataset" \
76
- --captions_file "/path/to/your/train_caption.json" \
77
- --output_file "/path/to/output/dino_features.pkl"
78
- ```
79
-
80
- **Available path arguments**:
81
- - `--coco_images_dir`: Path to COCO images directory (should contain `train2014/` and `val2014/` subdirs) - **Default: `/raid/datasets/coco`**
82
- - `--captions_file`: Path to COCO captions JSON file (supports both Karpathy and ClipCap formats) - **Default: `/raid/datasets/coco/train_split_karpathy.json`**
83
- - `--output_file`: Custom output file path (optional, auto-generated if not specified)
84
-
85
- ### 4. **Default Behavior** (if no paths specified):
86
- ```bash
87
- # This will use default paths for your setup:
88
- python clipcap_dino_parse_coco.py --dino_model_type dinov2_vitb14
89
-
90
- # Equivalent to:
91
- python clipcap_dino_parse_coco.py \
92
- --dino_model_type dinov2_vitb14 \
93
- --coco_images_dir "/raid/datasets/coco" \
94
- --captions_file "/raid/datasets/coco/train_split_karpathy.json" \
95
- --output_file "/raid/datasets/coco/coco_karpathy_split_dinov2_vitb14_train.pkl"
96
- ```
97
-
98
- ## Step 1: Extract DINO Features
99
-
100
- First, extract DINO features from the COCO images using the modified feature extraction script:
101
-
102
- ### For DINOv2-B/14 (768-dim features):
103
- ```bash
104
- # Default paths (uses /raid/datasets/coco and Karpathy annotations)
105
- python clipcap_dino_parse_coco.py --dino_model_type dinov2_vitb14 --resize_dim 518 --crop_dim 518
106
-
107
- # Custom paths
108
- python clipcap_dino_parse_coco.py \
109
- --dino_model_type dinov2_vitb14 \
110
- --coco_images_dir "/your/coco/path" \
111
- --captions_file "/your/coco/train_split_karpathy.json" \
112
- --output_file "/your/output/dino_vitb14_features.pkl"
113
- ```
114
-
115
- ### For DINOv2-L/14 (1024-dim features):
116
- ```bash
117
- # Default paths
118
- python clipcap_dino_parse_coco.py --dino_model_type dinov2_vitl14 --resize_dim 518 --crop_dim 518
119
-
120
- # Custom paths
121
- python clipcap_dino_parse_coco.py \
122
- --dino_model_type dinov2_vitl14 \
123
- --coco_images_dir "/your/coco/path" \
124
- --output_file "/your/output/dino_vitl14_features.pkl"
125
- ```
126
-
127
- ### For DINOv2-S/14 (384-dim features):
128
- ```bash
129
- python clipcap_dino_parse_coco.py --dino_model_type dinov2_vits14 --resize_dim 518 --crop_dim 518
130
- ```
131
-
132
- ### For DINOv2-G/14 (1536-dim features):
133
- ```bash
134
- python clipcap_dino_parse_coco.py --dino_model_type dinov2_vitg14 --resize_dim 518 --crop_dim 518
135
- ```
136
-
137
- **Output**: This will create a file like `/raid/datasets/models_weights/clipcap/training-features/coco_karpathy_split_dinov2_vitb14_train.pkl` (or your custom path) containing the DINO features and captions.
138
-
139
- ### Check Available Arguments:
140
- ```bash
141
- python clipcap_dino_parse_coco.py --help
142
- ```
143
-
144
- ## Step 2: Train ClipCap with DINO Features
145
-
146
- ### Basic Training Command (MLP with sequence length 10):
147
-
148
- For **DINOv2-B/14** with **MLP mapping** and **prefix length 10**:
149
- ```bash
150
- python clipcapTraining.py \
151
- --data /raid/datasets/models_weights/clipcap/training-features/coco_karpathy_split_dinov2_vitb14_train.pkl \
152
- --out_dir ./checkpoints_dino_vitb14_mlp_len10 \
153
- --prefix dino_vitb14_mlp_len10 \
154
- --epochs 10 \
155
- --save_every 2 \
156
- --prefix_length 10 \
157
- --bs 32 \
158
- --mapping_type mlp \
159
- --use_dino \
160
- --dino_model_type dinov2_vitb14 \
161
- --only_prefix
162
- ```
163
-
164
- ### Training Options for Different DINO Models:
165
-
166
- #### DINOv2-L/14 (1024-dim):
167
- ```bash
168
- python clipcapTraining.py \
169
- --data ./data/coco/coco_karpathy_split_dinov2_vitl14_train.pkl \
170
- --out_dir ./checkpoints_dino_vitl14_mlp_len10 \
171
- --prefix dino_vitl14_mlp_len10 \
172
- --epochs 10 \
173
- --save_every 2 \
174
- --prefix_length 10 \
175
- --bs 32 \
176
- --mapping_type mlp \
177
- --use_dino \
178
- --dino_model_type dinov2_vitl14 \
179
- --only_prefix
180
- ```
181
-
182
- #### DINOv2-S/14 (384-dim):
183
- ```bash
184
- python clipcapTraining.py \
185
- --data ./data/coco/coco_karpathy_split_dinov2_vits14_train.pkl \
186
- --out_dir ./checkpoints_dino_vits14_mlp_len10 \
187
- --prefix dino_vits14_mlp_len10 \
188
- --epochs 10 \
189
- --save_every 2 \
190
- --prefix_length 10 \
191
- --bs 32 \
192
- --mapping_type mlp \
193
- --use_dino \
194
- --dino_model_type dinov2_vits14 \
195
- --only_prefix
196
- ```
197
-
198
- ### Advanced Training Options:
199
-
200
- #### Train both prefix and GPT (full model):
201
- ```bash
202
- python clipcapTraining.py \
203
- --data /raid/datasets/models_weights/clipcap/training-features/coco_karpathy_split_dinov2_vitb14_train.pkl \
204
- --out_dir ./checkpoints_dino_vitb14_mlp_len10_full \
205
- --prefix dino_vitb14_mlp_len10_full \
206
- --epochs 10 \
207
- --save_every 2 \
208
- --prefix_length 10 \
209
- --bs 16 \
210
- --mapping_type mlp \
211
- --use_dino \
212
- --dino_model_type dinov2_vitb14
213
- ```
214
-
215
- #### Use Transformer mapping instead of MLP:
216
- ```bash
217
- python clipcapTraining.py \
218
- --data /raid/datasets/models_weights/clipcap/training-features/coco_karpathy_split_dinov2_vitb14_train.pkl \
219
- --out_dir ./checkpoints_dino_vitb14_transformer_len10 \
220
- --prefix dino_vitb14_transformer_len10 \
221
- --epochs 10 \
222
- --save_every 2 \
223
- --prefix_length 10 \
224
- --bs 32 \
225
- --mapping_type transformer \
226
- --num_layers 8 \
227
- --use_dino \
228
- --dino_model_type dinov2_vitb14 \
229
- --only_prefix
230
- ```
231
-
232
- #### Custom feature dimension (if needed):
233
- ```bash
234
- python clipcapTraining.py \
235
- --data ./data/coco/coco_karpathy_split_dinov2_vitb14_train.pkl \
236
- --out_dir ./checkpoints_dino_custom \
237
- --prefix dino_custom \
238
- --epochs 10 \
239
- --prefix_length 10 \
240
- --bs 32 \
241
- --mapping_type mlp \
242
- --use_dino \
243
- --dino_model_type dinov2_vitb14 \
244
- --dino_feature_dim 768 \
245
- --only_prefix
246
- ```
247
-
248
- ## Key Parameters Explanation:
249
-
250
- - `--use_dino`: Enable DINO mode (required for DINO training)
251
- - `--dino_model_type`: Specify which DINO model was used for feature extraction
252
- - `--dino_feature_dim`: Override automatic feature dimension detection
253
- - `--prefix_length`: Number of prefix tokens (set to 10 as requested)
254
- - `--mapping_type`: Choose between 'mlp' or 'transformer' mapping
255
- - `--only_prefix`: Train only the mapping layer, freeze GPT-2
256
- - `--bs`: Batch size (adjust based on GPU memory)
257
- - `--epochs`: Number of training epochs
258
- - `--save_every`: Save checkpoint every N epochs
259
-
260
- ## Expected Feature Dimensions:
261
-
262
- - **DINOv2-S/14**: 384 dimensions
263
- - **DINOv2-B/14**: 768 dimensions
264
- - **DINOv2-L/14**: 1024 dimensions
265
- - **DINOv2-G/14**: 1536 dimensions
266
-
267
- ## Training Tips:
268
-
269
- 1. **Memory Usage**: DINO features are typically larger than CLIP features, so you might need to reduce batch size
270
- 2. **Convergence**: DINO-based models may require different learning rates or longer training
271
- 3. **Prefix Length**: Experiment with different prefix lengths (5, 10, 20) for optimal performance
272
- 4. **Mapping Type**: MLP is faster, Transformer might give better results but requires more memory
273
-
274
- ## Output:
275
-
276
- The training will save checkpoints in the specified output directory:
277
- - `{prefix}-{epoch:03d}.pt`: Model checkpoint for each epoch
278
- - `{prefix}_latest.pt`: Latest model checkpoint (updated every 10k iterations)
279
- - `{prefix}.json`: Training configuration
280
-
281
- ## Example Full Workflow:
282
-
283
- ```bash
284
- # 1. Extract DINO features
285
- python clipcap_dino_parse_coco.py --dino_model_type dinov2_vitb14
286
-
287
- # 2. Train ClipCap with DINO features (MLP, length 10, prefix-only)
288
- python clipcapTraining.py \
289
- --data /raid/datasets/models_weights/clipcap/training-features/coco_karpathy_split_dinov2_vitb14_train.pkl \
290
- --out_dir ./checkpoints_dino_vitb14_mlp_len10 \
291
- --prefix dino_vitb14_mlp_len10 \
292
- --epochs 10 \
293
- --prefix_length 10 \
294
- --bs 32 \
295
- --mapping_type mlp \
296
- --use_dino \
297
- --dino_model_type dinov2_vitb14 \
298
- --only_prefix
299
- ```
300
-
301
- This will train a ClipCap model using DINO features with MLP mapping and sequence length 10 as requested.
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
src/clipcap/clipcapTraining.py DELETED
@@ -1,405 +0,0 @@
1
- import torch
2
- import torch.nn as nn
3
- from torch.nn import functional as nnf
4
- from torch.utils.data import Dataset, DataLoader
5
- from enum import Enum
6
- from transformers import GPT2Tokenizer, GPT2LMHeadModel, AdamW, get_linear_schedule_with_warmup
7
- from tqdm import tqdm
8
- import os
9
- import pickle
10
- import sys
11
- import argparse
12
- import json
13
- from typing import Tuple, Optional, Union
14
-
15
-
16
- class MappingType(Enum):
17
- MLP = 'mlp'
18
- Transformer = 'transformer'
19
-
20
-
21
- class ClipCocoDataset(Dataset):
22
-
23
- def __len__(self) -> int:
24
- return len(self.captions_tokens)
25
-
26
- def pad_tokens(self, item: int):
27
- tokens = self.captions_tokens[item]
28
- padding = self.max_seq_len - tokens.shape[0]
29
- if padding > 0:
30
- tokens = torch.cat((tokens, torch.zeros(padding, dtype=torch.int64) - 1))
31
- self.captions_tokens[item] = tokens
32
- elif padding < 0:
33
- tokens = tokens[:self.max_seq_len]
34
- self.captions_tokens[item] = tokens
35
- mask = tokens.ge(0) # mask is zero where we out of sequence
36
- tokens[~mask] = 0
37
- mask = mask.float()
38
- mask = torch.cat((torch.ones(self.prefix_length), mask), dim=0) # adding prefix mask
39
- return tokens, mask
40
-
41
- def __getitem__(self, item: int) -> Tuple[torch.Tensor, ...]:
42
- tokens, mask = self.pad_tokens(item)
43
- prefix = self.prefixes[self.caption2embedding[item]]
44
- if self.normalize_prefix:
45
- prefix = prefix.float()
46
- prefix = prefix / prefix.norm(2, -1)
47
- return tokens, mask, prefix
48
-
49
- def __init__(self, data_path: str, prefix_length: int, gpt2_type: str = "gpt2",
50
- normalize_prefix=False):
51
- self.tokenizer = GPT2Tokenizer.from_pretrained(gpt2_type)
52
- self.prefix_length = prefix_length
53
- self.normalize_prefix = normalize_prefix
54
- with open(data_path, 'rb') as f:
55
- all_data = pickle.load(f)
56
- print("Data size is %0d" % len(all_data["clip_embedding"]))
57
- sys.stdout.flush()
58
- self.prefixes = all_data["clip_embedding"]
59
- captions_raw = all_data["captions"]
60
- self.image_ids = [caption["image_id"] for caption in captions_raw]
61
- self.captions = [caption['caption'] for caption in captions_raw]
62
- if os.path.isfile(f"{data_path[:-4]}_tokens.pkl"):
63
- with open(f"{data_path[:-4]}_tokens.pkl", 'rb') as f:
64
- self.captions_tokens, self.caption2embedding, self.max_seq_len = pickle.load(f)
65
- else:
66
- self.captions_tokens = []
67
- self.caption2embedding = []
68
- max_seq_len = 0
69
- for caption in captions_raw:
70
- self.captions_tokens.append(torch.tensor(self.tokenizer.encode(caption['caption']), dtype=torch.int64))
71
- self.caption2embedding.append(caption["clip_embedding"])
72
- max_seq_len = max(max_seq_len, self.captions_tokens[-1].shape[0])
73
- # self.max_seq_len = max_seq_len
74
- with open(f"{data_path[:-4]}_tokens.pkl", 'wb') as f:
75
- pickle.dump([self.captions_tokens, self.caption2embedding, max_seq_len], f)
76
- all_len = torch.tensor([len(self.captions_tokens[i]) for i in range(len(self))]).float()
77
- self.max_seq_len = min(int(all_len.mean() + all_len.std() * 10), int(all_len.max()))
78
-
79
-
80
- class MLP(nn.Module):
81
-
82
- def forward(self, x: torch.Tensor) -> torch.Tensor:
83
- return self.model(x)
84
-
85
- def __init__(self, sizes: Tuple[int, ...], bias=True, act=nn.Tanh):
86
- super(MLP, self).__init__()
87
- layers = []
88
- for i in range(len(sizes) - 1):
89
- layers.append(nn.Linear(sizes[i], sizes[i + 1], bias=bias))
90
- if i < len(sizes) - 2:
91
- layers.append(act())
92
- self.model = nn.Sequential(*layers)
93
-
94
-
95
- class MlpTransformer(nn.Module):
96
- def __init__(self, in_dim, h_dim, out_d: Optional[int] = None, act=nnf.relu, dropout=0.):
97
- super().__init__()
98
- out_d = out_d if out_d is not None else in_dim
99
- self.fc1 = nn.Linear(in_dim, h_dim)
100
- self.act = act
101
- self.fc2 = nn.Linear(h_dim, out_d)
102
- self.dropout = nn.Dropout(dropout)
103
-
104
- def forward(self, x):
105
- x = self.fc1(x)
106
- x = self.act(x)
107
- x = self.dropout(x)
108
- x = self.fc2(x)
109
- x = self.dropout(x)
110
- return x
111
-
112
- class MultiHeadAttention(nn.Module):
113
-
114
- def __init__(self, dim_self, dim_ref, num_heads, bias=True, dropout=0.):
115
- super().__init__()
116
- self.num_heads = num_heads
117
- head_dim = dim_self // num_heads
118
- self.scale = head_dim ** -0.5
119
- self.to_queries = nn.Linear(dim_self, dim_self, bias=bias)
120
- self.to_keys_values = nn.Linear(dim_ref, dim_self * 2, bias=bias)
121
- self.project = nn.Linear(dim_self, dim_self)
122
- self.dropout = nn.Dropout(dropout)
123
-
124
- def forward(self, x, y=None, mask=None):
125
- y = y if y is not None else x
126
- b, n, c = x.shape
127
- _, m, d = y.shape
128
- # b n h dh
129
- queries = self.to_queries(x).reshape(b, n, self.num_heads, c // self.num_heads)
130
- # b m 2 h dh
131
- keys_values = self.to_keys_values(y).reshape(b, m, 2, self.num_heads, c // self.num_heads)
132
- keys, values = keys_values[:, :, 0], keys_values[:, :, 1]
133
- attention = torch.einsum('bnhd,bmhd->bnmh', queries, keys) * self.scale
134
- if mask is not None:
135
- if mask.dim() == 2:
136
- mask = mask.unsqueeze(1)
137
- attention = attention.masked_fill(mask.unsqueeze(3), float("-inf"))
138
- attention = attention.softmax(dim=2)
139
- out = torch.einsum('bnmh,bmhd->bnhd', attention, values).reshape(b, n, c)
140
- out = self.project(out)
141
- return out, attention
142
-
143
-
144
- class TransformerLayer(nn.Module):
145
-
146
- def forward_with_attention(self, x, y=None, mask=None):
147
- x_, attention = self.attn(self.norm1(x), y, mask)
148
- x = x + x_
149
- x = x + self.mlp(self.norm2(x))
150
- return x, attention
151
-
152
- def forward(self, x, y=None, mask=None):
153
- x = x + self.attn(self.norm1(x), y, mask)[0]
154
- x = x + self.mlp(self.norm2(x))
155
- return x
156
-
157
- def __init__(self, dim_self, dim_ref, num_heads, mlp_ratio=4., bias=False, dropout=0., act=nnf.relu,
158
- norm_layer: nn.Module = nn.LayerNorm):
159
- super().__init__()
160
- self.norm1 = norm_layer(dim_self)
161
- self.attn = MultiHeadAttention(dim_self, dim_ref, num_heads, bias=bias, dropout=dropout)
162
- self.norm2 = norm_layer(dim_self)
163
- self.mlp = MlpTransformer(dim_self, int(dim_self * mlp_ratio), act=act, dropout=dropout)
164
-
165
-
166
- class Transformer(nn.Module):
167
-
168
- def forward_with_attention(self, x, y=None, mask=None):
169
- attentions = []
170
- for layer in self.layers:
171
- x, att = layer.forward_with_attention(x, y, mask)
172
- attentions.append(att)
173
- return x, attentions
174
-
175
- def forward(self, x, y=None, mask=None):
176
- for i, layer in enumerate(self.layers):
177
- if i % 2 == 0 and self.enc_dec: # cross
178
- x = layer(x, y)
179
- elif self.enc_dec: # self
180
- x = layer(x, x, mask)
181
- else: # self or cross
182
- x = layer(x, y, mask)
183
- return x
184
-
185
- def __init__(self, dim_self: int, num_heads: int, num_layers: int, dim_ref: Optional[int] = None,
186
- mlp_ratio: float = 2., act=nnf.relu, norm_layer: nn.Module = nn.LayerNorm, enc_dec: bool = False):
187
- super(Transformer, self).__init__()
188
- dim_ref = dim_ref if dim_ref is not None else dim_self
189
- self.enc_dec = enc_dec
190
- if enc_dec:
191
- num_layers = num_layers * 2
192
- layers = []
193
- for i in range(num_layers):
194
- if i % 2 == 0 and enc_dec: # cross
195
- layers.append(TransformerLayer(dim_self, dim_ref, num_heads, mlp_ratio, act=act, norm_layer=norm_layer))
196
- elif enc_dec: # self
197
- layers.append(TransformerLayer(dim_self, dim_self, num_heads, mlp_ratio, act=act, norm_layer=norm_layer))
198
- else: # self or cross
199
- layers.append(TransformerLayer(dim_self, dim_ref, num_heads, mlp_ratio, act=act, norm_layer=norm_layer))
200
- self.layers = nn.ModuleList(layers)
201
-
202
-
203
- class TransformerMapper(nn.Module):
204
-
205
- def forward(self, x):
206
- x = self.linear(x).view(x.shape[0], self.clip_length, -1)
207
- prefix = self.prefix_const.unsqueeze(0).expand(x.shape[0], *self.prefix_const.shape)
208
- prefix = torch.cat((x, prefix), dim=1)
209
- out = self.transformer(prefix)[:, self.clip_length:]
210
- return out
211
-
212
- def __init__(self, dim_clip: int, dim_embedding: int, prefix_length: int, clip_length: int, num_layers: int = 8):
213
- super(TransformerMapper, self).__init__()
214
- self.clip_length = clip_length
215
- self.transformer = Transformer(dim_embedding, 8, num_layers)
216
- self.linear = nn.Linear(dim_clip, clip_length * dim_embedding)
217
- self.prefix_const = nn.Parameter(torch.randn(prefix_length, dim_embedding), requires_grad=True)
218
-
219
-
220
- class ClipCaptionModel(nn.Module):
221
-
222
- def get_dummy_token(self, batch_size: int, device: torch.device) -> torch.Tensor:
223
- return torch.zeros(batch_size, self.prefix_length, dtype=torch.int64, device=device)
224
-
225
- def forward(self, tokens: torch.Tensor, prefix: torch.Tensor, mask: Optional[torch.Tensor] = None,
226
- labels: Optional[torch.Tensor] = None):
227
- embedding_text = self.gpt.transformer.wte(tokens)
228
- prefix_projections = self.clip_project(prefix).view(-1, self.prefix_length, self.gpt_embedding_size)
229
- embedding_cat = torch.cat((prefix_projections, embedding_text), dim=1)
230
- if labels is not None:
231
- dummy_token = self.get_dummy_token(tokens.shape[0], tokens.device)
232
- labels = torch.cat((dummy_token, tokens), dim=1)
233
- out = self.gpt(inputs_embeds=embedding_cat, labels=labels, attention_mask=mask)
234
- return out
235
-
236
- def __init__(self, prefix_length: int, clip_length: Optional[int] = None, prefix_size: int = 512,
237
- num_layers: int = 8, mapping_type: MappingType = MappingType.MLP):
238
- super(ClipCaptionModel, self).__init__()
239
- self.prefix_length = prefix_length
240
- self.gpt = GPT2LMHeadModel.from_pretrained('gpt2')
241
- self.gpt_embedding_size = self.gpt.transformer.wte.weight.shape[1]
242
- if mapping_type == MappingType.MLP:
243
- self.clip_project = MLP((prefix_size, (self.gpt_embedding_size * prefix_length) // 2,
244
- self.gpt_embedding_size * prefix_length))
245
- else:
246
- self.clip_project = TransformerMapper(prefix_size, self.gpt_embedding_size, prefix_length,
247
- clip_length, num_layers)
248
-
249
-
250
- class ClipCaptionPrefix(ClipCaptionModel):
251
-
252
- def parameters(self, recurse: bool = True):
253
- return self.clip_project.parameters()
254
-
255
- def train(self, mode: bool = True):
256
- super(ClipCaptionPrefix, self).train(mode)
257
- self.gpt.eval()
258
- return self
259
-
260
-
261
- def save_config(args: argparse.Namespace):
262
- config = {}
263
- for key, item in args._get_kwargs():
264
- config[key] = item
265
- out_path = os.path.join(args.out_dir, f"{args.prefix}.json")
266
- with open(out_path, 'w') as outfile:
267
- json.dump(config, outfile)
268
-
269
-
270
- def load_model(config_path: str, epoch_or_latest: Union[str, int] = '_latest'):
271
- with open(config_path) as f:
272
- config = json.load(f)
273
- parser = argparse.ArgumentParser()
274
- parser.set_defaults(**config)
275
- args = parser.parse_args()
276
- if type(epoch_or_latest) is int:
277
- epoch_or_latest = f"-{epoch_or_latest:03d}"
278
- model_path = os.path.join(args.out_dir, f"{args.prefix}{epoch_or_latest}.pt")
279
- if args.only_prefix:
280
- model = ClipCaptionPrefix(args.prefix_length)
281
- else:
282
- model = ClipCaptionModel(args.prefix_length)
283
- if os.path.isfile(model_path):
284
- print(f"loading model from {model_path}")
285
- model.load_state_dict(torch.load(model_path, map_location=torch.device('cpu')))
286
- else:
287
- print(f"{model_path} is not exist")
288
- return model, parser
289
-
290
-
291
- def train(dataset: ClipCocoDataset, model: ClipCaptionModel, args,
292
- lr: float = 2e-5, warmup_steps: int = 5000, output_dir: str = ".", output_prefix: str = "", device = torch.device('cuda:0')):
293
-
294
- batch_size = args.bs
295
- epochs = args.epochs
296
- if not os.path.exists(output_dir):
297
- os.makedirs(output_dir)
298
- model = model.to(device)
299
- model.train()
300
- optimizer = AdamW(model.parameters(), lr=lr)
301
- train_dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True, drop_last=True)
302
- scheduler = get_linear_schedule_with_warmup(
303
- optimizer, num_warmup_steps=warmup_steps, num_training_steps=epochs * len(train_dataloader)
304
- )
305
- # save_config(args)
306
- for epoch in range(epochs):
307
- print(f">>> Training epoch {epoch}")
308
- sys.stdout.flush()
309
- progress = tqdm(total=len(train_dataloader), desc=output_prefix)
310
- for idx, (tokens, mask, prefix) in enumerate(train_dataloader):
311
- model.zero_grad()
312
- tokens, mask, prefix = tokens.to(device), mask.to(device), prefix.to(device, dtype=torch.float32)
313
- outputs = model(tokens, prefix, mask)
314
- logits = outputs.logits[:, dataset.prefix_length - 1: -1]
315
- loss = nnf.cross_entropy(logits.reshape(-1, logits.shape[-1]), tokens.flatten(), ignore_index=0)
316
- loss.backward()
317
- optimizer.step()
318
- scheduler.step()
319
- optimizer.zero_grad()
320
- progress.set_postfix({"loss": loss.item()})
321
- progress.update()
322
- if (idx + 1) % 10000 == 0:
323
- torch.save(
324
- model.state_dict(),
325
- os.path.join(output_dir, f"{output_prefix}_latest.pt"),
326
- )
327
- progress.close()
328
- if epoch % args.save_every == 0 or epoch == epochs - 1:
329
- torch.save(
330
- model.state_dict(),
331
- os.path.join(output_dir, f"{output_prefix}-{epoch:03d}.pt"),
332
- )
333
- return model
334
-
335
-
336
- def main():
337
- parser = argparse.ArgumentParser()
338
- parser.add_argument('--data', default='/raid/datasets/models_weights/clipcap/training-features/coco_karpathy_split_train.pkl')
339
- parser.add_argument('--out_dir', default='/raid/datasets/models_weights/clipcap/checkpoints/dinov2b14/')
340
- parser.add_argument('--prefix', default='coco_prefix', help='prefix for saved filenames')
341
- parser.add_argument('--epochs', type=int, default=10)
342
- parser.add_argument('--save_every', type=int, default=1)
343
- parser.add_argument('--prefix_length', type=int, default=10)
344
- parser.add_argument('--prefix_length_clip', type=int, default=10)
345
- parser.add_argument('--bs', type=int, default=40)
346
- parser.add_argument('--only_prefix', dest='only_prefix', action='store_true')
347
- parser.add_argument('--mapping_type', type=str, default='mlp', help='mlp/transformer')
348
- parser.add_argument('--num_layers', type=int, default=8)
349
- parser.add_argument('--is_rn', dest='is_rn', action='store_true')
350
- parser.add_argument('--normalize_prefix', dest='normalize_prefix', action='store_true')
351
- # DINO-specific arguments
352
- parser.add_argument('--use_dino', action='store_true', default=False, help='Use DINO features instead of CLIP')
353
- parser.add_argument('--dino_model_type', type=str, default='dinov2_vitb14',
354
- choices=['dinov2_vits14', 'dinov2_vitb14', 'dinov2_vitl14', 'dinov2_vitg14'],
355
- help='DINO model type')
356
- parser.add_argument('--dino_feature_dim', type=int, default=None,
357
- help='DINO feature dimension (auto-detected if None)')
358
- parser.add_argument('--device', type=str, default='cuda:0', help='Device to use for training')
359
- args = parser.parse_args()
360
-
361
- if isinstance(args.device, str):
362
- if not args.device.startswith('cuda') and not args.device.startswith('cpu'):
363
- # if it is an integer index, convert to f'cuda:{args.device}'
364
- if args.device.isdigit():
365
- args.device = f'cuda:{args.device}'
366
- else:
367
- raise ValueError(f"Invalid device string: {args.device}")
368
- args.device = torch.device(args.device)
369
-
370
- prefix_length = args.prefix_length
371
- dataset = ClipCocoDataset(args.data, prefix_length, normalize_prefix=args.normalize_prefix)
372
-
373
- # Determine prefix dimension based on model type
374
- if args.use_dino:
375
- if args.dino_feature_dim is not None:
376
- prefix_dim = args.dino_feature_dim
377
- else:
378
- # Auto-detect DINO feature dimensions
379
- dino_dims = {
380
- 'dinov2_vits14': 384,
381
- 'dinov2_vitb14': 768,
382
- 'dinov2_vitl14': 1024,
383
- 'dinov2_vitg14': 1536
384
- }
385
- prefix_dim = dino_dims.get(args.dino_model_type, 768)
386
- print(f"Using DINO features with dimension: {prefix_dim}")
387
- else:
388
- prefix_dim = 640 if args.is_rn else 512
389
- print(f"Using CLIP features with dimension: {prefix_dim}")
390
-
391
- args.mapping_type = {'mlp': MappingType.MLP, 'transformer': MappingType.Transformer}[args.mapping_type]
392
- if args.only_prefix:
393
- model = ClipCaptionPrefix(prefix_length, clip_length=args.prefix_length_clip, prefix_size=prefix_dim,
394
- num_layers=args.num_layers, mapping_type=args.mapping_type)
395
- print("Train only prefix")
396
- else:
397
- model = ClipCaptionModel(prefix_length, clip_length=args.prefix_length_clip, prefix_size=prefix_dim,
398
- num_layers=args.num_layers, mapping_type=args.mapping_type)
399
- print("Train both prefix and GPT")
400
- sys.stdout.flush()
401
- train(dataset, model, args, output_dir=args.out_dir, output_prefix=args.prefix, device=args.device)
402
-
403
-
404
- if __name__ == '__main__':
405
- main()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
src/clipcap/clipcap_dino_parse_coco.py DELETED
@@ -1,613 +0,0 @@
1
- import torch
2
- import torch.nn.functional as F
3
- import skimage.io as io
4
- from PIL import Image
5
- import pickle
6
- import json
7
- import os
8
- from tqdm import tqdm
9
- import argparse
10
- import torchvision.transforms as T
11
- import numpy as np
12
- import yaml
13
- import clip
14
- import sys
15
-
16
- # Add the src directory to the path so we can import ProjectionLayer
17
- sys.path.append(os.path.join(os.path.dirname(__file__), '../..', 'src'))
18
-
19
-
20
- # Container to store intermediate outputs for feature extraction
21
- feats = {}
22
-
23
- def get_self_attention(module, input, output):
24
- """Hook to capture self-attention weights"""
25
- global qkv_attention_out
26
- qkv_attention_out = output
27
-
28
- def get_layer_n_output(module, input, output):
29
- """Hook to capture intermediate layer output"""
30
- feats['intermediate_output'] = output
31
-
32
- def transform_to_standard_dino_out(x, model, num_global_tokens=1):
33
- """Transform raw DINO output to standardized format"""
34
- x_norm = model.norm(x)
35
- if num_global_tokens == 1:
36
- # Standard model without registers
37
- return {
38
- "x_norm_clstoken": x_norm[:, 0],
39
- "x_norm_regtokens": None,
40
- "x_norm_patchtokens": x_norm[:, 1:],
41
- "x_prenorm": x,
42
- }
43
- else:
44
- # Model with registers (num_global_tokens = 5)
45
- return {
46
- "x_norm_clstoken": x_norm[:, 0],
47
- "x_norm_regtokens": x_norm[:, 1:num_global_tokens],
48
- "x_norm_patchtokens": x_norm[:, num_global_tokens:],
49
- "x_prenorm": x,
50
- }
51
-
52
- def process_self_attention(output, batch_size, num_tokens, num_attn_heads, embed_dim, scale, num_global_tokens, ret_self_attn_maps=False):
53
- """Process self-attention output to compute attention weights"""
54
- qkv = output.reshape(batch_size, num_tokens, 3, num_attn_heads, embed_dim // num_attn_heads).permute(2, 0, 3, 1, 4)
55
- q, k, v = qkv[0] * scale, qkv[1], qkv[2]
56
- attn = q @ k.transpose(-2, -1)
57
- self_attn_maps = attn[:, :, 0, num_global_tokens:] # CLS token attention to patches
58
- self_attn = self_attn_maps.mean(dim=1) # Average over attention heads
59
- self_attn = self_attn.softmax(dim=-1)
60
- if ret_self_attn_maps:
61
- return self_attn, self_attn_maps
62
- else:
63
- return self_attn
64
-
65
-
66
- # Global variables to store hook outputs
67
- dino_layer_n_output = None
68
- qkv_attention_out = None
69
-
70
- def get_layer_n_output(module, input, output):
71
- """Hook to capture intermediate layer output"""
72
- global dino_layer_n_output
73
- dino_layer_n_output = output
74
-
75
-
76
- def select_most_significant_patch(dino_outs, self_attn, criteria, cls_token=None, caption_embedding=None):
77
- """
78
- Select the most significant patch token based on different criteria.
79
-
80
- Args:
81
- dino_outs: Dictionary containing normalized DINO outputs
82
- self_attn: Self-attention weights from CLS to patches [batch_size, num_patches]
83
- criteria: Selection criteria ('max_attention', 'most_similar_to_cls', etc.)
84
- cls_token: CLS token embeddings [batch_size, embed_dim]
85
- caption_embedding: Text caption embeddings [batch_size, embed_dim]
86
-
87
- Returns:
88
- selected_patches: [batch_size, embed_dim] - Selected patch embeddings
89
- """
90
- patch_tokens = dino_outs['x_norm_patchtokens'] # [batch_size, num_patches, embed_dim]
91
- batch_size, num_patches, embed_dim = patch_tokens.shape
92
-
93
- if criteria == "max_attention":
94
- # Select patch with highest attention weight from CLS token
95
- if self_attn is None:
96
- raise ValueError("self_attn required for max_attention criteria")
97
- max_attn_indices = self_attn.argmax(dim=1) # [batch_size]
98
- selected_patches = patch_tokens[torch.arange(batch_size), max_attn_indices]
99
-
100
- elif criteria == "most_similar_to_cls":
101
- # Select patch most similar to CLS token using cosine similarity
102
- if cls_token is None:
103
- raise ValueError("cls_token required for most_similar_to_cls criteria")
104
- # Compute cosine similarity between CLS and all patches
105
- cls_normalized = F.normalize(cls_token, p=2, dim=1) # [batch_size, embed_dim]
106
- patches_normalized = F.normalize(patch_tokens, p=2, dim=2) # [batch_size, num_patches, embed_dim]
107
- similarities = torch.bmm(patches_normalized, cls_normalized.unsqueeze(2)).squeeze(2) # [batch_size, num_patches]
108
- max_sim_indices = similarities.argmax(dim=1) # [batch_size]
109
- selected_patches = patch_tokens[torch.arange(batch_size), max_sim_indices]
110
-
111
- elif criteria == "most_similar_to_caption":
112
- # Select patch most similar to caption embedding
113
- if caption_embedding is None:
114
- raise ValueError("caption_embedding required for most_similar_to_caption criteria")
115
- caption_normalized = F.normalize(caption_embedding, p=2, dim=1) # [batch_size, embed_dim]
116
- patches_normalized = F.normalize(patch_tokens, p=2, dim=2) # [batch_size, num_patches, embed_dim]
117
- similarities = torch.bmm(patches_normalized, caption_normalized.unsqueeze(2)).squeeze(2) # [batch_size, num_patches]
118
- max_sim_indices = similarities.argmax(dim=1) # [batch_size]
119
- selected_patches = patch_tokens[torch.arange(batch_size), max_sim_indices]
120
-
121
- elif criteria == "max_norm":
122
- # Select patch with highest L2 norm
123
- patch_norms = torch.norm(patch_tokens, p=2, dim=2) # [batch_size, num_patches]
124
- max_norm_indices = patch_norms.argmax(dim=1) # [batch_size]
125
- selected_patches = patch_tokens[torch.arange(batch_size), max_norm_indices]
126
-
127
- elif criteria == "centroid_distance":
128
- # Select patch farthest from the centroid of all patches
129
- centroid = patch_tokens.mean(dim=1, keepdim=True) # [batch_size, 1, embed_dim]
130
- distances = torch.norm(patch_tokens - centroid, p=2, dim=2) # [batch_size, num_patches]
131
- max_dist_indices = distances.argmax(dim=1) # [batch_size]
132
- selected_patches = patch_tokens[torch.arange(batch_size), max_dist_indices]
133
-
134
- else:
135
- raise ValueError(f"Unknown patch selection criteria: {criteria}")
136
-
137
- return selected_patches
138
-
139
-
140
- def load_text_encoder(text_encoder_path, device, config_path=None):
141
- """
142
- Load a text encoder model for caption similarity.
143
- Supports Talk2Dino, CLIP, and DINO.txt-based text encoders.
144
- """
145
- if text_encoder_path is None:
146
- return None
147
-
148
- print(f"Loading text encoder from: {text_encoder_path}")
149
-
150
- # Check for DINO.txt model
151
- if text_encoder_path.lower() == 'dinotxt' or text_encoder_path.lower() == 'dino.txt':
152
- # Load DINO.txt model
153
- try:
154
- from src.dinotxt_utils import get_tokenizer
155
-
156
- print("Loading DINO.txt model...")
157
- dinotxt_model = torch.hub.load('facebookresearch/dinov2', 'dinov2_vitl14_reg4_dinotxt_tet1280d20h24l')
158
- dinotxt_model.eval()
159
- dinotxt_model.to(device)
160
-
161
- tokenizer = get_tokenizer()
162
-
163
- return {
164
- 'type': 'dinotxt',
165
- 'model': dinotxt_model,
166
- 'tokenizer': tokenizer
167
- }
168
-
169
- except ImportError:
170
- raise ImportError("Could not import dinotxt_utils. Make sure src/dinotxt_utils.py is accessible.")
171
- except Exception as e:
172
- raise RuntimeError(f"Failed to load DINO.txt model: {e}")
173
-
174
- # Check if it's a Talk2Dino model (expect config and weights)
175
- elif text_encoder_path.endswith('.pth') or text_encoder_path.endswith('.pt'):
176
- # Use provided config or auto-find
177
- if config_path is None:
178
- # Look for corresponding config file
179
- base_path = text_encoder_path.rsplit('.', 1)[0]
180
- config_path = base_path + '.yaml'
181
-
182
- # Alternative config path patterns
183
- if not os.path.exists(config_path):
184
- # Try configs_talk2dino directory
185
- config_name = os.path.basename(base_path) + '.yaml'
186
- config_path = os.path.join(os.path.dirname(__file__), 'configs_talk2dino', config_name)
187
-
188
- if not os.path.exists(config_path):
189
- raise FileNotFoundError(f"Could not find config file for {text_encoder_path}. "
190
- f"Expected at {config_path} or specify --text_encoder_config.")
191
-
192
- # Load Talk2Dino model
193
- try:
194
- from src.model import ProjectionLayer
195
-
196
- print(f"Using config: {config_path}")
197
-
198
- # Load the projection layer
199
- talk2dino = ProjectionLayer.from_config(config_path)
200
- talk2dino.load_state_dict(torch.load(text_encoder_path, map_location=device))
201
- talk2dino.to(device)
202
- talk2dino.eval()
203
-
204
- # Load CLIP model for text encoding
205
- clip_model, _ = clip.load("ViT-B/32", device=device)
206
- clip_model.eval()
207
-
208
- return {
209
- 'type': 'talk2dino',
210
- 'talk2dino': talk2dino,
211
- 'clip_model': clip_model,
212
- 'config_path': config_path
213
- }
214
-
215
- except ImportError:
216
- raise ImportError("Could not import ProjectionLayer. Make sure src/model.py is accessible.")
217
-
218
- else:
219
- # Assume it's a direct model path (CLIP or other)
220
- try:
221
- # Try loading as a CLIP model
222
- clip_model, _ = clip.load(text_encoder_path, device=device)
223
- clip_model.eval()
224
-
225
- return {
226
- 'type': 'clip',
227
- 'clip_model': clip_model
228
- }
229
- except:
230
- raise ValueError(f"Could not load text encoder from {text_encoder_path}. "
231
- f"Supported formats: 1) 'dinotxt' or 'dino.txt' for DINO.txt model, "
232
- f"2) Talk2Dino (.pth/.pt), 3) CLIP model names.")
233
-
234
-
235
- def encode_caption(caption, text_encoder, device):
236
- """
237
- Encode a text caption using the loaded text encoder.
238
- """
239
- if text_encoder is None:
240
- return None
241
-
242
- if text_encoder['type'] == 'dinotxt':
243
- # Use DINO.txt pipeline: tokenize + encode + extract patch-aligned features
244
- with torch.no_grad():
245
- # Tokenize with DINO.txt tokenizer
246
- text_tokens = text_encoder['tokenizer'].tokenize([caption]).to(device)
247
-
248
- # Encode with DINO.txt model
249
- dinotxt_features = text_encoder['model'].encode_text(text_tokens)
250
-
251
- # Extract patch-aligned text embeddings (dimensions 1024:)
252
- # DINO.txt concatenates standard text features [0:1024] and patch-aligned features [1024:]
253
- patch_aligned_features = dinotxt_features[:, 1024:]
254
-
255
- # Normalize the features to match DINO feature space
256
- patch_aligned_features = F.normalize(patch_aligned_features, p=2, dim=-1)
257
- return patch_aligned_features
258
-
259
- elif text_encoder['type'] == 'talk2dino':
260
- # Use Talk2Dino pipeline: CLIP text encoding + Talk2Dino projection
261
- with torch.no_grad():
262
- # Tokenize and encode with CLIP
263
- text_tokens = clip.tokenize([caption]).to(device)
264
- clip_text_features = text_encoder['clip_model'].encode_text(text_tokens)
265
-
266
- # Project through Talk2Dino to DINO space
267
- dino_text_features = text_encoder['talk2dino'].project_clip_txt(clip_text_features)
268
-
269
- # Normalize the encoded text to match DINO feature space
270
- dino_text_features = F.normalize(dino_text_features, p=2, dim=-1)
271
- return dino_text_features
272
-
273
- elif text_encoder['type'] == 'clip':
274
- # Use CLIP directly
275
- with torch.no_grad():
276
- text_tokens = clip.tokenize([caption]).to(device)
277
- clip_text_features = text_encoder['clip_model'].encode_text(text_tokens)
278
-
279
- # Normalize the features
280
- clip_text_features = F.normalize(clip_text_features, p=2, dim=-1)
281
- return clip_text_features
282
-
283
- else:
284
- raise ValueError(f"Unknown text encoder type: {text_encoder['type']}")
285
-
286
-
287
- def main(dino_model_type: str, resize_dim: int = 518, crop_dim: int = 518,
288
- coco_images_dir: str = "/raid/datasets/coco/", captions_file: str = "/raid/datasets/coco/train_split_karpathy.json",
289
- output_file: str = None, feature_type: str = "cls", extract_attention: bool = False,
290
- patch_selection_criteria: str = "max_attention", text_encoder_path: str = None, text_encoder_config: str = None):
291
- """
292
- Extract DINO features from COCO images for ClipCap training.
293
-
294
- Args:
295
- feature_type: Type of features to extract
296
- - "cls": CLS token features (default)
297
- - "avg_patch": Mean pooled patch token features
298
- - "avg_self_attn": Self-attention weighted patch token features
299
- - "most_significant_patch": Single most important patch token
300
- extract_attention: Whether to extract self-attention weights (required for avg_self_attn)
301
- patch_selection_criteria: Criteria for selecting most significant patch
302
- text_encoder_path: Path to text encoder for caption similarity
303
- """
304
- device = torch.device('cuda:0')
305
- dino_model_name = dino_model_type.replace('/', '_')
306
-
307
- # Determine model properties
308
- num_global_tokens = 1 if "reg" not in dino_model_type else 5
309
- patch_size = 14 # DINOv2 uses 14x14 patches
310
- num_patch_tokens = (crop_dim // patch_size) * (crop_dim // patch_size)
311
- num_tokens = num_global_tokens + num_patch_tokens
312
-
313
- # Get embedding dimension based on model type
314
- if 'vitl' in dino_model_type:
315
- embed_dim = 1024
316
- num_attn_heads = 16
317
- elif 'vitb' in dino_model_type:
318
- embed_dim = 768
319
- num_attn_heads = 12
320
- elif 'vits' in dino_model_type:
321
- embed_dim = 384
322
- num_attn_heads = 6
323
- elif 'vitg' in dino_model_type:
324
- embed_dim = 1536
325
- num_attn_heads = 24
326
- else:
327
- raise ValueError(f"Unknown model type: {dino_model_type}")
328
-
329
- scale = (embed_dim // num_attn_heads) ** -0.5
330
-
331
- # Set default output path if not specified
332
- if output_file is None:
333
- if feature_type == "cls":
334
- feature_suffix = ""
335
- elif feature_type == "most_significant_patch":
336
- if patch_selection_criteria == "most_similar_to_caption" and text_encoder_path is not None:
337
- # Determine text encoder type from path to create unique filename
338
- if text_encoder_path.lower() in ['dinotxt', 'dino.txt']:
339
- text_encoder_suffix = "_dinotxt"
340
- elif text_encoder_path.endswith('.pth') or text_encoder_path.endswith('.pt'):
341
- text_encoder_suffix = "_t2d" # Talk2Dino
342
- else:
343
- # CLIP or other models
344
- text_encoder_suffix = "_clip"
345
- feature_suffix = f"_{feature_type}_{patch_selection_criteria}{text_encoder_suffix}"
346
- else:
347
- feature_suffix = f"_{feature_type}_{patch_selection_criteria}"
348
- else:
349
- feature_suffix = f"_{feature_type}"
350
- output_file = f"/raid/datasets/models_weights/clipcap/training-features/coco_karpathy_split_{dino_model_name}{feature_suffix}_train.pkl"
351
-
352
- # Create output directory if it doesn't exist
353
- os.makedirs(os.path.dirname(output_file), exist_ok=True)
354
-
355
- # Load DINO model
356
- print(f"Loading DINO model: {dino_model_type}")
357
- print(f"Feature type: {feature_type}")
358
- if feature_type == "most_significant_patch":
359
- print(f"Patch selection criteria: {patch_selection_criteria}")
360
- print(f"Model properties: embed_dim={embed_dim}, num_heads={num_attn_heads}, num_global_tokens={num_global_tokens}")
361
-
362
- if 'dinov2' in dino_model_type:
363
- model_family = 'facebookresearch/dinov2'
364
- dino_model = torch.hub.load(model_family, dino_model_type)
365
- else:
366
- raise ValueError(f"Unsupported DINO model type: {dino_model_type}")
367
-
368
- # Setup transforms for DINO
369
- image_transforms = T.Compose([
370
- T.Resize(resize_dim, interpolation=T.InterpolationMode.BICUBIC),
371
- T.CenterCrop(crop_dim),
372
- T.ToTensor(),
373
- T.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)),
374
- ])
375
-
376
- dino_model.eval()
377
- dino_model.to(device)
378
-
379
- # Register hooks if we need attention or intermediate outputs
380
- if feature_type == "avg_self_attn" or extract_attention or \
381
- (feature_type == "most_significant_patch" and patch_selection_criteria in ["max_attention", "most_similar_to_caption"]):
382
- print("Registering hooks for attention extraction...")
383
- dino_model.blocks[-1].attn.qkv.register_forward_hook(get_self_attention)
384
-
385
- if feature_type in ["avg_patch", "avg_self_attn", "most_significant_patch"]:
386
- print("Registering hooks for intermediate output extraction...")
387
- dino_model.blocks[-1].register_forward_hook(get_layer_n_output)
388
-
389
- # Load caption data
390
- print(f"Loading captions from: {captions_file}")
391
- with open(captions_file, 'r') as f:
392
- data = json.load(f)
393
-
394
- # Handle different annotation formats
395
- if isinstance(data, list):
396
- # Original ClipCap format: list of dicts with 'image_id' and 'caption'
397
- annotations = data
398
- print(f"{len(annotations)} captions loaded from json (ClipCap format)")
399
- elif isinstance(data, dict) and 'annotations' in data:
400
- # Karpathy format: dict with 'annotations' key
401
- annotations = data['annotations']
402
- print(f"{len(annotations)} captions loaded from json (Karpathy format)")
403
-
404
- # Create image ID to filename mapping for faster lookup
405
- if 'images' in data:
406
- image_id_to_filename = {img['id']: img['file_name'] for img in data['images']}
407
- else:
408
- image_id_to_filename = {}
409
- else:
410
- raise ValueError("Unsupported annotation format")
411
-
412
- # Load text encoder if needed for caption similarity
413
- text_encoder = None
414
- if feature_type == "most_significant_patch" and patch_selection_criteria == "most_similar_to_caption":
415
- if text_encoder_path is None:
416
- raise ValueError("text_encoder_path required for most_similar_to_caption criteria")
417
- text_encoder = load_text_encoder(text_encoder_path, device, text_encoder_config)
418
- print(f"Loaded text encoder from: {text_encoder_path}")
419
-
420
- all_embeddings = []
421
- all_captions = []
422
-
423
- print(f"Processing images from: {coco_images_dir}")
424
- print(f"Output will be saved to: {output_file}")
425
-
426
- for i, annotation in enumerate(tqdm(annotations)):
427
- img_id = annotation["image_id"]
428
-
429
- # Determine filename based on format
430
- if isinstance(data, list):
431
- # Original format: construct filename from image_id
432
- filename = os.path.join(coco_images_dir, "train2014", f"COCO_train2014_{int(img_id):012d}.jpg")
433
- if not os.path.isfile(filename):
434
- filename = os.path.join(coco_images_dir, "val2014", f"COCO_val2014_{int(img_id):012d}.jpg")
435
- else:
436
- # Karpathy format: use filename from images mapping or construct it
437
- if img_id in image_id_to_filename:
438
- if 'train' in image_id_to_filename[img_id]:
439
- fold = "train2014"
440
- else:
441
- fold = "val2014"
442
- filename = os.path.join(coco_images_dir, fold, image_id_to_filename[img_id])
443
- else:
444
- # Fallback: try to construct filename
445
- filename = os.path.join(coco_images_dir, "train2014", f"COCO_train2014_{int(img_id):012d}.jpg")
446
- if not os.path.isfile(filename):
447
- filename = os.path.join(coco_images_dir, "val2014", f"COCO_val2014_{int(img_id):012d}.jpg")
448
-
449
- if not os.path.isfile(filename):
450
- print(f"Warning: Image not found: {filename}")
451
- continue
452
-
453
- # Load and process image
454
- try:
455
- image = io.imread(filename)
456
- if len(image.shape) == 2: # grayscale
457
- image = Image.fromarray(image).convert('RGB')
458
- else:
459
- image = Image.fromarray(image)
460
- except Exception as e:
461
- print(f"Warning: Failed to load image {filename}: {e}")
462
- continue
463
-
464
- # Apply DINO transforms
465
- image_tensor = image_transforms(image).unsqueeze(0).to(device)
466
-
467
- with torch.no_grad():
468
- # Clear any previous stored data
469
- global dino_layer_n_output, qkv_attention_out
470
- dino_layer_n_output = None
471
- qkv_attention_out = None
472
-
473
- # Extract DINO features
474
- if feature_type == "cls":
475
- # Standard CLS token extraction
476
- features = dino_model(image_tensor)
477
- # For DINOv2, the output is the CLS token by default
478
- if len(features.shape) == 3: # If we get [batch, seq_len, dim]
479
- features = features[:, 0, :] # Take CLS token
480
- prefix = features.cpu()
481
- else:
482
- # For patch-based features, we need intermediate outputs
483
- _ = dino_model(image_tensor) # Forward pass to trigger hooks
484
-
485
- if dino_layer_n_output is None:
486
- raise RuntimeError("No intermediate output captured. Check hook registration.")
487
-
488
- # Transform to standard format
489
- dino_outs = transform_to_standard_dino_out(dino_layer_n_output, dino_model, num_global_tokens)
490
-
491
- if feature_type == "avg_patch":
492
- # Average of patch tokens (excluding global tokens)
493
- prefix = dino_outs['x_norm_patchtokens'].mean(dim=1) # [B, D]
494
- elif feature_type == "avg_self_attn":
495
- # Self-attention weighted average of patch tokens
496
- if qkv_attention_out is None:
497
- raise RuntimeError("No attention output captured. Check hook registration.")
498
-
499
- # Process self-attention to get attention weights
500
- batch_size = qkv_attention_out.shape[0]
501
- self_attn = process_self_attention(
502
- qkv_attention_out,
503
- batch_size,
504
- num_tokens,
505
- num_attn_heads,
506
- embed_dim,
507
- scale,
508
- num_global_tokens
509
- )
510
-
511
- # Compute attention-weighted average
512
- prefix = (self_attn.unsqueeze(-1) * dino_outs['x_norm_patchtokens']).mean(dim=1)
513
- elif feature_type == "most_significant_patch":
514
- # Select single most significant patch based on criteria
515
- self_attn = None
516
- cls_token = None
517
- caption_embedding = None
518
-
519
- # Prepare required inputs based on criteria
520
- if patch_selection_criteria in ["max_attention", "most_similar_to_caption"]:
521
- if qkv_attention_out is None:
522
- raise RuntimeError("No attention output captured. Check hook registration.")
523
- batch_size = qkv_attention_out.shape[0]
524
- self_attn = process_self_attention(
525
- qkv_attention_out,
526
- batch_size,
527
- num_tokens,
528
- num_attn_heads,
529
- embed_dim,
530
- scale,
531
- num_global_tokens
532
- )
533
-
534
- if patch_selection_criteria == "most_similar_to_cls":
535
- cls_token = dino_outs['x_norm_clstoken']
536
-
537
- if patch_selection_criteria == "most_similar_to_caption":
538
- if text_encoder is not None:
539
- caption_embedding = encode_caption(annotation["caption"], text_encoder, device)
540
-
541
- # Select the most significant patch
542
- prefix = select_most_significant_patch(
543
- dino_outs,
544
- self_attn,
545
- patch_selection_criteria,
546
- cls_token=cls_token,
547
- caption_embedding=caption_embedding
548
- )
549
- else:
550
- raise ValueError(f"Unknown feature type: {feature_type}")
551
-
552
- prefix = prefix.cpu()
553
-
554
- # Create annotation in ClipCap format for compatibility
555
- caption_entry = {
556
- "image_id": img_id,
557
- "caption": annotation["caption"],
558
- "clip_embedding": i # Index for the embedding
559
- }
560
-
561
- all_embeddings.append(prefix)
562
- all_captions.append(caption_entry)
563
-
564
- if (i + 1) % 10000 == 0:
565
- # Create output directory if it doesn't exist
566
- os.makedirs(os.path.dirname(output_file), exist_ok=True)
567
- with open(output_file, 'wb') as f:
568
- pickle.dump({"clip_embedding": torch.cat(all_embeddings, dim=0), "captions": all_captions}, f)
569
-
570
- # Create output directory if it doesn't exist
571
- os.makedirs(os.path.dirname(output_file), exist_ok=True)
572
- with open(output_file, 'wb') as f:
573
- pickle.dump({"clip_embedding": torch.cat(all_embeddings, dim=0), "captions": all_captions}, f)
574
-
575
- print('Done')
576
- print("%0d embeddings saved " % len(all_embeddings))
577
- print(f"Feature dimension: {all_embeddings[0].shape[-1]}")
578
- return 0
579
-
580
-
581
- if __name__ == '__main__':
582
- parser = argparse.ArgumentParser(description='Extract DINO features from COCO images for ClipCap training')
583
- parser.add_argument('--dino_model_type', default="dinov2_vitb14",
584
- choices=('dinov2_vits14', 'dinov2_vitb14', 'dinov2_vitl14', 'dinov2_vitg14',
585
- 'dinov2_vits14_reg', 'dinov2_vitb14_reg', 'dinov2_vitl14_reg', 'dinov2_vitg14_reg'),
586
- help='DINO model type to use for feature extraction')
587
- parser.add_argument('--feature_type', default="cls",
588
- choices=('cls', 'avg_patch', 'avg_self_attn', 'most_significant_patch'),
589
- help='Type of features to extract: cls (CLS token), avg_patch (mean pooled patches), avg_self_attn (attention-weighted patches), most_significant_patch (single most important patch)')
590
- parser.add_argument('--patch_selection_criteria', default="max_attention",
591
- choices=('max_attention', 'most_similar_to_cls', 'most_similar_to_caption', 'max_norm', 'centroid_distance'),
592
- help='Criteria for selecting the most significant patch (only used with most_significant_patch feature_type)')
593
- parser.add_argument('--text_encoder_path', type=str, default=None,
594
- help='Path to text encoder for caption similarity. Supports: 1) "dinotxt" or "dino.txt" for DINO.txt model, 2) Talk2Dino weights (.pth/.pt) - will auto-find config, 3) CLIP model names (e.g., "ViT-B/32")')
595
- parser.add_argument('--text_encoder_config', type=str, default=None,
596
- help='Optional: explicit config path for Talk2Dino models (if not auto-found)')
597
- parser.add_argument('--resize_dim', type=int, default=518, help='Resize dimension for images')
598
- parser.add_argument('--crop_dim', type=int, default=518, help='Crop dimension for images')
599
- parser.add_argument('--coco_images_dir', type=str, default="/raid/datasets/coco",
600
- help='Path to COCO images directory (should contain train2014/ and val2014/ subdirs)')
601
- parser.add_argument('--captions_file', type=str, default="/raid/datasets/coco/train_split_karpathy.json",
602
- help='Path to COCO captions JSON file (supports both Karpathy and ClipCap formats)')
603
- parser.add_argument('--output_file', type=str, default=None,
604
- help='Output pickle file path (default: auto-generated based on model and feature type)')
605
- parser.add_argument('--extract_attention', action='store_true',
606
- help='Extract attention weights (automatically enabled for avg_self_attn feature type)')
607
-
608
- args = parser.parse_args()
609
-
610
- main(args.dino_model_type, args.resize_dim, args.crop_dim,
611
- args.coco_images_dir, args.captions_file, args.output_file,
612
- args.feature_type, args.extract_attention,
613
- args.patch_selection_criteria, args.text_encoder_path, args.text_encoder_config)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
src/clipcap/clipcap_parse_coco.py DELETED
@@ -1,51 +0,0 @@
1
- import torch
2
- import skimage.io as io
3
- import clip
4
- from PIL import Image
5
- import pickle
6
- import json
7
- import os
8
- from tqdm import tqdm
9
- import argparse
10
-
11
-
12
- def main(clip_model_type: str):
13
- device = torch.device('cuda:0')
14
- clip_model_name = clip_model_type.replace('/', '_')
15
- out_path = f"./data/coco/oscar_split_{clip_model_name}_train.pkl"
16
- clip_model, preprocess = clip.load(clip_model_type, device=device, jit=False)
17
- with open('./data/coco/annotations/train_caption.json', 'r') as f:
18
- data = json.load(f)
19
- print("%0d captions loaded from json " % len(data))
20
- all_embeddings = []
21
- all_captions = []
22
- for i in tqdm(range(len(data))):
23
- d = data[i]
24
- img_id = d["image_id"]
25
- filename = f"./data/coco/train2014/COCO_train2014_{int(img_id):012d}.jpg"
26
- if not os.path.isfile(filename):
27
- filename = f"./data/coco/val2014/COCO_val2014_{int(img_id):012d}.jpg"
28
- image = io.imread(filename)
29
- image = preprocess(Image.fromarray(image)).unsqueeze(0).to(device)
30
- with torch.no_grad():
31
- prefix = clip_model.encode_image(image).cpu()
32
- d["clip_embedding"] = i
33
- all_embeddings.append(prefix)
34
- all_captions.append(d)
35
- if (i + 1) % 10000 == 0:
36
- with open(out_path, 'wb') as f:
37
- pickle.dump({"clip_embedding": torch.cat(all_embeddings, dim=0), "captions": all_captions}, f)
38
-
39
- with open(out_path, 'wb') as f:
40
- pickle.dump({"clip_embedding": torch.cat(all_embeddings, dim=0), "captions": all_captions}, f)
41
-
42
- print('Done')
43
- print("%0d embeddings saved " % len(all_embeddings))
44
- return 0
45
-
46
-
47
- if __name__ == '__main__':
48
- parser = argparse.ArgumentParser()
49
- parser.add_argument('--clip_model_type', default="ViT-B/32", choices=('RN50', 'RN101', 'RN50x4', 'ViT-B/32'))
50
- args = parser.parse_args()
51
- exit(main(args.clip_model_type))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
src/clipcap/entrypoint.py DELETED
@@ -1,564 +0,0 @@
1
- import torch
2
- from torch import nn
3
- import json
4
- import os
5
- from transformers import GPT2Tokenizer, GPT2LMHeadModel
6
- from typing import List, Optional, Tuple, Union
7
- from argparse import Namespace
8
- from enum import Enum
9
-
10
- import torch.nn.functional as nnf
11
-
12
- class MappingType(Enum):
13
- MLP = 'mlp'
14
- Transformer = 'transformer'
15
-
16
-
17
- class MLP(nn.Module):
18
- def forward(self, x: torch.Tensor) -> torch.Tensor:
19
- return self.model(x)
20
-
21
- def __init__(self, sizes: Tuple[int, ...], bias=True, act=nn.Tanh):
22
- super(MLP, self).__init__()
23
- layers = []
24
- for i in range(len(sizes) - 1):
25
- layers.append(nn.Linear(sizes[i], sizes[i + 1], bias=bias))
26
- if i < len(sizes) - 2:
27
- layers.append(act())
28
- self.model = nn.Sequential(*layers)
29
-
30
-
31
- class MlpTransformer(nn.Module):
32
- def __init__(self, in_dim, h_dim, out_d: Optional[int] = None, act=nnf.relu, dropout=0.):
33
- super().__init__()
34
- out_d = out_d if out_d is not None else in_dim
35
- self.fc1 = nn.Linear(in_dim, h_dim)
36
- self.act = act
37
- self.fc2 = nn.Linear(h_dim, out_d)
38
- self.dropout = nn.Dropout(dropout)
39
-
40
- def forward(self, x):
41
- x = self.fc1(x)
42
- x = self.act(x)
43
- x = self.dropout(x)
44
- x = self.fc2(x)
45
- x = self.dropout(x)
46
- return x
47
-
48
- class MultiHeadAttention(nn.Module):
49
-
50
- def __init__(self, dim_self, dim_ref, num_heads, bias=True, dropout=0.):
51
- super().__init__()
52
- self.num_heads = num_heads
53
- head_dim = dim_self // num_heads
54
- self.scale = head_dim ** -0.5
55
- self.to_queries = nn.Linear(dim_self, dim_self, bias=bias)
56
- self.to_keys_values = nn.Linear(dim_ref, dim_self * 2, bias=bias)
57
- self.project = nn.Linear(dim_self, dim_self)
58
- self.dropout = nn.Dropout(dropout)
59
-
60
- def forward(self, x, y=None, mask=None):
61
- y = y if y is not None else x
62
- b, n, c = x.shape
63
- _, m, d = y.shape
64
- # b n h dh
65
- queries = self.to_queries(x).reshape(b, n, self.num_heads, c // self.num_heads)
66
- # b m 2 h dh
67
- keys_values = self.to_keys_values(y).reshape(b, m, 2, self.num_heads, c // self.num_heads)
68
- keys, values = keys_values[:, :, 0], keys_values[:, :, 1]
69
- attention = torch.einsum('bnhd,bmhd->bnmh', queries, keys) * self.scale
70
- if mask is not None:
71
- if mask.dim() == 2:
72
- mask = mask.unsqueeze(1)
73
- attention = attention.masked_fill(mask.unsqueeze(3), float("-inf"))
74
- attention = attention.softmax(dim=2)
75
- out = torch.einsum('bnmh,bmhd->bnhd', attention, values).reshape(b, n, c)
76
- out = self.project(out)
77
- return out, attention
78
-
79
-
80
- class TransformerLayer(nn.Module):
81
-
82
- def forward_with_attention(self, x, y=None, mask=None):
83
- x_, attention = self.attn(self.norm1(x), y, mask)
84
- x = x + x_
85
- x = x + self.mlp(self.norm2(x))
86
- return x, attention
87
-
88
- def forward(self, x, y=None, mask=None):
89
- x = x + self.attn(self.norm1(x), y, mask)[0]
90
- x = x + self.mlp(self.norm2(x))
91
- return x
92
-
93
- def __init__(self, dim_self, dim_ref, num_heads, mlp_ratio=4., bias=False, dropout=0., act=nnf.relu,
94
- norm_layer: nn.Module = nn.LayerNorm):
95
- super().__init__()
96
- self.norm1 = norm_layer(dim_self)
97
- self.attn = MultiHeadAttention(dim_self, dim_ref, num_heads, bias=bias, dropout=dropout)
98
- self.norm2 = norm_layer(dim_self)
99
- self.mlp = MlpTransformer(dim_self, int(dim_self * mlp_ratio), act=act, dropout=dropout)
100
-
101
-
102
- class Transformer(nn.Module):
103
-
104
- def forward_with_attention(self, x, y=None, mask=None):
105
- attentions = []
106
- for layer in self.layers:
107
- x, att = layer.forward_with_attention(x, y, mask)
108
- attentions.append(att)
109
- return x, attentions
110
-
111
- def forward(self, x, y=None, mask=None):
112
- for i, layer in enumerate(self.layers):
113
- if i % 2 == 0 and self.enc_dec: # cross
114
- x = layer(x, y)
115
- elif self.enc_dec: # self
116
- x = layer(x, x, mask)
117
- else: # self or cross
118
- x = layer(x, y, mask)
119
- return x
120
-
121
- def __init__(self, dim_self: int, num_heads: int, num_layers: int, dim_ref: Optional[int] = None,
122
- mlp_ratio: float = 2., act=nnf.relu, norm_layer: nn.Module = nn.LayerNorm, enc_dec: bool = False):
123
- super(Transformer, self).__init__()
124
- dim_ref = dim_ref if dim_ref is not None else dim_self
125
- self.enc_dec = enc_dec
126
- if enc_dec:
127
- num_layers = num_layers * 2
128
- layers = []
129
- for i in range(num_layers):
130
- if i % 2 == 0 and enc_dec: # cross
131
- layers.append(TransformerLayer(dim_self, dim_ref, num_heads, mlp_ratio, act=act, norm_layer=norm_layer))
132
- elif enc_dec: # self
133
- layers.append(TransformerLayer(dim_self, dim_self, num_heads, mlp_ratio, act=act, norm_layer=norm_layer))
134
- else: # self or cross
135
- layers.append(TransformerLayer(dim_self, dim_ref, num_heads, mlp_ratio, act=act, norm_layer=norm_layer))
136
- self.layers = nn.ModuleList(layers)
137
-
138
- class TransformerMapper(nn.Module):
139
- def forward(self, x):
140
- x = self.linear(x).view(x.shape[0], self.clip_length, -1)
141
- prefix = self.prefix_const.unsqueeze(0).expand(x.shape[0], *self.prefix_const.shape)
142
- prefix = torch.cat((x, prefix), dim=1)
143
- out = self.transformer(prefix)[:, self.clip_length:]
144
- return out
145
-
146
- def __init__(self, dim_clip: int, dim_embedding: int, prefix_length: int, clip_length: int, num_layers: int = 8):
147
- super(TransformerMapper, self).__init__()
148
- self.clip_length = clip_length
149
- self.transformer = Transformer(dim_embedding, 8, num_layers) #nn.Transformer(d_model=dim_embedding, nhead=8, num_encoder_layers=num_layers)
150
- self.linear = nn.Linear(dim_clip, clip_length * dim_embedding)
151
- self.prefix_const = nn.Parameter(torch.randn(prefix_length, dim_embedding), requires_grad=True)
152
-
153
-
154
- class ClipCaptionModel(nn.Module):
155
-
156
- def get_dummy_token(self, batch_size: int, device: torch.device) -> torch.Tensor:
157
- return torch.zeros(batch_size, self.prefix_length, dtype=torch.int64, device=device)
158
-
159
- def forward(self, tokens: torch.Tensor, prefix: torch.Tensor, mask: Optional[torch.Tensor] = None,
160
- labels: Optional[torch.Tensor] = None):
161
- embedding_text = self.gpt.transformer.wte(tokens)
162
- prefix_projections = self.clip_project(prefix).view(-1, self.prefix_length, self.gpt_embedding_size)
163
- embedding_cat = torch.cat((prefix_projections, embedding_text), dim=1)
164
- if labels is not None:
165
- dummy_token = self.get_dummy_token(tokens.shape[0], tokens.device)
166
- labels = torch.cat((dummy_token, tokens), dim=1)
167
- out = self.gpt(inputs_embeds=embedding_cat, labels=labels, attention_mask=mask)
168
- return out
169
-
170
- def __init__(self, prefix_length: int, clip_length: Optional[int] = None, prefix_size: int = 512,
171
- num_layers: int = 8, mapping_type: MappingType = MappingType.MLP):
172
- super(ClipCaptionModel, self).__init__()
173
- self.prefix_length = prefix_length
174
- self.gpt = GPT2LMHeadModel.from_pretrained('gpt2')
175
- self.gpt_embedding_size = self.gpt.transformer.wte.weight.shape[1]
176
- if mapping_type == MappingType.MLP:
177
- self.clip_project = MLP((prefix_size, (self.gpt_embedding_size * prefix_length) // 2,
178
- self.gpt_embedding_size * prefix_length))
179
- else:
180
- self.clip_project = TransformerMapper(prefix_size, self.gpt_embedding_size, prefix_length,
181
- clip_length, num_layers)
182
-
183
-
184
- class ClipCaptionPrefix(ClipCaptionModel):
185
-
186
- def parameters(self, recurse: bool = True):
187
- return self.clip_project.parameters()
188
-
189
- def train(self, mode: bool = True):
190
- super(ClipCaptionPrefix, self).train(mode)
191
- self.gpt.eval()
192
- return self
193
-
194
-
195
- def generate_batched(
196
- model,
197
- tokenizer,
198
- prefix_embeds,
199
- entry_length=67,
200
- top_p=0.8,
201
- temperature=1.0,
202
- stop_token: str = '.',
203
- ):
204
- """
205
- Batched text generation for ClipCap models.
206
-
207
- Args:
208
- model: ClipCap model
209
- tokenizer: GPT2 tokenizer
210
- prefix_embeds: (batch_size, prefix_length, embedding_dim) - prefix embeddings
211
- entry_length: Maximum sequence length to generate
212
- top_p: Nucleus sampling parameter
213
- temperature: Sampling temperature
214
- stop_token: Token to stop generation
215
-
216
- Returns:
217
- List[str]: Generated captions for each item in batch
218
- """
219
- model.eval()
220
- device = next(model.parameters()).device
221
- batch_size = prefix_embeds.shape[0]
222
-
223
- # Initialize
224
- stop_token_index = tokenizer.encode(stop_token)[0]
225
- filter_value = -float("Inf")
226
-
227
- # Track which sequences are still generating
228
- active_sequences = torch.ones(batch_size, dtype=torch.bool, device=device)
229
-
230
- # Initialize token sequences - start with None
231
- tokens = None
232
- generated_embeds = prefix_embeds # Start with prefix embeddings
233
-
234
- with torch.no_grad():
235
- for step in range(entry_length):
236
- # Forward pass for all active sequences
237
- outputs = model.gpt(inputs_embeds=generated_embeds)
238
- logits = outputs.logits[:, -1, :] # Get logits for last token: (batch_size, vocab_size)
239
-
240
- # Apply temperature
241
- logits = logits / (temperature if temperature > 0 else 1.0)
242
-
243
- # Apply nucleus sampling for each sequence in batch
244
- for i in range(batch_size):
245
- if not active_sequences[i]:
246
- continue
247
-
248
- # Sort logits for this sequence
249
- sorted_logits, sorted_indices = torch.sort(logits[i], descending=True)
250
- cumulative_probs = torch.cumsum(torch.softmax(sorted_logits, dim=-1), dim=-1)
251
-
252
- # Find indices to remove (above top_p threshold)
253
- sorted_indices_to_remove = cumulative_probs > top_p
254
- sorted_indices_to_remove[1:] = sorted_indices_to_remove[:-1].clone()
255
- sorted_indices_to_remove[0] = 0
256
-
257
- # Set logits to -inf for tokens to remove
258
- indices_to_remove = sorted_indices[sorted_indices_to_remove]
259
- logits[i, indices_to_remove] = filter_value
260
-
261
- # Clamp logits to avoid extreme values
262
- logits = torch.clamp(logits, min=-1e9, max=1e9) # keep values bounded
263
- # Sample next tokens for all sequences
264
- probs = torch.softmax(logits, dim=-1)
265
-
266
- # if some sequences probs tensor contains NaNs (e.g. all logits were -inf), set stop_token_index prob to 1
267
- for i in range(batch_size):
268
- if torch.isnan(probs[i]).all(): #if not torch.isfinite(probs[i]).any() or probs[i].sum() == 0:
269
- probs[i] = torch.zeros_like(probs[i])
270
- probs[i, stop_token_index] = 1.0
271
-
272
- next_tokens = torch.multinomial(probs, num_samples=1) # (batch_size, 1)
273
-
274
- # Get embeddings for next tokens
275
- next_token_embeds = model.gpt.transformer.wte(next_tokens) # (batch_size, 1, embed_dim)
276
-
277
- # Update token sequences
278
- if tokens is None:
279
- tokens = next_tokens
280
- else:
281
- tokens = torch.cat((tokens, next_tokens), dim=1)
282
-
283
- # Update generated embeddings
284
- generated_embeds = torch.cat((generated_embeds, next_token_embeds), dim=1)
285
-
286
- # Check for stop tokens and update active sequences
287
- for i in range(batch_size):
288
- if active_sequences[i] and next_tokens[i].item() == stop_token_index:
289
- active_sequences[i] = False
290
-
291
- # If all sequences have stopped, break early
292
- if not active_sequences.any():
293
- break
294
-
295
- # Decode all sequences
296
- captions = []
297
- for i in range(batch_size):
298
- if tokens is not None:
299
- token_list = tokens[i].cpu().numpy().tolist()
300
- # Remove padding and decode
301
- caption = tokenizer.decode(token_list)
302
- # Clean up the caption
303
- caption = caption.split(stop_token)[0] + stop_token
304
- captions.append(caption)
305
- else:
306
- captions.append("")
307
-
308
- return captions
309
-
310
-
311
- def generate2(
312
- model,
313
- tokenizer,
314
- tokens=None,
315
- prompt=None,
316
- embed=None,
317
- entry_count=1,
318
- entry_length=67, # maximum number of words
319
- top_p=0.8,
320
- temperature=1.,
321
- stop_token: str = '.',
322
- ):
323
- """
324
- Legacy single-sequence generation function.
325
- For new code, use generate_batched instead.
326
- """
327
- model.eval()
328
- generated_num = 0
329
- generated_list = []
330
- stop_token_index = tokenizer.encode(stop_token)[0]
331
- filter_value = -float("Inf")
332
- device = next(model.parameters()).device
333
-
334
- with torch.no_grad():
335
-
336
- for entry_idx in range(entry_count):
337
- if embed is not None:
338
- generated = embed
339
- else:
340
- if tokens is None:
341
- tokens = torch.tensor(tokenizer.encode(prompt))
342
- tokens = tokens.unsqueeze(0).to(device)
343
-
344
- generated = model.gpt.transformer.wte(tokens)
345
-
346
- for i in range(entry_length):
347
-
348
- outputs = model.gpt(inputs_embeds=generated)
349
- logits = outputs.logits
350
- logits = logits[:, -1, :] / (temperature if temperature > 0 else 1.0)
351
- sorted_logits, sorted_indices = torch.sort(logits, descending=True)
352
- cumulative_probs = torch.cumsum(torch.softmax(sorted_logits, dim=-1), dim=-1)
353
- sorted_indices_to_remove = cumulative_probs > top_p
354
- sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[
355
- ..., :-1
356
- ].clone()
357
- sorted_indices_to_remove[..., 0] = 0
358
-
359
- indices_to_remove = sorted_indices[sorted_indices_to_remove]
360
- logits[:, indices_to_remove] = filter_value
361
- next_token = torch.multinomial(torch.softmax(logits, dim=-1), num_samples=1)
362
- next_token_embed = model.gpt.transformer.wte(next_token)
363
- if tokens is None:
364
- tokens = next_token
365
- else:
366
- tokens = torch.cat((tokens, next_token), dim=1)
367
- generated = torch.cat((generated, next_token_embed), dim=1)
368
- if stop_token_index == next_token.item():
369
- break
370
-
371
- output_list = list(tokens.squeeze().cpu().numpy())
372
- output_text = tokenizer.decode(output_list)
373
- generated_list.append(output_text)
374
-
375
- return generated_list[0]
376
-
377
-
378
- class ClipCapModel(torch.nn.Module):
379
- """
380
- ClipCap integration for the Patchioner class.
381
- """
382
-
383
- def __init__(self, args, device, dino_feature_dim=768):
384
- super(ClipCapModel, self).__init__()
385
- args_dict = args.copy()
386
- self.args = args = self.load_config(args)
387
- self.device = device
388
- self.dino_feature_dim = dino_feature_dim
389
-
390
- # Initialize tokenizer
391
- self.tokenizer = GPT2Tokenizer.from_pretrained(args.language_model)
392
- if self.tokenizer.pad_token_id is None:
393
- self.tokenizer.pad_token_id = self.tokenizer.eos_token_id
394
-
395
- # Determine mapping type
396
- mapping_type = MappingType.MLP if args.mapping_type.lower() == 'mlp' else MappingType.Transformer
397
-
398
- # Initialize model with DINO feature dimensions
399
- if args.only_prefix:
400
- self.model = ClipCaptionPrefix(
401
- prefix_length=args.prefix_length,
402
- clip_length=args.clip_length,
403
- prefix_size=dino_feature_dim,
404
- num_layers=args.num_layers,
405
- mapping_type=mapping_type
406
- )
407
- else:
408
- self.model = ClipCaptionModel(
409
- prefix_length=args.prefix_length,
410
- clip_length=args.clip_length,
411
- prefix_size=dino_feature_dim,
412
- num_layers=args.num_layers,
413
- mapping_type=mapping_type
414
- )
415
-
416
- # Load trained weights
417
- print(f"Loading ClipCap weights from: {args.weight_path}")
418
- checkpoint = torch.load(args.weight_path, map_location=device)
419
- self.model.load_state_dict(checkpoint, strict=False)
420
- self.model.to(device)
421
- self.model.eval()
422
-
423
- defaults = {
424
- "language_model": "gpt2",
425
- "prefix_length": 10,
426
- "clip_length": 10,
427
- "num_layers": 8,
428
- "mapping_type": "mlp",
429
- "only_prefix": True,
430
- "temperature": 1.0,
431
- "top_p": 0.8,
432
- "entry_length": 67,
433
- "stop_token": ".",
434
- "use_batched_generation": True, # Use batched generation by default
435
- "normalize_prefix": False, # Whether to L2 normalize the input features
436
- "weight_path": "/raid/datasets/models_weights/clipcap/training-features/clipcap_dino_vitb14_len10_mlp.pt"
437
- }
438
-
439
- def load_config(self, args_dict: dict) -> Namespace:
440
- def dict_to_namespace(d):
441
- if isinstance(d, dict):
442
- return Namespace(**{k: dict_to_namespace(v) for k, v in d.items()})
443
- return d
444
-
445
- # Apply defaults
446
- for key, value in self.defaults.items():
447
- if isinstance(value, dict):
448
- for sub_key, sub_value in value.items():
449
- args_dict.setdefault(key, {}).setdefault(sub_key, sub_value)
450
- else:
451
- args_dict.setdefault(key, value)
452
-
453
- args = dict_to_namespace(args_dict)
454
- return args
455
-
456
- def forward(self, dino_features, compute_scores: bool = False) -> List[str]:
457
- """
458
- DINO Features: (batch_size, dino_feature_dim)
459
- - returns: List[str] of generated captions
460
- """
461
- if self.args.use_batched_generation:
462
- return self.forward_batched(dino_features, compute_scores)
463
- else:
464
- return self.forward_sequential(dino_features, compute_scores)
465
-
466
- def forward_batched(self, dino_features, compute_scores: bool = False) -> List[str]:
467
- """
468
- Efficient batched generation for multiple sequences.
469
- """
470
- batch_size = dino_features.shape[0]
471
-
472
- # Apply normalization if specified (to match training)
473
- if self.args.normalize_prefix:
474
- dino_features = dino_features / dino_features.norm(dim=-1, keepdim=True)
475
-
476
- # Generate prefix embeddings for entire batch
477
- with torch.no_grad():
478
- prefix_embeds = self.model.clip_project(dino_features).view(
479
- batch_size, self.args.prefix_length, -1
480
- )
481
-
482
- # Generate captions for entire batch
483
- captions = generate_batched(
484
- model=self.model,
485
- tokenizer=self.tokenizer,
486
- prefix_embeds=prefix_embeds,
487
- entry_length=self.args.entry_length,
488
- temperature=self.args.temperature,
489
- top_p=self.args.top_p,
490
- stop_token=self.args.stop_token
491
- )
492
-
493
- if compute_scores:
494
- # Compute perplexity scores for generated captions
495
- scores = self.compute_perplexity_scores(captions)
496
- return captions, scores
497
- else:
498
- return captions
499
-
500
- def forward_sequential(self, dino_features, compute_scores: bool = False) -> List[str]:
501
- """
502
- Sequential generation for backward compatibility or debugging.
503
- """
504
- batch_size = dino_features.shape[0]
505
- captions = []
506
- scores = []
507
-
508
- # Process each feature in the batch sequentially
509
- for i in range(batch_size):
510
- feature = dino_features[i:i+1] # Keep batch dimension
511
-
512
- # Apply normalization if enabled
513
- if self.args.normalize_prefix:
514
- feature = feature / feature.norm(dim=-1, keepdim=True)
515
-
516
- # Generate prefix embeddings
517
- with torch.no_grad():
518
- prefix_embed = self.model.clip_project(feature).view(1, self.args.prefix_length, -1)
519
-
520
- # Generate caption using legacy function
521
- caption = generate2(
522
- model=self.model,
523
- tokenizer=self.tokenizer,
524
- embed=prefix_embed,
525
- entry_length=self.args.entry_length,
526
- temperature=self.args.temperature,
527
- top_p=self.args.top_p,
528
- stop_token=self.args.stop_token
529
- )
530
-
531
- captions.append(caption)
532
- if compute_scores:
533
- # Compute perplexity for this caption
534
- score = self.compute_perplexity_scores([caption])[0]
535
- scores.append(score)
536
-
537
- return captions if not compute_scores else (captions, scores)
538
-
539
- def compute_perplexity_scores(self, captions: List[str]) -> List[float]:
540
- """
541
- Compute perplexity scores for generated captions.
542
- """
543
- scores = []
544
- self.model.eval()
545
-
546
- with torch.no_grad():
547
- for caption in captions:
548
- try:
549
- # Tokenize caption
550
- tokens = self.tokenizer.encode(caption, return_tensors='pt').to(self.device)
551
-
552
- # Compute loss (negative log-likelihood)
553
- outputs = self.model.gpt(input_ids=tokens, labels=tokens)
554
- loss = outputs.loss
555
-
556
- # Convert to perplexity (lower is better, but we'll use 1/perplexity as score)
557
- perplexity = torch.exp(loss).item()
558
- score = 1.0 / perplexity if perplexity > 0 else 1.0
559
- scores.append(score)
560
- except:
561
- # Fallback score if computation fails
562
- scores.append(1.0)
563
-
564
- return scores
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
src/clipcap/predict.py DELETED
@@ -1,302 +0,0 @@
1
- # Prediction interface for Cog โš™๏ธ
2
- # Reference: https://github.com/replicate/cog/blob/main/docs/python.md
3
-
4
- import clip
5
- import os
6
- from torch import nn
7
- import numpy as np
8
- import torch
9
- import torch.nn.functional as nnf
10
- import sys
11
- from typing import Tuple, List, Union, Optional
12
- from transformers import (
13
- GPT2Tokenizer,
14
- GPT2LMHeadModel,
15
- AdamW,
16
- get_linear_schedule_with_warmup,
17
- )
18
- import skimage.io as io
19
- import PIL.Image
20
-
21
- import cog
22
-
23
- # import torch
24
-
25
- N = type(None)
26
- V = np.array
27
- ARRAY = np.ndarray
28
- ARRAYS = Union[Tuple[ARRAY, ...], List[ARRAY]]
29
- VS = Union[Tuple[V, ...], List[V]]
30
- VN = Union[V, N]
31
- VNS = Union[VS, N]
32
- T = torch.Tensor
33
- TS = Union[Tuple[T, ...], List[T]]
34
- TN = Optional[T]
35
- TNS = Union[Tuple[TN, ...], List[TN]]
36
- TSN = Optional[TS]
37
- TA = Union[T, ARRAY]
38
-
39
- WEIGHTS_PATHS = {
40
- "coco": "coco_weights.pt",
41
- "conceptual-captions": "conceptual_weights.pt",
42
- }
43
-
44
- D = torch.device
45
- CPU = torch.device("cpu")
46
-
47
-
48
- class Predictor(cog.Predictor):
49
- def setup(self):
50
- """Load the model into memory to make running multiple predictions efficient"""
51
- self.device = torch.device("cuda")
52
- self.clip_model, self.preprocess = clip.load(
53
- "ViT-B/32", device=self.device, jit=False
54
- )
55
- self.tokenizer = GPT2Tokenizer.from_pretrained("gpt2")
56
-
57
- self.models = {}
58
- self.prefix_length = 10
59
- for key, weights_path in WEIGHTS_PATHS.items():
60
- model = ClipCaptionModel(self.prefix_length)
61
- model.load_state_dict(torch.load(weights_path, map_location=CPU))
62
- model = model.eval()
63
- model = model.to(self.device)
64
- self.models[key] = model
65
-
66
- @cog.input("image", type=cog.Path, help="Input image")
67
- @cog.input(
68
- "model",
69
- type=str,
70
- options=WEIGHTS_PATHS.keys(),
71
- default="coco",
72
- help="Model to use",
73
- )
74
- @cog.input(
75
- "use_beam_search",
76
- type=bool,
77
- default=False,
78
- help="Whether to apply beam search to generate the output text",
79
- )
80
- def predict(self, image, model, use_beam_search):
81
- """Run a single prediction on the model"""
82
- image = io.imread(image)
83
- model = self.models[model]
84
- pil_image = PIL.Image.fromarray(image)
85
- image = self.preprocess(pil_image).unsqueeze(0).to(self.device)
86
- with torch.no_grad():
87
- prefix = self.clip_model.encode_image(image).to(
88
- self.device, dtype=torch.float32
89
- )
90
- prefix_embed = model.clip_project(prefix).reshape(1, self.prefix_length, -1)
91
- if use_beam_search:
92
- return generate_beam(model, self.tokenizer, embed=prefix_embed)[0]
93
- else:
94
- return generate2(model, self.tokenizer, embed=prefix_embed)
95
-
96
-
97
- class MLP(nn.Module):
98
- def forward(self, x: T) -> T:
99
- return self.model(x)
100
-
101
- def __init__(self, sizes: Tuple[int, ...], bias=True, act=nn.Tanh):
102
- super(MLP, self).__init__()
103
- layers = []
104
- for i in range(len(sizes) - 1):
105
- layers.append(nn.Linear(sizes[i], sizes[i + 1], bias=bias))
106
- if i < len(sizes) - 2:
107
- layers.append(act())
108
- self.model = nn.Sequential(*layers)
109
-
110
-
111
- class ClipCaptionModel(nn.Module):
112
-
113
- # @functools.lru_cache #FIXME
114
- def get_dummy_token(self, batch_size: int, device: D) -> T:
115
- return torch.zeros(
116
- batch_size, self.prefix_length, dtype=torch.int64, device=device
117
- )
118
-
119
- def forward(
120
- self, tokens: T, prefix: T, mask: Optional[T] = None, labels: Optional[T] = None
121
- ):
122
- embedding_text = self.gpt.transformer.wte(tokens)
123
- prefix_projections = self.clip_project(prefix).view(
124
- -1, self.prefix_length, self.gpt_embedding_size
125
- )
126
- # print(embedding_text.size()) #torch.Size([5, 67, 768])
127
- # print(prefix_projections.size()) #torch.Size([5, 1, 768])
128
- embedding_cat = torch.cat((prefix_projections, embedding_text), dim=1)
129
- if labels is not None:
130
- dummy_token = self.get_dummy_token(tokens.shape[0], tokens.device)
131
- labels = torch.cat((dummy_token, tokens), dim=1)
132
- out = self.gpt(inputs_embeds=embedding_cat, labels=labels, attention_mask=mask)
133
- return out
134
-
135
- def __init__(self, prefix_length: int, prefix_size: int = 512):
136
- super(ClipCaptionModel, self).__init__()
137
- self.prefix_length = prefix_length
138
- self.gpt = GPT2LMHeadModel.from_pretrained("gpt2")
139
- self.gpt_embedding_size = self.gpt.transformer.wte.weight.shape[1]
140
- if prefix_length > 10: # not enough memory
141
- self.clip_project = nn.Linear(
142
- prefix_size, self.gpt_embedding_size * prefix_length
143
- )
144
- else:
145
- self.clip_project = MLP(
146
- (
147
- prefix_size,
148
- (self.gpt_embedding_size * prefix_length) // 2,
149
- self.gpt_embedding_size * prefix_length,
150
- )
151
- )
152
-
153
-
154
- class ClipCaptionPrefix(ClipCaptionModel):
155
- def parameters(self, recurse: bool = True):
156
- return self.clip_project.parameters()
157
-
158
- def train(self, mode: bool = True):
159
- super(ClipCaptionPrefix, self).train(mode)
160
- self.gpt.eval()
161
- return self
162
-
163
-
164
- def generate_beam(
165
- model,
166
- tokenizer,
167
- beam_size: int = 5,
168
- prompt=None,
169
- embed=None,
170
- entry_length=67,
171
- temperature=1.0,
172
- stop_token: str = ".",
173
- ):
174
-
175
- model.eval()
176
- stop_token_index = tokenizer.encode(stop_token)[0]
177
- tokens = None
178
- scores = None
179
- device = next(model.parameters()).device
180
- seq_lengths = torch.ones(beam_size, device=device)
181
- is_stopped = torch.zeros(beam_size, device=device, dtype=torch.bool)
182
- with torch.no_grad():
183
- if embed is not None:
184
- generated = embed
185
- else:
186
- if tokens is None:
187
- tokens = torch.tensor(tokenizer.encode(prompt))
188
- tokens = tokens.unsqueeze(0).to(device)
189
- generated = model.gpt.transformer.wte(tokens)
190
- for i in range(entry_length):
191
- outputs = model.gpt(inputs_embeds=generated)
192
- logits = outputs.logits
193
- logits = logits[:, -1, :] / (temperature if temperature > 0 else 1.0)
194
- logits = logits.softmax(-1).log()
195
- if scores is None:
196
- scores, next_tokens = logits.topk(beam_size, -1)
197
- generated = generated.expand(beam_size, *generated.shape[1:])
198
- next_tokens, scores = next_tokens.permute(1, 0), scores.squeeze(0)
199
- if tokens is None:
200
- tokens = next_tokens
201
- else:
202
- tokens = tokens.expand(beam_size, *tokens.shape[1:])
203
- tokens = torch.cat((tokens, next_tokens), dim=1)
204
- else:
205
- logits[is_stopped] = -float(np.inf)
206
- logits[is_stopped, 0] = 0
207
- scores_sum = scores[:, None] + logits
208
- seq_lengths[~is_stopped] += 1
209
- scores_sum_average = scores_sum / seq_lengths[:, None]
210
- scores_sum_average, next_tokens = scores_sum_average.view(-1).topk(
211
- beam_size, -1
212
- )
213
- next_tokens_source = next_tokens // scores_sum.shape[1]
214
- seq_lengths = seq_lengths[next_tokens_source]
215
- next_tokens = next_tokens % scores_sum.shape[1]
216
- next_tokens = next_tokens.unsqueeze(1)
217
- tokens = tokens[next_tokens_source]
218
- tokens = torch.cat((tokens, next_tokens), dim=1)
219
- generated = generated[next_tokens_source]
220
- scores = scores_sum_average * seq_lengths
221
- is_stopped = is_stopped[next_tokens_source]
222
- next_token_embed = model.gpt.transformer.wte(next_tokens.squeeze()).view(
223
- generated.shape[0], 1, -1
224
- )
225
- generated = torch.cat((generated, next_token_embed), dim=1)
226
- is_stopped = is_stopped + next_tokens.eq(stop_token_index).squeeze()
227
- if is_stopped.all():
228
- break
229
- scores = scores / seq_lengths
230
- output_list = tokens.cpu().numpy()
231
- output_texts = [
232
- tokenizer.decode(output[: int(length)])
233
- for output, length in zip(output_list, seq_lengths)
234
- ]
235
- order = scores.argsort(descending=True)
236
- output_texts = [output_texts[i] for i in order]
237
- return output_texts
238
-
239
-
240
- def generate2(
241
- model,
242
- tokenizer,
243
- tokens=None,
244
- prompt=None,
245
- embed=None,
246
- entry_count=1,
247
- entry_length=67, # maximum number of words
248
- top_p=0.8,
249
- temperature=1.0,
250
- stop_token: str = ".",
251
- ):
252
- model.eval()
253
- generated_num = 0
254
- generated_list = []
255
- stop_token_index = tokenizer.encode(stop_token)[0]
256
- filter_value = -float("Inf")
257
- device = next(model.parameters()).device
258
-
259
- with torch.no_grad():
260
-
261
- for entry_idx in range(entry_count):
262
- if embed is not None:
263
- generated = embed
264
- else:
265
- if tokens is None:
266
- tokens = torch.tensor(tokenizer.encode(prompt))
267
- tokens = tokens.unsqueeze(0).to(device)
268
-
269
- generated = model.gpt.transformer.wte(tokens)
270
-
271
- for i in range(entry_length):
272
-
273
- outputs = model.gpt(inputs_embeds=generated)
274
- logits = outputs.logits
275
- logits = logits[:, -1, :] / (temperature if temperature > 0 else 1.0)
276
- sorted_logits, sorted_indices = torch.sort(logits, descending=True)
277
- cumulative_probs = torch.cumsum(
278
- nnf.softmax(sorted_logits, dim=-1), dim=-1
279
- )
280
- sorted_indices_to_remove = cumulative_probs > top_p
281
- sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[
282
- ..., :-1
283
- ].clone()
284
- sorted_indices_to_remove[..., 0] = 0
285
-
286
- indices_to_remove = sorted_indices[sorted_indices_to_remove]
287
- logits[:, indices_to_remove] = filter_value
288
- next_token = torch.argmax(logits, -1).unsqueeze(0)
289
- next_token_embed = model.gpt.transformer.wte(next_token)
290
- if tokens is None:
291
- tokens = next_token
292
- else:
293
- tokens = torch.cat((tokens, next_token), dim=1)
294
- generated = torch.cat((generated, next_token_embed), dim=1)
295
- if stop_token_index == next_token.item():
296
- break
297
-
298
- output_list = list(tokens.squeeze().cpu().numpy())
299
- output_text = tokenizer.decode(output_list)
300
- generated_list.append(output_text)
301
-
302
- return generated_list[0]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
src/dataset.py DELETED
@@ -1,94 +0,0 @@
1
- import torch
2
- from torch.utils.data import Dataset
3
-
4
- from tqdm import tqdm
5
- import json
6
- from typing import Tuple
7
- import clip
8
- import random
9
- import json
10
- import random
11
- from tqdm import tqdm
12
-
13
- class ClipCocoDataset(Dataset):
14
-
15
- def __len__(self) -> int:
16
- return len(self.captions_tokens)
17
-
18
- def pad_tokens(self, item: int):
19
- tokens = self.captions_tokens[item]
20
- padding = self.max_seq_len - tokens.shape[0]
21
- if padding > 0:
22
- tokens = torch.cat((tokens, torch.zeros(padding, dtype=torch.int64)))
23
- elif padding < 0:
24
- tokens = tokens[:self.max_seq_len]
25
- return tokens
26
-
27
-
28
- def __getitem__(self, item: int) -> Tuple[torch.Tensor, ...]:
29
- # tokens = self.captions_tokens[item]
30
-
31
- clip_tokens = self.pad_tokens(item)
32
- if self.feats is None:
33
- clip_tokens_77 = self.captions_tokens[item]
34
- return clip_tokens, clip_tokens_77
35
- else:
36
- return clip_tokens, self.feats[item]
37
-
38
- def __init__(self, data_path: str, clip_model=None, talk2dino=None, use_dino_feats=False, tokenizer=None):
39
- if tokenizer is not None:
40
- self.clip_tokenizer = tokenizer
41
- else:
42
- print(f"Using default tokenizer")
43
- self.clip_tokenizer = clip.tokenize
44
- self.prefix_length = 10
45
- self.max_seq_len = 20
46
- self.feats = None
47
-
48
- if clip_model is not None:
49
- device = next(clip_model.parameters()).device
50
- print("Pre-extracting features...")
51
-
52
- if not use_dino_feats:
53
- with open(data_path, 'r') as f:
54
- self.captions = [ann['caption'] for ann in json.load(f)['annotations']]
55
- else:
56
- data = torch.load(data_path)
57
- self.captions = [ann['caption'] for ann in data['annotations']]
58
- self.feats = [ann['features'] for ann in data['annotations']]
59
-
60
-
61
- random.shuffle(self.captions)
62
- self.captions_tokens = []
63
-
64
- batch_size = 64
65
- batched_captions = [self.captions[i:i + batch_size] for i in range(0, len(self.captions), batch_size)]
66
-
67
- for batch in tqdm(batched_captions):
68
- try:
69
- # Tokenize the batch of captions
70
- batch_tokens = [torch.tensor(self.clip_tokenizer(caption)[0], dtype=torch.int64) for caption in batch]
71
-
72
- # Pad tokens to the same length for batching
73
- batch_tokens_padded = torch.nn.utils.rnn.pad_sequence(batch_tokens, batch_first=True)
74
- self.captions_tokens.extend(batch_tokens)
75
-
76
- if clip_model is not None:
77
- with torch.no_grad():
78
- # Encode the text batch
79
- feats = clip_model.encode_text(batch_tokens_padded.to(device))
80
-
81
- if talk2dino is not None:
82
- # Project to desired feature space
83
- feats = talk2dino.project_clip_txt(feats).to('cpu')
84
-
85
- # Concatenate features
86
- if self.feats is None:
87
- self.feats = feats
88
- else:
89
- self.feats = torch.cat((self.feats, feats))
90
- except Exception as e:
91
- print(f"Error processing batch: {e}")
92
- print(len(self.captions_tokens))
93
-
94
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
src/datasetMix.py DELETED
@@ -1,153 +0,0 @@
1
- import torch
2
- from torch.utils.data import Dataset
3
-
4
- from tqdm import tqdm
5
- import json
6
- from typing import Tuple
7
- import clip
8
- import random
9
- import json
10
- import random
11
- from tqdm import tqdm
12
-
13
- from pycocotools.coco import COCO
14
-
15
- class ClipCocoDatasetMix(Dataset):
16
-
17
- def __len__(self) -> int:
18
- return len(self.image_index_list)
19
-
20
- def _pad_tokens(self, tokens: torch.Tensor) -> torch.Tensor:
21
- padding = self.max_seq_len - tokens.shape[0]
22
- if padding > 0:
23
- tokens = torch.cat((tokens, torch.zeros(padding, dtype=torch.int64)))
24
- elif padding < 0:
25
- tokens = tokens[:self.max_seq_len]
26
- return tokens
27
-
28
-
29
- def __getitem__(self, item: int) -> Tuple[torch.Tensor, ...]:
30
-
31
- # get the image index for the item
32
- img_idx = self.image_index_list[item]
33
- # get the caption index for that image
34
- first_caption_idx = self.image_index_list.index(img_idx)
35
-
36
- # the caption index is the item - the first caption index
37
- caption_idx = item - first_caption_idx
38
-
39
- # how many captions are there for that image?
40
- num_captions = len(self.captions_list_of_lists[img_idx])
41
- try:
42
- tokens = self.captions_tokens_list_of_lists[img_idx][caption_idx] #self.captions_list_of_lists[img_idx][caption_idx]
43
- except IndexError:
44
- print(f"{len(self.captions_tokens_list_of_lists)= } - {len(self.captions_tokens_list_of_lists[img_idx])= }")
45
- print(f"IndexError: {img_idx}, {caption_idx}, {num_captions}")
46
- raise
47
- padded_tokens = self._pad_tokens(tokens)
48
-
49
- feats_same_img = self.feats[img_idx][random.choice(range(num_captions))]
50
-
51
- if self.feats is None or len(self.feats) == 0:
52
- raise Exception("Precomputed features required")
53
- else:
54
- return padded_tokens, feats_same_img
55
-
56
- def __init__(self, data_path: str, clip_model=None, talk2dino=None, use_precomputed_feats=False, tokenizer=None):
57
-
58
- batch_size = 64
59
- self.max_seq_len = 20
60
-
61
- if use_precomputed_feats:
62
- raise Exception("Precomputed features not supported")
63
-
64
- if tokenizer is not None:
65
- self.clip_tokenizer = tokenizer
66
- else:
67
- print(f"Using default tokenizer")
68
- self.clip_tokenizer = clip.tokenize
69
-
70
- coco_data = COCO(data_path)
71
- # I want to load the captions from the json file in a list of lists,
72
- # where each list contains the captions for a single image
73
-
74
- self.captions_list_of_lists = []
75
-
76
- self.image_index_list = []
77
-
78
- max_seq_len = 20
79
-
80
- for img_idx, (img_id, image) in enumerate(list(coco_data.imgs.items())):
81
- # get the captions for that image
82
- captions = coco_data.imgToAnns[img_id]
83
- # get the texts of the captions
84
- captions = [cap['caption'] for cap in captions] #[coco_data.anns[cap]['caption'] for cap in captions]
85
- self.captions_list_of_lists.append(captions)
86
- self.image_index_list.append([img_idx] * len(captions))
87
-
88
- #max_seq_len = max(max_seq_len, max([len(caption) for caption in captions]))
89
-
90
- self.max_seq_len = max_seq_len
91
- print(f"Computed Max seq len: {max_seq_len}")
92
-
93
- if clip_model is not None:
94
- device = next(clip_model.parameters()).device
95
- print("Pre-extracting features...")
96
-
97
- #random.shuffle(self.captions_list_of_lists)
98
- # should shuffle in the same way self.image_index_list and self.captions_list_of_lists
99
- # Combine captions and image indices into a list of pairs
100
- combined = list(zip(self.captions_list_of_lists, self.image_index_list, range(len(self.captions_list_of_lists))))
101
-
102
- # Shuffle them together
103
- random.shuffle(combined)
104
-
105
- # Unzip the shuffled pairs back into two separate lists
106
- self.captions_list_of_lists, self.image_index_list, img_idxes_shuffled = zip(*combined)
107
- # Convert back to lists (zip returns tuples)
108
- self.captions_list_of_lists = list(self.captions_list_of_lists)
109
- self.image_index_list = list(self.image_index_list)
110
- img_idxes_shuffled = list(img_idxes_shuffled)
111
-
112
- # self.image_index_list is a list of lists, where each list contains the image index for each caption,
113
- # so we need to flatten it
114
- self.image_index_list = [img_idxes_shuffled.index(item) for sublist in self.image_index_list for item in sublist]
115
-
116
-
117
- self.captions_tokens_list_of_lists = []
118
- self.feats = [] # feats will be a list of tensors, each tensor will be (num_captions, embedding_dimension)
119
- #ignore. # feats shape will be (num_images, num_captions, embedding_dimension)
120
-
121
- #batched_captions = [self.captions[i:i + batch_size] for i in range(0, len(self.captions), batch_size)]
122
-
123
- for captions_list in tqdm(self.captions_list_of_lists, dynamic_ncols=True):
124
- try:
125
- # Tokenize the batch of captions
126
- batch_tokens = [torch.tensor(self.clip_tokenizer(caption)[0], dtype=torch.int64) for caption in captions_list]
127
-
128
- # Pad tokens to the same length for batching
129
- batch_tokens_padded = torch.nn.utils.rnn.pad_sequence(batch_tokens, batch_first=True)
130
- self.captions_tokens_list_of_lists.append(batch_tokens)
131
-
132
- # alternative:
133
- # tokens = self.clip_tokenizer(captions_list, truncate=True).to(device) # shape: (num_captions, context_length)
134
-
135
-
136
- if clip_model is not None:
137
- with torch.no_grad():
138
- # Encode the text batch
139
- feats = clip_model.encode_text(batch_tokens_padded.to(device))
140
-
141
- if talk2dino is not None:
142
- # Project to desired feature space
143
- feats = talk2dino.project_clip_txt(feats).to('cpu')
144
-
145
- self.feats.append(feats.cpu()) # store (num_captions, embed_dim) for each image
146
-
147
- except Exception as e:
148
- print(f"Error processing batch: {e}")
149
-
150
- print(f"Dataset loaded with {len(self.captions_list_of_lists)} images")
151
- print(f"Max seq len: {max_seq_len}")
152
- print(f"Number of captions: {len(self.image_index_list)}")
153
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
src/decap/decap.py DELETED
@@ -1,193 +0,0 @@
1
- import os
2
- from torch import nn
3
- import numpy as np
4
- import torch
5
- import torch.nn.functional as nnf
6
- import sys
7
- from typing import Tuple, List, Union, Optional
8
- from tqdm import tqdm, trange
9
- import pickle
10
- import PIL.Image as Image
11
- import json
12
- import random
13
- import sys
14
- import clip
15
- import PIL
16
- import random
17
-
18
- from torch.utils.data import Dataset, DataLoader
19
- from enum import Enum
20
- from transformers import GPT2Tokenizer, GPT2LMHeadModel, AdamW, get_linear_schedule_with_warmup
21
- from tqdm import tqdm
22
- import os
23
- import pickle
24
- import sys
25
- import argparse
26
- import json
27
- from typing import Tuple, Optional, Union
28
-
29
- import os
30
- from dotenv import load_dotenv
31
-
32
- load_dotenv()
33
-
34
-
35
- DECAP_DECODER_CONFIG_PATH = os.path.join(os.path.dirname(__file__), "decoder_config.pkl")
36
- DECAP_COCO_WEIGHTS_PATH = None#'../../thesis-data/decap/coco_model/coco_prefix-009.pt'
37
-
38
- class MappingType(Enum):
39
- MLP = 'mlp'
40
- Transformer = 'transformer'
41
-
42
-
43
- class MLP(nn.Module):
44
-
45
- def forward(self, x: torch.Tensor) -> torch.Tensor:
46
- return self.model(x)
47
-
48
- def __init__(self, sizes: Tuple[int, ...], bias=True, act=nn.Tanh):
49
- super(MLP, self).__init__()
50
- layers = []
51
- for i in range(len(sizes) - 1):
52
- layers.append(nn.Linear(sizes[i], sizes[i + 1], bias=bias))
53
- if i < len(sizes) - 2:
54
- layers.append(act())
55
- self.model = nn.Sequential(*layers)
56
-
57
-
58
- class DeCap(nn.Module):
59
-
60
- def __init__(self,prefix_size: int = 512):
61
- super(DeCap, self).__init__()
62
- # decoder: 4 layers transformer with 4 attention heads
63
- # the decoder is not pretrained
64
- with open(DECAP_DECODER_CONFIG_PATH,'rb') as f:
65
- config = pickle.load(f)
66
- self.decoder = GPT2LMHeadModel(config)
67
- self.embedding_size = self.decoder.transformer.wte.weight.shape[1]
68
- self.clip_project = MLP((prefix_size,self.embedding_size))
69
-
70
- def forward(self, clip_features,tokens):
71
- embedding_text = self.decoder.transformer.wte(tokens)
72
- embedding_clip = self.clip_project(clip_features)
73
- embedding_clip = embedding_clip.reshape(-1,1,self.embedding_size)
74
- embedding_cat = torch.cat([embedding_clip,embedding_text],dim=1)
75
- out = self.decoder(inputs_embeds=embedding_cat)
76
- return out
77
-
78
- from clip.simple_tokenizer import SimpleTokenizer as _Tokenizer
79
- _Tokenizer = _Tokenizer()
80
-
81
- def Decoding(model,clip_features):
82
- model.eval()
83
- embedding_cat = model.clip_project(clip_features).reshape(1,1,-1)
84
- entry_length = 30
85
- temperature = 1
86
- tokens = None
87
- for i in range(entry_length):
88
- # print(location_token.shape)
89
- outputs = model.decoder(inputs_embeds=embedding_cat)
90
-
91
- logits = outputs.logits
92
- logits = logits[:, -1, :] / (temperature if temperature > 0 else 1.0)
93
- logits_max = logits.max()
94
- logits = torch.nn.functional.softmax(logits, -1)
95
- next_token = torch.argmax(logits, -1).unsqueeze(0)
96
- next_token_embed = model.decoder.transformer.wte(next_token)
97
-
98
- if tokens is None:
99
- tokens = next_token
100
-
101
- else:
102
- tokens = torch.cat((tokens, next_token), dim=1)
103
- if next_token.item()==49407:
104
- break
105
- embedding_cat = torch.cat((embedding_cat, next_token_embed), dim=1)
106
- try:
107
- output_list = list(tokens.squeeze().cpu().numpy())
108
- output = _Tokenizer.decode(output_list)
109
- except:
110
- output = 'None'
111
- return output
112
-
113
- def decoding_batched(model, clip_features, compute_scores : bool = False, decoding_method : callable = None, return_start_end_tokens : bool = False):
114
- """
115
- Returns the generated sequences for a batch of clip features.
116
- - if compute_scores is True, also returns the scores of the generated sequences.
117
- - returns a list of strings if compute_scores is False, otherwise a tuple of a list of strings and a list of floats.
118
- """
119
-
120
- model.eval()
121
- embedding_cat = model.clip_project(clip_features).view(clip_features.shape[0], 1, -1)
122
- entry_length = 30
123
- temperature = 1
124
- tokens = None
125
- sequence_log_probs = None
126
-
127
- for i in range(entry_length):
128
- outputs = model.decoder(inputs_embeds=embedding_cat)
129
-
130
- logits = outputs.logits[:, -1, :]
131
- logits = logits / (temperature if temperature > 0 else 1.0)
132
-
133
- probs = torch.nn.functional.softmax(logits, -1)
134
-
135
- if compute_scores:
136
- log_probs = torch.log(probs) # Convert to log-probabilities
137
-
138
- next_token = torch.argmax(probs, -1).unsqueeze(1)
139
- next_token_embed = model.decoder.transformer.wte(next_token)
140
-
141
- if tokens is None:
142
- tokens = next_token
143
- if compute_scores:
144
- sequence_log_probs = log_probs.gather(1, next_token) # Store log-prob of first token
145
- else:
146
- tokens = torch.cat((tokens, next_token), dim=1)
147
- if compute_scores:
148
- token_log_probs = log_probs.gather(1, next_token) # Get log-prob of chosen token
149
- sequence_log_probs = torch.cat((sequence_log_probs, token_log_probs), dim=1) # Append
150
-
151
- # Append new token embedding to input
152
- embedding_cat = torch.cat((embedding_cat, next_token_embed), dim=1)
153
-
154
- if compute_scores:
155
- # Compute total sequence scores
156
- sequence_scores = sequence_log_probs.sum(dim=-1) # Sum log-probs over sequence
157
- final_scores = torch.exp(sequence_scores) # Convert log-sum-prob to probability-like score
158
-
159
- try:
160
- outputs = []
161
- for tokens_elem in tokens:
162
- output_list = list(tokens_elem.squeeze().cpu().numpy())
163
- if decoding_method is not None:
164
- output = decoding_method(output_list)
165
- else:
166
- output = _Tokenizer.decode(output_list)
167
-
168
-
169
-
170
- output = output.split('<|endoftext|>')[0]
171
- if not return_start_end_tokens:
172
- output = output.replace('<|startoftext|>', '')
173
- else:
174
- output += '<|endoftext|>'
175
-
176
- outputs.append(output)
177
- except:
178
- outputs = None
179
-
180
- return (outputs, final_scores.cpu().numpy().tolist()) if compute_scores else outputs
181
-
182
-
183
- decap_model = None
184
-
185
- def get_decap_model(device, weights_path = DECAP_COCO_WEIGHTS_PATH, prefix_size=512):
186
- #global decap_model
187
- #if decap_model is not None:
188
- # return decap_model
189
- decap_model = DeCap(prefix_size)
190
- decap_model.load_state_dict(torch.load(weights_path,map_location= torch.device('cpu')), strict=False)
191
- decap_model = decap_model.to(device)
192
- decap_model = decap_model.eval()
193
- return decap_model
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
src/decap/decoderTraining.py DELETED
@@ -1,464 +0,0 @@
1
- import torch
2
- from torch.utils.data import DataLoader
3
- from torch.utils.data.distributed import DistributedSampler
4
- from torch.nn.parallel import DistributedDataParallel as DDP
5
- import torch.distributed as dist
6
-
7
- from im2txtprojection.im2txtprojection import Im2TxtProjector, ProjectionType
8
- from transformers import get_linear_schedule_with_warmup
9
- from torch.optim import AdamW
10
- from tqdm import tqdm
11
- from decap import get_decap_model
12
- import os
13
- import sys
14
- import argparse
15
- import json
16
- from typing import Union
17
- import sys
18
- import clip
19
- import json
20
-
21
- import csv
22
-
23
-
24
- from src.dataset import ClipCocoDataset
25
- from src.datasetMix import ClipCocoDatasetMix
26
- from src.model import DeCap, ProjectionLayer
27
-
28
- DECAP_DECODER_CONFIG_PATH = os.path.join("./decoder_config.pkl")
29
-
30
- def save_config(args: argparse.Namespace):
31
- config = {}
32
- for key, item in args._get_kwargs():
33
- config[key] = item
34
- out_path = os.path.join(args.out_dir, f"{args.prefix}.json")
35
- with open(out_path, 'w') as outfile:
36
- json.dump(config, outfile)
37
-
38
-
39
- def load_model(config_path: str, epoch_or_latest: Union[str, int] = '_latest'):
40
- with open(config_path) as f:
41
- config = json.load(f)
42
- parser = argparse.ArgumentParser()
43
- parser.set_defaults(**config)
44
- args = parser.parse_args()
45
- if type(epoch_or_latest) is int:
46
- epoch_or_latest = f"-{epoch_or_latest:03d}"
47
- model_path = os.path.join(args.out_dir, f"{args.prefix}{epoch_or_latest}.pt")
48
- if args.only_prefix:
49
- model = ClipCaptionPrefix(args.prefix_length)
50
- else:
51
- model = ClipCaptionModel(args.prefix_length)
52
- if os.path.isfile(model_path):
53
- print(f"loading model from {model_path}")
54
- model.load_state_dict(torch.load(model_path, map_location=torch.device('cpu')))
55
- else:
56
- print(f"{model_path} is not exist")
57
- return model, parser
58
-
59
-
60
-
61
-
62
- def train_decoder(args,
63
- lr: float = 1e-5, warmup_steps: int = 1000, output_dir: str = ".", output_prefix: str = ""):
64
-
65
- # device = torch.device('cuda:1')
66
- batch_size = args.bs
67
- epochs = args.epochs
68
- if not os.path.exists(output_dir):
69
- os.makedirs(output_dir)
70
- args.is_master = ( args.local_rank == 0 or args.not_distributed != False)
71
-
72
- # set the device
73
- #torch.cuda.set_device(args.local_rank)
74
- #device = torch.device('cuda:'+str(args.local_rank))
75
- if args.not_distributed == False:
76
- torch.cuda.set_device(args.local_rank)
77
- device = torch.device('cuda:'+str(args.local_rank))
78
- dist.init_process_group(backend='nccl', init_method='env://')
79
- else:
80
- device = torch.device('cuda:'+str(args.local_rank))
81
- print(f"NOT DISTRIBUTED")
82
- print(f"Using device {device}")
83
- SEED=42
84
- torch.cuda.manual_seed_all(SEED)
85
-
86
- if args.use_regionclip:
87
- # RegionCLIP typically uses 1024 dimensions for ResNet-50 or 512 for ViT
88
- # We'll determine this from the loaded model
89
- prefix_size = 1024 # Default for RegionCLIP ResNet-50, but will be adjusted if needed
90
- elif args.denseclip_config is not None:
91
- # DenseClip typically uses 512 dimensions (similar to CLIP ViT-B)
92
- from src.denseclip.loader import load_denseclip_config
93
- denseclip_config_dict = load_denseclip_config(args.denseclip_config)
94
- prefix_size = denseclip_config_dict.get('model', {}).get('text', {}).get('embed_dim', None)
95
- if prefix_size is None:
96
- print(f"Warning: Could not determine prefix_size from DenseClip config {args.denseclip_config}. Defaulting to 512.")
97
- prefix_size = 512 # Fallback to a common size)
98
-
99
- elif 'H' in args.clip_model or args.use_dinotxt:
100
- prefix_size = 1024
101
- elif args.talk2dino_weights is not None or args.use_dino_feats:
102
- prefix_size = 768
103
- else:
104
- prefix_size = 512
105
-
106
- if args.im_proj:
107
- memory_bank_path = os.path.abspath(args.dataset)
108
- print(f"Using Im2TxtProjector with {memory_bank_path = }")
109
- im_proj = Im2TxtProjector(
110
- type=memory_bank_path,
111
- use_talk2dino=True,
112
- linear_talk2dino=False,
113
- memory_bank_name='coco_karpathy',
114
- device_str=device)
115
-
116
- if args.use_regionclip:
117
- from src.regionclip.loader import load_regionclip_from_checkpoint
118
- from src.regionclip.datasets.clip_prompt_utils import tokenize as regionclip_tokenize
119
-
120
- print("Using RegionCLIP for text encoding.")
121
- if args.regionclip_checkpoint is None:
122
- raise ValueError("RegionCLIP checkpoint path must be provided when using --use-regionclip")
123
-
124
- clip_model = load_regionclip_from_checkpoint(
125
- args.regionclip_checkpoint,
126
- device=device,
127
- config=args.regionclip_config
128
- )
129
- tokenizer = regionclip_tokenize
130
- preprocess = None # RegionCLIP doesn't need preprocessing for text-only training
131
-
132
- # Determine the actual embedding dimension from the loaded model
133
- if hasattr(clip_model, 'text_projection'):
134
- actual_prefix_size = clip_model.text_projection.shape[1]
135
- print(f"RegionCLIP text embedding dimension: {actual_prefix_size}")
136
- if actual_prefix_size != prefix_size:
137
- print(f"Updating prefix_size from {prefix_size} to {actual_prefix_size}")
138
- prefix_size = actual_prefix_size
139
-
140
- # Test RegionCLIP text encoding to ensure it works
141
- try:
142
- test_text = ["A test sentence"]
143
- test_tokens = tokenizer(test_text)
144
- test_features = clip_model.encode_text(test_tokens.to(device))
145
- print(f"RegionCLIP test encoding successful. Output shape: {test_features.shape}")
146
- except Exception as e:
147
- print(f"Warning: RegionCLIP test encoding failed: {e}")
148
- print("This might cause issues during training.")
149
-
150
- elif args.denseclip_config is not None:
151
- from src.denseclip.loader import load_denseclip
152
-
153
- print(f"Using DenseClip for text encoding with config: {args.denseclip_config}")
154
-
155
- try:
156
- clip_model = load_denseclip(
157
- config_name=args.denseclip_config,
158
- device=device
159
- )
160
-
161
- # Try to use DenseClip's tokenizer first
162
- try:
163
- from src.denseclip.loader import DenseCLIP_tokenize
164
- tokenizer = DenseCLIP_tokenize
165
- print("Using DenseClip tokenizer")
166
- except ImportError:
167
- # Fallback to CLIP tokenizer if DenseClip tokenizer is not available
168
- import clip
169
- tokenizer = clip.tokenize
170
- print("Warning: DenseClip tokenizer not available, using CLIP tokenizer")
171
-
172
- preprocess = None # DenseClip doesn't need preprocessing for text-only training
173
-
174
- # Determine the actual embedding dimension from the loaded model
175
- if hasattr(clip_model, 'text_encoder') and hasattr(clip_model.text_encoder, 'embed_dim'):
176
- actual_prefix_size = clip_model.text_encoder.embed_dim
177
- print(f"DenseClip text embedding dimension: {actual_prefix_size}")
178
- if actual_prefix_size != prefix_size:
179
- print(f"Updating prefix_size from {prefix_size} to {actual_prefix_size}")
180
- prefix_size = actual_prefix_size
181
-
182
- # Test DenseClip text encoding to ensure it works
183
- test_text = ["A test sentence"]
184
- test_tokens = tokenizer(test_text)
185
- if hasattr(test_tokens, 'to'):
186
- test_tokens = test_tokens.to(device)
187
- test_features = clip_model.encode_text(test_tokens)
188
- print(f"DenseClip test encoding successful. Output shape: {test_features.shape}")
189
-
190
- except Exception as e:
191
- print(f"Error loading DenseClip model: {e}")
192
- raise e
193
-
194
- elif args.use_open_clip:
195
- from open_clip import create_model_and_transforms, tokenize
196
- print("Using open_clip for model loading.")
197
- clip_model, preprocess_train, preprocess_val = create_model_and_transforms(model_name=args.clip_model, pretrained="laion2b_s32b_b79k", device=device)
198
- preprocess = preprocess_train
199
- tokenizer = tokenize
200
-
201
- elif args.use_dinotxt:
202
- from src.dinotxt_utils import get_tokenizer
203
- clip_model = torch.hub.load('facebookresearch/dinov2', 'dinov2_vitl14_reg4_dinotxt_tet1280d20h24l')
204
- tokenizer = get_tokenizer().tokenize
205
- else:
206
- clip_model, preprocess = clip.load(args.clip_model, device=device, jit=False)
207
- tokenizer = clip.tokenize
208
- clip_model.eval()
209
- clip_model.to(device)
210
-
211
- # Create model after determining the correct prefix_size
212
- if args.decap_weights is None:
213
- model = DeCap(prefix_size)
214
- else:
215
- model = get_decap_model(device, args.decap_weights, prefix_size)
216
-
217
- if args.talk2dino_weights is not None:
218
- # loading Talk2DINO
219
- print(f"Loading Talk2DINO weights from {args.talk2dino_weights}")
220
- talk2dino = ProjectionLayer.from_config(args.talk2dino_config)
221
- talk2dino.load_state_dict(torch.load(args.talk2dino_weights, device))
222
- talk2dino.to(device)
223
- talk2dino.eval()
224
-
225
- else:
226
- talk2dino = None
227
-
228
-
229
- loss_ce = torch.nn.CrossEntropyLoss(ignore_index=0,label_smoothing=0.1)
230
- model.to(device)
231
-
232
- if args.not_distributed == False:
233
- model = DDP(
234
- model,
235
- device_ids=[args.local_rank],
236
- output_device=args.local_rank,
237
- find_unused_parameters=True
238
- )
239
-
240
- if not args.pre_extract_features:
241
- print("Features pre-extraction de-activated")
242
- if args.mix_captions:
243
- print("Using mix captions")
244
- dataset = ClipCocoDatasetMix(args.dataset, use_precomputed_feats=args.use_dino_feats, tokenizer=tokenizer)
245
- else:
246
- dataset = ClipCocoDataset(args.dataset, use_dino_feats=args.use_dino_feats, tokenizer=tokenizer)
247
- else:
248
- if args.mix_captions:
249
- print("Using mix captions")
250
- dataset = ClipCocoDatasetMix(args.dataset, clip_model=clip_model, talk2dino=talk2dino, tokenizer=tokenizer)
251
- else:
252
- dataset = ClipCocoDataset(args.dataset, clip_model=clip_model, talk2dino=talk2dino, tokenizer=tokenizer)
253
-
254
-
255
- optimizer = AdamW(model.parameters(),lr=lr)
256
-
257
- print(f"Going to construct DataLoader with {len(dataset)} samples")
258
- if args.not_distributed == False:
259
- sampler = DistributedSampler(dataset)
260
- train_dataloader = DataLoader(dataset, sampler=sampler, batch_size=batch_size, drop_last=True)
261
- else:
262
- train_dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True, drop_last=True)
263
-
264
- print("DataLoader constructed")
265
- scheduler = get_linear_schedule_with_warmup(
266
- optimizer, num_warmup_steps=warmup_steps, num_training_steps=epochs * len(train_dataloader)
267
- )
268
-
269
-
270
- for epoch in range(epochs):
271
-
272
- epoch_loss = 0.0
273
- epoch_acc = 0.0
274
- num_batches = 0
275
-
276
- loss_token_save,ac_save= 0,0
277
- sys.stdout.flush()
278
- if args.is_master:
279
- print(f">>> Training epoch {epoch}")
280
- progress = tqdm(total=int(len(train_dataloader)/10), desc=output_prefix, dynamic_ncols=True)
281
-
282
- if args.not_distributed == False:
283
- dist.barrier()
284
-
285
- for idx,(clip_tokens, pipeline_input) in enumerate(train_dataloader):
286
-
287
-
288
- clip_tokens, pipeline_input = clip_tokens.to(device), pipeline_input.to(device)
289
-
290
- with torch.no_grad():
291
- if not args.pre_extract_features and not args.use_dino_feats:
292
- if args.use_regionclip:
293
- # RegionCLIP text encoding
294
- feature_text = clip_model.encode_text(pipeline_input)
295
- elif args.denseclip_config is not None:
296
- # DenseClip text encoding
297
- feature_text = clip_model.encode_text(pipeline_input)
298
- else:
299
- # Standard CLIP or OpenCLIP text encoding
300
- feature_text = clip_model.encode_text(pipeline_input)
301
-
302
- if args.use_dinotxt:
303
- feature_text = feature_text[:, 1024:] # patch-aligned text embedding
304
-
305
- if args.talk2dino_weights is not None:
306
- feature_text = talk2dino.project_clip_txt(feature_text)
307
- else:
308
- feature_text = pipeline_input
309
- if args.im_proj:
310
- feature_text = im_proj.project(feature_text, normalize=True)
311
-
312
- feature_text /= feature_text.norm(dim=-1, keepdim=True)
313
-
314
- if args.gaussian_noise != 0:
315
- feature_text += args.gaussian_noise * torch.randn(feature_text.shape).to(device)
316
- feature_text /= feature_text.norm(dim=-1, keepdim=True)
317
-
318
-
319
- outputs = model(feature_text.float(),clip_tokens)
320
- logits = outputs
321
-
322
- logits = logits.logits
323
-
324
- logits = logits[:,: -1]
325
- clip_tokens = clip_tokens.flatten()
326
- logits = logits.reshape(-1, logits.shape[-1])
327
-
328
- loss_token = loss_ce(logits, clip_tokens)
329
- ac=((logits.argmax(1)==clip_tokens)*(clip_tokens>0)).sum()/(clip_tokens>0).sum()
330
- optimizer.zero_grad()
331
- loss_all = loss_token
332
- loss_all.backward()
333
- optimizer.step()
334
- scheduler.step()
335
-
336
- epoch_loss += loss_token.item()
337
- epoch_acc += ac.item()
338
- num_batches += 1
339
-
340
- if args.is_master:
341
-
342
- if(idx+1) %10 == 0:
343
- progress.set_postfix({"loss_token": loss_token_save/10.0,"acc_token":ac_save/10.0})
344
- progress.update()
345
- loss_token_save,ac_save= 0,0
346
- else:
347
- loss_token_save += loss_token.item()
348
- ac_save += ac.item()
349
-
350
- if args.is_master:
351
- log_dir = os.path.join('./log', f"{args.dataset}.txt")#'./log/'+args.dataset+'.txt'
352
- with open(log_dir,'w') as f:
353
- f.writelines('epoch ' +str(epoch) +': '+ progress.postfix+'\r\n')
354
- progress.close()
355
- if epoch % args.save_every == 0 or epoch == epochs - 1:
356
- torch.save(
357
- model.state_dict(),
358
- os.path.join(output_dir, f"{output_prefix}-{epoch:03d}.pt"),
359
- )
360
-
361
- # after the epoch, we need to synchronize the loss and accuracy across all processes
362
- loss_tensor = torch.tensor(epoch_loss, device=device)
363
- acc_tensor = torch.tensor(epoch_acc, device=device)
364
- count_tensor = torch.tensor(num_batches, device=device)
365
-
366
- if args.not_distributed == False:
367
- # sum on all processes
368
- torch.distributed.all_reduce(loss_tensor, op=torch.distributed.ReduceOp.SUM)
369
- torch.distributed.all_reduce(acc_tensor, op=torch.distributed.ReduceOp.SUM)
370
- torch.distributed.all_reduce(count_tensor, op=torch.distributed.ReduceOp.SUM)
371
-
372
- # compute global mean
373
- avg_loss = loss_tensor.item() / count_tensor.item()
374
- avg_acc = acc_tensor.item() / count_tensor.item()
375
-
376
- if args.is_master:
377
- epoch_loss_current = {'epoch': epoch, 'loss': avg_loss, 'accuracy': avg_acc}
378
- #epoch_losses.append(epoch_loss_current)
379
- print(f"Epoch {epoch} loss: {avg_loss}, accuracy: {avg_acc}")
380
-
381
- loss_csv_path = os.path.join(output_dir, f"{output_prefix}_epoch_losses.csv")
382
- with open(loss_csv_path, 'a', newline='') as csvfile:
383
- writer = csv.DictWriter(csvfile, fieldnames=['epoch', 'loss', 'accuracy'])
384
- # Write the header only if the file is empty
385
- if os.stat(loss_csv_path).st_size == 0:
386
- writer.writeheader()
387
- writer.writerow(epoch_loss_current)
388
- return model
389
-
390
- # DeCap CLIP B16 karpathy train split:
391
- #python decapTraining.py --out_dir weights_clip_b16_karpathy --not-distributed 1 --local-rank 0 --dataset coco_train_karpathy.json --prefix coco_karpathy
392
- # DECAP with proj -> ma in realtร  non serve.
393
- #python decapTraining.py --out_dir weights_clip_b16_proj_karpathy --not-distributed 1 --local-rank 0 --dataset coco_train_karpathy.json --prefix coco_karpathy --im_proj
394
-
395
- # Patchioner DINOv2 karpathy train split with proj:
396
- #python decapTraining.py --out_dir weights_dino_b14_proj_karpathy --not-distributed 1 --local-rank 1 --dataset coco_train_karpathy.json --prefix coco_karpathy --talk2dino_weights weights_talk2dino/vitb_mlp_infonce.pth --talk2dino_config configs_talk2dino/vitb_mlp_infonce.yaml --pre_extract_features --im_proj
397
- # Patchioner DINOv2 karpathy train split
398
- #python decapTraining.py --out_dir weights_dino_b14_karpathy --not-distributed 1 --local-rank 1 --dataset coco_train_karpathy.json --prefix coco_karpathy --talk2dino_weights weights_talk2dino/vitb_mlp_infonce.pth --talk2dino_config configs_talk2dino/vitb_mlp_infonce.yaml
399
- #python decapTraining.py --out_dir weights_dino_b14_karpathy --not-distributed 1 --local-rank 1 --dataset coco_train_karpathy.json --prefix coco_karpathy --talk2dino_weights weights_talk2dino/vitb_mlp_infonce.pth --talk2dino_config configs_talk2dino/vitb_mlp_infonce.yaml --use_dino_feats --pre_extract_features
400
-
401
- # DeCap CLIP B32 karpathy train split:
402
- #python decapTraining.py --out_dir weights_clip_b32_karpathy --not-distributed 1 --local-rank 0 --dataset coco_train_karpathy.json --prefix coco_karpathy --clip_model ViT-B/32
403
-
404
- # DeCap with RegionCLIP text encoder:
405
- #python decoderTraining.py --out_dir weights_regionclip_karpathy --not-distributed 1 --local-rank 0 --dataset coco_train_karpathy.json --prefix coco_karpathy --use-regionclip
406
-
407
- # DeCap with DenseClip text encoder:
408
- #python decoderTraining.py --out_dir weights_denseclip_segmentation_vitb16_karpathy --not-distributed 1 --local-rank 0 --dataset coco_train_karpathy.json --prefix coco_karpathy --denseclip-config denseclip_segmentation_vitb16
409
-
410
-
411
- def main():
412
- parser = argparse.ArgumentParser()
413
- parser.add_argument('--decap_weights', type=str, default=None, help="If setted the Decap initialization is not random")
414
- parser.add_argument('--clip_model', type=str, default='ViT-B/16', help="CLIP configuration")
415
- parser.add_argument('--use_dinotxt', default=None, action='store_true', help="CLIP configuration")
416
- parser.add_argument('--gaussian_noise', type=float, default=0, help="Standard deviation of the Gaussian noise to apply to the text input")
417
- parser.add_argument('--out_dir', default='./coco_model')
418
- parser.add_argument('--prefix', default='./coco_prefix', help='prefix for saved filenames')
419
- parser.add_argument('--dataset', default='coco', help='coco or cc3m or bookcorpus')
420
- parser.add_argument('--epochs', type=int, default=10)
421
- parser.add_argument('--save_every', type=int, default=1)
422
- parser.add_argument('--prefix_length', type=int, default=1)
423
- parser.add_argument('--prefix_length_clip', type=int, default=1)
424
- parser.add_argument('--bs', type=int, default=64)
425
- parser.add_argument('--talk2dino_weights', type=str, default=None, help="Talk2DINO weights. If None, the training will be performed without Talk2DINO.")
426
- parser.add_argument('--talk2dino_config', type=str, default=None, help="Talk2DINO configs. Valid only if the weights are setted.")
427
- parser.add_argument('--use_dino_feats', action="store_true", default=False, help="If setted, we use the pre-extracted features of DINOv2")
428
- parser.add_argument('--im_proj', action="store_true", default=False, help="If setted, we use the projection on the input features")
429
- parser.add_argument('--pre_extract_features', action="store_true", default=False, help="If setted, the features will be extracted during the dataloading")
430
- parser.add_argument('--only_prefix', dest='only_prefix', action='store_true')
431
- parser.add_argument('--mapping_type', type=str, default='mlp', help='mlp/transformer')
432
- parser.add_argument('--num_layers', type=int, default=8)
433
- parser.add_argument('--is_rn', dest='is_rn', action='store_true')
434
- parser.add_argument('--normalize_prefix', dest='normalize_prefix', action='store_true')
435
- parser.add_argument('--local-rank', type=int, default=-1, metavar='N', help='Local process rank.')
436
- parser.add_argument('--not-distributed', type=int, default=False, metavar='N', help='Not Distributed toggle.')
437
- parser.add_argument('--use-open-clip', action='store_true', default=False, help='Use OpenCLIP instead of CLIP')
438
- parser.add_argument('--mix-captions', action='store_true', default=False, help='Mix captions from the same image')
439
- parser.add_argument('--use-regionclip', action='store_true', default=False, help='Use RegionCLIP for text encoding')
440
- parser.add_argument('--regionclip-checkpoint', type=str, default='/raid/datasets/models_weights/regionclip/regionclip_pretrained-cc_rn50x4.pth', help='Path to RegionCLIP checkpoint file')
441
- parser.add_argument('--regionclip-config', type=str, default='pretrain/RegionCLIP_RN50x4.yaml', help='Path to RegionCLIP config file or config name')
442
- parser.add_argument('--denseclip-config', type=str, default=None, help='Path to DenseClip config file or config name')
443
- args = parser.parse_args()
444
-
445
- # Validate RegionCLIP arguments
446
- if args.use_regionclip and args.regionclip_checkpoint is None:
447
- parser.error("--regionclip-checkpoint is required when using --use-regionclip")
448
-
449
- if args.use_regionclip and args.use_open_clip:
450
- parser.error("Cannot use both --use-regionclip and --use-open-clip at the same time")
451
-
452
- # Validate DenseClip arguments
453
- if args.denseclip_config is not None and args.use_regionclip:
454
- parser.error("Cannot use both --denseclip-config and --use-regionclip at the same time")
455
-
456
- if args.denseclip_config is not None and args.use_open_clip:
457
- parser.error("Cannot use both --denseclip-config and --use-open-clip at the same time")
458
-
459
-
460
- train_decoder(args, output_dir=args.out_dir, output_prefix=args.prefix)
461
-
462
-
463
- if __name__ == '__main__':
464
- main()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
src/decap/decoder_config.pkl DELETED
@@ -1,3 +0,0 @@
1
- version https://git-lfs.github.com/spec/v1
2
- oid sha256:c454e6bddb15af52c82734f1796391bf3a10a6c5533ea095de06f661ebb858bb
3
- size 1744
 
 
 
 
src/decap/im2txtprojection/im2txtprojection.py DELETED
@@ -1,500 +0,0 @@
1
- from enum import Enum
2
- import numpy as np
3
- import math
4
- import json
5
- import random
6
- import torch
7
- from tqdm import tqdm
8
- import os
9
- import h5py
10
- from typing import Tuple
11
- from dotenv import load_dotenv
12
- from src.dinotxt_utils import get_tokenizer
13
-
14
- load_dotenv()
15
-
16
- class ProjectionType(Enum):
17
- COCO_CAPTIONS = 'coco_captions'
18
- MS_MARCO_QUERIES_A = 'ms_marco_queries_a'
19
- CC3M_BLIP = 'cc3m_blip_captions'
20
- VISUAL_GENOME = 'vg_captions'
21
- VISUAL_GENOME_TEST = "vg_dense_captions_test"
22
- ONLINE_TEXTS = "online_texts"
23
-
24
- class Im2TxtProjector:
25
- """
26
- Im2TxtProjector creates and manages text embedding memory banks for different models:
27
- - Standard CLIP models
28
- - OpenCLIP models
29
- - RegionCLIP models
30
- - DenseClip models
31
- - Talk2DINO projected embeddings
32
-
33
- For RegionCLIP usage, pass regionclip_config as a dict with:
34
- {
35
- 'checkpoint': '/path/to/regionclip_checkpoint.pth',
36
- 'config_name': 'RegionCLIP_RN50.yaml' # optional
37
- }
38
-
39
- For DenseClip usage, pass denseclip_config as a string with the config file name:
40
- 'denseclip_vitb16' # or other valid DenseClip config name
41
- """
42
-
43
- SUPPORT_MEMORY_SIZE = 500000
44
-
45
- __IM2TXT_MEMORY_PATH = os.getenv("IM2TXT_MEMORY_PATH")
46
-
47
- if __IM2TXT_MEMORY_PATH is None:
48
- default_path = "weights/im2txtmemories" #os.path.join(os.path.dirname(__file__), "../../../im2txtmemories")
49
- print(f"[!] Warning: IM2TXT_MEMORY_PATH not set in environment variables, using '{default_path}' [!]")
50
- __IM2TXT_MEMORY_PATH = default_path
51
-
52
- __DECAP_FOLDER = os.path.join(os.path.dirname(__file__), "../")
53
- __TALK2DINO_CONFIG_WEIGHTS_PATH = __DECAP_FOLDER
54
-
55
- captions_dataType = 'train2017'
56
- ANNOTATIONS_CAPTION_FILE_PATH = os.path.join(__DECAP_FOLDER, 'captions_{}.json'.format(captions_dataType))
57
- VG_ANNOTATIONS_DENSE_CAPTIONS_FILE_PATH = '/raid/datasets/densecaptioning-annotations/data/vg/controlcap/vg1.2/train.json'
58
- VG_ANNOTATIONS_DENSE_CAPTIONS_TEST_FILE_PATH = '/raid/datasets/densecaptioning-annotations/data/vg/controlcap/vg1.2/test.json'
59
-
60
- CC3M_BLIP_FILE_PATH = os.path.join(__DECAP_FOLDER, "blipv2_captions.txt")
61
- MS_MARCO_QUERIES_FILE_PATH = '/raid/datasets/MSMarco/queries/queries.train.tsv'
62
-
63
- @staticmethod
64
- def create_regionclip_config(checkpoint_path: str, config_name: str = None):
65
- """
66
- Helper method to create RegionCLIP configuration dictionary.
67
-
68
- Args:
69
- checkpoint_path (str): Path to RegionCLIP checkpoint file
70
- config_name (str, optional): RegionCLIP config name (e.g., 'RegionCLIP_RN50.yaml')
71
-
72
- Returns:
73
- dict: Configuration dictionary for RegionCLIP
74
- """
75
- return {
76
- 'checkpoint': checkpoint_path,
77
- 'config_name': config_name
78
- }
79
-
80
- def __init__(self, type = ProjectionType.COCO_CAPTIONS, verbose : bool = True, device_str = "cpu", use_talk2dino : bool = True,
81
- support_memory_size : int = SUPPORT_MEMORY_SIZE, batch_size=1000,
82
- clip_modelname = None, linear_talk2dino : bool = False,
83
- normalize_memory_embs : bool = False, talk2dino_attn_type='qkv', online_texts=None,
84
- memory_bank_name = None, use_open_clip = False, regionclip_config=None, invite_config=None, denseclip_config=None) -> None:
85
- """
86
- - normalize_memory_embs -> normalizes the embeddings memory (required for projection in CLIP space)
87
- - type : ProjectionType -> the type of the support memory to be built . Can either be the path to the file containing the captions or the type of the support memory to be built
88
-
89
- """
90
- # check if hdf5 already exists, otherwhise builds the support memory for that kind
91
-
92
- #if type not in ProjectionType.mro()
93
-
94
- self.type = type
95
- self.device_str = device_str
96
- self.device = torch.device(self.device_str)
97
- self.use_talk2dino = use_talk2dino
98
- self.linear_talk2dino = linear_talk2dino
99
- self.talk2dino_attn_type = talk2dino_attn_type
100
- self.online_texts = online_texts
101
- self.use_open_clip = use_open_clip
102
- self.regionclip_config = regionclip_config
103
- self.invite_config = invite_config
104
- self.denseclip_config = denseclip_config
105
-
106
- if use_open_clip:
107
- assert use_talk2dino is False, "use_open_clip and use_talk2dino cannot be used together"
108
-
109
- if regionclip_config is not None:
110
- assert use_talk2dino is False, "regionclip_config and use_talk2dino cannot be used together"
111
- assert use_open_clip is False, "regionclip_config and use_open_clip cannot be used together"
112
-
113
- if invite_config is not None:
114
- # overwrite clip_modelname with invite_config['name'] if provided
115
- clip_modelname = invite_config.get('name', clip_modelname)
116
- assert use_talk2dino is False, "invite_config and use_talk2dino cannot be used together"
117
-
118
- if denseclip_config is not None:
119
- assert use_talk2dino is False, "denseclip_config and use_talk2dino cannot be used together"
120
- assert use_open_clip is False, "denseclip_config and use_open_clip cannot be used together"
121
- assert regionclip_config is None, "denseclip_config and regionclip_config cannot be used together"
122
-
123
-
124
- if clip_modelname is None:
125
- if self.use_talk2dino:
126
- clip_modelname = "ViT-B/16"
127
- elif regionclip_config is not None:
128
- # For RegionCLIP, we'll use a generic identifier since the model type is in the config
129
- clip_modelname = "RegionCLIP"
130
- elif denseclip_config is not None:
131
- # For DenseClip, we'll use a generic identifier since the model type is in the config
132
- clip_modelname = "DenseClip"
133
- else:
134
- clip_modelname = "ViT-B/32"
135
- self.clip_modelname = clip_modelname
136
-
137
- self.SUPPORT_MEMORY_SIZE = support_memory_size
138
- if use_talk2dino:
139
- prefix = ""
140
- postfix = '-B16' if use_talk2dino is True else use_talk2dino
141
- if linear_talk2dino:
142
- postfix += "-linear"
143
- elif regionclip_config is not None:
144
- prefix = "regionclip-"
145
- postfix = ""
146
- elif denseclip_config is not None:
147
- prefix = "denseclip-"
148
- postfix = ""
149
- else:
150
- prefix = "clip-"
151
- postfix = ""
152
- if talk2dino_attn_type != 'qkv':
153
- self.talk2dino_attn_type_str = f"_{talk2dino_attn_type}"
154
- else:
155
- self.talk2dino_attn_type_str = ''
156
- if isinstance(type, ProjectionType):
157
- dataset_name = type.value
158
- elif memory_bank_name is not None:
159
- dataset_name = memory_bank_name
160
- else:
161
- dataset_name = 'coco'
162
-
163
- if use_open_clip:
164
- postfix += "-open_clip"
165
- elif regionclip_config is not None:
166
- postfix += "-regionclip"
167
- # Add checkpoint identifier to make filename unique
168
- checkpoint_path = regionclip_config.get('checkpoint', '')
169
- checkpoint_name = os.path.basename(checkpoint_path).replace('.pth', '').replace('.pt', '')
170
- if checkpoint_name:
171
- postfix += f"-{checkpoint_name}"
172
- elif denseclip_config is not None:
173
- postfix += "-denseclip"
174
- # Add config identifier to make filename unique
175
- config_name = os.path.basename(denseclip_config).replace('.yaml', '').replace('.yml', '')
176
- if config_name:
177
- postfix += f"-{config_name}"
178
-
179
- self.H5PY_FILE_PATH = os.path.join( self.__IM2TXT_MEMORY_PATH, prefix + f'{dataset_name}{self.talk2dino_attn_type_str}_text_embeddings{postfix}-{clip_modelname.replace("/", ".")}-{self.SUPPORT_MEMORY_SIZE}.h5' )
180
- self.H5PY_EMBEDDINGS_DATASET_NAME = '{}-embeddings'.format(dataset_name)
181
- self.H5PY_TEXT_DATASET_NAME = '{}-text'.format(dataset_name)
182
-
183
- embs_dataset, text_dataset = self._load_support_memory()
184
-
185
- if text_dataset is None:
186
- if verbose:
187
- model_type = "RegionCLIP" if regionclip_config is not None else ("DenseClip" if denseclip_config is not None else ("OpenCLIP" if use_open_clip else "CLIP"))
188
- print(f"[+] Going to build support memory for the given data type: {type} using {model_type} [+]")
189
- embs_dataset, text_dataset = self._build_support_memory(batch_size)
190
- if verbose: print(f"[+] Done [+]")
191
-
192
- if self.type != ProjectionType.ONLINE_TEXTS:
193
- embs_dataset, text_dataset = self._load_support_memory()
194
-
195
- print(f"[-] loaded memory from {os.path.abspath( self.H5PY_FILE_PATH )} [-]")
196
- if regionclip_config is not None:
197
- print(f"[-] Using RegionCLIP text embeddings from checkpoint: {regionclip_config.get('checkpoint', 'Unknown')} [-]")
198
- elif denseclip_config is not None:
199
- print(f"[-] Using DenseClip text embeddings from config: {denseclip_config} [-]")
200
-
201
- self.text_dataset = text_dataset
202
- self.embs_dataset = torch.tensor(embs_dataset[:]).to(self.device)
203
- self.embs_dataset = self.embs_dataset[self.embs_dataset.norm(dim=-1) != 0]
204
-
205
-
206
- if normalize_memory_embs:
207
- self.embs_dataset /= self.embs_dataset.norm(dim=-1,keepdim=True).float()
208
-
209
-
210
-
211
- def project(self, image_embedding, temperature : float = 0.01, normalize : bool = False, return_argmax_text : bool = False, return_n_best_sims=None) -> torch.TensorType:
212
- if not isinstance(image_embedding, torch.Tensor):
213
- print(f"the type of image_embedding is '{type(image_embedding)}' converting it to torch tensor")
214
- image_embedding = torch.tensor(image_embedding, dtype=torch.float).to(self.device)
215
-
216
- orig_device = image_embedding.device
217
-
218
- if image_embedding.device != self.device:
219
- image_embedding = image_embedding.to(self.device)
220
-
221
- if image_embedding.dtype != float:
222
- #print(f"[-] image_embedding.dtype is {image_embedding.dtype}, converting it to float [-]")
223
- image_embedding = image_embedding.float()
224
-
225
- embs_dataset = self.embs_dataset / self.embs_dataset.norm(dim=-1, keepdim=True)
226
- image_embedding /= image_embedding.norm(dim=-1,keepdim=True)
227
-
228
- sim = image_embedding@embs_dataset.T.float()
229
- if return_argmax_text:
230
- argmax_texts = [self.text_dataset[idx].decode() for idx in sim.argmax(dim=-1)]
231
- if return_n_best_sims:
232
- return argmax_texts, sim.sort(dim=-1, descending=True).values[:, :return_n_best_sims].tolist()
233
- return argmax_texts
234
- softmax_sim = (sim / temperature).softmax(dim=-1)
235
- prefix_embedding = [email protected]_dataset.float()
236
-
237
- if normalize:
238
- prefix_embedding /= prefix_embedding.norm(dim=-1,keepdim=True)
239
-
240
- if return_n_best_sims:
241
- return prefix_embedding.to(orig_device), sim.sort(dim=-1, descending=True).values[:, :return_n_best_sims].tolist()
242
-
243
- return prefix_embedding.to(orig_device)
244
-
245
- def _load_support_memory(self) -> Tuple[np.ndarray, np.ndarray]:
246
- if self.type == ProjectionType.ONLINE_TEXTS:
247
- print(f"[-] _load_support_memory: support memory for provided texts will be constructed [-]")
248
- return None, None
249
- if not os.path.exists(self.H5PY_FILE_PATH):
250
- print(f"[-] _load_support_memory: the path '{self.H5PY_FILE_PATH}' does not exist [-]")
251
- return None, None
252
-
253
- with h5py.File(self.H5PY_FILE_PATH, 'r') as hf:
254
-
255
- if self.H5PY_EMBEDDINGS_DATASET_NAME in hf:
256
- embeddings_dataset = hf[self.H5PY_EMBEDDINGS_DATASET_NAME][:]
257
- text_dataset = hf[self.H5PY_TEXT_DATASET_NAME][:]
258
- else:
259
- embeddings_dataset = None
260
- text_dataset = None
261
- if 'DINO.txt' in self.clip_modelname:
262
- embeddings_dataset = embeddings_dataset[:, 1024:] # Get patch-aligned text embeddings
263
- return embeddings_dataset, text_dataset
264
-
265
-
266
-
267
- def _build_support_memory(self, batch_size = 1000) -> Tuple[np.ndarray, np.ndarray]:
268
- ## construct the support memory
269
-
270
- self._load_models()
271
-
272
- if self.type == ProjectionType.COCO_CAPTIONS:
273
- from pycocotools.coco import COCO
274
- coco_obj = COCO(Im2TxtProjector.ANNOTATIONS_CAPTION_FILE_PATH)
275
- data = random.sample(list(coco_obj.anns.values()), k=self.SUPPORT_MEMORY_SIZE)
276
- data = [ d['caption'] for d in data ]
277
- elif self.type == ProjectionType.VISUAL_GENOME:
278
- from pycocotools.coco import COCO
279
- coco_obj = COCO(Im2TxtProjector.VG_ANNOTATIONS_DENSE_CAPTIONS_FILE_PATH)
280
- # data = random.sample(list(coco_obj.anns.values()), k=self.SUPPORT_MEMORY_SIZE)
281
- data = list(coco_obj.anns.values())[:self.SUPPORT_MEMORY_SIZE]
282
- data = [ d['caption'] for d in data ]
283
- elif self.type == ProjectionType.VISUAL_GENOME_TEST:
284
- from pycocotools.coco import COCO
285
- coco_obj = COCO(Im2TxtProjector.VG_ANNOTATIONS_DENSE_CAPTIONS_TEST_FILE_PATH)
286
- # data = random.sample(list(coco_obj.anns.values()), k=self.SUPPORT_MEMORY_SIZE)
287
- data = list(coco_obj.anns.values())[:self.SUPPORT_MEMORY_SIZE]
288
- data = [ d['caption'] for d in data ]
289
- elif self.type == ProjectionType.MS_MARCO_QUERIES_A:
290
- print(f"Loading MSMarco queries from file ", Im2TxtProjector.MS_MARCO_QUERIES_FILE_PATH)
291
- with open(Im2TxtProjector.MS_MARCO_QUERIES_FILE_PATH, "r") as input_file:
292
- lines = input_file.readlines()
293
- data = random.sample(lines, k=self.SUPPORT_MEMORY_SIZE)
294
- data = [ d.split("\t")[1].replace("\n", "") for d in data ]
295
- print(f"Loaded from file '{self.SUPPORT_MEMORY_SIZE}' lines, example of line: '{data[0]}'")
296
- elif self.type == ProjectionType.CC3M_BLIP:
297
- print(f"Loading cc3m captions txt file ", Im2TxtProjector.CC3M_BLIP_FILE_PATH)
298
- with open(Im2TxtProjector.CC3M_BLIP_FILE_PATH, "r") as input_file:
299
- lines = input_file.readlines()
300
- data = random.sample(lines, k=self.SUPPORT_MEMORY_SIZE)
301
- data = [ d.replace("\n", "") for d in data ]
302
- print(f"Loaded from file '{len(data)}' lines, example of line: '{data[0]}'")
303
- elif self.type == ProjectionType.CC3M_BLIP:
304
- print(f"Loading cc3m captions txt file ", Im2TxtProjector.CC3M_BLIP_FILE_PATH)
305
- with open(Im2TxtProjector.CC3M_BLIP_FILE_PATH, "r") as input_file:
306
- lines = input_file.readlines()
307
- data = random.sample(lines, k=self.SUPPORT_MEMORY_SIZE)
308
- data = [ d.replace("\n", "") for d in data ]
309
- print(f"Loaded from file '{len(data)}' lines, example of line: '{data[0]}'")
310
- elif self.type == ProjectionType.ONLINE_TEXTS:
311
- data = self.online_texts
312
- print(f"Loaded online_texts '{len(data)}' lines, example of line: '{data[0]}'")
313
- elif type(self.type) == str:
314
- if os.path.exists(self.type):
315
- path = self.type
316
- from pycocotools.coco import COCO
317
- coco_obj = COCO(path)
318
- data = random.sample(list(coco_obj.anns.values()), k=min(self.SUPPORT_MEMORY_SIZE, len(coco_obj.anns)))
319
- data = [ d['caption'] for d in data ]
320
- else:
321
- #data = random.sample(data,500000)
322
- print(f"[!] Unimplemented data type '{self.type}'[!]")
323
- return None, None
324
-
325
- text_features = []
326
- captions = []
327
-
328
- self.clip_model.eval()
329
-
330
- n_txts = len(data)
331
- n_batch = math.ceil(n_txts / batch_size)
332
- for i in tqdm(range(n_batch)):
333
- start = i * batch_size
334
- end = start + batch_size if i < n_batch - 1 else n_txts
335
-
336
- texts = data[start:end]
337
- with torch.no_grad():
338
- texts_token = self.tokenizer(texts).to(self.device)
339
- text_feature = self.clip_model.encode_text(texts_token)
340
- if self.use_talk2dino:
341
- text_feature = self.talk2dino.project_clip_txt(text_feature)
342
- text_features.append(text_feature)
343
- captions.extend(texts)
344
-
345
- text_features = torch.cat(text_features,dim=0)
346
-
347
-
348
- #text_features /= text_features.norm(dim=-1,keepdim=True).float()
349
-
350
- # store captions and text features in hdf5 dataset
351
-
352
- text_features_ndarray = text_features.cpu().numpy()
353
-
354
- assert len(text_features_ndarray) == len(captions), f"len(text_features_ndarray) = {len(text_features_ndarray)} != len(captions) = {len(captions)}"
355
-
356
- #if not os.path.exists(self.H5PY_FILE_PATH):
357
- # print(f"os.path '{self.H5PY_FILE_PATH}' does not exists")
358
-
359
- EMBEDDINGS_DIMENSION = text_features_ndarray.shape[1]
360
-
361
- if self.type != ProjectionType.ONLINE_TEXTS:
362
- with h5py.File(self.H5PY_FILE_PATH, 'w') as hf:
363
-
364
- if self.H5PY_EMBEDDINGS_DATASET_NAME in hf:
365
- embeddings_dataset = hf[self.H5PY_EMBEDDINGS_DATASET_NAME]
366
- text_dataset = hf[self.H5PY_TEXT_DATASET_NAME]
367
- print(f"[!] Dataset '{self.H5PY_EMBEDDINGS_DATASET_NAME}' already exists! Going to overwrite [!]")
368
- else:
369
- embeddings_dataset = hf.create_dataset(self.H5PY_EMBEDDINGS_DATASET_NAME, shape=(self.SUPPORT_MEMORY_SIZE, EMBEDDINGS_DIMENSION), dtype='float32')
370
- text_dataset = hf.create_dataset(self.H5PY_TEXT_DATASET_NAME, shape=(self.SUPPORT_MEMORY_SIZE, ), dtype=h5py.string_dtype(encoding='utf-8')) #, dtype='str'
371
-
372
- for num_row in range(len(text_features_ndarray)):
373
- embeddings_dataset[num_row] = text_features_ndarray[num_row]
374
- text_dataset[num_row] = captions[num_row]
375
- else:
376
- embeddings_dataset = text_features_ndarray
377
- text_dataset = [x.encode() for x in captions]
378
-
379
- return embeddings_dataset, text_dataset
380
-
381
- clip_model = None
382
- def _load_models(self):
383
-
384
- if self.clip_model is not None:
385
- # case already done
386
- return
387
-
388
- if self.use_open_clip:
389
- print("[-] loading open_clip model [-]")
390
- assert self.clip_modelname is not None, "clip_modelname must be provided when using open_clip"
391
- from open_clip import create_model_and_transforms, tokenize
392
- self.clip_model, preprocess_train, preprocess_val = create_model_and_transforms(self.clip_modelname, pretrained="laion2b_s32b_b79k", device=self.device)
393
- self.preprocess = preprocess_train
394
- self.tokenizer = tokenize
395
- return
396
-
397
- if self.regionclip_config is not None:
398
- print("[-] loading RegionCLIP model [-]")
399
- from src.regionclip.loader import load_regionclip_from_checkpoint
400
- from src.regionclip.datasets.clip_prompt_utils import tokenize as regionclip_tokenize
401
-
402
- regionclip_checkpoint = self.regionclip_config.get('checkpoint', None)
403
- if regionclip_checkpoint is None:
404
- raise ValueError("RegionCLIP checkpoint not specified in the configuration")
405
- regionclip_config_name = self.regionclip_config.get('config_name', None)
406
-
407
- print(f"[-] Loading RegionCLIP from checkpoint: {regionclip_checkpoint} [-]")
408
- if regionclip_config_name:
409
- print(f"[-] Using RegionCLIP config: {regionclip_config_name} [-]")
410
-
411
- self.clip_model = load_regionclip_from_checkpoint(
412
- regionclip_checkpoint,
413
- device=self.device,
414
- config=regionclip_config_name
415
- )
416
- self.tokenizer = regionclip_tokenize
417
- self.preprocess = None # RegionCLIP doesn't need preprocessing for text encoding
418
-
419
- # Test RegionCLIP text encoding to ensure it works
420
- try:
421
- test_text = ["A test sentence for RegionCLIP"]
422
- test_tokens = self.tokenizer(test_text)
423
- test_features = self.clip_model.encode_text(test_tokens.to(self.device))
424
- print(f"[-] RegionCLIP text encoding test successful. Output shape: {test_features.shape} [-]")
425
- except Exception as e:
426
- print(f"[!] Warning: RegionCLIP text encoding test failed: {e} [!]")
427
- raise e
428
-
429
- return
430
-
431
- if self.denseclip_config is not None:
432
- print("[-] loading DenseClip model [-]")
433
- from src.denseclip.loader import load_denseclip, DenseCLIP_tokenize
434
-
435
- print(f"[-] Loading DenseClip from config: {self.denseclip_config} [-]")
436
-
437
- # Load DenseClip model
438
- self.clip_model = load_denseclip(
439
- config_name=self.denseclip_config,
440
- device=self.device
441
- )
442
-
443
- # DenseClip should have encode_text method and a tokenizer
444
- # We need to check if DenseClip has a tokenizer method
445
- if DenseCLIP_tokenize is not None:
446
- self.tokenizer = DenseCLIP_tokenize
447
- else:
448
- # Fallback to CLIP tokenizer if DenseClip doesn't provide one
449
- import clip
450
- self.tokenizer = clip.tokenize
451
- print("[!] Warning: DenseClip model doesn't have tokenizer, using CLIP tokenizer [!]")
452
-
453
- self.preprocess = None # DenseClip doesn't need preprocessing for text encoding
454
-
455
- # Test DenseClip text encoding to ensure it works
456
- try:
457
- test_text = ["A test sentence for DenseClip"]
458
- test_tokens = self.tokenizer(test_text)
459
- if hasattr(test_tokens, 'to'):
460
- test_tokens = test_tokens.to(self.device)
461
- test_features = self.clip_model.encode_text(test_tokens)
462
- print(f"[-] DenseClip text encoding test successful. Output shape: {test_features.shape} [-]")
463
- except Exception as e:
464
- print(f"[!] Warning: DenseClip text encoding test failed: {e} [!]")
465
- raise e
466
-
467
- return
468
-
469
- import clip
470
- if self.clip_modelname is None:
471
- clip_model_name = "ViT-B/16" if self.use_talk2dino else "ViT-B/32"
472
- else:
473
- clip_model_name = self.clip_modelname
474
- if 'DINO.txt' not in clip_model_name:
475
- self.clip_model, self.preprocess = clip.load(clip_model_name, device=self.device, jit=False)
476
- self.tokenizer = clip.tokenize
477
- if self.use_talk2dino:
478
- # loading Talk2DINO
479
- if type(self.use_talk2dino) == str:
480
- proj_name = self.use_talk2dino
481
- elif self.linear_talk2dino is False:
482
- proj_name = 'vitb_mlp_infonce'
483
- else:
484
- proj_name = 'vitb_linear_infonce'
485
-
486
-
487
- config = os.path.join(self.__TALK2DINO_CONFIG_WEIGHTS_PATH, "configs_talk2dino", proj_name + '.yaml')
488
- weights = os.path.join(self.__TALK2DINO_CONFIG_WEIGHTS_PATH, "weights_talk2dino", proj_name + self.talk2dino_attn_type_str + '.pth')
489
- #import sys
490
- #import os
491
- #add_path = os.path.abspath( os.path.dirname("../"))
492
- ##print(add_path)
493
- #sys.path.insert(1, add_path )
494
- from src.model import ProjectionLayer
495
- self.talk2dino = ProjectionLayer.from_config(config)
496
- self.talk2dino.load_state_dict(torch.load((weights), self.device))
497
- self.talk2dino.to(self.device)
498
- else:
499
- self.clip_model = torch.hub.load('facebookresearch/dinov2', 'dinov2_vitl14_reg4_dinotxt_tet1280d20h24l').to(self.device)
500
- self.tokenizer = get_tokenizer().tokenize
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
src/denseclip/clip_loader/README.md DELETED
@@ -1,233 +0,0 @@
1
- # DenseCLIP to CLIP Loader
2
-
3
- A simple interface for loading DenseCLIP checkpoints as CLIP-like models for text and image encoding.
4
-
5
- ## Overview
6
-
7
- This module provides a clean API to load DenseCLIP models and use them like standard CLIP models for encoding text and images. It abstracts away the complexity of DenseCLIP's detection/segmentation components and exposes only the core vision-language encoding functionality.
8
-
9
- ## Features
10
-
11
- - โœ… **Simple API**: Load DenseCLIP models with a single function call
12
- - โœ… **CLIP-like Interface**: Familiar `encode_text()` and `encode_image()` methods
13
- - โœ… **Flexible Configuration**: YAML-based configuration system
14
- - โœ… **Multiple Input Types**: Support for PIL Images, image tensors, strings, and text lists
15
- - โœ… **Automatic Preprocessing**: Built-in image preprocessing pipeline
16
- - โœ… **Device Management**: Automatic GPU/CPU detection and placement
17
-
18
- ## Quick Start
19
-
20
- ```python
21
- from clip_loader import load_clip
22
-
23
- # Load DenseCLIP model with default configuration
24
- model = load_clip('denseclip_segmentation_vitb16')
25
-
26
- # Encode text
27
- texts = ["a photo of a cat", "a photo of a dog"]
28
- text_features = model.encode_text(texts)
29
-
30
- # Encode images (PIL Images)
31
- from PIL import Image
32
- images = [Image.open("cat.jpg"), Image.open("dog.jpg")]
33
- image_features = model.encode_image(images)
34
-
35
- # Compute similarities
36
- similarities = model.compute_similarity(image_features, text_features)
37
- print(f"Image-text similarities: {similarities}")
38
- ```
39
-
40
- ## Configuration
41
-
42
- Models are configured using YAML files in the `configs/` directory. The main configuration for DenseCLIP ViT-B/16 is in `configs/denseclip_vitb16.yaml`.
43
-
44
- ### Configuration Structure
45
-
46
- ```yaml
47
- model:
48
- name: "denseclip_vitb16"
49
- type: "vit"
50
-
51
- vision:
52
- image_resolution: 224
53
- vision_layers: 12
54
- vision_width: 768
55
- vision_patch_size: 16
56
- embed_dim: 512
57
-
58
- text:
59
- context_length: 13 # DenseCLIP uses shorter context
60
- vocab_size: 49408
61
- transformer_width: 512
62
- transformer_heads: 8
63
- transformer_layers: 12
64
- embed_dim: 512
65
-
66
- checkpoint:
67
- path: "/path/to/denseclip/checkpoint.pth"
68
- format: "denseclip"
69
-
70
- preprocessing:
71
- image_mean: [0.48145466, 0.4578275, 0.40821073]
72
- image_std: [0.26862954, 0.26130258, 0.27577711]
73
- normalize: true
74
- ```
75
-
76
- ## API Reference
77
-
78
- ### Core Functions
79
-
80
- #### `load_clip(config_name, checkpoint_path=None, device='auto')`
81
-
82
- Load a DenseCLIP model with the specified configuration.
83
-
84
- **Parameters:**
85
- - `config_name` (str): Name of config file (without .yaml extension)
86
- - `checkpoint_path` (str, optional): Path to checkpoint file (overrides config)
87
- - `device` (str): Device to load on ('auto', 'cpu', 'cuda')
88
-
89
- **Returns:**
90
- - `DenseCLIPModel`: Loaded model ready for inference
91
-
92
- #### `load_denseclip_model(config_path, checkpoint_path=None, device='auto')`
93
-
94
- Load a DenseCLIP model from configuration file path.
95
-
96
- ### DenseCLIPModel Methods
97
-
98
- #### `encode_text(texts)`
99
-
100
- Encode text into feature vectors.
101
-
102
- **Parameters:**
103
- - `texts` (str or List[str]): Text string(s) to encode
104
-
105
- **Returns:**
106
- - `torch.Tensor`: Normalized text features [batch_size, embed_dim]
107
-
108
- #### `encode_image(images)`
109
-
110
- Encode images into feature vectors.
111
-
112
- **Parameters:**
113
- - `images`: PIL Image, List[PIL.Image], or preprocessed tensor
114
-
115
- **Returns:**
116
- - `torch.Tensor`: Normalized image features [batch_size, embed_dim]
117
-
118
- #### `compute_similarity(image_features, text_features, temperature=1.0)`
119
-
120
- Compute similarity between image and text features.
121
-
122
- **Parameters:**
123
- - `image_features` (torch.Tensor): Image features [N, embed_dim]
124
- - `text_features` (torch.Tensor): Text features [M, embed_dim]
125
- - `temperature` (float): Temperature scaling factor
126
-
127
- **Returns:**
128
- - `torch.Tensor`: Similarity matrix [N, M]
129
-
130
- ## Examples
131
-
132
- ### Basic Text-Image Retrieval
133
-
134
- ```python
135
- from clip_loader import load_clip
136
- from PIL import Image
137
-
138
- # Load model
139
- model = load_clip('denseclip_vitb16')
140
-
141
- # Load and encode images
142
- images = [
143
- Image.open("cat.jpg"),
144
- Image.open("dog.jpg"),
145
- Image.open("car.jpg")
146
- ]
147
- image_features = model.encode_image(images)
148
-
149
- # Encode text queries
150
- queries = [
151
- "a cute cat",
152
- "a happy dog",
153
- "a red car"
154
- ]
155
- text_features = model.encode_text(queries)
156
-
157
- # Find best matches
158
- similarities = model.compute_similarity(image_features, text_features)
159
- best_matches = similarities.argmax(dim=1)
160
-
161
- for i, query in enumerate(queries):
162
- best_image_idx = best_matches[i]
163
- score = similarities[best_image_idx, i].item()
164
- print(f"Query '{query}' -> Image {best_image_idx} (score: {score:.3f})")
165
- ```
166
-
167
- ### Zero-Shot Classification
168
-
169
- ```python
170
- from clip_loader import load_clip
171
- from PIL import Image
172
-
173
- model = load_clip('denseclip_vitb16')
174
-
175
- # Load test image
176
- image = Image.open("test_image.jpg")
177
- image_features = model.encode_image(image)
178
-
179
- # Define class labels
180
- class_labels = [
181
- "a photo of a cat",
182
- "a photo of a dog",
183
- "a photo of a bird",
184
- "a photo of a car",
185
- "a photo of a house"
186
- ]
187
-
188
- # Encode labels
189
- text_features = model.encode_text(class_labels)
190
-
191
- # Classify
192
- similarities = model.compute_similarity(image_features, text_features)
193
- probabilities = similarities.softmax(dim=-1)
194
-
195
- # Show results
196
- for i, label in enumerate(class_labels):
197
- prob = probabilities[0, i].item()
198
- print(f"{label}: {prob:.3f}")
199
- ```
200
-
201
- ### Custom Configuration
202
-
203
- ```python
204
- from clip_loader import load_denseclip_model
205
-
206
- # Load with custom config
207
- model = load_denseclip_model(
208
- config_path='configs/custom_config.yaml',
209
- checkpoint_path='/path/to/custom/checkpoint.pth',
210
- device='cuda:1'
211
- )
212
- ```
213
-
214
- ## Requirements
215
-
216
- - PyTorch >= 1.9.0
217
- - torchvision >= 0.10.0
218
- - Pillow >= 8.0.0
219
- - PyYAML >= 5.4.0
220
-
221
- ## Notes
222
-
223
- - DenseCLIP uses a shorter text context length (13) compared to standard CLIP (77)
224
- - The model preserves ~98% similarity with original CLIP text representations
225
- - Image preprocessing follows CLIP's standard normalization
226
- - All features are L2-normalized for cosine similarity computation
227
-
228
- ## Supported Models
229
-
230
- Currently supported:
231
- - `denseclip_vitb16`: DenseCLIP with ViT-B/16 backbone
232
-
233
- To add support for other DenseCLIP variants, create new configuration files in the `configs/` directory.
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
src/denseclip/clip_loader/SUMMARY.md DELETED
@@ -1,78 +0,0 @@
1
- # DenseCLIP to CLIP Loader - Quick Start
2
-
3
- ## โœ… Successfully Created!
4
-
5
- The `clip_loader` module provides a simple interface to load DenseCLIP checkpoints as CLIP-like models for text and image encoding.
6
-
7
- ## ๐Ÿ“ Structure
8
-
9
- ```
10
- /raid/homes/giacomo.pacini/DenseCLIP/clip_loader/
11
- โ”œโ”€โ”€ __init__.py # Module initialization
12
- โ”œโ”€โ”€ denseclip_loader.py # Main loader implementation
13
- โ”œโ”€โ”€ example_usage.py # Example script
14
- โ”œโ”€โ”€ requirements.txt # Dependencies
15
- โ”œโ”€โ”€ README.md # Full documentation
16
- โ””โ”€โ”€ configs/
17
- โ””โ”€โ”€ denseclip_vitb16.yaml # Configuration for ViT-B/16 model
18
- ```
19
-
20
- ## ๐Ÿš€ Quick Usage
21
-
22
- ```python
23
- from clip_loader import load_clip
24
-
25
- # Load DenseCLIP model with default configuration
26
- model = load_clip('denseclip_vitb16')
27
-
28
- # Encode text
29
- texts = ["a photo of a cat", "a photo of a dog"]
30
- text_features = model.encode_text(texts) # Shape: [2, 512]
31
-
32
- # Encode images (if you have PIL Images)
33
- # image_features = model.encode_image(images)
34
-
35
- # Compute similarities
36
- similarities = model.compute_similarity(text_features, text_features)
37
- print(f"Cat-Dog similarity: {similarities[0, 1]:.3f}")
38
- ```
39
-
40
- ## โœ… Test Results
41
-
42
- - **โœ… Model loads successfully** from DenseCLIP checkpoint
43
- - **โœ… Text encoding works** (shape: [batch_size, 512])
44
- - **โœ… Features are normalized** (L2 norm = 1.0)
45
- - **โœ… Similarities make sense** (Cat-Dog: 0.872, Car-Person: lower)
46
- - **โœ… Zero-shot classification** shows logical patterns
47
- - **โœ… Model has 157M parameters** (94M vision + 63M text)
48
-
49
- ## ๐Ÿ”ง Key Features
50
-
51
- - **Simple API**: Just call `load_clip()` and start encoding
52
- - **Handles DenseCLIP specifics**: Automatically extracts weights from segmentation checkpoint
53
- - **CLIP-compatible**: Same interface as OpenAI CLIP
54
- - **Flexible configuration**: YAML-based configuration system
55
- - **GPU ready**: Automatic device detection and placement
56
- - **Context length**: Uses DenseCLIP's shorter context (13 vs 77)
57
-
58
- ## ๐ŸŽฏ Use Cases
59
-
60
- 1. **Text-Image Retrieval**: Encode both and compute similarities
61
- 2. **Zero-Shot Classification**: Encode class descriptions and compare
62
- 3. **Text Similarity**: Compare text representations
63
- 4. **Feature Extraction**: Get dense vector representations
64
-
65
- ## ๐Ÿ“ Configuration
66
-
67
- The model uses `/raid/datasets/models_weights/denseclip/segmentation/semanticFPN/ViT-B-DenseCLIP.pth` by default. You can override this by modifying `configs/denseclip_vitb16.yaml` or passing a custom checkpoint path.
68
-
69
- ## ๐Ÿ” What's Different from Standard CLIP
70
-
71
- - **Shorter context length**: 13 tokens vs 77
72
- - **Higher image resolution**: 640px vs 224px
73
- - **Fine-tuned weights**: Adapted for dense prediction tasks
74
- - **High text similarity**: ~98% similarity with original CLIP representations
75
-
76
- ## ๐ŸŽ‰ Ready to Use!
77
-
78
- The loader is fully functional and ready for use in your projects. See `README.md` for detailed documentation and more examples.
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
src/denseclip/clip_loader/__init__.py DELETED
@@ -1,21 +0,0 @@
1
- """
2
- DenseCLIP to CLIP Loader
3
-
4
- A simple interface for loading DenseCLIP checkpoints as CLIP-like models
5
- for text and image encoding.
6
- """
7
-
8
- from .denseclip_loader import (
9
- DenseCLIPModel,
10
- load_denseclip_model,
11
- load_clip,
12
- load_config
13
- )
14
-
15
- __version__ = "1.0.0"
16
- __all__ = [
17
- "DenseCLIPModel",
18
- "load_denseclip_model",
19
- "load_clip",
20
- "load_config"
21
- ]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
src/denseclip/clip_loader/bpe_simple_vocab_16e6.txt.gz DELETED
@@ -1,3 +0,0 @@
1
- version https://git-lfs.github.com/spec/v1
2
- oid sha256:924691ac288e54409236115652ad4aa250f48203de50a9e4722a6ecd48d6804a
3
- size 1356917
 
 
 
 
src/denseclip/clip_loader/configs/denseclip_segmentation_vitb16.yaml DELETED
@@ -1,41 +0,0 @@
1
- # DenseCLIP ViT-B/16 Configuration
2
- # Configuration for loading DenseCLIP checkpoint as a CLIP-like model
3
-
4
- model:
5
- name: "denseclip_vitb16"
6
- type: "vit" # vision transformer
7
-
8
- # Vision encoder configuration
9
- vision:
10
- image_resolution: 640
11
- vision_layers: 12
12
- vision_width: 768
13
- vision_patch_size: 16
14
- embed_dim: 512
15
-
16
- # Text encoder configuration
17
- text:
18
- context_length: 13 # DenseCLIP uses shorter context
19
- vocab_size: 49408
20
- transformer_width: 512
21
- transformer_heads: 8
22
- transformer_layers: 12
23
- embed_dim: 512
24
-
25
- # Checkpoint information
26
- checkpoint:
27
- path: "/raid/datasets/models_weights/denseclip/segmentation/semanticFPN/ViT-B-DenseCLIP.pth"
28
- format: "denseclip" # vs "openai_clip"
29
-
30
- # Processing configuration
31
- preprocessing:
32
- image_mean: [0.48145466, 0.4578275, 0.40821073]
33
- image_std: [0.26862954, 0.26130258, 0.27577711]
34
- normalize: true
35
-
36
- # Optional overrides
37
- overrides:
38
- # Set to true to use OpenAI CLIP tokenizer instead of DenseCLIP's
39
- use_openai_tokenizer: false
40
- # Set custom context length (will resize positional embeddings if needed)
41
- custom_context_length: null
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
src/denseclip/clip_loader/configs/denseclip_segmentation_vitb16_long_ctx.yaml DELETED
@@ -1,41 +0,0 @@
1
- # DenseCLIP ViT-B/16 Configuration
2
- # Configuration for loading DenseCLIP checkpoint as a CLIP-like model
3
-
4
- model:
5
- name: "denseclip_vitb16"
6
- type: "vit" # vision transformer
7
-
8
- # Vision encoder configuration
9
- vision:
10
- image_resolution: 640
11
- vision_layers: 12
12
- vision_width: 768
13
- vision_patch_size: 16
14
- embed_dim: 512
15
-
16
- # Text encoder configuration
17
- text:
18
- context_length: 77 # DenseCLIP uses shorter context
19
- vocab_size: 49408
20
- transformer_width: 512
21
- transformer_heads: 8
22
- transformer_layers: 12
23
- embed_dim: 512
24
-
25
- # Checkpoint information
26
- checkpoint:
27
- path: "/raid/datasets/models_weights/denseclip/segmentation/semanticFPN/ViT-B-DenseCLIP_long_ctx.pth"
28
- format: "denseclip" # vs "openai_clip"
29
-
30
- # Processing configuration
31
- preprocessing:
32
- image_mean: [0.48145466, 0.4578275, 0.40821073]
33
- image_std: [0.26862954, 0.26130258, 0.27577711]
34
- normalize: true
35
-
36
- # Optional overrides
37
- overrides:
38
- # Set to true to use OpenAI CLIP tokenizer instead of DenseCLIP's
39
- use_openai_tokenizer: false
40
- # Set custom context length (will resize positional embeddings if needed)
41
- custom_context_length: null
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
src/denseclip/clip_loader/denseclip_loader.py DELETED
@@ -1,316 +0,0 @@
1
- #!/usr/bin/env python3
2
- """
3
- DenseCLIP to CLIP Loader
4
-
5
- A simple interface for loading DenseCLIP checkpoints as CLIP-like models
6
- for text and image encoding.
7
- """
8
-
9
- import os
10
- import sys
11
- import yaml
12
- import torch
13
- import torch.nn as nn
14
- import torch.nn.functional as F
15
- from typing import Union, List, Tuple, Optional, Dict, Any
16
- from PIL import Image
17
- import torchvision.transforms as transforms
18
-
19
- # Import local model components
20
- try:
21
- from .models import CLIPVisionTransformer, CLIPTextEncoder, ResidualAttentionBlock, LayerNorm, QuickGELU
22
- from .tokenizer import tokenize
23
- except ImportError:
24
- # Fallback for direct execution
25
- from models import CLIPVisionTransformer, CLIPTextEncoder, ResidualAttentionBlock, LayerNorm, QuickGELU
26
- from tokenizer import tokenize
27
-
28
-
29
- class DenseCLIPModel(nn.Module):
30
- """
31
- A CLIP-like model loaded from DenseCLIP checkpoints.
32
- Provides simple text and image encoding functionality.
33
- """
34
-
35
- def __init__(self, config: Dict[str, Any]):
36
- super().__init__()
37
-
38
- self.config = config
39
-
40
- # Initialize vision encoder
41
- vision_config = config['model']['vision']
42
- self.visual = CLIPVisionTransformer(
43
- input_resolution=vision_config['image_resolution'],
44
- patch_size=vision_config['vision_patch_size'],
45
- width=vision_config['vision_width'],
46
- layers=vision_config['vision_layers'],
47
- heads=vision_config['vision_width'] // 64,
48
- output_dim=vision_config['embed_dim']
49
- )
50
-
51
- # Initialize text encoder
52
- text_config = config['model']['text']
53
- self.text_encoder = CLIPTextEncoder(
54
- context_length=text_config['context_length'],
55
- vocab_size=text_config['vocab_size'],
56
- transformer_width=text_config['transformer_width'],
57
- transformer_heads=text_config['transformer_heads'],
58
- transformer_layers=text_config['transformer_layers'],
59
- embed_dim=text_config['embed_dim']
60
- )
61
-
62
- # Store configuration for preprocessing
63
- self.context_length = text_config['context_length']
64
- self.image_resolution = vision_config['image_resolution']
65
-
66
- # Initialize preprocessing
67
- self._setup_preprocessing()
68
-
69
- def _setup_preprocessing(self):
70
- """Setup image preprocessing pipeline"""
71
- preprocess_config = self.config['preprocessing']
72
-
73
- self.preprocess = transforms.Compose([
74
- transforms.Resize(self.image_resolution, interpolation=transforms.InterpolationMode.BICUBIC),
75
- transforms.CenterCrop(self.image_resolution),
76
- transforms.ToTensor(),
77
- transforms.Normalize(
78
- mean=preprocess_config['image_mean'],
79
- std=preprocess_config['image_std']
80
- )
81
- ])
82
-
83
- def encode_image(self, images: Union[torch.Tensor, List[Image.Image], Image.Image]) -> torch.Tensor:
84
- """
85
- Encode images into feature vectors
86
-
87
- Args:
88
- images: PIL Images, list of PIL Images, or preprocessed tensor
89
-
90
- Returns:
91
- Normalized image features [batch_size, embed_dim]
92
- """
93
- if isinstance(images, (list, tuple)):
94
- # List of PIL Images
95
- image_tensors = torch.stack([self.preprocess(img) for img in images])
96
- elif isinstance(images, Image.Image):
97
- # Single PIL Image
98
- image_tensors = self.preprocess(images).unsqueeze(0)
99
- elif isinstance(images, torch.Tensor):
100
- # Already preprocessed tensor
101
- image_tensors = images
102
- else:
103
- raise ValueError(f"Unsupported image type: {type(images)}")
104
-
105
- # Move to same device as model
106
- device = next(self.parameters()).device
107
- image_tensors = image_tensors.to(device)
108
-
109
- # Encode
110
- with torch.no_grad():
111
- image_features = self.visual(image_tensors)
112
- image_features = F.normalize(image_features, dim=-1)
113
-
114
- return image_features
115
-
116
- def encode_text(self, texts: Union[str, List[str]]) -> torch.Tensor:
117
- """
118
- Encode texts into feature vectors
119
-
120
- Args:
121
- texts: Single text string or list of text strings
122
-
123
- Returns:
124
- Normalized text features [batch_size, embed_dim]
125
- """
126
- if isinstance(texts, str):
127
- texts = [texts]
128
-
129
- # Tokenize if necessary
130
- if isinstance(texts, list):
131
- tokens = tokenize(texts, context_length=self.context_length)
132
- elif isinstance(texts, torch.Tensor):
133
- if texts.dim() == 1:
134
- # Single tokenized text
135
- tokens = texts.unsqueeze(0)
136
- else:
137
- tokens = texts
138
- else:
139
- raise ValueError(f"Unsupported text type: {type(texts)}")
140
- # Move to same device as model
141
- device = next(self.parameters()).device
142
- tokens = tokens.to(device)
143
-
144
- # Encode
145
- with torch.no_grad():
146
- text_features = self.text_encoder(tokens)
147
- text_features = F.normalize(text_features, dim=-1)
148
-
149
- return text_features
150
-
151
- def compute_similarity(self,
152
- image_features: torch.Tensor,
153
- text_features: torch.Tensor,
154
- temperature: float = 1.0) -> torch.Tensor:
155
- """
156
- Compute similarity between image and text features
157
-
158
- Args:
159
- image_features: Normalized image features [N, embed_dim]
160
- text_features: Normalized text features [M, embed_dim]
161
- temperature: Temperature for scaling similarities
162
-
163
- Returns:
164
- Similarity matrix [N, M]
165
- """
166
- return (image_features @ text_features.t()) / temperature
167
-
168
- def forward(self, images: torch.Tensor, texts: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
169
- """
170
- Forward pass for both image and text encoding
171
-
172
- Args:
173
- images: Preprocessed image tensor [batch_size, 3, H, W]
174
- texts: Tokenized text tensor [batch_size, context_length]
175
-
176
- Returns:
177
- Tuple of (image_features, text_features)
178
- """
179
- image_features = self.visual(images)
180
- text_features = self.text_encoder(texts)
181
-
182
- # Normalize features
183
- image_features = F.normalize(image_features, dim=-1)
184
- text_features = F.normalize(text_features, dim=-1)
185
-
186
- return image_features, text_features
187
-
188
-
189
- def load_config(config_path: str) -> Dict[str, Any]:
190
- """Load configuration from YAML file"""
191
- with open(config_path, 'r') as f:
192
- config = yaml.safe_load(f)
193
- return config
194
-
195
-
196
- def load_denseclip_weights(checkpoint_path: str) -> Dict[str, torch.Tensor]:
197
- """Load DenseCLIP checkpoint and extract relevant weights"""
198
- print(f"Loading DenseCLIP checkpoint from: {checkpoint_path}")
199
- checkpoint = torch.load(checkpoint_path, map_location='cpu')
200
-
201
- if 'state_dict' not in checkpoint:
202
- raise ValueError("Checkpoint doesn't contain 'state_dict'")
203
-
204
- state_dict = checkpoint['state_dict']
205
-
206
- # Extract vision and text encoder weights
207
- vision_weights = {}
208
- text_weights = {}
209
-
210
- for key, value in state_dict.items():
211
- if key.startswith('backbone.'):
212
- # Remove 'backbone.' prefix for vision encoder
213
- new_key = key[len('backbone.'):]
214
- vision_weights[new_key] = value
215
- elif key.startswith('text_encoder.'):
216
- # Remove 'text_encoder.' prefix
217
- new_key = key[len('text_encoder.'):]
218
- text_weights[new_key] = value
219
-
220
- print(f"Extracted {len(vision_weights)} vision parameters")
221
- print(f"Extracted {len(text_weights)} text parameters")
222
-
223
- return {
224
- 'vision': vision_weights,
225
- 'text': text_weights,
226
- 'full_state_dict': state_dict
227
- }
228
-
229
-
230
- def load_denseclip_model(config_path: str,
231
- checkpoint_path: Optional[str] = None,
232
- device: str = 'auto') -> DenseCLIPModel:
233
- """
234
- Load a DenseCLIP model from configuration and checkpoint
235
-
236
- Args:
237
- config_path: Path to YAML configuration file
238
- checkpoint_path: Optional path to checkpoint (overrides config)
239
- device: Device to load model on ('auto', 'cpu', 'cuda')
240
-
241
- Returns:
242
- Loaded DenseCLIPModel ready for inference
243
- """
244
- # Load configuration
245
- config = load_config(config_path)
246
-
247
- # Override checkpoint path if provided
248
- if checkpoint_path is not None:
249
- config['checkpoint']['path'] = checkpoint_path
250
-
251
- # Create model
252
- model = DenseCLIPModel(config)
253
-
254
- # Load weights
255
- checkpoint_path = config['checkpoint']['path']
256
- if os.path.exists(checkpoint_path):
257
- weights = load_denseclip_weights(checkpoint_path)
258
-
259
- # Load vision encoder weights
260
- if weights['vision']:
261
- missing_v, unexpected_v = model.visual.load_state_dict(weights['vision'], strict=False)
262
- if missing_v:
263
- print(f"Missing vision keys: {len(missing_v)} (expected for FPN/post-norm components)")
264
- if unexpected_v:
265
- # Filter out expected mismatches
266
- important_unexpected = [k for k in unexpected_v if not any(x in k for x in ['fpn', 'ln_post', 'proj'])]
267
- if important_unexpected:
268
- print(f"Unexpected vision keys: {important_unexpected}")
269
- else:
270
- print(f"โœ“ Vision weights loaded (ignoring {len(unexpected_v)} FPN/post-norm parameters)")
271
-
272
- # Load text encoder weights
273
- if weights['text']:
274
- missing_t, unexpected_t = model.text_encoder.load_state_dict(weights['text'], strict=False)
275
- if missing_t:
276
- print(f"Missing text keys: {len(missing_t)}")
277
- if unexpected_t:
278
- print(f"Unexpected text keys: {unexpected_t}")
279
-
280
- print("โœ“ Model weights loaded successfully")
281
- else:
282
- print(f"โš  Checkpoint not found at {checkpoint_path}, using random weights")
283
-
284
- # Setup device
285
- if device == 'auto':
286
- device = 'cuda' if torch.cuda.is_available() else 'cpu'
287
-
288
- model = model.to(device)
289
- model.eval()
290
-
291
- print(f"โœ“ Model loaded on {device}")
292
- return model
293
-
294
-
295
- # Convenience function
296
- def load_clip(config_name: str = 'denseclip_vitb16',
297
- checkpoint_path: Optional[str] = None,
298
- device: str = 'auto') -> DenseCLIPModel:
299
- """
300
- Convenience function to load a DenseCLIP model
301
-
302
- Args:
303
- config_name: Name of config file (without .yaml extension)
304
- checkpoint_path: Optional path to checkpoint
305
- device: Device to load on
306
-
307
- Returns:
308
- Loaded DenseCLIPModel
309
- """
310
- current_dir = os.path.dirname(os.path.abspath(__file__))
311
- config_path = os.path.join(current_dir, 'configs', f'{config_name}.yaml')
312
-
313
- if not os.path.exists(config_path):
314
- raise FileNotFoundError(f"Config file not found: {config_path}")
315
-
316
- return load_denseclip_model(config_path, checkpoint_path, device)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
src/denseclip/clip_loader/example_usage.py DELETED
@@ -1,108 +0,0 @@
1
- #!/usr/bin/env python3
2
- """
3
- Example usage of the DenseCLIP to CLIP loader
4
- """
5
-
6
- import sys
7
- import os
8
-
9
- # Add the clip_loader to path
10
- current_dir = os.path.dirname(os.path.abspath(__file__))
11
- sys.path.append(current_dir)
12
-
13
- from denseclip_loader import load_clip
14
- import torch
15
-
16
-
17
- def main():
18
- print("๐Ÿš€ DenseCLIP to CLIP Loader Example")
19
- print("=" * 50)
20
-
21
- # Load model
22
- print("Loading DenseCLIP model...")
23
- try:
24
- model = load_clip('denseclip_segmentation_vitb16')
25
- print("โœ… Model loaded successfully!")
26
- except Exception as e:
27
- print(f"โŒ Error loading model: {e}")
28
- return
29
-
30
- print(f"Model device: {next(model.parameters()).device}")
31
- print(f"Text context length: {model.context_length}")
32
- print(f"Image resolution: {model.image_resolution}")
33
-
34
- # Test text encoding
35
- print("\n๐Ÿ“ Testing text encoding...")
36
- texts = [
37
- "a photo of a cat",
38
- "a photo of a dog",
39
- "a photo of a car",
40
- "a person walking",
41
- "a beautiful sunset"
42
- ]
43
-
44
- text_features = model.encode_text(texts)
45
- print(f"Text features shape: {text_features.shape}")
46
- print(f"Text features norm: {text_features.norm(dim=-1)}") # Should be ~1.0 (normalized)
47
-
48
- # Test text-text similarities
49
- print("\n๐Ÿ” Text-to-text similarities:")
50
- text_similarities = model.compute_similarity(text_features, text_features)
51
-
52
- print(f"{'Text':<20} {'Self-sim':<10} {'vs Cat':<10} {'vs Dog':<10}")
53
- print("-" * 50)
54
- for i, text in enumerate(texts):
55
- self_sim = text_similarities[i, i].item()
56
- cat_sim = text_similarities[i, 0].item()
57
- dog_sim = text_similarities[i, 1].item()
58
- print(f"{text:<20} {self_sim:<10.3f} {cat_sim:<10.3f} {dog_sim:<10.3f}")
59
-
60
- # Test zero-shot classification concepts
61
- print("\n๐ŸŽฏ Zero-shot classification example:")
62
- test_queries = [
63
- "an animal",
64
- "a vehicle",
65
- "a person",
66
- "nature scene"
67
- ]
68
-
69
- query_features = model.encode_text(test_queries)
70
- classification_similarities = model.compute_similarity(text_features, query_features)
71
-
72
- print(f"{'Original Text':<20} {'Animal':<8} {'Vehicle':<8} {'Person':<8} {'Nature':<8}")
73
- print("-" * 60)
74
- for i, text in enumerate(texts):
75
- sims = classification_similarities[i]
76
- print(f"{text:<20} {sims[0]:<8.3f} {sims[1]:<8.3f} {sims[2]:<8.3f} {sims[3]:<8.3f}")
77
-
78
- # Test feature statistics
79
- print("\n๐Ÿ“Š Feature statistics:")
80
- print(f"Text feature mean: {text_features.mean():.6f}")
81
- print(f"Text feature std: {text_features.std():.6f}")
82
- print(f"Text feature min: {text_features.min():.6f}")
83
- print(f"Text feature max: {text_features.max():.6f}")
84
-
85
- # Test model components
86
- print("\n๐Ÿ”ง Model architecture:")
87
- print(f"Vision encoder: {type(model.visual).__name__}")
88
- print(f"Text encoder: {type(model.text_encoder).__name__}")
89
-
90
- # Count parameters
91
- vision_params = sum(p.numel() for p in model.visual.parameters())
92
- text_params = sum(p.numel() for p in model.text_encoder.parameters())
93
- total_params = vision_params + text_params
94
-
95
- print(f"\n๐Ÿ“ˆ Parameter count:")
96
- print(f"Vision encoder: {vision_params:,}")
97
- print(f"Text encoder: {text_params:,}")
98
- print(f"Total: {total_params:,}")
99
-
100
- print("\nโœ… All tests completed successfully!")
101
- print("\n๐Ÿ’ก Usage tip:")
102
- print(" from clip_loader import load_clip")
103
- print(" model = load_clip('denseclip_vitb16')")
104
- print(" features = model.encode_text(['your text here'])")
105
-
106
-
107
- if __name__ == "__main__":
108
- main()