Spaces:
Running
on
Zero
Running
on
Zero
Commit
ยท
f281853
1
Parent(s):
86fbb78
Now the patchioner code is installed via pip, the model weights are loaded through repo id from huggingface
Browse filesThis view is limited to 50 files because it contains too many changes. ย
See raw diff
- .gitignore +3 -0
- README.md +5 -1
- app.py +15 -13
- configs/mlp.k.yaml +0 -8
- configs/mlp.viecap.k.yaml +0 -31
- requirements.txt +3 -36
- src/INViTE/clipfolder/__init__.py +0 -1
- src/INViTE/clipfolder/bpe_simple_vocab_16e6.txt.gz +0 -3
- src/INViTE/clipfolder/clip.py +0 -238
- src/INViTE/clipfolder/model.py +0 -515
- src/INViTE/clipfolder/simple_tokenizer.py +0 -132
- src/INViTE/loader.py +0 -72
- src/alphaclip/INSTALL.md +0 -113
- src/alphaclip/LICENSE +0 -201
- src/alphaclip/MANIFEST.in +0 -7
- src/alphaclip/README.md +0 -266
- src/alphaclip/__init__.py +0 -14
- src/alphaclip/alpha_clip/__init__.py +0 -1
- src/alphaclip/alpha_clip/alpha_clip.py +0 -254
- src/alphaclip/alpha_clip/bpe_simple_vocab_16e6.txt.gz +0 -3
- src/alphaclip/alpha_clip/model.py +0 -609
- src/alphaclip/alpha_clip/simple_tokenizer.py +0 -132
- src/alphaclip/alpha_mask_utils.py +0 -111
- src/alphaclip/alphaclip_loader.py +0 -233
- src/alphaclip/example.py +0 -76
- src/alphaclip/requirements.txt +0 -10
- src/alphaclip/setup.py +0 -47
- src/alphaclip/test_installation.py +0 -149
- src/bbox_utils.py +0 -421
- src/clipcap/CLIPCAP_INTEGRATION.md +0 -206
- src/clipcap/clipcapTrainREADME.md +0 -301
- src/clipcap/clipcapTraining.py +0 -405
- src/clipcap/clipcap_dino_parse_coco.py +0 -613
- src/clipcap/clipcap_parse_coco.py +0 -51
- src/clipcap/entrypoint.py +0 -564
- src/clipcap/predict.py +0 -302
- src/dataset.py +0 -94
- src/datasetMix.py +0 -153
- src/decap/decap.py +0 -193
- src/decap/decoderTraining.py +0 -464
- src/decap/decoder_config.pkl +0 -3
- src/decap/im2txtprojection/im2txtprojection.py +0 -500
- src/denseclip/clip_loader/README.md +0 -233
- src/denseclip/clip_loader/SUMMARY.md +0 -78
- src/denseclip/clip_loader/__init__.py +0 -21
- src/denseclip/clip_loader/bpe_simple_vocab_16e6.txt.gz +0 -3
- src/denseclip/clip_loader/configs/denseclip_segmentation_vitb16.yaml +0 -41
- src/denseclip/clip_loader/configs/denseclip_segmentation_vitb16_long_ctx.yaml +0 -41
- src/denseclip/clip_loader/denseclip_loader.py +0 -316
- src/denseclip/clip_loader/example_usage.py +0 -108
.gitignore
CHANGED
|
@@ -1,3 +1,6 @@
|
|
| 1 |
*.pyc
|
| 2 |
*.pyo
|
| 3 |
*.pyd
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
*.pyc
|
| 2 |
*.pyo
|
| 3 |
*.pyd
|
| 4 |
+
venv/
|
| 5 |
+
.gradio/
|
| 6 |
+
.venv/
|
README.md
CHANGED
|
@@ -13,4 +13,8 @@ short_description: 'Repo for the Paper "One Patch to Caption Them All: ...'
|
|
| 13 |
|
| 14 |
ArXiv: arxiv.org/abs/2510.02898
|
| 15 |
|
| 16 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 13 |
|
| 14 |
ArXiv: arxiv.org/abs/2510.02898
|
| 15 |
|
| 16 |
+
Demo of the Patch-ioner framework, from the paper "One Patch to Caption Them All: A Unified Zero-shot Captioning Framework".
|
| 17 |
+
|
| 18 |
+
The project page is at [paciosoft.com/Patch-ioner](https://paciosoft.com/Patch-ioner).
|
| 19 |
+
|
| 20 |
+
|
app.py
CHANGED
|
@@ -24,8 +24,7 @@ from PIL import Image
|
|
| 24 |
import numpy as np
|
| 25 |
from typing import List, Dict
|
| 26 |
|
| 27 |
-
|
| 28 |
-
from src.model import Patchioner
|
| 29 |
|
| 30 |
# Global variable to store the loaded model
|
| 31 |
loaded_model = None
|
|
@@ -33,7 +32,7 @@ model_config_path = None
|
|
| 33 |
device = "cuda" if torch.cuda.is_available() else "cpu"
|
| 34 |
|
| 35 |
# Default model configuration
|
| 36 |
-
DEFAULT_MODEL_CONFIG = "
|
| 37 |
|
| 38 |
# Example images directory
|
| 39 |
current_dir = os.path.dirname(__file__)
|
|
@@ -50,13 +49,16 @@ def initialize_default_model() -> str:
|
|
| 50 |
default_config_path = CONFIGS_DIR / DEFAULT_MODEL_CONFIG
|
| 51 |
|
| 52 |
if not default_config_path.exists():
|
| 53 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 54 |
|
| 55 |
print(f"Loading default model: {DEFAULT_MODEL_CONFIG}")
|
| 56 |
|
| 57 |
-
|
| 58 |
-
with open(default_config_path, 'r') as f:
|
| 59 |
-
config = yaml.safe_load(f)
|
| 60 |
|
| 61 |
# Load the model using the from_config class method
|
| 62 |
model = Patchioner.from_config(config, device=device)
|
|
@@ -553,7 +555,7 @@ def generate_bbox_caption(image_data, image) -> str:
|
|
| 553 |
return error_msg
|
| 554 |
|
| 555 |
|
| 556 |
-
def create_gradio_interface():
|
| 557 |
"""Create and configure the Gradio interface."""
|
| 558 |
|
| 559 |
# Get example files
|
|
@@ -593,7 +595,7 @@ def create_gradio_interface():
|
|
| 593 |
) as demo:
|
| 594 |
#gr.HTML(custom_js) # inject custom JS
|
| 595 |
|
| 596 |
-
gr.Markdown("""
|
| 597 |
# ๐ฏ Patchioner Trace Captioning Demo
|
| 598 |
|
| 599 |
This demo allows you to:
|
|
@@ -608,7 +610,7 @@ def create_gradio_interface():
|
|
| 608 |
3. Use the appropriate tool to mark areas of interest in the image
|
| 609 |
4. Click "Generate Caption" to get AI-generated descriptions
|
| 610 |
|
| 611 |
-
**Model:** Using `
|
| 612 |
""")
|
| 613 |
|
| 614 |
# Initialize model status
|
|
@@ -730,7 +732,7 @@ def create_gradio_interface():
|
|
| 730 |
outputs=[image_editor, image_annotator]
|
| 731 |
)
|
| 732 |
|
| 733 |
-
gr.Markdown("""
|
| 734 |
### ๐ก Tips:
|
| 735 |
- **Mode Selection**: Switch between trace and bounding box modes based on your needs
|
| 736 |
- **Trace Mode**: Draw continuous lines over areas you want to describe
|
|
@@ -741,7 +743,7 @@ def create_gradio_interface():
|
|
| 741 |
### ๐ง Technical Details:
|
| 742 |
- **Trace Mode**: Converts drawings to normalized (x, y) coordinates with timestamps
|
| 743 |
- **BBox Mode**: Uses bounding box coordinates for region-specific captioning
|
| 744 |
-
- **Model Architecture**: Uses `
|
| 745 |
- **Processing**: Each trace/bbox is processed separately to generate corresponding captions
|
| 746 |
""")
|
| 747 |
|
|
@@ -762,7 +764,7 @@ if __name__ == "__main__":
|
|
| 762 |
print(f"Example images directory: {EXAMPLE_IMAGES_DIR}")
|
| 763 |
print(f"Configs directory: {CONFIGS_DIR}")
|
| 764 |
|
| 765 |
-
demo = create_gradio_interface()
|
| 766 |
if not args.local:
|
| 767 |
demo.launch()
|
| 768 |
else:
|
|
|
|
| 24 |
import numpy as np
|
| 25 |
from typing import List, Dict
|
| 26 |
|
| 27 |
+
from patchioner import Patchioner
|
|
|
|
| 28 |
|
| 29 |
# Global variable to store the loaded model
|
| 30 |
loaded_model = None
|
|
|
|
| 32 |
device = "cuda" if torch.cuda.is_available() else "cpu"
|
| 33 |
|
| 34 |
# Default model configuration
|
| 35 |
+
DEFAULT_MODEL_CONFIG = "https://huggingface.co/Ruggero1912/Patch-ioner_talk2dino_decap_COCO_Captions"
|
| 36 |
|
| 37 |
# Example images directory
|
| 38 |
current_dir = os.path.dirname(__file__)
|
|
|
|
| 49 |
default_config_path = CONFIGS_DIR / DEFAULT_MODEL_CONFIG
|
| 50 |
|
| 51 |
if not default_config_path.exists():
|
| 52 |
+
print( f"โ Default config file not found: {default_config_path}" )
|
| 53 |
+
config = DEFAULT_MODEL_CONFIG # Assume it's a URL or model identifier
|
| 54 |
+
print( f"Attempting to load model from identifier: {config}" )
|
| 55 |
+
|
| 56 |
+
else:
|
| 57 |
+
config = default_config_path
|
| 58 |
|
| 59 |
print(f"Loading default model: {DEFAULT_MODEL_CONFIG}")
|
| 60 |
|
| 61 |
+
|
|
|
|
|
|
|
| 62 |
|
| 63 |
# Load the model using the from_config class method
|
| 64 |
model = Patchioner.from_config(config, device=device)
|
|
|
|
| 555 |
return error_msg
|
| 556 |
|
| 557 |
|
| 558 |
+
def create_gradio_interface(model_config_name : str):
|
| 559 |
"""Create and configure the Gradio interface."""
|
| 560 |
|
| 561 |
# Get example files
|
|
|
|
| 595 |
) as demo:
|
| 596 |
#gr.HTML(custom_js) # inject custom JS
|
| 597 |
|
| 598 |
+
gr.Markdown(f"""
|
| 599 |
# ๐ฏ Patchioner Trace Captioning Demo
|
| 600 |
|
| 601 |
This demo allows you to:
|
|
|
|
| 610 |
3. Use the appropriate tool to mark areas of interest in the image
|
| 611 |
4. Click "Generate Caption" to get AI-generated descriptions
|
| 612 |
|
| 613 |
+
**Model:** Using `{model_config_name}` configuration (automatically loaded)
|
| 614 |
""")
|
| 615 |
|
| 616 |
# Initialize model status
|
|
|
|
| 732 |
outputs=[image_editor, image_annotator]
|
| 733 |
)
|
| 734 |
|
| 735 |
+
gr.Markdown(f"""
|
| 736 |
### ๐ก Tips:
|
| 737 |
- **Mode Selection**: Switch between trace and bounding box modes based on your needs
|
| 738 |
- **Trace Mode**: Draw continuous lines over areas you want to describe
|
|
|
|
| 743 |
### ๐ง Technical Details:
|
| 744 |
- **Trace Mode**: Converts drawings to normalized (x, y) coordinates with timestamps
|
| 745 |
- **BBox Mode**: Uses bounding box coordinates for region-specific captioning
|
| 746 |
+
- **Model Architecture**: Uses `{model_config_name}` configuration with CLIP and ViT components
|
| 747 |
- **Processing**: Each trace/bbox is processed separately to generate corresponding captions
|
| 748 |
""")
|
| 749 |
|
|
|
|
| 764 |
print(f"Example images directory: {EXAMPLE_IMAGES_DIR}")
|
| 765 |
print(f"Configs directory: {CONFIGS_DIR}")
|
| 766 |
|
| 767 |
+
demo = create_gradio_interface(DEFAULT_MODEL_CONFIG)
|
| 768 |
if not args.local:
|
| 769 |
demo.launch()
|
| 770 |
else:
|
configs/mlp.k.yaml
DELETED
|
@@ -1,8 +0,0 @@
|
|
| 1 |
-
decap_weights: 'weights/decap-talk2dino-coco_karpathy-009.pt'
|
| 2 |
-
prefix_size: 768
|
| 3 |
-
linear_talk2dino: False
|
| 4 |
-
support_memory_size: 591753
|
| 5 |
-
dino_model: 'dinov2_vitb14_reg'
|
| 6 |
-
normalize: True
|
| 7 |
-
kkv_attention: False
|
| 8 |
-
projection_type: '/raid/datasets/im2txtmemories/coco_train_karpathy.json'
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
configs/mlp.viecap.k.yaml
DELETED
|
@@ -1,31 +0,0 @@
|
|
| 1 |
-
decap_weights: null
|
| 2 |
-
prefix_size: 768
|
| 3 |
-
linear_talk2dino: False
|
| 4 |
-
support_memory_size: 0
|
| 5 |
-
dino_model: 'dinov2_vitb14_reg'
|
| 6 |
-
normalize: False
|
| 7 |
-
kkv_attention: False
|
| 8 |
-
use_talk2dino_project: False
|
| 9 |
-
clip_model_name: "ViT-B/16"
|
| 10 |
-
|
| 11 |
-
|
| 12 |
-
# nested config
|
| 13 |
-
viecap:
|
| 14 |
-
clip_hidden_size: 768
|
| 15 |
-
suffix: ViT-B16_t2d_
|
| 16 |
-
project_length: 10
|
| 17 |
-
temperature: 0.01
|
| 18 |
-
top_k: 3
|
| 19 |
-
threshold: 0.4
|
| 20 |
-
language_model: 'gpt2'
|
| 21 |
-
name_of_entities_text: coco_entities #vinvl_vgoi_entities
|
| 22 |
-
files_path: 'weights/viecap_files/'
|
| 23 |
-
prompt_ensemble: True
|
| 24 |
-
weight_path: 'weights/viecap-talk2dino-coco_karpathy-0014.pt'
|
| 25 |
-
using_hard_prompt: True
|
| 26 |
-
soft_prompt_first: True
|
| 27 |
-
only_hard_prompt: False
|
| 28 |
-
using_greedy_search: True #if false, use beam search
|
| 29 |
-
beam_width: 5
|
| 30 |
-
text_prompt: None
|
| 31 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
requirements.txt
CHANGED
|
@@ -1,36 +1,3 @@
|
|
| 1 |
-
|
| 2 |
-
|
| 3 |
-
|
| 4 |
-
gradio>=4.0.0
|
| 5 |
-
gradio_image_annotation
|
| 6 |
-
|
| 7 |
-
# Image processing - required for the demo
|
| 8 |
-
pillow
|
| 9 |
-
torchvision
|
| 10 |
-
|
| 11 |
-
# Data handling - required
|
| 12 |
-
numpy
|
| 13 |
-
tqdm
|
| 14 |
-
|
| 15 |
-
# CLIP - essential for the model
|
| 16 |
-
git+https://github.com/openai/CLIP.git
|
| 17 |
-
|
| 18 |
-
# Model dependencies - needed for core functionality
|
| 19 |
-
timm
|
| 20 |
-
|
| 21 |
-
# Hugging Face model hosting
|
| 22 |
-
huggingface_hub
|
| 23 |
-
|
| 24 |
-
h5py
|
| 25 |
-
|
| 26 |
-
# Optional: Only include if specifically needed
|
| 27 |
-
# h5py # Only needed for some data formats - can be installed conditionally
|
| 28 |
-
# scikit-learn # Only for evaluation - not needed for inference
|
| 29 |
-
# plotly # Only for plotting - not needed for basic demo
|
| 30 |
-
# pandas # Only for data analysis - not needed for basic demo
|
| 31 |
-
# matplotlib # Only for plotting - not needed for basic demo
|
| 32 |
-
# pycocotools # Only for COCO evaluation - not needed for basic demo
|
| 33 |
-
# nbformat # Only for notebooks - not needed for basic demo
|
| 34 |
-
# speaksee # Only for evaluation - not needed for basic demo
|
| 35 |
-
# munkres # Only for specific evaluation metrics - not needed for basic demo
|
| 36 |
-
# open_clip_torch # Only if using open_clip models - not needed for basic demo
|
|
|
|
| 1 |
+
git+https://github.com/Ruggero1912/Patch-ioner
|
| 2 |
+
gradio==5.48.0
|
| 3 |
+
gradio_image_annotation
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
src/INViTE/clipfolder/__init__.py
DELETED
|
@@ -1 +0,0 @@
|
|
| 1 |
-
from .clip import *
|
|
|
|
|
|
src/INViTE/clipfolder/bpe_simple_vocab_16e6.txt.gz
DELETED
|
@@ -1,3 +0,0 @@
|
|
| 1 |
-
version https://git-lfs.github.com/spec/v1
|
| 2 |
-
oid sha256:924691ac288e54409236115652ad4aa250f48203de50a9e4722a6ecd48d6804a
|
| 3 |
-
size 1356917
|
|
|
|
|
|
|
|
|
|
|
|
src/INViTE/clipfolder/clip.py
DELETED
|
@@ -1,238 +0,0 @@
|
|
| 1 |
-
import hashlib
|
| 2 |
-
import os
|
| 3 |
-
import urllib
|
| 4 |
-
import warnings
|
| 5 |
-
from typing import Any, Union, List
|
| 6 |
-
from pkg_resources import packaging
|
| 7 |
-
|
| 8 |
-
import torch
|
| 9 |
-
from PIL import Image
|
| 10 |
-
from torchvision.transforms import Compose, Resize, CenterCrop, ToTensor, Normalize
|
| 11 |
-
from tqdm import tqdm
|
| 12 |
-
|
| 13 |
-
from .model import build_model
|
| 14 |
-
from .simple_tokenizer import SimpleTokenizer as _Tokenizer
|
| 15 |
-
|
| 16 |
-
try:
|
| 17 |
-
from torchvision.transforms import InterpolationMode
|
| 18 |
-
BICUBIC = InterpolationMode.BICUBIC
|
| 19 |
-
except ImportError:
|
| 20 |
-
BICUBIC = Image.BICUBIC
|
| 21 |
-
|
| 22 |
-
|
| 23 |
-
if packaging.version.parse(torch.__version__) < packaging.version.parse("1.7.1"):
|
| 24 |
-
warnings.warn("PyTorch version 1.7.1 or higher is recommended")
|
| 25 |
-
|
| 26 |
-
|
| 27 |
-
__all__ = ["available_models", "load", "tokenize"]
|
| 28 |
-
_tokenizer = _Tokenizer()
|
| 29 |
-
|
| 30 |
-
_MODELS = {
|
| 31 |
-
"RN50": "https://openaipublic.azureedge.net/clip/models/afeb0e10f9e5a86da6080e35cf09123aca3b358a0c3e3b6c78a7b63bc04b6762/RN50.pt",
|
| 32 |
-
"RN101": "https://openaipublic.azureedge.net/clip/models/8fa8567bab74a42d41c5915025a8e4538c3bdbe8804a470a72f30b0d94fab599/RN101.pt",
|
| 33 |
-
"RN50x4": "https://openaipublic.azureedge.net/clip/models/7e526bd135e493cef0776de27d5f42653e6b4c8bf9e0f653bb11773263205fdd/RN50x4.pt",
|
| 34 |
-
"RN50x16": "https://openaipublic.azureedge.net/clip/models/52378b407f34354e150460fe41077663dd5b39c54cd0bfd2b27167a4a06ec9aa/RN50x16.pt",
|
| 35 |
-
"RN50x64": "https://openaipublic.azureedge.net/clip/models/be1cfb55d75a9666199fb2206c106743da0f6468c9d327f3e0d0a543a9919d9c/RN50x64.pt",
|
| 36 |
-
"ViT-B/32": "https://openaipublic.azureedge.net/clip/models/40d365715913c9da98579312b702a82c18be219cc2a73407c4526f58eba950af/ViT-B-32.pt",
|
| 37 |
-
"ViT-B/16": "https://openaipublic.azureedge.net/clip/models/5806e77cd80f8b59890b7e101eabd078d9fb84e6937f9e85e4ecb61988df416f/ViT-B-16.pt",
|
| 38 |
-
"ViT-L/14": "https://openaipublic.azureedge.net/clip/models/b8cca3fd41ae0c99ba7e8951adf17d267cdb84cd88be6f7c2e0eca1737a03836/ViT-L-14.pt",
|
| 39 |
-
"ViT-L/14@336px": "https://openaipublic.azureedge.net/clip/models/3035c92b350959924f9f00213499208652fc7ea050643e8b385c2dac08641f02/ViT-L-14-336px.pt",
|
| 40 |
-
}
|
| 41 |
-
|
| 42 |
-
|
| 43 |
-
def _download(url: str, root: str):
|
| 44 |
-
os.makedirs(root, exist_ok=True)
|
| 45 |
-
filename = os.path.basename(url)
|
| 46 |
-
|
| 47 |
-
expected_sha256 = url.split("/")[-2]
|
| 48 |
-
download_target = os.path.join(root, filename)
|
| 49 |
-
|
| 50 |
-
if os.path.exists(download_target) and not os.path.isfile(download_target):
|
| 51 |
-
raise RuntimeError(f"{download_target} exists and is not a regular file")
|
| 52 |
-
|
| 53 |
-
if os.path.isfile(download_target):
|
| 54 |
-
if hashlib.sha256(open(download_target, "rb").read()).hexdigest() == expected_sha256:
|
| 55 |
-
return download_target
|
| 56 |
-
else:
|
| 57 |
-
warnings.warn(f"{download_target} exists, but the SHA256 checksum does not match; re-downloading the file")
|
| 58 |
-
|
| 59 |
-
with urllib.request.urlopen(url) as source, open(download_target, "wb") as output:
|
| 60 |
-
with tqdm(total=int(source.info().get("Content-Length")), ncols=80, unit='iB', unit_scale=True, unit_divisor=1024) as loop:
|
| 61 |
-
while True:
|
| 62 |
-
buffer = source.read(8192)
|
| 63 |
-
if not buffer:
|
| 64 |
-
break
|
| 65 |
-
|
| 66 |
-
output.write(buffer)
|
| 67 |
-
loop.update(len(buffer))
|
| 68 |
-
|
| 69 |
-
if hashlib.sha256(open(download_target, "rb").read()).hexdigest() != expected_sha256:
|
| 70 |
-
raise RuntimeError("Model has been downloaded but the SHA256 checksum does not not match")
|
| 71 |
-
|
| 72 |
-
return download_target
|
| 73 |
-
|
| 74 |
-
|
| 75 |
-
def _convert_image_to_rgb(image):
|
| 76 |
-
return image.convert("RGB")
|
| 77 |
-
|
| 78 |
-
|
| 79 |
-
def _transform(n_px):
|
| 80 |
-
return Compose([
|
| 81 |
-
Resize(n_px, interpolation=BICUBIC),
|
| 82 |
-
CenterCrop(n_px),
|
| 83 |
-
_convert_image_to_rgb,
|
| 84 |
-
ToTensor(),
|
| 85 |
-
Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711)),
|
| 86 |
-
])
|
| 87 |
-
|
| 88 |
-
|
| 89 |
-
def available_models() -> List[str]:
|
| 90 |
-
"""Returns the names of available CLIP models"""
|
| 91 |
-
return list(_MODELS.keys())
|
| 92 |
-
|
| 93 |
-
|
| 94 |
-
def load(name: str, device: Union[str, torch.device] = "cuda" if torch.cuda.is_available() else "cpu", jit: bool = False,
|
| 95 |
-
download_root: str = None, extract_last_k_th_token: int = -1, viz: bool = False, image_resolution: int = None):
|
| 96 |
-
"""Load a CLIP model
|
| 97 |
-
|
| 98 |
-
Parameters
|
| 99 |
-
----------
|
| 100 |
-
name : str
|
| 101 |
-
A model name listed by `clip.available_models()`, or the path to a model checkpoint containing the state_dict
|
| 102 |
-
|
| 103 |
-
device : Union[str, torch.device]
|
| 104 |
-
The device to put the loaded model
|
| 105 |
-
|
| 106 |
-
jit : bool
|
| 107 |
-
Whether to load the optimized JIT model or more hackable non-JIT model (default).
|
| 108 |
-
|
| 109 |
-
download_root: str
|
| 110 |
-
path to download the model files; by default, it uses "~/.cache/clip"
|
| 111 |
-
|
| 112 |
-
Returns
|
| 113 |
-
-------
|
| 114 |
-
model : torch.nn.Module
|
| 115 |
-
The CLIP model
|
| 116 |
-
|
| 117 |
-
preprocess : Callable[[PIL.Image], torch.Tensor]
|
| 118 |
-
A torchvision transform that converts a PIL image into a tensor that the returned model can take as its input
|
| 119 |
-
"""
|
| 120 |
-
if name in _MODELS:
|
| 121 |
-
model_path = _download(_MODELS[name], download_root or os.path.expanduser("~/.cache/clip"))
|
| 122 |
-
elif os.path.isfile(name):
|
| 123 |
-
model_path = name
|
| 124 |
-
else:
|
| 125 |
-
raise RuntimeError(f"Model {name} not found; available models = {available_models()}")
|
| 126 |
-
|
| 127 |
-
with open(model_path, 'rb') as opened_file:
|
| 128 |
-
try:
|
| 129 |
-
# loading JIT archive
|
| 130 |
-
model = torch.jit.load(opened_file, map_location=device if jit else "cpu").eval()
|
| 131 |
-
state_dict = None
|
| 132 |
-
except RuntimeError:
|
| 133 |
-
# loading saved state dict
|
| 134 |
-
if jit:
|
| 135 |
-
warnings.warn(f"File {model_path} is not a JIT archive. Loading as a state dict instead")
|
| 136 |
-
jit = False
|
| 137 |
-
state_dict = torch.load(opened_file, map_location="cpu")
|
| 138 |
-
|
| 139 |
-
if not jit:
|
| 140 |
-
model = build_model(state_dict or model.state_dict(), extract_last_k_th_token, viz, image_resolution=image_resolution).to(device)
|
| 141 |
-
if str(device) == "cpu":
|
| 142 |
-
model.float()
|
| 143 |
-
return model, _transform(model.visual.input_resolution)
|
| 144 |
-
|
| 145 |
-
# patch the device names
|
| 146 |
-
device_holder = torch.jit.trace(lambda: torch.ones([]).to(torch.device(device)), example_inputs=[])
|
| 147 |
-
device_node = [n for n in device_holder.graph.findAllNodes("prim::Constant") if "Device" in repr(n)][-1]
|
| 148 |
-
|
| 149 |
-
def patch_device(module):
|
| 150 |
-
try:
|
| 151 |
-
graphs = [module.graph] if hasattr(module, "graph") else []
|
| 152 |
-
except RuntimeError:
|
| 153 |
-
graphs = []
|
| 154 |
-
|
| 155 |
-
if hasattr(module, "forward1"):
|
| 156 |
-
graphs.append(module.forward1.graph)
|
| 157 |
-
|
| 158 |
-
for graph in graphs:
|
| 159 |
-
for node in graph.findAllNodes("prim::Constant"):
|
| 160 |
-
if "value" in node.attributeNames() and str(node["value"]).startswith("cuda"):
|
| 161 |
-
node.copyAttributes(device_node)
|
| 162 |
-
|
| 163 |
-
model.apply(patch_device)
|
| 164 |
-
patch_device(model.encode_image)
|
| 165 |
-
patch_device(model.encode_text)
|
| 166 |
-
|
| 167 |
-
# patch dtype to float32 on CPU
|
| 168 |
-
if str(device) == "cpu":
|
| 169 |
-
float_holder = torch.jit.trace(lambda: torch.ones([]).float(), example_inputs=[])
|
| 170 |
-
float_input = list(float_holder.graph.findNode("aten::to").inputs())[1]
|
| 171 |
-
float_node = float_input.node()
|
| 172 |
-
|
| 173 |
-
def patch_float(module):
|
| 174 |
-
try:
|
| 175 |
-
graphs = [module.graph] if hasattr(module, "graph") else []
|
| 176 |
-
except RuntimeError:
|
| 177 |
-
graphs = []
|
| 178 |
-
|
| 179 |
-
if hasattr(module, "forward1"):
|
| 180 |
-
graphs.append(module.forward1.graph)
|
| 181 |
-
|
| 182 |
-
for graph in graphs:
|
| 183 |
-
for node in graph.findAllNodes("aten::to"):
|
| 184 |
-
inputs = list(node.inputs())
|
| 185 |
-
for i in [1, 2]: # dtype can be the second or third argument to aten::to()
|
| 186 |
-
if inputs[i].node()["value"] == 5:
|
| 187 |
-
inputs[i].node().copyAttributes(float_node)
|
| 188 |
-
|
| 189 |
-
model.apply(patch_float)
|
| 190 |
-
patch_float(model.encode_image)
|
| 191 |
-
patch_float(model.encode_text)
|
| 192 |
-
|
| 193 |
-
model.float()
|
| 194 |
-
|
| 195 |
-
return model, _transform(model.input_resolution.item())
|
| 196 |
-
|
| 197 |
-
|
| 198 |
-
def tokenize(texts: Union[str, List[str]], context_length: int = 77, truncate: bool = False) -> Union[torch.IntTensor, torch.LongTensor]:
|
| 199 |
-
"""
|
| 200 |
-
Returns the tokenized representation of given input string(s)
|
| 201 |
-
|
| 202 |
-
Parameters
|
| 203 |
-
----------
|
| 204 |
-
texts : Union[str, List[str]]
|
| 205 |
-
An input string or a list of input strings to tokenize
|
| 206 |
-
|
| 207 |
-
context_length : int
|
| 208 |
-
The context length to use; all CLIP models use 77 as the context length
|
| 209 |
-
|
| 210 |
-
truncate: bool
|
| 211 |
-
Whether to truncate the text in case its encoding is longer than the context length
|
| 212 |
-
|
| 213 |
-
Returns
|
| 214 |
-
-------
|
| 215 |
-
A two-dimensional tensor containing the resulting tokens, shape = [number of input strings, context_length].
|
| 216 |
-
We return LongTensor when torch version is <1.8.0, since older index_select requires indices to be long.
|
| 217 |
-
"""
|
| 218 |
-
if isinstance(texts, str):
|
| 219 |
-
texts = [texts]
|
| 220 |
-
|
| 221 |
-
sot_token = _tokenizer.encoder["<|startoftext|>"]
|
| 222 |
-
eot_token = _tokenizer.encoder["<|endoftext|>"]
|
| 223 |
-
all_tokens = [[sot_token] + _tokenizer.encode(text) + [eot_token] for text in texts]
|
| 224 |
-
if packaging.version.parse(torch.__version__) < packaging.version.parse("1.8.0"):
|
| 225 |
-
result = torch.zeros(len(all_tokens), context_length, dtype=torch.long)
|
| 226 |
-
else:
|
| 227 |
-
result = torch.zeros(len(all_tokens), context_length, dtype=torch.int)
|
| 228 |
-
|
| 229 |
-
for i, tokens in enumerate(all_tokens):
|
| 230 |
-
if len(tokens) > context_length:
|
| 231 |
-
if truncate:
|
| 232 |
-
tokens = tokens[:context_length]
|
| 233 |
-
tokens[-1] = eot_token
|
| 234 |
-
else:
|
| 235 |
-
raise RuntimeError(f"Input {texts[i]} is too long for context length {context_length}")
|
| 236 |
-
result[i, :len(tokens)] = torch.tensor(tokens)
|
| 237 |
-
|
| 238 |
-
return result
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
src/INViTE/clipfolder/model.py
DELETED
|
@@ -1,515 +0,0 @@
|
|
| 1 |
-
from collections import OrderedDict
|
| 2 |
-
from typing import Tuple, Union
|
| 3 |
-
|
| 4 |
-
import numpy as np
|
| 5 |
-
import torch
|
| 6 |
-
import torch.nn.functional as F
|
| 7 |
-
from torch import nn
|
| 8 |
-
|
| 9 |
-
|
| 10 |
-
class Bottleneck(nn.Module):
|
| 11 |
-
expansion = 4
|
| 12 |
-
|
| 13 |
-
def __init__(self, inplanes, planes, stride=1):
|
| 14 |
-
super().__init__()
|
| 15 |
-
|
| 16 |
-
# all conv layers have stride 1. an avgpool is performed after the second convolution when stride > 1
|
| 17 |
-
self.conv1 = nn.Conv2d(inplanes, planes, 1, bias=False)
|
| 18 |
-
self.bn1 = nn.BatchNorm2d(planes)
|
| 19 |
-
self.relu1 = nn.ReLU(inplace=True)
|
| 20 |
-
|
| 21 |
-
self.conv2 = nn.Conv2d(planes, planes, 3, padding=1, bias=False)
|
| 22 |
-
self.bn2 = nn.BatchNorm2d(planes)
|
| 23 |
-
self.relu2 = nn.ReLU(inplace=True)
|
| 24 |
-
|
| 25 |
-
self.avgpool = nn.AvgPool2d(stride) if stride > 1 else nn.Identity()
|
| 26 |
-
|
| 27 |
-
self.conv3 = nn.Conv2d(planes, planes * self.expansion, 1, bias=False)
|
| 28 |
-
self.bn3 = nn.BatchNorm2d(planes * self.expansion)
|
| 29 |
-
self.relu3 = nn.ReLU(inplace=True)
|
| 30 |
-
|
| 31 |
-
self.downsample = None
|
| 32 |
-
self.stride = stride
|
| 33 |
-
|
| 34 |
-
if stride > 1 or inplanes != planes * Bottleneck.expansion:
|
| 35 |
-
# downsampling layer is prepended with an avgpool, and the subsequent convolution has stride 1
|
| 36 |
-
self.downsample = nn.Sequential(OrderedDict([
|
| 37 |
-
("-1", nn.AvgPool2d(stride)),
|
| 38 |
-
("0", nn.Conv2d(inplanes, planes * self.expansion, 1, stride=1, bias=False)),
|
| 39 |
-
("1", nn.BatchNorm2d(planes * self.expansion))
|
| 40 |
-
]))
|
| 41 |
-
|
| 42 |
-
def forward(self, x: torch.Tensor):
|
| 43 |
-
identity = x
|
| 44 |
-
|
| 45 |
-
out = self.relu1(self.bn1(self.conv1(x)))
|
| 46 |
-
out = self.relu2(self.bn2(self.conv2(out)))
|
| 47 |
-
out = self.avgpool(out)
|
| 48 |
-
out = self.bn3(self.conv3(out))
|
| 49 |
-
|
| 50 |
-
if self.downsample is not None:
|
| 51 |
-
identity = self.downsample(x)
|
| 52 |
-
|
| 53 |
-
out += identity
|
| 54 |
-
out = self.relu3(out)
|
| 55 |
-
return out
|
| 56 |
-
|
| 57 |
-
|
| 58 |
-
class AttentionPool2d(nn.Module):
|
| 59 |
-
def __init__(self, spacial_dim: int, embed_dim: int, num_heads: int, output_dim: int = None):
|
| 60 |
-
super().__init__()
|
| 61 |
-
self.positional_embedding = nn.Parameter(torch.randn(spacial_dim ** 2 + 1, embed_dim) / embed_dim ** 0.5)
|
| 62 |
-
self.k_proj = nn.Linear(embed_dim, embed_dim)
|
| 63 |
-
self.q_proj = nn.Linear(embed_dim, embed_dim)
|
| 64 |
-
self.v_proj = nn.Linear(embed_dim, embed_dim)
|
| 65 |
-
self.c_proj = nn.Linear(embed_dim, output_dim or embed_dim)
|
| 66 |
-
self.num_heads = num_heads
|
| 67 |
-
|
| 68 |
-
def forward(self, x):
|
| 69 |
-
x = x.flatten(start_dim=2).permute(2, 0, 1) # NCHW -> (HW)NC
|
| 70 |
-
x = torch.cat([x.mean(dim=0, keepdim=True), x], dim=0) # (HW+1)NC
|
| 71 |
-
x = x + self.positional_embedding[:, None, :].to(x.dtype) # (HW+1)NC
|
| 72 |
-
x, _ = F.multi_head_attention_forward(
|
| 73 |
-
query=x[:1], key=x, value=x,
|
| 74 |
-
embed_dim_to_check=x.shape[-1],
|
| 75 |
-
num_heads=self.num_heads,
|
| 76 |
-
q_proj_weight=self.q_proj.weight,
|
| 77 |
-
k_proj_weight=self.k_proj.weight,
|
| 78 |
-
v_proj_weight=self.v_proj.weight,
|
| 79 |
-
in_proj_weight=None,
|
| 80 |
-
in_proj_bias=torch.cat([self.q_proj.bias, self.k_proj.bias, self.v_proj.bias]),
|
| 81 |
-
bias_k=None,
|
| 82 |
-
bias_v=None,
|
| 83 |
-
add_zero_attn=False,
|
| 84 |
-
dropout_p=0,
|
| 85 |
-
out_proj_weight=self.c_proj.weight,
|
| 86 |
-
out_proj_bias=self.c_proj.bias,
|
| 87 |
-
use_separate_proj_weight=True,
|
| 88 |
-
training=self.training,
|
| 89 |
-
need_weights=False
|
| 90 |
-
)
|
| 91 |
-
return x.squeeze(0)
|
| 92 |
-
|
| 93 |
-
|
| 94 |
-
class ModifiedResNet(nn.Module):
|
| 95 |
-
"""
|
| 96 |
-
A ResNet class that is similar to torchvision's but contains the following changes:
|
| 97 |
-
- There are now 3 "stem" convolutions as opposed to 1, with an average pool instead of a max pool.
|
| 98 |
-
- Performs anti-aliasing strided convolutions, where an avgpool is prepended to convolutions with stride > 1
|
| 99 |
-
- The final pooling layer is a QKV attention instead of an average pool
|
| 100 |
-
"""
|
| 101 |
-
|
| 102 |
-
def __init__(self, layers, output_dim, heads, input_resolution=224, width=64):
|
| 103 |
-
super().__init__()
|
| 104 |
-
self.output_dim = output_dim
|
| 105 |
-
self.input_resolution = input_resolution
|
| 106 |
-
|
| 107 |
-
# the 3-layer stem
|
| 108 |
-
self.conv1 = nn.Conv2d(3, width // 2, kernel_size=3, stride=2, padding=1, bias=False)
|
| 109 |
-
self.bn1 = nn.BatchNorm2d(width // 2)
|
| 110 |
-
self.relu1 = nn.ReLU(inplace=True)
|
| 111 |
-
self.conv2 = nn.Conv2d(width // 2, width // 2, kernel_size=3, padding=1, bias=False)
|
| 112 |
-
self.bn2 = nn.BatchNorm2d(width // 2)
|
| 113 |
-
self.relu2 = nn.ReLU(inplace=True)
|
| 114 |
-
self.conv3 = nn.Conv2d(width // 2, width, kernel_size=3, padding=1, bias=False)
|
| 115 |
-
self.bn3 = nn.BatchNorm2d(width)
|
| 116 |
-
self.relu3 = nn.ReLU(inplace=True)
|
| 117 |
-
self.avgpool = nn.AvgPool2d(2)
|
| 118 |
-
|
| 119 |
-
# residual layers
|
| 120 |
-
self._inplanes = width # this is a *mutable* variable used during construction
|
| 121 |
-
self.layer1 = self._make_layer(width, layers[0])
|
| 122 |
-
self.layer2 = self._make_layer(width * 2, layers[1], stride=2)
|
| 123 |
-
self.layer3 = self._make_layer(width * 4, layers[2], stride=2)
|
| 124 |
-
self.layer4 = self._make_layer(width * 8, layers[3], stride=2)
|
| 125 |
-
|
| 126 |
-
embed_dim = width * 32 # the ResNet feature dimension
|
| 127 |
-
self.attnpool = AttentionPool2d(input_resolution // 32, embed_dim, heads, output_dim)
|
| 128 |
-
|
| 129 |
-
def _make_layer(self, planes, blocks, stride=1):
|
| 130 |
-
layers = [Bottleneck(self._inplanes, planes, stride)]
|
| 131 |
-
|
| 132 |
-
self._inplanes = planes * Bottleneck.expansion
|
| 133 |
-
for _ in range(1, blocks):
|
| 134 |
-
layers.append(Bottleneck(self._inplanes, planes))
|
| 135 |
-
|
| 136 |
-
return nn.Sequential(*layers)
|
| 137 |
-
|
| 138 |
-
def forward(self, x):
|
| 139 |
-
def stem(x):
|
| 140 |
-
x = self.relu1(self.bn1(self.conv1(x)))
|
| 141 |
-
x = self.relu2(self.bn2(self.conv2(x)))
|
| 142 |
-
x = self.relu3(self.bn3(self.conv3(x)))
|
| 143 |
-
x = self.avgpool(x)
|
| 144 |
-
return x
|
| 145 |
-
|
| 146 |
-
x = x.type(self.conv1.weight.dtype)
|
| 147 |
-
x = stem(x)
|
| 148 |
-
x = self.layer1(x)
|
| 149 |
-
x = self.layer2(x)
|
| 150 |
-
x = self.layer3(x)
|
| 151 |
-
x = self.layer4(x)
|
| 152 |
-
x = self.attnpool(x)
|
| 153 |
-
|
| 154 |
-
return x
|
| 155 |
-
|
| 156 |
-
|
| 157 |
-
class LayerNorm(nn.LayerNorm):
|
| 158 |
-
"""Subclass torch's LayerNorm to handle fp16."""
|
| 159 |
-
|
| 160 |
-
def forward(self, x: torch.Tensor):
|
| 161 |
-
orig_type = x.dtype
|
| 162 |
-
ret = super().forward(x.type(torch.float32))
|
| 163 |
-
return ret.type(orig_type)
|
| 164 |
-
|
| 165 |
-
|
| 166 |
-
class QuickGELU(nn.Module):
|
| 167 |
-
def forward(self, x: torch.Tensor):
|
| 168 |
-
return x * torch.sigmoid(1.702 * x)
|
| 169 |
-
|
| 170 |
-
|
| 171 |
-
class ResidualAttentionBlock(nn.Module):
|
| 172 |
-
def __init__(self, d_model: int, n_head: int, attn_mask: torch.Tensor = None, viz: bool = False):
|
| 173 |
-
super().__init__()
|
| 174 |
-
|
| 175 |
-
if viz:
|
| 176 |
-
self.attn = nn.MultiheadAttentionViz(d_model, n_head)
|
| 177 |
-
else:
|
| 178 |
-
self.attn = nn.MultiheadAttention(d_model, n_head)
|
| 179 |
-
|
| 180 |
-
self.ln_1 = LayerNorm(d_model)
|
| 181 |
-
self.mlp = nn.Sequential(OrderedDict([
|
| 182 |
-
("c_fc", nn.Linear(d_model, d_model * 4)),
|
| 183 |
-
("gelu", QuickGELU()),
|
| 184 |
-
("c_proj", nn.Linear(d_model * 4, d_model))
|
| 185 |
-
]))
|
| 186 |
-
self.ln_2 = LayerNorm(d_model)
|
| 187 |
-
self.attn_mask = attn_mask
|
| 188 |
-
|
| 189 |
-
"""attn_mask โ If specified, a 2D or 3D mask preventing attention to certain positions.
|
| 190 |
-
Must be of shape (L,S)(L, S)(L,S) or (Nโ
num_heads,L,S)(N\cdot\text{num\_heads}, L, S)(Nโ
num_heads,L,S),
|
| 191 |
-
where NNN is the batch size, LLL is the target sequence length, and SSS is the source sequence length.
|
| 192 |
-
A 2D mask will be broadcasted across the batch while a 3D mask allows for a different mask for each entry
|
| 193 |
-
in the batch. Binary, byte, and float masks are supported. For a binary mask,
|
| 194 |
-
a True value indicates that the corresponding position is not allowed to attend.
|
| 195 |
-
For a byte mask, a non-zero value indicates that the corresponding position is not allowed to attend.
|
| 196 |
-
For a float mask, the mask values will be added to the attention weight."""
|
| 197 |
-
|
| 198 |
-
def attention(self, x: torch.Tensor):
|
| 199 |
-
self.attn_mask = self.attn_mask.to(dtype=x.dtype, device=x.device) if self.attn_mask is not None else None
|
| 200 |
-
return self.attn(x, x, x, need_weights=False, attn_mask=self.attn_mask)[0]
|
| 201 |
-
|
| 202 |
-
def forward(self, x: torch.Tensor):
|
| 203 |
-
x = x + self.attention(self.ln_1(x))
|
| 204 |
-
x = x + self.mlp(self.ln_2(x))
|
| 205 |
-
return x
|
| 206 |
-
|
| 207 |
-
|
| 208 |
-
|
| 209 |
-
class Transformer(nn.Module):
|
| 210 |
-
def __init__(self, width: int, layers: int, heads: int, attn_mask: torch.Tensor = None,
|
| 211 |
-
extract_last_k_th_token: int=-1, viz: bool = False, num_tokens: int = 50):
|
| 212 |
-
super().__init__()
|
| 213 |
-
self.width = width
|
| 214 |
-
self.layers = layers
|
| 215 |
-
print('\n\n\n\n\ntransformer total layers', layers)
|
| 216 |
-
if extract_last_k_th_token>0:
|
| 217 |
-
start_mask_layer = layers - extract_last_k_th_token
|
| 218 |
-
|
| 219 |
-
ans = []
|
| 220 |
-
for cnt in range(layers):
|
| 221 |
-
if cnt < start_mask_layer:
|
| 222 |
-
ans.append(ResidualAttentionBlock(width, heads, attn_mask, viz))
|
| 223 |
-
else:
|
| 224 |
-
print(' mask for layer {}'.format(cnt))
|
| 225 |
-
mask = torch.empty(num_tokens, num_tokens)
|
| 226 |
-
mask.fill_(float("-inf"))
|
| 227 |
-
mask.fill_diagonal_(0)
|
| 228 |
-
ans.append(ResidualAttentionBlock(width, heads, mask.cuda(), viz))
|
| 229 |
-
# TODO: here is hard coded 50 sequence length
|
| 230 |
-
# only attend to themselves
|
| 231 |
-
|
| 232 |
-
self.resblocks = nn.Sequential(*ans)
|
| 233 |
-
else:
|
| 234 |
-
self.resblocks = nn.Sequential(*[ResidualAttentionBlock(width, heads, attn_mask, viz) for _ in range(layers)])
|
| 235 |
-
|
| 236 |
-
def forward(self, x: torch.Tensor):
|
| 237 |
-
return self.resblocks(x)
|
| 238 |
-
|
| 239 |
-
|
| 240 |
-
class VisionTransformer(nn.Module):
|
| 241 |
-
def __init__(self, input_resolution: int, patch_size: int, width: int, layers: int, heads: int, output_dim: int,
|
| 242 |
-
extract_last_k_th_token: int, viz: bool):
|
| 243 |
-
super().__init__()
|
| 244 |
-
self.input_resolution = input_resolution
|
| 245 |
-
self.output_dim = output_dim
|
| 246 |
-
self.conv1 = nn.Conv2d(in_channels=3, out_channels=width, kernel_size=patch_size, stride=patch_size, bias=False)
|
| 247 |
-
|
| 248 |
-
scale = width ** -0.5
|
| 249 |
-
self.class_embedding = nn.Parameter(scale * torch.randn(width))
|
| 250 |
-
self.positional_embedding = nn.Parameter(scale * torch.randn((input_resolution // patch_size) ** 2 + 1, width))
|
| 251 |
-
self.ln_pre = LayerNorm(width)
|
| 252 |
-
|
| 253 |
-
self.transformer = Transformer(width, layers, heads, extract_last_k_th_token=extract_last_k_th_token, viz=viz, num_tokens=(input_resolution // patch_size) ** 2 + 1)
|
| 254 |
-
|
| 255 |
-
self.ln_post = LayerNorm(width)
|
| 256 |
-
self.proj = nn.Parameter(scale * torch.randn(width, output_dim))
|
| 257 |
-
|
| 258 |
-
def forward(self, x: torch.Tensor, get_all_last: bool):
|
| 259 |
-
# convert x to conv1 dtype
|
| 260 |
-
x = x.type(self.conv1.weight.dtype)
|
| 261 |
-
x = self.conv1(x) # shape = [*, width, grid, grid]
|
| 262 |
-
x = x.reshape(x.shape[0], x.shape[1], -1) # shape = [*, width, grid ** 2]
|
| 263 |
-
x = x.permute(0, 2, 1) # shape = [*, grid ** 2, width]
|
| 264 |
-
x = torch.cat([self.class_embedding.to(x.dtype) + torch.zeros(x.shape[0], 1, x.shape[-1], dtype=x.dtype, device=x.device), x], dim=1) # shape = [*, grid ** 2 + 1, width]
|
| 265 |
-
x = x + self.positional_embedding.to(x.dtype)
|
| 266 |
-
x = self.ln_pre(x)
|
| 267 |
-
|
| 268 |
-
x = x.permute(1, 0, 2) # NLD -> LND
|
| 269 |
-
x = self.transformer(x)
|
| 270 |
-
x = x.permute(1, 0, 2) # LND -> NLD
|
| 271 |
-
|
| 272 |
-
if get_all_last:
|
| 273 |
-
# take all tokens, x is of shape [*, grid ** 2 + 1, width]
|
| 274 |
-
# and we apply layer norm to each token separately
|
| 275 |
-
# x is of shape [*, grid ** 2 + 1, width]
|
| 276 |
-
x = torch.cat([self.ln_post(x[:, idx, :]).unsqueeze(1) for idx in range(x.size(1))], dim=1)
|
| 277 |
-
else:
|
| 278 |
-
# take the first token (CLS token), x is of shape [*, grid ** 2 + 1, width]
|
| 279 |
-
x = self.ln_post(x[:, 0, :])
|
| 280 |
-
|
| 281 |
-
if self.proj is not None:
|
| 282 |
-
x = x @ self.proj
|
| 283 |
-
# the returned x is of shape [*, output_dim] where * is the batch size or if get_all_last is True, [*, grid ** 2 + 1, output_dim]
|
| 284 |
-
return x
|
| 285 |
-
|
| 286 |
-
|
| 287 |
-
class CLIP(nn.Module):
|
| 288 |
-
def __init__(self,
|
| 289 |
-
embed_dim: int,
|
| 290 |
-
# vision
|
| 291 |
-
image_resolution: int,
|
| 292 |
-
vision_layers: Union[Tuple[int, int, int, int], int],
|
| 293 |
-
vision_width: int,
|
| 294 |
-
vision_patch_size: int,
|
| 295 |
-
# text
|
| 296 |
-
context_length: int,
|
| 297 |
-
vocab_size: int,
|
| 298 |
-
transformer_width: int,
|
| 299 |
-
transformer_heads: int,
|
| 300 |
-
transformer_layers: int,
|
| 301 |
-
extract_last_k_th_token: int = -1,
|
| 302 |
-
viz: bool = False
|
| 303 |
-
):
|
| 304 |
-
super().__init__()
|
| 305 |
-
|
| 306 |
-
self.context_length = context_length
|
| 307 |
-
|
| 308 |
-
if isinstance(vision_layers, (tuple, list)):
|
| 309 |
-
vision_heads = vision_width * 32 // 64
|
| 310 |
-
self.visual = ModifiedResNet(
|
| 311 |
-
layers=vision_layers,
|
| 312 |
-
output_dim=embed_dim,
|
| 313 |
-
heads=vision_heads,
|
| 314 |
-
input_resolution=image_resolution,
|
| 315 |
-
width=vision_width
|
| 316 |
-
)
|
| 317 |
-
else:
|
| 318 |
-
vision_heads = vision_width // 64
|
| 319 |
-
self.visual = VisionTransformer(
|
| 320 |
-
input_resolution=image_resolution,
|
| 321 |
-
patch_size=vision_patch_size,
|
| 322 |
-
width=vision_width,
|
| 323 |
-
layers=vision_layers,
|
| 324 |
-
heads=vision_heads,
|
| 325 |
-
output_dim=embed_dim,
|
| 326 |
-
extract_last_k_th_token=extract_last_k_th_token,
|
| 327 |
-
viz=viz
|
| 328 |
-
)
|
| 329 |
-
|
| 330 |
-
self.transformer = Transformer(
|
| 331 |
-
width=transformer_width,
|
| 332 |
-
layers=transformer_layers,
|
| 333 |
-
heads=transformer_heads,
|
| 334 |
-
attn_mask=self.build_attention_mask()
|
| 335 |
-
)
|
| 336 |
-
|
| 337 |
-
self.vocab_size = vocab_size
|
| 338 |
-
self.token_embedding = nn.Embedding(vocab_size, transformer_width)
|
| 339 |
-
self.positional_embedding = nn.Parameter(torch.empty(self.context_length, transformer_width))
|
| 340 |
-
self.ln_final = LayerNorm(transformer_width)
|
| 341 |
-
|
| 342 |
-
self.text_projection = nn.Parameter(torch.empty(transformer_width, embed_dim))
|
| 343 |
-
self.logit_scale = nn.Parameter(torch.ones([]) * np.log(1 / 0.07))
|
| 344 |
-
|
| 345 |
-
self.initialize_parameters()
|
| 346 |
-
|
| 347 |
-
def initialize_parameters(self):
|
| 348 |
-
nn.init.normal_(self.token_embedding.weight, std=0.02)
|
| 349 |
-
nn.init.normal_(self.positional_embedding, std=0.01)
|
| 350 |
-
|
| 351 |
-
if isinstance(self.visual, ModifiedResNet):
|
| 352 |
-
if self.visual.attnpool is not None:
|
| 353 |
-
std = self.visual.attnpool.c_proj.in_features ** -0.5
|
| 354 |
-
nn.init.normal_(self.visual.attnpool.q_proj.weight, std=std)
|
| 355 |
-
nn.init.normal_(self.visual.attnpool.k_proj.weight, std=std)
|
| 356 |
-
nn.init.normal_(self.visual.attnpool.v_proj.weight, std=std)
|
| 357 |
-
nn.init.normal_(self.visual.attnpool.c_proj.weight, std=std)
|
| 358 |
-
|
| 359 |
-
for resnet_block in [self.visual.layer1, self.visual.layer2, self.visual.layer3, self.visual.layer4]:
|
| 360 |
-
for name, param in resnet_block.named_parameters():
|
| 361 |
-
if name.endswith("bn3.weight"):
|
| 362 |
-
nn.init.zeros_(param)
|
| 363 |
-
|
| 364 |
-
proj_std = (self.transformer.width ** -0.5) * ((2 * self.transformer.layers) ** -0.5)
|
| 365 |
-
attn_std = self.transformer.width ** -0.5
|
| 366 |
-
fc_std = (2 * self.transformer.width) ** -0.5
|
| 367 |
-
for block in self.transformer.resblocks:
|
| 368 |
-
nn.init.normal_(block.attn.in_proj_weight, std=attn_std)
|
| 369 |
-
nn.init.normal_(block.attn.out_proj.weight, std=proj_std)
|
| 370 |
-
nn.init.normal_(block.mlp.c_fc.weight, std=fc_std)
|
| 371 |
-
nn.init.normal_(block.mlp.c_proj.weight, std=proj_std)
|
| 372 |
-
|
| 373 |
-
if self.text_projection is not None:
|
| 374 |
-
nn.init.normal_(self.text_projection, std=self.transformer.width ** -0.5)
|
| 375 |
-
|
| 376 |
-
def build_attention_mask(self):
|
| 377 |
-
# lazily create causal attention mask, with full attention between the vision tokens
|
| 378 |
-
# pytorch uses additive attention mask; fill with -inf
|
| 379 |
-
mask = torch.empty(self.context_length, self.context_length)
|
| 380 |
-
mask.fill_(float("-inf"))
|
| 381 |
-
mask.triu_(1) # zero out the lower diagonal
|
| 382 |
-
return mask
|
| 383 |
-
|
| 384 |
-
@property
|
| 385 |
-
def dtype(self):
|
| 386 |
-
return self.visual.conv1.weight.dtype
|
| 387 |
-
|
| 388 |
-
def encode_image(self, image, get_all_last):
|
| 389 |
-
return self.visual(image.type(self.dtype), get_all_last)
|
| 390 |
-
|
| 391 |
-
def encode_text(self, text):
|
| 392 |
-
x = self.token_embedding(text).type(self.dtype) # [batch_size, n_ctx, d_model]
|
| 393 |
-
|
| 394 |
-
x = x + self.positional_embedding.type(self.dtype)
|
| 395 |
-
x = x.permute(1, 0, 2) # NLD -> LND
|
| 396 |
-
x = self.transformer(x)
|
| 397 |
-
x = x.permute(1, 0, 2) # LND -> NLD
|
| 398 |
-
x = self.ln_final(x).type(self.dtype)
|
| 399 |
-
|
| 400 |
-
# x.shape = [batch_size, n_ctx, transformer.width]
|
| 401 |
-
# take features from the eot embedding (eot_token is the highest number in each sequence)
|
| 402 |
-
x = x[torch.arange(x.shape[0]), text.argmax(dim=-1)] @ self.text_projection
|
| 403 |
-
|
| 404 |
-
return x
|
| 405 |
-
|
| 406 |
-
def forward(self, image, text, get_all_last=False):
|
| 407 |
-
image_features = self.encode_image(image, get_all_last)
|
| 408 |
-
text_features = self.encode_text(text)
|
| 409 |
-
|
| 410 |
-
# normalized features
|
| 411 |
-
image_features = image_features / image_features.norm(dim=1, keepdim=True)
|
| 412 |
-
text_features = text_features / text_features.norm(dim=1, keepdim=True)
|
| 413 |
-
|
| 414 |
-
# cosine similarity as logits
|
| 415 |
-
logit_scale = self.logit_scale.exp()
|
| 416 |
-
|
| 417 |
-
if get_all_last:
|
| 418 |
-
return logit_scale * image_features, text_features
|
| 419 |
-
|
| 420 |
-
|
| 421 |
-
logits_per_image = logit_scale * image_features @ text_features.t()
|
| 422 |
-
logits_per_text = logits_per_image.t()
|
| 423 |
-
|
| 424 |
-
# shape = [global_batch_size, global_batch_size]
|
| 425 |
-
return logits_per_image, logits_per_text
|
| 426 |
-
|
| 427 |
-
|
| 428 |
-
def convert_weights(model: nn.Module):
|
| 429 |
-
"""Convert applicable model parameters to fp16"""
|
| 430 |
-
|
| 431 |
-
def _convert_weights_to_fp16(l):
|
| 432 |
-
if isinstance(l, (nn.Conv1d, nn.Conv2d, nn.Linear)):
|
| 433 |
-
l.weight.data = l.weight.data.half()
|
| 434 |
-
if l.bias is not None:
|
| 435 |
-
l.bias.data = l.bias.data.half()
|
| 436 |
-
|
| 437 |
-
if isinstance(l, nn.MultiheadAttention):
|
| 438 |
-
for attr in [*[f"{s}_proj_weight" for s in ["in", "q", "k", "v"]], "in_proj_bias", "bias_k", "bias_v"]:
|
| 439 |
-
tensor = getattr(l, attr)
|
| 440 |
-
if tensor is not None:
|
| 441 |
-
tensor.data = tensor.data.half()
|
| 442 |
-
|
| 443 |
-
for name in ["text_projection", "proj"]:
|
| 444 |
-
if hasattr(l, name):
|
| 445 |
-
attr = getattr(l, name)
|
| 446 |
-
if attr is not None:
|
| 447 |
-
attr.data = attr.data.half()
|
| 448 |
-
|
| 449 |
-
model.apply(_convert_weights_to_fp16)
|
| 450 |
-
|
| 451 |
-
import torch
|
| 452 |
-
import torch.nn.functional as F
|
| 453 |
-
|
| 454 |
-
def resize_pos_embed(old_pe: torch.Tensor, new_shape: int) -> torch.Tensor:
|
| 455 |
-
# old_pe: [old_num_patches + 1, C]
|
| 456 |
-
# new_shape: new_num_patches + 1
|
| 457 |
-
cls_token = old_pe[:1]
|
| 458 |
-
patch_pe = old_pe[1:]
|
| 459 |
-
old_num = int(patch_pe.shape[0] ** 0.5)
|
| 460 |
-
new_num = int((new_shape - 1) ** 0.5)
|
| 461 |
-
|
| 462 |
-
patch_pe = patch_pe.reshape(1, old_num, old_num, -1).permute(0, 3, 1, 2) # (1, C, H, W)
|
| 463 |
-
patch_pe = F.interpolate(patch_pe, size=(new_num, new_num), mode='bicubic', align_corners=False)
|
| 464 |
-
patch_pe = patch_pe.permute(0, 2, 3, 1).reshape(1, new_num * new_num, -1)
|
| 465 |
-
|
| 466 |
-
return torch.cat([cls_token.unsqueeze(0), patch_pe], dim=1).squeeze(0)
|
| 467 |
-
|
| 468 |
-
def build_model(state_dict: dict, extract_last_k_th_token, viz, image_resolution: int = None) -> CLIP:
|
| 469 |
-
vit = "visual.proj" in state_dict
|
| 470 |
-
|
| 471 |
-
if vit:
|
| 472 |
-
vision_width = state_dict["visual.conv1.weight"].shape[0]
|
| 473 |
-
vision_layers = len([k for k in state_dict.keys() if k.startswith("visual.") and k.endswith(".attn.in_proj_weight")])
|
| 474 |
-
vision_patch_size = state_dict["visual.conv1.weight"].shape[-1]
|
| 475 |
-
grid_size = round((state_dict["visual.positional_embedding"].shape[0] - 1) ** 0.5)
|
| 476 |
-
if image_resolution is None:
|
| 477 |
-
image_resolution = vision_patch_size * grid_size
|
| 478 |
-
else:
|
| 479 |
-
counts: list = [len(set(k.split(".")[2] for k in state_dict if k.startswith(f"visual.layer{b}"))) for b in [1, 2, 3, 4]]
|
| 480 |
-
vision_layers = tuple(counts)
|
| 481 |
-
vision_width = state_dict["visual.layer1.0.conv1.weight"].shape[0]
|
| 482 |
-
output_width = round((state_dict["visual.attnpool.positional_embedding"].shape[0] - 1) ** 0.5)
|
| 483 |
-
vision_patch_size = None
|
| 484 |
-
assert output_width ** 2 + 1 == state_dict["visual.attnpool.positional_embedding"].shape[0]
|
| 485 |
-
if image_resolution is None:
|
| 486 |
-
image_resolution = output_width * 32
|
| 487 |
-
|
| 488 |
-
embed_dim = state_dict["text_projection"].shape[1]
|
| 489 |
-
context_length = state_dict["positional_embedding"].shape[0]
|
| 490 |
-
vocab_size = state_dict["token_embedding.weight"].shape[0]
|
| 491 |
-
transformer_width = state_dict["ln_final.weight"].shape[0]
|
| 492 |
-
transformer_heads = transformer_width // 64
|
| 493 |
-
transformer_layers = len(set(k.split(".")[2] for k in state_dict if k.startswith("transformer.resblocks")))
|
| 494 |
-
|
| 495 |
-
model = CLIP(
|
| 496 |
-
embed_dim,
|
| 497 |
-
image_resolution, vision_layers, vision_width, vision_patch_size,
|
| 498 |
-
context_length, vocab_size, transformer_width, transformer_heads, transformer_layers, extract_last_k_th_token, viz
|
| 499 |
-
)
|
| 500 |
-
|
| 501 |
-
for key in ["input_resolution", "context_length", "vocab_size"]:
|
| 502 |
-
if key in state_dict:
|
| 503 |
-
del state_dict[key]
|
| 504 |
-
|
| 505 |
-
convert_weights(model)
|
| 506 |
-
|
| 507 |
-
pretrained_pe = state_dict['visual.positional_embedding']
|
| 508 |
-
model_pe = model.visual.positional_embedding
|
| 509 |
-
|
| 510 |
-
if vit and (pretrained_pe.shape != model_pe.shape):
|
| 511 |
-
print(f"Interpolating positional embedding from {pretrained_pe.shape} to {model_pe.shape}")
|
| 512 |
-
state_dict['visual.positional_embedding'] = resize_pos_embed(pretrained_pe, model_pe.shape[0])
|
| 513 |
-
|
| 514 |
-
model.load_state_dict(state_dict)
|
| 515 |
-
return model.eval()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
src/INViTE/clipfolder/simple_tokenizer.py
DELETED
|
@@ -1,132 +0,0 @@
|
|
| 1 |
-
import gzip
|
| 2 |
-
import html
|
| 3 |
-
import os
|
| 4 |
-
from functools import lru_cache
|
| 5 |
-
|
| 6 |
-
import ftfy
|
| 7 |
-
import regex as re
|
| 8 |
-
|
| 9 |
-
|
| 10 |
-
@lru_cache()
|
| 11 |
-
def default_bpe():
|
| 12 |
-
return os.path.join(os.path.dirname(os.path.abspath(__file__)), "bpe_simple_vocab_16e6.txt.gz")
|
| 13 |
-
|
| 14 |
-
|
| 15 |
-
@lru_cache()
|
| 16 |
-
def bytes_to_unicode():
|
| 17 |
-
"""
|
| 18 |
-
Returns list of utf-8 byte and a corresponding list of unicode strings.
|
| 19 |
-
The reversible bpe codes work on unicode strings.
|
| 20 |
-
This means you need a large # of unicode characters in your vocab if you want to avoid UNKs.
|
| 21 |
-
When you're at something like a 10B token dataset you end up needing around 5K for decent coverage.
|
| 22 |
-
This is a signficant percentage of your normal, say, 32K bpe vocab.
|
| 23 |
-
To avoid that, we want lookup tables between utf-8 bytes and unicode strings.
|
| 24 |
-
And avoids mapping to whitespace/control characters the bpe code barfs on.
|
| 25 |
-
"""
|
| 26 |
-
bs = list(range(ord("!"), ord("~")+1))+list(range(ord("ยก"), ord("ยฌ")+1))+list(range(ord("ยฎ"), ord("รฟ")+1))
|
| 27 |
-
cs = bs[:]
|
| 28 |
-
n = 0
|
| 29 |
-
for b in range(2**8):
|
| 30 |
-
if b not in bs:
|
| 31 |
-
bs.append(b)
|
| 32 |
-
cs.append(2**8+n)
|
| 33 |
-
n += 1
|
| 34 |
-
cs = [chr(n) for n in cs]
|
| 35 |
-
return dict(zip(bs, cs))
|
| 36 |
-
|
| 37 |
-
|
| 38 |
-
def get_pairs(word):
|
| 39 |
-
"""Return set of symbol pairs in a word.
|
| 40 |
-
Word is represented as tuple of symbols (symbols being variable-length strings).
|
| 41 |
-
"""
|
| 42 |
-
pairs = set()
|
| 43 |
-
prev_char = word[0]
|
| 44 |
-
for char in word[1:]:
|
| 45 |
-
pairs.add((prev_char, char))
|
| 46 |
-
prev_char = char
|
| 47 |
-
return pairs
|
| 48 |
-
|
| 49 |
-
|
| 50 |
-
def basic_clean(text):
|
| 51 |
-
text = ftfy.fix_text(text)
|
| 52 |
-
text = html.unescape(html.unescape(text))
|
| 53 |
-
return text.strip()
|
| 54 |
-
|
| 55 |
-
|
| 56 |
-
def whitespace_clean(text):
|
| 57 |
-
text = re.sub(r'\s+', ' ', text)
|
| 58 |
-
text = text.strip()
|
| 59 |
-
return text
|
| 60 |
-
|
| 61 |
-
|
| 62 |
-
class SimpleTokenizer(object):
|
| 63 |
-
def __init__(self, bpe_path: str = default_bpe()):
|
| 64 |
-
self.byte_encoder = bytes_to_unicode()
|
| 65 |
-
self.byte_decoder = {v: k for k, v in self.byte_encoder.items()}
|
| 66 |
-
merges = gzip.open(bpe_path).read().decode("utf-8").split('\n')
|
| 67 |
-
merges = merges[1:49152-256-2+1]
|
| 68 |
-
merges = [tuple(merge.split()) for merge in merges]
|
| 69 |
-
vocab = list(bytes_to_unicode().values())
|
| 70 |
-
vocab = vocab + [v+'</w>' for v in vocab]
|
| 71 |
-
for merge in merges:
|
| 72 |
-
vocab.append(''.join(merge))
|
| 73 |
-
vocab.extend(['<|startoftext|>', '<|endoftext|>'])
|
| 74 |
-
self.encoder = dict(zip(vocab, range(len(vocab))))
|
| 75 |
-
self.decoder = {v: k for k, v in self.encoder.items()}
|
| 76 |
-
self.bpe_ranks = dict(zip(merges, range(len(merges))))
|
| 77 |
-
self.cache = {'<|startoftext|>': '<|startoftext|>', '<|endoftext|>': '<|endoftext|>'}
|
| 78 |
-
self.pat = re.compile(r"""<\|startoftext\|>|<\|endoftext\|>|'s|'t|'re|'ve|'m|'ll|'d|[\p{L}]+|[\p{N}]|[^\s\p{L}\p{N}]+""", re.IGNORECASE)
|
| 79 |
-
|
| 80 |
-
def bpe(self, token):
|
| 81 |
-
if token in self.cache:
|
| 82 |
-
return self.cache[token]
|
| 83 |
-
word = tuple(token[:-1]) + ( token[-1] + '</w>',)
|
| 84 |
-
pairs = get_pairs(word)
|
| 85 |
-
|
| 86 |
-
if not pairs:
|
| 87 |
-
return token+'</w>'
|
| 88 |
-
|
| 89 |
-
while True:
|
| 90 |
-
bigram = min(pairs, key = lambda pair: self.bpe_ranks.get(pair, float('inf')))
|
| 91 |
-
if bigram not in self.bpe_ranks:
|
| 92 |
-
break
|
| 93 |
-
first, second = bigram
|
| 94 |
-
new_word = []
|
| 95 |
-
i = 0
|
| 96 |
-
while i < len(word):
|
| 97 |
-
try:
|
| 98 |
-
j = word.index(first, i)
|
| 99 |
-
new_word.extend(word[i:j])
|
| 100 |
-
i = j
|
| 101 |
-
except:
|
| 102 |
-
new_word.extend(word[i:])
|
| 103 |
-
break
|
| 104 |
-
|
| 105 |
-
if word[i] == first and i < len(word)-1 and word[i+1] == second:
|
| 106 |
-
new_word.append(first+second)
|
| 107 |
-
i += 2
|
| 108 |
-
else:
|
| 109 |
-
new_word.append(word[i])
|
| 110 |
-
i += 1
|
| 111 |
-
new_word = tuple(new_word)
|
| 112 |
-
word = new_word
|
| 113 |
-
if len(word) == 1:
|
| 114 |
-
break
|
| 115 |
-
else:
|
| 116 |
-
pairs = get_pairs(word)
|
| 117 |
-
word = ' '.join(word)
|
| 118 |
-
self.cache[token] = word
|
| 119 |
-
return word
|
| 120 |
-
|
| 121 |
-
def encode(self, text):
|
| 122 |
-
bpe_tokens = []
|
| 123 |
-
text = whitespace_clean(basic_clean(text)).lower()
|
| 124 |
-
for token in re.findall(self.pat, text):
|
| 125 |
-
token = ''.join(self.byte_encoder[b] for b in token.encode('utf-8'))
|
| 126 |
-
bpe_tokens.extend(self.encoder[bpe_token] for bpe_token in self.bpe(token).split(' '))
|
| 127 |
-
return bpe_tokens
|
| 128 |
-
|
| 129 |
-
def decode(self, tokens):
|
| 130 |
-
text = ''.join([self.decoder[token] for token in tokens])
|
| 131 |
-
text = bytearray([self.byte_decoder[c] for c in text]).decode('utf-8', errors="replace").replace('</w>', ' ')
|
| 132 |
-
return text
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
src/INViTE/loader.py
DELETED
|
@@ -1,72 +0,0 @@
|
|
| 1 |
-
import torch
|
| 2 |
-
from typing import Union
|
| 3 |
-
from .clipfolder.clip import load as invite_clip_load, tokenize as invite_clip_tokenize
|
| 4 |
-
|
| 5 |
-
|
| 6 |
-
def load_invite_clip(config: dict, device: Union[str, torch.device] = "cuda" if torch.cuda.is_available() else "cpu"):
|
| 7 |
-
"""
|
| 8 |
-
Load an INViTE CLIP model based on the provided configuration.
|
| 9 |
-
|
| 10 |
-
This method loads an INViTE CLIP model similar to how RegionCLIP is loaded in the Patchioner class.
|
| 11 |
-
|
| 12 |
-
Args:
|
| 13 |
-
config (dict): Configuration dictionary containing the following keys:
|
| 14 |
-
- name (str): Model name listed by `clip.available_models()`, or path to a model checkpoint
|
| 15 |
-
- jit (bool, optional): Whether to load the optimized JIT model. Defaults to False
|
| 16 |
-
- download_root (str, optional): Path to download model files. Defaults to '/raid/datasets/models_weights/INViTE'
|
| 17 |
-
- extract_last_k_th_token (int, optional): Extract last k-th token. Defaults to -1
|
| 18 |
-
- viz (bool, optional): Visualization flag. Defaults to False
|
| 19 |
-
device (Union[str, torch.device], optional): Device to load the model on.
|
| 20 |
-
Defaults to "cuda" if available, else "cpu"
|
| 21 |
-
|
| 22 |
-
Returns:
|
| 23 |
-
tuple: (model, preprocess_transform, tokenize_fn)
|
| 24 |
-
- model: The loaded INViTE CLIP model
|
| 25 |
-
- preprocess_transform: Torchvision transform for preprocessing images
|
| 26 |
-
- tokenize_fn: Tokenization function for text processing
|
| 27 |
-
|
| 28 |
-
Raises:
|
| 29 |
-
KeyError: If required 'name' key is missing from config
|
| 30 |
-
RuntimeError: If model loading fails
|
| 31 |
-
|
| 32 |
-
Example:
|
| 33 |
-
config = {
|
| 34 |
-
'name': 'ViT-B/32',
|
| 35 |
-
'jit': False,
|
| 36 |
-
'download_root': '/raid/datasets/models_weights/INViTE', # optional, this is the default
|
| 37 |
-
'extract_last_k_th_token': -1,
|
| 38 |
-
'viz': False
|
| 39 |
-
}
|
| 40 |
-
model, preprocess, tokenize = load_invite_clip(config, device='cuda')
|
| 41 |
-
"""
|
| 42 |
-
|
| 43 |
-
# Validate required parameters
|
| 44 |
-
if 'name' not in config:
|
| 45 |
-
raise KeyError("'name' key is required in config dictionary")
|
| 46 |
-
|
| 47 |
-
# Extract parameters with defaults
|
| 48 |
-
name = config['name']
|
| 49 |
-
jit = config.get('jit', False)
|
| 50 |
-
download_root = config.get('download_root', '/raid/datasets/models_weights/INViTE')
|
| 51 |
-
extract_last_k_th_token = config.get('extract_last_k_th_token', -1)
|
| 52 |
-
viz = config.get('viz', False)
|
| 53 |
-
|
| 54 |
-
image_resolution = config.get('resolution', None) # Default resolution if not specified
|
| 55 |
-
|
| 56 |
-
# Load the INViTE CLIP model using the clip.load function
|
| 57 |
-
try:
|
| 58 |
-
model, preprocess_transform = invite_clip_load(
|
| 59 |
-
name=name,
|
| 60 |
-
device=device,
|
| 61 |
-
jit=jit,
|
| 62 |
-
download_root=download_root,
|
| 63 |
-
extract_last_k_th_token=extract_last_k_th_token,
|
| 64 |
-
viz=viz,
|
| 65 |
-
image_resolution=image_resolution
|
| 66 |
-
)
|
| 67 |
-
|
| 68 |
-
# Return model, preprocess transform, and tokenize function
|
| 69 |
-
return model, preprocess_transform, invite_clip_tokenize
|
| 70 |
-
|
| 71 |
-
except Exception as e:
|
| 72 |
-
raise RuntimeError(f"Failed to load INViTE CLIP model '{name}': {str(e)}")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
src/alphaclip/INSTALL.md
DELETED
|
@@ -1,113 +0,0 @@
|
|
| 1 |
-
# AlphaCLIP Standalone - Installation Guide
|
| 2 |
-
|
| 3 |
-
## Quick Installation
|
| 4 |
-
|
| 5 |
-
### Prerequisites
|
| 6 |
-
- Python 3.7 or higher
|
| 7 |
-
- pip package manager
|
| 8 |
-
|
| 9 |
-
### Step 1: Install Dependencies
|
| 10 |
-
|
| 11 |
-
```bash
|
| 12 |
-
cd alphaclip-standalone
|
| 13 |
-
pip install -r requirements.txt
|
| 14 |
-
```
|
| 15 |
-
|
| 16 |
-
### Step 2: Install the Package
|
| 17 |
-
|
| 18 |
-
```bash
|
| 19 |
-
# Install in development mode (recommended for testing)
|
| 20 |
-
pip install -e .
|
| 21 |
-
|
| 22 |
-
# OR install normally
|
| 23 |
-
pip install .
|
| 24 |
-
```
|
| 25 |
-
|
| 26 |
-
### Step 3: Test Installation
|
| 27 |
-
|
| 28 |
-
```bash
|
| 29 |
-
python test_installation.py
|
| 30 |
-
```
|
| 31 |
-
|
| 32 |
-
### Step 4: Run Example
|
| 33 |
-
|
| 34 |
-
```bash
|
| 35 |
-
python example.py
|
| 36 |
-
```
|
| 37 |
-
|
| 38 |
-
## Manual Dependency Installation
|
| 39 |
-
|
| 40 |
-
If you encounter issues with the requirements.txt, install dependencies manually:
|
| 41 |
-
|
| 42 |
-
```bash
|
| 43 |
-
# Core PyTorch (choose appropriate version for your system)
|
| 44 |
-
pip install torch torchvision torchaudio
|
| 45 |
-
|
| 46 |
-
# Text processing
|
| 47 |
-
pip install ftfy regex tqdm
|
| 48 |
-
|
| 49 |
-
# LoRA support
|
| 50 |
-
pip install loralib
|
| 51 |
-
|
| 52 |
-
# Image processing
|
| 53 |
-
pip install Pillow
|
| 54 |
-
|
| 55 |
-
# Utilities
|
| 56 |
-
pip install numpy packaging
|
| 57 |
-
```
|
| 58 |
-
|
| 59 |
-
## GPU Support
|
| 60 |
-
|
| 61 |
-
For CUDA support, make sure you install PyTorch with CUDA:
|
| 62 |
-
|
| 63 |
-
```bash
|
| 64 |
-
# For CUDA 11.8
|
| 65 |
-
pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu118
|
| 66 |
-
|
| 67 |
-
# For CUDA 12.1
|
| 68 |
-
pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu121
|
| 69 |
-
|
| 70 |
-
# Check your CUDA version with: nvidia-smi
|
| 71 |
-
```
|
| 72 |
-
|
| 73 |
-
## Verification
|
| 74 |
-
|
| 75 |
-
After installation, verify everything works:
|
| 76 |
-
|
| 77 |
-
```python
|
| 78 |
-
from alphaclip_loader import AlphaCLIPLoader
|
| 79 |
-
|
| 80 |
-
# This should work without errors
|
| 81 |
-
loader = AlphaCLIPLoader()
|
| 82 |
-
models = loader.available_models()
|
| 83 |
-
print("Available models:", models)
|
| 84 |
-
```
|
| 85 |
-
|
| 86 |
-
## Troubleshooting
|
| 87 |
-
|
| 88 |
-
### Common Issues
|
| 89 |
-
|
| 90 |
-
1. **ImportError: No module named 'loralib'**
|
| 91 |
-
```bash
|
| 92 |
-
pip install loralib
|
| 93 |
-
```
|
| 94 |
-
|
| 95 |
-
2. **CUDA out of memory**
|
| 96 |
-
- Use CPU: `AlphaCLIPLoader(default_device="cpu")`
|
| 97 |
-
- Or use a smaller model like "ViT-B/32"
|
| 98 |
-
|
| 99 |
-
3. **Model download fails**
|
| 100 |
-
- Check internet connection
|
| 101 |
-
- Ensure you have enough disk space (~1GB per model)
|
| 102 |
-
- Models are cached in `~/.cache/clip/`
|
| 103 |
-
|
| 104 |
-
4. **Permission errors**
|
| 105 |
-
- Use `--user` flag: `pip install --user -e .`
|
| 106 |
-
|
| 107 |
-
### Getting Help
|
| 108 |
-
|
| 109 |
-
If you encounter issues:
|
| 110 |
-
1. Check that all dependencies are properly installed
|
| 111 |
-
2. Run the test script: `python test_installation.py`
|
| 112 |
-
3. Check CUDA compatibility if using GPU
|
| 113 |
-
4. Ensure Python version is 3.7+
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
src/alphaclip/LICENSE
DELETED
|
@@ -1,201 +0,0 @@
|
|
| 1 |
-
Apache License
|
| 2 |
-
Version 2.0, January 2004
|
| 3 |
-
http://www.apache.org/licenses/
|
| 4 |
-
|
| 5 |
-
TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
|
| 6 |
-
|
| 7 |
-
1. Definitions.
|
| 8 |
-
|
| 9 |
-
"License" shall mean the terms and conditions for use, reproduction,
|
| 10 |
-
and distribution as defined by Sections 1 through 9 of this document.
|
| 11 |
-
|
| 12 |
-
"Licensor" shall mean the copyright owner or entity authorized by
|
| 13 |
-
the copyright owner that is granting the License.
|
| 14 |
-
|
| 15 |
-
"Legal Entity" shall mean the union of the acting entity and all
|
| 16 |
-
other entities that control, are controlled by, or are under common
|
| 17 |
-
control with that entity. For the purposes of this definition,
|
| 18 |
-
"control" means (i) the power, direct or indirect, to cause the
|
| 19 |
-
direction or management of such entity, whether by contract or
|
| 20 |
-
otherwise, or (ii) ownership of fifty percent (50%) or more of the
|
| 21 |
-
outstanding shares, or (iii) beneficial ownership of such entity.
|
| 22 |
-
|
| 23 |
-
"You" (or "Your") shall mean an individual or Legal Entity
|
| 24 |
-
exercising permissions granted by this License.
|
| 25 |
-
|
| 26 |
-
"Source" form shall mean the preferred form for making modifications,
|
| 27 |
-
including but not limited to software source code, documentation
|
| 28 |
-
source, and configuration files.
|
| 29 |
-
|
| 30 |
-
"Object" form shall mean any form resulting from mechanical
|
| 31 |
-
transformation or translation of a Source form, including but
|
| 32 |
-
not limited to compiled object code, generated documentation,
|
| 33 |
-
and conversions to other media types.
|
| 34 |
-
|
| 35 |
-
"Work" shall mean the work of authorship, whether in Source or
|
| 36 |
-
Object form, made available under the License, as indicated by a
|
| 37 |
-
copyright notice that is included in or attached to the work
|
| 38 |
-
(an example is provided in the Appendix below).
|
| 39 |
-
|
| 40 |
-
"Derivative Works" shall mean any work, whether in Source or Object
|
| 41 |
-
form, that is based on (or derived from) the Work and for which the
|
| 42 |
-
editorial revisions, annotations, elaborations, or other modifications
|
| 43 |
-
represent, as a whole, an original work of authorship. For the purposes
|
| 44 |
-
of this License, Derivative Works shall not include works that remain
|
| 45 |
-
separable from, or merely link (or bind by name) to the interfaces of,
|
| 46 |
-
the Work and Derivative Works thereof.
|
| 47 |
-
|
| 48 |
-
"Contribution" shall mean any work of authorship, including
|
| 49 |
-
the original version of the Work and any modifications or additions
|
| 50 |
-
to that Work or Derivative Works thereof, that is intentionally
|
| 51 |
-
submitted to Licensor for inclusion in the Work by the copyright owner
|
| 52 |
-
or by an individual or Legal Entity authorized to submit on behalf of
|
| 53 |
-
the copyright owner. For the purposes of this definition, "submitted"
|
| 54 |
-
means any form of electronic, verbal, or written communication sent
|
| 55 |
-
to the Licensor or its representatives, including but not limited to
|
| 56 |
-
communication on electronic mailing lists, source code control systems,
|
| 57 |
-
and issue tracking systems that are managed by, or on behalf of, the
|
| 58 |
-
Licensor for the purpose of discussing and improving the Work, but
|
| 59 |
-
excluding communication that is conspicuously marked or otherwise
|
| 60 |
-
designated in writing by the copyright owner as "Not a Contribution."
|
| 61 |
-
|
| 62 |
-
"Contributor" shall mean Licensor and any individual or Legal Entity
|
| 63 |
-
on behalf of whom a Contribution has been received by Licensor and
|
| 64 |
-
subsequently incorporated within the Work.
|
| 65 |
-
|
| 66 |
-
2. Grant of Copyright License. Subject to the terms and conditions of
|
| 67 |
-
this License, each Contributor hereby grants to You a perpetual,
|
| 68 |
-
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
| 69 |
-
copyright license to reproduce, prepare Derivative Works of,
|
| 70 |
-
publicly display, publicly perform, sublicense, and distribute the
|
| 71 |
-
Work and such Derivative Works in Source or Object form.
|
| 72 |
-
|
| 73 |
-
3. Grant of Patent License. Subject to the terms and conditions of
|
| 74 |
-
this License, each Contributor hereby grants to You a perpetual,
|
| 75 |
-
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
| 76 |
-
(except as stated in this section) patent license to make, have made,
|
| 77 |
-
use, offer to sell, sell, import, and otherwise transfer the Work,
|
| 78 |
-
where such license applies only to those patent claims licensable
|
| 79 |
-
by such Contributor that are necessarily infringed by their
|
| 80 |
-
Contribution(s) alone or by combination of their Contribution(s)
|
| 81 |
-
with the Work to which such Contribution(s) was submitted. If You
|
| 82 |
-
institute patent litigation against any entity (including a
|
| 83 |
-
cross-claim or counterclaim in a lawsuit) alleging that the Work
|
| 84 |
-
or a Contribution incorporated within the Work constitutes direct
|
| 85 |
-
or contributory patent infringement, then any patent licenses
|
| 86 |
-
granted to You under this License for that Work shall terminate
|
| 87 |
-
as of the date such litigation is filed.
|
| 88 |
-
|
| 89 |
-
4. Redistribution. You may reproduce and distribute copies of the
|
| 90 |
-
Work or Derivative Works thereof in any medium, with or without
|
| 91 |
-
modifications, and in Source or Object form, provided that You
|
| 92 |
-
meet the following conditions:
|
| 93 |
-
|
| 94 |
-
(a) You must give any other recipients of the Work or
|
| 95 |
-
Derivative Works a copy of this License; and
|
| 96 |
-
|
| 97 |
-
(b) You must cause any modified files to carry prominent notices
|
| 98 |
-
stating that You changed the files; and
|
| 99 |
-
|
| 100 |
-
(c) You must retain, in the Source form of any Derivative Works
|
| 101 |
-
that You distribute, all copyright, patent, trademark, and
|
| 102 |
-
attribution notices from the Source form of the Work,
|
| 103 |
-
excluding those notices that do not pertain to any part of
|
| 104 |
-
the Derivative Works; and
|
| 105 |
-
|
| 106 |
-
(d) If the Work includes a "NOTICE" text file as part of its
|
| 107 |
-
distribution, then any Derivative Works that You distribute must
|
| 108 |
-
include a readable copy of the attribution notices contained
|
| 109 |
-
within such NOTICE file, excluding those notices that do not
|
| 110 |
-
pertain to any part of the Derivative Works, in at least one
|
| 111 |
-
of the following places: within a NOTICE text file distributed
|
| 112 |
-
as part of the Derivative Works; within the Source form or
|
| 113 |
-
documentation, if provided along with the Derivative Works; or,
|
| 114 |
-
within a display generated by the Derivative Works, if and
|
| 115 |
-
wherever such third-party notices normally appear. The contents
|
| 116 |
-
of the NOTICE file are for informational purposes only and
|
| 117 |
-
do not modify the License. You may add Your own attribution
|
| 118 |
-
notices within Derivative Works that You distribute, alongside
|
| 119 |
-
or as an addendum to the NOTICE text from the Work, provided
|
| 120 |
-
that such additional attribution notices cannot be construed
|
| 121 |
-
as modifying the License.
|
| 122 |
-
|
| 123 |
-
You may add Your own copyright statement to Your modifications and
|
| 124 |
-
may provide additional or different license terms and conditions
|
| 125 |
-
for use, reproduction, or distribution of Your modifications, or
|
| 126 |
-
for any such Derivative Works as a whole, provided Your use,
|
| 127 |
-
reproduction, and distribution of the Work otherwise complies with
|
| 128 |
-
the conditions stated in this License.
|
| 129 |
-
|
| 130 |
-
5. Submission of Contributions. Unless You explicitly state otherwise,
|
| 131 |
-
any Contribution intentionally submitted for inclusion in the Work
|
| 132 |
-
by You to the Licensor shall be under the terms and conditions of
|
| 133 |
-
this License, without any additional terms or conditions.
|
| 134 |
-
Notwithstanding the above, nothing herein shall supersede or modify
|
| 135 |
-
the terms of any separate license agreement you may have executed
|
| 136 |
-
with Licensor regarding such Contributions.
|
| 137 |
-
|
| 138 |
-
6. Trademarks. This License does not grant permission to use the trade
|
| 139 |
-
names, trademarks, service marks, or product names of the Licensor,
|
| 140 |
-
except as required for reasonable and customary use in describing the
|
| 141 |
-
origin of the Work and reproducing the content of the NOTICE file.
|
| 142 |
-
|
| 143 |
-
7. Disclaimer of Warranty. Unless required by applicable law or
|
| 144 |
-
agreed to in writing, Licensor provides the Work (and each
|
| 145 |
-
Contributor provides its Contributions) on an "AS IS" BASIS,
|
| 146 |
-
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
|
| 147 |
-
implied, including, without limitation, any warranties or conditions
|
| 148 |
-
of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
|
| 149 |
-
PARTICULAR PURPOSE. You are solely responsible for determining the
|
| 150 |
-
appropriateness of using or redistributing the Work and assume any
|
| 151 |
-
risks associated with Your exercise of permissions under this License.
|
| 152 |
-
|
| 153 |
-
8. Limitation of Liability. In no event and under no legal theory,
|
| 154 |
-
whether in tort (including negligence), contract, or otherwise,
|
| 155 |
-
unless required by applicable law (such as deliberate and grossly
|
| 156 |
-
negligent acts) or agreed to in writing, shall any Contributor be
|
| 157 |
-
liable to You for damages, including any direct, indirect, special,
|
| 158 |
-
incidental, or consequential damages of any character arising as a
|
| 159 |
-
result of this License or out of the use or inability to use the
|
| 160 |
-
Work (including but not limited to damages for loss of goodwill,
|
| 161 |
-
work stoppage, computer failure or malfunction, or any and all
|
| 162 |
-
other commercial damages or losses), even if such Contributor
|
| 163 |
-
has been advised of the possibility of such damages.
|
| 164 |
-
|
| 165 |
-
9. Accepting Warranty or Additional Liability. While redistributing
|
| 166 |
-
the Work or Derivative Works thereof, You may choose to offer,
|
| 167 |
-
and charge a fee for, acceptance of support, warranty, indemnity,
|
| 168 |
-
or other liability obligations and/or rights consistent with this
|
| 169 |
-
License. However, in accepting such obligations, You may act only
|
| 170 |
-
on Your own behalf and on Your sole responsibility, not on behalf
|
| 171 |
-
of any other Contributor, and only if You agree to indemnify,
|
| 172 |
-
defend, and hold each Contributor harmless for any liability
|
| 173 |
-
incurred by, or claims asserted against, such Contributor by reason
|
| 174 |
-
of your accepting any such warranty or additional liability.
|
| 175 |
-
|
| 176 |
-
END OF TERMS AND CONDITIONS
|
| 177 |
-
|
| 178 |
-
APPENDIX: How to apply the Apache License to your work.
|
| 179 |
-
|
| 180 |
-
To apply the Apache License to your work, attach the following
|
| 181 |
-
boilerplate notice, with the fields enclosed by brackets "[]"
|
| 182 |
-
replaced with your own identifying information. (Don't include
|
| 183 |
-
the brackets!) The text should be enclosed in the appropriate
|
| 184 |
-
comment syntax for the file format. We also recommend that a
|
| 185 |
-
file or class name and description of purpose be included on the
|
| 186 |
-
same "printed page" as the copyright notice for easier
|
| 187 |
-
identification within third-party archives.
|
| 188 |
-
|
| 189 |
-
Copyright [Zeyi Sun] [name of copyright owner]
|
| 190 |
-
|
| 191 |
-
Licensed under the Apache License, Version 2.0 (the "License");
|
| 192 |
-
you may not use this file except in compliance with the License.
|
| 193 |
-
You may obtain a copy of the License at
|
| 194 |
-
|
| 195 |
-
http://www.apache.org/licenses/LICENSE-2.0
|
| 196 |
-
|
| 197 |
-
Unless required by applicable law or agreed to in writing, software
|
| 198 |
-
distributed under the License is distributed on an "AS IS" BASIS,
|
| 199 |
-
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 200 |
-
See the License for the specific language governing permissions and
|
| 201 |
-
limitations under the License.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
src/alphaclip/MANIFEST.in
DELETED
|
@@ -1,7 +0,0 @@
|
|
| 1 |
-
include README.md
|
| 2 |
-
include requirements.txt
|
| 3 |
-
include LICENSE
|
| 4 |
-
recursive-include alpha_clip *.py
|
| 5 |
-
recursive-include alpha_clip *.gz
|
| 6 |
-
include example.py
|
| 7 |
-
include test_installation.py
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
src/alphaclip/README.md
DELETED
|
@@ -1,266 +0,0 @@
|
|
| 1 |
-
# AlphaCLIP Standalone
|
| 2 |
-
|
| 3 |
-
A standalone, easy-to-use version of AlphaCLIP that can be integrated into any project without complex dependencies or setup.
|
| 4 |
-
|
| 5 |
-
## Overview
|
| 6 |
-
|
| 7 |
-
AlphaCLIP is an enhanced version of OpenAI's CLIP model that provides improved vision-language understanding capabilities. This standalone package makes it easy to use AlphaCLIP in your projects with minimal setup.
|
| 8 |
-
|
| 9 |
-
## Features
|
| 10 |
-
|
| 11 |
-
- **Easy Installation**: Simple pip install with minimal dependencies
|
| 12 |
-
- **Clean API**: Intuitive interface for loading models and processing data
|
| 13 |
-
- **Device Flexibility**: Automatic CUDA/CPU detection with manual override options
|
| 14 |
-
- **Model Variety**: Support for multiple AlphaCLIP model variants
|
| 15 |
-
- **Preprocessing Included**: Built-in image preprocessing and text tokenization
|
| 16 |
-
|
| 17 |
-
## Installation
|
| 18 |
-
|
| 19 |
-
### Requirements
|
| 20 |
-
|
| 21 |
-
- Python 3.7 or higher
|
| 22 |
-
- PyTorch 1.7.1 or higher
|
| 23 |
-
- CUDA (optional, for GPU acceleration)
|
| 24 |
-
|
| 25 |
-
### Install from source
|
| 26 |
-
|
| 27 |
-
```bash
|
| 28 |
-
# Clone or download this standalone package
|
| 29 |
-
cd alphaclip-standalone
|
| 30 |
-
|
| 31 |
-
# Install dependencies
|
| 32 |
-
pip install -r requirements.txt
|
| 33 |
-
|
| 34 |
-
# Install the package
|
| 35 |
-
pip install -e .
|
| 36 |
-
```
|
| 37 |
-
|
| 38 |
-
### Core Dependencies
|
| 39 |
-
|
| 40 |
-
The package requires the following core dependencies:
|
| 41 |
-
|
| 42 |
-
```
|
| 43 |
-
torch>=1.7.1
|
| 44 |
-
torchvision
|
| 45 |
-
ftfy
|
| 46 |
-
regex
|
| 47 |
-
tqdm
|
| 48 |
-
loralib
|
| 49 |
-
Pillow
|
| 50 |
-
numpy
|
| 51 |
-
packaging
|
| 52 |
-
```
|
| 53 |
-
|
| 54 |
-
## Quick Start
|
| 55 |
-
|
| 56 |
-
### Basic Usage
|
| 57 |
-
|
| 58 |
-
```python
|
| 59 |
-
from alphaclip_loader import AlphaCLIPLoader
|
| 60 |
-
|
| 61 |
-
# Initialize the loader
|
| 62 |
-
loader = AlphaCLIPLoader()
|
| 63 |
-
|
| 64 |
-
# Load a model (this will download the model if not cached)
|
| 65 |
-
model, preprocess = loader.load_model("ViT-B/16")
|
| 66 |
-
|
| 67 |
-
# Tokenize text
|
| 68 |
-
text_tokens = loader.tokenize("A photo of a cat")
|
| 69 |
-
|
| 70 |
-
# Get text embeddings
|
| 71 |
-
text_features = loader.encode_text(model, "A photo of a cat")
|
| 72 |
-
|
| 73 |
-
print(f"Text features shape: {text_features.shape}")
|
| 74 |
-
```
|
| 75 |
-
|
| 76 |
-
### Advanced Usage
|
| 77 |
-
|
| 78 |
-
```python
|
| 79 |
-
import torch
|
| 80 |
-
from PIL import Image
|
| 81 |
-
from alphaclip_loader import AlphaCLIPLoader
|
| 82 |
-
|
| 83 |
-
# Initialize with specific device
|
| 84 |
-
loader = AlphaCLIPLoader(default_device="cuda")
|
| 85 |
-
|
| 86 |
-
# Load model with custom options
|
| 87 |
-
model, preprocess = loader.load_model(
|
| 88 |
-
"ViT-B/16",
|
| 89 |
-
device="cuda",
|
| 90 |
-
lora_adapt=False,
|
| 91 |
-
rank=16
|
| 92 |
-
)
|
| 93 |
-
|
| 94 |
-
# Process an image
|
| 95 |
-
image = Image.open("your_image.jpg")
|
| 96 |
-
image_tensor = preprocess(image).unsqueeze(0) # Add batch dimension
|
| 97 |
-
|
| 98 |
-
# Get embeddings
|
| 99 |
-
with torch.no_grad():
|
| 100 |
-
image_features = loader.encode_image(model, image_tensor)
|
| 101 |
-
text_features = loader.encode_text(model, ["A photo of a cat", "A dog playing"])
|
| 102 |
-
|
| 103 |
-
# Compute similarities
|
| 104 |
-
similarities = loader.get_similarity(text_features, image_features)
|
| 105 |
-
print(f"Similarities: {similarities}")
|
| 106 |
-
```
|
| 107 |
-
|
| 108 |
-
### One-line Model Loading
|
| 109 |
-
|
| 110 |
-
```python
|
| 111 |
-
from alphaclip_loader import load_alphaclip
|
| 112 |
-
|
| 113 |
-
# Quick loading function
|
| 114 |
-
loader, model, preprocess = load_alphaclip("ViT-B/16", device="cuda")
|
| 115 |
-
```
|
| 116 |
-
|
| 117 |
-
## Available Models
|
| 118 |
-
|
| 119 |
-
You can check available models using:
|
| 120 |
-
|
| 121 |
-
```python
|
| 122 |
-
from alphaclip_loader import AlphaCLIPLoader
|
| 123 |
-
|
| 124 |
-
loader = AlphaCLIPLoader()
|
| 125 |
-
models = loader.available_models()
|
| 126 |
-
print("Available models:", models)
|
| 127 |
-
```
|
| 128 |
-
|
| 129 |
-
Typically includes:
|
| 130 |
-
- `ViT-B/32`
|
| 131 |
-
- `ViT-B/16`
|
| 132 |
-
- `ViT-L/14`
|
| 133 |
-
- `ViT-L/14@336px`
|
| 134 |
-
- `RN50`, `RN101`, `RN50x4`, `RN50x16`, `RN50x64`
|
| 135 |
-
|
| 136 |
-
## API Reference
|
| 137 |
-
|
| 138 |
-
### AlphaCLIPLoader Class
|
| 139 |
-
|
| 140 |
-
#### Methods
|
| 141 |
-
|
| 142 |
-
- **`__init__(default_device=None)`**: Initialize loader with optional default device
|
| 143 |
-
- **`available_models()`**: Get list of available model names
|
| 144 |
-
- **`load_model(name, **kwargs)`**: Load a model with preprocessing function
|
| 145 |
-
- **`tokenize(texts, context_length=77, truncate=True)`**: Tokenize text input
|
| 146 |
-
- **`encode_text(model, texts)`**: Encode text to embeddings
|
| 147 |
-
- **`encode_image(model, images)`**: Encode images to embeddings
|
| 148 |
-
- **`get_similarity(text_features, image_features)`**: Compute cosine similarity
|
| 149 |
-
|
| 150 |
-
#### load_model Parameters
|
| 151 |
-
|
| 152 |
-
- `name`: Model name or checkpoint path
|
| 153 |
-
- `alpha_vision_ckpt_pth`: Additional vision checkpoint path (default: "None")
|
| 154 |
-
- `device`: Device to load on (default: auto-detect)
|
| 155 |
-
- `jit`: Use JIT compilation (default: False)
|
| 156 |
-
- `download_root`: Model download directory (default: ~/.cache/clip)
|
| 157 |
-
- `lora_adapt`: Use LoRA adaptation (default: False)
|
| 158 |
-
- `rank`: LoRA rank if enabled (default: 16)
|
| 159 |
-
|
| 160 |
-
## Example Use Cases
|
| 161 |
-
|
| 162 |
-
### Image-Text Similarity
|
| 163 |
-
|
| 164 |
-
```python
|
| 165 |
-
from alphaclip_loader import load_alphaclip
|
| 166 |
-
from PIL import Image
|
| 167 |
-
import torch
|
| 168 |
-
|
| 169 |
-
loader, model, preprocess = load_alphaclip()
|
| 170 |
-
|
| 171 |
-
# Load and preprocess image
|
| 172 |
-
image = Image.open("cat.jpg")
|
| 173 |
-
image_input = preprocess(image).unsqueeze(0)
|
| 174 |
-
|
| 175 |
-
# Define candidate texts
|
| 176 |
-
texts = ["a cat", "a dog", "a bird", "a car"]
|
| 177 |
-
|
| 178 |
-
# Get features
|
| 179 |
-
image_features = loader.encode_image(model, image_input)
|
| 180 |
-
text_features = loader.encode_text(model, texts)
|
| 181 |
-
|
| 182 |
-
# Calculate similarities
|
| 183 |
-
similarities = loader.get_similarity(text_features, image_features)
|
| 184 |
-
|
| 185 |
-
# Find best match
|
| 186 |
-
best_match_idx = similarities.argmax()
|
| 187 |
-
print(f"Best match: {texts[best_match_idx]} (score: {similarities[best_match_idx]:.3f})")
|
| 188 |
-
```
|
| 189 |
-
|
| 190 |
-
### Batch Processing
|
| 191 |
-
|
| 192 |
-
```python
|
| 193 |
-
from alphaclip_loader import AlphaCLIPLoader
|
| 194 |
-
import torch
|
| 195 |
-
|
| 196 |
-
loader = AlphaCLIPLoader()
|
| 197 |
-
model, preprocess = loader.load_model("ViT-B/16")
|
| 198 |
-
|
| 199 |
-
# Process multiple texts at once
|
| 200 |
-
texts = [
|
| 201 |
-
"A red apple on a table",
|
| 202 |
-
"A dog running in the park",
|
| 203 |
-
"A beautiful sunset"
|
| 204 |
-
]
|
| 205 |
-
|
| 206 |
-
# Batch tokenization and encoding
|
| 207 |
-
text_features = loader.encode_text(model, texts)
|
| 208 |
-
print(f"Batch text features shape: {text_features.shape}") # [3, 512]
|
| 209 |
-
```
|
| 210 |
-
|
| 211 |
-
## Performance Tips
|
| 212 |
-
|
| 213 |
-
1. **GPU Usage**: Use CUDA for better performance with larger models
|
| 214 |
-
2. **Batch Processing**: Process multiple texts/images together when possible
|
| 215 |
-
3. **Model Caching**: Models are automatically cached after first download
|
| 216 |
-
4. **Memory Management**: Use `torch.no_grad()` during inference to save memory
|
| 217 |
-
|
| 218 |
-
## Troubleshooting
|
| 219 |
-
|
| 220 |
-
### Common Issues
|
| 221 |
-
|
| 222 |
-
1. **CUDA Out of Memory**: Reduce batch size or use CPU
|
| 223 |
-
2. **Model Download Fails**: Check internet connection and disk space
|
| 224 |
-
3. **Import Errors**: Ensure all dependencies are installed
|
| 225 |
-
|
| 226 |
-
### Dependencies Issues
|
| 227 |
-
|
| 228 |
-
If you encounter import errors, try:
|
| 229 |
-
|
| 230 |
-
```bash
|
| 231 |
-
pip install --upgrade torch torchvision
|
| 232 |
-
pip install ftfy regex tqdm loralib
|
| 233 |
-
```
|
| 234 |
-
|
| 235 |
-
## File Structure
|
| 236 |
-
|
| 237 |
-
```
|
| 238 |
-
alphaclip-standalone/
|
| 239 |
-
โโโ __init__.py # Package initialization
|
| 240 |
-
โโโ alphaclip_loader.py # Main loader class
|
| 241 |
-
โโโ requirements.txt # Dependencies
|
| 242 |
-
โโโ setup.py # Package setup
|
| 243 |
-
โโโ README.md # This file
|
| 244 |
-
โโโ alpha_clip/ # Core AlphaCLIP modules
|
| 245 |
-
โโโ __init__.py
|
| 246 |
-
โโโ alpha_clip.py # Main AlphaCLIP functions
|
| 247 |
-
โโโ model.py # Model architectures
|
| 248 |
-
โโโ simple_tokenizer.py # Text tokenization
|
| 249 |
-
โโโ bpe_simple_vocab_16e6.txt.gz # Tokenizer vocabulary
|
| 250 |
-
```
|
| 251 |
-
|
| 252 |
-
## License
|
| 253 |
-
|
| 254 |
-
This standalone package maintains the same license as the original AlphaCLIP project.
|
| 255 |
-
|
| 256 |
-
## Contributing
|
| 257 |
-
|
| 258 |
-
This is a standalone distribution. For contributions to the core AlphaCLIP model, please refer to the main AlphaCLIP repository.
|
| 259 |
-
|
| 260 |
-
## Changelog
|
| 261 |
-
|
| 262 |
-
### Version 1.0.0
|
| 263 |
-
- Initial standalone release
|
| 264 |
-
- Clean API with AlphaCLIPLoader class
|
| 265 |
-
- Comprehensive documentation and examples
|
| 266 |
-
- Easy installation and setup
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
src/alphaclip/__init__.py
DELETED
|
@@ -1,14 +0,0 @@
|
|
| 1 |
-
"""
|
| 2 |
-
AlphaCLIP Standalone Package
|
| 3 |
-
|
| 4 |
-
A standalone version of AlphaCLIP that can be used independently.
|
| 5 |
-
"""
|
| 6 |
-
|
| 7 |
-
from .alphaclip_loader import AlphaCLIPLoader, load_alphaclip
|
| 8 |
-
|
| 9 |
-
# Version info
|
| 10 |
-
__version__ = "1.0.0"
|
| 11 |
-
__author__ = "AlphaCLIP Team"
|
| 12 |
-
|
| 13 |
-
# Make main classes available at package level
|
| 14 |
-
__all__ = ['AlphaCLIPLoader', 'load_alphaclip']
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
src/alphaclip/alpha_clip/__init__.py
DELETED
|
@@ -1 +0,0 @@
|
|
| 1 |
-
from .alpha_clip import *
|
|
|
|
|
|
src/alphaclip/alpha_clip/alpha_clip.py
DELETED
|
@@ -1,254 +0,0 @@
|
|
| 1 |
-
import hashlib
|
| 2 |
-
import os
|
| 3 |
-
import urllib
|
| 4 |
-
import warnings
|
| 5 |
-
from typing import Any, Union, List
|
| 6 |
-
from pkg_resources import packaging
|
| 7 |
-
|
| 8 |
-
import torch
|
| 9 |
-
from PIL import Image
|
| 10 |
-
from torchvision.transforms import Compose, Resize, CenterCrop, ToTensor, Normalize
|
| 11 |
-
from tqdm import tqdm
|
| 12 |
-
|
| 13 |
-
from .model import build_model
|
| 14 |
-
from .simple_tokenizer import SimpleTokenizer as _Tokenizer
|
| 15 |
-
|
| 16 |
-
try:
|
| 17 |
-
from torchvision.transforms import InterpolationMode
|
| 18 |
-
BICUBIC = InterpolationMode.BICUBIC
|
| 19 |
-
except ImportError:
|
| 20 |
-
BICUBIC = Image.BICUBIC
|
| 21 |
-
|
| 22 |
-
|
| 23 |
-
if packaging.version.parse(torch.__version__) < packaging.version.parse("1.7.1"):
|
| 24 |
-
warnings.warn("PyTorch version 1.7.1 or higher is recommended")
|
| 25 |
-
|
| 26 |
-
|
| 27 |
-
__all__ = ["available_models", "load", "tokenize"]
|
| 28 |
-
_tokenizer = _Tokenizer()
|
| 29 |
-
|
| 30 |
-
_MODELS = {
|
| 31 |
-
"RN50": "https://openaipublic.azureedge.net/clip/models/afeb0e10f9e5a86da6080e35cf09123aca3b358a0c3e3b6c78a7b63bc04b6762/RN50.pt",
|
| 32 |
-
"RN101": "https://openaipublic.azureedge.net/clip/models/8fa8567bab74a42d41c5915025a8e4538c3bdbe8804a470a72f30b0d94fab599/RN101.pt",
|
| 33 |
-
"RN50x4": "https://openaipublic.azureedge.net/clip/models/7e526bd135e493cef0776de27d5f42653e6b4c8bf9e0f653bb11773263205fdd/RN50x4.pt",
|
| 34 |
-
"RN50x16": "https://openaipublic.azureedge.net/clip/models/52378b407f34354e150460fe41077663dd5b39c54cd0bfd2b27167a4a06ec9aa/RN50x16.pt",
|
| 35 |
-
"RN50x64": "https://openaipublic.azureedge.net/clip/models/be1cfb55d75a9666199fb2206c106743da0f6468c9d327f3e0d0a543a9919d9c/RN50x64.pt",
|
| 36 |
-
"ViT-B/32": "https://openaipublic.azureedge.net/clip/models/40d365715913c9da98579312b702a82c18be219cc2a73407c4526f58eba950af/ViT-B-32.pt",
|
| 37 |
-
"ViT-B/16": "https://openaipublic.azureedge.net/clip/models/5806e77cd80f8b59890b7e101eabd078d9fb84e6937f9e85e4ecb61988df416f/ViT-B-16.pt",
|
| 38 |
-
"ViT-L/14": "https://openaipublic.azureedge.net/clip/models/b8cca3fd41ae0c99ba7e8951adf17d267cdb84cd88be6f7c2e0eca1737a03836/ViT-L-14.pt",
|
| 39 |
-
"ViT-L/14@336px": "https://openaipublic.azureedge.net/clip/models/3035c92b350959924f9f00213499208652fc7ea050643e8b385c2dac08641f02/ViT-L-14-336px.pt",
|
| 40 |
-
}
|
| 41 |
-
|
| 42 |
-
|
| 43 |
-
def _download(url: str, root: str):
|
| 44 |
-
os.makedirs(root, exist_ok=True)
|
| 45 |
-
filename = os.path.basename(url)
|
| 46 |
-
|
| 47 |
-
expected_sha256 = url.split("/")[-2]
|
| 48 |
-
download_target = os.path.join(root, filename)
|
| 49 |
-
|
| 50 |
-
if os.path.exists(download_target) and not os.path.isfile(download_target):
|
| 51 |
-
raise RuntimeError(f"{download_target} exists and is not a regular file")
|
| 52 |
-
|
| 53 |
-
if os.path.isfile(download_target):
|
| 54 |
-
if hashlib.sha256(open(download_target, "rb").read()).hexdigest() == expected_sha256:
|
| 55 |
-
return download_target
|
| 56 |
-
else:
|
| 57 |
-
warnings.warn(f"{download_target} exists, but the SHA256 checksum does not match; re-downloading the file")
|
| 58 |
-
|
| 59 |
-
with urllib.request.urlopen(url) as source, open(download_target, "wb") as output:
|
| 60 |
-
with tqdm(total=int(source.info().get("Content-Length")), ncols=80, unit='iB', unit_scale=True, unit_divisor=1024) as loop:
|
| 61 |
-
while True:
|
| 62 |
-
buffer = source.read(8192)
|
| 63 |
-
if not buffer:
|
| 64 |
-
break
|
| 65 |
-
|
| 66 |
-
output.write(buffer)
|
| 67 |
-
loop.update(len(buffer))
|
| 68 |
-
|
| 69 |
-
if hashlib.sha256(open(download_target, "rb").read()).hexdigest() != expected_sha256:
|
| 70 |
-
raise RuntimeError("Model has been downloaded but the SHA256 checksum does not not match")
|
| 71 |
-
|
| 72 |
-
return download_target
|
| 73 |
-
|
| 74 |
-
|
| 75 |
-
def _convert_image_to_rgb(image):
|
| 76 |
-
return image.convert("RGB")
|
| 77 |
-
|
| 78 |
-
|
| 79 |
-
def _transform(n_px):
|
| 80 |
-
return Compose([
|
| 81 |
-
Resize(n_px, interpolation=BICUBIC),
|
| 82 |
-
CenterCrop(n_px),
|
| 83 |
-
_convert_image_to_rgb,
|
| 84 |
-
ToTensor(),
|
| 85 |
-
Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711)),
|
| 86 |
-
])
|
| 87 |
-
|
| 88 |
-
|
| 89 |
-
def available_models() -> List[str]:
|
| 90 |
-
"""Returns the names of available CLIP models"""
|
| 91 |
-
return list(_MODELS.keys())
|
| 92 |
-
|
| 93 |
-
|
| 94 |
-
def load(name: str, alpha_vision_ckpt_pth="None", device: Union[str, torch.device] = "cuda" if torch.cuda.is_available() else "cpu", jit: bool = False, download_root: str = None, lora_adapt=False, rank=16):
|
| 95 |
-
"""Load a CLIP model
|
| 96 |
-
|
| 97 |
-
Parameters
|
| 98 |
-
----------
|
| 99 |
-
name : str
|
| 100 |
-
A model name listed by `clip.available_models()`, or the path to a model checkpoint containing the state_dict
|
| 101 |
-
|
| 102 |
-
alpha_vision_ckpt_pth: str
|
| 103 |
-
only changed when inferencing model instead of training
|
| 104 |
-
|
| 105 |
-
device : Union[str, torch.device]
|
| 106 |
-
The device to put the loaded model
|
| 107 |
-
|
| 108 |
-
jit : bool
|
| 109 |
-
Whether to load the optimized JIT model or more hackable non-JIT model (default).
|
| 110 |
-
|
| 111 |
-
download_root: str
|
| 112 |
-
path to download the model files; by default, it uses "~/.cache/clip"
|
| 113 |
-
|
| 114 |
-
Returns
|
| 115 |
-
-------
|
| 116 |
-
model : torch.nn.Module
|
| 117 |
-
The CLIP model
|
| 118 |
-
|
| 119 |
-
preprocess : Callable[[PIL.Image], torch.Tensor]
|
| 120 |
-
A torchvision transform that converts a PIL image into a tensor that the returned model can take as its input
|
| 121 |
-
"""
|
| 122 |
-
if name in _MODELS:
|
| 123 |
-
model_path = _download(_MODELS[name], download_root or os.path.expanduser("~/.cache/clip"))
|
| 124 |
-
elif os.path.isfile(name):
|
| 125 |
-
model_path = name
|
| 126 |
-
else:
|
| 127 |
-
raise RuntimeError(f"Model {name} not found; available models = {available_models()}")
|
| 128 |
-
|
| 129 |
-
with open(model_path, 'rb') as opened_file:
|
| 130 |
-
try:
|
| 131 |
-
# loading JIT archive
|
| 132 |
-
model = torch.jit.load(opened_file, map_location=device if jit else "cpu").eval()
|
| 133 |
-
state_dict = None
|
| 134 |
-
except RuntimeError:
|
| 135 |
-
# loading saved state dict
|
| 136 |
-
if jit:
|
| 137 |
-
warnings.warn(f"File {model_path} is not a JIT archive. Loading as a state dict instead")
|
| 138 |
-
jit = False
|
| 139 |
-
state_dict = torch.load(opened_file, map_location="cpu")
|
| 140 |
-
|
| 141 |
-
if not jit:
|
| 142 |
-
model = build_model(state_dict or model.state_dict(), lora_adapt=lora_adapt, rank=rank).to(device)
|
| 143 |
-
if str(device) == "cpu":
|
| 144 |
-
model.float()
|
| 145 |
-
# If a separate checkpoint is provided for the visual encoder (e.g., CLIP), load it
|
| 146 |
-
if alpha_vision_ckpt_pth != "None":
|
| 147 |
-
# Load the visual encoder weights from the given checkpoint path
|
| 148 |
-
model.visual.load_state_dict(torch.load(alpha_vision_ckpt_pth))
|
| 149 |
-
# Set the model to evaluation mode
|
| 150 |
-
# Note: If LoRA is used, it may merge LoRA weights into the base model here for inference
|
| 151 |
-
model.eval() # merge lora params if exists (for inference only)
|
| 152 |
-
return model, _transform(model.visual.input_resolution)
|
| 153 |
-
|
| 154 |
-
# patch the device names
|
| 155 |
-
device_holder = torch.jit.trace(lambda: torch.ones([]).to(torch.device(device)), example_inputs=[])
|
| 156 |
-
device_node = [n for n in device_holder.graph.findAllNodes("prim::Constant") if "Device" in repr(n)][-1]
|
| 157 |
-
|
| 158 |
-
def _node_get(node: torch._C.Node, key: str):
|
| 159 |
-
"""Gets attributes of a node which is polymorphic over return type.
|
| 160 |
-
|
| 161 |
-
From https://github.com/pytorch/pytorch/pull/82628
|
| 162 |
-
"""
|
| 163 |
-
sel = node.kindOf(key)
|
| 164 |
-
return getattr(node, sel)(key)
|
| 165 |
-
|
| 166 |
-
def patch_device(module):
|
| 167 |
-
try:
|
| 168 |
-
graphs = [module.graph] if hasattr(module, "graph") else []
|
| 169 |
-
except RuntimeError:
|
| 170 |
-
graphs = []
|
| 171 |
-
|
| 172 |
-
if hasattr(module, "forward1"):
|
| 173 |
-
graphs.append(module.forward1.graph)
|
| 174 |
-
|
| 175 |
-
for graph in graphs:
|
| 176 |
-
for node in graph.findAllNodes("prim::Constant"):
|
| 177 |
-
if "value" in node.attributeNames() and str(_node_get(node, "value")).startswith("cuda"):
|
| 178 |
-
node.copyAttributes(device_node)
|
| 179 |
-
|
| 180 |
-
model.apply(patch_device)
|
| 181 |
-
patch_device(model.encode_image)
|
| 182 |
-
patch_device(model.encode_text)
|
| 183 |
-
|
| 184 |
-
# patch dtype to float32 on CPU
|
| 185 |
-
if str(device) == "cpu":
|
| 186 |
-
float_holder = torch.jit.trace(lambda: torch.ones([]).float(), example_inputs=[])
|
| 187 |
-
float_input = list(float_holder.graph.findNode("aten::to").inputs())[1]
|
| 188 |
-
float_node = float_input.node()
|
| 189 |
-
|
| 190 |
-
def patch_float(module):
|
| 191 |
-
try:
|
| 192 |
-
graphs = [module.graph] if hasattr(module, "graph") else []
|
| 193 |
-
except RuntimeError:
|
| 194 |
-
graphs = []
|
| 195 |
-
|
| 196 |
-
if hasattr(module, "forward1"):
|
| 197 |
-
graphs.append(module.forward1.graph)
|
| 198 |
-
|
| 199 |
-
for graph in graphs:
|
| 200 |
-
for node in graph.findAllNodes("aten::to"):
|
| 201 |
-
inputs = list(node.inputs())
|
| 202 |
-
for i in [1, 2]: # dtype can be the second or third argument to aten::to()
|
| 203 |
-
if _node_get(inputs[i].node(), "value") == 5:
|
| 204 |
-
inputs[i].node().copyAttributes(float_node)
|
| 205 |
-
|
| 206 |
-
model.apply(patch_float)
|
| 207 |
-
patch_float(model.encode_image)
|
| 208 |
-
patch_float(model.encode_text)
|
| 209 |
-
|
| 210 |
-
model.float()
|
| 211 |
-
return model, _transform(model.input_resolution.item())
|
| 212 |
-
|
| 213 |
-
|
| 214 |
-
def tokenize(texts: Union[str, List[str]], context_length: int = 77, truncate: bool = True) -> Union[torch.IntTensor, torch.LongTensor]:
|
| 215 |
-
"""
|
| 216 |
-
Returns the tokenized representation of given input string(s)
|
| 217 |
-
|
| 218 |
-
Parameters
|
| 219 |
-
----------
|
| 220 |
-
texts : Union[str, List[str]]
|
| 221 |
-
An input string or a list of input strings to tokenize
|
| 222 |
-
|
| 223 |
-
context_length : int
|
| 224 |
-
The context length to use; all CLIP models use 77 as the context length
|
| 225 |
-
|
| 226 |
-
truncate: bool
|
| 227 |
-
Whether to truncate the text in case its encoding is longer than the context length
|
| 228 |
-
|
| 229 |
-
Returns
|
| 230 |
-
-------
|
| 231 |
-
A two-dimensional tensor containing the resulting tokens, shape = [number of input strings, context_length].
|
| 232 |
-
We return LongTensor when torch version is <1.8.0, since older index_select requires indices to be long.
|
| 233 |
-
"""
|
| 234 |
-
if isinstance(texts, str):
|
| 235 |
-
texts = [texts]
|
| 236 |
-
|
| 237 |
-
sot_token = _tokenizer.encoder["<|startoftext|>"]
|
| 238 |
-
eot_token = _tokenizer.encoder["<|endoftext|>"]
|
| 239 |
-
all_tokens = [[sot_token] + _tokenizer.encode(text) + [eot_token] for text in texts]
|
| 240 |
-
if packaging.version.parse(torch.__version__) < packaging.version.parse("1.8.0"):
|
| 241 |
-
result = torch.zeros(len(all_tokens), context_length, dtype=torch.long)
|
| 242 |
-
else:
|
| 243 |
-
result = torch.zeros(len(all_tokens), context_length, dtype=torch.int)
|
| 244 |
-
|
| 245 |
-
for i, tokens in enumerate(all_tokens):
|
| 246 |
-
if len(tokens) > context_length:
|
| 247 |
-
if truncate:
|
| 248 |
-
tokens = tokens[:context_length]
|
| 249 |
-
tokens[-1] = eot_token
|
| 250 |
-
else:
|
| 251 |
-
raise RuntimeError(f"Input {texts[i]} is too long for context length {context_length}")
|
| 252 |
-
result[i, :len(tokens)] = torch.tensor(tokens)
|
| 253 |
-
|
| 254 |
-
return result
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
src/alphaclip/alpha_clip/bpe_simple_vocab_16e6.txt.gz
DELETED
|
@@ -1,3 +0,0 @@
|
|
| 1 |
-
version https://git-lfs.github.com/spec/v1
|
| 2 |
-
oid sha256:924691ac288e54409236115652ad4aa250f48203de50a9e4722a6ecd48d6804a
|
| 3 |
-
size 1356917
|
|
|
|
|
|
|
|
|
|
|
|
src/alphaclip/alpha_clip/model.py
DELETED
|
@@ -1,609 +0,0 @@
|
|
| 1 |
-
from collections import OrderedDict
|
| 2 |
-
from typing import Tuple, Union
|
| 3 |
-
|
| 4 |
-
import numpy as np
|
| 5 |
-
import torch
|
| 6 |
-
import torch.nn.functional as F
|
| 7 |
-
from torch import nn
|
| 8 |
-
import loralib as lora
|
| 9 |
-
import math
|
| 10 |
-
import collections
|
| 11 |
-
|
| 12 |
-
class Bottleneck(nn.Module):
|
| 13 |
-
expansion = 4
|
| 14 |
-
|
| 15 |
-
def __init__(self, inplanes, planes, stride=1):
|
| 16 |
-
super().__init__()
|
| 17 |
-
|
| 18 |
-
# all conv layers have stride 1. an avgpool is performed after the second convolution when stride > 1
|
| 19 |
-
self.conv1 = nn.Conv2d(inplanes, planes, 1, bias=False)
|
| 20 |
-
self.bn1 = nn.BatchNorm2d(planes)
|
| 21 |
-
self.relu1 = nn.ReLU(inplace=True)
|
| 22 |
-
|
| 23 |
-
self.conv2 = nn.Conv2d(planes, planes, 3, padding=1, bias=False)
|
| 24 |
-
self.bn2 = nn.BatchNorm2d(planes)
|
| 25 |
-
self.relu2 = nn.ReLU(inplace=True)
|
| 26 |
-
|
| 27 |
-
self.avgpool = nn.AvgPool2d(stride) if stride > 1 else nn.Identity()
|
| 28 |
-
|
| 29 |
-
self.conv3 = nn.Conv2d(planes, planes * self.expansion, 1, bias=False)
|
| 30 |
-
self.bn3 = nn.BatchNorm2d(planes * self.expansion)
|
| 31 |
-
self.relu3 = nn.ReLU(inplace=True)
|
| 32 |
-
|
| 33 |
-
self.downsample = None
|
| 34 |
-
self.stride = stride
|
| 35 |
-
|
| 36 |
-
if stride > 1 or inplanes != planes * Bottleneck.expansion:
|
| 37 |
-
# downsampling layer is prepended with an avgpool, and the subsequent convolution has stride 1
|
| 38 |
-
self.downsample = nn.Sequential(OrderedDict([
|
| 39 |
-
("-1", nn.AvgPool2d(stride)),
|
| 40 |
-
("0", nn.Conv2d(inplanes, planes * self.expansion, 1, stride=1, bias=False)),
|
| 41 |
-
("1", nn.BatchNorm2d(planes * self.expansion))
|
| 42 |
-
]))
|
| 43 |
-
|
| 44 |
-
def forward(self, x: torch.Tensor):
|
| 45 |
-
identity = x
|
| 46 |
-
|
| 47 |
-
out = self.relu1(self.bn1(self.conv1(x)))
|
| 48 |
-
out = self.relu2(self.bn2(self.conv2(out)))
|
| 49 |
-
out = self.avgpool(out)
|
| 50 |
-
out = self.bn3(self.conv3(out))
|
| 51 |
-
|
| 52 |
-
if self.downsample is not None:
|
| 53 |
-
identity = self.downsample(x)
|
| 54 |
-
|
| 55 |
-
out += identity
|
| 56 |
-
out = self.relu3(out)
|
| 57 |
-
return out
|
| 58 |
-
|
| 59 |
-
|
| 60 |
-
class AttentionPool2d(nn.Module):
|
| 61 |
-
def __init__(self, spacial_dim: int, embed_dim: int, num_heads: int, output_dim: int = None):
|
| 62 |
-
super().__init__()
|
| 63 |
-
self.positional_embedding = nn.Parameter(torch.randn(spacial_dim ** 2 + 1, embed_dim) / embed_dim ** 0.5)
|
| 64 |
-
self.k_proj = nn.Linear(embed_dim, embed_dim)
|
| 65 |
-
self.q_proj = nn.Linear(embed_dim, embed_dim)
|
| 66 |
-
self.v_proj = nn.Linear(embed_dim, embed_dim)
|
| 67 |
-
self.c_proj = nn.Linear(embed_dim, output_dim or embed_dim)
|
| 68 |
-
self.num_heads = num_heads
|
| 69 |
-
|
| 70 |
-
def forward(self, x):
|
| 71 |
-
x = x.flatten(start_dim=2).permute(2, 0, 1) # NCHW -> (HW)NC
|
| 72 |
-
x = torch.cat([x.mean(dim=0, keepdim=True), x], dim=0) # (HW+1)NC
|
| 73 |
-
x = x + self.positional_embedding[:, None, :].to(x.dtype) # (HW+1)NC
|
| 74 |
-
x, _ = F.multi_head_attention_forward(
|
| 75 |
-
query=x[:1], key=x, value=x,
|
| 76 |
-
embed_dim_to_check=x.shape[-1],
|
| 77 |
-
num_heads=self.num_heads,
|
| 78 |
-
q_proj_weight=self.q_proj.weight,
|
| 79 |
-
k_proj_weight=self.k_proj.weight,
|
| 80 |
-
v_proj_weight=self.v_proj.weight,
|
| 81 |
-
in_proj_weight=None,
|
| 82 |
-
in_proj_bias=torch.cat([self.q_proj.bias, self.k_proj.bias, self.v_proj.bias]),
|
| 83 |
-
bias_k=None,
|
| 84 |
-
bias_v=None,
|
| 85 |
-
add_zero_attn=False,
|
| 86 |
-
dropout_p=0,
|
| 87 |
-
out_proj_weight=self.c_proj.weight,
|
| 88 |
-
out_proj_bias=self.c_proj.bias,
|
| 89 |
-
use_separate_proj_weight=True,
|
| 90 |
-
training=self.training,
|
| 91 |
-
need_weights=False
|
| 92 |
-
)
|
| 93 |
-
return x.squeeze(0)
|
| 94 |
-
|
| 95 |
-
|
| 96 |
-
class ModifiedResNet(nn.Module):
|
| 97 |
-
"""
|
| 98 |
-
A ResNet class that is similar to torchvision's but contains the following changes:
|
| 99 |
-
- There are now 3 "stem" convolutions as opposed to 1, with an average pool instead of a max pool.
|
| 100 |
-
- Performs anti-aliasing strided convolutions, where an avgpool is prepended to convolutions with stride > 1
|
| 101 |
-
- The final pooling layer is a QKV attention instead of an average pool
|
| 102 |
-
"""
|
| 103 |
-
|
| 104 |
-
def __init__(self, layers, output_dim, heads, input_resolution=224, width=64):
|
| 105 |
-
super().__init__()
|
| 106 |
-
self.output_dim = output_dim
|
| 107 |
-
self.input_resolution = input_resolution
|
| 108 |
-
|
| 109 |
-
# the 3-layer stem
|
| 110 |
-
self.conv1 = nn.Conv2d(3, width // 2, kernel_size=3, stride=2, padding=1, bias=False)
|
| 111 |
-
self.conv1_alpha = nn.Conv2d(in_channels=1, out_channels=width // 2, kernel_size=3, stride=2, padding=1, bias=False)
|
| 112 |
-
self.bn1 = nn.BatchNorm2d(width // 2)
|
| 113 |
-
self.relu1 = nn.ReLU(inplace=True)
|
| 114 |
-
self.conv2 = nn.Conv2d(width // 2, width // 2, kernel_size=3, padding=1, bias=False)
|
| 115 |
-
self.bn2 = nn.BatchNorm2d(width // 2)
|
| 116 |
-
self.relu2 = nn.ReLU(inplace=True)
|
| 117 |
-
self.conv3 = nn.Conv2d(width // 2, width, kernel_size=3, padding=1, bias=False)
|
| 118 |
-
self.bn3 = nn.BatchNorm2d(width)
|
| 119 |
-
self.relu3 = nn.ReLU(inplace=True)
|
| 120 |
-
self.avgpool = nn.AvgPool2d(2)
|
| 121 |
-
|
| 122 |
-
# residual layers
|
| 123 |
-
self._inplanes = width # this is a *mutable* variable used during construction
|
| 124 |
-
self.layer1 = self._make_layer(width, layers[0])
|
| 125 |
-
self.layer2 = self._make_layer(width * 2, layers[1], stride=2)
|
| 126 |
-
self.layer3 = self._make_layer(width * 4, layers[2], stride=2)
|
| 127 |
-
self.layer4 = self._make_layer(width * 8, layers[3], stride=2)
|
| 128 |
-
|
| 129 |
-
embed_dim = width * 32 # the ResNet feature dimension
|
| 130 |
-
self.attnpool = AttentionPool2d(input_resolution // 32, embed_dim, heads, output_dim)
|
| 131 |
-
|
| 132 |
-
def _make_layer(self, planes, blocks, stride=1):
|
| 133 |
-
layers = [Bottleneck(self._inplanes, planes, stride)]
|
| 134 |
-
|
| 135 |
-
self._inplanes = planes * Bottleneck.expansion
|
| 136 |
-
for _ in range(1, blocks):
|
| 137 |
-
layers.append(Bottleneck(self._inplanes, planes))
|
| 138 |
-
|
| 139 |
-
return nn.Sequential(*layers)
|
| 140 |
-
|
| 141 |
-
def forward(self, x, alpha=None):
|
| 142 |
-
def stem(x):
|
| 143 |
-
x = self.relu1(self.bn1(self.conv1(x) + self.conv1_alpha(alpha)))
|
| 144 |
-
x = self.relu2(self.bn2(self.conv2(x)))
|
| 145 |
-
x = self.relu3(self.bn3(self.conv3(x)))
|
| 146 |
-
x = self.avgpool(x)
|
| 147 |
-
return x
|
| 148 |
-
|
| 149 |
-
x = x.type(self.conv1.weight.dtype)
|
| 150 |
-
x = stem(x)
|
| 151 |
-
x = self.layer1(x)
|
| 152 |
-
x = self.layer2(x)
|
| 153 |
-
x = self.layer3(x)
|
| 154 |
-
x = self.layer4(x)
|
| 155 |
-
x = self.attnpool(x)
|
| 156 |
-
|
| 157 |
-
return x
|
| 158 |
-
|
| 159 |
-
|
| 160 |
-
class LayerNorm(nn.LayerNorm):
|
| 161 |
-
"""Subclass torch's LayerNorm to handle fp16."""
|
| 162 |
-
|
| 163 |
-
def forward(self, x: torch.Tensor):
|
| 164 |
-
orig_type = x.dtype
|
| 165 |
-
ret = super().forward(x.type(torch.float32))
|
| 166 |
-
return ret.type(orig_type)
|
| 167 |
-
|
| 168 |
-
|
| 169 |
-
class QuickGELU(nn.Module):
|
| 170 |
-
def forward(self, x: torch.Tensor):
|
| 171 |
-
return x * torch.sigmoid(1.702 * x)
|
| 172 |
-
|
| 173 |
-
class Attention(nn.Module):
|
| 174 |
-
def __init__(
|
| 175 |
-
self,
|
| 176 |
-
dim,
|
| 177 |
-
num_heads=8,
|
| 178 |
-
qkv_bias=True,
|
| 179 |
-
scaled_cosine=False,
|
| 180 |
-
scale_heads=False,
|
| 181 |
-
logit_scale_max=math.log(1. / 0.01),
|
| 182 |
-
attn_drop=0.,
|
| 183 |
-
proj_drop=0.,
|
| 184 |
-
lora_adapt=False,
|
| 185 |
-
rank=16
|
| 186 |
-
):
|
| 187 |
-
super().__init__()
|
| 188 |
-
self.scaled_cosine = scaled_cosine
|
| 189 |
-
self.scale_heads = scale_heads
|
| 190 |
-
assert dim % num_heads == 0, 'dim should be divisible by num_heads'
|
| 191 |
-
self.num_heads = num_heads
|
| 192 |
-
self.head_dim = dim // num_heads
|
| 193 |
-
self.scale = self.head_dim ** -0.5
|
| 194 |
-
self.logit_scale_max = logit_scale_max
|
| 195 |
-
|
| 196 |
-
# keeping in_proj in this form (instead of nn.Linear) to match weight scheme of original
|
| 197 |
-
if lora_adapt:
|
| 198 |
-
print("!!!!!!!!!!using lora for qkv projection!!!!!!!!!!")
|
| 199 |
-
self.in_proj = lora.MergedLinear(dim, 3*dim, r=rank, enable_lora=[True, False, True])
|
| 200 |
-
else:
|
| 201 |
-
self.in_proj = nn.Linear(dim, dim * 3)
|
| 202 |
-
# self.in_proj_weight = nn.Parameter(torch.randn((dim * 3, dim)) * self.scale)
|
| 203 |
-
# if qkv_bias:
|
| 204 |
-
# self.in_proj_bias = nn.Parameter(torch.zeros(dim * 3))
|
| 205 |
-
# else:
|
| 206 |
-
# self.in_proj_bias = None
|
| 207 |
-
|
| 208 |
-
if self.scaled_cosine:
|
| 209 |
-
self.logit_scale = nn.Parameter(torch.log(10 * torch.ones((num_heads, 1, 1))))
|
| 210 |
-
else:
|
| 211 |
-
self.logit_scale = None
|
| 212 |
-
self.attn_drop = nn.Dropout(attn_drop)
|
| 213 |
-
if self.scale_heads:
|
| 214 |
-
self.head_scale = nn.Parameter(torch.ones((num_heads, 1, 1)))
|
| 215 |
-
else:
|
| 216 |
-
self.head_scale = None
|
| 217 |
-
self.out_proj = nn.Linear(dim, dim) if not lora_adapt else lora.Linear(dim, dim, r=rank)
|
| 218 |
-
self.out_drop = nn.Dropout(proj_drop)
|
| 219 |
-
|
| 220 |
-
def forward(self, x, attn_mask = None):
|
| 221 |
-
L, N, C = x.shape
|
| 222 |
-
q, k, v = self.in_proj(x).chunk(3, dim=-1)
|
| 223 |
-
q = q.contiguous().view(L, N * self.num_heads, -1).transpose(0, 1)
|
| 224 |
-
k = k.contiguous().view(L, N * self.num_heads, -1).transpose(0, 1)
|
| 225 |
-
v = v.contiguous().view(L, N * self.num_heads, -1).transpose(0, 1)
|
| 226 |
-
|
| 227 |
-
if self.logit_scale is not None:
|
| 228 |
-
attn = torch.bmm(F.normalize(q, dim=-1), F.normalize(k, dim=-1).transpose(-1, -2))
|
| 229 |
-
logit_scale = torch.clamp(self.logit_scale, max=self.logit_scale_max).exp()
|
| 230 |
-
attn = attn.view(N, self.num_heads, L, L) * logit_scale
|
| 231 |
-
attn = attn.view(-1, L, L)
|
| 232 |
-
else:
|
| 233 |
-
q = q * self.scale
|
| 234 |
-
attn = torch.bmm(q, k.transpose(-2, -1))
|
| 235 |
-
|
| 236 |
-
if attn_mask is not None:
|
| 237 |
-
if attn_mask.dtype == torch.bool:
|
| 238 |
-
new_attn_mask = torch.zeros_like(attn_mask, dtype=q.dtype)
|
| 239 |
-
new_attn_mask.masked_fill_(attn_mask, float("-inf"))
|
| 240 |
-
attn_mask = new_attn_mask
|
| 241 |
-
attn += attn_mask
|
| 242 |
-
|
| 243 |
-
attn = attn.softmax(dim=-1)
|
| 244 |
-
attn = self.attn_drop(attn)
|
| 245 |
-
|
| 246 |
-
x = torch.bmm(attn, v)
|
| 247 |
-
if self.head_scale is not None:
|
| 248 |
-
x = x.view(N, self.num_heads, L, C) * self.head_scale
|
| 249 |
-
x = x.view(-1, L, C)
|
| 250 |
-
x = x.transpose(0, 1).reshape(L, N, C)
|
| 251 |
-
x = self.out_proj(x)
|
| 252 |
-
x = self.out_drop(x)
|
| 253 |
-
return x, attn
|
| 254 |
-
|
| 255 |
-
|
| 256 |
-
class CustomResidualAttentionBlock(nn.Module):
|
| 257 |
-
def __init__(self, d_model: int, n_head: int, attn_mask: torch.Tensor = None, lora_adapt=False, rank=16):
|
| 258 |
-
super().__init__()
|
| 259 |
-
|
| 260 |
-
self.attn = Attention(d_model, n_head, lora_adapt=lora_adapt, rank=rank)
|
| 261 |
-
self.ln_1 = LayerNorm(d_model)
|
| 262 |
-
self.mlp = nn.Sequential(OrderedDict([
|
| 263 |
-
("c_fc", nn.Linear(d_model, d_model * 4) if not lora_adapt else lora.Linear(d_model, d_model*4, r=rank)),
|
| 264 |
-
("gelu", QuickGELU()),
|
| 265 |
-
("c_proj", nn.Linear(d_model * 4, d_model) if not lora_adapt else lora.Linear(d_model*4, d_model, r=rank))
|
| 266 |
-
]))
|
| 267 |
-
self.ln_2 = LayerNorm(d_model)
|
| 268 |
-
self.attn_mask = attn_mask
|
| 269 |
-
|
| 270 |
-
def attention(self, x: torch.Tensor):
|
| 271 |
-
self.attn_mask = self.attn_mask.to(dtype=x.dtype, device=x.device) if self.attn_mask is not None else None
|
| 272 |
-
return self.attn(x, attn_mask=self.attn_mask)
|
| 273 |
-
|
| 274 |
-
def forward(self, x: torch.Tensor, return_attn=False):
|
| 275 |
-
attn_out, attn = self.attention(self.ln_1(x))
|
| 276 |
-
x = x + attn_out
|
| 277 |
-
x = x + self.mlp(self.ln_2(x))
|
| 278 |
-
if return_attn:
|
| 279 |
-
return x, attn
|
| 280 |
-
else:
|
| 281 |
-
return x
|
| 282 |
-
|
| 283 |
-
class ResidualAttentionBlock(nn.Module):
|
| 284 |
-
def __init__(self, d_model: int, n_head: int, attn_mask: torch.Tensor = None):
|
| 285 |
-
super().__init__()
|
| 286 |
-
|
| 287 |
-
self.attn = nn.MultiheadAttention(d_model, n_head)
|
| 288 |
-
self.ln_1 = LayerNorm(d_model)
|
| 289 |
-
self.mlp = nn.Sequential(OrderedDict([
|
| 290 |
-
("c_fc", nn.Linear(d_model, d_model * 4)),
|
| 291 |
-
("gelu", QuickGELU()),
|
| 292 |
-
("c_proj", nn.Linear(d_model * 4, d_model))
|
| 293 |
-
]))
|
| 294 |
-
self.ln_2 = LayerNorm(d_model)
|
| 295 |
-
self.attn_mask = attn_mask
|
| 296 |
-
|
| 297 |
-
def attention(self, x: torch.Tensor):
|
| 298 |
-
self.attn_mask = self.attn_mask.to(dtype=x.dtype, device=x.device) if self.attn_mask is not None else None
|
| 299 |
-
return self.attn(x, x, x, attn_mask=self.attn_mask)[0]
|
| 300 |
-
|
| 301 |
-
def forward(self, x: torch.Tensor):
|
| 302 |
-
x = x + self.attention(self.ln_1(x))
|
| 303 |
-
x = x + self.mlp(self.ln_2(x))
|
| 304 |
-
return x
|
| 305 |
-
|
| 306 |
-
class Transformer(nn.Module):
|
| 307 |
-
def __init__(self, width: int, layers: int, heads: int, attn_mask: torch.Tensor = None):
|
| 308 |
-
super().__init__()
|
| 309 |
-
self.width = width
|
| 310 |
-
self.layers = layers
|
| 311 |
-
self.resblocks = nn.Sequential(*[ResidualAttentionBlock(width, heads, attn_mask) for _ in range(layers)])
|
| 312 |
-
|
| 313 |
-
def forward(self, x: torch.Tensor):
|
| 314 |
-
return self.resblocks(x)
|
| 315 |
-
|
| 316 |
-
class CustomTransformer(nn.Module):
|
| 317 |
-
def __init__(self, width: int, layers: int, heads: int, attn_mask: torch.Tensor = None, lora_adapt=False, rank=16):
|
| 318 |
-
super().__init__()
|
| 319 |
-
self.width = width
|
| 320 |
-
self.layers = layers
|
| 321 |
-
self.resblocks = nn.Sequential(*[CustomResidualAttentionBlock(width, heads, attn_mask, lora_adapt=lora_adapt, rank=rank) for _ in range(layers)])
|
| 322 |
-
|
| 323 |
-
def forward(self, x: torch.Tensor, return_attn=False):
|
| 324 |
-
if return_attn:
|
| 325 |
-
for i, block in enumerate(self.resblocks):
|
| 326 |
-
if i == len(self.resblocks) - 1:
|
| 327 |
-
return block(x, return_attn=True)
|
| 328 |
-
else:
|
| 329 |
-
x = block(x)
|
| 330 |
-
assert False
|
| 331 |
-
return self.resblocks(x)
|
| 332 |
-
|
| 333 |
-
class VisionTransformer(nn.Module):
|
| 334 |
-
def __init__(self, input_resolution: int, patch_size: int, width: int, layers: int, heads: int, output_dim: int, lora_adapt=False, rank=16):
|
| 335 |
-
super().__init__()
|
| 336 |
-
self.input_resolution = input_resolution
|
| 337 |
-
self.output_dim = output_dim
|
| 338 |
-
self.conv1 = nn.Conv2d(in_channels=3, out_channels=width, kernel_size=patch_size, stride=patch_size, bias=False)
|
| 339 |
-
self.conv1_alpha = nn.Conv2d(in_channels=1, out_channels=width, kernel_size=patch_size, stride=patch_size, bias=False)
|
| 340 |
-
|
| 341 |
-
scale = width ** -0.5
|
| 342 |
-
self.class_embedding = nn.Parameter(scale * torch.randn(width))
|
| 343 |
-
self.positional_embedding = nn.Parameter(scale * torch.randn((input_resolution // patch_size) ** 2 + 1, width))
|
| 344 |
-
self.ln_pre = LayerNorm(width)
|
| 345 |
-
|
| 346 |
-
self.transformer = CustomTransformer(width, layers, heads, lora_adapt=lora_adapt, rank=rank)
|
| 347 |
-
|
| 348 |
-
self.ln_post = LayerNorm(width)
|
| 349 |
-
self.proj = nn.Parameter(scale * torch.randn(width, output_dim))
|
| 350 |
-
|
| 351 |
-
def forward(self, x: torch.Tensor, alpha=None, return_attn=False, return_patches=False):
|
| 352 |
-
# if x dtype is different from conv1, convert it
|
| 353 |
-
if x.dtype != self.conv1.weight.dtype:
|
| 354 |
-
x = x.type(self.conv1.weight.dtype)
|
| 355 |
-
|
| 356 |
-
if alpha.dtype != self.conv1_alpha.weight.dtype:
|
| 357 |
-
alpha = alpha.type(self.conv1_alpha.weight.dtype)
|
| 358 |
-
|
| 359 |
-
x = self.conv1(x) # shape = [*, width, grid, grid]
|
| 360 |
-
# ASSUME alpha is always not None!
|
| 361 |
-
x = x + self.conv1_alpha(alpha)
|
| 362 |
-
x = x.reshape(x.shape[0], x.shape[1], -1) # shape = [*, width, grid ** 2]
|
| 363 |
-
x = x.permute(0, 2, 1) # shape = [*, grid ** 2, width]
|
| 364 |
-
x = torch.cat([self.class_embedding.to(x.dtype) + torch.zeros(x.shape[0], 1, x.shape[-1], dtype=x.dtype, device=x.device), x], dim=1) # shape = [*, grid ** 2 + 1, width]
|
| 365 |
-
x = x + self.positional_embedding.to(x.dtype)
|
| 366 |
-
x = self.ln_pre(x)
|
| 367 |
-
|
| 368 |
-
x = x.permute(1, 0, 2) # NLD -> LND
|
| 369 |
-
if return_attn:
|
| 370 |
-
x, attn_last = self.transformer(x, return_attn=True)
|
| 371 |
-
else:
|
| 372 |
-
x = self.transformer(x, return_attn=False)
|
| 373 |
-
x = x.permute(1, 0, 2) # LND -> NLD
|
| 374 |
-
|
| 375 |
-
if not return_patches:
|
| 376 |
-
x = self.ln_post(x[:, 0, :])
|
| 377 |
-
else:
|
| 378 |
-
x = self.ln_post(x)
|
| 379 |
-
|
| 380 |
-
if self.proj is not None:
|
| 381 |
-
x = x @ self.proj
|
| 382 |
-
if return_attn:
|
| 383 |
-
return x, attn_last
|
| 384 |
-
else:
|
| 385 |
-
return x
|
| 386 |
-
|
| 387 |
-
|
| 388 |
-
class CLIP(nn.Module):
|
| 389 |
-
def __init__(self,
|
| 390 |
-
embed_dim: int,
|
| 391 |
-
# vision
|
| 392 |
-
image_resolution: int,
|
| 393 |
-
vision_layers: Union[Tuple[int, int, int, int], int],
|
| 394 |
-
vision_width: int,
|
| 395 |
-
vision_patch_size: int,
|
| 396 |
-
# text
|
| 397 |
-
context_length: int,
|
| 398 |
-
vocab_size: int,
|
| 399 |
-
transformer_width: int,
|
| 400 |
-
transformer_heads: int,
|
| 401 |
-
transformer_layers: int,
|
| 402 |
-
lora_adapt = False,
|
| 403 |
-
rank = 16,
|
| 404 |
-
):
|
| 405 |
-
super().__init__()
|
| 406 |
-
|
| 407 |
-
self.context_length = context_length
|
| 408 |
-
|
| 409 |
-
if isinstance(vision_layers, (tuple, list)):
|
| 410 |
-
vision_heads = vision_width * 32 // 64
|
| 411 |
-
self.visual = ModifiedResNet(
|
| 412 |
-
layers=vision_layers,
|
| 413 |
-
output_dim=embed_dim,
|
| 414 |
-
heads=vision_heads,
|
| 415 |
-
input_resolution=image_resolution,
|
| 416 |
-
width=vision_width
|
| 417 |
-
)
|
| 418 |
-
else:
|
| 419 |
-
vision_heads = vision_width // 64
|
| 420 |
-
self.visual = VisionTransformer(
|
| 421 |
-
input_resolution=image_resolution,
|
| 422 |
-
patch_size=vision_patch_size,
|
| 423 |
-
width=vision_width,
|
| 424 |
-
layers=vision_layers,
|
| 425 |
-
heads=vision_heads,
|
| 426 |
-
output_dim=embed_dim,
|
| 427 |
-
lora_adapt=lora_adapt,
|
| 428 |
-
rank=rank
|
| 429 |
-
)
|
| 430 |
-
|
| 431 |
-
self.transformer = Transformer(
|
| 432 |
-
width=transformer_width,
|
| 433 |
-
layers=transformer_layers,
|
| 434 |
-
heads=transformer_heads,
|
| 435 |
-
attn_mask=self.build_attention_mask()
|
| 436 |
-
)
|
| 437 |
-
|
| 438 |
-
self.vocab_size = vocab_size
|
| 439 |
-
self.token_embedding = nn.Embedding(vocab_size, transformer_width)
|
| 440 |
-
self.positional_embedding = nn.Parameter(torch.empty(self.context_length, transformer_width))
|
| 441 |
-
self.ln_final = LayerNorm(transformer_width)
|
| 442 |
-
|
| 443 |
-
self.text_projection = nn.Parameter(torch.empty(transformer_width, embed_dim))
|
| 444 |
-
self.logit_scale = nn.Parameter(torch.ones([]) * np.log(1 / 0.07))
|
| 445 |
-
|
| 446 |
-
self.initialize_parameters()
|
| 447 |
-
|
| 448 |
-
def initialize_parameters(self):
|
| 449 |
-
nn.init.normal_(self.token_embedding.weight, std=0.02)
|
| 450 |
-
nn.init.normal_(self.positional_embedding, std=0.01)
|
| 451 |
-
|
| 452 |
-
if isinstance(self.visual, ModifiedResNet):
|
| 453 |
-
if self.visual.attnpool is not None:
|
| 454 |
-
std = self.visual.attnpool.c_proj.in_features ** -0.5
|
| 455 |
-
nn.init.normal_(self.visual.attnpool.q_proj.weight, std=std)
|
| 456 |
-
nn.init.normal_(self.visual.attnpool.k_proj.weight, std=std)
|
| 457 |
-
nn.init.normal_(self.visual.attnpool.v_proj.weight, std=std)
|
| 458 |
-
nn.init.normal_(self.visual.attnpool.c_proj.weight, std=std)
|
| 459 |
-
|
| 460 |
-
for resnet_block in [self.visual.layer1, self.visual.layer2, self.visual.layer3, self.visual.layer4]:
|
| 461 |
-
for name, param in resnet_block.named_parameters():
|
| 462 |
-
if name.endswith("bn3.weight"):
|
| 463 |
-
nn.init.zeros_(param)
|
| 464 |
-
|
| 465 |
-
proj_std = (self.transformer.width ** -0.5) * ((2 * self.transformer.layers) ** -0.5)
|
| 466 |
-
attn_std = self.transformer.width ** -0.5
|
| 467 |
-
fc_std = (2 * self.transformer.width) ** -0.5
|
| 468 |
-
for block in self.transformer.resblocks:
|
| 469 |
-
nn.init.normal_(block.attn.in_proj_weight, std=attn_std)
|
| 470 |
-
nn.init.normal_(block.attn.out_proj.weight, std=proj_std)
|
| 471 |
-
nn.init.normal_(block.mlp.c_fc.weight, std=fc_std)
|
| 472 |
-
nn.init.normal_(block.mlp.c_proj.weight, std=proj_std)
|
| 473 |
-
|
| 474 |
-
if self.text_projection is not None:
|
| 475 |
-
nn.init.normal_(self.text_projection, std=self.transformer.width ** -0.5)
|
| 476 |
-
|
| 477 |
-
def build_attention_mask(self):
|
| 478 |
-
# lazily create causal attention mask, with full attention between the vision tokens
|
| 479 |
-
# pytorch uses additive attention mask; fill with -inf
|
| 480 |
-
mask = torch.empty(self.context_length, self.context_length)
|
| 481 |
-
mask.fill_(float("-inf"))
|
| 482 |
-
mask.triu_(1) # zero out the lower diagonal
|
| 483 |
-
return mask
|
| 484 |
-
|
| 485 |
-
@property
|
| 486 |
-
def dtype(self):
|
| 487 |
-
if not hasattr(self.visual, "conv1"):
|
| 488 |
-
return self.visual.module.conv1.weight.dtype
|
| 489 |
-
return self.visual.conv1.weight.dtype
|
| 490 |
-
|
| 491 |
-
def encode_image(self, image, alpha):
|
| 492 |
-
assert alpha is not None
|
| 493 |
-
return self.visual(image.type(self.dtype), alpha.type(self.dtype))
|
| 494 |
-
|
| 495 |
-
def encode_text(self, text):
|
| 496 |
-
x = self.token_embedding(text).type(self.dtype) # [batch_size, n_ctx, d_model]
|
| 497 |
-
|
| 498 |
-
x = x + self.positional_embedding.type(self.dtype)
|
| 499 |
-
x = x.permute(1, 0, 2) # NLD -> LND
|
| 500 |
-
x = self.transformer(x)
|
| 501 |
-
x = x.permute(1, 0, 2) # LND -> NLD
|
| 502 |
-
x = self.ln_final(x).type(self.dtype)
|
| 503 |
-
|
| 504 |
-
# x.shape = [batch_size, n_ctx, transformer.width]
|
| 505 |
-
# take features from the eot embedding (eot_token is the highest number in each sequence)
|
| 506 |
-
x = x[torch.arange(x.shape[0]), text.argmax(dim=-1)] @ self.text_projection
|
| 507 |
-
|
| 508 |
-
return x
|
| 509 |
-
|
| 510 |
-
def forward(self, image, text, alpha):
|
| 511 |
-
|
| 512 |
-
image_features = self.encode_image(image, alpha)
|
| 513 |
-
text_features = self.encode_text(text)
|
| 514 |
-
|
| 515 |
-
# normalized features
|
| 516 |
-
image_features = image_features / image_features.norm(dim=1, keepdim=True)
|
| 517 |
-
text_features = text_features / text_features.norm(dim=1, keepdim=True)
|
| 518 |
-
|
| 519 |
-
# cosine similarity as logits
|
| 520 |
-
logit_scale = self.logit_scale.exp()
|
| 521 |
-
logits_per_image = logit_scale * image_features @ text_features.t()
|
| 522 |
-
logits_per_text = logits_per_image.t()
|
| 523 |
-
|
| 524 |
-
# shape = [global_batch_size, global_batch_size]
|
| 525 |
-
return logits_per_image, logits_per_text
|
| 526 |
-
|
| 527 |
-
|
| 528 |
-
def convert_weights(model: nn.Module):
|
| 529 |
-
"""Convert applicable model parameters to fp16"""
|
| 530 |
-
|
| 531 |
-
def _convert_weights_to_fp16(l):
|
| 532 |
-
if isinstance(l, (nn.Conv1d, nn.Conv2d, nn.Linear)):
|
| 533 |
-
l.weight.data = l.weight.data.half()
|
| 534 |
-
if l.bias is not None:
|
| 535 |
-
l.bias.data = l.bias.data.half()
|
| 536 |
-
|
| 537 |
-
if isinstance(l, nn.MultiheadAttention):
|
| 538 |
-
for attr in [*[f"{s}_proj_weight" for s in ["in", "q", "k", "v"]], "in_proj_bias", "bias_k", "bias_v"]:
|
| 539 |
-
tensor = getattr(l, attr)
|
| 540 |
-
if tensor is not None:
|
| 541 |
-
tensor.data = tensor.data.half()
|
| 542 |
-
|
| 543 |
-
for name in ["text_projection", "proj"]:
|
| 544 |
-
if hasattr(l, name):
|
| 545 |
-
attr = getattr(l, name)
|
| 546 |
-
if attr is not None:
|
| 547 |
-
attr.data = attr.data.half()
|
| 548 |
-
|
| 549 |
-
model.apply(_convert_weights_to_fp16)
|
| 550 |
-
|
| 551 |
-
|
| 552 |
-
def build_model(state_dict: dict, lora_adapt=False, rank=16):
|
| 553 |
-
vit = "visual.proj" in state_dict
|
| 554 |
-
|
| 555 |
-
if vit:
|
| 556 |
-
vision_width = state_dict["visual.conv1.weight"].shape[0]
|
| 557 |
-
vision_layers = len([k for k in state_dict.keys() if k.startswith("visual.") and k.endswith(".attn.in_proj_weight")])
|
| 558 |
-
vision_patch_size = state_dict["visual.conv1.weight"].shape[-1]
|
| 559 |
-
grid_size = round((state_dict["visual.positional_embedding"].shape[0] - 1) ** 0.5)
|
| 560 |
-
image_resolution = vision_patch_size * grid_size
|
| 561 |
-
else:
|
| 562 |
-
counts: list = [len(set(k.split(".")[2] for k in state_dict if k.startswith(f"visual.layer{b}"))) for b in [1, 2, 3, 4]]
|
| 563 |
-
vision_layers = tuple(counts)
|
| 564 |
-
vision_width = state_dict["visual.layer1.0.conv1.weight"].shape[0]
|
| 565 |
-
output_width = round((state_dict["visual.attnpool.positional_embedding"].shape[0] - 1) ** 0.5)
|
| 566 |
-
vision_patch_size = None
|
| 567 |
-
assert output_width ** 2 + 1 == state_dict["visual.attnpool.positional_embedding"].shape[0]
|
| 568 |
-
image_resolution = output_width * 32
|
| 569 |
-
|
| 570 |
-
embed_dim = state_dict["text_projection"].shape[1]
|
| 571 |
-
context_length = state_dict["positional_embedding"].shape[0]
|
| 572 |
-
vocab_size = state_dict["token_embedding.weight"].shape[0]
|
| 573 |
-
transformer_width = state_dict["ln_final.weight"].shape[0]
|
| 574 |
-
transformer_heads = transformer_width // 64
|
| 575 |
-
transformer_layers = len(set(k.split(".")[2] for k in state_dict if k.startswith("transformer.resblocks")))
|
| 576 |
-
|
| 577 |
-
# always load lora version
|
| 578 |
-
model = CLIP(
|
| 579 |
-
embed_dim,
|
| 580 |
-
image_resolution, vision_layers, vision_width, vision_patch_size,
|
| 581 |
-
context_length, vocab_size, transformer_width, transformer_heads, transformer_layers,
|
| 582 |
-
lora_adapt=lora_adapt, rank=rank,
|
| 583 |
-
)
|
| 584 |
-
|
| 585 |
-
for key in ["input_resolution", "context_length", "vocab_size"]:
|
| 586 |
-
if key in state_dict:
|
| 587 |
-
del state_dict[key]
|
| 588 |
-
# para_wb to linear
|
| 589 |
-
new_state_dict = collections.OrderedDict()
|
| 590 |
-
for k, v in state_dict.items():
|
| 591 |
-
if 'visual' in k:
|
| 592 |
-
if 'in_proj_weight' in k:
|
| 593 |
-
new_state_dict[k.replace('in_proj_weight', 'in_proj.weight')] = v
|
| 594 |
-
elif 'in_proj_bias' in k:
|
| 595 |
-
new_state_dict[k.replace('in_proj_bias', 'in_proj.bias')] = v
|
| 596 |
-
else:
|
| 597 |
-
new_state_dict[k] = v
|
| 598 |
-
else:
|
| 599 |
-
new_state_dict[k] = v
|
| 600 |
-
|
| 601 |
-
state_dict = new_state_dict
|
| 602 |
-
# add rgba_conv_weight
|
| 603 |
-
if 'visual.conv1_alpha.weight' not in state_dict.keys(): # zero initialization on alpha channel
|
| 604 |
-
rgb_weight = state_dict['visual.conv1.weight'].clone().detach()
|
| 605 |
-
rgba_weigth = torch.zeros_like(rgb_weight)[:, 0:1, :, :]
|
| 606 |
-
state_dict['visual.conv1_alpha.weight'] = rgba_weigth
|
| 607 |
-
convert_weights(model)
|
| 608 |
-
model.load_state_dict(state_dict, strict=False)
|
| 609 |
-
return model.eval()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
src/alphaclip/alpha_clip/simple_tokenizer.py
DELETED
|
@@ -1,132 +0,0 @@
|
|
| 1 |
-
import gzip
|
| 2 |
-
import html
|
| 3 |
-
import os
|
| 4 |
-
from functools import lru_cache
|
| 5 |
-
|
| 6 |
-
import ftfy
|
| 7 |
-
import regex as re
|
| 8 |
-
|
| 9 |
-
|
| 10 |
-
@lru_cache()
|
| 11 |
-
def default_bpe():
|
| 12 |
-
return os.path.join(os.path.dirname(os.path.abspath(__file__)), "bpe_simple_vocab_16e6.txt.gz")
|
| 13 |
-
|
| 14 |
-
|
| 15 |
-
@lru_cache()
|
| 16 |
-
def bytes_to_unicode():
|
| 17 |
-
"""
|
| 18 |
-
Returns list of utf-8 byte and a corresponding list of unicode strings.
|
| 19 |
-
The reversible bpe codes work on unicode strings.
|
| 20 |
-
This means you need a large # of unicode characters in your vocab if you want to avoid UNKs.
|
| 21 |
-
When you're at something like a 10B token dataset you end up needing around 5K for decent coverage.
|
| 22 |
-
This is a signficant percentage of your normal, say, 32K bpe vocab.
|
| 23 |
-
To avoid that, we want lookup tables between utf-8 bytes and unicode strings.
|
| 24 |
-
And avoids mapping to whitespace/control characters the bpe code barfs on.
|
| 25 |
-
"""
|
| 26 |
-
bs = list(range(ord("!"), ord("~")+1))+list(range(ord("ยก"), ord("ยฌ")+1))+list(range(ord("ยฎ"), ord("รฟ")+1))
|
| 27 |
-
cs = bs[:]
|
| 28 |
-
n = 0
|
| 29 |
-
for b in range(2**8):
|
| 30 |
-
if b not in bs:
|
| 31 |
-
bs.append(b)
|
| 32 |
-
cs.append(2**8+n)
|
| 33 |
-
n += 1
|
| 34 |
-
cs = [chr(n) for n in cs]
|
| 35 |
-
return dict(zip(bs, cs))
|
| 36 |
-
|
| 37 |
-
|
| 38 |
-
def get_pairs(word):
|
| 39 |
-
"""Return set of symbol pairs in a word.
|
| 40 |
-
Word is represented as tuple of symbols (symbols being variable-length strings).
|
| 41 |
-
"""
|
| 42 |
-
pairs = set()
|
| 43 |
-
prev_char = word[0]
|
| 44 |
-
for char in word[1:]:
|
| 45 |
-
pairs.add((prev_char, char))
|
| 46 |
-
prev_char = char
|
| 47 |
-
return pairs
|
| 48 |
-
|
| 49 |
-
|
| 50 |
-
def basic_clean(text):
|
| 51 |
-
text = ftfy.fix_text(text)
|
| 52 |
-
text = html.unescape(html.unescape(text))
|
| 53 |
-
return text.strip()
|
| 54 |
-
|
| 55 |
-
|
| 56 |
-
def whitespace_clean(text):
|
| 57 |
-
text = re.sub(r'\s+', ' ', text)
|
| 58 |
-
text = text.strip()
|
| 59 |
-
return text
|
| 60 |
-
|
| 61 |
-
|
| 62 |
-
class SimpleTokenizer(object):
|
| 63 |
-
def __init__(self, bpe_path: str = default_bpe()):
|
| 64 |
-
self.byte_encoder = bytes_to_unicode()
|
| 65 |
-
self.byte_decoder = {v: k for k, v in self.byte_encoder.items()}
|
| 66 |
-
merges = gzip.open(bpe_path).read().decode("utf-8").split('\n')
|
| 67 |
-
merges = merges[1:49152-256-2+1]
|
| 68 |
-
merges = [tuple(merge.split()) for merge in merges]
|
| 69 |
-
vocab = list(bytes_to_unicode().values())
|
| 70 |
-
vocab = vocab + [v+'</w>' for v in vocab]
|
| 71 |
-
for merge in merges:
|
| 72 |
-
vocab.append(''.join(merge))
|
| 73 |
-
vocab.extend(['<|startoftext|>', '<|endoftext|>'])
|
| 74 |
-
self.encoder = dict(zip(vocab, range(len(vocab))))
|
| 75 |
-
self.decoder = {v: k for k, v in self.encoder.items()}
|
| 76 |
-
self.bpe_ranks = dict(zip(merges, range(len(merges))))
|
| 77 |
-
self.cache = {'<|startoftext|>': '<|startoftext|>', '<|endoftext|>': '<|endoftext|>'}
|
| 78 |
-
self.pat = re.compile(r"""<\|startoftext\|>|<\|endoftext\|>|'s|'t|'re|'ve|'m|'ll|'d|[\p{L}]+|[\p{N}]|[^\s\p{L}\p{N}]+""", re.IGNORECASE)
|
| 79 |
-
|
| 80 |
-
def bpe(self, token):
|
| 81 |
-
if token in self.cache:
|
| 82 |
-
return self.cache[token]
|
| 83 |
-
word = tuple(token[:-1]) + ( token[-1] + '</w>',)
|
| 84 |
-
pairs = get_pairs(word)
|
| 85 |
-
|
| 86 |
-
if not pairs:
|
| 87 |
-
return token+'</w>'
|
| 88 |
-
|
| 89 |
-
while True:
|
| 90 |
-
bigram = min(pairs, key = lambda pair: self.bpe_ranks.get(pair, float('inf')))
|
| 91 |
-
if bigram not in self.bpe_ranks:
|
| 92 |
-
break
|
| 93 |
-
first, second = bigram
|
| 94 |
-
new_word = []
|
| 95 |
-
i = 0
|
| 96 |
-
while i < len(word):
|
| 97 |
-
try:
|
| 98 |
-
j = word.index(first, i)
|
| 99 |
-
new_word.extend(word[i:j])
|
| 100 |
-
i = j
|
| 101 |
-
except:
|
| 102 |
-
new_word.extend(word[i:])
|
| 103 |
-
break
|
| 104 |
-
|
| 105 |
-
if word[i] == first and i < len(word)-1 and word[i+1] == second:
|
| 106 |
-
new_word.append(first+second)
|
| 107 |
-
i += 2
|
| 108 |
-
else:
|
| 109 |
-
new_word.append(word[i])
|
| 110 |
-
i += 1
|
| 111 |
-
new_word = tuple(new_word)
|
| 112 |
-
word = new_word
|
| 113 |
-
if len(word) == 1:
|
| 114 |
-
break
|
| 115 |
-
else:
|
| 116 |
-
pairs = get_pairs(word)
|
| 117 |
-
word = ' '.join(word)
|
| 118 |
-
self.cache[token] = word
|
| 119 |
-
return word
|
| 120 |
-
|
| 121 |
-
def encode(self, text):
|
| 122 |
-
bpe_tokens = []
|
| 123 |
-
text = whitespace_clean(basic_clean(text)).lower()
|
| 124 |
-
for token in re.findall(self.pat, text):
|
| 125 |
-
token = ''.join(self.byte_encoder[b] for b in token.encode('utf-8'))
|
| 126 |
-
bpe_tokens.extend(self.encoder[bpe_token] for bpe_token in self.bpe(token).split(' '))
|
| 127 |
-
return bpe_tokens
|
| 128 |
-
|
| 129 |
-
def decode(self, tokens):
|
| 130 |
-
text = ''.join([self.decoder[token] for token in tokens])
|
| 131 |
-
text = bytearray([self.byte_decoder[c] for c in text]).decode('utf-8', errors="replace").replace('</w>', ' ')
|
| 132 |
-
return text
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
src/alphaclip/alpha_mask_utils.py
DELETED
|
@@ -1,111 +0,0 @@
|
|
| 1 |
-
"""
|
| 2 |
-
Utility functions for converting bboxes and traces to alpha masks for AlphaClip.
|
| 3 |
-
"""
|
| 4 |
-
|
| 5 |
-
import torch
|
| 6 |
-
import math
|
| 7 |
-
|
| 8 |
-
|
| 9 |
-
def bbox_to_alpha_mask(bbox, grid_size, patch_size, crop_dim):
|
| 10 |
-
"""
|
| 11 |
-
Convert a single bounding box to an alpha mask for AlphaClip.
|
| 12 |
-
|
| 13 |
-
Args:
|
| 14 |
-
bbox: [x_min, y_min, w, h] format in original coordinates
|
| 15 |
-
grid_size: Number of patches per side (e.g., 37 for 518/14)
|
| 16 |
-
patch_size: Size of each patch in pixels
|
| 17 |
-
crop_dim: Size of the cropped image
|
| 18 |
-
|
| 19 |
-
Returns:
|
| 20 |
-
alpha_mask: Binary mask of shape (grid_size, grid_size)
|
| 21 |
-
"""
|
| 22 |
-
alpha_mask = torch.zeros((grid_size, grid_size))
|
| 23 |
-
|
| 24 |
-
# Convert bbox to patch coordinates
|
| 25 |
-
x_min, y_min, w, h = bbox
|
| 26 |
-
x_max = x_min + w
|
| 27 |
-
y_max = y_min + h
|
| 28 |
-
|
| 29 |
-
# Scale to patch grid coordinates
|
| 30 |
-
x1_patch = int(x_min // patch_size)
|
| 31 |
-
y1_patch = int(y_min // patch_size)
|
| 32 |
-
x2_patch = int(x_max // patch_size)
|
| 33 |
-
y2_patch = int(y_max // patch_size)
|
| 34 |
-
|
| 35 |
-
# Clamp to grid bounds
|
| 36 |
-
x1_patch = max(0, min(x1_patch, grid_size - 1))
|
| 37 |
-
y1_patch = max(0, min(y1_patch, grid_size - 1))
|
| 38 |
-
x2_patch = max(0, min(x2_patch, grid_size)) # Allow up to grid_size for exclusive end
|
| 39 |
-
y2_patch = max(0, min(y2_patch, grid_size))
|
| 40 |
-
|
| 41 |
-
# Set the region to 1 (using slice notation for proper indexing)
|
| 42 |
-
if x2_patch > x1_patch and y2_patch > y1_patch:
|
| 43 |
-
alpha_mask[y1_patch:y2_patch, x1_patch:x2_patch] = 1.0
|
| 44 |
-
|
| 45 |
-
return alpha_mask
|
| 46 |
-
|
| 47 |
-
|
| 48 |
-
def bboxes_to_alpha_mask(bboxes, grid_size, patch_size, crop_dim):
|
| 49 |
-
"""
|
| 50 |
-
Convert multiple bboxes to a single OR-ed alpha mask.
|
| 51 |
-
|
| 52 |
-
Args:
|
| 53 |
-
bboxes: Tensor of bboxes in [x_min, y_min, w, h] format, shape [n_boxes, 4]
|
| 54 |
-
grid_size: Number of patches per side
|
| 55 |
-
patch_size: Size of each patch in pixels
|
| 56 |
-
crop_dim: Size of the cropped image
|
| 57 |
-
|
| 58 |
-
Returns:
|
| 59 |
-
alpha_mask: Binary mask of shape (grid_size, grid_size)
|
| 60 |
-
"""
|
| 61 |
-
alpha_mask = torch.zeros((grid_size, grid_size))
|
| 62 |
-
|
| 63 |
-
for bbox in bboxes:
|
| 64 |
-
# Skip dummy boxes (negative values)
|
| 65 |
-
if bbox.sum().item() < 0:
|
| 66 |
-
continue
|
| 67 |
-
|
| 68 |
-
bbox_mask = bbox_to_alpha_mask(bbox, grid_size, patch_size, crop_dim)
|
| 69 |
-
alpha_mask = torch.logical_or(alpha_mask, bbox_mask).float()
|
| 70 |
-
|
| 71 |
-
return alpha_mask
|
| 72 |
-
|
| 73 |
-
|
| 74 |
-
def trace_to_alpha_mask(trace, grid_size):
|
| 75 |
-
"""
|
| 76 |
-
Convert a trace to an alpha mask using the existing map_traces_to_grid function.
|
| 77 |
-
|
| 78 |
-
Args:
|
| 79 |
-
trace: List of trace points with 'x' and 'y' coordinates (normalized 0-1)
|
| 80 |
-
grid_size: Number of patches per side
|
| 81 |
-
|
| 82 |
-
Returns:
|
| 83 |
-
alpha_mask: Binary mask of shape (grid_size, grid_size)
|
| 84 |
-
"""
|
| 85 |
-
from src.bbox_utils import map_traces_to_grid
|
| 86 |
-
|
| 87 |
-
alpha_mask = map_traces_to_grid(trace, grid_size)
|
| 88 |
-
# Convert to binary (any value > 0 becomes 1)
|
| 89 |
-
alpha_mask = (alpha_mask > 0).float()
|
| 90 |
-
|
| 91 |
-
return alpha_mask
|
| 92 |
-
|
| 93 |
-
|
| 94 |
-
def traces_to_alpha_mask(traces, grid_size):
|
| 95 |
-
"""
|
| 96 |
-
Convert multiple traces to a single OR-ed alpha mask.
|
| 97 |
-
|
| 98 |
-
Args:
|
| 99 |
-
traces: List of traces
|
| 100 |
-
grid_size: Number of patches per side
|
| 101 |
-
|
| 102 |
-
Returns:
|
| 103 |
-
alpha_mask: Binary mask of shape (grid_size, grid_size)
|
| 104 |
-
"""
|
| 105 |
-
alpha_mask = torch.zeros((grid_size, grid_size))
|
| 106 |
-
|
| 107 |
-
for trace in traces:
|
| 108 |
-
trace_mask = trace_to_alpha_mask(trace, grid_size)
|
| 109 |
-
alpha_mask = torch.logical_or(alpha_mask, trace_mask).float()
|
| 110 |
-
|
| 111 |
-
return alpha_mask
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
src/alphaclip/alphaclip_loader.py
DELETED
|
@@ -1,233 +0,0 @@
|
|
| 1 |
-
"""
|
| 2 |
-
AlphaCLIP Standalone Loader
|
| 3 |
-
|
| 4 |
-
This module provides a simple interface to load and use AlphaCLIP models.
|
| 5 |
-
It exposes the core functionality of AlphaCLIP in a standalone package.
|
| 6 |
-
|
| 7 |
-
Usage:
|
| 8 |
-
from alphaclip_loader import AlphaCLIPLoader
|
| 9 |
-
|
| 10 |
-
# Initialize the loader
|
| 11 |
-
loader = AlphaCLIPLoader()
|
| 12 |
-
|
| 13 |
-
# Load a model
|
| 14 |
-
model, preprocess = loader.load_model("ViT-B/16")
|
| 15 |
-
|
| 16 |
-
# Tokenize text
|
| 17 |
-
tokens = loader.tokenize("A photo of a cat")
|
| 18 |
-
|
| 19 |
-
# Get available models
|
| 20 |
-
models = loader.available_models()
|
| 21 |
-
"""
|
| 22 |
-
|
| 23 |
-
import os
|
| 24 |
-
import sys
|
| 25 |
-
from typing import Union, List, Tuple, Optional
|
| 26 |
-
|
| 27 |
-
# Check for critical dependencies
|
| 28 |
-
missing_deps = []
|
| 29 |
-
try:
|
| 30 |
-
import torch
|
| 31 |
-
except ImportError:
|
| 32 |
-
missing_deps.append("torch")
|
| 33 |
-
|
| 34 |
-
try:
|
| 35 |
-
from PIL import Image
|
| 36 |
-
except ImportError:
|
| 37 |
-
missing_deps.append("Pillow")
|
| 38 |
-
|
| 39 |
-
if missing_deps:
|
| 40 |
-
raise ImportError(f"Missing required dependencies: {', '.join(missing_deps)}. "
|
| 41 |
-
f"Please install them with: pip install {' '.join(missing_deps)}")
|
| 42 |
-
|
| 43 |
-
# Add the alpha_clip directory to the path
|
| 44 |
-
_current_dir = os.path.dirname(os.path.abspath(__file__))
|
| 45 |
-
_alpha_clip_dir = os.path.join(_current_dir, 'alpha_clip')
|
| 46 |
-
if _alpha_clip_dir not in sys.path:
|
| 47 |
-
sys.path.insert(0, _alpha_clip_dir)
|
| 48 |
-
|
| 49 |
-
# Import the alpha_clip modules
|
| 50 |
-
try:
|
| 51 |
-
#import .alpha_clip
|
| 52 |
-
from .alpha_clip import available_models, load, tokenize
|
| 53 |
-
except ImportError as e:
|
| 54 |
-
raise ImportError(f"Failed to import alpha_clip modules: {e}. Please ensure all dependencies are installed.")
|
| 55 |
-
|
| 56 |
-
|
| 57 |
-
class AlphaCLIPLoader:
|
| 58 |
-
"""
|
| 59 |
-
A convenience wrapper for AlphaCLIP functionality.
|
| 60 |
-
|
| 61 |
-
This class provides a clean interface to load AlphaCLIP models and
|
| 62 |
-
perform text tokenization.
|
| 63 |
-
"""
|
| 64 |
-
|
| 65 |
-
def __init__(self, default_device: Optional[str] = None):
|
| 66 |
-
"""
|
| 67 |
-
Initialize the AlphaCLIP loader.
|
| 68 |
-
|
| 69 |
-
Args:
|
| 70 |
-
default_device: Default device to load models on. If None, will use
|
| 71 |
-
CUDA if available, otherwise CPU.
|
| 72 |
-
"""
|
| 73 |
-
if default_device is None:
|
| 74 |
-
self.default_device = "cuda" if torch.cuda.is_available() else "cpu"
|
| 75 |
-
else:
|
| 76 |
-
self.default_device = default_device
|
| 77 |
-
|
| 78 |
-
def available_models(self) -> List[str]:
|
| 79 |
-
"""
|
| 80 |
-
Get list of available AlphaCLIP model names.
|
| 81 |
-
|
| 82 |
-
Returns:
|
| 83 |
-
List of model names that can be used with load_model()
|
| 84 |
-
"""
|
| 85 |
-
return available_models()
|
| 86 |
-
|
| 87 |
-
def load_model(
|
| 88 |
-
self,
|
| 89 |
-
name: str,
|
| 90 |
-
alpha_vision_ckpt_pth: str = "None",
|
| 91 |
-
device: Optional[Union[str, torch.device]] = None,
|
| 92 |
-
jit: bool = False,
|
| 93 |
-
download_root: Optional[str] = None,
|
| 94 |
-
lora_adapt: bool = False,
|
| 95 |
-
rank: int = 16
|
| 96 |
-
) -> Tuple[torch.nn.Module, callable]:
|
| 97 |
-
"""
|
| 98 |
-
Load an AlphaCLIP model.
|
| 99 |
-
|
| 100 |
-
Args:
|
| 101 |
-
name: Model name (e.g., "ViT-B/16") or path to checkpoint
|
| 102 |
-
alpha_vision_ckpt_pth: Path to additional vision checkpoint
|
| 103 |
-
device: Device to load model on (defaults to self.default_device)
|
| 104 |
-
jit: Whether to load JIT optimized model
|
| 105 |
-
download_root: Directory to download models to
|
| 106 |
-
lora_adapt: Whether to use LoRA adaptation
|
| 107 |
-
rank: LoRA rank if lora_adapt is True
|
| 108 |
-
|
| 109 |
-
Returns:
|
| 110 |
-
Tuple of (model, preprocess_function)
|
| 111 |
-
"""
|
| 112 |
-
if device is None:
|
| 113 |
-
device = self.default_device
|
| 114 |
-
|
| 115 |
-
return load(
|
| 116 |
-
name=name,
|
| 117 |
-
alpha_vision_ckpt_pth=alpha_vision_ckpt_pth,
|
| 118 |
-
device=device,
|
| 119 |
-
jit=jit,
|
| 120 |
-
download_root=download_root,
|
| 121 |
-
lora_adapt=lora_adapt,
|
| 122 |
-
rank=rank
|
| 123 |
-
)
|
| 124 |
-
|
| 125 |
-
def tokenize(
|
| 126 |
-
self,
|
| 127 |
-
texts: Union[str, List[str]],
|
| 128 |
-
context_length: int = 77,
|
| 129 |
-
truncate: bool = True
|
| 130 |
-
) -> torch.Tensor:
|
| 131 |
-
"""
|
| 132 |
-
Tokenize text for use with AlphaCLIP models.
|
| 133 |
-
|
| 134 |
-
Args:
|
| 135 |
-
texts: String or list of strings to tokenize
|
| 136 |
-
context_length: Maximum token length (default 77)
|
| 137 |
-
truncate: Whether to truncate long texts
|
| 138 |
-
|
| 139 |
-
Returns:
|
| 140 |
-
Tensor of tokenized text
|
| 141 |
-
"""
|
| 142 |
-
return tokenize(texts, context_length, truncate)
|
| 143 |
-
|
| 144 |
-
def encode_text(self, model: torch.nn.Module, texts: Union[str, List[str]]) -> torch.Tensor:
|
| 145 |
-
"""
|
| 146 |
-
Convenience method to tokenize and encode text.
|
| 147 |
-
|
| 148 |
-
Args:
|
| 149 |
-
model: Loaded AlphaCLIP model
|
| 150 |
-
texts: Text(s) to encode
|
| 151 |
-
|
| 152 |
-
Returns:
|
| 153 |
-
Text embeddings tensor
|
| 154 |
-
"""
|
| 155 |
-
tokens = self.tokenize(texts)
|
| 156 |
-
if hasattr(model, 'token_embedding'):
|
| 157 |
-
# Move tokens to same device as model
|
| 158 |
-
device = next(model.parameters()).device
|
| 159 |
-
tokens = tokens.to(device)
|
| 160 |
-
|
| 161 |
-
with torch.no_grad():
|
| 162 |
-
text_features = model.encode_text(tokens)
|
| 163 |
-
|
| 164 |
-
return text_features
|
| 165 |
-
|
| 166 |
-
def encode_image(self, model: torch.nn.Module, images: torch.Tensor) -> torch.Tensor:
|
| 167 |
-
"""
|
| 168 |
-
Convenience method to encode images.
|
| 169 |
-
|
| 170 |
-
Args:
|
| 171 |
-
model: Loaded AlphaCLIP model
|
| 172 |
-
images: Preprocessed image tensor
|
| 173 |
-
|
| 174 |
-
Returns:
|
| 175 |
-
Image embeddings tensor
|
| 176 |
-
"""
|
| 177 |
-
with torch.no_grad():
|
| 178 |
-
image_features = model.encode_image(images)
|
| 179 |
-
|
| 180 |
-
return image_features
|
| 181 |
-
|
| 182 |
-
def get_similarity(self, text_features: torch.Tensor, image_features: torch.Tensor) -> torch.Tensor:
|
| 183 |
-
"""
|
| 184 |
-
Compute cosine similarity between text and image features.
|
| 185 |
-
|
| 186 |
-
Args:
|
| 187 |
-
text_features: Text embedding tensor
|
| 188 |
-
image_features: Image embedding tensor
|
| 189 |
-
|
| 190 |
-
Returns:
|
| 191 |
-
Similarity scores tensor
|
| 192 |
-
"""
|
| 193 |
-
# Normalize features
|
| 194 |
-
text_features = text_features / text_features.norm(dim=-1, keepdim=True)
|
| 195 |
-
image_features = image_features / image_features.norm(dim=-1, keepdim=True)
|
| 196 |
-
|
| 197 |
-
# Compute similarity
|
| 198 |
-
similarity = (text_features @ image_features.T)
|
| 199 |
-
return similarity
|
| 200 |
-
|
| 201 |
-
|
| 202 |
-
# Convenience function for quick model loading
|
| 203 |
-
def load_alphaclip(
|
| 204 |
-
model_name: str = "ViT-B/16",
|
| 205 |
-
device: Optional[str] = None,
|
| 206 |
-
alpha_vision_ckpt_pth: str = "None",
|
| 207 |
-
download_root = '/raid/datasets/models_weights/alphaclip',
|
| 208 |
-
**kwargs
|
| 209 |
-
) -> Tuple[AlphaCLIPLoader, torch.nn.Module, callable]:
|
| 210 |
-
"""
|
| 211 |
-
Quick function to load AlphaCLIP with a loader instance.
|
| 212 |
-
|
| 213 |
-
Args:
|
| 214 |
-
model_name: Name of the model to load
|
| 215 |
-
device: Device to use
|
| 216 |
-
**kwargs: Additional arguments for model loading
|
| 217 |
-
|
| 218 |
-
Returns:
|
| 219 |
-
Tuple of (loader, model, preprocess_function)
|
| 220 |
-
"""
|
| 221 |
-
loader = AlphaCLIPLoader(default_device=device)
|
| 222 |
-
model, preprocess = loader.load_model(model_name, **kwargs)
|
| 223 |
-
return loader, model, preprocess
|
| 224 |
-
|
| 225 |
-
|
| 226 |
-
# Make key functions available at module level
|
| 227 |
-
__all__ = [
|
| 228 |
-
'AlphaCLIPLoader',
|
| 229 |
-
'load_alphaclip',
|
| 230 |
-
'available_models',
|
| 231 |
-
'load',
|
| 232 |
-
'tokenize'
|
| 233 |
-
]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
src/alphaclip/example.py
DELETED
|
@@ -1,76 +0,0 @@
|
|
| 1 |
-
#!/usr/bin/env python3
|
| 2 |
-
"""
|
| 3 |
-
Example usage of AlphaCLIP Standalone
|
| 4 |
-
|
| 5 |
-
This script demonstrates basic usage of the AlphaCLIP standalone package.
|
| 6 |
-
"""
|
| 7 |
-
|
| 8 |
-
import torch
|
| 9 |
-
import numpy as np
|
| 10 |
-
from alphaclip_loader import AlphaCLIPLoader, load_alphaclip
|
| 11 |
-
|
| 12 |
-
def main():
|
| 13 |
-
print("AlphaCLIP Standalone Example")
|
| 14 |
-
print("=" * 40)
|
| 15 |
-
|
| 16 |
-
# Check if CUDA is available
|
| 17 |
-
device = "cuda" if torch.cuda.is_available() else "cpu"
|
| 18 |
-
print(f"Using device: {device}")
|
| 19 |
-
|
| 20 |
-
# Method 1: Using the loader class
|
| 21 |
-
print("\n1. Using AlphaCLIPLoader class:")
|
| 22 |
-
loader = AlphaCLIPLoader(default_device=device)
|
| 23 |
-
|
| 24 |
-
# Show available models
|
| 25 |
-
models = loader.available_models()
|
| 26 |
-
print(f"Available models: {models}")
|
| 27 |
-
|
| 28 |
-
# Load a model
|
| 29 |
-
print("\nLoading ViT-B/16 model...")
|
| 30 |
-
model, preprocess = loader.load_model("ViT-B/16")
|
| 31 |
-
print(f"Model loaded successfully!")
|
| 32 |
-
|
| 33 |
-
# Test text encoding
|
| 34 |
-
test_texts = [
|
| 35 |
-
"a photo of a cat",
|
| 36 |
-
"a dog running in the park",
|
| 37 |
-
"a beautiful sunset over the ocean"
|
| 38 |
-
]
|
| 39 |
-
|
| 40 |
-
print(f"\nEncoding {len(test_texts)} texts...")
|
| 41 |
-
text_features = loader.encode_text(model, test_texts)
|
| 42 |
-
print(f"Text features shape: {text_features.shape}")
|
| 43 |
-
|
| 44 |
-
# Compute similarities between texts
|
| 45 |
-
print("\nComputing text-to-text similarities:")
|
| 46 |
-
similarities = loader.get_similarity(text_features, text_features)
|
| 47 |
-
|
| 48 |
-
for i, text1 in enumerate(test_texts):
|
| 49 |
-
for j, text2 in enumerate(test_texts):
|
| 50 |
-
if i <= j: # Only show upper triangle
|
| 51 |
-
sim = similarities[i, j].item()
|
| 52 |
-
print(f" '{text1}' <-> '{text2}': {sim:.3f}")
|
| 53 |
-
|
| 54 |
-
# Method 2: Using the quick loader function
|
| 55 |
-
print("\n\n2. Using quick loader function:")
|
| 56 |
-
loader2, model2, preprocess2 = load_alphaclip("ViT-B/16", device=device)
|
| 57 |
-
|
| 58 |
-
# Test single text
|
| 59 |
-
single_text = "a red apple on a wooden table"
|
| 60 |
-
single_features = loader2.encode_text(model2, single_text)
|
| 61 |
-
print(f"Single text '{single_text}' encoded to shape: {single_features.shape}")
|
| 62 |
-
|
| 63 |
-
# Test tokenization
|
| 64 |
-
print("\n3. Tokenization example:")
|
| 65 |
-
tokens = loader.tokenize(test_texts)
|
| 66 |
-
print(f"Tokenized {len(test_texts)} texts to shape: {tokens.shape}")
|
| 67 |
-
|
| 68 |
-
# Show some token examples
|
| 69 |
-
print("First few tokens for each text:")
|
| 70 |
-
for i, text in enumerate(test_texts):
|
| 71 |
-
print(f" '{text}': {tokens[i][:10].tolist()}...")
|
| 72 |
-
|
| 73 |
-
print("\nExample completed successfully!")
|
| 74 |
-
|
| 75 |
-
if __name__ == "__main__":
|
| 76 |
-
main()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
src/alphaclip/requirements.txt
DELETED
|
@@ -1,10 +0,0 @@
|
|
| 1 |
-
# Core dependencies for AlphaCLIP standalone
|
| 2 |
-
torch>=1.7.1
|
| 3 |
-
torchvision
|
| 4 |
-
ftfy
|
| 5 |
-
regex
|
| 6 |
-
tqdm
|
| 7 |
-
loralib
|
| 8 |
-
Pillow
|
| 9 |
-
numpy
|
| 10 |
-
packaging
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
src/alphaclip/setup.py
DELETED
|
@@ -1,47 +0,0 @@
|
|
| 1 |
-
"""
|
| 2 |
-
Setup script for AlphaCLIP Standalone
|
| 3 |
-
"""
|
| 4 |
-
|
| 5 |
-
from setuptools import setup, find_packages
|
| 6 |
-
import os
|
| 7 |
-
|
| 8 |
-
# Read requirements
|
| 9 |
-
with open('requirements.txt', 'r') as f:
|
| 10 |
-
requirements = [line.strip() for line in f if line.strip() and not line.startswith('#')]
|
| 11 |
-
|
| 12 |
-
# Read README if it exists
|
| 13 |
-
readme_content = ""
|
| 14 |
-
if os.path.exists('README.md'):
|
| 15 |
-
with open('README.md', 'r', encoding='utf-8') as f:
|
| 16 |
-
readme_content = f.read()
|
| 17 |
-
|
| 18 |
-
setup(
|
| 19 |
-
name="alphaclip-standalone",
|
| 20 |
-
version="1.0.0",
|
| 21 |
-
author="AlphaCLIP Team",
|
| 22 |
-
description="Standalone version of AlphaCLIP for easy integration",
|
| 23 |
-
long_description=readme_content,
|
| 24 |
-
long_description_content_type="text/markdown",
|
| 25 |
-
packages=find_packages(),
|
| 26 |
-
package_data={
|
| 27 |
-
'alpha_clip': ['*.gz'], # Include the tokenizer vocabulary file
|
| 28 |
-
},
|
| 29 |
-
include_package_data=True,
|
| 30 |
-
install_requires=requirements,
|
| 31 |
-
python_requires=">=3.7",
|
| 32 |
-
classifiers=[
|
| 33 |
-
"Development Status :: 4 - Beta",
|
| 34 |
-
"Intended Audience :: Developers",
|
| 35 |
-
"Intended Audience :: Science/Research",
|
| 36 |
-
"License :: OSI Approved :: MIT License",
|
| 37 |
-
"Operating System :: OS Independent",
|
| 38 |
-
"Programming Language :: Python :: 3",
|
| 39 |
-
"Programming Language :: Python :: 3.7",
|
| 40 |
-
"Programming Language :: Python :: 3.8",
|
| 41 |
-
"Programming Language :: Python :: 3.9",
|
| 42 |
-
"Programming Language :: Python :: 3.10",
|
| 43 |
-
"Topic :: Scientific/Engineering :: Artificial Intelligence",
|
| 44 |
-
"Topic :: Software Development :: Libraries :: Python Modules",
|
| 45 |
-
],
|
| 46 |
-
keywords="clip, vision, language, deep learning, pytorch",
|
| 47 |
-
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
src/alphaclip/test_installation.py
DELETED
|
@@ -1,149 +0,0 @@
|
|
| 1 |
-
#!/usr/bin/env python3
|
| 2 |
-
"""
|
| 3 |
-
Test script for AlphaCLIP Standalone
|
| 4 |
-
|
| 5 |
-
This script tests the basic functionality of the standalone package
|
| 6 |
-
to ensure everything is working correctly.
|
| 7 |
-
"""
|
| 8 |
-
|
| 9 |
-
import sys
|
| 10 |
-
import os
|
| 11 |
-
|
| 12 |
-
def test_imports():
|
| 13 |
-
"""Test that all required modules can be imported."""
|
| 14 |
-
print("Testing imports...")
|
| 15 |
-
|
| 16 |
-
try:
|
| 17 |
-
import torch
|
| 18 |
-
print(f"โ PyTorch {torch.__version__} imported successfully")
|
| 19 |
-
except ImportError as e:
|
| 20 |
-
print(f"โ Failed to import PyTorch: {e}")
|
| 21 |
-
return False
|
| 22 |
-
|
| 23 |
-
try:
|
| 24 |
-
import torchvision
|
| 25 |
-
print(f"โ Torchvision imported successfully")
|
| 26 |
-
except ImportError as e:
|
| 27 |
-
print(f"โ Failed to import torchvision: {e}")
|
| 28 |
-
return False
|
| 29 |
-
|
| 30 |
-
try:
|
| 31 |
-
from alphaclip_loader import AlphaCLIPLoader
|
| 32 |
-
print("โ AlphaCLIPLoader imported successfully")
|
| 33 |
-
except ImportError as e:
|
| 34 |
-
print(f"โ Failed to import AlphaCLIPLoader: {e}")
|
| 35 |
-
return False
|
| 36 |
-
|
| 37 |
-
try:
|
| 38 |
-
import loralib
|
| 39 |
-
print("โ LoraLib imported successfully")
|
| 40 |
-
except ImportError as e:
|
| 41 |
-
print(f"โ Failed to import loralib: {e}")
|
| 42 |
-
return False
|
| 43 |
-
|
| 44 |
-
return True
|
| 45 |
-
|
| 46 |
-
def test_model_loading():
|
| 47 |
-
"""Test loading a model."""
|
| 48 |
-
print("\nTesting model loading...")
|
| 49 |
-
|
| 50 |
-
try:
|
| 51 |
-
from alphaclip_loader import AlphaCLIPLoader
|
| 52 |
-
|
| 53 |
-
loader = AlphaCLIPLoader(default_device="cpu") # Use CPU for testing
|
| 54 |
-
models = loader.available_models()
|
| 55 |
-
print(f"โ Available models: {models}")
|
| 56 |
-
|
| 57 |
-
# Try to load the smallest model for testing
|
| 58 |
-
print("Loading ViT-B/32 model (this may take a while for first download)...")
|
| 59 |
-
model, preprocess = loader.load_model("ViT-B/32", device="cpu")
|
| 60 |
-
print("โ Model loaded successfully")
|
| 61 |
-
|
| 62 |
-
return True
|
| 63 |
-
|
| 64 |
-
except Exception as e:
|
| 65 |
-
print(f"โ Failed to load model: {e}")
|
| 66 |
-
return False
|
| 67 |
-
|
| 68 |
-
def test_tokenization():
|
| 69 |
-
"""Test text tokenization."""
|
| 70 |
-
print("\nTesting tokenization...")
|
| 71 |
-
|
| 72 |
-
try:
|
| 73 |
-
from alphaclip_loader import AlphaCLIPLoader
|
| 74 |
-
|
| 75 |
-
loader = AlphaCLIPLoader()
|
| 76 |
-
test_text = "a photo of a cat"
|
| 77 |
-
tokens = loader.tokenize(test_text)
|
| 78 |
-
print(f"โ Tokenized '{test_text}' to shape {tokens.shape}")
|
| 79 |
-
|
| 80 |
-
# Test batch tokenization
|
| 81 |
-
test_texts = ["a cat", "a dog", "a bird"]
|
| 82 |
-
batch_tokens = loader.tokenize(test_texts)
|
| 83 |
-
print(f"โ Batch tokenized {len(test_texts)} texts to shape {batch_tokens.shape}")
|
| 84 |
-
|
| 85 |
-
return True
|
| 86 |
-
|
| 87 |
-
except Exception as e:
|
| 88 |
-
print(f"โ Failed tokenization test: {e}")
|
| 89 |
-
return False
|
| 90 |
-
|
| 91 |
-
def test_text_encoding():
|
| 92 |
-
"""Test text encoding with a loaded model."""
|
| 93 |
-
print("\nTesting text encoding...")
|
| 94 |
-
|
| 95 |
-
try:
|
| 96 |
-
from alphaclip_loader import AlphaCLIPLoader
|
| 97 |
-
|
| 98 |
-
loader = AlphaCLIPLoader(default_device="cpu")
|
| 99 |
-
model, preprocess = loader.load_model("ViT-B/32", device="cpu")
|
| 100 |
-
|
| 101 |
-
test_text = "a photo of a cat"
|
| 102 |
-
features = loader.encode_text(model, test_text)
|
| 103 |
-
print(f"โ Encoded text to features with shape {features.shape}")
|
| 104 |
-
|
| 105 |
-
# Test batch encoding
|
| 106 |
-
test_texts = ["a cat", "a dog"]
|
| 107 |
-
batch_features = loader.encode_text(model, test_texts)
|
| 108 |
-
print(f"โ Batch encoded {len(test_texts)} texts to shape {batch_features.shape}")
|
| 109 |
-
|
| 110 |
-
return True
|
| 111 |
-
|
| 112 |
-
except Exception as e:
|
| 113 |
-
print(f"โ Failed text encoding test: {e}")
|
| 114 |
-
return False
|
| 115 |
-
|
| 116 |
-
def main():
|
| 117 |
-
"""Run all tests."""
|
| 118 |
-
print("AlphaCLIP Standalone Test Suite")
|
| 119 |
-
print("=" * 40)
|
| 120 |
-
|
| 121 |
-
tests = [
|
| 122 |
-
test_imports,
|
| 123 |
-
test_tokenization,
|
| 124 |
-
test_model_loading,
|
| 125 |
-
test_text_encoding,
|
| 126 |
-
]
|
| 127 |
-
|
| 128 |
-
passed = 0
|
| 129 |
-
total = len(tests)
|
| 130 |
-
|
| 131 |
-
for test in tests:
|
| 132 |
-
try:
|
| 133 |
-
if test():
|
| 134 |
-
passed += 1
|
| 135 |
-
except Exception as e:
|
| 136 |
-
print(f"โ Test {test.__name__} failed with exception: {e}")
|
| 137 |
-
|
| 138 |
-
print(f"\n{'='*40}")
|
| 139 |
-
print(f"Test Results: {passed}/{total} tests passed")
|
| 140 |
-
|
| 141 |
-
if passed == total:
|
| 142 |
-
print("๐ All tests passed! AlphaCLIP Standalone is working correctly.")
|
| 143 |
-
return 0
|
| 144 |
-
else:
|
| 145 |
-
print("โ Some tests failed. Please check the error messages above.")
|
| 146 |
-
return 1
|
| 147 |
-
|
| 148 |
-
if __name__ == "__main__":
|
| 149 |
-
sys.exit(main())
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
src/bbox_utils.py
DELETED
|
@@ -1,421 +0,0 @@
|
|
| 1 |
-
import torch
|
| 2 |
-
from copy import deepcopy
|
| 3 |
-
from PIL import ImageDraw
|
| 4 |
-
import itertools
|
| 5 |
-
import random
|
| 6 |
-
|
| 7 |
-
|
| 8 |
-
def extract_bboxes_feats(patch_embeddings, bboxes, gaussian_avg=False,
|
| 9 |
-
gaussian_bbox_variance=0.5, get_single_embedding_per_image=False,
|
| 10 |
-
patch_size=14, attention_map=None):
|
| 11 |
-
"""
|
| 12 |
-
if get_single_embedding_per_image is True, the weights of all the bounding boxes patches on an image will be summed and the function will return the patch weights depending on this map
|
| 13 |
-
"""
|
| 14 |
-
N = patch_embeddings.shape[0]
|
| 15 |
-
N_boxes = bboxes.shape[1]
|
| 16 |
-
grid_size = int(patch_embeddings.shape[1]**0.5)
|
| 17 |
-
device = patch_embeddings.device
|
| 18 |
-
|
| 19 |
-
bboxes //= patch_size
|
| 20 |
-
bboxes = bboxes.int()
|
| 21 |
-
|
| 22 |
-
# Reshape patches to grid
|
| 23 |
-
patch_embeddings = patch_embeddings.view(N, grid_size, grid_size, -1) # Shape (N, grid_size, grid_size, embed_dim)
|
| 24 |
-
if attention_map is not None:
|
| 25 |
-
attention_map = attention_map.view(N, grid_size, grid_size) # Shape (N, grid_size, grid_size)
|
| 26 |
-
# Grid of the sum of the gaussian weights
|
| 27 |
-
total_patch_weights = torch.zeros(N, grid_size, grid_size)
|
| 28 |
-
|
| 29 |
-
# Extract boxes
|
| 30 |
-
x1, y1, w, h = bboxes.unbind(-1) # Separate box dimensions (N, N_boxes)
|
| 31 |
-
|
| 32 |
-
# Create mesh grid for slicing
|
| 33 |
-
x2 = x1 + w # Exclusive end x
|
| 34 |
-
y2 = y1 + h # Exclusive end y
|
| 35 |
-
|
| 36 |
-
means = []
|
| 37 |
-
for i in range(N):
|
| 38 |
-
image_means = []
|
| 39 |
-
for j in range(N_boxes):
|
| 40 |
-
if bboxes[i, j].sum().item() < 0 and get_single_embedding_per_image:
|
| 41 |
-
# this is the case where we receive a dummy box
|
| 42 |
-
continue
|
| 43 |
-
# Extract the region for each box
|
| 44 |
-
region_patches = patch_embeddings[i, y1[i, j]:y2[i, j] + 1, x1[i, j]:x2[i, j] + 1, :] # (h, w, embed_dim)
|
| 45 |
-
|
| 46 |
-
if attention_map is not None:
|
| 47 |
-
patch_weights = attention_map[i, y1[i, j]:y2[i, j] + 1, x1[i, j]:x2[i, j] + 1]
|
| 48 |
-
patch_weights /= patch_weights.sum()
|
| 49 |
-
total_patch_weights[i, y1[i, j]:y2[i, j] + 1, x1[i, j]:x2[i, j] + 1] += patch_weights
|
| 50 |
-
|
| 51 |
-
weighted_patches = region_patches * patch_weights.to(device).unsqueeze(-1) # (h, w, embed_dim)
|
| 52 |
-
region_mean = weighted_patches.sum(dim=(0, 1)) # Weighted mean
|
| 53 |
-
|
| 54 |
-
elif gaussian_avg:
|
| 55 |
-
# Create Gaussian weights
|
| 56 |
-
h_span, w_span = region_patches.shape[:2]
|
| 57 |
-
y_coords, x_coords = torch.meshgrid(
|
| 58 |
-
torch.linspace(-1, 1, h_span),
|
| 59 |
-
torch.linspace(-1, 1, w_span),
|
| 60 |
-
indexing="ij"
|
| 61 |
-
)
|
| 62 |
-
if gaussian_bbox_variance == 0:
|
| 63 |
-
patch_weights = torch.zeros((h_span, w_span))
|
| 64 |
-
# Determine central indices
|
| 65 |
-
center_y = [h_span // 2] if h_span % 2 == 1 else [h_span // 2 - 1, h_span // 2]
|
| 66 |
-
center_x = [w_span // 2] if w_span % 2 == 1 else [w_span // 2 - 1, w_span // 2]
|
| 67 |
-
# Randomly select one of the central elements in even case
|
| 68 |
-
center_y = random.choice(center_y)
|
| 69 |
-
center_x = random.choice(center_x)
|
| 70 |
-
# Set the selected central element to 1
|
| 71 |
-
patch_weights[center_y, center_x] = 1.0
|
| 72 |
-
else:
|
| 73 |
-
distances = x_coords**2 + y_coords**2
|
| 74 |
-
patch_weights = torch.exp(-distances / gaussian_bbox_variance)
|
| 75 |
-
patch_weights = patch_weights / patch_weights.sum() # Normalize to sum to 1
|
| 76 |
-
|
| 77 |
-
# Apply Gaussian weights to region patches
|
| 78 |
-
weighted_patches = region_patches * patch_weights.to(device).unsqueeze(-1) # (h, w, embed_dim)
|
| 79 |
-
region_mean = weighted_patches.sum(dim=(0, 1)) # Weighted mean
|
| 80 |
-
|
| 81 |
-
# Recording the bbox weight inside the image patch weight map
|
| 82 |
-
total_patch_weights[i, y1[i, j]:y2[i, j] + 1, x1[i, j]:x2[i, j] + 1] += patch_weights
|
| 83 |
-
else:
|
| 84 |
-
# Mean pooling case: create uniform weights
|
| 85 |
-
h_span, w_span = region_patches.shape[:2]
|
| 86 |
-
uniform_weights = torch.ones(h_span, w_span) / (h_span * w_span)
|
| 87 |
-
|
| 88 |
-
# Update total_patch_weights for mean pooling
|
| 89 |
-
total_patch_weights[i, y1[i,j]:y2[i,j]+1, x1[i,j]:x2[i,j]+1] += uniform_weights
|
| 90 |
-
|
| 91 |
-
# Compute mean of the region
|
| 92 |
-
region_mean = region_patches.mean(dim=(0, 1))
|
| 93 |
-
|
| 94 |
-
# Store the mean
|
| 95 |
-
image_means.append(region_mean)
|
| 96 |
-
if not get_single_embedding_per_image:
|
| 97 |
-
means.append(torch.stack(image_means))
|
| 98 |
-
|
| 99 |
-
# Normalizing the weight map so the sum is equal to 1
|
| 100 |
-
total_patch_weights /= total_patch_weights.sum(dim=(1,2), keepdim=True)
|
| 101 |
-
if not get_single_embedding_per_image:
|
| 102 |
-
return torch.stack(means) # Shape (N, N_boxes, embed_dim)
|
| 103 |
-
else:
|
| 104 |
-
# Expand dimensions to match embeddings
|
| 105 |
-
total_patch_weights = total_patch_weights.unsqueeze(-1).to(device)
|
| 106 |
-
|
| 107 |
-
# Compute weighted sum
|
| 108 |
-
weighted_patch_mean = (total_patch_weights * patch_embeddings).sum(dim=(1, 2))
|
| 109 |
-
return weighted_patch_mean
|
| 110 |
-
# Shape (N, embed_dim)
|
| 111 |
-
|
| 112 |
-
#def adjust_bbox_for_transform(image, bbox, resize_dim, crop_dim):
|
| 113 |
-
# """
|
| 114 |
-
# Adjusts the bounding box for a resized and center-cropped image.
|
| 115 |
-
#
|
| 116 |
-
# Args:
|
| 117 |
-
# image (PIL.Image): The input image.
|
| 118 |
-
# bbox (list): The bounding box in [x1, y1, w, h] format.
|
| 119 |
-
# resize_dim (int): The dimension of the shortest side after resizing.
|
| 120 |
-
# crop_dim (int): The size of the square crop.
|
| 121 |
-
#
|
| 122 |
-
# Returns:
|
| 123 |
-
# list: The adjusted bounding box in [x1, y1, w, h] format.
|
| 124 |
-
# """
|
| 125 |
-
# x1, y1, w, h = bbox
|
| 126 |
-
# orig_width, orig_height = image.size
|
| 127 |
-
#
|
| 128 |
-
# # Calculate resize scale for the shortest side
|
| 129 |
-
# if orig_width < orig_height:
|
| 130 |
-
# scale = resize_dim / orig_width
|
| 131 |
-
# resized_width, resized_height = resize_dim, int(orig_height * scale)
|
| 132 |
-
# else:
|
| 133 |
-
# scale = resize_dim / orig_height
|
| 134 |
-
# resized_width, resized_height = int(orig_width * scale), resize_dim
|
| 135 |
-
#
|
| 136 |
-
# # Scale the bounding box
|
| 137 |
-
# x1 *= scale
|
| 138 |
-
# y1 *= scale
|
| 139 |
-
# w *= scale
|
| 140 |
-
# h *= scale
|
| 141 |
-
#
|
| 142 |
-
# # Calculate cropping offsets
|
| 143 |
-
# crop_x = (resized_width - crop_dim) // 2
|
| 144 |
-
# crop_y = (resized_height - crop_dim) // 2
|
| 145 |
-
#
|
| 146 |
-
# # Adjust bounding box for cropping
|
| 147 |
-
# x1 -= crop_x
|
| 148 |
-
# y1 -= crop_y
|
| 149 |
-
#
|
| 150 |
-
# # Clamp the bounding box to the cropped area
|
| 151 |
-
# x1 = max(0, x1)
|
| 152 |
-
# y1 = max(0, y1)
|
| 153 |
-
# w = min(w, crop_dim - x1)
|
| 154 |
-
# h = min(h, crop_dim - y1)
|
| 155 |
-
#
|
| 156 |
-
# return [x1, y1, w, h]
|
| 157 |
-
|
| 158 |
-
def map_traces_to_grid(traces, n_patch):
|
| 159 |
-
grid = torch.zeros((n_patch, n_patch))
|
| 160 |
-
patch_size = 1.0 / n_patch
|
| 161 |
-
|
| 162 |
-
for trace in traces:
|
| 163 |
-
x, y = trace['x'], trace['y']
|
| 164 |
-
if 0 <= x <= 1 and 0 <= y <= 1:
|
| 165 |
-
grid_x, grid_y = int(x / patch_size), int(y / patch_size)
|
| 166 |
-
grid[min(grid_y, n_patch - 1), min(grid_x, n_patch - 1)] += 1
|
| 167 |
-
|
| 168 |
-
return grid
|
| 169 |
-
|
| 170 |
-
def adjust_bbox_for_transform(image, bbox, resize_dim, crop_dim):
|
| 171 |
-
"""
|
| 172 |
-
Adjusts the bounding box for a resized and center-cropped image.
|
| 173 |
-
|
| 174 |
-
Args:
|
| 175 |
-
image (PIL.Image): The input image.
|
| 176 |
-
bbox (list): The bounding box in [x1, y1, w, h] format.
|
| 177 |
-
resize_dim (int): The dimension of the shortest side after resizing.
|
| 178 |
-
crop_dim (int): The size of the square crop.
|
| 179 |
-
|
| 180 |
-
Returns:
|
| 181 |
-
list: The adjusted bounding box in [x1, y1, w, h] format.
|
| 182 |
-
"""
|
| 183 |
-
x1, y1, w, h = bbox
|
| 184 |
-
orig_width, orig_height = image.size
|
| 185 |
-
|
| 186 |
-
# Scale factors for resizing
|
| 187 |
-
if orig_width < orig_height:
|
| 188 |
-
scale_w = resize_dim / orig_width
|
| 189 |
-
scale_h = (resize_dim * orig_height) / orig_width / orig_height
|
| 190 |
-
else:
|
| 191 |
-
scale_h = resize_dim / orig_height
|
| 192 |
-
scale_w = (resize_dim * orig_width) / orig_height / orig_width
|
| 193 |
-
|
| 194 |
-
# New dimensions after resize
|
| 195 |
-
new_width = int(orig_width * scale_w)
|
| 196 |
-
new_height = int(orig_height * scale_h)
|
| 197 |
-
|
| 198 |
-
# Update bounding box for resizing
|
| 199 |
-
x1 = x1 * scale_w
|
| 200 |
-
y1 = y1 * scale_h
|
| 201 |
-
w = w * scale_w
|
| 202 |
-
h = h * scale_h
|
| 203 |
-
|
| 204 |
-
# Compute cropping offsets
|
| 205 |
-
crop_x_offset = max(0, (new_width - crop_dim) // 2)
|
| 206 |
-
crop_y_offset = max(0, (new_height - crop_dim) // 2)
|
| 207 |
-
|
| 208 |
-
# Adjust bounding box for cropping
|
| 209 |
-
x1 -= crop_x_offset
|
| 210 |
-
y1 -= crop_y_offset
|
| 211 |
-
|
| 212 |
-
# Clip bounding box to crop dimensions
|
| 213 |
-
x1 = max(0, min(x1, crop_dim - 1))
|
| 214 |
-
y1 = max(0, min(y1, crop_dim - 1))
|
| 215 |
-
w = max(0, min(w, crop_dim - x1))
|
| 216 |
-
h = max(0, min(h, crop_dim - y1))
|
| 217 |
-
|
| 218 |
-
return [x1, y1, w, h]
|
| 219 |
-
|
| 220 |
-
|
| 221 |
-
|
| 222 |
-
def adjust_bbox_for_transform_no_scale(image, bbox, target_width, target_height):
|
| 223 |
-
"""
|
| 224 |
-
- Does not preserve the image scale.
|
| 225 |
-
Adjusts the bounding box for an image resized to a fixed width and height.
|
| 226 |
-
|
| 227 |
-
Args:
|
| 228 |
-
image (PIL.Image): The original image.
|
| 229 |
-
bbox (list): The bounding box in [x1, y1, w, h] format.
|
| 230 |
-
target_width (int): The width of the resized image.
|
| 231 |
-
target_height (int): The height of the resized image.
|
| 232 |
-
|
| 233 |
-
Returns:
|
| 234 |
-
list: The adjusted bounding box in [x1, y1, w, h] format.
|
| 235 |
-
"""
|
| 236 |
-
x1, y1, w, h = bbox
|
| 237 |
-
orig_width, orig_height = image.size
|
| 238 |
-
|
| 239 |
-
# Calculate scale factors for width and height
|
| 240 |
-
scale_w = target_width / orig_width
|
| 241 |
-
scale_h = target_height / orig_height
|
| 242 |
-
|
| 243 |
-
# Adjust the bounding box
|
| 244 |
-
x1 = x1 * scale_w
|
| 245 |
-
y1 = y1 * scale_h
|
| 246 |
-
w = w * scale_w
|
| 247 |
-
h = h * scale_h
|
| 248 |
-
|
| 249 |
-
# Return the adjusted bounding box
|
| 250 |
-
return [x1, y1, w, h]
|
| 251 |
-
|
| 252 |
-
|
| 253 |
-
def draw_bounding_boxes(input_image, bounding_boxes, captions=[""], color="red", width=2, text_background=True, boxes_to_show = None):
|
| 254 |
-
"""
|
| 255 |
-
Draws bounding boxes on an image.
|
| 256 |
-
|
| 257 |
-
Args:
|
| 258 |
-
image (PIL.Image): The image to draw on.
|
| 259 |
-
bounding_boxes (list): A list of bounding boxes, each as [x1, y1, x2, y2].
|
| 260 |
-
color (str): The color of the bounding boxes (default is red).
|
| 261 |
-
width (int): The width of the bounding box lines (default is 2).
|
| 262 |
-
|
| 263 |
-
Returns:
|
| 264 |
-
PIL.Image: The image with bounding boxes drawn.
|
| 265 |
-
"""
|
| 266 |
-
# Create a drawing context
|
| 267 |
-
image = deepcopy(input_image)
|
| 268 |
-
draw = ImageDraw.Draw( image )
|
| 269 |
-
|
| 270 |
-
#scale = 720.0 / max(image.size)
|
| 271 |
-
if boxes_to_show is not None:
|
| 272 |
-
if isinstance(boxes_to_show, int):
|
| 273 |
-
indexes_to_show = random.sample(range(len(bounding_boxes)), boxes_to_show)
|
| 274 |
-
else:
|
| 275 |
-
indexes_to_show = boxes_to_show
|
| 276 |
-
|
| 277 |
-
for i, (bbox, cap ) in enumerate(itertools.zip_longest(bounding_boxes, captions, fillvalue="")):
|
| 278 |
-
|
| 279 |
-
if boxes_to_show is not None:
|
| 280 |
-
if i not in indexes_to_show: continue
|
| 281 |
-
#bbox = [ i / scale for i in bbox ]
|
| 282 |
-
#x1, y1, w, h = bbox
|
| 283 |
-
x1, y1, x2, y2 = bbox
|
| 284 |
-
|
| 285 |
-
#x2, y2 = x1 + w, y1 + h # Convert width/height to bottom-right corner
|
| 286 |
-
try:
|
| 287 |
-
draw.rectangle([x1, y1, x2, y2], outline=color, width=width)
|
| 288 |
-
if cap != "":
|
| 289 |
-
if text_background:
|
| 290 |
-
left,top,right,bottom = draw.multiline_textbbox((x1,y1), cap) #textbbox
|
| 291 |
-
draw.rectangle((left-5, top-5, right+5, bottom+5), fill="white")
|
| 292 |
-
draw.multiline_text((x1,y1), cap, fill=color) #text
|
| 293 |
-
|
| 294 |
-
except Exception as e:
|
| 295 |
-
print("exception, i: ", i, f"{x1 = } {y1 = } {x2 = }, {y2 = }")
|
| 296 |
-
print(e)
|
| 297 |
-
|
| 298 |
-
return image
|
| 299 |
-
|
| 300 |
-
def extract_bboxes_feats_double_dino(dino_model, patch_embeddings, bboxes, cls_token, registers_tokens, patch_size, return_type="cls", gaussian_bbox_variance=0.5):
|
| 301 |
-
"""
|
| 302 |
-
Perform a forward pass of the last DINO layer with selected features, batched.
|
| 303 |
-
|
| 304 |
-
Args:
|
| 305 |
-
dino_model: The DINO model.
|
| 306 |
-
patch_embeddings: Patch embeddings before the last layer.
|
| 307 |
-
bboxes: Bounding boxes for each image in the batch (BS x N_BOX_MAX x 4).
|
| 308 |
-
cls_token: CLS token embedding.
|
| 309 |
-
return_type: Type of feature to return ('cls', 'avg', 'gaussian_avg').
|
| 310 |
-
gaussian_bbox_variance: Variance for Gaussian averaging.
|
| 311 |
-
|
| 312 |
-
Returns:
|
| 313 |
-
bbox_features: Features for each bounding box based on return_type.
|
| 314 |
-
"""
|
| 315 |
-
N = patch_embeddings.shape[0] # Batch size
|
| 316 |
-
N_boxes = bboxes.shape[1] # Number of bounding boxes
|
| 317 |
-
grid_size = int(patch_embeddings.shape[1] ** 0.5) # Assuming square grid
|
| 318 |
-
embed_dim = patch_embeddings.shape[-1]
|
| 319 |
-
|
| 320 |
-
bboxes_patch_indexes = bboxes.clone()
|
| 321 |
-
bboxes_patch_indexes //= patch_size # Scale down bbox coordinates to match patch grid
|
| 322 |
-
bboxes_patch_indexes = bboxes_patch_indexes.int()
|
| 323 |
-
|
| 324 |
-
# Reshape patches to grid
|
| 325 |
-
patch_embeddings = patch_embeddings.view(N, grid_size, grid_size, embed_dim) # (N, grid_size, grid_size, embed_dim)
|
| 326 |
-
|
| 327 |
-
if cls_token is not None:
|
| 328 |
-
cls_tokens = cls_token.view(N, embed_dim)
|
| 329 |
-
if registers_tokens is not None:
|
| 330 |
-
patches_offset = 5
|
| 331 |
-
else:
|
| 332 |
-
patches_offset = 1
|
| 333 |
-
else:
|
| 334 |
-
assert return_type != "cls"
|
| 335 |
-
patches_offset = 0
|
| 336 |
-
batch_outputs = []
|
| 337 |
-
|
| 338 |
-
#batch_inputs = []
|
| 339 |
-
|
| 340 |
-
means = []
|
| 341 |
-
for i in range(N): # Iterate over batch
|
| 342 |
-
image_means = []
|
| 343 |
-
|
| 344 |
-
if cls_token is not None:
|
| 345 |
-
cls_cur_img = cls_tokens[i].reshape(1, 1, embed_dim)
|
| 346 |
-
if registers_tokens is not None:
|
| 347 |
-
cur_img_register_tokens = registers_tokens[i].reshape(1, 4, embed_dim)
|
| 348 |
-
|
| 349 |
-
for j in range(N_boxes): # Iterate over bounding boxes
|
| 350 |
-
# Extract the region for the bounding box
|
| 351 |
-
region_patches_xy = patch_embeddings[i, bboxes_patch_indexes[i, j, 1]:bboxes_patch_indexes[i, j, 3] + 1, bboxes_patch_indexes[i, j, 0]:bboxes_patch_indexes[i, j, 2] + 1, :]
|
| 352 |
-
#region_patches = region_patches.reshape(-1, embed_dim) # Flatten to (num_patches, embed_dim)
|
| 353 |
-
|
| 354 |
-
#region_patches = region_patches.view(-1, embed_dim) # Flatten to (num_patches, embed_dim)
|
| 355 |
-
#cls_cur_img = cls_tokens[i].unsqueeze(0) # Add batch dimension (1, embed_dim)
|
| 356 |
-
#region_patches = region_patches.unsqueeze(0) # Add batch dimension (1, num_patches, embed_dim)
|
| 357 |
-
region_patches = region_patches_xy.reshape(1,-1, embed_dim)
|
| 358 |
-
if cls_token is not None:
|
| 359 |
-
inputs = torch.cat([cls_cur_img, region_patches], dim=1) # Concatenate along the token dimension (1, num_patches + 1, embed_dim)
|
| 360 |
-
if registers_tokens is not None:
|
| 361 |
-
inputs = torch.cat([cls_cur_img, cur_img_register_tokens, region_patches], dim=1) # Concatenate along the token dimension (1, num_patches + 5, embed_dim)
|
| 362 |
-
else:
|
| 363 |
-
inputs = torch.cat([region_patches], dim=1) # Concatenate along the token dimension (1, num_patches + 1, embed_dim)
|
| 364 |
-
|
| 365 |
-
outputs = dino_model.blocks[-1](inputs) # Forward pass
|
| 366 |
-
# shape (1, 1 + len(region_patches), 768)
|
| 367 |
-
#cls_cur_img = cls_tokens[i]
|
| 368 |
-
#cls_cur_img = cls_cur_img.reshape(1, embed_dim)
|
| 369 |
-
#inputs = torch.cat([cls_cur_img, region_patches], dim=0) # Add CLS token to inputs
|
| 370 |
-
#outputs = dino_model.blocks[-1](inputs) # Forward pass
|
| 371 |
-
|
| 372 |
-
batch_outputs.append(outputs)
|
| 373 |
-
|
| 374 |
-
region_patches = outputs[0, patches_offset: ,] #(1,45,768) -> (1,1,768)
|
| 375 |
-
|
| 376 |
-
if return_type == "gaussian_avg":
|
| 377 |
-
#region_patches = outputs[5: ,]
|
| 378 |
-
h_span, w_span = region_patches_xy.shape[:2]
|
| 379 |
-
y_coords, x_coords = torch.meshgrid(
|
| 380 |
-
torch.linspace(-1, 1, h_span),
|
| 381 |
-
torch.linspace(-1, 1, w_span),
|
| 382 |
-
indexing="ij"
|
| 383 |
-
)
|
| 384 |
-
distances = x_coords**2 + y_coords**2
|
| 385 |
-
gaussian_weights = torch.exp(-distances / gaussian_bbox_variance) # Adjust 0.1 for variance control
|
| 386 |
-
gaussian_weights = gaussian_weights / gaussian_weights.sum() # Normalize to sum to 1
|
| 387 |
-
|
| 388 |
-
# Apply Gaussian weights to region patches
|
| 389 |
-
weighted_patches = region_patches_xy * gaussian_weights.to(next(dino_model.parameters()).device).unsqueeze(-1) # (h, w, embed_dim)
|
| 390 |
-
region_mean = weighted_patches.sum(dim=(0,1)) # Weighted mean
|
| 391 |
-
#image_means.append(region_mean)
|
| 392 |
-
elif return_type == "avg":
|
| 393 |
-
# Compute mean of the region
|
| 394 |
-
region_mean = region_patches.mean(dim=(0)) # Mean over h, w
|
| 395 |
-
elif return_type == "cls":
|
| 396 |
-
region_mean = outputs[0, 0, ]
|
| 397 |
-
image_means.append(region_mean)
|
| 398 |
-
|
| 399 |
-
means.append(torch.stack(image_means))
|
| 400 |
-
|
| 401 |
-
stacked_means = torch.stack(means)
|
| 402 |
-
#stacked_means = stacked_means.reshape(-1, embed_dim)
|
| 403 |
-
return stacked_means
|
| 404 |
-
|
| 405 |
-
|
| 406 |
-
def process_bboxes(imgs, bboxes, transform):
|
| 407 |
-
transformed_bboxes = []
|
| 408 |
-
bboxes = bboxes.tolist()
|
| 409 |
-
for img, img_bboxes in zip(imgs, bboxes):
|
| 410 |
-
for bbox in img_bboxes:
|
| 411 |
-
# Crop the region defined by bbox
|
| 412 |
-
x_min, y_min, w, h = bbox
|
| 413 |
-
x_max = x_min + w
|
| 414 |
-
y_max = y_min + h
|
| 415 |
-
cropped_region = img.crop((x_min, y_min, x_max, y_max))
|
| 416 |
-
|
| 417 |
-
# Apply the transform to the cropped region
|
| 418 |
-
transformed_region = transform(cropped_region)
|
| 419 |
-
transformed_bboxes.append(transformed_region)
|
| 420 |
-
|
| 421 |
-
return torch.stack(transformed_bboxes)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
src/clipcap/CLIPCAP_INTEGRATION.md
DELETED
|
@@ -1,206 +0,0 @@
|
|
| 1 |
-
# ClipCap Integration with Patchioner Class
|
| 2 |
-
|
| 3 |
-
This document describes how ClipCap models have been integrated into the Patchioner class for DINO feature-based image captioning.
|
| 4 |
-
|
| 5 |
-
## Overview
|
| 6 |
-
|
| 7 |
-
ClipCap support has been added to the Patchioner class following the same pattern as other captioning models (VieCap, MeaCap, etc.). This integration allows you to use trained ClipCap models with DINO features for image captioning tasks.
|
| 8 |
-
|
| 9 |
-
## Architecture
|
| 10 |
-
|
| 11 |
-
### Files Added/Modified
|
| 12 |
-
|
| 13 |
-
1. **`src/clipcap/entrypoint.py`** - Main ClipCap integration module
|
| 14 |
-
- `ClipCapModel` class for DINO feature-based captioning
|
| 15 |
-
- Model classes: `ClipCaptionModel`, `ClipCaptionPrefix`, `MLP`, `TransformerMapper`
|
| 16 |
-
- Text generation utilities
|
| 17 |
-
|
| 18 |
-
2. **`src/model.py`** - Modified Patchioner class
|
| 19 |
-
- Added `clipcap_config` parameter to constructor
|
| 20 |
-
- Added ClipCap initialization logic
|
| 21 |
-
- Added ClipCap support to `caption_tokens` method
|
| 22 |
-
|
| 23 |
-
3. **Configuration Files**
|
| 24 |
-
- `configs/clipcap_dino_vitb14.k.yaml` - DINOv2-B/14 configuration
|
| 25 |
-
- `configs/clipcap_dino_vitl14.k.yaml` - DINOv2-L/14 configuration
|
| 26 |
-
|
| 27 |
-
## Configuration
|
| 28 |
-
|
| 29 |
-
### YAML Configuration Format
|
| 30 |
-
|
| 31 |
-
```yaml
|
| 32 |
-
decap_weights: '/path/to/decap/weights.pt'
|
| 33 |
-
prefix_size: 768 # DINO feature dimension
|
| 34 |
-
support_memory_size: 0
|
| 35 |
-
dino_model: 'dinov2_vitb14'
|
| 36 |
-
normalize: True
|
| 37 |
-
resize_dim: 518
|
| 38 |
-
crop_dim: 518
|
| 39 |
-
use_talk2dino_project: False
|
| 40 |
-
|
| 41 |
-
# ClipCap configuration
|
| 42 |
-
clipcap:
|
| 43 |
-
language_model: 'gpt2'
|
| 44 |
-
prefix_length: 10 # Sequence length for prefix
|
| 45 |
-
clip_length: 10 # CLIP sequence length (for transformer mapping)
|
| 46 |
-
num_layers: 8 # Number of transformer layers (for transformer mapping)
|
| 47 |
-
mapping_type: 'mlp' # 'mlp' or 'transformer'
|
| 48 |
-
only_prefix: True # Train only prefix mapping vs full model
|
| 49 |
-
temperature: 1.0 # Sampling temperature
|
| 50 |
-
top_p: 0.8 # Nucleus sampling parameter
|
| 51 |
-
entry_length: 67 # Maximum caption length
|
| 52 |
-
stop_token: '.' # Stop token for generation
|
| 53 |
-
weight_path: '/path/to/trained/clipcap/model.pt'
|
| 54 |
-
```
|
| 55 |
-
|
| 56 |
-
### Supported DINO Models
|
| 57 |
-
|
| 58 |
-
The integration automatically detects DINO feature dimensions:
|
| 59 |
-
|
| 60 |
-
- **DINOv2-S/14**: 384 dimensions (`dinov2_vits14`)
|
| 61 |
-
- **DINOv2-B/14**: 768 dimensions (`dinov2_vitb14`)
|
| 62 |
-
- **DINOv2-L/14**: 1024 dimensions (`dinov2_vitl14`)
|
| 63 |
-
- **DINOv2-G/14**: 1536 dimensions (`dinov2_vitg14`)
|
| 64 |
-
|
| 65 |
-
## Usage
|
| 66 |
-
|
| 67 |
-
### 1. Training ClipCap Models
|
| 68 |
-
|
| 69 |
-
First, train your ClipCap model with DINO features:
|
| 70 |
-
|
| 71 |
-
```bash
|
| 72 |
-
# Extract DINO features
|
| 73 |
-
python clipcap_dino_parse_coco.py --dino_model_type dinov2_vitb14
|
| 74 |
-
|
| 75 |
-
# Train ClipCap model
|
| 76 |
-
python clipcapTraining.py \
|
| 77 |
-
--use_dino \
|
| 78 |
-
--dino_model_type dinov2_vitb14 \
|
| 79 |
-
--prefix_length 10 \
|
| 80 |
-
--mapping_type mlp \
|
| 81 |
-
--only_prefix \
|
| 82 |
-
--epochs 10
|
| 83 |
-
```
|
| 84 |
-
|
| 85 |
-
### 2. Using ClipCap with Patchioner
|
| 86 |
-
|
| 87 |
-
```python
|
| 88 |
-
import torch
|
| 89 |
-
from src.model import Patchioner
|
| 90 |
-
|
| 91 |
-
# Load model with ClipCap configuration
|
| 92 |
-
device = torch.device('cuda')
|
| 93 |
-
model = Patchioner.from_config('configs/clipcap_dino_vitb14.k.yaml', device)
|
| 94 |
-
|
| 95 |
-
# Generate captions from images
|
| 96 |
-
imgs = torch.randn(2, 3, 518, 518).to(device) # Example batch
|
| 97 |
-
results = model.forward(imgs, get_cls_capt=True)
|
| 98 |
-
captions = results['cls_capt']
|
| 99 |
-
|
| 100 |
-
print("Generated captions:")
|
| 101 |
-
for i, caption in enumerate(captions):
|
| 102 |
-
print(f"Image {i+1}: {caption}")
|
| 103 |
-
```
|
| 104 |
-
|
| 105 |
-
### 3. Using ClipCap Directly
|
| 106 |
-
|
| 107 |
-
```python
|
| 108 |
-
from src.clipcap.entrypoint import ClipCapModel
|
| 109 |
-
import torch
|
| 110 |
-
|
| 111 |
-
# Configuration
|
| 112 |
-
config = {
|
| 113 |
-
'language_model': 'gpt2',
|
| 114 |
-
'prefix_length': 10,
|
| 115 |
-
'mapping_type': 'mlp',
|
| 116 |
-
'only_prefix': True,
|
| 117 |
-
'weight_path': '/path/to/trained/model.pt'
|
| 118 |
-
}
|
| 119 |
-
|
| 120 |
-
# Initialize model
|
| 121 |
-
device = torch.device('cuda')
|
| 122 |
-
clipcap = ClipCapModel(config, device, dino_feature_dim=768)
|
| 123 |
-
|
| 124 |
-
# Generate captions from DINO features
|
| 125 |
-
dino_features = torch.randn(2, 768).to(device)
|
| 126 |
-
captions = clipcap.forward(dino_features)
|
| 127 |
-
|
| 128 |
-
print(captions)
|
| 129 |
-
```
|
| 130 |
-
|
| 131 |
-
## Performance Improvements
|
| 132 |
-
|
| 133 |
-
### Batched Text Generation
|
| 134 |
-
|
| 135 |
-
The ClipCap integration includes an efficient batched text generation implementation:
|
| 136 |
-
|
| 137 |
-
- **`generate_batched()`**: Processes entire batches simultaneously
|
| 138 |
-
- **Significant speedup**: 2-8x faster than sequential processing
|
| 139 |
-
- **Memory efficient**: Optimized for GPU memory usage
|
| 140 |
-
- **Configurable**: Can fallback to sequential mode if needed
|
| 141 |
-
|
| 142 |
-
### Configuration Options
|
| 143 |
-
|
| 144 |
-
```yaml
|
| 145 |
-
clipcap:
|
| 146 |
-
use_batched_generation: True # Enable batched generation (recommended)
|
| 147 |
-
temperature: 1.0 # Sampling temperature
|
| 148 |
-
top_p: 0.8 # Nucleus sampling parameter
|
| 149 |
-
entry_length: 67 # Maximum sequence length
|
| 150 |
-
```
|
| 151 |
-
|
| 152 |
-
## Model Architecture Details
|
| 153 |
-
|
| 154 |
-
### ClipCap Model Structure
|
| 155 |
-
|
| 156 |
-
1. **Input**: DINO features (384/768/1024/1536 dimensions)
|
| 157 |
-
2. **Mapping Layer**:
|
| 158 |
-
- **MLP**: `DINO_dim โ GPT2_dim * prefix_length`
|
| 159 |
-
- **Transformer**: Multi-layer transformer mapping
|
| 160 |
-
3. **GPT-2 Decoder**: Pretrained GPT-2 for text generation
|
| 161 |
-
4. **Output**: Natural language captions
|
| 162 |
-
|
| 163 |
-
### Key Components
|
| 164 |
-
|
| 165 |
-
- **`ClipCapModel`**: Main class for DINO-to-text captioning
|
| 166 |
-
- **`MLP`/`TransformerMapper`**: Feature mapping from DINO to GPT-2 space
|
| 167 |
-
- **Text Generation**: Nucleus sampling with configurable parameters
|
| 168 |
-
|
| 169 |
-
## Integration with Existing Pipeline
|
| 170 |
-
|
| 171 |
-
The ClipCap integration follows the established pattern:
|
| 172 |
-
|
| 173 |
-
1. **Configuration**: YAML-based configuration like other models
|
| 174 |
-
2. **Initialization**: Automatic DINO dimension detection
|
| 175 |
-
3. **Forward Pass**: Seamless integration with existing forward methods
|
| 176 |
-
4. **Scoring**: Optional confidence scoring support
|
| 177 |
-
|
| 178 |
-
## Testing
|
| 179 |
-
|
| 180 |
-
Run the integration test:
|
| 181 |
-
|
| 182 |
-
```bash
|
| 183 |
-
python test_clipcap_integration.py
|
| 184 |
-
```
|
| 185 |
-
|
| 186 |
-
This test verifies:
|
| 187 |
-
- Configuration loading from YAML
|
| 188 |
-
- Model instantiation with ClipCap
|
| 189 |
-
- Caption generation with dummy DINO features
|
| 190 |
-
- Score computation functionality
|
| 191 |
-
|
| 192 |
-
|
| 193 |
-
## Troubleshooting
|
| 194 |
-
|
| 195 |
-
### Common Issues
|
| 196 |
-
|
| 197 |
-
1. **Dimension Mismatch**: Ensure `prefix_size` matches DINO model dimension
|
| 198 |
-
2. **Missing Weights**: Verify `weight_path` points to trained ClipCap model
|
| 199 |
-
3. **Memory Issues**: Use `only_prefix=True` for lower memory usage
|
| 200 |
-
4. **Generation Quality**: Tune `temperature`, `top_p`, and `entry_length`
|
| 201 |
-
|
| 202 |
-
## References
|
| 203 |
-
|
| 204 |
-
- [ClipCap Paper](https://arxiv.org/abs/2111.09734)
|
| 205 |
-
- [DINO Paper](https://arxiv.org/abs/2104.14294)
|
| 206 |
-
- [DINOv2 Paper](https://arxiv.org/abs/2304.07193)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
src/clipcap/clipcapTrainREADME.md
DELETED
|
@@ -1,301 +0,0 @@
|
|
| 1 |
-
# ClipCap Training with DINO Features - README
|
| 2 |
-
|
| 3 |
-
This guide provides instructions for training ClipCap with DINO visual features instead of CLIP features.
|
| 4 |
-
|
| 5 |
-
## Prerequisites
|
| 6 |
-
|
| 7 |
-
1. Ensure you have the required dependencies installed:
|
| 8 |
-
- PyTorch
|
| 9 |
-
- torchvision
|
| 10 |
-
- transformers
|
| 11 |
-
- tqdm
|
| 12 |
-
- Pillow
|
| 13 |
-
- scikit-image
|
| 14 |
-
|
| 15 |
-
2. Prepare your COCO dataset with the following structure:
|
| 16 |
-
```
|
| 17 |
-
./data/coco/
|
| 18 |
-
โโโ annotations/
|
| 19 |
-
โ โโโ train_caption.json
|
| 20 |
-
โโโ train2014/
|
| 21 |
-
โ โโโ COCO_train2014_*.jpg
|
| 22 |
-
โโโ val2014/
|
| 23 |
-
โโโ COCO_val2014_*.jpg
|
| 24 |
-
```
|
| 25 |
-
|
| 26 |
-
## Required Files for DINO Feature Extraction
|
| 27 |
-
|
| 28 |
-
To start the DINO feature extraction for the COCO dataset, you need:
|
| 29 |
-
|
| 30 |
-
### 1. **COCO Dataset Structure**:
|
| 31 |
-
```
|
| 32 |
-
/raid/datasets/coco/ # Main COCO directory (default)
|
| 33 |
-
โโโ train2014/ # REQUIRED: Training images
|
| 34 |
-
โ โโโ COCO_train2014_*.jpg # Image files
|
| 35 |
-
โโโ val2014/ # REQUIRED: Validation images
|
| 36 |
-
โ โโโ COCO_val2014_*.jpg # Image files
|
| 37 |
-
โโโ train_split_karpathy.json # REQUIRED: Karpathy format annotations (default)
|
| 38 |
-
```
|
| 39 |
-
|
| 40 |
-
### 2. **Required Files**:
|
| 41 |
-
- **`train_split_karpathy.json`**: COCO caption annotations in Karpathy format (default)
|
| 42 |
-
- **Training images**: COCO 2014 training set (COCO_train2014_*.jpg)
|
| 43 |
-
- **Validation images**: COCO 2014 validation set (COCO_val2014_*.jpg)
|
| 44 |
-
|
| 45 |
-
### 3. **Annotation Format Support**:
|
| 46 |
-
|
| 47 |
-
The script supports two annotation formats:
|
| 48 |
-
|
| 49 |
-
#### **A. Karpathy Format** (default, recommended):
|
| 50 |
-
```json
|
| 51 |
-
{
|
| 52 |
-
"images": [
|
| 53 |
-
{"id": 522418, "file_name": "COCO_val2014_000000522418.jpg"}
|
| 54 |
-
],
|
| 55 |
-
"annotations": [
|
| 56 |
-
{"image_id": 522418, "id": 0, "caption": "A woman wearing a net..."}
|
| 57 |
-
]
|
| 58 |
-
}
|
| 59 |
-
```
|
| 60 |
-
|
| 61 |
-
#### **B. ClipCap Format** (legacy):
|
| 62 |
-
```json
|
| 63 |
-
[
|
| 64 |
-
{"image_id": 522418, "caption": "A woman wearing a net..."}
|
| 65 |
-
]
|
| 66 |
-
```
|
| 67 |
-
|
| 68 |
-
### 3. **Specifying Custom Input/Output Paths**:
|
| 69 |
-
|
| 70 |
-
You can customize the paths using command-line arguments:
|
| 71 |
-
|
| 72 |
-
```bash
|
| 73 |
-
python clipcap_dino_parse_coco.py \
|
| 74 |
-
--dino_model_type dinov2_vitb14 \
|
| 75 |
-
--coco_images_dir "/path/to/your/coco/dataset" \
|
| 76 |
-
--captions_file "/path/to/your/train_caption.json" \
|
| 77 |
-
--output_file "/path/to/output/dino_features.pkl"
|
| 78 |
-
```
|
| 79 |
-
|
| 80 |
-
**Available path arguments**:
|
| 81 |
-
- `--coco_images_dir`: Path to COCO images directory (should contain `train2014/` and `val2014/` subdirs) - **Default: `/raid/datasets/coco`**
|
| 82 |
-
- `--captions_file`: Path to COCO captions JSON file (supports both Karpathy and ClipCap formats) - **Default: `/raid/datasets/coco/train_split_karpathy.json`**
|
| 83 |
-
- `--output_file`: Custom output file path (optional, auto-generated if not specified)
|
| 84 |
-
|
| 85 |
-
### 4. **Default Behavior** (if no paths specified):
|
| 86 |
-
```bash
|
| 87 |
-
# This will use default paths for your setup:
|
| 88 |
-
python clipcap_dino_parse_coco.py --dino_model_type dinov2_vitb14
|
| 89 |
-
|
| 90 |
-
# Equivalent to:
|
| 91 |
-
python clipcap_dino_parse_coco.py \
|
| 92 |
-
--dino_model_type dinov2_vitb14 \
|
| 93 |
-
--coco_images_dir "/raid/datasets/coco" \
|
| 94 |
-
--captions_file "/raid/datasets/coco/train_split_karpathy.json" \
|
| 95 |
-
--output_file "/raid/datasets/coco/coco_karpathy_split_dinov2_vitb14_train.pkl"
|
| 96 |
-
```
|
| 97 |
-
|
| 98 |
-
## Step 1: Extract DINO Features
|
| 99 |
-
|
| 100 |
-
First, extract DINO features from the COCO images using the modified feature extraction script:
|
| 101 |
-
|
| 102 |
-
### For DINOv2-B/14 (768-dim features):
|
| 103 |
-
```bash
|
| 104 |
-
# Default paths (uses /raid/datasets/coco and Karpathy annotations)
|
| 105 |
-
python clipcap_dino_parse_coco.py --dino_model_type dinov2_vitb14 --resize_dim 518 --crop_dim 518
|
| 106 |
-
|
| 107 |
-
# Custom paths
|
| 108 |
-
python clipcap_dino_parse_coco.py \
|
| 109 |
-
--dino_model_type dinov2_vitb14 \
|
| 110 |
-
--coco_images_dir "/your/coco/path" \
|
| 111 |
-
--captions_file "/your/coco/train_split_karpathy.json" \
|
| 112 |
-
--output_file "/your/output/dino_vitb14_features.pkl"
|
| 113 |
-
```
|
| 114 |
-
|
| 115 |
-
### For DINOv2-L/14 (1024-dim features):
|
| 116 |
-
```bash
|
| 117 |
-
# Default paths
|
| 118 |
-
python clipcap_dino_parse_coco.py --dino_model_type dinov2_vitl14 --resize_dim 518 --crop_dim 518
|
| 119 |
-
|
| 120 |
-
# Custom paths
|
| 121 |
-
python clipcap_dino_parse_coco.py \
|
| 122 |
-
--dino_model_type dinov2_vitl14 \
|
| 123 |
-
--coco_images_dir "/your/coco/path" \
|
| 124 |
-
--output_file "/your/output/dino_vitl14_features.pkl"
|
| 125 |
-
```
|
| 126 |
-
|
| 127 |
-
### For DINOv2-S/14 (384-dim features):
|
| 128 |
-
```bash
|
| 129 |
-
python clipcap_dino_parse_coco.py --dino_model_type dinov2_vits14 --resize_dim 518 --crop_dim 518
|
| 130 |
-
```
|
| 131 |
-
|
| 132 |
-
### For DINOv2-G/14 (1536-dim features):
|
| 133 |
-
```bash
|
| 134 |
-
python clipcap_dino_parse_coco.py --dino_model_type dinov2_vitg14 --resize_dim 518 --crop_dim 518
|
| 135 |
-
```
|
| 136 |
-
|
| 137 |
-
**Output**: This will create a file like `/raid/datasets/models_weights/clipcap/training-features/coco_karpathy_split_dinov2_vitb14_train.pkl` (or your custom path) containing the DINO features and captions.
|
| 138 |
-
|
| 139 |
-
### Check Available Arguments:
|
| 140 |
-
```bash
|
| 141 |
-
python clipcap_dino_parse_coco.py --help
|
| 142 |
-
```
|
| 143 |
-
|
| 144 |
-
## Step 2: Train ClipCap with DINO Features
|
| 145 |
-
|
| 146 |
-
### Basic Training Command (MLP with sequence length 10):
|
| 147 |
-
|
| 148 |
-
For **DINOv2-B/14** with **MLP mapping** and **prefix length 10**:
|
| 149 |
-
```bash
|
| 150 |
-
python clipcapTraining.py \
|
| 151 |
-
--data /raid/datasets/models_weights/clipcap/training-features/coco_karpathy_split_dinov2_vitb14_train.pkl \
|
| 152 |
-
--out_dir ./checkpoints_dino_vitb14_mlp_len10 \
|
| 153 |
-
--prefix dino_vitb14_mlp_len10 \
|
| 154 |
-
--epochs 10 \
|
| 155 |
-
--save_every 2 \
|
| 156 |
-
--prefix_length 10 \
|
| 157 |
-
--bs 32 \
|
| 158 |
-
--mapping_type mlp \
|
| 159 |
-
--use_dino \
|
| 160 |
-
--dino_model_type dinov2_vitb14 \
|
| 161 |
-
--only_prefix
|
| 162 |
-
```
|
| 163 |
-
|
| 164 |
-
### Training Options for Different DINO Models:
|
| 165 |
-
|
| 166 |
-
#### DINOv2-L/14 (1024-dim):
|
| 167 |
-
```bash
|
| 168 |
-
python clipcapTraining.py \
|
| 169 |
-
--data ./data/coco/coco_karpathy_split_dinov2_vitl14_train.pkl \
|
| 170 |
-
--out_dir ./checkpoints_dino_vitl14_mlp_len10 \
|
| 171 |
-
--prefix dino_vitl14_mlp_len10 \
|
| 172 |
-
--epochs 10 \
|
| 173 |
-
--save_every 2 \
|
| 174 |
-
--prefix_length 10 \
|
| 175 |
-
--bs 32 \
|
| 176 |
-
--mapping_type mlp \
|
| 177 |
-
--use_dino \
|
| 178 |
-
--dino_model_type dinov2_vitl14 \
|
| 179 |
-
--only_prefix
|
| 180 |
-
```
|
| 181 |
-
|
| 182 |
-
#### DINOv2-S/14 (384-dim):
|
| 183 |
-
```bash
|
| 184 |
-
python clipcapTraining.py \
|
| 185 |
-
--data ./data/coco/coco_karpathy_split_dinov2_vits14_train.pkl \
|
| 186 |
-
--out_dir ./checkpoints_dino_vits14_mlp_len10 \
|
| 187 |
-
--prefix dino_vits14_mlp_len10 \
|
| 188 |
-
--epochs 10 \
|
| 189 |
-
--save_every 2 \
|
| 190 |
-
--prefix_length 10 \
|
| 191 |
-
--bs 32 \
|
| 192 |
-
--mapping_type mlp \
|
| 193 |
-
--use_dino \
|
| 194 |
-
--dino_model_type dinov2_vits14 \
|
| 195 |
-
--only_prefix
|
| 196 |
-
```
|
| 197 |
-
|
| 198 |
-
### Advanced Training Options:
|
| 199 |
-
|
| 200 |
-
#### Train both prefix and GPT (full model):
|
| 201 |
-
```bash
|
| 202 |
-
python clipcapTraining.py \
|
| 203 |
-
--data /raid/datasets/models_weights/clipcap/training-features/coco_karpathy_split_dinov2_vitb14_train.pkl \
|
| 204 |
-
--out_dir ./checkpoints_dino_vitb14_mlp_len10_full \
|
| 205 |
-
--prefix dino_vitb14_mlp_len10_full \
|
| 206 |
-
--epochs 10 \
|
| 207 |
-
--save_every 2 \
|
| 208 |
-
--prefix_length 10 \
|
| 209 |
-
--bs 16 \
|
| 210 |
-
--mapping_type mlp \
|
| 211 |
-
--use_dino \
|
| 212 |
-
--dino_model_type dinov2_vitb14
|
| 213 |
-
```
|
| 214 |
-
|
| 215 |
-
#### Use Transformer mapping instead of MLP:
|
| 216 |
-
```bash
|
| 217 |
-
python clipcapTraining.py \
|
| 218 |
-
--data /raid/datasets/models_weights/clipcap/training-features/coco_karpathy_split_dinov2_vitb14_train.pkl \
|
| 219 |
-
--out_dir ./checkpoints_dino_vitb14_transformer_len10 \
|
| 220 |
-
--prefix dino_vitb14_transformer_len10 \
|
| 221 |
-
--epochs 10 \
|
| 222 |
-
--save_every 2 \
|
| 223 |
-
--prefix_length 10 \
|
| 224 |
-
--bs 32 \
|
| 225 |
-
--mapping_type transformer \
|
| 226 |
-
--num_layers 8 \
|
| 227 |
-
--use_dino \
|
| 228 |
-
--dino_model_type dinov2_vitb14 \
|
| 229 |
-
--only_prefix
|
| 230 |
-
```
|
| 231 |
-
|
| 232 |
-
#### Custom feature dimension (if needed):
|
| 233 |
-
```bash
|
| 234 |
-
python clipcapTraining.py \
|
| 235 |
-
--data ./data/coco/coco_karpathy_split_dinov2_vitb14_train.pkl \
|
| 236 |
-
--out_dir ./checkpoints_dino_custom \
|
| 237 |
-
--prefix dino_custom \
|
| 238 |
-
--epochs 10 \
|
| 239 |
-
--prefix_length 10 \
|
| 240 |
-
--bs 32 \
|
| 241 |
-
--mapping_type mlp \
|
| 242 |
-
--use_dino \
|
| 243 |
-
--dino_model_type dinov2_vitb14 \
|
| 244 |
-
--dino_feature_dim 768 \
|
| 245 |
-
--only_prefix
|
| 246 |
-
```
|
| 247 |
-
|
| 248 |
-
## Key Parameters Explanation:
|
| 249 |
-
|
| 250 |
-
- `--use_dino`: Enable DINO mode (required for DINO training)
|
| 251 |
-
- `--dino_model_type`: Specify which DINO model was used for feature extraction
|
| 252 |
-
- `--dino_feature_dim`: Override automatic feature dimension detection
|
| 253 |
-
- `--prefix_length`: Number of prefix tokens (set to 10 as requested)
|
| 254 |
-
- `--mapping_type`: Choose between 'mlp' or 'transformer' mapping
|
| 255 |
-
- `--only_prefix`: Train only the mapping layer, freeze GPT-2
|
| 256 |
-
- `--bs`: Batch size (adjust based on GPU memory)
|
| 257 |
-
- `--epochs`: Number of training epochs
|
| 258 |
-
- `--save_every`: Save checkpoint every N epochs
|
| 259 |
-
|
| 260 |
-
## Expected Feature Dimensions:
|
| 261 |
-
|
| 262 |
-
- **DINOv2-S/14**: 384 dimensions
|
| 263 |
-
- **DINOv2-B/14**: 768 dimensions
|
| 264 |
-
- **DINOv2-L/14**: 1024 dimensions
|
| 265 |
-
- **DINOv2-G/14**: 1536 dimensions
|
| 266 |
-
|
| 267 |
-
## Training Tips:
|
| 268 |
-
|
| 269 |
-
1. **Memory Usage**: DINO features are typically larger than CLIP features, so you might need to reduce batch size
|
| 270 |
-
2. **Convergence**: DINO-based models may require different learning rates or longer training
|
| 271 |
-
3. **Prefix Length**: Experiment with different prefix lengths (5, 10, 20) for optimal performance
|
| 272 |
-
4. **Mapping Type**: MLP is faster, Transformer might give better results but requires more memory
|
| 273 |
-
|
| 274 |
-
## Output:
|
| 275 |
-
|
| 276 |
-
The training will save checkpoints in the specified output directory:
|
| 277 |
-
- `{prefix}-{epoch:03d}.pt`: Model checkpoint for each epoch
|
| 278 |
-
- `{prefix}_latest.pt`: Latest model checkpoint (updated every 10k iterations)
|
| 279 |
-
- `{prefix}.json`: Training configuration
|
| 280 |
-
|
| 281 |
-
## Example Full Workflow:
|
| 282 |
-
|
| 283 |
-
```bash
|
| 284 |
-
# 1. Extract DINO features
|
| 285 |
-
python clipcap_dino_parse_coco.py --dino_model_type dinov2_vitb14
|
| 286 |
-
|
| 287 |
-
# 2. Train ClipCap with DINO features (MLP, length 10, prefix-only)
|
| 288 |
-
python clipcapTraining.py \
|
| 289 |
-
--data /raid/datasets/models_weights/clipcap/training-features/coco_karpathy_split_dinov2_vitb14_train.pkl \
|
| 290 |
-
--out_dir ./checkpoints_dino_vitb14_mlp_len10 \
|
| 291 |
-
--prefix dino_vitb14_mlp_len10 \
|
| 292 |
-
--epochs 10 \
|
| 293 |
-
--prefix_length 10 \
|
| 294 |
-
--bs 32 \
|
| 295 |
-
--mapping_type mlp \
|
| 296 |
-
--use_dino \
|
| 297 |
-
--dino_model_type dinov2_vitb14 \
|
| 298 |
-
--only_prefix
|
| 299 |
-
```
|
| 300 |
-
|
| 301 |
-
This will train a ClipCap model using DINO features with MLP mapping and sequence length 10 as requested.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
src/clipcap/clipcapTraining.py
DELETED
|
@@ -1,405 +0,0 @@
|
|
| 1 |
-
import torch
|
| 2 |
-
import torch.nn as nn
|
| 3 |
-
from torch.nn import functional as nnf
|
| 4 |
-
from torch.utils.data import Dataset, DataLoader
|
| 5 |
-
from enum import Enum
|
| 6 |
-
from transformers import GPT2Tokenizer, GPT2LMHeadModel, AdamW, get_linear_schedule_with_warmup
|
| 7 |
-
from tqdm import tqdm
|
| 8 |
-
import os
|
| 9 |
-
import pickle
|
| 10 |
-
import sys
|
| 11 |
-
import argparse
|
| 12 |
-
import json
|
| 13 |
-
from typing import Tuple, Optional, Union
|
| 14 |
-
|
| 15 |
-
|
| 16 |
-
class MappingType(Enum):
|
| 17 |
-
MLP = 'mlp'
|
| 18 |
-
Transformer = 'transformer'
|
| 19 |
-
|
| 20 |
-
|
| 21 |
-
class ClipCocoDataset(Dataset):
|
| 22 |
-
|
| 23 |
-
def __len__(self) -> int:
|
| 24 |
-
return len(self.captions_tokens)
|
| 25 |
-
|
| 26 |
-
def pad_tokens(self, item: int):
|
| 27 |
-
tokens = self.captions_tokens[item]
|
| 28 |
-
padding = self.max_seq_len - tokens.shape[0]
|
| 29 |
-
if padding > 0:
|
| 30 |
-
tokens = torch.cat((tokens, torch.zeros(padding, dtype=torch.int64) - 1))
|
| 31 |
-
self.captions_tokens[item] = tokens
|
| 32 |
-
elif padding < 0:
|
| 33 |
-
tokens = tokens[:self.max_seq_len]
|
| 34 |
-
self.captions_tokens[item] = tokens
|
| 35 |
-
mask = tokens.ge(0) # mask is zero where we out of sequence
|
| 36 |
-
tokens[~mask] = 0
|
| 37 |
-
mask = mask.float()
|
| 38 |
-
mask = torch.cat((torch.ones(self.prefix_length), mask), dim=0) # adding prefix mask
|
| 39 |
-
return tokens, mask
|
| 40 |
-
|
| 41 |
-
def __getitem__(self, item: int) -> Tuple[torch.Tensor, ...]:
|
| 42 |
-
tokens, mask = self.pad_tokens(item)
|
| 43 |
-
prefix = self.prefixes[self.caption2embedding[item]]
|
| 44 |
-
if self.normalize_prefix:
|
| 45 |
-
prefix = prefix.float()
|
| 46 |
-
prefix = prefix / prefix.norm(2, -1)
|
| 47 |
-
return tokens, mask, prefix
|
| 48 |
-
|
| 49 |
-
def __init__(self, data_path: str, prefix_length: int, gpt2_type: str = "gpt2",
|
| 50 |
-
normalize_prefix=False):
|
| 51 |
-
self.tokenizer = GPT2Tokenizer.from_pretrained(gpt2_type)
|
| 52 |
-
self.prefix_length = prefix_length
|
| 53 |
-
self.normalize_prefix = normalize_prefix
|
| 54 |
-
with open(data_path, 'rb') as f:
|
| 55 |
-
all_data = pickle.load(f)
|
| 56 |
-
print("Data size is %0d" % len(all_data["clip_embedding"]))
|
| 57 |
-
sys.stdout.flush()
|
| 58 |
-
self.prefixes = all_data["clip_embedding"]
|
| 59 |
-
captions_raw = all_data["captions"]
|
| 60 |
-
self.image_ids = [caption["image_id"] for caption in captions_raw]
|
| 61 |
-
self.captions = [caption['caption'] for caption in captions_raw]
|
| 62 |
-
if os.path.isfile(f"{data_path[:-4]}_tokens.pkl"):
|
| 63 |
-
with open(f"{data_path[:-4]}_tokens.pkl", 'rb') as f:
|
| 64 |
-
self.captions_tokens, self.caption2embedding, self.max_seq_len = pickle.load(f)
|
| 65 |
-
else:
|
| 66 |
-
self.captions_tokens = []
|
| 67 |
-
self.caption2embedding = []
|
| 68 |
-
max_seq_len = 0
|
| 69 |
-
for caption in captions_raw:
|
| 70 |
-
self.captions_tokens.append(torch.tensor(self.tokenizer.encode(caption['caption']), dtype=torch.int64))
|
| 71 |
-
self.caption2embedding.append(caption["clip_embedding"])
|
| 72 |
-
max_seq_len = max(max_seq_len, self.captions_tokens[-1].shape[0])
|
| 73 |
-
# self.max_seq_len = max_seq_len
|
| 74 |
-
with open(f"{data_path[:-4]}_tokens.pkl", 'wb') as f:
|
| 75 |
-
pickle.dump([self.captions_tokens, self.caption2embedding, max_seq_len], f)
|
| 76 |
-
all_len = torch.tensor([len(self.captions_tokens[i]) for i in range(len(self))]).float()
|
| 77 |
-
self.max_seq_len = min(int(all_len.mean() + all_len.std() * 10), int(all_len.max()))
|
| 78 |
-
|
| 79 |
-
|
| 80 |
-
class MLP(nn.Module):
|
| 81 |
-
|
| 82 |
-
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| 83 |
-
return self.model(x)
|
| 84 |
-
|
| 85 |
-
def __init__(self, sizes: Tuple[int, ...], bias=True, act=nn.Tanh):
|
| 86 |
-
super(MLP, self).__init__()
|
| 87 |
-
layers = []
|
| 88 |
-
for i in range(len(sizes) - 1):
|
| 89 |
-
layers.append(nn.Linear(sizes[i], sizes[i + 1], bias=bias))
|
| 90 |
-
if i < len(sizes) - 2:
|
| 91 |
-
layers.append(act())
|
| 92 |
-
self.model = nn.Sequential(*layers)
|
| 93 |
-
|
| 94 |
-
|
| 95 |
-
class MlpTransformer(nn.Module):
|
| 96 |
-
def __init__(self, in_dim, h_dim, out_d: Optional[int] = None, act=nnf.relu, dropout=0.):
|
| 97 |
-
super().__init__()
|
| 98 |
-
out_d = out_d if out_d is not None else in_dim
|
| 99 |
-
self.fc1 = nn.Linear(in_dim, h_dim)
|
| 100 |
-
self.act = act
|
| 101 |
-
self.fc2 = nn.Linear(h_dim, out_d)
|
| 102 |
-
self.dropout = nn.Dropout(dropout)
|
| 103 |
-
|
| 104 |
-
def forward(self, x):
|
| 105 |
-
x = self.fc1(x)
|
| 106 |
-
x = self.act(x)
|
| 107 |
-
x = self.dropout(x)
|
| 108 |
-
x = self.fc2(x)
|
| 109 |
-
x = self.dropout(x)
|
| 110 |
-
return x
|
| 111 |
-
|
| 112 |
-
class MultiHeadAttention(nn.Module):
|
| 113 |
-
|
| 114 |
-
def __init__(self, dim_self, dim_ref, num_heads, bias=True, dropout=0.):
|
| 115 |
-
super().__init__()
|
| 116 |
-
self.num_heads = num_heads
|
| 117 |
-
head_dim = dim_self // num_heads
|
| 118 |
-
self.scale = head_dim ** -0.5
|
| 119 |
-
self.to_queries = nn.Linear(dim_self, dim_self, bias=bias)
|
| 120 |
-
self.to_keys_values = nn.Linear(dim_ref, dim_self * 2, bias=bias)
|
| 121 |
-
self.project = nn.Linear(dim_self, dim_self)
|
| 122 |
-
self.dropout = nn.Dropout(dropout)
|
| 123 |
-
|
| 124 |
-
def forward(self, x, y=None, mask=None):
|
| 125 |
-
y = y if y is not None else x
|
| 126 |
-
b, n, c = x.shape
|
| 127 |
-
_, m, d = y.shape
|
| 128 |
-
# b n h dh
|
| 129 |
-
queries = self.to_queries(x).reshape(b, n, self.num_heads, c // self.num_heads)
|
| 130 |
-
# b m 2 h dh
|
| 131 |
-
keys_values = self.to_keys_values(y).reshape(b, m, 2, self.num_heads, c // self.num_heads)
|
| 132 |
-
keys, values = keys_values[:, :, 0], keys_values[:, :, 1]
|
| 133 |
-
attention = torch.einsum('bnhd,bmhd->bnmh', queries, keys) * self.scale
|
| 134 |
-
if mask is not None:
|
| 135 |
-
if mask.dim() == 2:
|
| 136 |
-
mask = mask.unsqueeze(1)
|
| 137 |
-
attention = attention.masked_fill(mask.unsqueeze(3), float("-inf"))
|
| 138 |
-
attention = attention.softmax(dim=2)
|
| 139 |
-
out = torch.einsum('bnmh,bmhd->bnhd', attention, values).reshape(b, n, c)
|
| 140 |
-
out = self.project(out)
|
| 141 |
-
return out, attention
|
| 142 |
-
|
| 143 |
-
|
| 144 |
-
class TransformerLayer(nn.Module):
|
| 145 |
-
|
| 146 |
-
def forward_with_attention(self, x, y=None, mask=None):
|
| 147 |
-
x_, attention = self.attn(self.norm1(x), y, mask)
|
| 148 |
-
x = x + x_
|
| 149 |
-
x = x + self.mlp(self.norm2(x))
|
| 150 |
-
return x, attention
|
| 151 |
-
|
| 152 |
-
def forward(self, x, y=None, mask=None):
|
| 153 |
-
x = x + self.attn(self.norm1(x), y, mask)[0]
|
| 154 |
-
x = x + self.mlp(self.norm2(x))
|
| 155 |
-
return x
|
| 156 |
-
|
| 157 |
-
def __init__(self, dim_self, dim_ref, num_heads, mlp_ratio=4., bias=False, dropout=0., act=nnf.relu,
|
| 158 |
-
norm_layer: nn.Module = nn.LayerNorm):
|
| 159 |
-
super().__init__()
|
| 160 |
-
self.norm1 = norm_layer(dim_self)
|
| 161 |
-
self.attn = MultiHeadAttention(dim_self, dim_ref, num_heads, bias=bias, dropout=dropout)
|
| 162 |
-
self.norm2 = norm_layer(dim_self)
|
| 163 |
-
self.mlp = MlpTransformer(dim_self, int(dim_self * mlp_ratio), act=act, dropout=dropout)
|
| 164 |
-
|
| 165 |
-
|
| 166 |
-
class Transformer(nn.Module):
|
| 167 |
-
|
| 168 |
-
def forward_with_attention(self, x, y=None, mask=None):
|
| 169 |
-
attentions = []
|
| 170 |
-
for layer in self.layers:
|
| 171 |
-
x, att = layer.forward_with_attention(x, y, mask)
|
| 172 |
-
attentions.append(att)
|
| 173 |
-
return x, attentions
|
| 174 |
-
|
| 175 |
-
def forward(self, x, y=None, mask=None):
|
| 176 |
-
for i, layer in enumerate(self.layers):
|
| 177 |
-
if i % 2 == 0 and self.enc_dec: # cross
|
| 178 |
-
x = layer(x, y)
|
| 179 |
-
elif self.enc_dec: # self
|
| 180 |
-
x = layer(x, x, mask)
|
| 181 |
-
else: # self or cross
|
| 182 |
-
x = layer(x, y, mask)
|
| 183 |
-
return x
|
| 184 |
-
|
| 185 |
-
def __init__(self, dim_self: int, num_heads: int, num_layers: int, dim_ref: Optional[int] = None,
|
| 186 |
-
mlp_ratio: float = 2., act=nnf.relu, norm_layer: nn.Module = nn.LayerNorm, enc_dec: bool = False):
|
| 187 |
-
super(Transformer, self).__init__()
|
| 188 |
-
dim_ref = dim_ref if dim_ref is not None else dim_self
|
| 189 |
-
self.enc_dec = enc_dec
|
| 190 |
-
if enc_dec:
|
| 191 |
-
num_layers = num_layers * 2
|
| 192 |
-
layers = []
|
| 193 |
-
for i in range(num_layers):
|
| 194 |
-
if i % 2 == 0 and enc_dec: # cross
|
| 195 |
-
layers.append(TransformerLayer(dim_self, dim_ref, num_heads, mlp_ratio, act=act, norm_layer=norm_layer))
|
| 196 |
-
elif enc_dec: # self
|
| 197 |
-
layers.append(TransformerLayer(dim_self, dim_self, num_heads, mlp_ratio, act=act, norm_layer=norm_layer))
|
| 198 |
-
else: # self or cross
|
| 199 |
-
layers.append(TransformerLayer(dim_self, dim_ref, num_heads, mlp_ratio, act=act, norm_layer=norm_layer))
|
| 200 |
-
self.layers = nn.ModuleList(layers)
|
| 201 |
-
|
| 202 |
-
|
| 203 |
-
class TransformerMapper(nn.Module):
|
| 204 |
-
|
| 205 |
-
def forward(self, x):
|
| 206 |
-
x = self.linear(x).view(x.shape[0], self.clip_length, -1)
|
| 207 |
-
prefix = self.prefix_const.unsqueeze(0).expand(x.shape[0], *self.prefix_const.shape)
|
| 208 |
-
prefix = torch.cat((x, prefix), dim=1)
|
| 209 |
-
out = self.transformer(prefix)[:, self.clip_length:]
|
| 210 |
-
return out
|
| 211 |
-
|
| 212 |
-
def __init__(self, dim_clip: int, dim_embedding: int, prefix_length: int, clip_length: int, num_layers: int = 8):
|
| 213 |
-
super(TransformerMapper, self).__init__()
|
| 214 |
-
self.clip_length = clip_length
|
| 215 |
-
self.transformer = Transformer(dim_embedding, 8, num_layers)
|
| 216 |
-
self.linear = nn.Linear(dim_clip, clip_length * dim_embedding)
|
| 217 |
-
self.prefix_const = nn.Parameter(torch.randn(prefix_length, dim_embedding), requires_grad=True)
|
| 218 |
-
|
| 219 |
-
|
| 220 |
-
class ClipCaptionModel(nn.Module):
|
| 221 |
-
|
| 222 |
-
def get_dummy_token(self, batch_size: int, device: torch.device) -> torch.Tensor:
|
| 223 |
-
return torch.zeros(batch_size, self.prefix_length, dtype=torch.int64, device=device)
|
| 224 |
-
|
| 225 |
-
def forward(self, tokens: torch.Tensor, prefix: torch.Tensor, mask: Optional[torch.Tensor] = None,
|
| 226 |
-
labels: Optional[torch.Tensor] = None):
|
| 227 |
-
embedding_text = self.gpt.transformer.wte(tokens)
|
| 228 |
-
prefix_projections = self.clip_project(prefix).view(-1, self.prefix_length, self.gpt_embedding_size)
|
| 229 |
-
embedding_cat = torch.cat((prefix_projections, embedding_text), dim=1)
|
| 230 |
-
if labels is not None:
|
| 231 |
-
dummy_token = self.get_dummy_token(tokens.shape[0], tokens.device)
|
| 232 |
-
labels = torch.cat((dummy_token, tokens), dim=1)
|
| 233 |
-
out = self.gpt(inputs_embeds=embedding_cat, labels=labels, attention_mask=mask)
|
| 234 |
-
return out
|
| 235 |
-
|
| 236 |
-
def __init__(self, prefix_length: int, clip_length: Optional[int] = None, prefix_size: int = 512,
|
| 237 |
-
num_layers: int = 8, mapping_type: MappingType = MappingType.MLP):
|
| 238 |
-
super(ClipCaptionModel, self).__init__()
|
| 239 |
-
self.prefix_length = prefix_length
|
| 240 |
-
self.gpt = GPT2LMHeadModel.from_pretrained('gpt2')
|
| 241 |
-
self.gpt_embedding_size = self.gpt.transformer.wte.weight.shape[1]
|
| 242 |
-
if mapping_type == MappingType.MLP:
|
| 243 |
-
self.clip_project = MLP((prefix_size, (self.gpt_embedding_size * prefix_length) // 2,
|
| 244 |
-
self.gpt_embedding_size * prefix_length))
|
| 245 |
-
else:
|
| 246 |
-
self.clip_project = TransformerMapper(prefix_size, self.gpt_embedding_size, prefix_length,
|
| 247 |
-
clip_length, num_layers)
|
| 248 |
-
|
| 249 |
-
|
| 250 |
-
class ClipCaptionPrefix(ClipCaptionModel):
|
| 251 |
-
|
| 252 |
-
def parameters(self, recurse: bool = True):
|
| 253 |
-
return self.clip_project.parameters()
|
| 254 |
-
|
| 255 |
-
def train(self, mode: bool = True):
|
| 256 |
-
super(ClipCaptionPrefix, self).train(mode)
|
| 257 |
-
self.gpt.eval()
|
| 258 |
-
return self
|
| 259 |
-
|
| 260 |
-
|
| 261 |
-
def save_config(args: argparse.Namespace):
|
| 262 |
-
config = {}
|
| 263 |
-
for key, item in args._get_kwargs():
|
| 264 |
-
config[key] = item
|
| 265 |
-
out_path = os.path.join(args.out_dir, f"{args.prefix}.json")
|
| 266 |
-
with open(out_path, 'w') as outfile:
|
| 267 |
-
json.dump(config, outfile)
|
| 268 |
-
|
| 269 |
-
|
| 270 |
-
def load_model(config_path: str, epoch_or_latest: Union[str, int] = '_latest'):
|
| 271 |
-
with open(config_path) as f:
|
| 272 |
-
config = json.load(f)
|
| 273 |
-
parser = argparse.ArgumentParser()
|
| 274 |
-
parser.set_defaults(**config)
|
| 275 |
-
args = parser.parse_args()
|
| 276 |
-
if type(epoch_or_latest) is int:
|
| 277 |
-
epoch_or_latest = f"-{epoch_or_latest:03d}"
|
| 278 |
-
model_path = os.path.join(args.out_dir, f"{args.prefix}{epoch_or_latest}.pt")
|
| 279 |
-
if args.only_prefix:
|
| 280 |
-
model = ClipCaptionPrefix(args.prefix_length)
|
| 281 |
-
else:
|
| 282 |
-
model = ClipCaptionModel(args.prefix_length)
|
| 283 |
-
if os.path.isfile(model_path):
|
| 284 |
-
print(f"loading model from {model_path}")
|
| 285 |
-
model.load_state_dict(torch.load(model_path, map_location=torch.device('cpu')))
|
| 286 |
-
else:
|
| 287 |
-
print(f"{model_path} is not exist")
|
| 288 |
-
return model, parser
|
| 289 |
-
|
| 290 |
-
|
| 291 |
-
def train(dataset: ClipCocoDataset, model: ClipCaptionModel, args,
|
| 292 |
-
lr: float = 2e-5, warmup_steps: int = 5000, output_dir: str = ".", output_prefix: str = "", device = torch.device('cuda:0')):
|
| 293 |
-
|
| 294 |
-
batch_size = args.bs
|
| 295 |
-
epochs = args.epochs
|
| 296 |
-
if not os.path.exists(output_dir):
|
| 297 |
-
os.makedirs(output_dir)
|
| 298 |
-
model = model.to(device)
|
| 299 |
-
model.train()
|
| 300 |
-
optimizer = AdamW(model.parameters(), lr=lr)
|
| 301 |
-
train_dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True, drop_last=True)
|
| 302 |
-
scheduler = get_linear_schedule_with_warmup(
|
| 303 |
-
optimizer, num_warmup_steps=warmup_steps, num_training_steps=epochs * len(train_dataloader)
|
| 304 |
-
)
|
| 305 |
-
# save_config(args)
|
| 306 |
-
for epoch in range(epochs):
|
| 307 |
-
print(f">>> Training epoch {epoch}")
|
| 308 |
-
sys.stdout.flush()
|
| 309 |
-
progress = tqdm(total=len(train_dataloader), desc=output_prefix)
|
| 310 |
-
for idx, (tokens, mask, prefix) in enumerate(train_dataloader):
|
| 311 |
-
model.zero_grad()
|
| 312 |
-
tokens, mask, prefix = tokens.to(device), mask.to(device), prefix.to(device, dtype=torch.float32)
|
| 313 |
-
outputs = model(tokens, prefix, mask)
|
| 314 |
-
logits = outputs.logits[:, dataset.prefix_length - 1: -1]
|
| 315 |
-
loss = nnf.cross_entropy(logits.reshape(-1, logits.shape[-1]), tokens.flatten(), ignore_index=0)
|
| 316 |
-
loss.backward()
|
| 317 |
-
optimizer.step()
|
| 318 |
-
scheduler.step()
|
| 319 |
-
optimizer.zero_grad()
|
| 320 |
-
progress.set_postfix({"loss": loss.item()})
|
| 321 |
-
progress.update()
|
| 322 |
-
if (idx + 1) % 10000 == 0:
|
| 323 |
-
torch.save(
|
| 324 |
-
model.state_dict(),
|
| 325 |
-
os.path.join(output_dir, f"{output_prefix}_latest.pt"),
|
| 326 |
-
)
|
| 327 |
-
progress.close()
|
| 328 |
-
if epoch % args.save_every == 0 or epoch == epochs - 1:
|
| 329 |
-
torch.save(
|
| 330 |
-
model.state_dict(),
|
| 331 |
-
os.path.join(output_dir, f"{output_prefix}-{epoch:03d}.pt"),
|
| 332 |
-
)
|
| 333 |
-
return model
|
| 334 |
-
|
| 335 |
-
|
| 336 |
-
def main():
|
| 337 |
-
parser = argparse.ArgumentParser()
|
| 338 |
-
parser.add_argument('--data', default='/raid/datasets/models_weights/clipcap/training-features/coco_karpathy_split_train.pkl')
|
| 339 |
-
parser.add_argument('--out_dir', default='/raid/datasets/models_weights/clipcap/checkpoints/dinov2b14/')
|
| 340 |
-
parser.add_argument('--prefix', default='coco_prefix', help='prefix for saved filenames')
|
| 341 |
-
parser.add_argument('--epochs', type=int, default=10)
|
| 342 |
-
parser.add_argument('--save_every', type=int, default=1)
|
| 343 |
-
parser.add_argument('--prefix_length', type=int, default=10)
|
| 344 |
-
parser.add_argument('--prefix_length_clip', type=int, default=10)
|
| 345 |
-
parser.add_argument('--bs', type=int, default=40)
|
| 346 |
-
parser.add_argument('--only_prefix', dest='only_prefix', action='store_true')
|
| 347 |
-
parser.add_argument('--mapping_type', type=str, default='mlp', help='mlp/transformer')
|
| 348 |
-
parser.add_argument('--num_layers', type=int, default=8)
|
| 349 |
-
parser.add_argument('--is_rn', dest='is_rn', action='store_true')
|
| 350 |
-
parser.add_argument('--normalize_prefix', dest='normalize_prefix', action='store_true')
|
| 351 |
-
# DINO-specific arguments
|
| 352 |
-
parser.add_argument('--use_dino', action='store_true', default=False, help='Use DINO features instead of CLIP')
|
| 353 |
-
parser.add_argument('--dino_model_type', type=str, default='dinov2_vitb14',
|
| 354 |
-
choices=['dinov2_vits14', 'dinov2_vitb14', 'dinov2_vitl14', 'dinov2_vitg14'],
|
| 355 |
-
help='DINO model type')
|
| 356 |
-
parser.add_argument('--dino_feature_dim', type=int, default=None,
|
| 357 |
-
help='DINO feature dimension (auto-detected if None)')
|
| 358 |
-
parser.add_argument('--device', type=str, default='cuda:0', help='Device to use for training')
|
| 359 |
-
args = parser.parse_args()
|
| 360 |
-
|
| 361 |
-
if isinstance(args.device, str):
|
| 362 |
-
if not args.device.startswith('cuda') and not args.device.startswith('cpu'):
|
| 363 |
-
# if it is an integer index, convert to f'cuda:{args.device}'
|
| 364 |
-
if args.device.isdigit():
|
| 365 |
-
args.device = f'cuda:{args.device}'
|
| 366 |
-
else:
|
| 367 |
-
raise ValueError(f"Invalid device string: {args.device}")
|
| 368 |
-
args.device = torch.device(args.device)
|
| 369 |
-
|
| 370 |
-
prefix_length = args.prefix_length
|
| 371 |
-
dataset = ClipCocoDataset(args.data, prefix_length, normalize_prefix=args.normalize_prefix)
|
| 372 |
-
|
| 373 |
-
# Determine prefix dimension based on model type
|
| 374 |
-
if args.use_dino:
|
| 375 |
-
if args.dino_feature_dim is not None:
|
| 376 |
-
prefix_dim = args.dino_feature_dim
|
| 377 |
-
else:
|
| 378 |
-
# Auto-detect DINO feature dimensions
|
| 379 |
-
dino_dims = {
|
| 380 |
-
'dinov2_vits14': 384,
|
| 381 |
-
'dinov2_vitb14': 768,
|
| 382 |
-
'dinov2_vitl14': 1024,
|
| 383 |
-
'dinov2_vitg14': 1536
|
| 384 |
-
}
|
| 385 |
-
prefix_dim = dino_dims.get(args.dino_model_type, 768)
|
| 386 |
-
print(f"Using DINO features with dimension: {prefix_dim}")
|
| 387 |
-
else:
|
| 388 |
-
prefix_dim = 640 if args.is_rn else 512
|
| 389 |
-
print(f"Using CLIP features with dimension: {prefix_dim}")
|
| 390 |
-
|
| 391 |
-
args.mapping_type = {'mlp': MappingType.MLP, 'transformer': MappingType.Transformer}[args.mapping_type]
|
| 392 |
-
if args.only_prefix:
|
| 393 |
-
model = ClipCaptionPrefix(prefix_length, clip_length=args.prefix_length_clip, prefix_size=prefix_dim,
|
| 394 |
-
num_layers=args.num_layers, mapping_type=args.mapping_type)
|
| 395 |
-
print("Train only prefix")
|
| 396 |
-
else:
|
| 397 |
-
model = ClipCaptionModel(prefix_length, clip_length=args.prefix_length_clip, prefix_size=prefix_dim,
|
| 398 |
-
num_layers=args.num_layers, mapping_type=args.mapping_type)
|
| 399 |
-
print("Train both prefix and GPT")
|
| 400 |
-
sys.stdout.flush()
|
| 401 |
-
train(dataset, model, args, output_dir=args.out_dir, output_prefix=args.prefix, device=args.device)
|
| 402 |
-
|
| 403 |
-
|
| 404 |
-
if __name__ == '__main__':
|
| 405 |
-
main()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
src/clipcap/clipcap_dino_parse_coco.py
DELETED
|
@@ -1,613 +0,0 @@
|
|
| 1 |
-
import torch
|
| 2 |
-
import torch.nn.functional as F
|
| 3 |
-
import skimage.io as io
|
| 4 |
-
from PIL import Image
|
| 5 |
-
import pickle
|
| 6 |
-
import json
|
| 7 |
-
import os
|
| 8 |
-
from tqdm import tqdm
|
| 9 |
-
import argparse
|
| 10 |
-
import torchvision.transforms as T
|
| 11 |
-
import numpy as np
|
| 12 |
-
import yaml
|
| 13 |
-
import clip
|
| 14 |
-
import sys
|
| 15 |
-
|
| 16 |
-
# Add the src directory to the path so we can import ProjectionLayer
|
| 17 |
-
sys.path.append(os.path.join(os.path.dirname(__file__), '../..', 'src'))
|
| 18 |
-
|
| 19 |
-
|
| 20 |
-
# Container to store intermediate outputs for feature extraction
|
| 21 |
-
feats = {}
|
| 22 |
-
|
| 23 |
-
def get_self_attention(module, input, output):
|
| 24 |
-
"""Hook to capture self-attention weights"""
|
| 25 |
-
global qkv_attention_out
|
| 26 |
-
qkv_attention_out = output
|
| 27 |
-
|
| 28 |
-
def get_layer_n_output(module, input, output):
|
| 29 |
-
"""Hook to capture intermediate layer output"""
|
| 30 |
-
feats['intermediate_output'] = output
|
| 31 |
-
|
| 32 |
-
def transform_to_standard_dino_out(x, model, num_global_tokens=1):
|
| 33 |
-
"""Transform raw DINO output to standardized format"""
|
| 34 |
-
x_norm = model.norm(x)
|
| 35 |
-
if num_global_tokens == 1:
|
| 36 |
-
# Standard model without registers
|
| 37 |
-
return {
|
| 38 |
-
"x_norm_clstoken": x_norm[:, 0],
|
| 39 |
-
"x_norm_regtokens": None,
|
| 40 |
-
"x_norm_patchtokens": x_norm[:, 1:],
|
| 41 |
-
"x_prenorm": x,
|
| 42 |
-
}
|
| 43 |
-
else:
|
| 44 |
-
# Model with registers (num_global_tokens = 5)
|
| 45 |
-
return {
|
| 46 |
-
"x_norm_clstoken": x_norm[:, 0],
|
| 47 |
-
"x_norm_regtokens": x_norm[:, 1:num_global_tokens],
|
| 48 |
-
"x_norm_patchtokens": x_norm[:, num_global_tokens:],
|
| 49 |
-
"x_prenorm": x,
|
| 50 |
-
}
|
| 51 |
-
|
| 52 |
-
def process_self_attention(output, batch_size, num_tokens, num_attn_heads, embed_dim, scale, num_global_tokens, ret_self_attn_maps=False):
|
| 53 |
-
"""Process self-attention output to compute attention weights"""
|
| 54 |
-
qkv = output.reshape(batch_size, num_tokens, 3, num_attn_heads, embed_dim // num_attn_heads).permute(2, 0, 3, 1, 4)
|
| 55 |
-
q, k, v = qkv[0] * scale, qkv[1], qkv[2]
|
| 56 |
-
attn = q @ k.transpose(-2, -1)
|
| 57 |
-
self_attn_maps = attn[:, :, 0, num_global_tokens:] # CLS token attention to patches
|
| 58 |
-
self_attn = self_attn_maps.mean(dim=1) # Average over attention heads
|
| 59 |
-
self_attn = self_attn.softmax(dim=-1)
|
| 60 |
-
if ret_self_attn_maps:
|
| 61 |
-
return self_attn, self_attn_maps
|
| 62 |
-
else:
|
| 63 |
-
return self_attn
|
| 64 |
-
|
| 65 |
-
|
| 66 |
-
# Global variables to store hook outputs
|
| 67 |
-
dino_layer_n_output = None
|
| 68 |
-
qkv_attention_out = None
|
| 69 |
-
|
| 70 |
-
def get_layer_n_output(module, input, output):
|
| 71 |
-
"""Hook to capture intermediate layer output"""
|
| 72 |
-
global dino_layer_n_output
|
| 73 |
-
dino_layer_n_output = output
|
| 74 |
-
|
| 75 |
-
|
| 76 |
-
def select_most_significant_patch(dino_outs, self_attn, criteria, cls_token=None, caption_embedding=None):
|
| 77 |
-
"""
|
| 78 |
-
Select the most significant patch token based on different criteria.
|
| 79 |
-
|
| 80 |
-
Args:
|
| 81 |
-
dino_outs: Dictionary containing normalized DINO outputs
|
| 82 |
-
self_attn: Self-attention weights from CLS to patches [batch_size, num_patches]
|
| 83 |
-
criteria: Selection criteria ('max_attention', 'most_similar_to_cls', etc.)
|
| 84 |
-
cls_token: CLS token embeddings [batch_size, embed_dim]
|
| 85 |
-
caption_embedding: Text caption embeddings [batch_size, embed_dim]
|
| 86 |
-
|
| 87 |
-
Returns:
|
| 88 |
-
selected_patches: [batch_size, embed_dim] - Selected patch embeddings
|
| 89 |
-
"""
|
| 90 |
-
patch_tokens = dino_outs['x_norm_patchtokens'] # [batch_size, num_patches, embed_dim]
|
| 91 |
-
batch_size, num_patches, embed_dim = patch_tokens.shape
|
| 92 |
-
|
| 93 |
-
if criteria == "max_attention":
|
| 94 |
-
# Select patch with highest attention weight from CLS token
|
| 95 |
-
if self_attn is None:
|
| 96 |
-
raise ValueError("self_attn required for max_attention criteria")
|
| 97 |
-
max_attn_indices = self_attn.argmax(dim=1) # [batch_size]
|
| 98 |
-
selected_patches = patch_tokens[torch.arange(batch_size), max_attn_indices]
|
| 99 |
-
|
| 100 |
-
elif criteria == "most_similar_to_cls":
|
| 101 |
-
# Select patch most similar to CLS token using cosine similarity
|
| 102 |
-
if cls_token is None:
|
| 103 |
-
raise ValueError("cls_token required for most_similar_to_cls criteria")
|
| 104 |
-
# Compute cosine similarity between CLS and all patches
|
| 105 |
-
cls_normalized = F.normalize(cls_token, p=2, dim=1) # [batch_size, embed_dim]
|
| 106 |
-
patches_normalized = F.normalize(patch_tokens, p=2, dim=2) # [batch_size, num_patches, embed_dim]
|
| 107 |
-
similarities = torch.bmm(patches_normalized, cls_normalized.unsqueeze(2)).squeeze(2) # [batch_size, num_patches]
|
| 108 |
-
max_sim_indices = similarities.argmax(dim=1) # [batch_size]
|
| 109 |
-
selected_patches = patch_tokens[torch.arange(batch_size), max_sim_indices]
|
| 110 |
-
|
| 111 |
-
elif criteria == "most_similar_to_caption":
|
| 112 |
-
# Select patch most similar to caption embedding
|
| 113 |
-
if caption_embedding is None:
|
| 114 |
-
raise ValueError("caption_embedding required for most_similar_to_caption criteria")
|
| 115 |
-
caption_normalized = F.normalize(caption_embedding, p=2, dim=1) # [batch_size, embed_dim]
|
| 116 |
-
patches_normalized = F.normalize(patch_tokens, p=2, dim=2) # [batch_size, num_patches, embed_dim]
|
| 117 |
-
similarities = torch.bmm(patches_normalized, caption_normalized.unsqueeze(2)).squeeze(2) # [batch_size, num_patches]
|
| 118 |
-
max_sim_indices = similarities.argmax(dim=1) # [batch_size]
|
| 119 |
-
selected_patches = patch_tokens[torch.arange(batch_size), max_sim_indices]
|
| 120 |
-
|
| 121 |
-
elif criteria == "max_norm":
|
| 122 |
-
# Select patch with highest L2 norm
|
| 123 |
-
patch_norms = torch.norm(patch_tokens, p=2, dim=2) # [batch_size, num_patches]
|
| 124 |
-
max_norm_indices = patch_norms.argmax(dim=1) # [batch_size]
|
| 125 |
-
selected_patches = patch_tokens[torch.arange(batch_size), max_norm_indices]
|
| 126 |
-
|
| 127 |
-
elif criteria == "centroid_distance":
|
| 128 |
-
# Select patch farthest from the centroid of all patches
|
| 129 |
-
centroid = patch_tokens.mean(dim=1, keepdim=True) # [batch_size, 1, embed_dim]
|
| 130 |
-
distances = torch.norm(patch_tokens - centroid, p=2, dim=2) # [batch_size, num_patches]
|
| 131 |
-
max_dist_indices = distances.argmax(dim=1) # [batch_size]
|
| 132 |
-
selected_patches = patch_tokens[torch.arange(batch_size), max_dist_indices]
|
| 133 |
-
|
| 134 |
-
else:
|
| 135 |
-
raise ValueError(f"Unknown patch selection criteria: {criteria}")
|
| 136 |
-
|
| 137 |
-
return selected_patches
|
| 138 |
-
|
| 139 |
-
|
| 140 |
-
def load_text_encoder(text_encoder_path, device, config_path=None):
|
| 141 |
-
"""
|
| 142 |
-
Load a text encoder model for caption similarity.
|
| 143 |
-
Supports Talk2Dino, CLIP, and DINO.txt-based text encoders.
|
| 144 |
-
"""
|
| 145 |
-
if text_encoder_path is None:
|
| 146 |
-
return None
|
| 147 |
-
|
| 148 |
-
print(f"Loading text encoder from: {text_encoder_path}")
|
| 149 |
-
|
| 150 |
-
# Check for DINO.txt model
|
| 151 |
-
if text_encoder_path.lower() == 'dinotxt' or text_encoder_path.lower() == 'dino.txt':
|
| 152 |
-
# Load DINO.txt model
|
| 153 |
-
try:
|
| 154 |
-
from src.dinotxt_utils import get_tokenizer
|
| 155 |
-
|
| 156 |
-
print("Loading DINO.txt model...")
|
| 157 |
-
dinotxt_model = torch.hub.load('facebookresearch/dinov2', 'dinov2_vitl14_reg4_dinotxt_tet1280d20h24l')
|
| 158 |
-
dinotxt_model.eval()
|
| 159 |
-
dinotxt_model.to(device)
|
| 160 |
-
|
| 161 |
-
tokenizer = get_tokenizer()
|
| 162 |
-
|
| 163 |
-
return {
|
| 164 |
-
'type': 'dinotxt',
|
| 165 |
-
'model': dinotxt_model,
|
| 166 |
-
'tokenizer': tokenizer
|
| 167 |
-
}
|
| 168 |
-
|
| 169 |
-
except ImportError:
|
| 170 |
-
raise ImportError("Could not import dinotxt_utils. Make sure src/dinotxt_utils.py is accessible.")
|
| 171 |
-
except Exception as e:
|
| 172 |
-
raise RuntimeError(f"Failed to load DINO.txt model: {e}")
|
| 173 |
-
|
| 174 |
-
# Check if it's a Talk2Dino model (expect config and weights)
|
| 175 |
-
elif text_encoder_path.endswith('.pth') or text_encoder_path.endswith('.pt'):
|
| 176 |
-
# Use provided config or auto-find
|
| 177 |
-
if config_path is None:
|
| 178 |
-
# Look for corresponding config file
|
| 179 |
-
base_path = text_encoder_path.rsplit('.', 1)[0]
|
| 180 |
-
config_path = base_path + '.yaml'
|
| 181 |
-
|
| 182 |
-
# Alternative config path patterns
|
| 183 |
-
if not os.path.exists(config_path):
|
| 184 |
-
# Try configs_talk2dino directory
|
| 185 |
-
config_name = os.path.basename(base_path) + '.yaml'
|
| 186 |
-
config_path = os.path.join(os.path.dirname(__file__), 'configs_talk2dino', config_name)
|
| 187 |
-
|
| 188 |
-
if not os.path.exists(config_path):
|
| 189 |
-
raise FileNotFoundError(f"Could not find config file for {text_encoder_path}. "
|
| 190 |
-
f"Expected at {config_path} or specify --text_encoder_config.")
|
| 191 |
-
|
| 192 |
-
# Load Talk2Dino model
|
| 193 |
-
try:
|
| 194 |
-
from src.model import ProjectionLayer
|
| 195 |
-
|
| 196 |
-
print(f"Using config: {config_path}")
|
| 197 |
-
|
| 198 |
-
# Load the projection layer
|
| 199 |
-
talk2dino = ProjectionLayer.from_config(config_path)
|
| 200 |
-
talk2dino.load_state_dict(torch.load(text_encoder_path, map_location=device))
|
| 201 |
-
talk2dino.to(device)
|
| 202 |
-
talk2dino.eval()
|
| 203 |
-
|
| 204 |
-
# Load CLIP model for text encoding
|
| 205 |
-
clip_model, _ = clip.load("ViT-B/32", device=device)
|
| 206 |
-
clip_model.eval()
|
| 207 |
-
|
| 208 |
-
return {
|
| 209 |
-
'type': 'talk2dino',
|
| 210 |
-
'talk2dino': talk2dino,
|
| 211 |
-
'clip_model': clip_model,
|
| 212 |
-
'config_path': config_path
|
| 213 |
-
}
|
| 214 |
-
|
| 215 |
-
except ImportError:
|
| 216 |
-
raise ImportError("Could not import ProjectionLayer. Make sure src/model.py is accessible.")
|
| 217 |
-
|
| 218 |
-
else:
|
| 219 |
-
# Assume it's a direct model path (CLIP or other)
|
| 220 |
-
try:
|
| 221 |
-
# Try loading as a CLIP model
|
| 222 |
-
clip_model, _ = clip.load(text_encoder_path, device=device)
|
| 223 |
-
clip_model.eval()
|
| 224 |
-
|
| 225 |
-
return {
|
| 226 |
-
'type': 'clip',
|
| 227 |
-
'clip_model': clip_model
|
| 228 |
-
}
|
| 229 |
-
except:
|
| 230 |
-
raise ValueError(f"Could not load text encoder from {text_encoder_path}. "
|
| 231 |
-
f"Supported formats: 1) 'dinotxt' or 'dino.txt' for DINO.txt model, "
|
| 232 |
-
f"2) Talk2Dino (.pth/.pt), 3) CLIP model names.")
|
| 233 |
-
|
| 234 |
-
|
| 235 |
-
def encode_caption(caption, text_encoder, device):
|
| 236 |
-
"""
|
| 237 |
-
Encode a text caption using the loaded text encoder.
|
| 238 |
-
"""
|
| 239 |
-
if text_encoder is None:
|
| 240 |
-
return None
|
| 241 |
-
|
| 242 |
-
if text_encoder['type'] == 'dinotxt':
|
| 243 |
-
# Use DINO.txt pipeline: tokenize + encode + extract patch-aligned features
|
| 244 |
-
with torch.no_grad():
|
| 245 |
-
# Tokenize with DINO.txt tokenizer
|
| 246 |
-
text_tokens = text_encoder['tokenizer'].tokenize([caption]).to(device)
|
| 247 |
-
|
| 248 |
-
# Encode with DINO.txt model
|
| 249 |
-
dinotxt_features = text_encoder['model'].encode_text(text_tokens)
|
| 250 |
-
|
| 251 |
-
# Extract patch-aligned text embeddings (dimensions 1024:)
|
| 252 |
-
# DINO.txt concatenates standard text features [0:1024] and patch-aligned features [1024:]
|
| 253 |
-
patch_aligned_features = dinotxt_features[:, 1024:]
|
| 254 |
-
|
| 255 |
-
# Normalize the features to match DINO feature space
|
| 256 |
-
patch_aligned_features = F.normalize(patch_aligned_features, p=2, dim=-1)
|
| 257 |
-
return patch_aligned_features
|
| 258 |
-
|
| 259 |
-
elif text_encoder['type'] == 'talk2dino':
|
| 260 |
-
# Use Talk2Dino pipeline: CLIP text encoding + Talk2Dino projection
|
| 261 |
-
with torch.no_grad():
|
| 262 |
-
# Tokenize and encode with CLIP
|
| 263 |
-
text_tokens = clip.tokenize([caption]).to(device)
|
| 264 |
-
clip_text_features = text_encoder['clip_model'].encode_text(text_tokens)
|
| 265 |
-
|
| 266 |
-
# Project through Talk2Dino to DINO space
|
| 267 |
-
dino_text_features = text_encoder['talk2dino'].project_clip_txt(clip_text_features)
|
| 268 |
-
|
| 269 |
-
# Normalize the encoded text to match DINO feature space
|
| 270 |
-
dino_text_features = F.normalize(dino_text_features, p=2, dim=-1)
|
| 271 |
-
return dino_text_features
|
| 272 |
-
|
| 273 |
-
elif text_encoder['type'] == 'clip':
|
| 274 |
-
# Use CLIP directly
|
| 275 |
-
with torch.no_grad():
|
| 276 |
-
text_tokens = clip.tokenize([caption]).to(device)
|
| 277 |
-
clip_text_features = text_encoder['clip_model'].encode_text(text_tokens)
|
| 278 |
-
|
| 279 |
-
# Normalize the features
|
| 280 |
-
clip_text_features = F.normalize(clip_text_features, p=2, dim=-1)
|
| 281 |
-
return clip_text_features
|
| 282 |
-
|
| 283 |
-
else:
|
| 284 |
-
raise ValueError(f"Unknown text encoder type: {text_encoder['type']}")
|
| 285 |
-
|
| 286 |
-
|
| 287 |
-
def main(dino_model_type: str, resize_dim: int = 518, crop_dim: int = 518,
|
| 288 |
-
coco_images_dir: str = "/raid/datasets/coco/", captions_file: str = "/raid/datasets/coco/train_split_karpathy.json",
|
| 289 |
-
output_file: str = None, feature_type: str = "cls", extract_attention: bool = False,
|
| 290 |
-
patch_selection_criteria: str = "max_attention", text_encoder_path: str = None, text_encoder_config: str = None):
|
| 291 |
-
"""
|
| 292 |
-
Extract DINO features from COCO images for ClipCap training.
|
| 293 |
-
|
| 294 |
-
Args:
|
| 295 |
-
feature_type: Type of features to extract
|
| 296 |
-
- "cls": CLS token features (default)
|
| 297 |
-
- "avg_patch": Mean pooled patch token features
|
| 298 |
-
- "avg_self_attn": Self-attention weighted patch token features
|
| 299 |
-
- "most_significant_patch": Single most important patch token
|
| 300 |
-
extract_attention: Whether to extract self-attention weights (required for avg_self_attn)
|
| 301 |
-
patch_selection_criteria: Criteria for selecting most significant patch
|
| 302 |
-
text_encoder_path: Path to text encoder for caption similarity
|
| 303 |
-
"""
|
| 304 |
-
device = torch.device('cuda:0')
|
| 305 |
-
dino_model_name = dino_model_type.replace('/', '_')
|
| 306 |
-
|
| 307 |
-
# Determine model properties
|
| 308 |
-
num_global_tokens = 1 if "reg" not in dino_model_type else 5
|
| 309 |
-
patch_size = 14 # DINOv2 uses 14x14 patches
|
| 310 |
-
num_patch_tokens = (crop_dim // patch_size) * (crop_dim // patch_size)
|
| 311 |
-
num_tokens = num_global_tokens + num_patch_tokens
|
| 312 |
-
|
| 313 |
-
# Get embedding dimension based on model type
|
| 314 |
-
if 'vitl' in dino_model_type:
|
| 315 |
-
embed_dim = 1024
|
| 316 |
-
num_attn_heads = 16
|
| 317 |
-
elif 'vitb' in dino_model_type:
|
| 318 |
-
embed_dim = 768
|
| 319 |
-
num_attn_heads = 12
|
| 320 |
-
elif 'vits' in dino_model_type:
|
| 321 |
-
embed_dim = 384
|
| 322 |
-
num_attn_heads = 6
|
| 323 |
-
elif 'vitg' in dino_model_type:
|
| 324 |
-
embed_dim = 1536
|
| 325 |
-
num_attn_heads = 24
|
| 326 |
-
else:
|
| 327 |
-
raise ValueError(f"Unknown model type: {dino_model_type}")
|
| 328 |
-
|
| 329 |
-
scale = (embed_dim // num_attn_heads) ** -0.5
|
| 330 |
-
|
| 331 |
-
# Set default output path if not specified
|
| 332 |
-
if output_file is None:
|
| 333 |
-
if feature_type == "cls":
|
| 334 |
-
feature_suffix = ""
|
| 335 |
-
elif feature_type == "most_significant_patch":
|
| 336 |
-
if patch_selection_criteria == "most_similar_to_caption" and text_encoder_path is not None:
|
| 337 |
-
# Determine text encoder type from path to create unique filename
|
| 338 |
-
if text_encoder_path.lower() in ['dinotxt', 'dino.txt']:
|
| 339 |
-
text_encoder_suffix = "_dinotxt"
|
| 340 |
-
elif text_encoder_path.endswith('.pth') or text_encoder_path.endswith('.pt'):
|
| 341 |
-
text_encoder_suffix = "_t2d" # Talk2Dino
|
| 342 |
-
else:
|
| 343 |
-
# CLIP or other models
|
| 344 |
-
text_encoder_suffix = "_clip"
|
| 345 |
-
feature_suffix = f"_{feature_type}_{patch_selection_criteria}{text_encoder_suffix}"
|
| 346 |
-
else:
|
| 347 |
-
feature_suffix = f"_{feature_type}_{patch_selection_criteria}"
|
| 348 |
-
else:
|
| 349 |
-
feature_suffix = f"_{feature_type}"
|
| 350 |
-
output_file = f"/raid/datasets/models_weights/clipcap/training-features/coco_karpathy_split_{dino_model_name}{feature_suffix}_train.pkl"
|
| 351 |
-
|
| 352 |
-
# Create output directory if it doesn't exist
|
| 353 |
-
os.makedirs(os.path.dirname(output_file), exist_ok=True)
|
| 354 |
-
|
| 355 |
-
# Load DINO model
|
| 356 |
-
print(f"Loading DINO model: {dino_model_type}")
|
| 357 |
-
print(f"Feature type: {feature_type}")
|
| 358 |
-
if feature_type == "most_significant_patch":
|
| 359 |
-
print(f"Patch selection criteria: {patch_selection_criteria}")
|
| 360 |
-
print(f"Model properties: embed_dim={embed_dim}, num_heads={num_attn_heads}, num_global_tokens={num_global_tokens}")
|
| 361 |
-
|
| 362 |
-
if 'dinov2' in dino_model_type:
|
| 363 |
-
model_family = 'facebookresearch/dinov2'
|
| 364 |
-
dino_model = torch.hub.load(model_family, dino_model_type)
|
| 365 |
-
else:
|
| 366 |
-
raise ValueError(f"Unsupported DINO model type: {dino_model_type}")
|
| 367 |
-
|
| 368 |
-
# Setup transforms for DINO
|
| 369 |
-
image_transforms = T.Compose([
|
| 370 |
-
T.Resize(resize_dim, interpolation=T.InterpolationMode.BICUBIC),
|
| 371 |
-
T.CenterCrop(crop_dim),
|
| 372 |
-
T.ToTensor(),
|
| 373 |
-
T.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)),
|
| 374 |
-
])
|
| 375 |
-
|
| 376 |
-
dino_model.eval()
|
| 377 |
-
dino_model.to(device)
|
| 378 |
-
|
| 379 |
-
# Register hooks if we need attention or intermediate outputs
|
| 380 |
-
if feature_type == "avg_self_attn" or extract_attention or \
|
| 381 |
-
(feature_type == "most_significant_patch" and patch_selection_criteria in ["max_attention", "most_similar_to_caption"]):
|
| 382 |
-
print("Registering hooks for attention extraction...")
|
| 383 |
-
dino_model.blocks[-1].attn.qkv.register_forward_hook(get_self_attention)
|
| 384 |
-
|
| 385 |
-
if feature_type in ["avg_patch", "avg_self_attn", "most_significant_patch"]:
|
| 386 |
-
print("Registering hooks for intermediate output extraction...")
|
| 387 |
-
dino_model.blocks[-1].register_forward_hook(get_layer_n_output)
|
| 388 |
-
|
| 389 |
-
# Load caption data
|
| 390 |
-
print(f"Loading captions from: {captions_file}")
|
| 391 |
-
with open(captions_file, 'r') as f:
|
| 392 |
-
data = json.load(f)
|
| 393 |
-
|
| 394 |
-
# Handle different annotation formats
|
| 395 |
-
if isinstance(data, list):
|
| 396 |
-
# Original ClipCap format: list of dicts with 'image_id' and 'caption'
|
| 397 |
-
annotations = data
|
| 398 |
-
print(f"{len(annotations)} captions loaded from json (ClipCap format)")
|
| 399 |
-
elif isinstance(data, dict) and 'annotations' in data:
|
| 400 |
-
# Karpathy format: dict with 'annotations' key
|
| 401 |
-
annotations = data['annotations']
|
| 402 |
-
print(f"{len(annotations)} captions loaded from json (Karpathy format)")
|
| 403 |
-
|
| 404 |
-
# Create image ID to filename mapping for faster lookup
|
| 405 |
-
if 'images' in data:
|
| 406 |
-
image_id_to_filename = {img['id']: img['file_name'] for img in data['images']}
|
| 407 |
-
else:
|
| 408 |
-
image_id_to_filename = {}
|
| 409 |
-
else:
|
| 410 |
-
raise ValueError("Unsupported annotation format")
|
| 411 |
-
|
| 412 |
-
# Load text encoder if needed for caption similarity
|
| 413 |
-
text_encoder = None
|
| 414 |
-
if feature_type == "most_significant_patch" and patch_selection_criteria == "most_similar_to_caption":
|
| 415 |
-
if text_encoder_path is None:
|
| 416 |
-
raise ValueError("text_encoder_path required for most_similar_to_caption criteria")
|
| 417 |
-
text_encoder = load_text_encoder(text_encoder_path, device, text_encoder_config)
|
| 418 |
-
print(f"Loaded text encoder from: {text_encoder_path}")
|
| 419 |
-
|
| 420 |
-
all_embeddings = []
|
| 421 |
-
all_captions = []
|
| 422 |
-
|
| 423 |
-
print(f"Processing images from: {coco_images_dir}")
|
| 424 |
-
print(f"Output will be saved to: {output_file}")
|
| 425 |
-
|
| 426 |
-
for i, annotation in enumerate(tqdm(annotations)):
|
| 427 |
-
img_id = annotation["image_id"]
|
| 428 |
-
|
| 429 |
-
# Determine filename based on format
|
| 430 |
-
if isinstance(data, list):
|
| 431 |
-
# Original format: construct filename from image_id
|
| 432 |
-
filename = os.path.join(coco_images_dir, "train2014", f"COCO_train2014_{int(img_id):012d}.jpg")
|
| 433 |
-
if not os.path.isfile(filename):
|
| 434 |
-
filename = os.path.join(coco_images_dir, "val2014", f"COCO_val2014_{int(img_id):012d}.jpg")
|
| 435 |
-
else:
|
| 436 |
-
# Karpathy format: use filename from images mapping or construct it
|
| 437 |
-
if img_id in image_id_to_filename:
|
| 438 |
-
if 'train' in image_id_to_filename[img_id]:
|
| 439 |
-
fold = "train2014"
|
| 440 |
-
else:
|
| 441 |
-
fold = "val2014"
|
| 442 |
-
filename = os.path.join(coco_images_dir, fold, image_id_to_filename[img_id])
|
| 443 |
-
else:
|
| 444 |
-
# Fallback: try to construct filename
|
| 445 |
-
filename = os.path.join(coco_images_dir, "train2014", f"COCO_train2014_{int(img_id):012d}.jpg")
|
| 446 |
-
if not os.path.isfile(filename):
|
| 447 |
-
filename = os.path.join(coco_images_dir, "val2014", f"COCO_val2014_{int(img_id):012d}.jpg")
|
| 448 |
-
|
| 449 |
-
if not os.path.isfile(filename):
|
| 450 |
-
print(f"Warning: Image not found: {filename}")
|
| 451 |
-
continue
|
| 452 |
-
|
| 453 |
-
# Load and process image
|
| 454 |
-
try:
|
| 455 |
-
image = io.imread(filename)
|
| 456 |
-
if len(image.shape) == 2: # grayscale
|
| 457 |
-
image = Image.fromarray(image).convert('RGB')
|
| 458 |
-
else:
|
| 459 |
-
image = Image.fromarray(image)
|
| 460 |
-
except Exception as e:
|
| 461 |
-
print(f"Warning: Failed to load image {filename}: {e}")
|
| 462 |
-
continue
|
| 463 |
-
|
| 464 |
-
# Apply DINO transforms
|
| 465 |
-
image_tensor = image_transforms(image).unsqueeze(0).to(device)
|
| 466 |
-
|
| 467 |
-
with torch.no_grad():
|
| 468 |
-
# Clear any previous stored data
|
| 469 |
-
global dino_layer_n_output, qkv_attention_out
|
| 470 |
-
dino_layer_n_output = None
|
| 471 |
-
qkv_attention_out = None
|
| 472 |
-
|
| 473 |
-
# Extract DINO features
|
| 474 |
-
if feature_type == "cls":
|
| 475 |
-
# Standard CLS token extraction
|
| 476 |
-
features = dino_model(image_tensor)
|
| 477 |
-
# For DINOv2, the output is the CLS token by default
|
| 478 |
-
if len(features.shape) == 3: # If we get [batch, seq_len, dim]
|
| 479 |
-
features = features[:, 0, :] # Take CLS token
|
| 480 |
-
prefix = features.cpu()
|
| 481 |
-
else:
|
| 482 |
-
# For patch-based features, we need intermediate outputs
|
| 483 |
-
_ = dino_model(image_tensor) # Forward pass to trigger hooks
|
| 484 |
-
|
| 485 |
-
if dino_layer_n_output is None:
|
| 486 |
-
raise RuntimeError("No intermediate output captured. Check hook registration.")
|
| 487 |
-
|
| 488 |
-
# Transform to standard format
|
| 489 |
-
dino_outs = transform_to_standard_dino_out(dino_layer_n_output, dino_model, num_global_tokens)
|
| 490 |
-
|
| 491 |
-
if feature_type == "avg_patch":
|
| 492 |
-
# Average of patch tokens (excluding global tokens)
|
| 493 |
-
prefix = dino_outs['x_norm_patchtokens'].mean(dim=1) # [B, D]
|
| 494 |
-
elif feature_type == "avg_self_attn":
|
| 495 |
-
# Self-attention weighted average of patch tokens
|
| 496 |
-
if qkv_attention_out is None:
|
| 497 |
-
raise RuntimeError("No attention output captured. Check hook registration.")
|
| 498 |
-
|
| 499 |
-
# Process self-attention to get attention weights
|
| 500 |
-
batch_size = qkv_attention_out.shape[0]
|
| 501 |
-
self_attn = process_self_attention(
|
| 502 |
-
qkv_attention_out,
|
| 503 |
-
batch_size,
|
| 504 |
-
num_tokens,
|
| 505 |
-
num_attn_heads,
|
| 506 |
-
embed_dim,
|
| 507 |
-
scale,
|
| 508 |
-
num_global_tokens
|
| 509 |
-
)
|
| 510 |
-
|
| 511 |
-
# Compute attention-weighted average
|
| 512 |
-
prefix = (self_attn.unsqueeze(-1) * dino_outs['x_norm_patchtokens']).mean(dim=1)
|
| 513 |
-
elif feature_type == "most_significant_patch":
|
| 514 |
-
# Select single most significant patch based on criteria
|
| 515 |
-
self_attn = None
|
| 516 |
-
cls_token = None
|
| 517 |
-
caption_embedding = None
|
| 518 |
-
|
| 519 |
-
# Prepare required inputs based on criteria
|
| 520 |
-
if patch_selection_criteria in ["max_attention", "most_similar_to_caption"]:
|
| 521 |
-
if qkv_attention_out is None:
|
| 522 |
-
raise RuntimeError("No attention output captured. Check hook registration.")
|
| 523 |
-
batch_size = qkv_attention_out.shape[0]
|
| 524 |
-
self_attn = process_self_attention(
|
| 525 |
-
qkv_attention_out,
|
| 526 |
-
batch_size,
|
| 527 |
-
num_tokens,
|
| 528 |
-
num_attn_heads,
|
| 529 |
-
embed_dim,
|
| 530 |
-
scale,
|
| 531 |
-
num_global_tokens
|
| 532 |
-
)
|
| 533 |
-
|
| 534 |
-
if patch_selection_criteria == "most_similar_to_cls":
|
| 535 |
-
cls_token = dino_outs['x_norm_clstoken']
|
| 536 |
-
|
| 537 |
-
if patch_selection_criteria == "most_similar_to_caption":
|
| 538 |
-
if text_encoder is not None:
|
| 539 |
-
caption_embedding = encode_caption(annotation["caption"], text_encoder, device)
|
| 540 |
-
|
| 541 |
-
# Select the most significant patch
|
| 542 |
-
prefix = select_most_significant_patch(
|
| 543 |
-
dino_outs,
|
| 544 |
-
self_attn,
|
| 545 |
-
patch_selection_criteria,
|
| 546 |
-
cls_token=cls_token,
|
| 547 |
-
caption_embedding=caption_embedding
|
| 548 |
-
)
|
| 549 |
-
else:
|
| 550 |
-
raise ValueError(f"Unknown feature type: {feature_type}")
|
| 551 |
-
|
| 552 |
-
prefix = prefix.cpu()
|
| 553 |
-
|
| 554 |
-
# Create annotation in ClipCap format for compatibility
|
| 555 |
-
caption_entry = {
|
| 556 |
-
"image_id": img_id,
|
| 557 |
-
"caption": annotation["caption"],
|
| 558 |
-
"clip_embedding": i # Index for the embedding
|
| 559 |
-
}
|
| 560 |
-
|
| 561 |
-
all_embeddings.append(prefix)
|
| 562 |
-
all_captions.append(caption_entry)
|
| 563 |
-
|
| 564 |
-
if (i + 1) % 10000 == 0:
|
| 565 |
-
# Create output directory if it doesn't exist
|
| 566 |
-
os.makedirs(os.path.dirname(output_file), exist_ok=True)
|
| 567 |
-
with open(output_file, 'wb') as f:
|
| 568 |
-
pickle.dump({"clip_embedding": torch.cat(all_embeddings, dim=0), "captions": all_captions}, f)
|
| 569 |
-
|
| 570 |
-
# Create output directory if it doesn't exist
|
| 571 |
-
os.makedirs(os.path.dirname(output_file), exist_ok=True)
|
| 572 |
-
with open(output_file, 'wb') as f:
|
| 573 |
-
pickle.dump({"clip_embedding": torch.cat(all_embeddings, dim=0), "captions": all_captions}, f)
|
| 574 |
-
|
| 575 |
-
print('Done')
|
| 576 |
-
print("%0d embeddings saved " % len(all_embeddings))
|
| 577 |
-
print(f"Feature dimension: {all_embeddings[0].shape[-1]}")
|
| 578 |
-
return 0
|
| 579 |
-
|
| 580 |
-
|
| 581 |
-
if __name__ == '__main__':
|
| 582 |
-
parser = argparse.ArgumentParser(description='Extract DINO features from COCO images for ClipCap training')
|
| 583 |
-
parser.add_argument('--dino_model_type', default="dinov2_vitb14",
|
| 584 |
-
choices=('dinov2_vits14', 'dinov2_vitb14', 'dinov2_vitl14', 'dinov2_vitg14',
|
| 585 |
-
'dinov2_vits14_reg', 'dinov2_vitb14_reg', 'dinov2_vitl14_reg', 'dinov2_vitg14_reg'),
|
| 586 |
-
help='DINO model type to use for feature extraction')
|
| 587 |
-
parser.add_argument('--feature_type', default="cls",
|
| 588 |
-
choices=('cls', 'avg_patch', 'avg_self_attn', 'most_significant_patch'),
|
| 589 |
-
help='Type of features to extract: cls (CLS token), avg_patch (mean pooled patches), avg_self_attn (attention-weighted patches), most_significant_patch (single most important patch)')
|
| 590 |
-
parser.add_argument('--patch_selection_criteria', default="max_attention",
|
| 591 |
-
choices=('max_attention', 'most_similar_to_cls', 'most_similar_to_caption', 'max_norm', 'centroid_distance'),
|
| 592 |
-
help='Criteria for selecting the most significant patch (only used with most_significant_patch feature_type)')
|
| 593 |
-
parser.add_argument('--text_encoder_path', type=str, default=None,
|
| 594 |
-
help='Path to text encoder for caption similarity. Supports: 1) "dinotxt" or "dino.txt" for DINO.txt model, 2) Talk2Dino weights (.pth/.pt) - will auto-find config, 3) CLIP model names (e.g., "ViT-B/32")')
|
| 595 |
-
parser.add_argument('--text_encoder_config', type=str, default=None,
|
| 596 |
-
help='Optional: explicit config path for Talk2Dino models (if not auto-found)')
|
| 597 |
-
parser.add_argument('--resize_dim', type=int, default=518, help='Resize dimension for images')
|
| 598 |
-
parser.add_argument('--crop_dim', type=int, default=518, help='Crop dimension for images')
|
| 599 |
-
parser.add_argument('--coco_images_dir', type=str, default="/raid/datasets/coco",
|
| 600 |
-
help='Path to COCO images directory (should contain train2014/ and val2014/ subdirs)')
|
| 601 |
-
parser.add_argument('--captions_file', type=str, default="/raid/datasets/coco/train_split_karpathy.json",
|
| 602 |
-
help='Path to COCO captions JSON file (supports both Karpathy and ClipCap formats)')
|
| 603 |
-
parser.add_argument('--output_file', type=str, default=None,
|
| 604 |
-
help='Output pickle file path (default: auto-generated based on model and feature type)')
|
| 605 |
-
parser.add_argument('--extract_attention', action='store_true',
|
| 606 |
-
help='Extract attention weights (automatically enabled for avg_self_attn feature type)')
|
| 607 |
-
|
| 608 |
-
args = parser.parse_args()
|
| 609 |
-
|
| 610 |
-
main(args.dino_model_type, args.resize_dim, args.crop_dim,
|
| 611 |
-
args.coco_images_dir, args.captions_file, args.output_file,
|
| 612 |
-
args.feature_type, args.extract_attention,
|
| 613 |
-
args.patch_selection_criteria, args.text_encoder_path, args.text_encoder_config)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
src/clipcap/clipcap_parse_coco.py
DELETED
|
@@ -1,51 +0,0 @@
|
|
| 1 |
-
import torch
|
| 2 |
-
import skimage.io as io
|
| 3 |
-
import clip
|
| 4 |
-
from PIL import Image
|
| 5 |
-
import pickle
|
| 6 |
-
import json
|
| 7 |
-
import os
|
| 8 |
-
from tqdm import tqdm
|
| 9 |
-
import argparse
|
| 10 |
-
|
| 11 |
-
|
| 12 |
-
def main(clip_model_type: str):
|
| 13 |
-
device = torch.device('cuda:0')
|
| 14 |
-
clip_model_name = clip_model_type.replace('/', '_')
|
| 15 |
-
out_path = f"./data/coco/oscar_split_{clip_model_name}_train.pkl"
|
| 16 |
-
clip_model, preprocess = clip.load(clip_model_type, device=device, jit=False)
|
| 17 |
-
with open('./data/coco/annotations/train_caption.json', 'r') as f:
|
| 18 |
-
data = json.load(f)
|
| 19 |
-
print("%0d captions loaded from json " % len(data))
|
| 20 |
-
all_embeddings = []
|
| 21 |
-
all_captions = []
|
| 22 |
-
for i in tqdm(range(len(data))):
|
| 23 |
-
d = data[i]
|
| 24 |
-
img_id = d["image_id"]
|
| 25 |
-
filename = f"./data/coco/train2014/COCO_train2014_{int(img_id):012d}.jpg"
|
| 26 |
-
if not os.path.isfile(filename):
|
| 27 |
-
filename = f"./data/coco/val2014/COCO_val2014_{int(img_id):012d}.jpg"
|
| 28 |
-
image = io.imread(filename)
|
| 29 |
-
image = preprocess(Image.fromarray(image)).unsqueeze(0).to(device)
|
| 30 |
-
with torch.no_grad():
|
| 31 |
-
prefix = clip_model.encode_image(image).cpu()
|
| 32 |
-
d["clip_embedding"] = i
|
| 33 |
-
all_embeddings.append(prefix)
|
| 34 |
-
all_captions.append(d)
|
| 35 |
-
if (i + 1) % 10000 == 0:
|
| 36 |
-
with open(out_path, 'wb') as f:
|
| 37 |
-
pickle.dump({"clip_embedding": torch.cat(all_embeddings, dim=0), "captions": all_captions}, f)
|
| 38 |
-
|
| 39 |
-
with open(out_path, 'wb') as f:
|
| 40 |
-
pickle.dump({"clip_embedding": torch.cat(all_embeddings, dim=0), "captions": all_captions}, f)
|
| 41 |
-
|
| 42 |
-
print('Done')
|
| 43 |
-
print("%0d embeddings saved " % len(all_embeddings))
|
| 44 |
-
return 0
|
| 45 |
-
|
| 46 |
-
|
| 47 |
-
if __name__ == '__main__':
|
| 48 |
-
parser = argparse.ArgumentParser()
|
| 49 |
-
parser.add_argument('--clip_model_type', default="ViT-B/32", choices=('RN50', 'RN101', 'RN50x4', 'ViT-B/32'))
|
| 50 |
-
args = parser.parse_args()
|
| 51 |
-
exit(main(args.clip_model_type))
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
src/clipcap/entrypoint.py
DELETED
|
@@ -1,564 +0,0 @@
|
|
| 1 |
-
import torch
|
| 2 |
-
from torch import nn
|
| 3 |
-
import json
|
| 4 |
-
import os
|
| 5 |
-
from transformers import GPT2Tokenizer, GPT2LMHeadModel
|
| 6 |
-
from typing import List, Optional, Tuple, Union
|
| 7 |
-
from argparse import Namespace
|
| 8 |
-
from enum import Enum
|
| 9 |
-
|
| 10 |
-
import torch.nn.functional as nnf
|
| 11 |
-
|
| 12 |
-
class MappingType(Enum):
|
| 13 |
-
MLP = 'mlp'
|
| 14 |
-
Transformer = 'transformer'
|
| 15 |
-
|
| 16 |
-
|
| 17 |
-
class MLP(nn.Module):
|
| 18 |
-
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| 19 |
-
return self.model(x)
|
| 20 |
-
|
| 21 |
-
def __init__(self, sizes: Tuple[int, ...], bias=True, act=nn.Tanh):
|
| 22 |
-
super(MLP, self).__init__()
|
| 23 |
-
layers = []
|
| 24 |
-
for i in range(len(sizes) - 1):
|
| 25 |
-
layers.append(nn.Linear(sizes[i], sizes[i + 1], bias=bias))
|
| 26 |
-
if i < len(sizes) - 2:
|
| 27 |
-
layers.append(act())
|
| 28 |
-
self.model = nn.Sequential(*layers)
|
| 29 |
-
|
| 30 |
-
|
| 31 |
-
class MlpTransformer(nn.Module):
|
| 32 |
-
def __init__(self, in_dim, h_dim, out_d: Optional[int] = None, act=nnf.relu, dropout=0.):
|
| 33 |
-
super().__init__()
|
| 34 |
-
out_d = out_d if out_d is not None else in_dim
|
| 35 |
-
self.fc1 = nn.Linear(in_dim, h_dim)
|
| 36 |
-
self.act = act
|
| 37 |
-
self.fc2 = nn.Linear(h_dim, out_d)
|
| 38 |
-
self.dropout = nn.Dropout(dropout)
|
| 39 |
-
|
| 40 |
-
def forward(self, x):
|
| 41 |
-
x = self.fc1(x)
|
| 42 |
-
x = self.act(x)
|
| 43 |
-
x = self.dropout(x)
|
| 44 |
-
x = self.fc2(x)
|
| 45 |
-
x = self.dropout(x)
|
| 46 |
-
return x
|
| 47 |
-
|
| 48 |
-
class MultiHeadAttention(nn.Module):
|
| 49 |
-
|
| 50 |
-
def __init__(self, dim_self, dim_ref, num_heads, bias=True, dropout=0.):
|
| 51 |
-
super().__init__()
|
| 52 |
-
self.num_heads = num_heads
|
| 53 |
-
head_dim = dim_self // num_heads
|
| 54 |
-
self.scale = head_dim ** -0.5
|
| 55 |
-
self.to_queries = nn.Linear(dim_self, dim_self, bias=bias)
|
| 56 |
-
self.to_keys_values = nn.Linear(dim_ref, dim_self * 2, bias=bias)
|
| 57 |
-
self.project = nn.Linear(dim_self, dim_self)
|
| 58 |
-
self.dropout = nn.Dropout(dropout)
|
| 59 |
-
|
| 60 |
-
def forward(self, x, y=None, mask=None):
|
| 61 |
-
y = y if y is not None else x
|
| 62 |
-
b, n, c = x.shape
|
| 63 |
-
_, m, d = y.shape
|
| 64 |
-
# b n h dh
|
| 65 |
-
queries = self.to_queries(x).reshape(b, n, self.num_heads, c // self.num_heads)
|
| 66 |
-
# b m 2 h dh
|
| 67 |
-
keys_values = self.to_keys_values(y).reshape(b, m, 2, self.num_heads, c // self.num_heads)
|
| 68 |
-
keys, values = keys_values[:, :, 0], keys_values[:, :, 1]
|
| 69 |
-
attention = torch.einsum('bnhd,bmhd->bnmh', queries, keys) * self.scale
|
| 70 |
-
if mask is not None:
|
| 71 |
-
if mask.dim() == 2:
|
| 72 |
-
mask = mask.unsqueeze(1)
|
| 73 |
-
attention = attention.masked_fill(mask.unsqueeze(3), float("-inf"))
|
| 74 |
-
attention = attention.softmax(dim=2)
|
| 75 |
-
out = torch.einsum('bnmh,bmhd->bnhd', attention, values).reshape(b, n, c)
|
| 76 |
-
out = self.project(out)
|
| 77 |
-
return out, attention
|
| 78 |
-
|
| 79 |
-
|
| 80 |
-
class TransformerLayer(nn.Module):
|
| 81 |
-
|
| 82 |
-
def forward_with_attention(self, x, y=None, mask=None):
|
| 83 |
-
x_, attention = self.attn(self.norm1(x), y, mask)
|
| 84 |
-
x = x + x_
|
| 85 |
-
x = x + self.mlp(self.norm2(x))
|
| 86 |
-
return x, attention
|
| 87 |
-
|
| 88 |
-
def forward(self, x, y=None, mask=None):
|
| 89 |
-
x = x + self.attn(self.norm1(x), y, mask)[0]
|
| 90 |
-
x = x + self.mlp(self.norm2(x))
|
| 91 |
-
return x
|
| 92 |
-
|
| 93 |
-
def __init__(self, dim_self, dim_ref, num_heads, mlp_ratio=4., bias=False, dropout=0., act=nnf.relu,
|
| 94 |
-
norm_layer: nn.Module = nn.LayerNorm):
|
| 95 |
-
super().__init__()
|
| 96 |
-
self.norm1 = norm_layer(dim_self)
|
| 97 |
-
self.attn = MultiHeadAttention(dim_self, dim_ref, num_heads, bias=bias, dropout=dropout)
|
| 98 |
-
self.norm2 = norm_layer(dim_self)
|
| 99 |
-
self.mlp = MlpTransformer(dim_self, int(dim_self * mlp_ratio), act=act, dropout=dropout)
|
| 100 |
-
|
| 101 |
-
|
| 102 |
-
class Transformer(nn.Module):
|
| 103 |
-
|
| 104 |
-
def forward_with_attention(self, x, y=None, mask=None):
|
| 105 |
-
attentions = []
|
| 106 |
-
for layer in self.layers:
|
| 107 |
-
x, att = layer.forward_with_attention(x, y, mask)
|
| 108 |
-
attentions.append(att)
|
| 109 |
-
return x, attentions
|
| 110 |
-
|
| 111 |
-
def forward(self, x, y=None, mask=None):
|
| 112 |
-
for i, layer in enumerate(self.layers):
|
| 113 |
-
if i % 2 == 0 and self.enc_dec: # cross
|
| 114 |
-
x = layer(x, y)
|
| 115 |
-
elif self.enc_dec: # self
|
| 116 |
-
x = layer(x, x, mask)
|
| 117 |
-
else: # self or cross
|
| 118 |
-
x = layer(x, y, mask)
|
| 119 |
-
return x
|
| 120 |
-
|
| 121 |
-
def __init__(self, dim_self: int, num_heads: int, num_layers: int, dim_ref: Optional[int] = None,
|
| 122 |
-
mlp_ratio: float = 2., act=nnf.relu, norm_layer: nn.Module = nn.LayerNorm, enc_dec: bool = False):
|
| 123 |
-
super(Transformer, self).__init__()
|
| 124 |
-
dim_ref = dim_ref if dim_ref is not None else dim_self
|
| 125 |
-
self.enc_dec = enc_dec
|
| 126 |
-
if enc_dec:
|
| 127 |
-
num_layers = num_layers * 2
|
| 128 |
-
layers = []
|
| 129 |
-
for i in range(num_layers):
|
| 130 |
-
if i % 2 == 0 and enc_dec: # cross
|
| 131 |
-
layers.append(TransformerLayer(dim_self, dim_ref, num_heads, mlp_ratio, act=act, norm_layer=norm_layer))
|
| 132 |
-
elif enc_dec: # self
|
| 133 |
-
layers.append(TransformerLayer(dim_self, dim_self, num_heads, mlp_ratio, act=act, norm_layer=norm_layer))
|
| 134 |
-
else: # self or cross
|
| 135 |
-
layers.append(TransformerLayer(dim_self, dim_ref, num_heads, mlp_ratio, act=act, norm_layer=norm_layer))
|
| 136 |
-
self.layers = nn.ModuleList(layers)
|
| 137 |
-
|
| 138 |
-
class TransformerMapper(nn.Module):
|
| 139 |
-
def forward(self, x):
|
| 140 |
-
x = self.linear(x).view(x.shape[0], self.clip_length, -1)
|
| 141 |
-
prefix = self.prefix_const.unsqueeze(0).expand(x.shape[0], *self.prefix_const.shape)
|
| 142 |
-
prefix = torch.cat((x, prefix), dim=1)
|
| 143 |
-
out = self.transformer(prefix)[:, self.clip_length:]
|
| 144 |
-
return out
|
| 145 |
-
|
| 146 |
-
def __init__(self, dim_clip: int, dim_embedding: int, prefix_length: int, clip_length: int, num_layers: int = 8):
|
| 147 |
-
super(TransformerMapper, self).__init__()
|
| 148 |
-
self.clip_length = clip_length
|
| 149 |
-
self.transformer = Transformer(dim_embedding, 8, num_layers) #nn.Transformer(d_model=dim_embedding, nhead=8, num_encoder_layers=num_layers)
|
| 150 |
-
self.linear = nn.Linear(dim_clip, clip_length * dim_embedding)
|
| 151 |
-
self.prefix_const = nn.Parameter(torch.randn(prefix_length, dim_embedding), requires_grad=True)
|
| 152 |
-
|
| 153 |
-
|
| 154 |
-
class ClipCaptionModel(nn.Module):
|
| 155 |
-
|
| 156 |
-
def get_dummy_token(self, batch_size: int, device: torch.device) -> torch.Tensor:
|
| 157 |
-
return torch.zeros(batch_size, self.prefix_length, dtype=torch.int64, device=device)
|
| 158 |
-
|
| 159 |
-
def forward(self, tokens: torch.Tensor, prefix: torch.Tensor, mask: Optional[torch.Tensor] = None,
|
| 160 |
-
labels: Optional[torch.Tensor] = None):
|
| 161 |
-
embedding_text = self.gpt.transformer.wte(tokens)
|
| 162 |
-
prefix_projections = self.clip_project(prefix).view(-1, self.prefix_length, self.gpt_embedding_size)
|
| 163 |
-
embedding_cat = torch.cat((prefix_projections, embedding_text), dim=1)
|
| 164 |
-
if labels is not None:
|
| 165 |
-
dummy_token = self.get_dummy_token(tokens.shape[0], tokens.device)
|
| 166 |
-
labels = torch.cat((dummy_token, tokens), dim=1)
|
| 167 |
-
out = self.gpt(inputs_embeds=embedding_cat, labels=labels, attention_mask=mask)
|
| 168 |
-
return out
|
| 169 |
-
|
| 170 |
-
def __init__(self, prefix_length: int, clip_length: Optional[int] = None, prefix_size: int = 512,
|
| 171 |
-
num_layers: int = 8, mapping_type: MappingType = MappingType.MLP):
|
| 172 |
-
super(ClipCaptionModel, self).__init__()
|
| 173 |
-
self.prefix_length = prefix_length
|
| 174 |
-
self.gpt = GPT2LMHeadModel.from_pretrained('gpt2')
|
| 175 |
-
self.gpt_embedding_size = self.gpt.transformer.wte.weight.shape[1]
|
| 176 |
-
if mapping_type == MappingType.MLP:
|
| 177 |
-
self.clip_project = MLP((prefix_size, (self.gpt_embedding_size * prefix_length) // 2,
|
| 178 |
-
self.gpt_embedding_size * prefix_length))
|
| 179 |
-
else:
|
| 180 |
-
self.clip_project = TransformerMapper(prefix_size, self.gpt_embedding_size, prefix_length,
|
| 181 |
-
clip_length, num_layers)
|
| 182 |
-
|
| 183 |
-
|
| 184 |
-
class ClipCaptionPrefix(ClipCaptionModel):
|
| 185 |
-
|
| 186 |
-
def parameters(self, recurse: bool = True):
|
| 187 |
-
return self.clip_project.parameters()
|
| 188 |
-
|
| 189 |
-
def train(self, mode: bool = True):
|
| 190 |
-
super(ClipCaptionPrefix, self).train(mode)
|
| 191 |
-
self.gpt.eval()
|
| 192 |
-
return self
|
| 193 |
-
|
| 194 |
-
|
| 195 |
-
def generate_batched(
|
| 196 |
-
model,
|
| 197 |
-
tokenizer,
|
| 198 |
-
prefix_embeds,
|
| 199 |
-
entry_length=67,
|
| 200 |
-
top_p=0.8,
|
| 201 |
-
temperature=1.0,
|
| 202 |
-
stop_token: str = '.',
|
| 203 |
-
):
|
| 204 |
-
"""
|
| 205 |
-
Batched text generation for ClipCap models.
|
| 206 |
-
|
| 207 |
-
Args:
|
| 208 |
-
model: ClipCap model
|
| 209 |
-
tokenizer: GPT2 tokenizer
|
| 210 |
-
prefix_embeds: (batch_size, prefix_length, embedding_dim) - prefix embeddings
|
| 211 |
-
entry_length: Maximum sequence length to generate
|
| 212 |
-
top_p: Nucleus sampling parameter
|
| 213 |
-
temperature: Sampling temperature
|
| 214 |
-
stop_token: Token to stop generation
|
| 215 |
-
|
| 216 |
-
Returns:
|
| 217 |
-
List[str]: Generated captions for each item in batch
|
| 218 |
-
"""
|
| 219 |
-
model.eval()
|
| 220 |
-
device = next(model.parameters()).device
|
| 221 |
-
batch_size = prefix_embeds.shape[0]
|
| 222 |
-
|
| 223 |
-
# Initialize
|
| 224 |
-
stop_token_index = tokenizer.encode(stop_token)[0]
|
| 225 |
-
filter_value = -float("Inf")
|
| 226 |
-
|
| 227 |
-
# Track which sequences are still generating
|
| 228 |
-
active_sequences = torch.ones(batch_size, dtype=torch.bool, device=device)
|
| 229 |
-
|
| 230 |
-
# Initialize token sequences - start with None
|
| 231 |
-
tokens = None
|
| 232 |
-
generated_embeds = prefix_embeds # Start with prefix embeddings
|
| 233 |
-
|
| 234 |
-
with torch.no_grad():
|
| 235 |
-
for step in range(entry_length):
|
| 236 |
-
# Forward pass for all active sequences
|
| 237 |
-
outputs = model.gpt(inputs_embeds=generated_embeds)
|
| 238 |
-
logits = outputs.logits[:, -1, :] # Get logits for last token: (batch_size, vocab_size)
|
| 239 |
-
|
| 240 |
-
# Apply temperature
|
| 241 |
-
logits = logits / (temperature if temperature > 0 else 1.0)
|
| 242 |
-
|
| 243 |
-
# Apply nucleus sampling for each sequence in batch
|
| 244 |
-
for i in range(batch_size):
|
| 245 |
-
if not active_sequences[i]:
|
| 246 |
-
continue
|
| 247 |
-
|
| 248 |
-
# Sort logits for this sequence
|
| 249 |
-
sorted_logits, sorted_indices = torch.sort(logits[i], descending=True)
|
| 250 |
-
cumulative_probs = torch.cumsum(torch.softmax(sorted_logits, dim=-1), dim=-1)
|
| 251 |
-
|
| 252 |
-
# Find indices to remove (above top_p threshold)
|
| 253 |
-
sorted_indices_to_remove = cumulative_probs > top_p
|
| 254 |
-
sorted_indices_to_remove[1:] = sorted_indices_to_remove[:-1].clone()
|
| 255 |
-
sorted_indices_to_remove[0] = 0
|
| 256 |
-
|
| 257 |
-
# Set logits to -inf for tokens to remove
|
| 258 |
-
indices_to_remove = sorted_indices[sorted_indices_to_remove]
|
| 259 |
-
logits[i, indices_to_remove] = filter_value
|
| 260 |
-
|
| 261 |
-
# Clamp logits to avoid extreme values
|
| 262 |
-
logits = torch.clamp(logits, min=-1e9, max=1e9) # keep values bounded
|
| 263 |
-
# Sample next tokens for all sequences
|
| 264 |
-
probs = torch.softmax(logits, dim=-1)
|
| 265 |
-
|
| 266 |
-
# if some sequences probs tensor contains NaNs (e.g. all logits were -inf), set stop_token_index prob to 1
|
| 267 |
-
for i in range(batch_size):
|
| 268 |
-
if torch.isnan(probs[i]).all(): #if not torch.isfinite(probs[i]).any() or probs[i].sum() == 0:
|
| 269 |
-
probs[i] = torch.zeros_like(probs[i])
|
| 270 |
-
probs[i, stop_token_index] = 1.0
|
| 271 |
-
|
| 272 |
-
next_tokens = torch.multinomial(probs, num_samples=1) # (batch_size, 1)
|
| 273 |
-
|
| 274 |
-
# Get embeddings for next tokens
|
| 275 |
-
next_token_embeds = model.gpt.transformer.wte(next_tokens) # (batch_size, 1, embed_dim)
|
| 276 |
-
|
| 277 |
-
# Update token sequences
|
| 278 |
-
if tokens is None:
|
| 279 |
-
tokens = next_tokens
|
| 280 |
-
else:
|
| 281 |
-
tokens = torch.cat((tokens, next_tokens), dim=1)
|
| 282 |
-
|
| 283 |
-
# Update generated embeddings
|
| 284 |
-
generated_embeds = torch.cat((generated_embeds, next_token_embeds), dim=1)
|
| 285 |
-
|
| 286 |
-
# Check for stop tokens and update active sequences
|
| 287 |
-
for i in range(batch_size):
|
| 288 |
-
if active_sequences[i] and next_tokens[i].item() == stop_token_index:
|
| 289 |
-
active_sequences[i] = False
|
| 290 |
-
|
| 291 |
-
# If all sequences have stopped, break early
|
| 292 |
-
if not active_sequences.any():
|
| 293 |
-
break
|
| 294 |
-
|
| 295 |
-
# Decode all sequences
|
| 296 |
-
captions = []
|
| 297 |
-
for i in range(batch_size):
|
| 298 |
-
if tokens is not None:
|
| 299 |
-
token_list = tokens[i].cpu().numpy().tolist()
|
| 300 |
-
# Remove padding and decode
|
| 301 |
-
caption = tokenizer.decode(token_list)
|
| 302 |
-
# Clean up the caption
|
| 303 |
-
caption = caption.split(stop_token)[0] + stop_token
|
| 304 |
-
captions.append(caption)
|
| 305 |
-
else:
|
| 306 |
-
captions.append("")
|
| 307 |
-
|
| 308 |
-
return captions
|
| 309 |
-
|
| 310 |
-
|
| 311 |
-
def generate2(
|
| 312 |
-
model,
|
| 313 |
-
tokenizer,
|
| 314 |
-
tokens=None,
|
| 315 |
-
prompt=None,
|
| 316 |
-
embed=None,
|
| 317 |
-
entry_count=1,
|
| 318 |
-
entry_length=67, # maximum number of words
|
| 319 |
-
top_p=0.8,
|
| 320 |
-
temperature=1.,
|
| 321 |
-
stop_token: str = '.',
|
| 322 |
-
):
|
| 323 |
-
"""
|
| 324 |
-
Legacy single-sequence generation function.
|
| 325 |
-
For new code, use generate_batched instead.
|
| 326 |
-
"""
|
| 327 |
-
model.eval()
|
| 328 |
-
generated_num = 0
|
| 329 |
-
generated_list = []
|
| 330 |
-
stop_token_index = tokenizer.encode(stop_token)[0]
|
| 331 |
-
filter_value = -float("Inf")
|
| 332 |
-
device = next(model.parameters()).device
|
| 333 |
-
|
| 334 |
-
with torch.no_grad():
|
| 335 |
-
|
| 336 |
-
for entry_idx in range(entry_count):
|
| 337 |
-
if embed is not None:
|
| 338 |
-
generated = embed
|
| 339 |
-
else:
|
| 340 |
-
if tokens is None:
|
| 341 |
-
tokens = torch.tensor(tokenizer.encode(prompt))
|
| 342 |
-
tokens = tokens.unsqueeze(0).to(device)
|
| 343 |
-
|
| 344 |
-
generated = model.gpt.transformer.wte(tokens)
|
| 345 |
-
|
| 346 |
-
for i in range(entry_length):
|
| 347 |
-
|
| 348 |
-
outputs = model.gpt(inputs_embeds=generated)
|
| 349 |
-
logits = outputs.logits
|
| 350 |
-
logits = logits[:, -1, :] / (temperature if temperature > 0 else 1.0)
|
| 351 |
-
sorted_logits, sorted_indices = torch.sort(logits, descending=True)
|
| 352 |
-
cumulative_probs = torch.cumsum(torch.softmax(sorted_logits, dim=-1), dim=-1)
|
| 353 |
-
sorted_indices_to_remove = cumulative_probs > top_p
|
| 354 |
-
sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[
|
| 355 |
-
..., :-1
|
| 356 |
-
].clone()
|
| 357 |
-
sorted_indices_to_remove[..., 0] = 0
|
| 358 |
-
|
| 359 |
-
indices_to_remove = sorted_indices[sorted_indices_to_remove]
|
| 360 |
-
logits[:, indices_to_remove] = filter_value
|
| 361 |
-
next_token = torch.multinomial(torch.softmax(logits, dim=-1), num_samples=1)
|
| 362 |
-
next_token_embed = model.gpt.transformer.wte(next_token)
|
| 363 |
-
if tokens is None:
|
| 364 |
-
tokens = next_token
|
| 365 |
-
else:
|
| 366 |
-
tokens = torch.cat((tokens, next_token), dim=1)
|
| 367 |
-
generated = torch.cat((generated, next_token_embed), dim=1)
|
| 368 |
-
if stop_token_index == next_token.item():
|
| 369 |
-
break
|
| 370 |
-
|
| 371 |
-
output_list = list(tokens.squeeze().cpu().numpy())
|
| 372 |
-
output_text = tokenizer.decode(output_list)
|
| 373 |
-
generated_list.append(output_text)
|
| 374 |
-
|
| 375 |
-
return generated_list[0]
|
| 376 |
-
|
| 377 |
-
|
| 378 |
-
class ClipCapModel(torch.nn.Module):
|
| 379 |
-
"""
|
| 380 |
-
ClipCap integration for the Patchioner class.
|
| 381 |
-
"""
|
| 382 |
-
|
| 383 |
-
def __init__(self, args, device, dino_feature_dim=768):
|
| 384 |
-
super(ClipCapModel, self).__init__()
|
| 385 |
-
args_dict = args.copy()
|
| 386 |
-
self.args = args = self.load_config(args)
|
| 387 |
-
self.device = device
|
| 388 |
-
self.dino_feature_dim = dino_feature_dim
|
| 389 |
-
|
| 390 |
-
# Initialize tokenizer
|
| 391 |
-
self.tokenizer = GPT2Tokenizer.from_pretrained(args.language_model)
|
| 392 |
-
if self.tokenizer.pad_token_id is None:
|
| 393 |
-
self.tokenizer.pad_token_id = self.tokenizer.eos_token_id
|
| 394 |
-
|
| 395 |
-
# Determine mapping type
|
| 396 |
-
mapping_type = MappingType.MLP if args.mapping_type.lower() == 'mlp' else MappingType.Transformer
|
| 397 |
-
|
| 398 |
-
# Initialize model with DINO feature dimensions
|
| 399 |
-
if args.only_prefix:
|
| 400 |
-
self.model = ClipCaptionPrefix(
|
| 401 |
-
prefix_length=args.prefix_length,
|
| 402 |
-
clip_length=args.clip_length,
|
| 403 |
-
prefix_size=dino_feature_dim,
|
| 404 |
-
num_layers=args.num_layers,
|
| 405 |
-
mapping_type=mapping_type
|
| 406 |
-
)
|
| 407 |
-
else:
|
| 408 |
-
self.model = ClipCaptionModel(
|
| 409 |
-
prefix_length=args.prefix_length,
|
| 410 |
-
clip_length=args.clip_length,
|
| 411 |
-
prefix_size=dino_feature_dim,
|
| 412 |
-
num_layers=args.num_layers,
|
| 413 |
-
mapping_type=mapping_type
|
| 414 |
-
)
|
| 415 |
-
|
| 416 |
-
# Load trained weights
|
| 417 |
-
print(f"Loading ClipCap weights from: {args.weight_path}")
|
| 418 |
-
checkpoint = torch.load(args.weight_path, map_location=device)
|
| 419 |
-
self.model.load_state_dict(checkpoint, strict=False)
|
| 420 |
-
self.model.to(device)
|
| 421 |
-
self.model.eval()
|
| 422 |
-
|
| 423 |
-
defaults = {
|
| 424 |
-
"language_model": "gpt2",
|
| 425 |
-
"prefix_length": 10,
|
| 426 |
-
"clip_length": 10,
|
| 427 |
-
"num_layers": 8,
|
| 428 |
-
"mapping_type": "mlp",
|
| 429 |
-
"only_prefix": True,
|
| 430 |
-
"temperature": 1.0,
|
| 431 |
-
"top_p": 0.8,
|
| 432 |
-
"entry_length": 67,
|
| 433 |
-
"stop_token": ".",
|
| 434 |
-
"use_batched_generation": True, # Use batched generation by default
|
| 435 |
-
"normalize_prefix": False, # Whether to L2 normalize the input features
|
| 436 |
-
"weight_path": "/raid/datasets/models_weights/clipcap/training-features/clipcap_dino_vitb14_len10_mlp.pt"
|
| 437 |
-
}
|
| 438 |
-
|
| 439 |
-
def load_config(self, args_dict: dict) -> Namespace:
|
| 440 |
-
def dict_to_namespace(d):
|
| 441 |
-
if isinstance(d, dict):
|
| 442 |
-
return Namespace(**{k: dict_to_namespace(v) for k, v in d.items()})
|
| 443 |
-
return d
|
| 444 |
-
|
| 445 |
-
# Apply defaults
|
| 446 |
-
for key, value in self.defaults.items():
|
| 447 |
-
if isinstance(value, dict):
|
| 448 |
-
for sub_key, sub_value in value.items():
|
| 449 |
-
args_dict.setdefault(key, {}).setdefault(sub_key, sub_value)
|
| 450 |
-
else:
|
| 451 |
-
args_dict.setdefault(key, value)
|
| 452 |
-
|
| 453 |
-
args = dict_to_namespace(args_dict)
|
| 454 |
-
return args
|
| 455 |
-
|
| 456 |
-
def forward(self, dino_features, compute_scores: bool = False) -> List[str]:
|
| 457 |
-
"""
|
| 458 |
-
DINO Features: (batch_size, dino_feature_dim)
|
| 459 |
-
- returns: List[str] of generated captions
|
| 460 |
-
"""
|
| 461 |
-
if self.args.use_batched_generation:
|
| 462 |
-
return self.forward_batched(dino_features, compute_scores)
|
| 463 |
-
else:
|
| 464 |
-
return self.forward_sequential(dino_features, compute_scores)
|
| 465 |
-
|
| 466 |
-
def forward_batched(self, dino_features, compute_scores: bool = False) -> List[str]:
|
| 467 |
-
"""
|
| 468 |
-
Efficient batched generation for multiple sequences.
|
| 469 |
-
"""
|
| 470 |
-
batch_size = dino_features.shape[0]
|
| 471 |
-
|
| 472 |
-
# Apply normalization if specified (to match training)
|
| 473 |
-
if self.args.normalize_prefix:
|
| 474 |
-
dino_features = dino_features / dino_features.norm(dim=-1, keepdim=True)
|
| 475 |
-
|
| 476 |
-
# Generate prefix embeddings for entire batch
|
| 477 |
-
with torch.no_grad():
|
| 478 |
-
prefix_embeds = self.model.clip_project(dino_features).view(
|
| 479 |
-
batch_size, self.args.prefix_length, -1
|
| 480 |
-
)
|
| 481 |
-
|
| 482 |
-
# Generate captions for entire batch
|
| 483 |
-
captions = generate_batched(
|
| 484 |
-
model=self.model,
|
| 485 |
-
tokenizer=self.tokenizer,
|
| 486 |
-
prefix_embeds=prefix_embeds,
|
| 487 |
-
entry_length=self.args.entry_length,
|
| 488 |
-
temperature=self.args.temperature,
|
| 489 |
-
top_p=self.args.top_p,
|
| 490 |
-
stop_token=self.args.stop_token
|
| 491 |
-
)
|
| 492 |
-
|
| 493 |
-
if compute_scores:
|
| 494 |
-
# Compute perplexity scores for generated captions
|
| 495 |
-
scores = self.compute_perplexity_scores(captions)
|
| 496 |
-
return captions, scores
|
| 497 |
-
else:
|
| 498 |
-
return captions
|
| 499 |
-
|
| 500 |
-
def forward_sequential(self, dino_features, compute_scores: bool = False) -> List[str]:
|
| 501 |
-
"""
|
| 502 |
-
Sequential generation for backward compatibility or debugging.
|
| 503 |
-
"""
|
| 504 |
-
batch_size = dino_features.shape[0]
|
| 505 |
-
captions = []
|
| 506 |
-
scores = []
|
| 507 |
-
|
| 508 |
-
# Process each feature in the batch sequentially
|
| 509 |
-
for i in range(batch_size):
|
| 510 |
-
feature = dino_features[i:i+1] # Keep batch dimension
|
| 511 |
-
|
| 512 |
-
# Apply normalization if enabled
|
| 513 |
-
if self.args.normalize_prefix:
|
| 514 |
-
feature = feature / feature.norm(dim=-1, keepdim=True)
|
| 515 |
-
|
| 516 |
-
# Generate prefix embeddings
|
| 517 |
-
with torch.no_grad():
|
| 518 |
-
prefix_embed = self.model.clip_project(feature).view(1, self.args.prefix_length, -1)
|
| 519 |
-
|
| 520 |
-
# Generate caption using legacy function
|
| 521 |
-
caption = generate2(
|
| 522 |
-
model=self.model,
|
| 523 |
-
tokenizer=self.tokenizer,
|
| 524 |
-
embed=prefix_embed,
|
| 525 |
-
entry_length=self.args.entry_length,
|
| 526 |
-
temperature=self.args.temperature,
|
| 527 |
-
top_p=self.args.top_p,
|
| 528 |
-
stop_token=self.args.stop_token
|
| 529 |
-
)
|
| 530 |
-
|
| 531 |
-
captions.append(caption)
|
| 532 |
-
if compute_scores:
|
| 533 |
-
# Compute perplexity for this caption
|
| 534 |
-
score = self.compute_perplexity_scores([caption])[0]
|
| 535 |
-
scores.append(score)
|
| 536 |
-
|
| 537 |
-
return captions if not compute_scores else (captions, scores)
|
| 538 |
-
|
| 539 |
-
def compute_perplexity_scores(self, captions: List[str]) -> List[float]:
|
| 540 |
-
"""
|
| 541 |
-
Compute perplexity scores for generated captions.
|
| 542 |
-
"""
|
| 543 |
-
scores = []
|
| 544 |
-
self.model.eval()
|
| 545 |
-
|
| 546 |
-
with torch.no_grad():
|
| 547 |
-
for caption in captions:
|
| 548 |
-
try:
|
| 549 |
-
# Tokenize caption
|
| 550 |
-
tokens = self.tokenizer.encode(caption, return_tensors='pt').to(self.device)
|
| 551 |
-
|
| 552 |
-
# Compute loss (negative log-likelihood)
|
| 553 |
-
outputs = self.model.gpt(input_ids=tokens, labels=tokens)
|
| 554 |
-
loss = outputs.loss
|
| 555 |
-
|
| 556 |
-
# Convert to perplexity (lower is better, but we'll use 1/perplexity as score)
|
| 557 |
-
perplexity = torch.exp(loss).item()
|
| 558 |
-
score = 1.0 / perplexity if perplexity > 0 else 1.0
|
| 559 |
-
scores.append(score)
|
| 560 |
-
except:
|
| 561 |
-
# Fallback score if computation fails
|
| 562 |
-
scores.append(1.0)
|
| 563 |
-
|
| 564 |
-
return scores
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
src/clipcap/predict.py
DELETED
|
@@ -1,302 +0,0 @@
|
|
| 1 |
-
# Prediction interface for Cog โ๏ธ
|
| 2 |
-
# Reference: https://github.com/replicate/cog/blob/main/docs/python.md
|
| 3 |
-
|
| 4 |
-
import clip
|
| 5 |
-
import os
|
| 6 |
-
from torch import nn
|
| 7 |
-
import numpy as np
|
| 8 |
-
import torch
|
| 9 |
-
import torch.nn.functional as nnf
|
| 10 |
-
import sys
|
| 11 |
-
from typing import Tuple, List, Union, Optional
|
| 12 |
-
from transformers import (
|
| 13 |
-
GPT2Tokenizer,
|
| 14 |
-
GPT2LMHeadModel,
|
| 15 |
-
AdamW,
|
| 16 |
-
get_linear_schedule_with_warmup,
|
| 17 |
-
)
|
| 18 |
-
import skimage.io as io
|
| 19 |
-
import PIL.Image
|
| 20 |
-
|
| 21 |
-
import cog
|
| 22 |
-
|
| 23 |
-
# import torch
|
| 24 |
-
|
| 25 |
-
N = type(None)
|
| 26 |
-
V = np.array
|
| 27 |
-
ARRAY = np.ndarray
|
| 28 |
-
ARRAYS = Union[Tuple[ARRAY, ...], List[ARRAY]]
|
| 29 |
-
VS = Union[Tuple[V, ...], List[V]]
|
| 30 |
-
VN = Union[V, N]
|
| 31 |
-
VNS = Union[VS, N]
|
| 32 |
-
T = torch.Tensor
|
| 33 |
-
TS = Union[Tuple[T, ...], List[T]]
|
| 34 |
-
TN = Optional[T]
|
| 35 |
-
TNS = Union[Tuple[TN, ...], List[TN]]
|
| 36 |
-
TSN = Optional[TS]
|
| 37 |
-
TA = Union[T, ARRAY]
|
| 38 |
-
|
| 39 |
-
WEIGHTS_PATHS = {
|
| 40 |
-
"coco": "coco_weights.pt",
|
| 41 |
-
"conceptual-captions": "conceptual_weights.pt",
|
| 42 |
-
}
|
| 43 |
-
|
| 44 |
-
D = torch.device
|
| 45 |
-
CPU = torch.device("cpu")
|
| 46 |
-
|
| 47 |
-
|
| 48 |
-
class Predictor(cog.Predictor):
|
| 49 |
-
def setup(self):
|
| 50 |
-
"""Load the model into memory to make running multiple predictions efficient"""
|
| 51 |
-
self.device = torch.device("cuda")
|
| 52 |
-
self.clip_model, self.preprocess = clip.load(
|
| 53 |
-
"ViT-B/32", device=self.device, jit=False
|
| 54 |
-
)
|
| 55 |
-
self.tokenizer = GPT2Tokenizer.from_pretrained("gpt2")
|
| 56 |
-
|
| 57 |
-
self.models = {}
|
| 58 |
-
self.prefix_length = 10
|
| 59 |
-
for key, weights_path in WEIGHTS_PATHS.items():
|
| 60 |
-
model = ClipCaptionModel(self.prefix_length)
|
| 61 |
-
model.load_state_dict(torch.load(weights_path, map_location=CPU))
|
| 62 |
-
model = model.eval()
|
| 63 |
-
model = model.to(self.device)
|
| 64 |
-
self.models[key] = model
|
| 65 |
-
|
| 66 |
-
@cog.input("image", type=cog.Path, help="Input image")
|
| 67 |
-
@cog.input(
|
| 68 |
-
"model",
|
| 69 |
-
type=str,
|
| 70 |
-
options=WEIGHTS_PATHS.keys(),
|
| 71 |
-
default="coco",
|
| 72 |
-
help="Model to use",
|
| 73 |
-
)
|
| 74 |
-
@cog.input(
|
| 75 |
-
"use_beam_search",
|
| 76 |
-
type=bool,
|
| 77 |
-
default=False,
|
| 78 |
-
help="Whether to apply beam search to generate the output text",
|
| 79 |
-
)
|
| 80 |
-
def predict(self, image, model, use_beam_search):
|
| 81 |
-
"""Run a single prediction on the model"""
|
| 82 |
-
image = io.imread(image)
|
| 83 |
-
model = self.models[model]
|
| 84 |
-
pil_image = PIL.Image.fromarray(image)
|
| 85 |
-
image = self.preprocess(pil_image).unsqueeze(0).to(self.device)
|
| 86 |
-
with torch.no_grad():
|
| 87 |
-
prefix = self.clip_model.encode_image(image).to(
|
| 88 |
-
self.device, dtype=torch.float32
|
| 89 |
-
)
|
| 90 |
-
prefix_embed = model.clip_project(prefix).reshape(1, self.prefix_length, -1)
|
| 91 |
-
if use_beam_search:
|
| 92 |
-
return generate_beam(model, self.tokenizer, embed=prefix_embed)[0]
|
| 93 |
-
else:
|
| 94 |
-
return generate2(model, self.tokenizer, embed=prefix_embed)
|
| 95 |
-
|
| 96 |
-
|
| 97 |
-
class MLP(nn.Module):
|
| 98 |
-
def forward(self, x: T) -> T:
|
| 99 |
-
return self.model(x)
|
| 100 |
-
|
| 101 |
-
def __init__(self, sizes: Tuple[int, ...], bias=True, act=nn.Tanh):
|
| 102 |
-
super(MLP, self).__init__()
|
| 103 |
-
layers = []
|
| 104 |
-
for i in range(len(sizes) - 1):
|
| 105 |
-
layers.append(nn.Linear(sizes[i], sizes[i + 1], bias=bias))
|
| 106 |
-
if i < len(sizes) - 2:
|
| 107 |
-
layers.append(act())
|
| 108 |
-
self.model = nn.Sequential(*layers)
|
| 109 |
-
|
| 110 |
-
|
| 111 |
-
class ClipCaptionModel(nn.Module):
|
| 112 |
-
|
| 113 |
-
# @functools.lru_cache #FIXME
|
| 114 |
-
def get_dummy_token(self, batch_size: int, device: D) -> T:
|
| 115 |
-
return torch.zeros(
|
| 116 |
-
batch_size, self.prefix_length, dtype=torch.int64, device=device
|
| 117 |
-
)
|
| 118 |
-
|
| 119 |
-
def forward(
|
| 120 |
-
self, tokens: T, prefix: T, mask: Optional[T] = None, labels: Optional[T] = None
|
| 121 |
-
):
|
| 122 |
-
embedding_text = self.gpt.transformer.wte(tokens)
|
| 123 |
-
prefix_projections = self.clip_project(prefix).view(
|
| 124 |
-
-1, self.prefix_length, self.gpt_embedding_size
|
| 125 |
-
)
|
| 126 |
-
# print(embedding_text.size()) #torch.Size([5, 67, 768])
|
| 127 |
-
# print(prefix_projections.size()) #torch.Size([5, 1, 768])
|
| 128 |
-
embedding_cat = torch.cat((prefix_projections, embedding_text), dim=1)
|
| 129 |
-
if labels is not None:
|
| 130 |
-
dummy_token = self.get_dummy_token(tokens.shape[0], tokens.device)
|
| 131 |
-
labels = torch.cat((dummy_token, tokens), dim=1)
|
| 132 |
-
out = self.gpt(inputs_embeds=embedding_cat, labels=labels, attention_mask=mask)
|
| 133 |
-
return out
|
| 134 |
-
|
| 135 |
-
def __init__(self, prefix_length: int, prefix_size: int = 512):
|
| 136 |
-
super(ClipCaptionModel, self).__init__()
|
| 137 |
-
self.prefix_length = prefix_length
|
| 138 |
-
self.gpt = GPT2LMHeadModel.from_pretrained("gpt2")
|
| 139 |
-
self.gpt_embedding_size = self.gpt.transformer.wte.weight.shape[1]
|
| 140 |
-
if prefix_length > 10: # not enough memory
|
| 141 |
-
self.clip_project = nn.Linear(
|
| 142 |
-
prefix_size, self.gpt_embedding_size * prefix_length
|
| 143 |
-
)
|
| 144 |
-
else:
|
| 145 |
-
self.clip_project = MLP(
|
| 146 |
-
(
|
| 147 |
-
prefix_size,
|
| 148 |
-
(self.gpt_embedding_size * prefix_length) // 2,
|
| 149 |
-
self.gpt_embedding_size * prefix_length,
|
| 150 |
-
)
|
| 151 |
-
)
|
| 152 |
-
|
| 153 |
-
|
| 154 |
-
class ClipCaptionPrefix(ClipCaptionModel):
|
| 155 |
-
def parameters(self, recurse: bool = True):
|
| 156 |
-
return self.clip_project.parameters()
|
| 157 |
-
|
| 158 |
-
def train(self, mode: bool = True):
|
| 159 |
-
super(ClipCaptionPrefix, self).train(mode)
|
| 160 |
-
self.gpt.eval()
|
| 161 |
-
return self
|
| 162 |
-
|
| 163 |
-
|
| 164 |
-
def generate_beam(
|
| 165 |
-
model,
|
| 166 |
-
tokenizer,
|
| 167 |
-
beam_size: int = 5,
|
| 168 |
-
prompt=None,
|
| 169 |
-
embed=None,
|
| 170 |
-
entry_length=67,
|
| 171 |
-
temperature=1.0,
|
| 172 |
-
stop_token: str = ".",
|
| 173 |
-
):
|
| 174 |
-
|
| 175 |
-
model.eval()
|
| 176 |
-
stop_token_index = tokenizer.encode(stop_token)[0]
|
| 177 |
-
tokens = None
|
| 178 |
-
scores = None
|
| 179 |
-
device = next(model.parameters()).device
|
| 180 |
-
seq_lengths = torch.ones(beam_size, device=device)
|
| 181 |
-
is_stopped = torch.zeros(beam_size, device=device, dtype=torch.bool)
|
| 182 |
-
with torch.no_grad():
|
| 183 |
-
if embed is not None:
|
| 184 |
-
generated = embed
|
| 185 |
-
else:
|
| 186 |
-
if tokens is None:
|
| 187 |
-
tokens = torch.tensor(tokenizer.encode(prompt))
|
| 188 |
-
tokens = tokens.unsqueeze(0).to(device)
|
| 189 |
-
generated = model.gpt.transformer.wte(tokens)
|
| 190 |
-
for i in range(entry_length):
|
| 191 |
-
outputs = model.gpt(inputs_embeds=generated)
|
| 192 |
-
logits = outputs.logits
|
| 193 |
-
logits = logits[:, -1, :] / (temperature if temperature > 0 else 1.0)
|
| 194 |
-
logits = logits.softmax(-1).log()
|
| 195 |
-
if scores is None:
|
| 196 |
-
scores, next_tokens = logits.topk(beam_size, -1)
|
| 197 |
-
generated = generated.expand(beam_size, *generated.shape[1:])
|
| 198 |
-
next_tokens, scores = next_tokens.permute(1, 0), scores.squeeze(0)
|
| 199 |
-
if tokens is None:
|
| 200 |
-
tokens = next_tokens
|
| 201 |
-
else:
|
| 202 |
-
tokens = tokens.expand(beam_size, *tokens.shape[1:])
|
| 203 |
-
tokens = torch.cat((tokens, next_tokens), dim=1)
|
| 204 |
-
else:
|
| 205 |
-
logits[is_stopped] = -float(np.inf)
|
| 206 |
-
logits[is_stopped, 0] = 0
|
| 207 |
-
scores_sum = scores[:, None] + logits
|
| 208 |
-
seq_lengths[~is_stopped] += 1
|
| 209 |
-
scores_sum_average = scores_sum / seq_lengths[:, None]
|
| 210 |
-
scores_sum_average, next_tokens = scores_sum_average.view(-1).topk(
|
| 211 |
-
beam_size, -1
|
| 212 |
-
)
|
| 213 |
-
next_tokens_source = next_tokens // scores_sum.shape[1]
|
| 214 |
-
seq_lengths = seq_lengths[next_tokens_source]
|
| 215 |
-
next_tokens = next_tokens % scores_sum.shape[1]
|
| 216 |
-
next_tokens = next_tokens.unsqueeze(1)
|
| 217 |
-
tokens = tokens[next_tokens_source]
|
| 218 |
-
tokens = torch.cat((tokens, next_tokens), dim=1)
|
| 219 |
-
generated = generated[next_tokens_source]
|
| 220 |
-
scores = scores_sum_average * seq_lengths
|
| 221 |
-
is_stopped = is_stopped[next_tokens_source]
|
| 222 |
-
next_token_embed = model.gpt.transformer.wte(next_tokens.squeeze()).view(
|
| 223 |
-
generated.shape[0], 1, -1
|
| 224 |
-
)
|
| 225 |
-
generated = torch.cat((generated, next_token_embed), dim=1)
|
| 226 |
-
is_stopped = is_stopped + next_tokens.eq(stop_token_index).squeeze()
|
| 227 |
-
if is_stopped.all():
|
| 228 |
-
break
|
| 229 |
-
scores = scores / seq_lengths
|
| 230 |
-
output_list = tokens.cpu().numpy()
|
| 231 |
-
output_texts = [
|
| 232 |
-
tokenizer.decode(output[: int(length)])
|
| 233 |
-
for output, length in zip(output_list, seq_lengths)
|
| 234 |
-
]
|
| 235 |
-
order = scores.argsort(descending=True)
|
| 236 |
-
output_texts = [output_texts[i] for i in order]
|
| 237 |
-
return output_texts
|
| 238 |
-
|
| 239 |
-
|
| 240 |
-
def generate2(
|
| 241 |
-
model,
|
| 242 |
-
tokenizer,
|
| 243 |
-
tokens=None,
|
| 244 |
-
prompt=None,
|
| 245 |
-
embed=None,
|
| 246 |
-
entry_count=1,
|
| 247 |
-
entry_length=67, # maximum number of words
|
| 248 |
-
top_p=0.8,
|
| 249 |
-
temperature=1.0,
|
| 250 |
-
stop_token: str = ".",
|
| 251 |
-
):
|
| 252 |
-
model.eval()
|
| 253 |
-
generated_num = 0
|
| 254 |
-
generated_list = []
|
| 255 |
-
stop_token_index = tokenizer.encode(stop_token)[0]
|
| 256 |
-
filter_value = -float("Inf")
|
| 257 |
-
device = next(model.parameters()).device
|
| 258 |
-
|
| 259 |
-
with torch.no_grad():
|
| 260 |
-
|
| 261 |
-
for entry_idx in range(entry_count):
|
| 262 |
-
if embed is not None:
|
| 263 |
-
generated = embed
|
| 264 |
-
else:
|
| 265 |
-
if tokens is None:
|
| 266 |
-
tokens = torch.tensor(tokenizer.encode(prompt))
|
| 267 |
-
tokens = tokens.unsqueeze(0).to(device)
|
| 268 |
-
|
| 269 |
-
generated = model.gpt.transformer.wte(tokens)
|
| 270 |
-
|
| 271 |
-
for i in range(entry_length):
|
| 272 |
-
|
| 273 |
-
outputs = model.gpt(inputs_embeds=generated)
|
| 274 |
-
logits = outputs.logits
|
| 275 |
-
logits = logits[:, -1, :] / (temperature if temperature > 0 else 1.0)
|
| 276 |
-
sorted_logits, sorted_indices = torch.sort(logits, descending=True)
|
| 277 |
-
cumulative_probs = torch.cumsum(
|
| 278 |
-
nnf.softmax(sorted_logits, dim=-1), dim=-1
|
| 279 |
-
)
|
| 280 |
-
sorted_indices_to_remove = cumulative_probs > top_p
|
| 281 |
-
sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[
|
| 282 |
-
..., :-1
|
| 283 |
-
].clone()
|
| 284 |
-
sorted_indices_to_remove[..., 0] = 0
|
| 285 |
-
|
| 286 |
-
indices_to_remove = sorted_indices[sorted_indices_to_remove]
|
| 287 |
-
logits[:, indices_to_remove] = filter_value
|
| 288 |
-
next_token = torch.argmax(logits, -1).unsqueeze(0)
|
| 289 |
-
next_token_embed = model.gpt.transformer.wte(next_token)
|
| 290 |
-
if tokens is None:
|
| 291 |
-
tokens = next_token
|
| 292 |
-
else:
|
| 293 |
-
tokens = torch.cat((tokens, next_token), dim=1)
|
| 294 |
-
generated = torch.cat((generated, next_token_embed), dim=1)
|
| 295 |
-
if stop_token_index == next_token.item():
|
| 296 |
-
break
|
| 297 |
-
|
| 298 |
-
output_list = list(tokens.squeeze().cpu().numpy())
|
| 299 |
-
output_text = tokenizer.decode(output_list)
|
| 300 |
-
generated_list.append(output_text)
|
| 301 |
-
|
| 302 |
-
return generated_list[0]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
src/dataset.py
DELETED
|
@@ -1,94 +0,0 @@
|
|
| 1 |
-
import torch
|
| 2 |
-
from torch.utils.data import Dataset
|
| 3 |
-
|
| 4 |
-
from tqdm import tqdm
|
| 5 |
-
import json
|
| 6 |
-
from typing import Tuple
|
| 7 |
-
import clip
|
| 8 |
-
import random
|
| 9 |
-
import json
|
| 10 |
-
import random
|
| 11 |
-
from tqdm import tqdm
|
| 12 |
-
|
| 13 |
-
class ClipCocoDataset(Dataset):
|
| 14 |
-
|
| 15 |
-
def __len__(self) -> int:
|
| 16 |
-
return len(self.captions_tokens)
|
| 17 |
-
|
| 18 |
-
def pad_tokens(self, item: int):
|
| 19 |
-
tokens = self.captions_tokens[item]
|
| 20 |
-
padding = self.max_seq_len - tokens.shape[0]
|
| 21 |
-
if padding > 0:
|
| 22 |
-
tokens = torch.cat((tokens, torch.zeros(padding, dtype=torch.int64)))
|
| 23 |
-
elif padding < 0:
|
| 24 |
-
tokens = tokens[:self.max_seq_len]
|
| 25 |
-
return tokens
|
| 26 |
-
|
| 27 |
-
|
| 28 |
-
def __getitem__(self, item: int) -> Tuple[torch.Tensor, ...]:
|
| 29 |
-
# tokens = self.captions_tokens[item]
|
| 30 |
-
|
| 31 |
-
clip_tokens = self.pad_tokens(item)
|
| 32 |
-
if self.feats is None:
|
| 33 |
-
clip_tokens_77 = self.captions_tokens[item]
|
| 34 |
-
return clip_tokens, clip_tokens_77
|
| 35 |
-
else:
|
| 36 |
-
return clip_tokens, self.feats[item]
|
| 37 |
-
|
| 38 |
-
def __init__(self, data_path: str, clip_model=None, talk2dino=None, use_dino_feats=False, tokenizer=None):
|
| 39 |
-
if tokenizer is not None:
|
| 40 |
-
self.clip_tokenizer = tokenizer
|
| 41 |
-
else:
|
| 42 |
-
print(f"Using default tokenizer")
|
| 43 |
-
self.clip_tokenizer = clip.tokenize
|
| 44 |
-
self.prefix_length = 10
|
| 45 |
-
self.max_seq_len = 20
|
| 46 |
-
self.feats = None
|
| 47 |
-
|
| 48 |
-
if clip_model is not None:
|
| 49 |
-
device = next(clip_model.parameters()).device
|
| 50 |
-
print("Pre-extracting features...")
|
| 51 |
-
|
| 52 |
-
if not use_dino_feats:
|
| 53 |
-
with open(data_path, 'r') as f:
|
| 54 |
-
self.captions = [ann['caption'] for ann in json.load(f)['annotations']]
|
| 55 |
-
else:
|
| 56 |
-
data = torch.load(data_path)
|
| 57 |
-
self.captions = [ann['caption'] for ann in data['annotations']]
|
| 58 |
-
self.feats = [ann['features'] for ann in data['annotations']]
|
| 59 |
-
|
| 60 |
-
|
| 61 |
-
random.shuffle(self.captions)
|
| 62 |
-
self.captions_tokens = []
|
| 63 |
-
|
| 64 |
-
batch_size = 64
|
| 65 |
-
batched_captions = [self.captions[i:i + batch_size] for i in range(0, len(self.captions), batch_size)]
|
| 66 |
-
|
| 67 |
-
for batch in tqdm(batched_captions):
|
| 68 |
-
try:
|
| 69 |
-
# Tokenize the batch of captions
|
| 70 |
-
batch_tokens = [torch.tensor(self.clip_tokenizer(caption)[0], dtype=torch.int64) for caption in batch]
|
| 71 |
-
|
| 72 |
-
# Pad tokens to the same length for batching
|
| 73 |
-
batch_tokens_padded = torch.nn.utils.rnn.pad_sequence(batch_tokens, batch_first=True)
|
| 74 |
-
self.captions_tokens.extend(batch_tokens)
|
| 75 |
-
|
| 76 |
-
if clip_model is not None:
|
| 77 |
-
with torch.no_grad():
|
| 78 |
-
# Encode the text batch
|
| 79 |
-
feats = clip_model.encode_text(batch_tokens_padded.to(device))
|
| 80 |
-
|
| 81 |
-
if talk2dino is not None:
|
| 82 |
-
# Project to desired feature space
|
| 83 |
-
feats = talk2dino.project_clip_txt(feats).to('cpu')
|
| 84 |
-
|
| 85 |
-
# Concatenate features
|
| 86 |
-
if self.feats is None:
|
| 87 |
-
self.feats = feats
|
| 88 |
-
else:
|
| 89 |
-
self.feats = torch.cat((self.feats, feats))
|
| 90 |
-
except Exception as e:
|
| 91 |
-
print(f"Error processing batch: {e}")
|
| 92 |
-
print(len(self.captions_tokens))
|
| 93 |
-
|
| 94 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
src/datasetMix.py
DELETED
|
@@ -1,153 +0,0 @@
|
|
| 1 |
-
import torch
|
| 2 |
-
from torch.utils.data import Dataset
|
| 3 |
-
|
| 4 |
-
from tqdm import tqdm
|
| 5 |
-
import json
|
| 6 |
-
from typing import Tuple
|
| 7 |
-
import clip
|
| 8 |
-
import random
|
| 9 |
-
import json
|
| 10 |
-
import random
|
| 11 |
-
from tqdm import tqdm
|
| 12 |
-
|
| 13 |
-
from pycocotools.coco import COCO
|
| 14 |
-
|
| 15 |
-
class ClipCocoDatasetMix(Dataset):
|
| 16 |
-
|
| 17 |
-
def __len__(self) -> int:
|
| 18 |
-
return len(self.image_index_list)
|
| 19 |
-
|
| 20 |
-
def _pad_tokens(self, tokens: torch.Tensor) -> torch.Tensor:
|
| 21 |
-
padding = self.max_seq_len - tokens.shape[0]
|
| 22 |
-
if padding > 0:
|
| 23 |
-
tokens = torch.cat((tokens, torch.zeros(padding, dtype=torch.int64)))
|
| 24 |
-
elif padding < 0:
|
| 25 |
-
tokens = tokens[:self.max_seq_len]
|
| 26 |
-
return tokens
|
| 27 |
-
|
| 28 |
-
|
| 29 |
-
def __getitem__(self, item: int) -> Tuple[torch.Tensor, ...]:
|
| 30 |
-
|
| 31 |
-
# get the image index for the item
|
| 32 |
-
img_idx = self.image_index_list[item]
|
| 33 |
-
# get the caption index for that image
|
| 34 |
-
first_caption_idx = self.image_index_list.index(img_idx)
|
| 35 |
-
|
| 36 |
-
# the caption index is the item - the first caption index
|
| 37 |
-
caption_idx = item - first_caption_idx
|
| 38 |
-
|
| 39 |
-
# how many captions are there for that image?
|
| 40 |
-
num_captions = len(self.captions_list_of_lists[img_idx])
|
| 41 |
-
try:
|
| 42 |
-
tokens = self.captions_tokens_list_of_lists[img_idx][caption_idx] #self.captions_list_of_lists[img_idx][caption_idx]
|
| 43 |
-
except IndexError:
|
| 44 |
-
print(f"{len(self.captions_tokens_list_of_lists)= } - {len(self.captions_tokens_list_of_lists[img_idx])= }")
|
| 45 |
-
print(f"IndexError: {img_idx}, {caption_idx}, {num_captions}")
|
| 46 |
-
raise
|
| 47 |
-
padded_tokens = self._pad_tokens(tokens)
|
| 48 |
-
|
| 49 |
-
feats_same_img = self.feats[img_idx][random.choice(range(num_captions))]
|
| 50 |
-
|
| 51 |
-
if self.feats is None or len(self.feats) == 0:
|
| 52 |
-
raise Exception("Precomputed features required")
|
| 53 |
-
else:
|
| 54 |
-
return padded_tokens, feats_same_img
|
| 55 |
-
|
| 56 |
-
def __init__(self, data_path: str, clip_model=None, talk2dino=None, use_precomputed_feats=False, tokenizer=None):
|
| 57 |
-
|
| 58 |
-
batch_size = 64
|
| 59 |
-
self.max_seq_len = 20
|
| 60 |
-
|
| 61 |
-
if use_precomputed_feats:
|
| 62 |
-
raise Exception("Precomputed features not supported")
|
| 63 |
-
|
| 64 |
-
if tokenizer is not None:
|
| 65 |
-
self.clip_tokenizer = tokenizer
|
| 66 |
-
else:
|
| 67 |
-
print(f"Using default tokenizer")
|
| 68 |
-
self.clip_tokenizer = clip.tokenize
|
| 69 |
-
|
| 70 |
-
coco_data = COCO(data_path)
|
| 71 |
-
# I want to load the captions from the json file in a list of lists,
|
| 72 |
-
# where each list contains the captions for a single image
|
| 73 |
-
|
| 74 |
-
self.captions_list_of_lists = []
|
| 75 |
-
|
| 76 |
-
self.image_index_list = []
|
| 77 |
-
|
| 78 |
-
max_seq_len = 20
|
| 79 |
-
|
| 80 |
-
for img_idx, (img_id, image) in enumerate(list(coco_data.imgs.items())):
|
| 81 |
-
# get the captions for that image
|
| 82 |
-
captions = coco_data.imgToAnns[img_id]
|
| 83 |
-
# get the texts of the captions
|
| 84 |
-
captions = [cap['caption'] for cap in captions] #[coco_data.anns[cap]['caption'] for cap in captions]
|
| 85 |
-
self.captions_list_of_lists.append(captions)
|
| 86 |
-
self.image_index_list.append([img_idx] * len(captions))
|
| 87 |
-
|
| 88 |
-
#max_seq_len = max(max_seq_len, max([len(caption) for caption in captions]))
|
| 89 |
-
|
| 90 |
-
self.max_seq_len = max_seq_len
|
| 91 |
-
print(f"Computed Max seq len: {max_seq_len}")
|
| 92 |
-
|
| 93 |
-
if clip_model is not None:
|
| 94 |
-
device = next(clip_model.parameters()).device
|
| 95 |
-
print("Pre-extracting features...")
|
| 96 |
-
|
| 97 |
-
#random.shuffle(self.captions_list_of_lists)
|
| 98 |
-
# should shuffle in the same way self.image_index_list and self.captions_list_of_lists
|
| 99 |
-
# Combine captions and image indices into a list of pairs
|
| 100 |
-
combined = list(zip(self.captions_list_of_lists, self.image_index_list, range(len(self.captions_list_of_lists))))
|
| 101 |
-
|
| 102 |
-
# Shuffle them together
|
| 103 |
-
random.shuffle(combined)
|
| 104 |
-
|
| 105 |
-
# Unzip the shuffled pairs back into two separate lists
|
| 106 |
-
self.captions_list_of_lists, self.image_index_list, img_idxes_shuffled = zip(*combined)
|
| 107 |
-
# Convert back to lists (zip returns tuples)
|
| 108 |
-
self.captions_list_of_lists = list(self.captions_list_of_lists)
|
| 109 |
-
self.image_index_list = list(self.image_index_list)
|
| 110 |
-
img_idxes_shuffled = list(img_idxes_shuffled)
|
| 111 |
-
|
| 112 |
-
# self.image_index_list is a list of lists, where each list contains the image index for each caption,
|
| 113 |
-
# so we need to flatten it
|
| 114 |
-
self.image_index_list = [img_idxes_shuffled.index(item) for sublist in self.image_index_list for item in sublist]
|
| 115 |
-
|
| 116 |
-
|
| 117 |
-
self.captions_tokens_list_of_lists = []
|
| 118 |
-
self.feats = [] # feats will be a list of tensors, each tensor will be (num_captions, embedding_dimension)
|
| 119 |
-
#ignore. # feats shape will be (num_images, num_captions, embedding_dimension)
|
| 120 |
-
|
| 121 |
-
#batched_captions = [self.captions[i:i + batch_size] for i in range(0, len(self.captions), batch_size)]
|
| 122 |
-
|
| 123 |
-
for captions_list in tqdm(self.captions_list_of_lists, dynamic_ncols=True):
|
| 124 |
-
try:
|
| 125 |
-
# Tokenize the batch of captions
|
| 126 |
-
batch_tokens = [torch.tensor(self.clip_tokenizer(caption)[0], dtype=torch.int64) for caption in captions_list]
|
| 127 |
-
|
| 128 |
-
# Pad tokens to the same length for batching
|
| 129 |
-
batch_tokens_padded = torch.nn.utils.rnn.pad_sequence(batch_tokens, batch_first=True)
|
| 130 |
-
self.captions_tokens_list_of_lists.append(batch_tokens)
|
| 131 |
-
|
| 132 |
-
# alternative:
|
| 133 |
-
# tokens = self.clip_tokenizer(captions_list, truncate=True).to(device) # shape: (num_captions, context_length)
|
| 134 |
-
|
| 135 |
-
|
| 136 |
-
if clip_model is not None:
|
| 137 |
-
with torch.no_grad():
|
| 138 |
-
# Encode the text batch
|
| 139 |
-
feats = clip_model.encode_text(batch_tokens_padded.to(device))
|
| 140 |
-
|
| 141 |
-
if talk2dino is not None:
|
| 142 |
-
# Project to desired feature space
|
| 143 |
-
feats = talk2dino.project_clip_txt(feats).to('cpu')
|
| 144 |
-
|
| 145 |
-
self.feats.append(feats.cpu()) # store (num_captions, embed_dim) for each image
|
| 146 |
-
|
| 147 |
-
except Exception as e:
|
| 148 |
-
print(f"Error processing batch: {e}")
|
| 149 |
-
|
| 150 |
-
print(f"Dataset loaded with {len(self.captions_list_of_lists)} images")
|
| 151 |
-
print(f"Max seq len: {max_seq_len}")
|
| 152 |
-
print(f"Number of captions: {len(self.image_index_list)}")
|
| 153 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
src/decap/decap.py
DELETED
|
@@ -1,193 +0,0 @@
|
|
| 1 |
-
import os
|
| 2 |
-
from torch import nn
|
| 3 |
-
import numpy as np
|
| 4 |
-
import torch
|
| 5 |
-
import torch.nn.functional as nnf
|
| 6 |
-
import sys
|
| 7 |
-
from typing import Tuple, List, Union, Optional
|
| 8 |
-
from tqdm import tqdm, trange
|
| 9 |
-
import pickle
|
| 10 |
-
import PIL.Image as Image
|
| 11 |
-
import json
|
| 12 |
-
import random
|
| 13 |
-
import sys
|
| 14 |
-
import clip
|
| 15 |
-
import PIL
|
| 16 |
-
import random
|
| 17 |
-
|
| 18 |
-
from torch.utils.data import Dataset, DataLoader
|
| 19 |
-
from enum import Enum
|
| 20 |
-
from transformers import GPT2Tokenizer, GPT2LMHeadModel, AdamW, get_linear_schedule_with_warmup
|
| 21 |
-
from tqdm import tqdm
|
| 22 |
-
import os
|
| 23 |
-
import pickle
|
| 24 |
-
import sys
|
| 25 |
-
import argparse
|
| 26 |
-
import json
|
| 27 |
-
from typing import Tuple, Optional, Union
|
| 28 |
-
|
| 29 |
-
import os
|
| 30 |
-
from dotenv import load_dotenv
|
| 31 |
-
|
| 32 |
-
load_dotenv()
|
| 33 |
-
|
| 34 |
-
|
| 35 |
-
DECAP_DECODER_CONFIG_PATH = os.path.join(os.path.dirname(__file__), "decoder_config.pkl")
|
| 36 |
-
DECAP_COCO_WEIGHTS_PATH = None#'../../thesis-data/decap/coco_model/coco_prefix-009.pt'
|
| 37 |
-
|
| 38 |
-
class MappingType(Enum):
|
| 39 |
-
MLP = 'mlp'
|
| 40 |
-
Transformer = 'transformer'
|
| 41 |
-
|
| 42 |
-
|
| 43 |
-
class MLP(nn.Module):
|
| 44 |
-
|
| 45 |
-
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| 46 |
-
return self.model(x)
|
| 47 |
-
|
| 48 |
-
def __init__(self, sizes: Tuple[int, ...], bias=True, act=nn.Tanh):
|
| 49 |
-
super(MLP, self).__init__()
|
| 50 |
-
layers = []
|
| 51 |
-
for i in range(len(sizes) - 1):
|
| 52 |
-
layers.append(nn.Linear(sizes[i], sizes[i + 1], bias=bias))
|
| 53 |
-
if i < len(sizes) - 2:
|
| 54 |
-
layers.append(act())
|
| 55 |
-
self.model = nn.Sequential(*layers)
|
| 56 |
-
|
| 57 |
-
|
| 58 |
-
class DeCap(nn.Module):
|
| 59 |
-
|
| 60 |
-
def __init__(self,prefix_size: int = 512):
|
| 61 |
-
super(DeCap, self).__init__()
|
| 62 |
-
# decoder: 4 layers transformer with 4 attention heads
|
| 63 |
-
# the decoder is not pretrained
|
| 64 |
-
with open(DECAP_DECODER_CONFIG_PATH,'rb') as f:
|
| 65 |
-
config = pickle.load(f)
|
| 66 |
-
self.decoder = GPT2LMHeadModel(config)
|
| 67 |
-
self.embedding_size = self.decoder.transformer.wte.weight.shape[1]
|
| 68 |
-
self.clip_project = MLP((prefix_size,self.embedding_size))
|
| 69 |
-
|
| 70 |
-
def forward(self, clip_features,tokens):
|
| 71 |
-
embedding_text = self.decoder.transformer.wte(tokens)
|
| 72 |
-
embedding_clip = self.clip_project(clip_features)
|
| 73 |
-
embedding_clip = embedding_clip.reshape(-1,1,self.embedding_size)
|
| 74 |
-
embedding_cat = torch.cat([embedding_clip,embedding_text],dim=1)
|
| 75 |
-
out = self.decoder(inputs_embeds=embedding_cat)
|
| 76 |
-
return out
|
| 77 |
-
|
| 78 |
-
from clip.simple_tokenizer import SimpleTokenizer as _Tokenizer
|
| 79 |
-
_Tokenizer = _Tokenizer()
|
| 80 |
-
|
| 81 |
-
def Decoding(model,clip_features):
|
| 82 |
-
model.eval()
|
| 83 |
-
embedding_cat = model.clip_project(clip_features).reshape(1,1,-1)
|
| 84 |
-
entry_length = 30
|
| 85 |
-
temperature = 1
|
| 86 |
-
tokens = None
|
| 87 |
-
for i in range(entry_length):
|
| 88 |
-
# print(location_token.shape)
|
| 89 |
-
outputs = model.decoder(inputs_embeds=embedding_cat)
|
| 90 |
-
|
| 91 |
-
logits = outputs.logits
|
| 92 |
-
logits = logits[:, -1, :] / (temperature if temperature > 0 else 1.0)
|
| 93 |
-
logits_max = logits.max()
|
| 94 |
-
logits = torch.nn.functional.softmax(logits, -1)
|
| 95 |
-
next_token = torch.argmax(logits, -1).unsqueeze(0)
|
| 96 |
-
next_token_embed = model.decoder.transformer.wte(next_token)
|
| 97 |
-
|
| 98 |
-
if tokens is None:
|
| 99 |
-
tokens = next_token
|
| 100 |
-
|
| 101 |
-
else:
|
| 102 |
-
tokens = torch.cat((tokens, next_token), dim=1)
|
| 103 |
-
if next_token.item()==49407:
|
| 104 |
-
break
|
| 105 |
-
embedding_cat = torch.cat((embedding_cat, next_token_embed), dim=1)
|
| 106 |
-
try:
|
| 107 |
-
output_list = list(tokens.squeeze().cpu().numpy())
|
| 108 |
-
output = _Tokenizer.decode(output_list)
|
| 109 |
-
except:
|
| 110 |
-
output = 'None'
|
| 111 |
-
return output
|
| 112 |
-
|
| 113 |
-
def decoding_batched(model, clip_features, compute_scores : bool = False, decoding_method : callable = None, return_start_end_tokens : bool = False):
|
| 114 |
-
"""
|
| 115 |
-
Returns the generated sequences for a batch of clip features.
|
| 116 |
-
- if compute_scores is True, also returns the scores of the generated sequences.
|
| 117 |
-
- returns a list of strings if compute_scores is False, otherwise a tuple of a list of strings and a list of floats.
|
| 118 |
-
"""
|
| 119 |
-
|
| 120 |
-
model.eval()
|
| 121 |
-
embedding_cat = model.clip_project(clip_features).view(clip_features.shape[0], 1, -1)
|
| 122 |
-
entry_length = 30
|
| 123 |
-
temperature = 1
|
| 124 |
-
tokens = None
|
| 125 |
-
sequence_log_probs = None
|
| 126 |
-
|
| 127 |
-
for i in range(entry_length):
|
| 128 |
-
outputs = model.decoder(inputs_embeds=embedding_cat)
|
| 129 |
-
|
| 130 |
-
logits = outputs.logits[:, -1, :]
|
| 131 |
-
logits = logits / (temperature if temperature > 0 else 1.0)
|
| 132 |
-
|
| 133 |
-
probs = torch.nn.functional.softmax(logits, -1)
|
| 134 |
-
|
| 135 |
-
if compute_scores:
|
| 136 |
-
log_probs = torch.log(probs) # Convert to log-probabilities
|
| 137 |
-
|
| 138 |
-
next_token = torch.argmax(probs, -1).unsqueeze(1)
|
| 139 |
-
next_token_embed = model.decoder.transformer.wte(next_token)
|
| 140 |
-
|
| 141 |
-
if tokens is None:
|
| 142 |
-
tokens = next_token
|
| 143 |
-
if compute_scores:
|
| 144 |
-
sequence_log_probs = log_probs.gather(1, next_token) # Store log-prob of first token
|
| 145 |
-
else:
|
| 146 |
-
tokens = torch.cat((tokens, next_token), dim=1)
|
| 147 |
-
if compute_scores:
|
| 148 |
-
token_log_probs = log_probs.gather(1, next_token) # Get log-prob of chosen token
|
| 149 |
-
sequence_log_probs = torch.cat((sequence_log_probs, token_log_probs), dim=1) # Append
|
| 150 |
-
|
| 151 |
-
# Append new token embedding to input
|
| 152 |
-
embedding_cat = torch.cat((embedding_cat, next_token_embed), dim=1)
|
| 153 |
-
|
| 154 |
-
if compute_scores:
|
| 155 |
-
# Compute total sequence scores
|
| 156 |
-
sequence_scores = sequence_log_probs.sum(dim=-1) # Sum log-probs over sequence
|
| 157 |
-
final_scores = torch.exp(sequence_scores) # Convert log-sum-prob to probability-like score
|
| 158 |
-
|
| 159 |
-
try:
|
| 160 |
-
outputs = []
|
| 161 |
-
for tokens_elem in tokens:
|
| 162 |
-
output_list = list(tokens_elem.squeeze().cpu().numpy())
|
| 163 |
-
if decoding_method is not None:
|
| 164 |
-
output = decoding_method(output_list)
|
| 165 |
-
else:
|
| 166 |
-
output = _Tokenizer.decode(output_list)
|
| 167 |
-
|
| 168 |
-
|
| 169 |
-
|
| 170 |
-
output = output.split('<|endoftext|>')[0]
|
| 171 |
-
if not return_start_end_tokens:
|
| 172 |
-
output = output.replace('<|startoftext|>', '')
|
| 173 |
-
else:
|
| 174 |
-
output += '<|endoftext|>'
|
| 175 |
-
|
| 176 |
-
outputs.append(output)
|
| 177 |
-
except:
|
| 178 |
-
outputs = None
|
| 179 |
-
|
| 180 |
-
return (outputs, final_scores.cpu().numpy().tolist()) if compute_scores else outputs
|
| 181 |
-
|
| 182 |
-
|
| 183 |
-
decap_model = None
|
| 184 |
-
|
| 185 |
-
def get_decap_model(device, weights_path = DECAP_COCO_WEIGHTS_PATH, prefix_size=512):
|
| 186 |
-
#global decap_model
|
| 187 |
-
#if decap_model is not None:
|
| 188 |
-
# return decap_model
|
| 189 |
-
decap_model = DeCap(prefix_size)
|
| 190 |
-
decap_model.load_state_dict(torch.load(weights_path,map_location= torch.device('cpu')), strict=False)
|
| 191 |
-
decap_model = decap_model.to(device)
|
| 192 |
-
decap_model = decap_model.eval()
|
| 193 |
-
return decap_model
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
src/decap/decoderTraining.py
DELETED
|
@@ -1,464 +0,0 @@
|
|
| 1 |
-
import torch
|
| 2 |
-
from torch.utils.data import DataLoader
|
| 3 |
-
from torch.utils.data.distributed import DistributedSampler
|
| 4 |
-
from torch.nn.parallel import DistributedDataParallel as DDP
|
| 5 |
-
import torch.distributed as dist
|
| 6 |
-
|
| 7 |
-
from im2txtprojection.im2txtprojection import Im2TxtProjector, ProjectionType
|
| 8 |
-
from transformers import get_linear_schedule_with_warmup
|
| 9 |
-
from torch.optim import AdamW
|
| 10 |
-
from tqdm import tqdm
|
| 11 |
-
from decap import get_decap_model
|
| 12 |
-
import os
|
| 13 |
-
import sys
|
| 14 |
-
import argparse
|
| 15 |
-
import json
|
| 16 |
-
from typing import Union
|
| 17 |
-
import sys
|
| 18 |
-
import clip
|
| 19 |
-
import json
|
| 20 |
-
|
| 21 |
-
import csv
|
| 22 |
-
|
| 23 |
-
|
| 24 |
-
from src.dataset import ClipCocoDataset
|
| 25 |
-
from src.datasetMix import ClipCocoDatasetMix
|
| 26 |
-
from src.model import DeCap, ProjectionLayer
|
| 27 |
-
|
| 28 |
-
DECAP_DECODER_CONFIG_PATH = os.path.join("./decoder_config.pkl")
|
| 29 |
-
|
| 30 |
-
def save_config(args: argparse.Namespace):
|
| 31 |
-
config = {}
|
| 32 |
-
for key, item in args._get_kwargs():
|
| 33 |
-
config[key] = item
|
| 34 |
-
out_path = os.path.join(args.out_dir, f"{args.prefix}.json")
|
| 35 |
-
with open(out_path, 'w') as outfile:
|
| 36 |
-
json.dump(config, outfile)
|
| 37 |
-
|
| 38 |
-
|
| 39 |
-
def load_model(config_path: str, epoch_or_latest: Union[str, int] = '_latest'):
|
| 40 |
-
with open(config_path) as f:
|
| 41 |
-
config = json.load(f)
|
| 42 |
-
parser = argparse.ArgumentParser()
|
| 43 |
-
parser.set_defaults(**config)
|
| 44 |
-
args = parser.parse_args()
|
| 45 |
-
if type(epoch_or_latest) is int:
|
| 46 |
-
epoch_or_latest = f"-{epoch_or_latest:03d}"
|
| 47 |
-
model_path = os.path.join(args.out_dir, f"{args.prefix}{epoch_or_latest}.pt")
|
| 48 |
-
if args.only_prefix:
|
| 49 |
-
model = ClipCaptionPrefix(args.prefix_length)
|
| 50 |
-
else:
|
| 51 |
-
model = ClipCaptionModel(args.prefix_length)
|
| 52 |
-
if os.path.isfile(model_path):
|
| 53 |
-
print(f"loading model from {model_path}")
|
| 54 |
-
model.load_state_dict(torch.load(model_path, map_location=torch.device('cpu')))
|
| 55 |
-
else:
|
| 56 |
-
print(f"{model_path} is not exist")
|
| 57 |
-
return model, parser
|
| 58 |
-
|
| 59 |
-
|
| 60 |
-
|
| 61 |
-
|
| 62 |
-
def train_decoder(args,
|
| 63 |
-
lr: float = 1e-5, warmup_steps: int = 1000, output_dir: str = ".", output_prefix: str = ""):
|
| 64 |
-
|
| 65 |
-
# device = torch.device('cuda:1')
|
| 66 |
-
batch_size = args.bs
|
| 67 |
-
epochs = args.epochs
|
| 68 |
-
if not os.path.exists(output_dir):
|
| 69 |
-
os.makedirs(output_dir)
|
| 70 |
-
args.is_master = ( args.local_rank == 0 or args.not_distributed != False)
|
| 71 |
-
|
| 72 |
-
# set the device
|
| 73 |
-
#torch.cuda.set_device(args.local_rank)
|
| 74 |
-
#device = torch.device('cuda:'+str(args.local_rank))
|
| 75 |
-
if args.not_distributed == False:
|
| 76 |
-
torch.cuda.set_device(args.local_rank)
|
| 77 |
-
device = torch.device('cuda:'+str(args.local_rank))
|
| 78 |
-
dist.init_process_group(backend='nccl', init_method='env://')
|
| 79 |
-
else:
|
| 80 |
-
device = torch.device('cuda:'+str(args.local_rank))
|
| 81 |
-
print(f"NOT DISTRIBUTED")
|
| 82 |
-
print(f"Using device {device}")
|
| 83 |
-
SEED=42
|
| 84 |
-
torch.cuda.manual_seed_all(SEED)
|
| 85 |
-
|
| 86 |
-
if args.use_regionclip:
|
| 87 |
-
# RegionCLIP typically uses 1024 dimensions for ResNet-50 or 512 for ViT
|
| 88 |
-
# We'll determine this from the loaded model
|
| 89 |
-
prefix_size = 1024 # Default for RegionCLIP ResNet-50, but will be adjusted if needed
|
| 90 |
-
elif args.denseclip_config is not None:
|
| 91 |
-
# DenseClip typically uses 512 dimensions (similar to CLIP ViT-B)
|
| 92 |
-
from src.denseclip.loader import load_denseclip_config
|
| 93 |
-
denseclip_config_dict = load_denseclip_config(args.denseclip_config)
|
| 94 |
-
prefix_size = denseclip_config_dict.get('model', {}).get('text', {}).get('embed_dim', None)
|
| 95 |
-
if prefix_size is None:
|
| 96 |
-
print(f"Warning: Could not determine prefix_size from DenseClip config {args.denseclip_config}. Defaulting to 512.")
|
| 97 |
-
prefix_size = 512 # Fallback to a common size)
|
| 98 |
-
|
| 99 |
-
elif 'H' in args.clip_model or args.use_dinotxt:
|
| 100 |
-
prefix_size = 1024
|
| 101 |
-
elif args.talk2dino_weights is not None or args.use_dino_feats:
|
| 102 |
-
prefix_size = 768
|
| 103 |
-
else:
|
| 104 |
-
prefix_size = 512
|
| 105 |
-
|
| 106 |
-
if args.im_proj:
|
| 107 |
-
memory_bank_path = os.path.abspath(args.dataset)
|
| 108 |
-
print(f"Using Im2TxtProjector with {memory_bank_path = }")
|
| 109 |
-
im_proj = Im2TxtProjector(
|
| 110 |
-
type=memory_bank_path,
|
| 111 |
-
use_talk2dino=True,
|
| 112 |
-
linear_talk2dino=False,
|
| 113 |
-
memory_bank_name='coco_karpathy',
|
| 114 |
-
device_str=device)
|
| 115 |
-
|
| 116 |
-
if args.use_regionclip:
|
| 117 |
-
from src.regionclip.loader import load_regionclip_from_checkpoint
|
| 118 |
-
from src.regionclip.datasets.clip_prompt_utils import tokenize as regionclip_tokenize
|
| 119 |
-
|
| 120 |
-
print("Using RegionCLIP for text encoding.")
|
| 121 |
-
if args.regionclip_checkpoint is None:
|
| 122 |
-
raise ValueError("RegionCLIP checkpoint path must be provided when using --use-regionclip")
|
| 123 |
-
|
| 124 |
-
clip_model = load_regionclip_from_checkpoint(
|
| 125 |
-
args.regionclip_checkpoint,
|
| 126 |
-
device=device,
|
| 127 |
-
config=args.regionclip_config
|
| 128 |
-
)
|
| 129 |
-
tokenizer = regionclip_tokenize
|
| 130 |
-
preprocess = None # RegionCLIP doesn't need preprocessing for text-only training
|
| 131 |
-
|
| 132 |
-
# Determine the actual embedding dimension from the loaded model
|
| 133 |
-
if hasattr(clip_model, 'text_projection'):
|
| 134 |
-
actual_prefix_size = clip_model.text_projection.shape[1]
|
| 135 |
-
print(f"RegionCLIP text embedding dimension: {actual_prefix_size}")
|
| 136 |
-
if actual_prefix_size != prefix_size:
|
| 137 |
-
print(f"Updating prefix_size from {prefix_size} to {actual_prefix_size}")
|
| 138 |
-
prefix_size = actual_prefix_size
|
| 139 |
-
|
| 140 |
-
# Test RegionCLIP text encoding to ensure it works
|
| 141 |
-
try:
|
| 142 |
-
test_text = ["A test sentence"]
|
| 143 |
-
test_tokens = tokenizer(test_text)
|
| 144 |
-
test_features = clip_model.encode_text(test_tokens.to(device))
|
| 145 |
-
print(f"RegionCLIP test encoding successful. Output shape: {test_features.shape}")
|
| 146 |
-
except Exception as e:
|
| 147 |
-
print(f"Warning: RegionCLIP test encoding failed: {e}")
|
| 148 |
-
print("This might cause issues during training.")
|
| 149 |
-
|
| 150 |
-
elif args.denseclip_config is not None:
|
| 151 |
-
from src.denseclip.loader import load_denseclip
|
| 152 |
-
|
| 153 |
-
print(f"Using DenseClip for text encoding with config: {args.denseclip_config}")
|
| 154 |
-
|
| 155 |
-
try:
|
| 156 |
-
clip_model = load_denseclip(
|
| 157 |
-
config_name=args.denseclip_config,
|
| 158 |
-
device=device
|
| 159 |
-
)
|
| 160 |
-
|
| 161 |
-
# Try to use DenseClip's tokenizer first
|
| 162 |
-
try:
|
| 163 |
-
from src.denseclip.loader import DenseCLIP_tokenize
|
| 164 |
-
tokenizer = DenseCLIP_tokenize
|
| 165 |
-
print("Using DenseClip tokenizer")
|
| 166 |
-
except ImportError:
|
| 167 |
-
# Fallback to CLIP tokenizer if DenseClip tokenizer is not available
|
| 168 |
-
import clip
|
| 169 |
-
tokenizer = clip.tokenize
|
| 170 |
-
print("Warning: DenseClip tokenizer not available, using CLIP tokenizer")
|
| 171 |
-
|
| 172 |
-
preprocess = None # DenseClip doesn't need preprocessing for text-only training
|
| 173 |
-
|
| 174 |
-
# Determine the actual embedding dimension from the loaded model
|
| 175 |
-
if hasattr(clip_model, 'text_encoder') and hasattr(clip_model.text_encoder, 'embed_dim'):
|
| 176 |
-
actual_prefix_size = clip_model.text_encoder.embed_dim
|
| 177 |
-
print(f"DenseClip text embedding dimension: {actual_prefix_size}")
|
| 178 |
-
if actual_prefix_size != prefix_size:
|
| 179 |
-
print(f"Updating prefix_size from {prefix_size} to {actual_prefix_size}")
|
| 180 |
-
prefix_size = actual_prefix_size
|
| 181 |
-
|
| 182 |
-
# Test DenseClip text encoding to ensure it works
|
| 183 |
-
test_text = ["A test sentence"]
|
| 184 |
-
test_tokens = tokenizer(test_text)
|
| 185 |
-
if hasattr(test_tokens, 'to'):
|
| 186 |
-
test_tokens = test_tokens.to(device)
|
| 187 |
-
test_features = clip_model.encode_text(test_tokens)
|
| 188 |
-
print(f"DenseClip test encoding successful. Output shape: {test_features.shape}")
|
| 189 |
-
|
| 190 |
-
except Exception as e:
|
| 191 |
-
print(f"Error loading DenseClip model: {e}")
|
| 192 |
-
raise e
|
| 193 |
-
|
| 194 |
-
elif args.use_open_clip:
|
| 195 |
-
from open_clip import create_model_and_transforms, tokenize
|
| 196 |
-
print("Using open_clip for model loading.")
|
| 197 |
-
clip_model, preprocess_train, preprocess_val = create_model_and_transforms(model_name=args.clip_model, pretrained="laion2b_s32b_b79k", device=device)
|
| 198 |
-
preprocess = preprocess_train
|
| 199 |
-
tokenizer = tokenize
|
| 200 |
-
|
| 201 |
-
elif args.use_dinotxt:
|
| 202 |
-
from src.dinotxt_utils import get_tokenizer
|
| 203 |
-
clip_model = torch.hub.load('facebookresearch/dinov2', 'dinov2_vitl14_reg4_dinotxt_tet1280d20h24l')
|
| 204 |
-
tokenizer = get_tokenizer().tokenize
|
| 205 |
-
else:
|
| 206 |
-
clip_model, preprocess = clip.load(args.clip_model, device=device, jit=False)
|
| 207 |
-
tokenizer = clip.tokenize
|
| 208 |
-
clip_model.eval()
|
| 209 |
-
clip_model.to(device)
|
| 210 |
-
|
| 211 |
-
# Create model after determining the correct prefix_size
|
| 212 |
-
if args.decap_weights is None:
|
| 213 |
-
model = DeCap(prefix_size)
|
| 214 |
-
else:
|
| 215 |
-
model = get_decap_model(device, args.decap_weights, prefix_size)
|
| 216 |
-
|
| 217 |
-
if args.talk2dino_weights is not None:
|
| 218 |
-
# loading Talk2DINO
|
| 219 |
-
print(f"Loading Talk2DINO weights from {args.talk2dino_weights}")
|
| 220 |
-
talk2dino = ProjectionLayer.from_config(args.talk2dino_config)
|
| 221 |
-
talk2dino.load_state_dict(torch.load(args.talk2dino_weights, device))
|
| 222 |
-
talk2dino.to(device)
|
| 223 |
-
talk2dino.eval()
|
| 224 |
-
|
| 225 |
-
else:
|
| 226 |
-
talk2dino = None
|
| 227 |
-
|
| 228 |
-
|
| 229 |
-
loss_ce = torch.nn.CrossEntropyLoss(ignore_index=0,label_smoothing=0.1)
|
| 230 |
-
model.to(device)
|
| 231 |
-
|
| 232 |
-
if args.not_distributed == False:
|
| 233 |
-
model = DDP(
|
| 234 |
-
model,
|
| 235 |
-
device_ids=[args.local_rank],
|
| 236 |
-
output_device=args.local_rank,
|
| 237 |
-
find_unused_parameters=True
|
| 238 |
-
)
|
| 239 |
-
|
| 240 |
-
if not args.pre_extract_features:
|
| 241 |
-
print("Features pre-extraction de-activated")
|
| 242 |
-
if args.mix_captions:
|
| 243 |
-
print("Using mix captions")
|
| 244 |
-
dataset = ClipCocoDatasetMix(args.dataset, use_precomputed_feats=args.use_dino_feats, tokenizer=tokenizer)
|
| 245 |
-
else:
|
| 246 |
-
dataset = ClipCocoDataset(args.dataset, use_dino_feats=args.use_dino_feats, tokenizer=tokenizer)
|
| 247 |
-
else:
|
| 248 |
-
if args.mix_captions:
|
| 249 |
-
print("Using mix captions")
|
| 250 |
-
dataset = ClipCocoDatasetMix(args.dataset, clip_model=clip_model, talk2dino=talk2dino, tokenizer=tokenizer)
|
| 251 |
-
else:
|
| 252 |
-
dataset = ClipCocoDataset(args.dataset, clip_model=clip_model, talk2dino=talk2dino, tokenizer=tokenizer)
|
| 253 |
-
|
| 254 |
-
|
| 255 |
-
optimizer = AdamW(model.parameters(),lr=lr)
|
| 256 |
-
|
| 257 |
-
print(f"Going to construct DataLoader with {len(dataset)} samples")
|
| 258 |
-
if args.not_distributed == False:
|
| 259 |
-
sampler = DistributedSampler(dataset)
|
| 260 |
-
train_dataloader = DataLoader(dataset, sampler=sampler, batch_size=batch_size, drop_last=True)
|
| 261 |
-
else:
|
| 262 |
-
train_dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True, drop_last=True)
|
| 263 |
-
|
| 264 |
-
print("DataLoader constructed")
|
| 265 |
-
scheduler = get_linear_schedule_with_warmup(
|
| 266 |
-
optimizer, num_warmup_steps=warmup_steps, num_training_steps=epochs * len(train_dataloader)
|
| 267 |
-
)
|
| 268 |
-
|
| 269 |
-
|
| 270 |
-
for epoch in range(epochs):
|
| 271 |
-
|
| 272 |
-
epoch_loss = 0.0
|
| 273 |
-
epoch_acc = 0.0
|
| 274 |
-
num_batches = 0
|
| 275 |
-
|
| 276 |
-
loss_token_save,ac_save= 0,0
|
| 277 |
-
sys.stdout.flush()
|
| 278 |
-
if args.is_master:
|
| 279 |
-
print(f">>> Training epoch {epoch}")
|
| 280 |
-
progress = tqdm(total=int(len(train_dataloader)/10), desc=output_prefix, dynamic_ncols=True)
|
| 281 |
-
|
| 282 |
-
if args.not_distributed == False:
|
| 283 |
-
dist.barrier()
|
| 284 |
-
|
| 285 |
-
for idx,(clip_tokens, pipeline_input) in enumerate(train_dataloader):
|
| 286 |
-
|
| 287 |
-
|
| 288 |
-
clip_tokens, pipeline_input = clip_tokens.to(device), pipeline_input.to(device)
|
| 289 |
-
|
| 290 |
-
with torch.no_grad():
|
| 291 |
-
if not args.pre_extract_features and not args.use_dino_feats:
|
| 292 |
-
if args.use_regionclip:
|
| 293 |
-
# RegionCLIP text encoding
|
| 294 |
-
feature_text = clip_model.encode_text(pipeline_input)
|
| 295 |
-
elif args.denseclip_config is not None:
|
| 296 |
-
# DenseClip text encoding
|
| 297 |
-
feature_text = clip_model.encode_text(pipeline_input)
|
| 298 |
-
else:
|
| 299 |
-
# Standard CLIP or OpenCLIP text encoding
|
| 300 |
-
feature_text = clip_model.encode_text(pipeline_input)
|
| 301 |
-
|
| 302 |
-
if args.use_dinotxt:
|
| 303 |
-
feature_text = feature_text[:, 1024:] # patch-aligned text embedding
|
| 304 |
-
|
| 305 |
-
if args.talk2dino_weights is not None:
|
| 306 |
-
feature_text = talk2dino.project_clip_txt(feature_text)
|
| 307 |
-
else:
|
| 308 |
-
feature_text = pipeline_input
|
| 309 |
-
if args.im_proj:
|
| 310 |
-
feature_text = im_proj.project(feature_text, normalize=True)
|
| 311 |
-
|
| 312 |
-
feature_text /= feature_text.norm(dim=-1, keepdim=True)
|
| 313 |
-
|
| 314 |
-
if args.gaussian_noise != 0:
|
| 315 |
-
feature_text += args.gaussian_noise * torch.randn(feature_text.shape).to(device)
|
| 316 |
-
feature_text /= feature_text.norm(dim=-1, keepdim=True)
|
| 317 |
-
|
| 318 |
-
|
| 319 |
-
outputs = model(feature_text.float(),clip_tokens)
|
| 320 |
-
logits = outputs
|
| 321 |
-
|
| 322 |
-
logits = logits.logits
|
| 323 |
-
|
| 324 |
-
logits = logits[:,: -1]
|
| 325 |
-
clip_tokens = clip_tokens.flatten()
|
| 326 |
-
logits = logits.reshape(-1, logits.shape[-1])
|
| 327 |
-
|
| 328 |
-
loss_token = loss_ce(logits, clip_tokens)
|
| 329 |
-
ac=((logits.argmax(1)==clip_tokens)*(clip_tokens>0)).sum()/(clip_tokens>0).sum()
|
| 330 |
-
optimizer.zero_grad()
|
| 331 |
-
loss_all = loss_token
|
| 332 |
-
loss_all.backward()
|
| 333 |
-
optimizer.step()
|
| 334 |
-
scheduler.step()
|
| 335 |
-
|
| 336 |
-
epoch_loss += loss_token.item()
|
| 337 |
-
epoch_acc += ac.item()
|
| 338 |
-
num_batches += 1
|
| 339 |
-
|
| 340 |
-
if args.is_master:
|
| 341 |
-
|
| 342 |
-
if(idx+1) %10 == 0:
|
| 343 |
-
progress.set_postfix({"loss_token": loss_token_save/10.0,"acc_token":ac_save/10.0})
|
| 344 |
-
progress.update()
|
| 345 |
-
loss_token_save,ac_save= 0,0
|
| 346 |
-
else:
|
| 347 |
-
loss_token_save += loss_token.item()
|
| 348 |
-
ac_save += ac.item()
|
| 349 |
-
|
| 350 |
-
if args.is_master:
|
| 351 |
-
log_dir = os.path.join('./log', f"{args.dataset}.txt")#'./log/'+args.dataset+'.txt'
|
| 352 |
-
with open(log_dir,'w') as f:
|
| 353 |
-
f.writelines('epoch ' +str(epoch) +': '+ progress.postfix+'\r\n')
|
| 354 |
-
progress.close()
|
| 355 |
-
if epoch % args.save_every == 0 or epoch == epochs - 1:
|
| 356 |
-
torch.save(
|
| 357 |
-
model.state_dict(),
|
| 358 |
-
os.path.join(output_dir, f"{output_prefix}-{epoch:03d}.pt"),
|
| 359 |
-
)
|
| 360 |
-
|
| 361 |
-
# after the epoch, we need to synchronize the loss and accuracy across all processes
|
| 362 |
-
loss_tensor = torch.tensor(epoch_loss, device=device)
|
| 363 |
-
acc_tensor = torch.tensor(epoch_acc, device=device)
|
| 364 |
-
count_tensor = torch.tensor(num_batches, device=device)
|
| 365 |
-
|
| 366 |
-
if args.not_distributed == False:
|
| 367 |
-
# sum on all processes
|
| 368 |
-
torch.distributed.all_reduce(loss_tensor, op=torch.distributed.ReduceOp.SUM)
|
| 369 |
-
torch.distributed.all_reduce(acc_tensor, op=torch.distributed.ReduceOp.SUM)
|
| 370 |
-
torch.distributed.all_reduce(count_tensor, op=torch.distributed.ReduceOp.SUM)
|
| 371 |
-
|
| 372 |
-
# compute global mean
|
| 373 |
-
avg_loss = loss_tensor.item() / count_tensor.item()
|
| 374 |
-
avg_acc = acc_tensor.item() / count_tensor.item()
|
| 375 |
-
|
| 376 |
-
if args.is_master:
|
| 377 |
-
epoch_loss_current = {'epoch': epoch, 'loss': avg_loss, 'accuracy': avg_acc}
|
| 378 |
-
#epoch_losses.append(epoch_loss_current)
|
| 379 |
-
print(f"Epoch {epoch} loss: {avg_loss}, accuracy: {avg_acc}")
|
| 380 |
-
|
| 381 |
-
loss_csv_path = os.path.join(output_dir, f"{output_prefix}_epoch_losses.csv")
|
| 382 |
-
with open(loss_csv_path, 'a', newline='') as csvfile:
|
| 383 |
-
writer = csv.DictWriter(csvfile, fieldnames=['epoch', 'loss', 'accuracy'])
|
| 384 |
-
# Write the header only if the file is empty
|
| 385 |
-
if os.stat(loss_csv_path).st_size == 0:
|
| 386 |
-
writer.writeheader()
|
| 387 |
-
writer.writerow(epoch_loss_current)
|
| 388 |
-
return model
|
| 389 |
-
|
| 390 |
-
# DeCap CLIP B16 karpathy train split:
|
| 391 |
-
#python decapTraining.py --out_dir weights_clip_b16_karpathy --not-distributed 1 --local-rank 0 --dataset coco_train_karpathy.json --prefix coco_karpathy
|
| 392 |
-
# DECAP with proj -> ma in realtร non serve.
|
| 393 |
-
#python decapTraining.py --out_dir weights_clip_b16_proj_karpathy --not-distributed 1 --local-rank 0 --dataset coco_train_karpathy.json --prefix coco_karpathy --im_proj
|
| 394 |
-
|
| 395 |
-
# Patchioner DINOv2 karpathy train split with proj:
|
| 396 |
-
#python decapTraining.py --out_dir weights_dino_b14_proj_karpathy --not-distributed 1 --local-rank 1 --dataset coco_train_karpathy.json --prefix coco_karpathy --talk2dino_weights weights_talk2dino/vitb_mlp_infonce.pth --talk2dino_config configs_talk2dino/vitb_mlp_infonce.yaml --pre_extract_features --im_proj
|
| 397 |
-
# Patchioner DINOv2 karpathy train split
|
| 398 |
-
#python decapTraining.py --out_dir weights_dino_b14_karpathy --not-distributed 1 --local-rank 1 --dataset coco_train_karpathy.json --prefix coco_karpathy --talk2dino_weights weights_talk2dino/vitb_mlp_infonce.pth --talk2dino_config configs_talk2dino/vitb_mlp_infonce.yaml
|
| 399 |
-
#python decapTraining.py --out_dir weights_dino_b14_karpathy --not-distributed 1 --local-rank 1 --dataset coco_train_karpathy.json --prefix coco_karpathy --talk2dino_weights weights_talk2dino/vitb_mlp_infonce.pth --talk2dino_config configs_talk2dino/vitb_mlp_infonce.yaml --use_dino_feats --pre_extract_features
|
| 400 |
-
|
| 401 |
-
# DeCap CLIP B32 karpathy train split:
|
| 402 |
-
#python decapTraining.py --out_dir weights_clip_b32_karpathy --not-distributed 1 --local-rank 0 --dataset coco_train_karpathy.json --prefix coco_karpathy --clip_model ViT-B/32
|
| 403 |
-
|
| 404 |
-
# DeCap with RegionCLIP text encoder:
|
| 405 |
-
#python decoderTraining.py --out_dir weights_regionclip_karpathy --not-distributed 1 --local-rank 0 --dataset coco_train_karpathy.json --prefix coco_karpathy --use-regionclip
|
| 406 |
-
|
| 407 |
-
# DeCap with DenseClip text encoder:
|
| 408 |
-
#python decoderTraining.py --out_dir weights_denseclip_segmentation_vitb16_karpathy --not-distributed 1 --local-rank 0 --dataset coco_train_karpathy.json --prefix coco_karpathy --denseclip-config denseclip_segmentation_vitb16
|
| 409 |
-
|
| 410 |
-
|
| 411 |
-
def main():
|
| 412 |
-
parser = argparse.ArgumentParser()
|
| 413 |
-
parser.add_argument('--decap_weights', type=str, default=None, help="If setted the Decap initialization is not random")
|
| 414 |
-
parser.add_argument('--clip_model', type=str, default='ViT-B/16', help="CLIP configuration")
|
| 415 |
-
parser.add_argument('--use_dinotxt', default=None, action='store_true', help="CLIP configuration")
|
| 416 |
-
parser.add_argument('--gaussian_noise', type=float, default=0, help="Standard deviation of the Gaussian noise to apply to the text input")
|
| 417 |
-
parser.add_argument('--out_dir', default='./coco_model')
|
| 418 |
-
parser.add_argument('--prefix', default='./coco_prefix', help='prefix for saved filenames')
|
| 419 |
-
parser.add_argument('--dataset', default='coco', help='coco or cc3m or bookcorpus')
|
| 420 |
-
parser.add_argument('--epochs', type=int, default=10)
|
| 421 |
-
parser.add_argument('--save_every', type=int, default=1)
|
| 422 |
-
parser.add_argument('--prefix_length', type=int, default=1)
|
| 423 |
-
parser.add_argument('--prefix_length_clip', type=int, default=1)
|
| 424 |
-
parser.add_argument('--bs', type=int, default=64)
|
| 425 |
-
parser.add_argument('--talk2dino_weights', type=str, default=None, help="Talk2DINO weights. If None, the training will be performed without Talk2DINO.")
|
| 426 |
-
parser.add_argument('--talk2dino_config', type=str, default=None, help="Talk2DINO configs. Valid only if the weights are setted.")
|
| 427 |
-
parser.add_argument('--use_dino_feats', action="store_true", default=False, help="If setted, we use the pre-extracted features of DINOv2")
|
| 428 |
-
parser.add_argument('--im_proj', action="store_true", default=False, help="If setted, we use the projection on the input features")
|
| 429 |
-
parser.add_argument('--pre_extract_features', action="store_true", default=False, help="If setted, the features will be extracted during the dataloading")
|
| 430 |
-
parser.add_argument('--only_prefix', dest='only_prefix', action='store_true')
|
| 431 |
-
parser.add_argument('--mapping_type', type=str, default='mlp', help='mlp/transformer')
|
| 432 |
-
parser.add_argument('--num_layers', type=int, default=8)
|
| 433 |
-
parser.add_argument('--is_rn', dest='is_rn', action='store_true')
|
| 434 |
-
parser.add_argument('--normalize_prefix', dest='normalize_prefix', action='store_true')
|
| 435 |
-
parser.add_argument('--local-rank', type=int, default=-1, metavar='N', help='Local process rank.')
|
| 436 |
-
parser.add_argument('--not-distributed', type=int, default=False, metavar='N', help='Not Distributed toggle.')
|
| 437 |
-
parser.add_argument('--use-open-clip', action='store_true', default=False, help='Use OpenCLIP instead of CLIP')
|
| 438 |
-
parser.add_argument('--mix-captions', action='store_true', default=False, help='Mix captions from the same image')
|
| 439 |
-
parser.add_argument('--use-regionclip', action='store_true', default=False, help='Use RegionCLIP for text encoding')
|
| 440 |
-
parser.add_argument('--regionclip-checkpoint', type=str, default='/raid/datasets/models_weights/regionclip/regionclip_pretrained-cc_rn50x4.pth', help='Path to RegionCLIP checkpoint file')
|
| 441 |
-
parser.add_argument('--regionclip-config', type=str, default='pretrain/RegionCLIP_RN50x4.yaml', help='Path to RegionCLIP config file or config name')
|
| 442 |
-
parser.add_argument('--denseclip-config', type=str, default=None, help='Path to DenseClip config file or config name')
|
| 443 |
-
args = parser.parse_args()
|
| 444 |
-
|
| 445 |
-
# Validate RegionCLIP arguments
|
| 446 |
-
if args.use_regionclip and args.regionclip_checkpoint is None:
|
| 447 |
-
parser.error("--regionclip-checkpoint is required when using --use-regionclip")
|
| 448 |
-
|
| 449 |
-
if args.use_regionclip and args.use_open_clip:
|
| 450 |
-
parser.error("Cannot use both --use-regionclip and --use-open-clip at the same time")
|
| 451 |
-
|
| 452 |
-
# Validate DenseClip arguments
|
| 453 |
-
if args.denseclip_config is not None and args.use_regionclip:
|
| 454 |
-
parser.error("Cannot use both --denseclip-config and --use-regionclip at the same time")
|
| 455 |
-
|
| 456 |
-
if args.denseclip_config is not None and args.use_open_clip:
|
| 457 |
-
parser.error("Cannot use both --denseclip-config and --use-open-clip at the same time")
|
| 458 |
-
|
| 459 |
-
|
| 460 |
-
train_decoder(args, output_dir=args.out_dir, output_prefix=args.prefix)
|
| 461 |
-
|
| 462 |
-
|
| 463 |
-
if __name__ == '__main__':
|
| 464 |
-
main()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
src/decap/decoder_config.pkl
DELETED
|
@@ -1,3 +0,0 @@
|
|
| 1 |
-
version https://git-lfs.github.com/spec/v1
|
| 2 |
-
oid sha256:c454e6bddb15af52c82734f1796391bf3a10a6c5533ea095de06f661ebb858bb
|
| 3 |
-
size 1744
|
|
|
|
|
|
|
|
|
|
|
|
src/decap/im2txtprojection/im2txtprojection.py
DELETED
|
@@ -1,500 +0,0 @@
|
|
| 1 |
-
from enum import Enum
|
| 2 |
-
import numpy as np
|
| 3 |
-
import math
|
| 4 |
-
import json
|
| 5 |
-
import random
|
| 6 |
-
import torch
|
| 7 |
-
from tqdm import tqdm
|
| 8 |
-
import os
|
| 9 |
-
import h5py
|
| 10 |
-
from typing import Tuple
|
| 11 |
-
from dotenv import load_dotenv
|
| 12 |
-
from src.dinotxt_utils import get_tokenizer
|
| 13 |
-
|
| 14 |
-
load_dotenv()
|
| 15 |
-
|
| 16 |
-
class ProjectionType(Enum):
|
| 17 |
-
COCO_CAPTIONS = 'coco_captions'
|
| 18 |
-
MS_MARCO_QUERIES_A = 'ms_marco_queries_a'
|
| 19 |
-
CC3M_BLIP = 'cc3m_blip_captions'
|
| 20 |
-
VISUAL_GENOME = 'vg_captions'
|
| 21 |
-
VISUAL_GENOME_TEST = "vg_dense_captions_test"
|
| 22 |
-
ONLINE_TEXTS = "online_texts"
|
| 23 |
-
|
| 24 |
-
class Im2TxtProjector:
|
| 25 |
-
"""
|
| 26 |
-
Im2TxtProjector creates and manages text embedding memory banks for different models:
|
| 27 |
-
- Standard CLIP models
|
| 28 |
-
- OpenCLIP models
|
| 29 |
-
- RegionCLIP models
|
| 30 |
-
- DenseClip models
|
| 31 |
-
- Talk2DINO projected embeddings
|
| 32 |
-
|
| 33 |
-
For RegionCLIP usage, pass regionclip_config as a dict with:
|
| 34 |
-
{
|
| 35 |
-
'checkpoint': '/path/to/regionclip_checkpoint.pth',
|
| 36 |
-
'config_name': 'RegionCLIP_RN50.yaml' # optional
|
| 37 |
-
}
|
| 38 |
-
|
| 39 |
-
For DenseClip usage, pass denseclip_config as a string with the config file name:
|
| 40 |
-
'denseclip_vitb16' # or other valid DenseClip config name
|
| 41 |
-
"""
|
| 42 |
-
|
| 43 |
-
SUPPORT_MEMORY_SIZE = 500000
|
| 44 |
-
|
| 45 |
-
__IM2TXT_MEMORY_PATH = os.getenv("IM2TXT_MEMORY_PATH")
|
| 46 |
-
|
| 47 |
-
if __IM2TXT_MEMORY_PATH is None:
|
| 48 |
-
default_path = "weights/im2txtmemories" #os.path.join(os.path.dirname(__file__), "../../../im2txtmemories")
|
| 49 |
-
print(f"[!] Warning: IM2TXT_MEMORY_PATH not set in environment variables, using '{default_path}' [!]")
|
| 50 |
-
__IM2TXT_MEMORY_PATH = default_path
|
| 51 |
-
|
| 52 |
-
__DECAP_FOLDER = os.path.join(os.path.dirname(__file__), "../")
|
| 53 |
-
__TALK2DINO_CONFIG_WEIGHTS_PATH = __DECAP_FOLDER
|
| 54 |
-
|
| 55 |
-
captions_dataType = 'train2017'
|
| 56 |
-
ANNOTATIONS_CAPTION_FILE_PATH = os.path.join(__DECAP_FOLDER, 'captions_{}.json'.format(captions_dataType))
|
| 57 |
-
VG_ANNOTATIONS_DENSE_CAPTIONS_FILE_PATH = '/raid/datasets/densecaptioning-annotations/data/vg/controlcap/vg1.2/train.json'
|
| 58 |
-
VG_ANNOTATIONS_DENSE_CAPTIONS_TEST_FILE_PATH = '/raid/datasets/densecaptioning-annotations/data/vg/controlcap/vg1.2/test.json'
|
| 59 |
-
|
| 60 |
-
CC3M_BLIP_FILE_PATH = os.path.join(__DECAP_FOLDER, "blipv2_captions.txt")
|
| 61 |
-
MS_MARCO_QUERIES_FILE_PATH = '/raid/datasets/MSMarco/queries/queries.train.tsv'
|
| 62 |
-
|
| 63 |
-
@staticmethod
|
| 64 |
-
def create_regionclip_config(checkpoint_path: str, config_name: str = None):
|
| 65 |
-
"""
|
| 66 |
-
Helper method to create RegionCLIP configuration dictionary.
|
| 67 |
-
|
| 68 |
-
Args:
|
| 69 |
-
checkpoint_path (str): Path to RegionCLIP checkpoint file
|
| 70 |
-
config_name (str, optional): RegionCLIP config name (e.g., 'RegionCLIP_RN50.yaml')
|
| 71 |
-
|
| 72 |
-
Returns:
|
| 73 |
-
dict: Configuration dictionary for RegionCLIP
|
| 74 |
-
"""
|
| 75 |
-
return {
|
| 76 |
-
'checkpoint': checkpoint_path,
|
| 77 |
-
'config_name': config_name
|
| 78 |
-
}
|
| 79 |
-
|
| 80 |
-
def __init__(self, type = ProjectionType.COCO_CAPTIONS, verbose : bool = True, device_str = "cpu", use_talk2dino : bool = True,
|
| 81 |
-
support_memory_size : int = SUPPORT_MEMORY_SIZE, batch_size=1000,
|
| 82 |
-
clip_modelname = None, linear_talk2dino : bool = False,
|
| 83 |
-
normalize_memory_embs : bool = False, talk2dino_attn_type='qkv', online_texts=None,
|
| 84 |
-
memory_bank_name = None, use_open_clip = False, regionclip_config=None, invite_config=None, denseclip_config=None) -> None:
|
| 85 |
-
"""
|
| 86 |
-
- normalize_memory_embs -> normalizes the embeddings memory (required for projection in CLIP space)
|
| 87 |
-
- type : ProjectionType -> the type of the support memory to be built . Can either be the path to the file containing the captions or the type of the support memory to be built
|
| 88 |
-
|
| 89 |
-
"""
|
| 90 |
-
# check if hdf5 already exists, otherwhise builds the support memory for that kind
|
| 91 |
-
|
| 92 |
-
#if type not in ProjectionType.mro()
|
| 93 |
-
|
| 94 |
-
self.type = type
|
| 95 |
-
self.device_str = device_str
|
| 96 |
-
self.device = torch.device(self.device_str)
|
| 97 |
-
self.use_talk2dino = use_talk2dino
|
| 98 |
-
self.linear_talk2dino = linear_talk2dino
|
| 99 |
-
self.talk2dino_attn_type = talk2dino_attn_type
|
| 100 |
-
self.online_texts = online_texts
|
| 101 |
-
self.use_open_clip = use_open_clip
|
| 102 |
-
self.regionclip_config = regionclip_config
|
| 103 |
-
self.invite_config = invite_config
|
| 104 |
-
self.denseclip_config = denseclip_config
|
| 105 |
-
|
| 106 |
-
if use_open_clip:
|
| 107 |
-
assert use_talk2dino is False, "use_open_clip and use_talk2dino cannot be used together"
|
| 108 |
-
|
| 109 |
-
if regionclip_config is not None:
|
| 110 |
-
assert use_talk2dino is False, "regionclip_config and use_talk2dino cannot be used together"
|
| 111 |
-
assert use_open_clip is False, "regionclip_config and use_open_clip cannot be used together"
|
| 112 |
-
|
| 113 |
-
if invite_config is not None:
|
| 114 |
-
# overwrite clip_modelname with invite_config['name'] if provided
|
| 115 |
-
clip_modelname = invite_config.get('name', clip_modelname)
|
| 116 |
-
assert use_talk2dino is False, "invite_config and use_talk2dino cannot be used together"
|
| 117 |
-
|
| 118 |
-
if denseclip_config is not None:
|
| 119 |
-
assert use_talk2dino is False, "denseclip_config and use_talk2dino cannot be used together"
|
| 120 |
-
assert use_open_clip is False, "denseclip_config and use_open_clip cannot be used together"
|
| 121 |
-
assert regionclip_config is None, "denseclip_config and regionclip_config cannot be used together"
|
| 122 |
-
|
| 123 |
-
|
| 124 |
-
if clip_modelname is None:
|
| 125 |
-
if self.use_talk2dino:
|
| 126 |
-
clip_modelname = "ViT-B/16"
|
| 127 |
-
elif regionclip_config is not None:
|
| 128 |
-
# For RegionCLIP, we'll use a generic identifier since the model type is in the config
|
| 129 |
-
clip_modelname = "RegionCLIP"
|
| 130 |
-
elif denseclip_config is not None:
|
| 131 |
-
# For DenseClip, we'll use a generic identifier since the model type is in the config
|
| 132 |
-
clip_modelname = "DenseClip"
|
| 133 |
-
else:
|
| 134 |
-
clip_modelname = "ViT-B/32"
|
| 135 |
-
self.clip_modelname = clip_modelname
|
| 136 |
-
|
| 137 |
-
self.SUPPORT_MEMORY_SIZE = support_memory_size
|
| 138 |
-
if use_talk2dino:
|
| 139 |
-
prefix = ""
|
| 140 |
-
postfix = '-B16' if use_talk2dino is True else use_talk2dino
|
| 141 |
-
if linear_talk2dino:
|
| 142 |
-
postfix += "-linear"
|
| 143 |
-
elif regionclip_config is not None:
|
| 144 |
-
prefix = "regionclip-"
|
| 145 |
-
postfix = ""
|
| 146 |
-
elif denseclip_config is not None:
|
| 147 |
-
prefix = "denseclip-"
|
| 148 |
-
postfix = ""
|
| 149 |
-
else:
|
| 150 |
-
prefix = "clip-"
|
| 151 |
-
postfix = ""
|
| 152 |
-
if talk2dino_attn_type != 'qkv':
|
| 153 |
-
self.talk2dino_attn_type_str = f"_{talk2dino_attn_type}"
|
| 154 |
-
else:
|
| 155 |
-
self.talk2dino_attn_type_str = ''
|
| 156 |
-
if isinstance(type, ProjectionType):
|
| 157 |
-
dataset_name = type.value
|
| 158 |
-
elif memory_bank_name is not None:
|
| 159 |
-
dataset_name = memory_bank_name
|
| 160 |
-
else:
|
| 161 |
-
dataset_name = 'coco'
|
| 162 |
-
|
| 163 |
-
if use_open_clip:
|
| 164 |
-
postfix += "-open_clip"
|
| 165 |
-
elif regionclip_config is not None:
|
| 166 |
-
postfix += "-regionclip"
|
| 167 |
-
# Add checkpoint identifier to make filename unique
|
| 168 |
-
checkpoint_path = regionclip_config.get('checkpoint', '')
|
| 169 |
-
checkpoint_name = os.path.basename(checkpoint_path).replace('.pth', '').replace('.pt', '')
|
| 170 |
-
if checkpoint_name:
|
| 171 |
-
postfix += f"-{checkpoint_name}"
|
| 172 |
-
elif denseclip_config is not None:
|
| 173 |
-
postfix += "-denseclip"
|
| 174 |
-
# Add config identifier to make filename unique
|
| 175 |
-
config_name = os.path.basename(denseclip_config).replace('.yaml', '').replace('.yml', '')
|
| 176 |
-
if config_name:
|
| 177 |
-
postfix += f"-{config_name}"
|
| 178 |
-
|
| 179 |
-
self.H5PY_FILE_PATH = os.path.join( self.__IM2TXT_MEMORY_PATH, prefix + f'{dataset_name}{self.talk2dino_attn_type_str}_text_embeddings{postfix}-{clip_modelname.replace("/", ".")}-{self.SUPPORT_MEMORY_SIZE}.h5' )
|
| 180 |
-
self.H5PY_EMBEDDINGS_DATASET_NAME = '{}-embeddings'.format(dataset_name)
|
| 181 |
-
self.H5PY_TEXT_DATASET_NAME = '{}-text'.format(dataset_name)
|
| 182 |
-
|
| 183 |
-
embs_dataset, text_dataset = self._load_support_memory()
|
| 184 |
-
|
| 185 |
-
if text_dataset is None:
|
| 186 |
-
if verbose:
|
| 187 |
-
model_type = "RegionCLIP" if regionclip_config is not None else ("DenseClip" if denseclip_config is not None else ("OpenCLIP" if use_open_clip else "CLIP"))
|
| 188 |
-
print(f"[+] Going to build support memory for the given data type: {type} using {model_type} [+]")
|
| 189 |
-
embs_dataset, text_dataset = self._build_support_memory(batch_size)
|
| 190 |
-
if verbose: print(f"[+] Done [+]")
|
| 191 |
-
|
| 192 |
-
if self.type != ProjectionType.ONLINE_TEXTS:
|
| 193 |
-
embs_dataset, text_dataset = self._load_support_memory()
|
| 194 |
-
|
| 195 |
-
print(f"[-] loaded memory from {os.path.abspath( self.H5PY_FILE_PATH )} [-]")
|
| 196 |
-
if regionclip_config is not None:
|
| 197 |
-
print(f"[-] Using RegionCLIP text embeddings from checkpoint: {regionclip_config.get('checkpoint', 'Unknown')} [-]")
|
| 198 |
-
elif denseclip_config is not None:
|
| 199 |
-
print(f"[-] Using DenseClip text embeddings from config: {denseclip_config} [-]")
|
| 200 |
-
|
| 201 |
-
self.text_dataset = text_dataset
|
| 202 |
-
self.embs_dataset = torch.tensor(embs_dataset[:]).to(self.device)
|
| 203 |
-
self.embs_dataset = self.embs_dataset[self.embs_dataset.norm(dim=-1) != 0]
|
| 204 |
-
|
| 205 |
-
|
| 206 |
-
if normalize_memory_embs:
|
| 207 |
-
self.embs_dataset /= self.embs_dataset.norm(dim=-1,keepdim=True).float()
|
| 208 |
-
|
| 209 |
-
|
| 210 |
-
|
| 211 |
-
def project(self, image_embedding, temperature : float = 0.01, normalize : bool = False, return_argmax_text : bool = False, return_n_best_sims=None) -> torch.TensorType:
|
| 212 |
-
if not isinstance(image_embedding, torch.Tensor):
|
| 213 |
-
print(f"the type of image_embedding is '{type(image_embedding)}' converting it to torch tensor")
|
| 214 |
-
image_embedding = torch.tensor(image_embedding, dtype=torch.float).to(self.device)
|
| 215 |
-
|
| 216 |
-
orig_device = image_embedding.device
|
| 217 |
-
|
| 218 |
-
if image_embedding.device != self.device:
|
| 219 |
-
image_embedding = image_embedding.to(self.device)
|
| 220 |
-
|
| 221 |
-
if image_embedding.dtype != float:
|
| 222 |
-
#print(f"[-] image_embedding.dtype is {image_embedding.dtype}, converting it to float [-]")
|
| 223 |
-
image_embedding = image_embedding.float()
|
| 224 |
-
|
| 225 |
-
embs_dataset = self.embs_dataset / self.embs_dataset.norm(dim=-1, keepdim=True)
|
| 226 |
-
image_embedding /= image_embedding.norm(dim=-1,keepdim=True)
|
| 227 |
-
|
| 228 |
-
sim = image_embedding@embs_dataset.T.float()
|
| 229 |
-
if return_argmax_text:
|
| 230 |
-
argmax_texts = [self.text_dataset[idx].decode() for idx in sim.argmax(dim=-1)]
|
| 231 |
-
if return_n_best_sims:
|
| 232 |
-
return argmax_texts, sim.sort(dim=-1, descending=True).values[:, :return_n_best_sims].tolist()
|
| 233 |
-
return argmax_texts
|
| 234 |
-
softmax_sim = (sim / temperature).softmax(dim=-1)
|
| 235 |
-
prefix_embedding = [email protected]_dataset.float()
|
| 236 |
-
|
| 237 |
-
if normalize:
|
| 238 |
-
prefix_embedding /= prefix_embedding.norm(dim=-1,keepdim=True)
|
| 239 |
-
|
| 240 |
-
if return_n_best_sims:
|
| 241 |
-
return prefix_embedding.to(orig_device), sim.sort(dim=-1, descending=True).values[:, :return_n_best_sims].tolist()
|
| 242 |
-
|
| 243 |
-
return prefix_embedding.to(orig_device)
|
| 244 |
-
|
| 245 |
-
def _load_support_memory(self) -> Tuple[np.ndarray, np.ndarray]:
|
| 246 |
-
if self.type == ProjectionType.ONLINE_TEXTS:
|
| 247 |
-
print(f"[-] _load_support_memory: support memory for provided texts will be constructed [-]")
|
| 248 |
-
return None, None
|
| 249 |
-
if not os.path.exists(self.H5PY_FILE_PATH):
|
| 250 |
-
print(f"[-] _load_support_memory: the path '{self.H5PY_FILE_PATH}' does not exist [-]")
|
| 251 |
-
return None, None
|
| 252 |
-
|
| 253 |
-
with h5py.File(self.H5PY_FILE_PATH, 'r') as hf:
|
| 254 |
-
|
| 255 |
-
if self.H5PY_EMBEDDINGS_DATASET_NAME in hf:
|
| 256 |
-
embeddings_dataset = hf[self.H5PY_EMBEDDINGS_DATASET_NAME][:]
|
| 257 |
-
text_dataset = hf[self.H5PY_TEXT_DATASET_NAME][:]
|
| 258 |
-
else:
|
| 259 |
-
embeddings_dataset = None
|
| 260 |
-
text_dataset = None
|
| 261 |
-
if 'DINO.txt' in self.clip_modelname:
|
| 262 |
-
embeddings_dataset = embeddings_dataset[:, 1024:] # Get patch-aligned text embeddings
|
| 263 |
-
return embeddings_dataset, text_dataset
|
| 264 |
-
|
| 265 |
-
|
| 266 |
-
|
| 267 |
-
def _build_support_memory(self, batch_size = 1000) -> Tuple[np.ndarray, np.ndarray]:
|
| 268 |
-
## construct the support memory
|
| 269 |
-
|
| 270 |
-
self._load_models()
|
| 271 |
-
|
| 272 |
-
if self.type == ProjectionType.COCO_CAPTIONS:
|
| 273 |
-
from pycocotools.coco import COCO
|
| 274 |
-
coco_obj = COCO(Im2TxtProjector.ANNOTATIONS_CAPTION_FILE_PATH)
|
| 275 |
-
data = random.sample(list(coco_obj.anns.values()), k=self.SUPPORT_MEMORY_SIZE)
|
| 276 |
-
data = [ d['caption'] for d in data ]
|
| 277 |
-
elif self.type == ProjectionType.VISUAL_GENOME:
|
| 278 |
-
from pycocotools.coco import COCO
|
| 279 |
-
coco_obj = COCO(Im2TxtProjector.VG_ANNOTATIONS_DENSE_CAPTIONS_FILE_PATH)
|
| 280 |
-
# data = random.sample(list(coco_obj.anns.values()), k=self.SUPPORT_MEMORY_SIZE)
|
| 281 |
-
data = list(coco_obj.anns.values())[:self.SUPPORT_MEMORY_SIZE]
|
| 282 |
-
data = [ d['caption'] for d in data ]
|
| 283 |
-
elif self.type == ProjectionType.VISUAL_GENOME_TEST:
|
| 284 |
-
from pycocotools.coco import COCO
|
| 285 |
-
coco_obj = COCO(Im2TxtProjector.VG_ANNOTATIONS_DENSE_CAPTIONS_TEST_FILE_PATH)
|
| 286 |
-
# data = random.sample(list(coco_obj.anns.values()), k=self.SUPPORT_MEMORY_SIZE)
|
| 287 |
-
data = list(coco_obj.anns.values())[:self.SUPPORT_MEMORY_SIZE]
|
| 288 |
-
data = [ d['caption'] for d in data ]
|
| 289 |
-
elif self.type == ProjectionType.MS_MARCO_QUERIES_A:
|
| 290 |
-
print(f"Loading MSMarco queries from file ", Im2TxtProjector.MS_MARCO_QUERIES_FILE_PATH)
|
| 291 |
-
with open(Im2TxtProjector.MS_MARCO_QUERIES_FILE_PATH, "r") as input_file:
|
| 292 |
-
lines = input_file.readlines()
|
| 293 |
-
data = random.sample(lines, k=self.SUPPORT_MEMORY_SIZE)
|
| 294 |
-
data = [ d.split("\t")[1].replace("\n", "") for d in data ]
|
| 295 |
-
print(f"Loaded from file '{self.SUPPORT_MEMORY_SIZE}' lines, example of line: '{data[0]}'")
|
| 296 |
-
elif self.type == ProjectionType.CC3M_BLIP:
|
| 297 |
-
print(f"Loading cc3m captions txt file ", Im2TxtProjector.CC3M_BLIP_FILE_PATH)
|
| 298 |
-
with open(Im2TxtProjector.CC3M_BLIP_FILE_PATH, "r") as input_file:
|
| 299 |
-
lines = input_file.readlines()
|
| 300 |
-
data = random.sample(lines, k=self.SUPPORT_MEMORY_SIZE)
|
| 301 |
-
data = [ d.replace("\n", "") for d in data ]
|
| 302 |
-
print(f"Loaded from file '{len(data)}' lines, example of line: '{data[0]}'")
|
| 303 |
-
elif self.type == ProjectionType.CC3M_BLIP:
|
| 304 |
-
print(f"Loading cc3m captions txt file ", Im2TxtProjector.CC3M_BLIP_FILE_PATH)
|
| 305 |
-
with open(Im2TxtProjector.CC3M_BLIP_FILE_PATH, "r") as input_file:
|
| 306 |
-
lines = input_file.readlines()
|
| 307 |
-
data = random.sample(lines, k=self.SUPPORT_MEMORY_SIZE)
|
| 308 |
-
data = [ d.replace("\n", "") for d in data ]
|
| 309 |
-
print(f"Loaded from file '{len(data)}' lines, example of line: '{data[0]}'")
|
| 310 |
-
elif self.type == ProjectionType.ONLINE_TEXTS:
|
| 311 |
-
data = self.online_texts
|
| 312 |
-
print(f"Loaded online_texts '{len(data)}' lines, example of line: '{data[0]}'")
|
| 313 |
-
elif type(self.type) == str:
|
| 314 |
-
if os.path.exists(self.type):
|
| 315 |
-
path = self.type
|
| 316 |
-
from pycocotools.coco import COCO
|
| 317 |
-
coco_obj = COCO(path)
|
| 318 |
-
data = random.sample(list(coco_obj.anns.values()), k=min(self.SUPPORT_MEMORY_SIZE, len(coco_obj.anns)))
|
| 319 |
-
data = [ d['caption'] for d in data ]
|
| 320 |
-
else:
|
| 321 |
-
#data = random.sample(data,500000)
|
| 322 |
-
print(f"[!] Unimplemented data type '{self.type}'[!]")
|
| 323 |
-
return None, None
|
| 324 |
-
|
| 325 |
-
text_features = []
|
| 326 |
-
captions = []
|
| 327 |
-
|
| 328 |
-
self.clip_model.eval()
|
| 329 |
-
|
| 330 |
-
n_txts = len(data)
|
| 331 |
-
n_batch = math.ceil(n_txts / batch_size)
|
| 332 |
-
for i in tqdm(range(n_batch)):
|
| 333 |
-
start = i * batch_size
|
| 334 |
-
end = start + batch_size if i < n_batch - 1 else n_txts
|
| 335 |
-
|
| 336 |
-
texts = data[start:end]
|
| 337 |
-
with torch.no_grad():
|
| 338 |
-
texts_token = self.tokenizer(texts).to(self.device)
|
| 339 |
-
text_feature = self.clip_model.encode_text(texts_token)
|
| 340 |
-
if self.use_talk2dino:
|
| 341 |
-
text_feature = self.talk2dino.project_clip_txt(text_feature)
|
| 342 |
-
text_features.append(text_feature)
|
| 343 |
-
captions.extend(texts)
|
| 344 |
-
|
| 345 |
-
text_features = torch.cat(text_features,dim=0)
|
| 346 |
-
|
| 347 |
-
|
| 348 |
-
#text_features /= text_features.norm(dim=-1,keepdim=True).float()
|
| 349 |
-
|
| 350 |
-
# store captions and text features in hdf5 dataset
|
| 351 |
-
|
| 352 |
-
text_features_ndarray = text_features.cpu().numpy()
|
| 353 |
-
|
| 354 |
-
assert len(text_features_ndarray) == len(captions), f"len(text_features_ndarray) = {len(text_features_ndarray)} != len(captions) = {len(captions)}"
|
| 355 |
-
|
| 356 |
-
#if not os.path.exists(self.H5PY_FILE_PATH):
|
| 357 |
-
# print(f"os.path '{self.H5PY_FILE_PATH}' does not exists")
|
| 358 |
-
|
| 359 |
-
EMBEDDINGS_DIMENSION = text_features_ndarray.shape[1]
|
| 360 |
-
|
| 361 |
-
if self.type != ProjectionType.ONLINE_TEXTS:
|
| 362 |
-
with h5py.File(self.H5PY_FILE_PATH, 'w') as hf:
|
| 363 |
-
|
| 364 |
-
if self.H5PY_EMBEDDINGS_DATASET_NAME in hf:
|
| 365 |
-
embeddings_dataset = hf[self.H5PY_EMBEDDINGS_DATASET_NAME]
|
| 366 |
-
text_dataset = hf[self.H5PY_TEXT_DATASET_NAME]
|
| 367 |
-
print(f"[!] Dataset '{self.H5PY_EMBEDDINGS_DATASET_NAME}' already exists! Going to overwrite [!]")
|
| 368 |
-
else:
|
| 369 |
-
embeddings_dataset = hf.create_dataset(self.H5PY_EMBEDDINGS_DATASET_NAME, shape=(self.SUPPORT_MEMORY_SIZE, EMBEDDINGS_DIMENSION), dtype='float32')
|
| 370 |
-
text_dataset = hf.create_dataset(self.H5PY_TEXT_DATASET_NAME, shape=(self.SUPPORT_MEMORY_SIZE, ), dtype=h5py.string_dtype(encoding='utf-8')) #, dtype='str'
|
| 371 |
-
|
| 372 |
-
for num_row in range(len(text_features_ndarray)):
|
| 373 |
-
embeddings_dataset[num_row] = text_features_ndarray[num_row]
|
| 374 |
-
text_dataset[num_row] = captions[num_row]
|
| 375 |
-
else:
|
| 376 |
-
embeddings_dataset = text_features_ndarray
|
| 377 |
-
text_dataset = [x.encode() for x in captions]
|
| 378 |
-
|
| 379 |
-
return embeddings_dataset, text_dataset
|
| 380 |
-
|
| 381 |
-
clip_model = None
|
| 382 |
-
def _load_models(self):
|
| 383 |
-
|
| 384 |
-
if self.clip_model is not None:
|
| 385 |
-
# case already done
|
| 386 |
-
return
|
| 387 |
-
|
| 388 |
-
if self.use_open_clip:
|
| 389 |
-
print("[-] loading open_clip model [-]")
|
| 390 |
-
assert self.clip_modelname is not None, "clip_modelname must be provided when using open_clip"
|
| 391 |
-
from open_clip import create_model_and_transforms, tokenize
|
| 392 |
-
self.clip_model, preprocess_train, preprocess_val = create_model_and_transforms(self.clip_modelname, pretrained="laion2b_s32b_b79k", device=self.device)
|
| 393 |
-
self.preprocess = preprocess_train
|
| 394 |
-
self.tokenizer = tokenize
|
| 395 |
-
return
|
| 396 |
-
|
| 397 |
-
if self.regionclip_config is not None:
|
| 398 |
-
print("[-] loading RegionCLIP model [-]")
|
| 399 |
-
from src.regionclip.loader import load_regionclip_from_checkpoint
|
| 400 |
-
from src.regionclip.datasets.clip_prompt_utils import tokenize as regionclip_tokenize
|
| 401 |
-
|
| 402 |
-
regionclip_checkpoint = self.regionclip_config.get('checkpoint', None)
|
| 403 |
-
if regionclip_checkpoint is None:
|
| 404 |
-
raise ValueError("RegionCLIP checkpoint not specified in the configuration")
|
| 405 |
-
regionclip_config_name = self.regionclip_config.get('config_name', None)
|
| 406 |
-
|
| 407 |
-
print(f"[-] Loading RegionCLIP from checkpoint: {regionclip_checkpoint} [-]")
|
| 408 |
-
if regionclip_config_name:
|
| 409 |
-
print(f"[-] Using RegionCLIP config: {regionclip_config_name} [-]")
|
| 410 |
-
|
| 411 |
-
self.clip_model = load_regionclip_from_checkpoint(
|
| 412 |
-
regionclip_checkpoint,
|
| 413 |
-
device=self.device,
|
| 414 |
-
config=regionclip_config_name
|
| 415 |
-
)
|
| 416 |
-
self.tokenizer = regionclip_tokenize
|
| 417 |
-
self.preprocess = None # RegionCLIP doesn't need preprocessing for text encoding
|
| 418 |
-
|
| 419 |
-
# Test RegionCLIP text encoding to ensure it works
|
| 420 |
-
try:
|
| 421 |
-
test_text = ["A test sentence for RegionCLIP"]
|
| 422 |
-
test_tokens = self.tokenizer(test_text)
|
| 423 |
-
test_features = self.clip_model.encode_text(test_tokens.to(self.device))
|
| 424 |
-
print(f"[-] RegionCLIP text encoding test successful. Output shape: {test_features.shape} [-]")
|
| 425 |
-
except Exception as e:
|
| 426 |
-
print(f"[!] Warning: RegionCLIP text encoding test failed: {e} [!]")
|
| 427 |
-
raise e
|
| 428 |
-
|
| 429 |
-
return
|
| 430 |
-
|
| 431 |
-
if self.denseclip_config is not None:
|
| 432 |
-
print("[-] loading DenseClip model [-]")
|
| 433 |
-
from src.denseclip.loader import load_denseclip, DenseCLIP_tokenize
|
| 434 |
-
|
| 435 |
-
print(f"[-] Loading DenseClip from config: {self.denseclip_config} [-]")
|
| 436 |
-
|
| 437 |
-
# Load DenseClip model
|
| 438 |
-
self.clip_model = load_denseclip(
|
| 439 |
-
config_name=self.denseclip_config,
|
| 440 |
-
device=self.device
|
| 441 |
-
)
|
| 442 |
-
|
| 443 |
-
# DenseClip should have encode_text method and a tokenizer
|
| 444 |
-
# We need to check if DenseClip has a tokenizer method
|
| 445 |
-
if DenseCLIP_tokenize is not None:
|
| 446 |
-
self.tokenizer = DenseCLIP_tokenize
|
| 447 |
-
else:
|
| 448 |
-
# Fallback to CLIP tokenizer if DenseClip doesn't provide one
|
| 449 |
-
import clip
|
| 450 |
-
self.tokenizer = clip.tokenize
|
| 451 |
-
print("[!] Warning: DenseClip model doesn't have tokenizer, using CLIP tokenizer [!]")
|
| 452 |
-
|
| 453 |
-
self.preprocess = None # DenseClip doesn't need preprocessing for text encoding
|
| 454 |
-
|
| 455 |
-
# Test DenseClip text encoding to ensure it works
|
| 456 |
-
try:
|
| 457 |
-
test_text = ["A test sentence for DenseClip"]
|
| 458 |
-
test_tokens = self.tokenizer(test_text)
|
| 459 |
-
if hasattr(test_tokens, 'to'):
|
| 460 |
-
test_tokens = test_tokens.to(self.device)
|
| 461 |
-
test_features = self.clip_model.encode_text(test_tokens)
|
| 462 |
-
print(f"[-] DenseClip text encoding test successful. Output shape: {test_features.shape} [-]")
|
| 463 |
-
except Exception as e:
|
| 464 |
-
print(f"[!] Warning: DenseClip text encoding test failed: {e} [!]")
|
| 465 |
-
raise e
|
| 466 |
-
|
| 467 |
-
return
|
| 468 |
-
|
| 469 |
-
import clip
|
| 470 |
-
if self.clip_modelname is None:
|
| 471 |
-
clip_model_name = "ViT-B/16" if self.use_talk2dino else "ViT-B/32"
|
| 472 |
-
else:
|
| 473 |
-
clip_model_name = self.clip_modelname
|
| 474 |
-
if 'DINO.txt' not in clip_model_name:
|
| 475 |
-
self.clip_model, self.preprocess = clip.load(clip_model_name, device=self.device, jit=False)
|
| 476 |
-
self.tokenizer = clip.tokenize
|
| 477 |
-
if self.use_talk2dino:
|
| 478 |
-
# loading Talk2DINO
|
| 479 |
-
if type(self.use_talk2dino) == str:
|
| 480 |
-
proj_name = self.use_talk2dino
|
| 481 |
-
elif self.linear_talk2dino is False:
|
| 482 |
-
proj_name = 'vitb_mlp_infonce'
|
| 483 |
-
else:
|
| 484 |
-
proj_name = 'vitb_linear_infonce'
|
| 485 |
-
|
| 486 |
-
|
| 487 |
-
config = os.path.join(self.__TALK2DINO_CONFIG_WEIGHTS_PATH, "configs_talk2dino", proj_name + '.yaml')
|
| 488 |
-
weights = os.path.join(self.__TALK2DINO_CONFIG_WEIGHTS_PATH, "weights_talk2dino", proj_name + self.talk2dino_attn_type_str + '.pth')
|
| 489 |
-
#import sys
|
| 490 |
-
#import os
|
| 491 |
-
#add_path = os.path.abspath( os.path.dirname("../"))
|
| 492 |
-
##print(add_path)
|
| 493 |
-
#sys.path.insert(1, add_path )
|
| 494 |
-
from src.model import ProjectionLayer
|
| 495 |
-
self.talk2dino = ProjectionLayer.from_config(config)
|
| 496 |
-
self.talk2dino.load_state_dict(torch.load((weights), self.device))
|
| 497 |
-
self.talk2dino.to(self.device)
|
| 498 |
-
else:
|
| 499 |
-
self.clip_model = torch.hub.load('facebookresearch/dinov2', 'dinov2_vitl14_reg4_dinotxt_tet1280d20h24l').to(self.device)
|
| 500 |
-
self.tokenizer = get_tokenizer().tokenize
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
src/denseclip/clip_loader/README.md
DELETED
|
@@ -1,233 +0,0 @@
|
|
| 1 |
-
# DenseCLIP to CLIP Loader
|
| 2 |
-
|
| 3 |
-
A simple interface for loading DenseCLIP checkpoints as CLIP-like models for text and image encoding.
|
| 4 |
-
|
| 5 |
-
## Overview
|
| 6 |
-
|
| 7 |
-
This module provides a clean API to load DenseCLIP models and use them like standard CLIP models for encoding text and images. It abstracts away the complexity of DenseCLIP's detection/segmentation components and exposes only the core vision-language encoding functionality.
|
| 8 |
-
|
| 9 |
-
## Features
|
| 10 |
-
|
| 11 |
-
- โ
**Simple API**: Load DenseCLIP models with a single function call
|
| 12 |
-
- โ
**CLIP-like Interface**: Familiar `encode_text()` and `encode_image()` methods
|
| 13 |
-
- โ
**Flexible Configuration**: YAML-based configuration system
|
| 14 |
-
- โ
**Multiple Input Types**: Support for PIL Images, image tensors, strings, and text lists
|
| 15 |
-
- โ
**Automatic Preprocessing**: Built-in image preprocessing pipeline
|
| 16 |
-
- โ
**Device Management**: Automatic GPU/CPU detection and placement
|
| 17 |
-
|
| 18 |
-
## Quick Start
|
| 19 |
-
|
| 20 |
-
```python
|
| 21 |
-
from clip_loader import load_clip
|
| 22 |
-
|
| 23 |
-
# Load DenseCLIP model with default configuration
|
| 24 |
-
model = load_clip('denseclip_segmentation_vitb16')
|
| 25 |
-
|
| 26 |
-
# Encode text
|
| 27 |
-
texts = ["a photo of a cat", "a photo of a dog"]
|
| 28 |
-
text_features = model.encode_text(texts)
|
| 29 |
-
|
| 30 |
-
# Encode images (PIL Images)
|
| 31 |
-
from PIL import Image
|
| 32 |
-
images = [Image.open("cat.jpg"), Image.open("dog.jpg")]
|
| 33 |
-
image_features = model.encode_image(images)
|
| 34 |
-
|
| 35 |
-
# Compute similarities
|
| 36 |
-
similarities = model.compute_similarity(image_features, text_features)
|
| 37 |
-
print(f"Image-text similarities: {similarities}")
|
| 38 |
-
```
|
| 39 |
-
|
| 40 |
-
## Configuration
|
| 41 |
-
|
| 42 |
-
Models are configured using YAML files in the `configs/` directory. The main configuration for DenseCLIP ViT-B/16 is in `configs/denseclip_vitb16.yaml`.
|
| 43 |
-
|
| 44 |
-
### Configuration Structure
|
| 45 |
-
|
| 46 |
-
```yaml
|
| 47 |
-
model:
|
| 48 |
-
name: "denseclip_vitb16"
|
| 49 |
-
type: "vit"
|
| 50 |
-
|
| 51 |
-
vision:
|
| 52 |
-
image_resolution: 224
|
| 53 |
-
vision_layers: 12
|
| 54 |
-
vision_width: 768
|
| 55 |
-
vision_patch_size: 16
|
| 56 |
-
embed_dim: 512
|
| 57 |
-
|
| 58 |
-
text:
|
| 59 |
-
context_length: 13 # DenseCLIP uses shorter context
|
| 60 |
-
vocab_size: 49408
|
| 61 |
-
transformer_width: 512
|
| 62 |
-
transformer_heads: 8
|
| 63 |
-
transformer_layers: 12
|
| 64 |
-
embed_dim: 512
|
| 65 |
-
|
| 66 |
-
checkpoint:
|
| 67 |
-
path: "/path/to/denseclip/checkpoint.pth"
|
| 68 |
-
format: "denseclip"
|
| 69 |
-
|
| 70 |
-
preprocessing:
|
| 71 |
-
image_mean: [0.48145466, 0.4578275, 0.40821073]
|
| 72 |
-
image_std: [0.26862954, 0.26130258, 0.27577711]
|
| 73 |
-
normalize: true
|
| 74 |
-
```
|
| 75 |
-
|
| 76 |
-
## API Reference
|
| 77 |
-
|
| 78 |
-
### Core Functions
|
| 79 |
-
|
| 80 |
-
#### `load_clip(config_name, checkpoint_path=None, device='auto')`
|
| 81 |
-
|
| 82 |
-
Load a DenseCLIP model with the specified configuration.
|
| 83 |
-
|
| 84 |
-
**Parameters:**
|
| 85 |
-
- `config_name` (str): Name of config file (without .yaml extension)
|
| 86 |
-
- `checkpoint_path` (str, optional): Path to checkpoint file (overrides config)
|
| 87 |
-
- `device` (str): Device to load on ('auto', 'cpu', 'cuda')
|
| 88 |
-
|
| 89 |
-
**Returns:**
|
| 90 |
-
- `DenseCLIPModel`: Loaded model ready for inference
|
| 91 |
-
|
| 92 |
-
#### `load_denseclip_model(config_path, checkpoint_path=None, device='auto')`
|
| 93 |
-
|
| 94 |
-
Load a DenseCLIP model from configuration file path.
|
| 95 |
-
|
| 96 |
-
### DenseCLIPModel Methods
|
| 97 |
-
|
| 98 |
-
#### `encode_text(texts)`
|
| 99 |
-
|
| 100 |
-
Encode text into feature vectors.
|
| 101 |
-
|
| 102 |
-
**Parameters:**
|
| 103 |
-
- `texts` (str or List[str]): Text string(s) to encode
|
| 104 |
-
|
| 105 |
-
**Returns:**
|
| 106 |
-
- `torch.Tensor`: Normalized text features [batch_size, embed_dim]
|
| 107 |
-
|
| 108 |
-
#### `encode_image(images)`
|
| 109 |
-
|
| 110 |
-
Encode images into feature vectors.
|
| 111 |
-
|
| 112 |
-
**Parameters:**
|
| 113 |
-
- `images`: PIL Image, List[PIL.Image], or preprocessed tensor
|
| 114 |
-
|
| 115 |
-
**Returns:**
|
| 116 |
-
- `torch.Tensor`: Normalized image features [batch_size, embed_dim]
|
| 117 |
-
|
| 118 |
-
#### `compute_similarity(image_features, text_features, temperature=1.0)`
|
| 119 |
-
|
| 120 |
-
Compute similarity between image and text features.
|
| 121 |
-
|
| 122 |
-
**Parameters:**
|
| 123 |
-
- `image_features` (torch.Tensor): Image features [N, embed_dim]
|
| 124 |
-
- `text_features` (torch.Tensor): Text features [M, embed_dim]
|
| 125 |
-
- `temperature` (float): Temperature scaling factor
|
| 126 |
-
|
| 127 |
-
**Returns:**
|
| 128 |
-
- `torch.Tensor`: Similarity matrix [N, M]
|
| 129 |
-
|
| 130 |
-
## Examples
|
| 131 |
-
|
| 132 |
-
### Basic Text-Image Retrieval
|
| 133 |
-
|
| 134 |
-
```python
|
| 135 |
-
from clip_loader import load_clip
|
| 136 |
-
from PIL import Image
|
| 137 |
-
|
| 138 |
-
# Load model
|
| 139 |
-
model = load_clip('denseclip_vitb16')
|
| 140 |
-
|
| 141 |
-
# Load and encode images
|
| 142 |
-
images = [
|
| 143 |
-
Image.open("cat.jpg"),
|
| 144 |
-
Image.open("dog.jpg"),
|
| 145 |
-
Image.open("car.jpg")
|
| 146 |
-
]
|
| 147 |
-
image_features = model.encode_image(images)
|
| 148 |
-
|
| 149 |
-
# Encode text queries
|
| 150 |
-
queries = [
|
| 151 |
-
"a cute cat",
|
| 152 |
-
"a happy dog",
|
| 153 |
-
"a red car"
|
| 154 |
-
]
|
| 155 |
-
text_features = model.encode_text(queries)
|
| 156 |
-
|
| 157 |
-
# Find best matches
|
| 158 |
-
similarities = model.compute_similarity(image_features, text_features)
|
| 159 |
-
best_matches = similarities.argmax(dim=1)
|
| 160 |
-
|
| 161 |
-
for i, query in enumerate(queries):
|
| 162 |
-
best_image_idx = best_matches[i]
|
| 163 |
-
score = similarities[best_image_idx, i].item()
|
| 164 |
-
print(f"Query '{query}' -> Image {best_image_idx} (score: {score:.3f})")
|
| 165 |
-
```
|
| 166 |
-
|
| 167 |
-
### Zero-Shot Classification
|
| 168 |
-
|
| 169 |
-
```python
|
| 170 |
-
from clip_loader import load_clip
|
| 171 |
-
from PIL import Image
|
| 172 |
-
|
| 173 |
-
model = load_clip('denseclip_vitb16')
|
| 174 |
-
|
| 175 |
-
# Load test image
|
| 176 |
-
image = Image.open("test_image.jpg")
|
| 177 |
-
image_features = model.encode_image(image)
|
| 178 |
-
|
| 179 |
-
# Define class labels
|
| 180 |
-
class_labels = [
|
| 181 |
-
"a photo of a cat",
|
| 182 |
-
"a photo of a dog",
|
| 183 |
-
"a photo of a bird",
|
| 184 |
-
"a photo of a car",
|
| 185 |
-
"a photo of a house"
|
| 186 |
-
]
|
| 187 |
-
|
| 188 |
-
# Encode labels
|
| 189 |
-
text_features = model.encode_text(class_labels)
|
| 190 |
-
|
| 191 |
-
# Classify
|
| 192 |
-
similarities = model.compute_similarity(image_features, text_features)
|
| 193 |
-
probabilities = similarities.softmax(dim=-1)
|
| 194 |
-
|
| 195 |
-
# Show results
|
| 196 |
-
for i, label in enumerate(class_labels):
|
| 197 |
-
prob = probabilities[0, i].item()
|
| 198 |
-
print(f"{label}: {prob:.3f}")
|
| 199 |
-
```
|
| 200 |
-
|
| 201 |
-
### Custom Configuration
|
| 202 |
-
|
| 203 |
-
```python
|
| 204 |
-
from clip_loader import load_denseclip_model
|
| 205 |
-
|
| 206 |
-
# Load with custom config
|
| 207 |
-
model = load_denseclip_model(
|
| 208 |
-
config_path='configs/custom_config.yaml',
|
| 209 |
-
checkpoint_path='/path/to/custom/checkpoint.pth',
|
| 210 |
-
device='cuda:1'
|
| 211 |
-
)
|
| 212 |
-
```
|
| 213 |
-
|
| 214 |
-
## Requirements
|
| 215 |
-
|
| 216 |
-
- PyTorch >= 1.9.0
|
| 217 |
-
- torchvision >= 0.10.0
|
| 218 |
-
- Pillow >= 8.0.0
|
| 219 |
-
- PyYAML >= 5.4.0
|
| 220 |
-
|
| 221 |
-
## Notes
|
| 222 |
-
|
| 223 |
-
- DenseCLIP uses a shorter text context length (13) compared to standard CLIP (77)
|
| 224 |
-
- The model preserves ~98% similarity with original CLIP text representations
|
| 225 |
-
- Image preprocessing follows CLIP's standard normalization
|
| 226 |
-
- All features are L2-normalized for cosine similarity computation
|
| 227 |
-
|
| 228 |
-
## Supported Models
|
| 229 |
-
|
| 230 |
-
Currently supported:
|
| 231 |
-
- `denseclip_vitb16`: DenseCLIP with ViT-B/16 backbone
|
| 232 |
-
|
| 233 |
-
To add support for other DenseCLIP variants, create new configuration files in the `configs/` directory.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
src/denseclip/clip_loader/SUMMARY.md
DELETED
|
@@ -1,78 +0,0 @@
|
|
| 1 |
-
# DenseCLIP to CLIP Loader - Quick Start
|
| 2 |
-
|
| 3 |
-
## โ
Successfully Created!
|
| 4 |
-
|
| 5 |
-
The `clip_loader` module provides a simple interface to load DenseCLIP checkpoints as CLIP-like models for text and image encoding.
|
| 6 |
-
|
| 7 |
-
## ๐ Structure
|
| 8 |
-
|
| 9 |
-
```
|
| 10 |
-
/raid/homes/giacomo.pacini/DenseCLIP/clip_loader/
|
| 11 |
-
โโโ __init__.py # Module initialization
|
| 12 |
-
โโโ denseclip_loader.py # Main loader implementation
|
| 13 |
-
โโโ example_usage.py # Example script
|
| 14 |
-
โโโ requirements.txt # Dependencies
|
| 15 |
-
โโโ README.md # Full documentation
|
| 16 |
-
โโโ configs/
|
| 17 |
-
โโโ denseclip_vitb16.yaml # Configuration for ViT-B/16 model
|
| 18 |
-
```
|
| 19 |
-
|
| 20 |
-
## ๐ Quick Usage
|
| 21 |
-
|
| 22 |
-
```python
|
| 23 |
-
from clip_loader import load_clip
|
| 24 |
-
|
| 25 |
-
# Load DenseCLIP model with default configuration
|
| 26 |
-
model = load_clip('denseclip_vitb16')
|
| 27 |
-
|
| 28 |
-
# Encode text
|
| 29 |
-
texts = ["a photo of a cat", "a photo of a dog"]
|
| 30 |
-
text_features = model.encode_text(texts) # Shape: [2, 512]
|
| 31 |
-
|
| 32 |
-
# Encode images (if you have PIL Images)
|
| 33 |
-
# image_features = model.encode_image(images)
|
| 34 |
-
|
| 35 |
-
# Compute similarities
|
| 36 |
-
similarities = model.compute_similarity(text_features, text_features)
|
| 37 |
-
print(f"Cat-Dog similarity: {similarities[0, 1]:.3f}")
|
| 38 |
-
```
|
| 39 |
-
|
| 40 |
-
## โ
Test Results
|
| 41 |
-
|
| 42 |
-
- **โ
Model loads successfully** from DenseCLIP checkpoint
|
| 43 |
-
- **โ
Text encoding works** (shape: [batch_size, 512])
|
| 44 |
-
- **โ
Features are normalized** (L2 norm = 1.0)
|
| 45 |
-
- **โ
Similarities make sense** (Cat-Dog: 0.872, Car-Person: lower)
|
| 46 |
-
- **โ
Zero-shot classification** shows logical patterns
|
| 47 |
-
- **โ
Model has 157M parameters** (94M vision + 63M text)
|
| 48 |
-
|
| 49 |
-
## ๐ง Key Features
|
| 50 |
-
|
| 51 |
-
- **Simple API**: Just call `load_clip()` and start encoding
|
| 52 |
-
- **Handles DenseCLIP specifics**: Automatically extracts weights from segmentation checkpoint
|
| 53 |
-
- **CLIP-compatible**: Same interface as OpenAI CLIP
|
| 54 |
-
- **Flexible configuration**: YAML-based configuration system
|
| 55 |
-
- **GPU ready**: Automatic device detection and placement
|
| 56 |
-
- **Context length**: Uses DenseCLIP's shorter context (13 vs 77)
|
| 57 |
-
|
| 58 |
-
## ๐ฏ Use Cases
|
| 59 |
-
|
| 60 |
-
1. **Text-Image Retrieval**: Encode both and compute similarities
|
| 61 |
-
2. **Zero-Shot Classification**: Encode class descriptions and compare
|
| 62 |
-
3. **Text Similarity**: Compare text representations
|
| 63 |
-
4. **Feature Extraction**: Get dense vector representations
|
| 64 |
-
|
| 65 |
-
## ๐ Configuration
|
| 66 |
-
|
| 67 |
-
The model uses `/raid/datasets/models_weights/denseclip/segmentation/semanticFPN/ViT-B-DenseCLIP.pth` by default. You can override this by modifying `configs/denseclip_vitb16.yaml` or passing a custom checkpoint path.
|
| 68 |
-
|
| 69 |
-
## ๐ What's Different from Standard CLIP
|
| 70 |
-
|
| 71 |
-
- **Shorter context length**: 13 tokens vs 77
|
| 72 |
-
- **Higher image resolution**: 640px vs 224px
|
| 73 |
-
- **Fine-tuned weights**: Adapted for dense prediction tasks
|
| 74 |
-
- **High text similarity**: ~98% similarity with original CLIP representations
|
| 75 |
-
|
| 76 |
-
## ๐ Ready to Use!
|
| 77 |
-
|
| 78 |
-
The loader is fully functional and ready for use in your projects. See `README.md` for detailed documentation and more examples.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
src/denseclip/clip_loader/__init__.py
DELETED
|
@@ -1,21 +0,0 @@
|
|
| 1 |
-
"""
|
| 2 |
-
DenseCLIP to CLIP Loader
|
| 3 |
-
|
| 4 |
-
A simple interface for loading DenseCLIP checkpoints as CLIP-like models
|
| 5 |
-
for text and image encoding.
|
| 6 |
-
"""
|
| 7 |
-
|
| 8 |
-
from .denseclip_loader import (
|
| 9 |
-
DenseCLIPModel,
|
| 10 |
-
load_denseclip_model,
|
| 11 |
-
load_clip,
|
| 12 |
-
load_config
|
| 13 |
-
)
|
| 14 |
-
|
| 15 |
-
__version__ = "1.0.0"
|
| 16 |
-
__all__ = [
|
| 17 |
-
"DenseCLIPModel",
|
| 18 |
-
"load_denseclip_model",
|
| 19 |
-
"load_clip",
|
| 20 |
-
"load_config"
|
| 21 |
-
]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
src/denseclip/clip_loader/bpe_simple_vocab_16e6.txt.gz
DELETED
|
@@ -1,3 +0,0 @@
|
|
| 1 |
-
version https://git-lfs.github.com/spec/v1
|
| 2 |
-
oid sha256:924691ac288e54409236115652ad4aa250f48203de50a9e4722a6ecd48d6804a
|
| 3 |
-
size 1356917
|
|
|
|
|
|
|
|
|
|
|
|
src/denseclip/clip_loader/configs/denseclip_segmentation_vitb16.yaml
DELETED
|
@@ -1,41 +0,0 @@
|
|
| 1 |
-
# DenseCLIP ViT-B/16 Configuration
|
| 2 |
-
# Configuration for loading DenseCLIP checkpoint as a CLIP-like model
|
| 3 |
-
|
| 4 |
-
model:
|
| 5 |
-
name: "denseclip_vitb16"
|
| 6 |
-
type: "vit" # vision transformer
|
| 7 |
-
|
| 8 |
-
# Vision encoder configuration
|
| 9 |
-
vision:
|
| 10 |
-
image_resolution: 640
|
| 11 |
-
vision_layers: 12
|
| 12 |
-
vision_width: 768
|
| 13 |
-
vision_patch_size: 16
|
| 14 |
-
embed_dim: 512
|
| 15 |
-
|
| 16 |
-
# Text encoder configuration
|
| 17 |
-
text:
|
| 18 |
-
context_length: 13 # DenseCLIP uses shorter context
|
| 19 |
-
vocab_size: 49408
|
| 20 |
-
transformer_width: 512
|
| 21 |
-
transformer_heads: 8
|
| 22 |
-
transformer_layers: 12
|
| 23 |
-
embed_dim: 512
|
| 24 |
-
|
| 25 |
-
# Checkpoint information
|
| 26 |
-
checkpoint:
|
| 27 |
-
path: "/raid/datasets/models_weights/denseclip/segmentation/semanticFPN/ViT-B-DenseCLIP.pth"
|
| 28 |
-
format: "denseclip" # vs "openai_clip"
|
| 29 |
-
|
| 30 |
-
# Processing configuration
|
| 31 |
-
preprocessing:
|
| 32 |
-
image_mean: [0.48145466, 0.4578275, 0.40821073]
|
| 33 |
-
image_std: [0.26862954, 0.26130258, 0.27577711]
|
| 34 |
-
normalize: true
|
| 35 |
-
|
| 36 |
-
# Optional overrides
|
| 37 |
-
overrides:
|
| 38 |
-
# Set to true to use OpenAI CLIP tokenizer instead of DenseCLIP's
|
| 39 |
-
use_openai_tokenizer: false
|
| 40 |
-
# Set custom context length (will resize positional embeddings if needed)
|
| 41 |
-
custom_context_length: null
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
src/denseclip/clip_loader/configs/denseclip_segmentation_vitb16_long_ctx.yaml
DELETED
|
@@ -1,41 +0,0 @@
|
|
| 1 |
-
# DenseCLIP ViT-B/16 Configuration
|
| 2 |
-
# Configuration for loading DenseCLIP checkpoint as a CLIP-like model
|
| 3 |
-
|
| 4 |
-
model:
|
| 5 |
-
name: "denseclip_vitb16"
|
| 6 |
-
type: "vit" # vision transformer
|
| 7 |
-
|
| 8 |
-
# Vision encoder configuration
|
| 9 |
-
vision:
|
| 10 |
-
image_resolution: 640
|
| 11 |
-
vision_layers: 12
|
| 12 |
-
vision_width: 768
|
| 13 |
-
vision_patch_size: 16
|
| 14 |
-
embed_dim: 512
|
| 15 |
-
|
| 16 |
-
# Text encoder configuration
|
| 17 |
-
text:
|
| 18 |
-
context_length: 77 # DenseCLIP uses shorter context
|
| 19 |
-
vocab_size: 49408
|
| 20 |
-
transformer_width: 512
|
| 21 |
-
transformer_heads: 8
|
| 22 |
-
transformer_layers: 12
|
| 23 |
-
embed_dim: 512
|
| 24 |
-
|
| 25 |
-
# Checkpoint information
|
| 26 |
-
checkpoint:
|
| 27 |
-
path: "/raid/datasets/models_weights/denseclip/segmentation/semanticFPN/ViT-B-DenseCLIP_long_ctx.pth"
|
| 28 |
-
format: "denseclip" # vs "openai_clip"
|
| 29 |
-
|
| 30 |
-
# Processing configuration
|
| 31 |
-
preprocessing:
|
| 32 |
-
image_mean: [0.48145466, 0.4578275, 0.40821073]
|
| 33 |
-
image_std: [0.26862954, 0.26130258, 0.27577711]
|
| 34 |
-
normalize: true
|
| 35 |
-
|
| 36 |
-
# Optional overrides
|
| 37 |
-
overrides:
|
| 38 |
-
# Set to true to use OpenAI CLIP tokenizer instead of DenseCLIP's
|
| 39 |
-
use_openai_tokenizer: false
|
| 40 |
-
# Set custom context length (will resize positional embeddings if needed)
|
| 41 |
-
custom_context_length: null
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
src/denseclip/clip_loader/denseclip_loader.py
DELETED
|
@@ -1,316 +0,0 @@
|
|
| 1 |
-
#!/usr/bin/env python3
|
| 2 |
-
"""
|
| 3 |
-
DenseCLIP to CLIP Loader
|
| 4 |
-
|
| 5 |
-
A simple interface for loading DenseCLIP checkpoints as CLIP-like models
|
| 6 |
-
for text and image encoding.
|
| 7 |
-
"""
|
| 8 |
-
|
| 9 |
-
import os
|
| 10 |
-
import sys
|
| 11 |
-
import yaml
|
| 12 |
-
import torch
|
| 13 |
-
import torch.nn as nn
|
| 14 |
-
import torch.nn.functional as F
|
| 15 |
-
from typing import Union, List, Tuple, Optional, Dict, Any
|
| 16 |
-
from PIL import Image
|
| 17 |
-
import torchvision.transforms as transforms
|
| 18 |
-
|
| 19 |
-
# Import local model components
|
| 20 |
-
try:
|
| 21 |
-
from .models import CLIPVisionTransformer, CLIPTextEncoder, ResidualAttentionBlock, LayerNorm, QuickGELU
|
| 22 |
-
from .tokenizer import tokenize
|
| 23 |
-
except ImportError:
|
| 24 |
-
# Fallback for direct execution
|
| 25 |
-
from models import CLIPVisionTransformer, CLIPTextEncoder, ResidualAttentionBlock, LayerNorm, QuickGELU
|
| 26 |
-
from tokenizer import tokenize
|
| 27 |
-
|
| 28 |
-
|
| 29 |
-
class DenseCLIPModel(nn.Module):
|
| 30 |
-
"""
|
| 31 |
-
A CLIP-like model loaded from DenseCLIP checkpoints.
|
| 32 |
-
Provides simple text and image encoding functionality.
|
| 33 |
-
"""
|
| 34 |
-
|
| 35 |
-
def __init__(self, config: Dict[str, Any]):
|
| 36 |
-
super().__init__()
|
| 37 |
-
|
| 38 |
-
self.config = config
|
| 39 |
-
|
| 40 |
-
# Initialize vision encoder
|
| 41 |
-
vision_config = config['model']['vision']
|
| 42 |
-
self.visual = CLIPVisionTransformer(
|
| 43 |
-
input_resolution=vision_config['image_resolution'],
|
| 44 |
-
patch_size=vision_config['vision_patch_size'],
|
| 45 |
-
width=vision_config['vision_width'],
|
| 46 |
-
layers=vision_config['vision_layers'],
|
| 47 |
-
heads=vision_config['vision_width'] // 64,
|
| 48 |
-
output_dim=vision_config['embed_dim']
|
| 49 |
-
)
|
| 50 |
-
|
| 51 |
-
# Initialize text encoder
|
| 52 |
-
text_config = config['model']['text']
|
| 53 |
-
self.text_encoder = CLIPTextEncoder(
|
| 54 |
-
context_length=text_config['context_length'],
|
| 55 |
-
vocab_size=text_config['vocab_size'],
|
| 56 |
-
transformer_width=text_config['transformer_width'],
|
| 57 |
-
transformer_heads=text_config['transformer_heads'],
|
| 58 |
-
transformer_layers=text_config['transformer_layers'],
|
| 59 |
-
embed_dim=text_config['embed_dim']
|
| 60 |
-
)
|
| 61 |
-
|
| 62 |
-
# Store configuration for preprocessing
|
| 63 |
-
self.context_length = text_config['context_length']
|
| 64 |
-
self.image_resolution = vision_config['image_resolution']
|
| 65 |
-
|
| 66 |
-
# Initialize preprocessing
|
| 67 |
-
self._setup_preprocessing()
|
| 68 |
-
|
| 69 |
-
def _setup_preprocessing(self):
|
| 70 |
-
"""Setup image preprocessing pipeline"""
|
| 71 |
-
preprocess_config = self.config['preprocessing']
|
| 72 |
-
|
| 73 |
-
self.preprocess = transforms.Compose([
|
| 74 |
-
transforms.Resize(self.image_resolution, interpolation=transforms.InterpolationMode.BICUBIC),
|
| 75 |
-
transforms.CenterCrop(self.image_resolution),
|
| 76 |
-
transforms.ToTensor(),
|
| 77 |
-
transforms.Normalize(
|
| 78 |
-
mean=preprocess_config['image_mean'],
|
| 79 |
-
std=preprocess_config['image_std']
|
| 80 |
-
)
|
| 81 |
-
])
|
| 82 |
-
|
| 83 |
-
def encode_image(self, images: Union[torch.Tensor, List[Image.Image], Image.Image]) -> torch.Tensor:
|
| 84 |
-
"""
|
| 85 |
-
Encode images into feature vectors
|
| 86 |
-
|
| 87 |
-
Args:
|
| 88 |
-
images: PIL Images, list of PIL Images, or preprocessed tensor
|
| 89 |
-
|
| 90 |
-
Returns:
|
| 91 |
-
Normalized image features [batch_size, embed_dim]
|
| 92 |
-
"""
|
| 93 |
-
if isinstance(images, (list, tuple)):
|
| 94 |
-
# List of PIL Images
|
| 95 |
-
image_tensors = torch.stack([self.preprocess(img) for img in images])
|
| 96 |
-
elif isinstance(images, Image.Image):
|
| 97 |
-
# Single PIL Image
|
| 98 |
-
image_tensors = self.preprocess(images).unsqueeze(0)
|
| 99 |
-
elif isinstance(images, torch.Tensor):
|
| 100 |
-
# Already preprocessed tensor
|
| 101 |
-
image_tensors = images
|
| 102 |
-
else:
|
| 103 |
-
raise ValueError(f"Unsupported image type: {type(images)}")
|
| 104 |
-
|
| 105 |
-
# Move to same device as model
|
| 106 |
-
device = next(self.parameters()).device
|
| 107 |
-
image_tensors = image_tensors.to(device)
|
| 108 |
-
|
| 109 |
-
# Encode
|
| 110 |
-
with torch.no_grad():
|
| 111 |
-
image_features = self.visual(image_tensors)
|
| 112 |
-
image_features = F.normalize(image_features, dim=-1)
|
| 113 |
-
|
| 114 |
-
return image_features
|
| 115 |
-
|
| 116 |
-
def encode_text(self, texts: Union[str, List[str]]) -> torch.Tensor:
|
| 117 |
-
"""
|
| 118 |
-
Encode texts into feature vectors
|
| 119 |
-
|
| 120 |
-
Args:
|
| 121 |
-
texts: Single text string or list of text strings
|
| 122 |
-
|
| 123 |
-
Returns:
|
| 124 |
-
Normalized text features [batch_size, embed_dim]
|
| 125 |
-
"""
|
| 126 |
-
if isinstance(texts, str):
|
| 127 |
-
texts = [texts]
|
| 128 |
-
|
| 129 |
-
# Tokenize if necessary
|
| 130 |
-
if isinstance(texts, list):
|
| 131 |
-
tokens = tokenize(texts, context_length=self.context_length)
|
| 132 |
-
elif isinstance(texts, torch.Tensor):
|
| 133 |
-
if texts.dim() == 1:
|
| 134 |
-
# Single tokenized text
|
| 135 |
-
tokens = texts.unsqueeze(0)
|
| 136 |
-
else:
|
| 137 |
-
tokens = texts
|
| 138 |
-
else:
|
| 139 |
-
raise ValueError(f"Unsupported text type: {type(texts)}")
|
| 140 |
-
# Move to same device as model
|
| 141 |
-
device = next(self.parameters()).device
|
| 142 |
-
tokens = tokens.to(device)
|
| 143 |
-
|
| 144 |
-
# Encode
|
| 145 |
-
with torch.no_grad():
|
| 146 |
-
text_features = self.text_encoder(tokens)
|
| 147 |
-
text_features = F.normalize(text_features, dim=-1)
|
| 148 |
-
|
| 149 |
-
return text_features
|
| 150 |
-
|
| 151 |
-
def compute_similarity(self,
|
| 152 |
-
image_features: torch.Tensor,
|
| 153 |
-
text_features: torch.Tensor,
|
| 154 |
-
temperature: float = 1.0) -> torch.Tensor:
|
| 155 |
-
"""
|
| 156 |
-
Compute similarity between image and text features
|
| 157 |
-
|
| 158 |
-
Args:
|
| 159 |
-
image_features: Normalized image features [N, embed_dim]
|
| 160 |
-
text_features: Normalized text features [M, embed_dim]
|
| 161 |
-
temperature: Temperature for scaling similarities
|
| 162 |
-
|
| 163 |
-
Returns:
|
| 164 |
-
Similarity matrix [N, M]
|
| 165 |
-
"""
|
| 166 |
-
return (image_features @ text_features.t()) / temperature
|
| 167 |
-
|
| 168 |
-
def forward(self, images: torch.Tensor, texts: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
|
| 169 |
-
"""
|
| 170 |
-
Forward pass for both image and text encoding
|
| 171 |
-
|
| 172 |
-
Args:
|
| 173 |
-
images: Preprocessed image tensor [batch_size, 3, H, W]
|
| 174 |
-
texts: Tokenized text tensor [batch_size, context_length]
|
| 175 |
-
|
| 176 |
-
Returns:
|
| 177 |
-
Tuple of (image_features, text_features)
|
| 178 |
-
"""
|
| 179 |
-
image_features = self.visual(images)
|
| 180 |
-
text_features = self.text_encoder(texts)
|
| 181 |
-
|
| 182 |
-
# Normalize features
|
| 183 |
-
image_features = F.normalize(image_features, dim=-1)
|
| 184 |
-
text_features = F.normalize(text_features, dim=-1)
|
| 185 |
-
|
| 186 |
-
return image_features, text_features
|
| 187 |
-
|
| 188 |
-
|
| 189 |
-
def load_config(config_path: str) -> Dict[str, Any]:
|
| 190 |
-
"""Load configuration from YAML file"""
|
| 191 |
-
with open(config_path, 'r') as f:
|
| 192 |
-
config = yaml.safe_load(f)
|
| 193 |
-
return config
|
| 194 |
-
|
| 195 |
-
|
| 196 |
-
def load_denseclip_weights(checkpoint_path: str) -> Dict[str, torch.Tensor]:
|
| 197 |
-
"""Load DenseCLIP checkpoint and extract relevant weights"""
|
| 198 |
-
print(f"Loading DenseCLIP checkpoint from: {checkpoint_path}")
|
| 199 |
-
checkpoint = torch.load(checkpoint_path, map_location='cpu')
|
| 200 |
-
|
| 201 |
-
if 'state_dict' not in checkpoint:
|
| 202 |
-
raise ValueError("Checkpoint doesn't contain 'state_dict'")
|
| 203 |
-
|
| 204 |
-
state_dict = checkpoint['state_dict']
|
| 205 |
-
|
| 206 |
-
# Extract vision and text encoder weights
|
| 207 |
-
vision_weights = {}
|
| 208 |
-
text_weights = {}
|
| 209 |
-
|
| 210 |
-
for key, value in state_dict.items():
|
| 211 |
-
if key.startswith('backbone.'):
|
| 212 |
-
# Remove 'backbone.' prefix for vision encoder
|
| 213 |
-
new_key = key[len('backbone.'):]
|
| 214 |
-
vision_weights[new_key] = value
|
| 215 |
-
elif key.startswith('text_encoder.'):
|
| 216 |
-
# Remove 'text_encoder.' prefix
|
| 217 |
-
new_key = key[len('text_encoder.'):]
|
| 218 |
-
text_weights[new_key] = value
|
| 219 |
-
|
| 220 |
-
print(f"Extracted {len(vision_weights)} vision parameters")
|
| 221 |
-
print(f"Extracted {len(text_weights)} text parameters")
|
| 222 |
-
|
| 223 |
-
return {
|
| 224 |
-
'vision': vision_weights,
|
| 225 |
-
'text': text_weights,
|
| 226 |
-
'full_state_dict': state_dict
|
| 227 |
-
}
|
| 228 |
-
|
| 229 |
-
|
| 230 |
-
def load_denseclip_model(config_path: str,
|
| 231 |
-
checkpoint_path: Optional[str] = None,
|
| 232 |
-
device: str = 'auto') -> DenseCLIPModel:
|
| 233 |
-
"""
|
| 234 |
-
Load a DenseCLIP model from configuration and checkpoint
|
| 235 |
-
|
| 236 |
-
Args:
|
| 237 |
-
config_path: Path to YAML configuration file
|
| 238 |
-
checkpoint_path: Optional path to checkpoint (overrides config)
|
| 239 |
-
device: Device to load model on ('auto', 'cpu', 'cuda')
|
| 240 |
-
|
| 241 |
-
Returns:
|
| 242 |
-
Loaded DenseCLIPModel ready for inference
|
| 243 |
-
"""
|
| 244 |
-
# Load configuration
|
| 245 |
-
config = load_config(config_path)
|
| 246 |
-
|
| 247 |
-
# Override checkpoint path if provided
|
| 248 |
-
if checkpoint_path is not None:
|
| 249 |
-
config['checkpoint']['path'] = checkpoint_path
|
| 250 |
-
|
| 251 |
-
# Create model
|
| 252 |
-
model = DenseCLIPModel(config)
|
| 253 |
-
|
| 254 |
-
# Load weights
|
| 255 |
-
checkpoint_path = config['checkpoint']['path']
|
| 256 |
-
if os.path.exists(checkpoint_path):
|
| 257 |
-
weights = load_denseclip_weights(checkpoint_path)
|
| 258 |
-
|
| 259 |
-
# Load vision encoder weights
|
| 260 |
-
if weights['vision']:
|
| 261 |
-
missing_v, unexpected_v = model.visual.load_state_dict(weights['vision'], strict=False)
|
| 262 |
-
if missing_v:
|
| 263 |
-
print(f"Missing vision keys: {len(missing_v)} (expected for FPN/post-norm components)")
|
| 264 |
-
if unexpected_v:
|
| 265 |
-
# Filter out expected mismatches
|
| 266 |
-
important_unexpected = [k for k in unexpected_v if not any(x in k for x in ['fpn', 'ln_post', 'proj'])]
|
| 267 |
-
if important_unexpected:
|
| 268 |
-
print(f"Unexpected vision keys: {important_unexpected}")
|
| 269 |
-
else:
|
| 270 |
-
print(f"โ Vision weights loaded (ignoring {len(unexpected_v)} FPN/post-norm parameters)")
|
| 271 |
-
|
| 272 |
-
# Load text encoder weights
|
| 273 |
-
if weights['text']:
|
| 274 |
-
missing_t, unexpected_t = model.text_encoder.load_state_dict(weights['text'], strict=False)
|
| 275 |
-
if missing_t:
|
| 276 |
-
print(f"Missing text keys: {len(missing_t)}")
|
| 277 |
-
if unexpected_t:
|
| 278 |
-
print(f"Unexpected text keys: {unexpected_t}")
|
| 279 |
-
|
| 280 |
-
print("โ Model weights loaded successfully")
|
| 281 |
-
else:
|
| 282 |
-
print(f"โ Checkpoint not found at {checkpoint_path}, using random weights")
|
| 283 |
-
|
| 284 |
-
# Setup device
|
| 285 |
-
if device == 'auto':
|
| 286 |
-
device = 'cuda' if torch.cuda.is_available() else 'cpu'
|
| 287 |
-
|
| 288 |
-
model = model.to(device)
|
| 289 |
-
model.eval()
|
| 290 |
-
|
| 291 |
-
print(f"โ Model loaded on {device}")
|
| 292 |
-
return model
|
| 293 |
-
|
| 294 |
-
|
| 295 |
-
# Convenience function
|
| 296 |
-
def load_clip(config_name: str = 'denseclip_vitb16',
|
| 297 |
-
checkpoint_path: Optional[str] = None,
|
| 298 |
-
device: str = 'auto') -> DenseCLIPModel:
|
| 299 |
-
"""
|
| 300 |
-
Convenience function to load a DenseCLIP model
|
| 301 |
-
|
| 302 |
-
Args:
|
| 303 |
-
config_name: Name of config file (without .yaml extension)
|
| 304 |
-
checkpoint_path: Optional path to checkpoint
|
| 305 |
-
device: Device to load on
|
| 306 |
-
|
| 307 |
-
Returns:
|
| 308 |
-
Loaded DenseCLIPModel
|
| 309 |
-
"""
|
| 310 |
-
current_dir = os.path.dirname(os.path.abspath(__file__))
|
| 311 |
-
config_path = os.path.join(current_dir, 'configs', f'{config_name}.yaml')
|
| 312 |
-
|
| 313 |
-
if not os.path.exists(config_path):
|
| 314 |
-
raise FileNotFoundError(f"Config file not found: {config_path}")
|
| 315 |
-
|
| 316 |
-
return load_denseclip_model(config_path, checkpoint_path, device)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
src/denseclip/clip_loader/example_usage.py
DELETED
|
@@ -1,108 +0,0 @@
|
|
| 1 |
-
#!/usr/bin/env python3
|
| 2 |
-
"""
|
| 3 |
-
Example usage of the DenseCLIP to CLIP loader
|
| 4 |
-
"""
|
| 5 |
-
|
| 6 |
-
import sys
|
| 7 |
-
import os
|
| 8 |
-
|
| 9 |
-
# Add the clip_loader to path
|
| 10 |
-
current_dir = os.path.dirname(os.path.abspath(__file__))
|
| 11 |
-
sys.path.append(current_dir)
|
| 12 |
-
|
| 13 |
-
from denseclip_loader import load_clip
|
| 14 |
-
import torch
|
| 15 |
-
|
| 16 |
-
|
| 17 |
-
def main():
|
| 18 |
-
print("๐ DenseCLIP to CLIP Loader Example")
|
| 19 |
-
print("=" * 50)
|
| 20 |
-
|
| 21 |
-
# Load model
|
| 22 |
-
print("Loading DenseCLIP model...")
|
| 23 |
-
try:
|
| 24 |
-
model = load_clip('denseclip_segmentation_vitb16')
|
| 25 |
-
print("โ
Model loaded successfully!")
|
| 26 |
-
except Exception as e:
|
| 27 |
-
print(f"โ Error loading model: {e}")
|
| 28 |
-
return
|
| 29 |
-
|
| 30 |
-
print(f"Model device: {next(model.parameters()).device}")
|
| 31 |
-
print(f"Text context length: {model.context_length}")
|
| 32 |
-
print(f"Image resolution: {model.image_resolution}")
|
| 33 |
-
|
| 34 |
-
# Test text encoding
|
| 35 |
-
print("\n๐ Testing text encoding...")
|
| 36 |
-
texts = [
|
| 37 |
-
"a photo of a cat",
|
| 38 |
-
"a photo of a dog",
|
| 39 |
-
"a photo of a car",
|
| 40 |
-
"a person walking",
|
| 41 |
-
"a beautiful sunset"
|
| 42 |
-
]
|
| 43 |
-
|
| 44 |
-
text_features = model.encode_text(texts)
|
| 45 |
-
print(f"Text features shape: {text_features.shape}")
|
| 46 |
-
print(f"Text features norm: {text_features.norm(dim=-1)}") # Should be ~1.0 (normalized)
|
| 47 |
-
|
| 48 |
-
# Test text-text similarities
|
| 49 |
-
print("\n๐ Text-to-text similarities:")
|
| 50 |
-
text_similarities = model.compute_similarity(text_features, text_features)
|
| 51 |
-
|
| 52 |
-
print(f"{'Text':<20} {'Self-sim':<10} {'vs Cat':<10} {'vs Dog':<10}")
|
| 53 |
-
print("-" * 50)
|
| 54 |
-
for i, text in enumerate(texts):
|
| 55 |
-
self_sim = text_similarities[i, i].item()
|
| 56 |
-
cat_sim = text_similarities[i, 0].item()
|
| 57 |
-
dog_sim = text_similarities[i, 1].item()
|
| 58 |
-
print(f"{text:<20} {self_sim:<10.3f} {cat_sim:<10.3f} {dog_sim:<10.3f}")
|
| 59 |
-
|
| 60 |
-
# Test zero-shot classification concepts
|
| 61 |
-
print("\n๐ฏ Zero-shot classification example:")
|
| 62 |
-
test_queries = [
|
| 63 |
-
"an animal",
|
| 64 |
-
"a vehicle",
|
| 65 |
-
"a person",
|
| 66 |
-
"nature scene"
|
| 67 |
-
]
|
| 68 |
-
|
| 69 |
-
query_features = model.encode_text(test_queries)
|
| 70 |
-
classification_similarities = model.compute_similarity(text_features, query_features)
|
| 71 |
-
|
| 72 |
-
print(f"{'Original Text':<20} {'Animal':<8} {'Vehicle':<8} {'Person':<8} {'Nature':<8}")
|
| 73 |
-
print("-" * 60)
|
| 74 |
-
for i, text in enumerate(texts):
|
| 75 |
-
sims = classification_similarities[i]
|
| 76 |
-
print(f"{text:<20} {sims[0]:<8.3f} {sims[1]:<8.3f} {sims[2]:<8.3f} {sims[3]:<8.3f}")
|
| 77 |
-
|
| 78 |
-
# Test feature statistics
|
| 79 |
-
print("\n๐ Feature statistics:")
|
| 80 |
-
print(f"Text feature mean: {text_features.mean():.6f}")
|
| 81 |
-
print(f"Text feature std: {text_features.std():.6f}")
|
| 82 |
-
print(f"Text feature min: {text_features.min():.6f}")
|
| 83 |
-
print(f"Text feature max: {text_features.max():.6f}")
|
| 84 |
-
|
| 85 |
-
# Test model components
|
| 86 |
-
print("\n๐ง Model architecture:")
|
| 87 |
-
print(f"Vision encoder: {type(model.visual).__name__}")
|
| 88 |
-
print(f"Text encoder: {type(model.text_encoder).__name__}")
|
| 89 |
-
|
| 90 |
-
# Count parameters
|
| 91 |
-
vision_params = sum(p.numel() for p in model.visual.parameters())
|
| 92 |
-
text_params = sum(p.numel() for p in model.text_encoder.parameters())
|
| 93 |
-
total_params = vision_params + text_params
|
| 94 |
-
|
| 95 |
-
print(f"\n๐ Parameter count:")
|
| 96 |
-
print(f"Vision encoder: {vision_params:,}")
|
| 97 |
-
print(f"Text encoder: {text_params:,}")
|
| 98 |
-
print(f"Total: {total_params:,}")
|
| 99 |
-
|
| 100 |
-
print("\nโ
All tests completed successfully!")
|
| 101 |
-
print("\n๐ก Usage tip:")
|
| 102 |
-
print(" from clip_loader import load_clip")
|
| 103 |
-
print(" model = load_clip('denseclip_vitb16')")
|
| 104 |
-
print(" features = model.encode_text(['your text here'])")
|
| 105 |
-
|
| 106 |
-
|
| 107 |
-
if __name__ == "__main__":
|
| 108 |
-
main()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|