diff --git a/.gitignore b/.gitignore index 5f3574690accb596145c550530e1db3a5e928877..333c32e12863b6f282b1224d52d7fa3dc534653b 100644 --- a/.gitignore +++ b/.gitignore @@ -1,3 +1,6 @@ *.pyc *.pyo *.pyd +venv/ +.gradio/ +.venv/ diff --git a/README.md b/README.md index c69f23e8582c7218714d4c7dafd1202517b58803..09e84fe072c3a02ec39ed8a62ed5b482a3dcb9bd 100644 --- a/README.md +++ b/README.md @@ -13,4 +13,8 @@ short_description: 'Repo for the Paper "One Patch to Caption Them All: ...' ArXiv: arxiv.org/abs/2510.02898 -Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference +Demo of the Patch-ioner framework, from the paper "One Patch to Caption Them All: A Unified Zero-shot Captioning Framework". + +The project page is at [paciosoft.com/Patch-ioner](https://paciosoft.com/Patch-ioner). + + diff --git a/app.py b/app.py index 1de0342ee6cdaaca331e7c8023837f2945bea90e..50dc665122a579c39f5e588d222dcbaf259d8a5d 100644 --- a/app.py +++ b/app.py @@ -24,8 +24,7 @@ from PIL import Image import numpy as np from typing import List, Dict -# Import the Patchioner model from the src directory -from src.model import Patchioner +from patchioner import Patchioner # Global variable to store the loaded model loaded_model = None @@ -33,7 +32,7 @@ model_config_path = None device = "cuda" if torch.cuda.is_available() else "cpu" # Default model configuration -DEFAULT_MODEL_CONFIG = "mlp.viecap.k.yaml" +DEFAULT_MODEL_CONFIG = "https://huggingface.co/Ruggero1912/Patch-ioner_talk2dino_decap_COCO_Captions" # Example images directory current_dir = os.path.dirname(__file__) @@ -50,13 +49,16 @@ def initialize_default_model() -> str: default_config_path = CONFIGS_DIR / DEFAULT_MODEL_CONFIG if not default_config_path.exists(): - return f"โŒ Default config file not found: {default_config_path}" + print( f"โŒ Default config file not found: {default_config_path}" ) + config = DEFAULT_MODEL_CONFIG # Assume it's a URL or model identifier + print( f"Attempting to load model from identifier: {config}" ) + + else: + config = default_config_path print(f"Loading default model: {DEFAULT_MODEL_CONFIG}") - # Load and parse the config - with open(default_config_path, 'r') as f: - config = yaml.safe_load(f) + # Load the model using the from_config class method model = Patchioner.from_config(config, device=device) @@ -553,7 +555,7 @@ def generate_bbox_caption(image_data, image) -> str: return error_msg -def create_gradio_interface(): +def create_gradio_interface(model_config_name : str): """Create and configure the Gradio interface.""" # Get example files @@ -593,7 +595,7 @@ def create_gradio_interface(): ) as demo: #gr.HTML(custom_js) # inject custom JS - gr.Markdown(""" + gr.Markdown(f""" # ๐ŸŽฏ Patchioner Trace Captioning Demo This demo allows you to: @@ -608,7 +610,7 @@ def create_gradio_interface(): 3. Use the appropriate tool to mark areas of interest in the image 4. Click "Generate Caption" to get AI-generated descriptions - **Model:** Using `mlp.karpathy.yaml` configuration (automatically loaded) + **Model:** Using `{model_config_name}` configuration (automatically loaded) """) # Initialize model status @@ -730,7 +732,7 @@ def create_gradio_interface(): outputs=[image_editor, image_annotator] ) - gr.Markdown(""" + gr.Markdown(f""" ### ๐Ÿ’ก Tips: - **Mode Selection**: Switch between trace and bounding box modes based on your needs - **Trace Mode**: Draw continuous lines over areas you want to describe @@ -741,7 +743,7 @@ def create_gradio_interface(): ### ๐Ÿ”ง Technical Details: - **Trace Mode**: Converts drawings to normalized (x, y) coordinates with timestamps - **BBox Mode**: Uses bounding box coordinates for region-specific captioning - - **Model Architecture**: Uses `mlp.karpathy.yaml` configuration with CLIP and ViT components + - **Model Architecture**: Uses `{model_config_name}` configuration with CLIP and ViT components - **Processing**: Each trace/bbox is processed separately to generate corresponding captions """) @@ -762,7 +764,7 @@ if __name__ == "__main__": print(f"Example images directory: {EXAMPLE_IMAGES_DIR}") print(f"Configs directory: {CONFIGS_DIR}") - demo = create_gradio_interface() + demo = create_gradio_interface(DEFAULT_MODEL_CONFIG) if not args.local: demo.launch() else: diff --git a/configs/mlp.k.yaml b/configs/mlp.k.yaml deleted file mode 100644 index 6c36e44b7d8f4771690793e733ecdc16a93c1891..0000000000000000000000000000000000000000 --- a/configs/mlp.k.yaml +++ /dev/null @@ -1,8 +0,0 @@ -decap_weights: 'weights/decap-talk2dino-coco_karpathy-009.pt' -prefix_size: 768 -linear_talk2dino: False -support_memory_size: 591753 -dino_model: 'dinov2_vitb14_reg' -normalize: True -kkv_attention: False -projection_type: '/raid/datasets/im2txtmemories/coco_train_karpathy.json' diff --git a/configs/mlp.viecap.k.yaml b/configs/mlp.viecap.k.yaml deleted file mode 100644 index 99deff334027cd1ddfe1337803c32bc2d87aac1e..0000000000000000000000000000000000000000 --- a/configs/mlp.viecap.k.yaml +++ /dev/null @@ -1,31 +0,0 @@ -decap_weights: null -prefix_size: 768 -linear_talk2dino: False -support_memory_size: 0 -dino_model: 'dinov2_vitb14_reg' -normalize: False -kkv_attention: False -use_talk2dino_project: False -clip_model_name: "ViT-B/16" - - -# nested config -viecap: - clip_hidden_size: 768 - suffix: ViT-B16_t2d_ - project_length: 10 - temperature: 0.01 - top_k: 3 - threshold: 0.4 - language_model: 'gpt2' - name_of_entities_text: coco_entities #vinvl_vgoi_entities - files_path: 'weights/viecap_files/' - prompt_ensemble: True - weight_path: 'weights/viecap-talk2dino-coco_karpathy-0014.pt' - using_hard_prompt: True - soft_prompt_first: True - only_hard_prompt: False - using_greedy_search: True #if false, use beam search - beam_width: 5 - text_prompt: None - diff --git a/requirements.txt b/requirements.txt index 2ef33d58a637c2fa4a4fbbada1cb323537c89b84..de1206e2e920773d14b08f3b3cf94adf8913f1e9 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,36 +1,3 @@ -# Core dependencies - absolutely required -torch -transformers==4.46.3 -gradio>=4.0.0 -gradio_image_annotation - -# Image processing - required for the demo -pillow -torchvision - -# Data handling - required -numpy -tqdm - -# CLIP - essential for the model -git+https://github.com/openai/CLIP.git - -# Model dependencies - needed for core functionality -timm - -# Hugging Face model hosting -huggingface_hub - -h5py - -# Optional: Only include if specifically needed -# h5py # Only needed for some data formats - can be installed conditionally -# scikit-learn # Only for evaluation - not needed for inference -# plotly # Only for plotting - not needed for basic demo -# pandas # Only for data analysis - not needed for basic demo -# matplotlib # Only for plotting - not needed for basic demo -# pycocotools # Only for COCO evaluation - not needed for basic demo -# nbformat # Only for notebooks - not needed for basic demo -# speaksee # Only for evaluation - not needed for basic demo -# munkres # Only for specific evaluation metrics - not needed for basic demo -# open_clip_torch # Only if using open_clip models - not needed for basic demo +git+https://github.com/Ruggero1912/Patch-ioner +gradio==5.48.0 +gradio_image_annotation \ No newline at end of file diff --git a/src/INViTE/clipfolder/__init__.py b/src/INViTE/clipfolder/__init__.py deleted file mode 100644 index dcc5619538c0f7c782508bdbd9587259d805e0d9..0000000000000000000000000000000000000000 --- a/src/INViTE/clipfolder/__init__.py +++ /dev/null @@ -1 +0,0 @@ -from .clip import * diff --git a/src/INViTE/clipfolder/bpe_simple_vocab_16e6.txt.gz b/src/INViTE/clipfolder/bpe_simple_vocab_16e6.txt.gz deleted file mode 100644 index 36a15856e00a06a9fbed8cdd34d2393fea4a3113..0000000000000000000000000000000000000000 --- a/src/INViTE/clipfolder/bpe_simple_vocab_16e6.txt.gz +++ /dev/null @@ -1,3 +0,0 @@ -version https://git-lfs.github.com/spec/v1 -oid sha256:924691ac288e54409236115652ad4aa250f48203de50a9e4722a6ecd48d6804a -size 1356917 diff --git a/src/INViTE/clipfolder/clip.py b/src/INViTE/clipfolder/clip.py deleted file mode 100644 index 918e8c6913ee6274dd375bcf682e1fabd276b572..0000000000000000000000000000000000000000 --- a/src/INViTE/clipfolder/clip.py +++ /dev/null @@ -1,238 +0,0 @@ -import hashlib -import os -import urllib -import warnings -from typing import Any, Union, List -from pkg_resources import packaging - -import torch -from PIL import Image -from torchvision.transforms import Compose, Resize, CenterCrop, ToTensor, Normalize -from tqdm import tqdm - -from .model import build_model -from .simple_tokenizer import SimpleTokenizer as _Tokenizer - -try: - from torchvision.transforms import InterpolationMode - BICUBIC = InterpolationMode.BICUBIC -except ImportError: - BICUBIC = Image.BICUBIC - - -if packaging.version.parse(torch.__version__) < packaging.version.parse("1.7.1"): - warnings.warn("PyTorch version 1.7.1 or higher is recommended") - - -__all__ = ["available_models", "load", "tokenize"] -_tokenizer = _Tokenizer() - -_MODELS = { - "RN50": "https://openaipublic.azureedge.net/clip/models/afeb0e10f9e5a86da6080e35cf09123aca3b358a0c3e3b6c78a7b63bc04b6762/RN50.pt", - "RN101": "https://openaipublic.azureedge.net/clip/models/8fa8567bab74a42d41c5915025a8e4538c3bdbe8804a470a72f30b0d94fab599/RN101.pt", - "RN50x4": "https://openaipublic.azureedge.net/clip/models/7e526bd135e493cef0776de27d5f42653e6b4c8bf9e0f653bb11773263205fdd/RN50x4.pt", - "RN50x16": "https://openaipublic.azureedge.net/clip/models/52378b407f34354e150460fe41077663dd5b39c54cd0bfd2b27167a4a06ec9aa/RN50x16.pt", - "RN50x64": "https://openaipublic.azureedge.net/clip/models/be1cfb55d75a9666199fb2206c106743da0f6468c9d327f3e0d0a543a9919d9c/RN50x64.pt", - "ViT-B/32": "https://openaipublic.azureedge.net/clip/models/40d365715913c9da98579312b702a82c18be219cc2a73407c4526f58eba950af/ViT-B-32.pt", - "ViT-B/16": "https://openaipublic.azureedge.net/clip/models/5806e77cd80f8b59890b7e101eabd078d9fb84e6937f9e85e4ecb61988df416f/ViT-B-16.pt", - "ViT-L/14": "https://openaipublic.azureedge.net/clip/models/b8cca3fd41ae0c99ba7e8951adf17d267cdb84cd88be6f7c2e0eca1737a03836/ViT-L-14.pt", - "ViT-L/14@336px": "https://openaipublic.azureedge.net/clip/models/3035c92b350959924f9f00213499208652fc7ea050643e8b385c2dac08641f02/ViT-L-14-336px.pt", -} - - -def _download(url: str, root: str): - os.makedirs(root, exist_ok=True) - filename = os.path.basename(url) - - expected_sha256 = url.split("/")[-2] - download_target = os.path.join(root, filename) - - if os.path.exists(download_target) and not os.path.isfile(download_target): - raise RuntimeError(f"{download_target} exists and is not a regular file") - - if os.path.isfile(download_target): - if hashlib.sha256(open(download_target, "rb").read()).hexdigest() == expected_sha256: - return download_target - else: - warnings.warn(f"{download_target} exists, but the SHA256 checksum does not match; re-downloading the file") - - with urllib.request.urlopen(url) as source, open(download_target, "wb") as output: - with tqdm(total=int(source.info().get("Content-Length")), ncols=80, unit='iB', unit_scale=True, unit_divisor=1024) as loop: - while True: - buffer = source.read(8192) - if not buffer: - break - - output.write(buffer) - loop.update(len(buffer)) - - if hashlib.sha256(open(download_target, "rb").read()).hexdigest() != expected_sha256: - raise RuntimeError("Model has been downloaded but the SHA256 checksum does not not match") - - return download_target - - -def _convert_image_to_rgb(image): - return image.convert("RGB") - - -def _transform(n_px): - return Compose([ - Resize(n_px, interpolation=BICUBIC), - CenterCrop(n_px), - _convert_image_to_rgb, - ToTensor(), - Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711)), - ]) - - -def available_models() -> List[str]: - """Returns the names of available CLIP models""" - return list(_MODELS.keys()) - - -def load(name: str, device: Union[str, torch.device] = "cuda" if torch.cuda.is_available() else "cpu", jit: bool = False, - download_root: str = None, extract_last_k_th_token: int = -1, viz: bool = False, image_resolution: int = None): - """Load a CLIP model - - Parameters - ---------- - name : str - A model name listed by `clip.available_models()`, or the path to a model checkpoint containing the state_dict - - device : Union[str, torch.device] - The device to put the loaded model - - jit : bool - Whether to load the optimized JIT model or more hackable non-JIT model (default). - - download_root: str - path to download the model files; by default, it uses "~/.cache/clip" - - Returns - ------- - model : torch.nn.Module - The CLIP model - - preprocess : Callable[[PIL.Image], torch.Tensor] - A torchvision transform that converts a PIL image into a tensor that the returned model can take as its input - """ - if name in _MODELS: - model_path = _download(_MODELS[name], download_root or os.path.expanduser("~/.cache/clip")) - elif os.path.isfile(name): - model_path = name - else: - raise RuntimeError(f"Model {name} not found; available models = {available_models()}") - - with open(model_path, 'rb') as opened_file: - try: - # loading JIT archive - model = torch.jit.load(opened_file, map_location=device if jit else "cpu").eval() - state_dict = None - except RuntimeError: - # loading saved state dict - if jit: - warnings.warn(f"File {model_path} is not a JIT archive. Loading as a state dict instead") - jit = False - state_dict = torch.load(opened_file, map_location="cpu") - - if not jit: - model = build_model(state_dict or model.state_dict(), extract_last_k_th_token, viz, image_resolution=image_resolution).to(device) - if str(device) == "cpu": - model.float() - return model, _transform(model.visual.input_resolution) - - # patch the device names - device_holder = torch.jit.trace(lambda: torch.ones([]).to(torch.device(device)), example_inputs=[]) - device_node = [n for n in device_holder.graph.findAllNodes("prim::Constant") if "Device" in repr(n)][-1] - - def patch_device(module): - try: - graphs = [module.graph] if hasattr(module, "graph") else [] - except RuntimeError: - graphs = [] - - if hasattr(module, "forward1"): - graphs.append(module.forward1.graph) - - for graph in graphs: - for node in graph.findAllNodes("prim::Constant"): - if "value" in node.attributeNames() and str(node["value"]).startswith("cuda"): - node.copyAttributes(device_node) - - model.apply(patch_device) - patch_device(model.encode_image) - patch_device(model.encode_text) - - # patch dtype to float32 on CPU - if str(device) == "cpu": - float_holder = torch.jit.trace(lambda: torch.ones([]).float(), example_inputs=[]) - float_input = list(float_holder.graph.findNode("aten::to").inputs())[1] - float_node = float_input.node() - - def patch_float(module): - try: - graphs = [module.graph] if hasattr(module, "graph") else [] - except RuntimeError: - graphs = [] - - if hasattr(module, "forward1"): - graphs.append(module.forward1.graph) - - for graph in graphs: - for node in graph.findAllNodes("aten::to"): - inputs = list(node.inputs()) - for i in [1, 2]: # dtype can be the second or third argument to aten::to() - if inputs[i].node()["value"] == 5: - inputs[i].node().copyAttributes(float_node) - - model.apply(patch_float) - patch_float(model.encode_image) - patch_float(model.encode_text) - - model.float() - - return model, _transform(model.input_resolution.item()) - - -def tokenize(texts: Union[str, List[str]], context_length: int = 77, truncate: bool = False) -> Union[torch.IntTensor, torch.LongTensor]: - """ - Returns the tokenized representation of given input string(s) - - Parameters - ---------- - texts : Union[str, List[str]] - An input string or a list of input strings to tokenize - - context_length : int - The context length to use; all CLIP models use 77 as the context length - - truncate: bool - Whether to truncate the text in case its encoding is longer than the context length - - Returns - ------- - A two-dimensional tensor containing the resulting tokens, shape = [number of input strings, context_length]. - We return LongTensor when torch version is <1.8.0, since older index_select requires indices to be long. - """ - if isinstance(texts, str): - texts = [texts] - - sot_token = _tokenizer.encoder["<|startoftext|>"] - eot_token = _tokenizer.encoder["<|endoftext|>"] - all_tokens = [[sot_token] + _tokenizer.encode(text) + [eot_token] for text in texts] - if packaging.version.parse(torch.__version__) < packaging.version.parse("1.8.0"): - result = torch.zeros(len(all_tokens), context_length, dtype=torch.long) - else: - result = torch.zeros(len(all_tokens), context_length, dtype=torch.int) - - for i, tokens in enumerate(all_tokens): - if len(tokens) > context_length: - if truncate: - tokens = tokens[:context_length] - tokens[-1] = eot_token - else: - raise RuntimeError(f"Input {texts[i]} is too long for context length {context_length}") - result[i, :len(tokens)] = torch.tensor(tokens) - - return result diff --git a/src/INViTE/clipfolder/model.py b/src/INViTE/clipfolder/model.py deleted file mode 100644 index 51644225279acb87f30ec4322c13ec8d94e6104f..0000000000000000000000000000000000000000 --- a/src/INViTE/clipfolder/model.py +++ /dev/null @@ -1,515 +0,0 @@ -from collections import OrderedDict -from typing import Tuple, Union - -import numpy as np -import torch -import torch.nn.functional as F -from torch import nn - - -class Bottleneck(nn.Module): - expansion = 4 - - def __init__(self, inplanes, planes, stride=1): - super().__init__() - - # all conv layers have stride 1. an avgpool is performed after the second convolution when stride > 1 - self.conv1 = nn.Conv2d(inplanes, planes, 1, bias=False) - self.bn1 = nn.BatchNorm2d(planes) - self.relu1 = nn.ReLU(inplace=True) - - self.conv2 = nn.Conv2d(planes, planes, 3, padding=1, bias=False) - self.bn2 = nn.BatchNorm2d(planes) - self.relu2 = nn.ReLU(inplace=True) - - self.avgpool = nn.AvgPool2d(stride) if stride > 1 else nn.Identity() - - self.conv3 = nn.Conv2d(planes, planes * self.expansion, 1, bias=False) - self.bn3 = nn.BatchNorm2d(planes * self.expansion) - self.relu3 = nn.ReLU(inplace=True) - - self.downsample = None - self.stride = stride - - if stride > 1 or inplanes != planes * Bottleneck.expansion: - # downsampling layer is prepended with an avgpool, and the subsequent convolution has stride 1 - self.downsample = nn.Sequential(OrderedDict([ - ("-1", nn.AvgPool2d(stride)), - ("0", nn.Conv2d(inplanes, planes * self.expansion, 1, stride=1, bias=False)), - ("1", nn.BatchNorm2d(planes * self.expansion)) - ])) - - def forward(self, x: torch.Tensor): - identity = x - - out = self.relu1(self.bn1(self.conv1(x))) - out = self.relu2(self.bn2(self.conv2(out))) - out = self.avgpool(out) - out = self.bn3(self.conv3(out)) - - if self.downsample is not None: - identity = self.downsample(x) - - out += identity - out = self.relu3(out) - return out - - -class AttentionPool2d(nn.Module): - def __init__(self, spacial_dim: int, embed_dim: int, num_heads: int, output_dim: int = None): - super().__init__() - self.positional_embedding = nn.Parameter(torch.randn(spacial_dim ** 2 + 1, embed_dim) / embed_dim ** 0.5) - self.k_proj = nn.Linear(embed_dim, embed_dim) - self.q_proj = nn.Linear(embed_dim, embed_dim) - self.v_proj = nn.Linear(embed_dim, embed_dim) - self.c_proj = nn.Linear(embed_dim, output_dim or embed_dim) - self.num_heads = num_heads - - def forward(self, x): - x = x.flatten(start_dim=2).permute(2, 0, 1) # NCHW -> (HW)NC - x = torch.cat([x.mean(dim=0, keepdim=True), x], dim=0) # (HW+1)NC - x = x + self.positional_embedding[:, None, :].to(x.dtype) # (HW+1)NC - x, _ = F.multi_head_attention_forward( - query=x[:1], key=x, value=x, - embed_dim_to_check=x.shape[-1], - num_heads=self.num_heads, - q_proj_weight=self.q_proj.weight, - k_proj_weight=self.k_proj.weight, - v_proj_weight=self.v_proj.weight, - in_proj_weight=None, - in_proj_bias=torch.cat([self.q_proj.bias, self.k_proj.bias, self.v_proj.bias]), - bias_k=None, - bias_v=None, - add_zero_attn=False, - dropout_p=0, - out_proj_weight=self.c_proj.weight, - out_proj_bias=self.c_proj.bias, - use_separate_proj_weight=True, - training=self.training, - need_weights=False - ) - return x.squeeze(0) - - -class ModifiedResNet(nn.Module): - """ - A ResNet class that is similar to torchvision's but contains the following changes: - - There are now 3 "stem" convolutions as opposed to 1, with an average pool instead of a max pool. - - Performs anti-aliasing strided convolutions, where an avgpool is prepended to convolutions with stride > 1 - - The final pooling layer is a QKV attention instead of an average pool - """ - - def __init__(self, layers, output_dim, heads, input_resolution=224, width=64): - super().__init__() - self.output_dim = output_dim - self.input_resolution = input_resolution - - # the 3-layer stem - self.conv1 = nn.Conv2d(3, width // 2, kernel_size=3, stride=2, padding=1, bias=False) - self.bn1 = nn.BatchNorm2d(width // 2) - self.relu1 = nn.ReLU(inplace=True) - self.conv2 = nn.Conv2d(width // 2, width // 2, kernel_size=3, padding=1, bias=False) - self.bn2 = nn.BatchNorm2d(width // 2) - self.relu2 = nn.ReLU(inplace=True) - self.conv3 = nn.Conv2d(width // 2, width, kernel_size=3, padding=1, bias=False) - self.bn3 = nn.BatchNorm2d(width) - self.relu3 = nn.ReLU(inplace=True) - self.avgpool = nn.AvgPool2d(2) - - # residual layers - self._inplanes = width # this is a *mutable* variable used during construction - self.layer1 = self._make_layer(width, layers[0]) - self.layer2 = self._make_layer(width * 2, layers[1], stride=2) - self.layer3 = self._make_layer(width * 4, layers[2], stride=2) - self.layer4 = self._make_layer(width * 8, layers[3], stride=2) - - embed_dim = width * 32 # the ResNet feature dimension - self.attnpool = AttentionPool2d(input_resolution // 32, embed_dim, heads, output_dim) - - def _make_layer(self, planes, blocks, stride=1): - layers = [Bottleneck(self._inplanes, planes, stride)] - - self._inplanes = planes * Bottleneck.expansion - for _ in range(1, blocks): - layers.append(Bottleneck(self._inplanes, planes)) - - return nn.Sequential(*layers) - - def forward(self, x): - def stem(x): - x = self.relu1(self.bn1(self.conv1(x))) - x = self.relu2(self.bn2(self.conv2(x))) - x = self.relu3(self.bn3(self.conv3(x))) - x = self.avgpool(x) - return x - - x = x.type(self.conv1.weight.dtype) - x = stem(x) - x = self.layer1(x) - x = self.layer2(x) - x = self.layer3(x) - x = self.layer4(x) - x = self.attnpool(x) - - return x - - -class LayerNorm(nn.LayerNorm): - """Subclass torch's LayerNorm to handle fp16.""" - - def forward(self, x: torch.Tensor): - orig_type = x.dtype - ret = super().forward(x.type(torch.float32)) - return ret.type(orig_type) - - -class QuickGELU(nn.Module): - def forward(self, x: torch.Tensor): - return x * torch.sigmoid(1.702 * x) - - -class ResidualAttentionBlock(nn.Module): - def __init__(self, d_model: int, n_head: int, attn_mask: torch.Tensor = None, viz: bool = False): - super().__init__() - - if viz: - self.attn = nn.MultiheadAttentionViz(d_model, n_head) - else: - self.attn = nn.MultiheadAttention(d_model, n_head) - - self.ln_1 = LayerNorm(d_model) - self.mlp = nn.Sequential(OrderedDict([ - ("c_fc", nn.Linear(d_model, d_model * 4)), - ("gelu", QuickGELU()), - ("c_proj", nn.Linear(d_model * 4, d_model)) - ])) - self.ln_2 = LayerNorm(d_model) - self.attn_mask = attn_mask - - """attn_mask โ€“ If specified, a 2D or 3D mask preventing attention to certain positions. - Must be of shape (L,S)(L, S)(L,S) or (Nโ‹…num_heads,L,S)(N\cdot\text{num\_heads}, L, S)(Nโ‹…num_heads,L,S), - where NNN is the batch size, LLL is the target sequence length, and SSS is the source sequence length. - A 2D mask will be broadcasted across the batch while a 3D mask allows for a different mask for each entry - in the batch. Binary, byte, and float masks are supported. For a binary mask, - a True value indicates that the corresponding position is not allowed to attend. - For a byte mask, a non-zero value indicates that the corresponding position is not allowed to attend. - For a float mask, the mask values will be added to the attention weight.""" - - def attention(self, x: torch.Tensor): - self.attn_mask = self.attn_mask.to(dtype=x.dtype, device=x.device) if self.attn_mask is not None else None - return self.attn(x, x, x, need_weights=False, attn_mask=self.attn_mask)[0] - - def forward(self, x: torch.Tensor): - x = x + self.attention(self.ln_1(x)) - x = x + self.mlp(self.ln_2(x)) - return x - - - -class Transformer(nn.Module): - def __init__(self, width: int, layers: int, heads: int, attn_mask: torch.Tensor = None, - extract_last_k_th_token: int=-1, viz: bool = False, num_tokens: int = 50): - super().__init__() - self.width = width - self.layers = layers - print('\n\n\n\n\ntransformer total layers', layers) - if extract_last_k_th_token>0: - start_mask_layer = layers - extract_last_k_th_token - - ans = [] - for cnt in range(layers): - if cnt < start_mask_layer: - ans.append(ResidualAttentionBlock(width, heads, attn_mask, viz)) - else: - print(' mask for layer {}'.format(cnt)) - mask = torch.empty(num_tokens, num_tokens) - mask.fill_(float("-inf")) - mask.fill_diagonal_(0) - ans.append(ResidualAttentionBlock(width, heads, mask.cuda(), viz)) - # TODO: here is hard coded 50 sequence length - # only attend to themselves - - self.resblocks = nn.Sequential(*ans) - else: - self.resblocks = nn.Sequential(*[ResidualAttentionBlock(width, heads, attn_mask, viz) for _ in range(layers)]) - - def forward(self, x: torch.Tensor): - return self.resblocks(x) - - -class VisionTransformer(nn.Module): - def __init__(self, input_resolution: int, patch_size: int, width: int, layers: int, heads: int, output_dim: int, - extract_last_k_th_token: int, viz: bool): - super().__init__() - self.input_resolution = input_resolution - self.output_dim = output_dim - self.conv1 = nn.Conv2d(in_channels=3, out_channels=width, kernel_size=patch_size, stride=patch_size, bias=False) - - scale = width ** -0.5 - self.class_embedding = nn.Parameter(scale * torch.randn(width)) - self.positional_embedding = nn.Parameter(scale * torch.randn((input_resolution // patch_size) ** 2 + 1, width)) - self.ln_pre = LayerNorm(width) - - self.transformer = Transformer(width, layers, heads, extract_last_k_th_token=extract_last_k_th_token, viz=viz, num_tokens=(input_resolution // patch_size) ** 2 + 1) - - self.ln_post = LayerNorm(width) - self.proj = nn.Parameter(scale * torch.randn(width, output_dim)) - - def forward(self, x: torch.Tensor, get_all_last: bool): - # convert x to conv1 dtype - x = x.type(self.conv1.weight.dtype) - x = self.conv1(x) # shape = [*, width, grid, grid] - x = x.reshape(x.shape[0], x.shape[1], -1) # shape = [*, width, grid ** 2] - x = x.permute(0, 2, 1) # shape = [*, grid ** 2, width] - x = torch.cat([self.class_embedding.to(x.dtype) + torch.zeros(x.shape[0], 1, x.shape[-1], dtype=x.dtype, device=x.device), x], dim=1) # shape = [*, grid ** 2 + 1, width] - x = x + self.positional_embedding.to(x.dtype) - x = self.ln_pre(x) - - x = x.permute(1, 0, 2) # NLD -> LND - x = self.transformer(x) - x = x.permute(1, 0, 2) # LND -> NLD - - if get_all_last: - # take all tokens, x is of shape [*, grid ** 2 + 1, width] - # and we apply layer norm to each token separately - # x is of shape [*, grid ** 2 + 1, width] - x = torch.cat([self.ln_post(x[:, idx, :]).unsqueeze(1) for idx in range(x.size(1))], dim=1) - else: - # take the first token (CLS token), x is of shape [*, grid ** 2 + 1, width] - x = self.ln_post(x[:, 0, :]) - - if self.proj is not None: - x = x @ self.proj - # the returned x is of shape [*, output_dim] where * is the batch size or if get_all_last is True, [*, grid ** 2 + 1, output_dim] - return x - - -class CLIP(nn.Module): - def __init__(self, - embed_dim: int, - # vision - image_resolution: int, - vision_layers: Union[Tuple[int, int, int, int], int], - vision_width: int, - vision_patch_size: int, - # text - context_length: int, - vocab_size: int, - transformer_width: int, - transformer_heads: int, - transformer_layers: int, - extract_last_k_th_token: int = -1, - viz: bool = False - ): - super().__init__() - - self.context_length = context_length - - if isinstance(vision_layers, (tuple, list)): - vision_heads = vision_width * 32 // 64 - self.visual = ModifiedResNet( - layers=vision_layers, - output_dim=embed_dim, - heads=vision_heads, - input_resolution=image_resolution, - width=vision_width - ) - else: - vision_heads = vision_width // 64 - self.visual = VisionTransformer( - input_resolution=image_resolution, - patch_size=vision_patch_size, - width=vision_width, - layers=vision_layers, - heads=vision_heads, - output_dim=embed_dim, - extract_last_k_th_token=extract_last_k_th_token, - viz=viz - ) - - self.transformer = Transformer( - width=transformer_width, - layers=transformer_layers, - heads=transformer_heads, - attn_mask=self.build_attention_mask() - ) - - self.vocab_size = vocab_size - self.token_embedding = nn.Embedding(vocab_size, transformer_width) - self.positional_embedding = nn.Parameter(torch.empty(self.context_length, transformer_width)) - self.ln_final = LayerNorm(transformer_width) - - self.text_projection = nn.Parameter(torch.empty(transformer_width, embed_dim)) - self.logit_scale = nn.Parameter(torch.ones([]) * np.log(1 / 0.07)) - - self.initialize_parameters() - - def initialize_parameters(self): - nn.init.normal_(self.token_embedding.weight, std=0.02) - nn.init.normal_(self.positional_embedding, std=0.01) - - if isinstance(self.visual, ModifiedResNet): - if self.visual.attnpool is not None: - std = self.visual.attnpool.c_proj.in_features ** -0.5 - nn.init.normal_(self.visual.attnpool.q_proj.weight, std=std) - nn.init.normal_(self.visual.attnpool.k_proj.weight, std=std) - nn.init.normal_(self.visual.attnpool.v_proj.weight, std=std) - nn.init.normal_(self.visual.attnpool.c_proj.weight, std=std) - - for resnet_block in [self.visual.layer1, self.visual.layer2, self.visual.layer3, self.visual.layer4]: - for name, param in resnet_block.named_parameters(): - if name.endswith("bn3.weight"): - nn.init.zeros_(param) - - proj_std = (self.transformer.width ** -0.5) * ((2 * self.transformer.layers) ** -0.5) - attn_std = self.transformer.width ** -0.5 - fc_std = (2 * self.transformer.width) ** -0.5 - for block in self.transformer.resblocks: - nn.init.normal_(block.attn.in_proj_weight, std=attn_std) - nn.init.normal_(block.attn.out_proj.weight, std=proj_std) - nn.init.normal_(block.mlp.c_fc.weight, std=fc_std) - nn.init.normal_(block.mlp.c_proj.weight, std=proj_std) - - if self.text_projection is not None: - nn.init.normal_(self.text_projection, std=self.transformer.width ** -0.5) - - def build_attention_mask(self): - # lazily create causal attention mask, with full attention between the vision tokens - # pytorch uses additive attention mask; fill with -inf - mask = torch.empty(self.context_length, self.context_length) - mask.fill_(float("-inf")) - mask.triu_(1) # zero out the lower diagonal - return mask - - @property - def dtype(self): - return self.visual.conv1.weight.dtype - - def encode_image(self, image, get_all_last): - return self.visual(image.type(self.dtype), get_all_last) - - def encode_text(self, text): - x = self.token_embedding(text).type(self.dtype) # [batch_size, n_ctx, d_model] - - x = x + self.positional_embedding.type(self.dtype) - x = x.permute(1, 0, 2) # NLD -> LND - x = self.transformer(x) - x = x.permute(1, 0, 2) # LND -> NLD - x = self.ln_final(x).type(self.dtype) - - # x.shape = [batch_size, n_ctx, transformer.width] - # take features from the eot embedding (eot_token is the highest number in each sequence) - x = x[torch.arange(x.shape[0]), text.argmax(dim=-1)] @ self.text_projection - - return x - - def forward(self, image, text, get_all_last=False): - image_features = self.encode_image(image, get_all_last) - text_features = self.encode_text(text) - - # normalized features - image_features = image_features / image_features.norm(dim=1, keepdim=True) - text_features = text_features / text_features.norm(dim=1, keepdim=True) - - # cosine similarity as logits - logit_scale = self.logit_scale.exp() - - if get_all_last: - return logit_scale * image_features, text_features - - - logits_per_image = logit_scale * image_features @ text_features.t() - logits_per_text = logits_per_image.t() - - # shape = [global_batch_size, global_batch_size] - return logits_per_image, logits_per_text - - -def convert_weights(model: nn.Module): - """Convert applicable model parameters to fp16""" - - def _convert_weights_to_fp16(l): - if isinstance(l, (nn.Conv1d, nn.Conv2d, nn.Linear)): - l.weight.data = l.weight.data.half() - if l.bias is not None: - l.bias.data = l.bias.data.half() - - if isinstance(l, nn.MultiheadAttention): - for attr in [*[f"{s}_proj_weight" for s in ["in", "q", "k", "v"]], "in_proj_bias", "bias_k", "bias_v"]: - tensor = getattr(l, attr) - if tensor is not None: - tensor.data = tensor.data.half() - - for name in ["text_projection", "proj"]: - if hasattr(l, name): - attr = getattr(l, name) - if attr is not None: - attr.data = attr.data.half() - - model.apply(_convert_weights_to_fp16) - -import torch -import torch.nn.functional as F - -def resize_pos_embed(old_pe: torch.Tensor, new_shape: int) -> torch.Tensor: - # old_pe: [old_num_patches + 1, C] - # new_shape: new_num_patches + 1 - cls_token = old_pe[:1] - patch_pe = old_pe[1:] - old_num = int(patch_pe.shape[0] ** 0.5) - new_num = int((new_shape - 1) ** 0.5) - - patch_pe = patch_pe.reshape(1, old_num, old_num, -1).permute(0, 3, 1, 2) # (1, C, H, W) - patch_pe = F.interpolate(patch_pe, size=(new_num, new_num), mode='bicubic', align_corners=False) - patch_pe = patch_pe.permute(0, 2, 3, 1).reshape(1, new_num * new_num, -1) - - return torch.cat([cls_token.unsqueeze(0), patch_pe], dim=1).squeeze(0) - -def build_model(state_dict: dict, extract_last_k_th_token, viz, image_resolution: int = None) -> CLIP: - vit = "visual.proj" in state_dict - - if vit: - vision_width = state_dict["visual.conv1.weight"].shape[0] - vision_layers = len([k for k in state_dict.keys() if k.startswith("visual.") and k.endswith(".attn.in_proj_weight")]) - vision_patch_size = state_dict["visual.conv1.weight"].shape[-1] - grid_size = round((state_dict["visual.positional_embedding"].shape[0] - 1) ** 0.5) - if image_resolution is None: - image_resolution = vision_patch_size * grid_size - else: - counts: list = [len(set(k.split(".")[2] for k in state_dict if k.startswith(f"visual.layer{b}"))) for b in [1, 2, 3, 4]] - vision_layers = tuple(counts) - vision_width = state_dict["visual.layer1.0.conv1.weight"].shape[0] - output_width = round((state_dict["visual.attnpool.positional_embedding"].shape[0] - 1) ** 0.5) - vision_patch_size = None - assert output_width ** 2 + 1 == state_dict["visual.attnpool.positional_embedding"].shape[0] - if image_resolution is None: - image_resolution = output_width * 32 - - embed_dim = state_dict["text_projection"].shape[1] - context_length = state_dict["positional_embedding"].shape[0] - vocab_size = state_dict["token_embedding.weight"].shape[0] - transformer_width = state_dict["ln_final.weight"].shape[0] - transformer_heads = transformer_width // 64 - transformer_layers = len(set(k.split(".")[2] for k in state_dict if k.startswith("transformer.resblocks"))) - - model = CLIP( - embed_dim, - image_resolution, vision_layers, vision_width, vision_patch_size, - context_length, vocab_size, transformer_width, transformer_heads, transformer_layers, extract_last_k_th_token, viz - ) - - for key in ["input_resolution", "context_length", "vocab_size"]: - if key in state_dict: - del state_dict[key] - - convert_weights(model) - - pretrained_pe = state_dict['visual.positional_embedding'] - model_pe = model.visual.positional_embedding - - if vit and (pretrained_pe.shape != model_pe.shape): - print(f"Interpolating positional embedding from {pretrained_pe.shape} to {model_pe.shape}") - state_dict['visual.positional_embedding'] = resize_pos_embed(pretrained_pe, model_pe.shape[0]) - - model.load_state_dict(state_dict) - return model.eval() diff --git a/src/INViTE/clipfolder/simple_tokenizer.py b/src/INViTE/clipfolder/simple_tokenizer.py deleted file mode 100644 index 0a66286b7d5019c6e221932a813768038f839c91..0000000000000000000000000000000000000000 --- a/src/INViTE/clipfolder/simple_tokenizer.py +++ /dev/null @@ -1,132 +0,0 @@ -import gzip -import html -import os -from functools import lru_cache - -import ftfy -import regex as re - - -@lru_cache() -def default_bpe(): - return os.path.join(os.path.dirname(os.path.abspath(__file__)), "bpe_simple_vocab_16e6.txt.gz") - - -@lru_cache() -def bytes_to_unicode(): - """ - Returns list of utf-8 byte and a corresponding list of unicode strings. - The reversible bpe codes work on unicode strings. - This means you need a large # of unicode characters in your vocab if you want to avoid UNKs. - When you're at something like a 10B token dataset you end up needing around 5K for decent coverage. - This is a signficant percentage of your normal, say, 32K bpe vocab. - To avoid that, we want lookup tables between utf-8 bytes and unicode strings. - And avoids mapping to whitespace/control characters the bpe code barfs on. - """ - bs = list(range(ord("!"), ord("~")+1))+list(range(ord("ยก"), ord("ยฌ")+1))+list(range(ord("ยฎ"), ord("รฟ")+1)) - cs = bs[:] - n = 0 - for b in range(2**8): - if b not in bs: - bs.append(b) - cs.append(2**8+n) - n += 1 - cs = [chr(n) for n in cs] - return dict(zip(bs, cs)) - - -def get_pairs(word): - """Return set of symbol pairs in a word. - Word is represented as tuple of symbols (symbols being variable-length strings). - """ - pairs = set() - prev_char = word[0] - for char in word[1:]: - pairs.add((prev_char, char)) - prev_char = char - return pairs - - -def basic_clean(text): - text = ftfy.fix_text(text) - text = html.unescape(html.unescape(text)) - return text.strip() - - -def whitespace_clean(text): - text = re.sub(r'\s+', ' ', text) - text = text.strip() - return text - - -class SimpleTokenizer(object): - def __init__(self, bpe_path: str = default_bpe()): - self.byte_encoder = bytes_to_unicode() - self.byte_decoder = {v: k for k, v in self.byte_encoder.items()} - merges = gzip.open(bpe_path).read().decode("utf-8").split('\n') - merges = merges[1:49152-256-2+1] - merges = [tuple(merge.split()) for merge in merges] - vocab = list(bytes_to_unicode().values()) - vocab = vocab + [v+'' for v in vocab] - for merge in merges: - vocab.append(''.join(merge)) - vocab.extend(['<|startoftext|>', '<|endoftext|>']) - self.encoder = dict(zip(vocab, range(len(vocab)))) - self.decoder = {v: k for k, v in self.encoder.items()} - self.bpe_ranks = dict(zip(merges, range(len(merges)))) - self.cache = {'<|startoftext|>': '<|startoftext|>', '<|endoftext|>': '<|endoftext|>'} - self.pat = re.compile(r"""<\|startoftext\|>|<\|endoftext\|>|'s|'t|'re|'ve|'m|'ll|'d|[\p{L}]+|[\p{N}]|[^\s\p{L}\p{N}]+""", re.IGNORECASE) - - def bpe(self, token): - if token in self.cache: - return self.cache[token] - word = tuple(token[:-1]) + ( token[-1] + '',) - pairs = get_pairs(word) - - if not pairs: - return token+'' - - while True: - bigram = min(pairs, key = lambda pair: self.bpe_ranks.get(pair, float('inf'))) - if bigram not in self.bpe_ranks: - break - first, second = bigram - new_word = [] - i = 0 - while i < len(word): - try: - j = word.index(first, i) - new_word.extend(word[i:j]) - i = j - except: - new_word.extend(word[i:]) - break - - if word[i] == first and i < len(word)-1 and word[i+1] == second: - new_word.append(first+second) - i += 2 - else: - new_word.append(word[i]) - i += 1 - new_word = tuple(new_word) - word = new_word - if len(word) == 1: - break - else: - pairs = get_pairs(word) - word = ' '.join(word) - self.cache[token] = word - return word - - def encode(self, text): - bpe_tokens = [] - text = whitespace_clean(basic_clean(text)).lower() - for token in re.findall(self.pat, text): - token = ''.join(self.byte_encoder[b] for b in token.encode('utf-8')) - bpe_tokens.extend(self.encoder[bpe_token] for bpe_token in self.bpe(token).split(' ')) - return bpe_tokens - - def decode(self, tokens): - text = ''.join([self.decoder[token] for token in tokens]) - text = bytearray([self.byte_decoder[c] for c in text]).decode('utf-8', errors="replace").replace('', ' ') - return text diff --git a/src/INViTE/loader.py b/src/INViTE/loader.py deleted file mode 100644 index e4f2e7d85974dc7c8b771902eec2104ace51417d..0000000000000000000000000000000000000000 --- a/src/INViTE/loader.py +++ /dev/null @@ -1,72 +0,0 @@ -import torch -from typing import Union -from .clipfolder.clip import load as invite_clip_load, tokenize as invite_clip_tokenize - - -def load_invite_clip(config: dict, device: Union[str, torch.device] = "cuda" if torch.cuda.is_available() else "cpu"): - """ - Load an INViTE CLIP model based on the provided configuration. - - This method loads an INViTE CLIP model similar to how RegionCLIP is loaded in the Patchioner class. - - Args: - config (dict): Configuration dictionary containing the following keys: - - name (str): Model name listed by `clip.available_models()`, or path to a model checkpoint - - jit (bool, optional): Whether to load the optimized JIT model. Defaults to False - - download_root (str, optional): Path to download model files. Defaults to '/raid/datasets/models_weights/INViTE' - - extract_last_k_th_token (int, optional): Extract last k-th token. Defaults to -1 - - viz (bool, optional): Visualization flag. Defaults to False - device (Union[str, torch.device], optional): Device to load the model on. - Defaults to "cuda" if available, else "cpu" - - Returns: - tuple: (model, preprocess_transform, tokenize_fn) - - model: The loaded INViTE CLIP model - - preprocess_transform: Torchvision transform for preprocessing images - - tokenize_fn: Tokenization function for text processing - - Raises: - KeyError: If required 'name' key is missing from config - RuntimeError: If model loading fails - - Example: - config = { - 'name': 'ViT-B/32', - 'jit': False, - 'download_root': '/raid/datasets/models_weights/INViTE', # optional, this is the default - 'extract_last_k_th_token': -1, - 'viz': False - } - model, preprocess, tokenize = load_invite_clip(config, device='cuda') - """ - - # Validate required parameters - if 'name' not in config: - raise KeyError("'name' key is required in config dictionary") - - # Extract parameters with defaults - name = config['name'] - jit = config.get('jit', False) - download_root = config.get('download_root', '/raid/datasets/models_weights/INViTE') - extract_last_k_th_token = config.get('extract_last_k_th_token', -1) - viz = config.get('viz', False) - - image_resolution = config.get('resolution', None) # Default resolution if not specified - - # Load the INViTE CLIP model using the clip.load function - try: - model, preprocess_transform = invite_clip_load( - name=name, - device=device, - jit=jit, - download_root=download_root, - extract_last_k_th_token=extract_last_k_th_token, - viz=viz, - image_resolution=image_resolution - ) - - # Return model, preprocess transform, and tokenize function - return model, preprocess_transform, invite_clip_tokenize - - except Exception as e: - raise RuntimeError(f"Failed to load INViTE CLIP model '{name}': {str(e)}") diff --git a/src/alphaclip/INSTALL.md b/src/alphaclip/INSTALL.md deleted file mode 100644 index 6397f2cb054a6d731e8a79bcac63d7913941f422..0000000000000000000000000000000000000000 --- a/src/alphaclip/INSTALL.md +++ /dev/null @@ -1,113 +0,0 @@ -# AlphaCLIP Standalone - Installation Guide - -## Quick Installation - -### Prerequisites -- Python 3.7 or higher -- pip package manager - -### Step 1: Install Dependencies - -```bash -cd alphaclip-standalone -pip install -r requirements.txt -``` - -### Step 2: Install the Package - -```bash -# Install in development mode (recommended for testing) -pip install -e . - -# OR install normally -pip install . -``` - -### Step 3: Test Installation - -```bash -python test_installation.py -``` - -### Step 4: Run Example - -```bash -python example.py -``` - -## Manual Dependency Installation - -If you encounter issues with the requirements.txt, install dependencies manually: - -```bash -# Core PyTorch (choose appropriate version for your system) -pip install torch torchvision torchaudio - -# Text processing -pip install ftfy regex tqdm - -# LoRA support -pip install loralib - -# Image processing -pip install Pillow - -# Utilities -pip install numpy packaging -``` - -## GPU Support - -For CUDA support, make sure you install PyTorch with CUDA: - -```bash -# For CUDA 11.8 -pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu118 - -# For CUDA 12.1 -pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu121 - -# Check your CUDA version with: nvidia-smi -``` - -## Verification - -After installation, verify everything works: - -```python -from alphaclip_loader import AlphaCLIPLoader - -# This should work without errors -loader = AlphaCLIPLoader() -models = loader.available_models() -print("Available models:", models) -``` - -## Troubleshooting - -### Common Issues - -1. **ImportError: No module named 'loralib'** - ```bash - pip install loralib - ``` - -2. **CUDA out of memory** - - Use CPU: `AlphaCLIPLoader(default_device="cpu")` - - Or use a smaller model like "ViT-B/32" - -3. **Model download fails** - - Check internet connection - - Ensure you have enough disk space (~1GB per model) - - Models are cached in `~/.cache/clip/` - -4. **Permission errors** - - Use `--user` flag: `pip install --user -e .` - -### Getting Help - -If you encounter issues: -1. Check that all dependencies are properly installed -2. Run the test script: `python test_installation.py` -3. Check CUDA compatibility if using GPU -4. Ensure Python version is 3.7+ diff --git a/src/alphaclip/LICENSE b/src/alphaclip/LICENSE deleted file mode 100644 index 284803df68c514bdd81477e9248b9a0b4e769533..0000000000000000000000000000000000000000 --- a/src/alphaclip/LICENSE +++ /dev/null @@ -1,201 +0,0 @@ - Apache License - Version 2.0, January 2004 - http://www.apache.org/licenses/ - - TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION - - 1. Definitions. - - "License" shall mean the terms and conditions for use, reproduction, - and distribution as defined by Sections 1 through 9 of this document. - - "Licensor" shall mean the copyright owner or entity authorized by - the copyright owner that is granting the License. - - "Legal Entity" shall mean the union of the acting entity and all - other entities that control, are controlled by, or are under common - control with that entity. For the purposes of this definition, - "control" means (i) the power, direct or indirect, to cause the - direction or management of such entity, whether by contract or - otherwise, or (ii) ownership of fifty percent (50%) or more of the - outstanding shares, or (iii) beneficial ownership of such entity. - - "You" (or "Your") shall mean an individual or Legal Entity - exercising permissions granted by this License. - - "Source" form shall mean the preferred form for making modifications, - including but not limited to software source code, documentation - source, and configuration files. - - "Object" form shall mean any form resulting from mechanical - transformation or translation of a Source form, including but - not limited to compiled object code, generated documentation, - and conversions to other media types. - - "Work" shall mean the work of authorship, whether in Source or - Object form, made available under the License, as indicated by a - copyright notice that is included in or attached to the work - (an example is provided in the Appendix below). - - "Derivative Works" shall mean any work, whether in Source or Object - form, that is based on (or derived from) the Work and for which the - editorial revisions, annotations, elaborations, or other modifications - represent, as a whole, an original work of authorship. For the purposes - of this License, Derivative Works shall not include works that remain - separable from, or merely link (or bind by name) to the interfaces of, - the Work and Derivative Works thereof. - - "Contribution" shall mean any work of authorship, including - the original version of the Work and any modifications or additions - to that Work or Derivative Works thereof, that is intentionally - submitted to Licensor for inclusion in the Work by the copyright owner - or by an individual or Legal Entity authorized to submit on behalf of - the copyright owner. For the purposes of this definition, "submitted" - means any form of electronic, verbal, or written communication sent - to the Licensor or its representatives, including but not limited to - communication on electronic mailing lists, source code control systems, - and issue tracking systems that are managed by, or on behalf of, the - Licensor for the purpose of discussing and improving the Work, but - excluding communication that is conspicuously marked or otherwise - designated in writing by the copyright owner as "Not a Contribution." - - "Contributor" shall mean Licensor and any individual or Legal Entity - on behalf of whom a Contribution has been received by Licensor and - subsequently incorporated within the Work. - - 2. Grant of Copyright License. Subject to the terms and conditions of - this License, each Contributor hereby grants to You a perpetual, - worldwide, non-exclusive, no-charge, royalty-free, irrevocable - copyright license to reproduce, prepare Derivative Works of, - publicly display, publicly perform, sublicense, and distribute the - Work and such Derivative Works in Source or Object form. - - 3. Grant of Patent License. Subject to the terms and conditions of - this License, each Contributor hereby grants to You a perpetual, - worldwide, non-exclusive, no-charge, royalty-free, irrevocable - (except as stated in this section) patent license to make, have made, - use, offer to sell, sell, import, and otherwise transfer the Work, - where such license applies only to those patent claims licensable - by such Contributor that are necessarily infringed by their - Contribution(s) alone or by combination of their Contribution(s) - with the Work to which such Contribution(s) was submitted. If You - institute patent litigation against any entity (including a - cross-claim or counterclaim in a lawsuit) alleging that the Work - or a Contribution incorporated within the Work constitutes direct - or contributory patent infringement, then any patent licenses - granted to You under this License for that Work shall terminate - as of the date such litigation is filed. - - 4. Redistribution. You may reproduce and distribute copies of the - Work or Derivative Works thereof in any medium, with or without - modifications, and in Source or Object form, provided that You - meet the following conditions: - - (a) You must give any other recipients of the Work or - Derivative Works a copy of this License; and - - (b) You must cause any modified files to carry prominent notices - stating that You changed the files; and - - (c) You must retain, in the Source form of any Derivative Works - that You distribute, all copyright, patent, trademark, and - attribution notices from the Source form of the Work, - excluding those notices that do not pertain to any part of - the Derivative Works; and - - (d) If the Work includes a "NOTICE" text file as part of its - distribution, then any Derivative Works that You distribute must - include a readable copy of the attribution notices contained - within such NOTICE file, excluding those notices that do not - pertain to any part of the Derivative Works, in at least one - of the following places: within a NOTICE text file distributed - as part of the Derivative Works; within the Source form or - documentation, if provided along with the Derivative Works; or, - within a display generated by the Derivative Works, if and - wherever such third-party notices normally appear. The contents - of the NOTICE file are for informational purposes only and - do not modify the License. You may add Your own attribution - notices within Derivative Works that You distribute, alongside - or as an addendum to the NOTICE text from the Work, provided - that such additional attribution notices cannot be construed - as modifying the License. - - You may add Your own copyright statement to Your modifications and - may provide additional or different license terms and conditions - for use, reproduction, or distribution of Your modifications, or - for any such Derivative Works as a whole, provided Your use, - reproduction, and distribution of the Work otherwise complies with - the conditions stated in this License. - - 5. Submission of Contributions. Unless You explicitly state otherwise, - any Contribution intentionally submitted for inclusion in the Work - by You to the Licensor shall be under the terms and conditions of - this License, without any additional terms or conditions. - Notwithstanding the above, nothing herein shall supersede or modify - the terms of any separate license agreement you may have executed - with Licensor regarding such Contributions. - - 6. Trademarks. This License does not grant permission to use the trade - names, trademarks, service marks, or product names of the Licensor, - except as required for reasonable and customary use in describing the - origin of the Work and reproducing the content of the NOTICE file. - - 7. Disclaimer of Warranty. Unless required by applicable law or - agreed to in writing, Licensor provides the Work (and each - Contributor provides its Contributions) on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or - implied, including, without limitation, any warranties or conditions - of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A - PARTICULAR PURPOSE. You are solely responsible for determining the - appropriateness of using or redistributing the Work and assume any - risks associated with Your exercise of permissions under this License. - - 8. Limitation of Liability. In no event and under no legal theory, - whether in tort (including negligence), contract, or otherwise, - unless required by applicable law (such as deliberate and grossly - negligent acts) or agreed to in writing, shall any Contributor be - liable to You for damages, including any direct, indirect, special, - incidental, or consequential damages of any character arising as a - result of this License or out of the use or inability to use the - Work (including but not limited to damages for loss of goodwill, - work stoppage, computer failure or malfunction, or any and all - other commercial damages or losses), even if such Contributor - has been advised of the possibility of such damages. - - 9. Accepting Warranty or Additional Liability. While redistributing - the Work or Derivative Works thereof, You may choose to offer, - and charge a fee for, acceptance of support, warranty, indemnity, - or other liability obligations and/or rights consistent with this - License. However, in accepting such obligations, You may act only - on Your own behalf and on Your sole responsibility, not on behalf - of any other Contributor, and only if You agree to indemnify, - defend, and hold each Contributor harmless for any liability - incurred by, or claims asserted against, such Contributor by reason - of your accepting any such warranty or additional liability. - - END OF TERMS AND CONDITIONS - - APPENDIX: How to apply the Apache License to your work. - - To apply the Apache License to your work, attach the following - boilerplate notice, with the fields enclosed by brackets "[]" - replaced with your own identifying information. (Don't include - the brackets!) The text should be enclosed in the appropriate - comment syntax for the file format. We also recommend that a - file or class name and description of purpose be included on the - same "printed page" as the copyright notice for easier - identification within third-party archives. - - Copyright [Zeyi Sun] [name of copyright owner] - - Licensed under the Apache License, Version 2.0 (the "License"); - you may not use this file except in compliance with the License. - You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - - Unless required by applicable law or agreed to in writing, software - distributed under the License is distributed on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - See the License for the specific language governing permissions and - limitations under the License. \ No newline at end of file diff --git a/src/alphaclip/MANIFEST.in b/src/alphaclip/MANIFEST.in deleted file mode 100644 index 8df77511f19983ffb187b54738b3b2b316249a10..0000000000000000000000000000000000000000 --- a/src/alphaclip/MANIFEST.in +++ /dev/null @@ -1,7 +0,0 @@ -include README.md -include requirements.txt -include LICENSE -recursive-include alpha_clip *.py -recursive-include alpha_clip *.gz -include example.py -include test_installation.py diff --git a/src/alphaclip/README.md b/src/alphaclip/README.md deleted file mode 100644 index c1bd6c355bfa2bc6e0218c2509d3f57602322db0..0000000000000000000000000000000000000000 --- a/src/alphaclip/README.md +++ /dev/null @@ -1,266 +0,0 @@ -# AlphaCLIP Standalone - -A standalone, easy-to-use version of AlphaCLIP that can be integrated into any project without complex dependencies or setup. - -## Overview - -AlphaCLIP is an enhanced version of OpenAI's CLIP model that provides improved vision-language understanding capabilities. This standalone package makes it easy to use AlphaCLIP in your projects with minimal setup. - -## Features - -- **Easy Installation**: Simple pip install with minimal dependencies -- **Clean API**: Intuitive interface for loading models and processing data -- **Device Flexibility**: Automatic CUDA/CPU detection with manual override options -- **Model Variety**: Support for multiple AlphaCLIP model variants -- **Preprocessing Included**: Built-in image preprocessing and text tokenization - -## Installation - -### Requirements - -- Python 3.7 or higher -- PyTorch 1.7.1 or higher -- CUDA (optional, for GPU acceleration) - -### Install from source - -```bash -# Clone or download this standalone package -cd alphaclip-standalone - -# Install dependencies -pip install -r requirements.txt - -# Install the package -pip install -e . -``` - -### Core Dependencies - -The package requires the following core dependencies: - -``` -torch>=1.7.1 -torchvision -ftfy -regex -tqdm -loralib -Pillow -numpy -packaging -``` - -## Quick Start - -### Basic Usage - -```python -from alphaclip_loader import AlphaCLIPLoader - -# Initialize the loader -loader = AlphaCLIPLoader() - -# Load a model (this will download the model if not cached) -model, preprocess = loader.load_model("ViT-B/16") - -# Tokenize text -text_tokens = loader.tokenize("A photo of a cat") - -# Get text embeddings -text_features = loader.encode_text(model, "A photo of a cat") - -print(f"Text features shape: {text_features.shape}") -``` - -### Advanced Usage - -```python -import torch -from PIL import Image -from alphaclip_loader import AlphaCLIPLoader - -# Initialize with specific device -loader = AlphaCLIPLoader(default_device="cuda") - -# Load model with custom options -model, preprocess = loader.load_model( - "ViT-B/16", - device="cuda", - lora_adapt=False, - rank=16 -) - -# Process an image -image = Image.open("your_image.jpg") -image_tensor = preprocess(image).unsqueeze(0) # Add batch dimension - -# Get embeddings -with torch.no_grad(): - image_features = loader.encode_image(model, image_tensor) - text_features = loader.encode_text(model, ["A photo of a cat", "A dog playing"]) - -# Compute similarities -similarities = loader.get_similarity(text_features, image_features) -print(f"Similarities: {similarities}") -``` - -### One-line Model Loading - -```python -from alphaclip_loader import load_alphaclip - -# Quick loading function -loader, model, preprocess = load_alphaclip("ViT-B/16", device="cuda") -``` - -## Available Models - -You can check available models using: - -```python -from alphaclip_loader import AlphaCLIPLoader - -loader = AlphaCLIPLoader() -models = loader.available_models() -print("Available models:", models) -``` - -Typically includes: -- `ViT-B/32` -- `ViT-B/16` -- `ViT-L/14` -- `ViT-L/14@336px` -- `RN50`, `RN101`, `RN50x4`, `RN50x16`, `RN50x64` - -## API Reference - -### AlphaCLIPLoader Class - -#### Methods - -- **`__init__(default_device=None)`**: Initialize loader with optional default device -- **`available_models()`**: Get list of available model names -- **`load_model(name, **kwargs)`**: Load a model with preprocessing function -- **`tokenize(texts, context_length=77, truncate=True)`**: Tokenize text input -- **`encode_text(model, texts)`**: Encode text to embeddings -- **`encode_image(model, images)`**: Encode images to embeddings -- **`get_similarity(text_features, image_features)`**: Compute cosine similarity - -#### load_model Parameters - -- `name`: Model name or checkpoint path -- `alpha_vision_ckpt_pth`: Additional vision checkpoint path (default: "None") -- `device`: Device to load on (default: auto-detect) -- `jit`: Use JIT compilation (default: False) -- `download_root`: Model download directory (default: ~/.cache/clip) -- `lora_adapt`: Use LoRA adaptation (default: False) -- `rank`: LoRA rank if enabled (default: 16) - -## Example Use Cases - -### Image-Text Similarity - -```python -from alphaclip_loader import load_alphaclip -from PIL import Image -import torch - -loader, model, preprocess = load_alphaclip() - -# Load and preprocess image -image = Image.open("cat.jpg") -image_input = preprocess(image).unsqueeze(0) - -# Define candidate texts -texts = ["a cat", "a dog", "a bird", "a car"] - -# Get features -image_features = loader.encode_image(model, image_input) -text_features = loader.encode_text(model, texts) - -# Calculate similarities -similarities = loader.get_similarity(text_features, image_features) - -# Find best match -best_match_idx = similarities.argmax() -print(f"Best match: {texts[best_match_idx]} (score: {similarities[best_match_idx]:.3f})") -``` - -### Batch Processing - -```python -from alphaclip_loader import AlphaCLIPLoader -import torch - -loader = AlphaCLIPLoader() -model, preprocess = loader.load_model("ViT-B/16") - -# Process multiple texts at once -texts = [ - "A red apple on a table", - "A dog running in the park", - "A beautiful sunset" -] - -# Batch tokenization and encoding -text_features = loader.encode_text(model, texts) -print(f"Batch text features shape: {text_features.shape}") # [3, 512] -``` - -## Performance Tips - -1. **GPU Usage**: Use CUDA for better performance with larger models -2. **Batch Processing**: Process multiple texts/images together when possible -3. **Model Caching**: Models are automatically cached after first download -4. **Memory Management**: Use `torch.no_grad()` during inference to save memory - -## Troubleshooting - -### Common Issues - -1. **CUDA Out of Memory**: Reduce batch size or use CPU -2. **Model Download Fails**: Check internet connection and disk space -3. **Import Errors**: Ensure all dependencies are installed - -### Dependencies Issues - -If you encounter import errors, try: - -```bash -pip install --upgrade torch torchvision -pip install ftfy regex tqdm loralib -``` - -## File Structure - -``` -alphaclip-standalone/ -โ”œโ”€โ”€ __init__.py # Package initialization -โ”œโ”€โ”€ alphaclip_loader.py # Main loader class -โ”œโ”€โ”€ requirements.txt # Dependencies -โ”œโ”€โ”€ setup.py # Package setup -โ”œโ”€โ”€ README.md # This file -โ””โ”€โ”€ alpha_clip/ # Core AlphaCLIP modules - โ”œโ”€โ”€ __init__.py - โ”œโ”€โ”€ alpha_clip.py # Main AlphaCLIP functions - โ”œโ”€โ”€ model.py # Model architectures - โ”œโ”€โ”€ simple_tokenizer.py # Text tokenization - โ””โ”€โ”€ bpe_simple_vocab_16e6.txt.gz # Tokenizer vocabulary -``` - -## License - -This standalone package maintains the same license as the original AlphaCLIP project. - -## Contributing - -This is a standalone distribution. For contributions to the core AlphaCLIP model, please refer to the main AlphaCLIP repository. - -## Changelog - -### Version 1.0.0 -- Initial standalone release -- Clean API with AlphaCLIPLoader class -- Comprehensive documentation and examples -- Easy installation and setup diff --git a/src/alphaclip/__init__.py b/src/alphaclip/__init__.py deleted file mode 100644 index cdece856b7aebbae39b2ada408c81de6954d74c2..0000000000000000000000000000000000000000 --- a/src/alphaclip/__init__.py +++ /dev/null @@ -1,14 +0,0 @@ -""" -AlphaCLIP Standalone Package - -A standalone version of AlphaCLIP that can be used independently. -""" - -from .alphaclip_loader import AlphaCLIPLoader, load_alphaclip - -# Version info -__version__ = "1.0.0" -__author__ = "AlphaCLIP Team" - -# Make main classes available at package level -__all__ = ['AlphaCLIPLoader', 'load_alphaclip'] diff --git a/src/alphaclip/alpha_clip/__init__.py b/src/alphaclip/alpha_clip/__init__.py deleted file mode 100644 index 3d5b643bb3da8fde1fcadedb6919a36fb544cf97..0000000000000000000000000000000000000000 --- a/src/alphaclip/alpha_clip/__init__.py +++ /dev/null @@ -1 +0,0 @@ -from .alpha_clip import * diff --git a/src/alphaclip/alpha_clip/alpha_clip.py b/src/alphaclip/alpha_clip/alpha_clip.py deleted file mode 100644 index 57e7caefd7e925b4c3f0bd7096d101a8a32417aa..0000000000000000000000000000000000000000 --- a/src/alphaclip/alpha_clip/alpha_clip.py +++ /dev/null @@ -1,254 +0,0 @@ -import hashlib -import os -import urllib -import warnings -from typing import Any, Union, List -from pkg_resources import packaging - -import torch -from PIL import Image -from torchvision.transforms import Compose, Resize, CenterCrop, ToTensor, Normalize -from tqdm import tqdm - -from .model import build_model -from .simple_tokenizer import SimpleTokenizer as _Tokenizer - -try: - from torchvision.transforms import InterpolationMode - BICUBIC = InterpolationMode.BICUBIC -except ImportError: - BICUBIC = Image.BICUBIC - - -if packaging.version.parse(torch.__version__) < packaging.version.parse("1.7.1"): - warnings.warn("PyTorch version 1.7.1 or higher is recommended") - - -__all__ = ["available_models", "load", "tokenize"] -_tokenizer = _Tokenizer() - -_MODELS = { - "RN50": "https://openaipublic.azureedge.net/clip/models/afeb0e10f9e5a86da6080e35cf09123aca3b358a0c3e3b6c78a7b63bc04b6762/RN50.pt", - "RN101": "https://openaipublic.azureedge.net/clip/models/8fa8567bab74a42d41c5915025a8e4538c3bdbe8804a470a72f30b0d94fab599/RN101.pt", - "RN50x4": "https://openaipublic.azureedge.net/clip/models/7e526bd135e493cef0776de27d5f42653e6b4c8bf9e0f653bb11773263205fdd/RN50x4.pt", - "RN50x16": "https://openaipublic.azureedge.net/clip/models/52378b407f34354e150460fe41077663dd5b39c54cd0bfd2b27167a4a06ec9aa/RN50x16.pt", - "RN50x64": "https://openaipublic.azureedge.net/clip/models/be1cfb55d75a9666199fb2206c106743da0f6468c9d327f3e0d0a543a9919d9c/RN50x64.pt", - "ViT-B/32": "https://openaipublic.azureedge.net/clip/models/40d365715913c9da98579312b702a82c18be219cc2a73407c4526f58eba950af/ViT-B-32.pt", - "ViT-B/16": "https://openaipublic.azureedge.net/clip/models/5806e77cd80f8b59890b7e101eabd078d9fb84e6937f9e85e4ecb61988df416f/ViT-B-16.pt", - "ViT-L/14": "https://openaipublic.azureedge.net/clip/models/b8cca3fd41ae0c99ba7e8951adf17d267cdb84cd88be6f7c2e0eca1737a03836/ViT-L-14.pt", - "ViT-L/14@336px": "https://openaipublic.azureedge.net/clip/models/3035c92b350959924f9f00213499208652fc7ea050643e8b385c2dac08641f02/ViT-L-14-336px.pt", -} - - -def _download(url: str, root: str): - os.makedirs(root, exist_ok=True) - filename = os.path.basename(url) - - expected_sha256 = url.split("/")[-2] - download_target = os.path.join(root, filename) - - if os.path.exists(download_target) and not os.path.isfile(download_target): - raise RuntimeError(f"{download_target} exists and is not a regular file") - - if os.path.isfile(download_target): - if hashlib.sha256(open(download_target, "rb").read()).hexdigest() == expected_sha256: - return download_target - else: - warnings.warn(f"{download_target} exists, but the SHA256 checksum does not match; re-downloading the file") - - with urllib.request.urlopen(url) as source, open(download_target, "wb") as output: - with tqdm(total=int(source.info().get("Content-Length")), ncols=80, unit='iB', unit_scale=True, unit_divisor=1024) as loop: - while True: - buffer = source.read(8192) - if not buffer: - break - - output.write(buffer) - loop.update(len(buffer)) - - if hashlib.sha256(open(download_target, "rb").read()).hexdigest() != expected_sha256: - raise RuntimeError("Model has been downloaded but the SHA256 checksum does not not match") - - return download_target - - -def _convert_image_to_rgb(image): - return image.convert("RGB") - - -def _transform(n_px): - return Compose([ - Resize(n_px, interpolation=BICUBIC), - CenterCrop(n_px), - _convert_image_to_rgb, - ToTensor(), - Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711)), - ]) - - -def available_models() -> List[str]: - """Returns the names of available CLIP models""" - return list(_MODELS.keys()) - - -def load(name: str, alpha_vision_ckpt_pth="None", device: Union[str, torch.device] = "cuda" if torch.cuda.is_available() else "cpu", jit: bool = False, download_root: str = None, lora_adapt=False, rank=16): - """Load a CLIP model - - Parameters - ---------- - name : str - A model name listed by `clip.available_models()`, or the path to a model checkpoint containing the state_dict - - alpha_vision_ckpt_pth: str - only changed when inferencing model instead of training - - device : Union[str, torch.device] - The device to put the loaded model - - jit : bool - Whether to load the optimized JIT model or more hackable non-JIT model (default). - - download_root: str - path to download the model files; by default, it uses "~/.cache/clip" - - Returns - ------- - model : torch.nn.Module - The CLIP model - - preprocess : Callable[[PIL.Image], torch.Tensor] - A torchvision transform that converts a PIL image into a tensor that the returned model can take as its input - """ - if name in _MODELS: - model_path = _download(_MODELS[name], download_root or os.path.expanduser("~/.cache/clip")) - elif os.path.isfile(name): - model_path = name - else: - raise RuntimeError(f"Model {name} not found; available models = {available_models()}") - - with open(model_path, 'rb') as opened_file: - try: - # loading JIT archive - model = torch.jit.load(opened_file, map_location=device if jit else "cpu").eval() - state_dict = None - except RuntimeError: - # loading saved state dict - if jit: - warnings.warn(f"File {model_path} is not a JIT archive. Loading as a state dict instead") - jit = False - state_dict = torch.load(opened_file, map_location="cpu") - - if not jit: - model = build_model(state_dict or model.state_dict(), lora_adapt=lora_adapt, rank=rank).to(device) - if str(device) == "cpu": - model.float() - # If a separate checkpoint is provided for the visual encoder (e.g., CLIP), load it - if alpha_vision_ckpt_pth != "None": - # Load the visual encoder weights from the given checkpoint path - model.visual.load_state_dict(torch.load(alpha_vision_ckpt_pth)) - # Set the model to evaluation mode - # Note: If LoRA is used, it may merge LoRA weights into the base model here for inference - model.eval() # merge lora params if exists (for inference only) - return model, _transform(model.visual.input_resolution) - - # patch the device names - device_holder = torch.jit.trace(lambda: torch.ones([]).to(torch.device(device)), example_inputs=[]) - device_node = [n for n in device_holder.graph.findAllNodes("prim::Constant") if "Device" in repr(n)][-1] - - def _node_get(node: torch._C.Node, key: str): - """Gets attributes of a node which is polymorphic over return type. - - From https://github.com/pytorch/pytorch/pull/82628 - """ - sel = node.kindOf(key) - return getattr(node, sel)(key) - - def patch_device(module): - try: - graphs = [module.graph] if hasattr(module, "graph") else [] - except RuntimeError: - graphs = [] - - if hasattr(module, "forward1"): - graphs.append(module.forward1.graph) - - for graph in graphs: - for node in graph.findAllNodes("prim::Constant"): - if "value" in node.attributeNames() and str(_node_get(node, "value")).startswith("cuda"): - node.copyAttributes(device_node) - - model.apply(patch_device) - patch_device(model.encode_image) - patch_device(model.encode_text) - - # patch dtype to float32 on CPU - if str(device) == "cpu": - float_holder = torch.jit.trace(lambda: torch.ones([]).float(), example_inputs=[]) - float_input = list(float_holder.graph.findNode("aten::to").inputs())[1] - float_node = float_input.node() - - def patch_float(module): - try: - graphs = [module.graph] if hasattr(module, "graph") else [] - except RuntimeError: - graphs = [] - - if hasattr(module, "forward1"): - graphs.append(module.forward1.graph) - - for graph in graphs: - for node in graph.findAllNodes("aten::to"): - inputs = list(node.inputs()) - for i in [1, 2]: # dtype can be the second or third argument to aten::to() - if _node_get(inputs[i].node(), "value") == 5: - inputs[i].node().copyAttributes(float_node) - - model.apply(patch_float) - patch_float(model.encode_image) - patch_float(model.encode_text) - - model.float() - return model, _transform(model.input_resolution.item()) - - -def tokenize(texts: Union[str, List[str]], context_length: int = 77, truncate: bool = True) -> Union[torch.IntTensor, torch.LongTensor]: - """ - Returns the tokenized representation of given input string(s) - - Parameters - ---------- - texts : Union[str, List[str]] - An input string or a list of input strings to tokenize - - context_length : int - The context length to use; all CLIP models use 77 as the context length - - truncate: bool - Whether to truncate the text in case its encoding is longer than the context length - - Returns - ------- - A two-dimensional tensor containing the resulting tokens, shape = [number of input strings, context_length]. - We return LongTensor when torch version is <1.8.0, since older index_select requires indices to be long. - """ - if isinstance(texts, str): - texts = [texts] - - sot_token = _tokenizer.encoder["<|startoftext|>"] - eot_token = _tokenizer.encoder["<|endoftext|>"] - all_tokens = [[sot_token] + _tokenizer.encode(text) + [eot_token] for text in texts] - if packaging.version.parse(torch.__version__) < packaging.version.parse("1.8.0"): - result = torch.zeros(len(all_tokens), context_length, dtype=torch.long) - else: - result = torch.zeros(len(all_tokens), context_length, dtype=torch.int) - - for i, tokens in enumerate(all_tokens): - if len(tokens) > context_length: - if truncate: - tokens = tokens[:context_length] - tokens[-1] = eot_token - else: - raise RuntimeError(f"Input {texts[i]} is too long for context length {context_length}") - result[i, :len(tokens)] = torch.tensor(tokens) - - return result diff --git a/src/alphaclip/alpha_clip/bpe_simple_vocab_16e6.txt.gz b/src/alphaclip/alpha_clip/bpe_simple_vocab_16e6.txt.gz deleted file mode 100644 index 36a15856e00a06a9fbed8cdd34d2393fea4a3113..0000000000000000000000000000000000000000 --- a/src/alphaclip/alpha_clip/bpe_simple_vocab_16e6.txt.gz +++ /dev/null @@ -1,3 +0,0 @@ -version https://git-lfs.github.com/spec/v1 -oid sha256:924691ac288e54409236115652ad4aa250f48203de50a9e4722a6ecd48d6804a -size 1356917 diff --git a/src/alphaclip/alpha_clip/model.py b/src/alphaclip/alpha_clip/model.py deleted file mode 100644 index dfcebd59b1ecd553c58142f09fd058a7e8702339..0000000000000000000000000000000000000000 --- a/src/alphaclip/alpha_clip/model.py +++ /dev/null @@ -1,609 +0,0 @@ -from collections import OrderedDict -from typing import Tuple, Union - -import numpy as np -import torch -import torch.nn.functional as F -from torch import nn -import loralib as lora -import math -import collections - -class Bottleneck(nn.Module): - expansion = 4 - - def __init__(self, inplanes, planes, stride=1): - super().__init__() - - # all conv layers have stride 1. an avgpool is performed after the second convolution when stride > 1 - self.conv1 = nn.Conv2d(inplanes, planes, 1, bias=False) - self.bn1 = nn.BatchNorm2d(planes) - self.relu1 = nn.ReLU(inplace=True) - - self.conv2 = nn.Conv2d(planes, planes, 3, padding=1, bias=False) - self.bn2 = nn.BatchNorm2d(planes) - self.relu2 = nn.ReLU(inplace=True) - - self.avgpool = nn.AvgPool2d(stride) if stride > 1 else nn.Identity() - - self.conv3 = nn.Conv2d(planes, planes * self.expansion, 1, bias=False) - self.bn3 = nn.BatchNorm2d(planes * self.expansion) - self.relu3 = nn.ReLU(inplace=True) - - self.downsample = None - self.stride = stride - - if stride > 1 or inplanes != planes * Bottleneck.expansion: - # downsampling layer is prepended with an avgpool, and the subsequent convolution has stride 1 - self.downsample = nn.Sequential(OrderedDict([ - ("-1", nn.AvgPool2d(stride)), - ("0", nn.Conv2d(inplanes, planes * self.expansion, 1, stride=1, bias=False)), - ("1", nn.BatchNorm2d(planes * self.expansion)) - ])) - - def forward(self, x: torch.Tensor): - identity = x - - out = self.relu1(self.bn1(self.conv1(x))) - out = self.relu2(self.bn2(self.conv2(out))) - out = self.avgpool(out) - out = self.bn3(self.conv3(out)) - - if self.downsample is not None: - identity = self.downsample(x) - - out += identity - out = self.relu3(out) - return out - - -class AttentionPool2d(nn.Module): - def __init__(self, spacial_dim: int, embed_dim: int, num_heads: int, output_dim: int = None): - super().__init__() - self.positional_embedding = nn.Parameter(torch.randn(spacial_dim ** 2 + 1, embed_dim) / embed_dim ** 0.5) - self.k_proj = nn.Linear(embed_dim, embed_dim) - self.q_proj = nn.Linear(embed_dim, embed_dim) - self.v_proj = nn.Linear(embed_dim, embed_dim) - self.c_proj = nn.Linear(embed_dim, output_dim or embed_dim) - self.num_heads = num_heads - - def forward(self, x): - x = x.flatten(start_dim=2).permute(2, 0, 1) # NCHW -> (HW)NC - x = torch.cat([x.mean(dim=0, keepdim=True), x], dim=0) # (HW+1)NC - x = x + self.positional_embedding[:, None, :].to(x.dtype) # (HW+1)NC - x, _ = F.multi_head_attention_forward( - query=x[:1], key=x, value=x, - embed_dim_to_check=x.shape[-1], - num_heads=self.num_heads, - q_proj_weight=self.q_proj.weight, - k_proj_weight=self.k_proj.weight, - v_proj_weight=self.v_proj.weight, - in_proj_weight=None, - in_proj_bias=torch.cat([self.q_proj.bias, self.k_proj.bias, self.v_proj.bias]), - bias_k=None, - bias_v=None, - add_zero_attn=False, - dropout_p=0, - out_proj_weight=self.c_proj.weight, - out_proj_bias=self.c_proj.bias, - use_separate_proj_weight=True, - training=self.training, - need_weights=False - ) - return x.squeeze(0) - - -class ModifiedResNet(nn.Module): - """ - A ResNet class that is similar to torchvision's but contains the following changes: - - There are now 3 "stem" convolutions as opposed to 1, with an average pool instead of a max pool. - - Performs anti-aliasing strided convolutions, where an avgpool is prepended to convolutions with stride > 1 - - The final pooling layer is a QKV attention instead of an average pool - """ - - def __init__(self, layers, output_dim, heads, input_resolution=224, width=64): - super().__init__() - self.output_dim = output_dim - self.input_resolution = input_resolution - - # the 3-layer stem - self.conv1 = nn.Conv2d(3, width // 2, kernel_size=3, stride=2, padding=1, bias=False) - self.conv1_alpha = nn.Conv2d(in_channels=1, out_channels=width // 2, kernel_size=3, stride=2, padding=1, bias=False) - self.bn1 = nn.BatchNorm2d(width // 2) - self.relu1 = nn.ReLU(inplace=True) - self.conv2 = nn.Conv2d(width // 2, width // 2, kernel_size=3, padding=1, bias=False) - self.bn2 = nn.BatchNorm2d(width // 2) - self.relu2 = nn.ReLU(inplace=True) - self.conv3 = nn.Conv2d(width // 2, width, kernel_size=3, padding=1, bias=False) - self.bn3 = nn.BatchNorm2d(width) - self.relu3 = nn.ReLU(inplace=True) - self.avgpool = nn.AvgPool2d(2) - - # residual layers - self._inplanes = width # this is a *mutable* variable used during construction - self.layer1 = self._make_layer(width, layers[0]) - self.layer2 = self._make_layer(width * 2, layers[1], stride=2) - self.layer3 = self._make_layer(width * 4, layers[2], stride=2) - self.layer4 = self._make_layer(width * 8, layers[3], stride=2) - - embed_dim = width * 32 # the ResNet feature dimension - self.attnpool = AttentionPool2d(input_resolution // 32, embed_dim, heads, output_dim) - - def _make_layer(self, planes, blocks, stride=1): - layers = [Bottleneck(self._inplanes, planes, stride)] - - self._inplanes = planes * Bottleneck.expansion - for _ in range(1, blocks): - layers.append(Bottleneck(self._inplanes, planes)) - - return nn.Sequential(*layers) - - def forward(self, x, alpha=None): - def stem(x): - x = self.relu1(self.bn1(self.conv1(x) + self.conv1_alpha(alpha))) - x = self.relu2(self.bn2(self.conv2(x))) - x = self.relu3(self.bn3(self.conv3(x))) - x = self.avgpool(x) - return x - - x = x.type(self.conv1.weight.dtype) - x = stem(x) - x = self.layer1(x) - x = self.layer2(x) - x = self.layer3(x) - x = self.layer4(x) - x = self.attnpool(x) - - return x - - -class LayerNorm(nn.LayerNorm): - """Subclass torch's LayerNorm to handle fp16.""" - - def forward(self, x: torch.Tensor): - orig_type = x.dtype - ret = super().forward(x.type(torch.float32)) - return ret.type(orig_type) - - -class QuickGELU(nn.Module): - def forward(self, x: torch.Tensor): - return x * torch.sigmoid(1.702 * x) - -class Attention(nn.Module): - def __init__( - self, - dim, - num_heads=8, - qkv_bias=True, - scaled_cosine=False, - scale_heads=False, - logit_scale_max=math.log(1. / 0.01), - attn_drop=0., - proj_drop=0., - lora_adapt=False, - rank=16 - ): - super().__init__() - self.scaled_cosine = scaled_cosine - self.scale_heads = scale_heads - assert dim % num_heads == 0, 'dim should be divisible by num_heads' - self.num_heads = num_heads - self.head_dim = dim // num_heads - self.scale = self.head_dim ** -0.5 - self.logit_scale_max = logit_scale_max - - # keeping in_proj in this form (instead of nn.Linear) to match weight scheme of original - if lora_adapt: - print("!!!!!!!!!!using lora for qkv projection!!!!!!!!!!") - self.in_proj = lora.MergedLinear(dim, 3*dim, r=rank, enable_lora=[True, False, True]) - else: - self.in_proj = nn.Linear(dim, dim * 3) - # self.in_proj_weight = nn.Parameter(torch.randn((dim * 3, dim)) * self.scale) - # if qkv_bias: - # self.in_proj_bias = nn.Parameter(torch.zeros(dim * 3)) - # else: - # self.in_proj_bias = None - - if self.scaled_cosine: - self.logit_scale = nn.Parameter(torch.log(10 * torch.ones((num_heads, 1, 1)))) - else: - self.logit_scale = None - self.attn_drop = nn.Dropout(attn_drop) - if self.scale_heads: - self.head_scale = nn.Parameter(torch.ones((num_heads, 1, 1))) - else: - self.head_scale = None - self.out_proj = nn.Linear(dim, dim) if not lora_adapt else lora.Linear(dim, dim, r=rank) - self.out_drop = nn.Dropout(proj_drop) - - def forward(self, x, attn_mask = None): - L, N, C = x.shape - q, k, v = self.in_proj(x).chunk(3, dim=-1) - q = q.contiguous().view(L, N * self.num_heads, -1).transpose(0, 1) - k = k.contiguous().view(L, N * self.num_heads, -1).transpose(0, 1) - v = v.contiguous().view(L, N * self.num_heads, -1).transpose(0, 1) - - if self.logit_scale is not None: - attn = torch.bmm(F.normalize(q, dim=-1), F.normalize(k, dim=-1).transpose(-1, -2)) - logit_scale = torch.clamp(self.logit_scale, max=self.logit_scale_max).exp() - attn = attn.view(N, self.num_heads, L, L) * logit_scale - attn = attn.view(-1, L, L) - else: - q = q * self.scale - attn = torch.bmm(q, k.transpose(-2, -1)) - - if attn_mask is not None: - if attn_mask.dtype == torch.bool: - new_attn_mask = torch.zeros_like(attn_mask, dtype=q.dtype) - new_attn_mask.masked_fill_(attn_mask, float("-inf")) - attn_mask = new_attn_mask - attn += attn_mask - - attn = attn.softmax(dim=-1) - attn = self.attn_drop(attn) - - x = torch.bmm(attn, v) - if self.head_scale is not None: - x = x.view(N, self.num_heads, L, C) * self.head_scale - x = x.view(-1, L, C) - x = x.transpose(0, 1).reshape(L, N, C) - x = self.out_proj(x) - x = self.out_drop(x) - return x, attn - - -class CustomResidualAttentionBlock(nn.Module): - def __init__(self, d_model: int, n_head: int, attn_mask: torch.Tensor = None, lora_adapt=False, rank=16): - super().__init__() - - self.attn = Attention(d_model, n_head, lora_adapt=lora_adapt, rank=rank) - self.ln_1 = LayerNorm(d_model) - self.mlp = nn.Sequential(OrderedDict([ - ("c_fc", nn.Linear(d_model, d_model * 4) if not lora_adapt else lora.Linear(d_model, d_model*4, r=rank)), - ("gelu", QuickGELU()), - ("c_proj", nn.Linear(d_model * 4, d_model) if not lora_adapt else lora.Linear(d_model*4, d_model, r=rank)) - ])) - self.ln_2 = LayerNorm(d_model) - self.attn_mask = attn_mask - - def attention(self, x: torch.Tensor): - self.attn_mask = self.attn_mask.to(dtype=x.dtype, device=x.device) if self.attn_mask is not None else None - return self.attn(x, attn_mask=self.attn_mask) - - def forward(self, x: torch.Tensor, return_attn=False): - attn_out, attn = self.attention(self.ln_1(x)) - x = x + attn_out - x = x + self.mlp(self.ln_2(x)) - if return_attn: - return x, attn - else: - return x - -class ResidualAttentionBlock(nn.Module): - def __init__(self, d_model: int, n_head: int, attn_mask: torch.Tensor = None): - super().__init__() - - self.attn = nn.MultiheadAttention(d_model, n_head) - self.ln_1 = LayerNorm(d_model) - self.mlp = nn.Sequential(OrderedDict([ - ("c_fc", nn.Linear(d_model, d_model * 4)), - ("gelu", QuickGELU()), - ("c_proj", nn.Linear(d_model * 4, d_model)) - ])) - self.ln_2 = LayerNorm(d_model) - self.attn_mask = attn_mask - - def attention(self, x: torch.Tensor): - self.attn_mask = self.attn_mask.to(dtype=x.dtype, device=x.device) if self.attn_mask is not None else None - return self.attn(x, x, x, attn_mask=self.attn_mask)[0] - - def forward(self, x: torch.Tensor): - x = x + self.attention(self.ln_1(x)) - x = x + self.mlp(self.ln_2(x)) - return x - -class Transformer(nn.Module): - def __init__(self, width: int, layers: int, heads: int, attn_mask: torch.Tensor = None): - super().__init__() - self.width = width - self.layers = layers - self.resblocks = nn.Sequential(*[ResidualAttentionBlock(width, heads, attn_mask) for _ in range(layers)]) - - def forward(self, x: torch.Tensor): - return self.resblocks(x) - -class CustomTransformer(nn.Module): - def __init__(self, width: int, layers: int, heads: int, attn_mask: torch.Tensor = None, lora_adapt=False, rank=16): - super().__init__() - self.width = width - self.layers = layers - self.resblocks = nn.Sequential(*[CustomResidualAttentionBlock(width, heads, attn_mask, lora_adapt=lora_adapt, rank=rank) for _ in range(layers)]) - - def forward(self, x: torch.Tensor, return_attn=False): - if return_attn: - for i, block in enumerate(self.resblocks): - if i == len(self.resblocks) - 1: - return block(x, return_attn=True) - else: - x = block(x) - assert False - return self.resblocks(x) - -class VisionTransformer(nn.Module): - def __init__(self, input_resolution: int, patch_size: int, width: int, layers: int, heads: int, output_dim: int, lora_adapt=False, rank=16): - super().__init__() - self.input_resolution = input_resolution - self.output_dim = output_dim - self.conv1 = nn.Conv2d(in_channels=3, out_channels=width, kernel_size=patch_size, stride=patch_size, bias=False) - self.conv1_alpha = nn.Conv2d(in_channels=1, out_channels=width, kernel_size=patch_size, stride=patch_size, bias=False) - - scale = width ** -0.5 - self.class_embedding = nn.Parameter(scale * torch.randn(width)) - self.positional_embedding = nn.Parameter(scale * torch.randn((input_resolution // patch_size) ** 2 + 1, width)) - self.ln_pre = LayerNorm(width) - - self.transformer = CustomTransformer(width, layers, heads, lora_adapt=lora_adapt, rank=rank) - - self.ln_post = LayerNorm(width) - self.proj = nn.Parameter(scale * torch.randn(width, output_dim)) - - def forward(self, x: torch.Tensor, alpha=None, return_attn=False, return_patches=False): - # if x dtype is different from conv1, convert it - if x.dtype != self.conv1.weight.dtype: - x = x.type(self.conv1.weight.dtype) - - if alpha.dtype != self.conv1_alpha.weight.dtype: - alpha = alpha.type(self.conv1_alpha.weight.dtype) - - x = self.conv1(x) # shape = [*, width, grid, grid] - # ASSUME alpha is always not None! - x = x + self.conv1_alpha(alpha) - x = x.reshape(x.shape[0], x.shape[1], -1) # shape = [*, width, grid ** 2] - x = x.permute(0, 2, 1) # shape = [*, grid ** 2, width] - x = torch.cat([self.class_embedding.to(x.dtype) + torch.zeros(x.shape[0], 1, x.shape[-1], dtype=x.dtype, device=x.device), x], dim=1) # shape = [*, grid ** 2 + 1, width] - x = x + self.positional_embedding.to(x.dtype) - x = self.ln_pre(x) - - x = x.permute(1, 0, 2) # NLD -> LND - if return_attn: - x, attn_last = self.transformer(x, return_attn=True) - else: - x = self.transformer(x, return_attn=False) - x = x.permute(1, 0, 2) # LND -> NLD - - if not return_patches: - x = self.ln_post(x[:, 0, :]) - else: - x = self.ln_post(x) - - if self.proj is not None: - x = x @ self.proj - if return_attn: - return x, attn_last - else: - return x - - -class CLIP(nn.Module): - def __init__(self, - embed_dim: int, - # vision - image_resolution: int, - vision_layers: Union[Tuple[int, int, int, int], int], - vision_width: int, - vision_patch_size: int, - # text - context_length: int, - vocab_size: int, - transformer_width: int, - transformer_heads: int, - transformer_layers: int, - lora_adapt = False, - rank = 16, - ): - super().__init__() - - self.context_length = context_length - - if isinstance(vision_layers, (tuple, list)): - vision_heads = vision_width * 32 // 64 - self.visual = ModifiedResNet( - layers=vision_layers, - output_dim=embed_dim, - heads=vision_heads, - input_resolution=image_resolution, - width=vision_width - ) - else: - vision_heads = vision_width // 64 - self.visual = VisionTransformer( - input_resolution=image_resolution, - patch_size=vision_patch_size, - width=vision_width, - layers=vision_layers, - heads=vision_heads, - output_dim=embed_dim, - lora_adapt=lora_adapt, - rank=rank - ) - - self.transformer = Transformer( - width=transformer_width, - layers=transformer_layers, - heads=transformer_heads, - attn_mask=self.build_attention_mask() - ) - - self.vocab_size = vocab_size - self.token_embedding = nn.Embedding(vocab_size, transformer_width) - self.positional_embedding = nn.Parameter(torch.empty(self.context_length, transformer_width)) - self.ln_final = LayerNorm(transformer_width) - - self.text_projection = nn.Parameter(torch.empty(transformer_width, embed_dim)) - self.logit_scale = nn.Parameter(torch.ones([]) * np.log(1 / 0.07)) - - self.initialize_parameters() - - def initialize_parameters(self): - nn.init.normal_(self.token_embedding.weight, std=0.02) - nn.init.normal_(self.positional_embedding, std=0.01) - - if isinstance(self.visual, ModifiedResNet): - if self.visual.attnpool is not None: - std = self.visual.attnpool.c_proj.in_features ** -0.5 - nn.init.normal_(self.visual.attnpool.q_proj.weight, std=std) - nn.init.normal_(self.visual.attnpool.k_proj.weight, std=std) - nn.init.normal_(self.visual.attnpool.v_proj.weight, std=std) - nn.init.normal_(self.visual.attnpool.c_proj.weight, std=std) - - for resnet_block in [self.visual.layer1, self.visual.layer2, self.visual.layer3, self.visual.layer4]: - for name, param in resnet_block.named_parameters(): - if name.endswith("bn3.weight"): - nn.init.zeros_(param) - - proj_std = (self.transformer.width ** -0.5) * ((2 * self.transformer.layers) ** -0.5) - attn_std = self.transformer.width ** -0.5 - fc_std = (2 * self.transformer.width) ** -0.5 - for block in self.transformer.resblocks: - nn.init.normal_(block.attn.in_proj_weight, std=attn_std) - nn.init.normal_(block.attn.out_proj.weight, std=proj_std) - nn.init.normal_(block.mlp.c_fc.weight, std=fc_std) - nn.init.normal_(block.mlp.c_proj.weight, std=proj_std) - - if self.text_projection is not None: - nn.init.normal_(self.text_projection, std=self.transformer.width ** -0.5) - - def build_attention_mask(self): - # lazily create causal attention mask, with full attention between the vision tokens - # pytorch uses additive attention mask; fill with -inf - mask = torch.empty(self.context_length, self.context_length) - mask.fill_(float("-inf")) - mask.triu_(1) # zero out the lower diagonal - return mask - - @property - def dtype(self): - if not hasattr(self.visual, "conv1"): - return self.visual.module.conv1.weight.dtype - return self.visual.conv1.weight.dtype - - def encode_image(self, image, alpha): - assert alpha is not None - return self.visual(image.type(self.dtype), alpha.type(self.dtype)) - - def encode_text(self, text): - x = self.token_embedding(text).type(self.dtype) # [batch_size, n_ctx, d_model] - - x = x + self.positional_embedding.type(self.dtype) - x = x.permute(1, 0, 2) # NLD -> LND - x = self.transformer(x) - x = x.permute(1, 0, 2) # LND -> NLD - x = self.ln_final(x).type(self.dtype) - - # x.shape = [batch_size, n_ctx, transformer.width] - # take features from the eot embedding (eot_token is the highest number in each sequence) - x = x[torch.arange(x.shape[0]), text.argmax(dim=-1)] @ self.text_projection - - return x - - def forward(self, image, text, alpha): - - image_features = self.encode_image(image, alpha) - text_features = self.encode_text(text) - - # normalized features - image_features = image_features / image_features.norm(dim=1, keepdim=True) - text_features = text_features / text_features.norm(dim=1, keepdim=True) - - # cosine similarity as logits - logit_scale = self.logit_scale.exp() - logits_per_image = logit_scale * image_features @ text_features.t() - logits_per_text = logits_per_image.t() - - # shape = [global_batch_size, global_batch_size] - return logits_per_image, logits_per_text - - -def convert_weights(model: nn.Module): - """Convert applicable model parameters to fp16""" - - def _convert_weights_to_fp16(l): - if isinstance(l, (nn.Conv1d, nn.Conv2d, nn.Linear)): - l.weight.data = l.weight.data.half() - if l.bias is not None: - l.bias.data = l.bias.data.half() - - if isinstance(l, nn.MultiheadAttention): - for attr in [*[f"{s}_proj_weight" for s in ["in", "q", "k", "v"]], "in_proj_bias", "bias_k", "bias_v"]: - tensor = getattr(l, attr) - if tensor is not None: - tensor.data = tensor.data.half() - - for name in ["text_projection", "proj"]: - if hasattr(l, name): - attr = getattr(l, name) - if attr is not None: - attr.data = attr.data.half() - - model.apply(_convert_weights_to_fp16) - - -def build_model(state_dict: dict, lora_adapt=False, rank=16): - vit = "visual.proj" in state_dict - - if vit: - vision_width = state_dict["visual.conv1.weight"].shape[0] - vision_layers = len([k for k in state_dict.keys() if k.startswith("visual.") and k.endswith(".attn.in_proj_weight")]) - vision_patch_size = state_dict["visual.conv1.weight"].shape[-1] - grid_size = round((state_dict["visual.positional_embedding"].shape[0] - 1) ** 0.5) - image_resolution = vision_patch_size * grid_size - else: - counts: list = [len(set(k.split(".")[2] for k in state_dict if k.startswith(f"visual.layer{b}"))) for b in [1, 2, 3, 4]] - vision_layers = tuple(counts) - vision_width = state_dict["visual.layer1.0.conv1.weight"].shape[0] - output_width = round((state_dict["visual.attnpool.positional_embedding"].shape[0] - 1) ** 0.5) - vision_patch_size = None - assert output_width ** 2 + 1 == state_dict["visual.attnpool.positional_embedding"].shape[0] - image_resolution = output_width * 32 - - embed_dim = state_dict["text_projection"].shape[1] - context_length = state_dict["positional_embedding"].shape[0] - vocab_size = state_dict["token_embedding.weight"].shape[0] - transformer_width = state_dict["ln_final.weight"].shape[0] - transformer_heads = transformer_width // 64 - transformer_layers = len(set(k.split(".")[2] for k in state_dict if k.startswith("transformer.resblocks"))) - - # always load lora version - model = CLIP( - embed_dim, - image_resolution, vision_layers, vision_width, vision_patch_size, - context_length, vocab_size, transformer_width, transformer_heads, transformer_layers, - lora_adapt=lora_adapt, rank=rank, - ) - - for key in ["input_resolution", "context_length", "vocab_size"]: - if key in state_dict: - del state_dict[key] - # para_wb to linear - new_state_dict = collections.OrderedDict() - for k, v in state_dict.items(): - if 'visual' in k: - if 'in_proj_weight' in k: - new_state_dict[k.replace('in_proj_weight', 'in_proj.weight')] = v - elif 'in_proj_bias' in k: - new_state_dict[k.replace('in_proj_bias', 'in_proj.bias')] = v - else: - new_state_dict[k] = v - else: - new_state_dict[k] = v - - state_dict = new_state_dict - # add rgba_conv_weight - if 'visual.conv1_alpha.weight' not in state_dict.keys(): # zero initialization on alpha channel - rgb_weight = state_dict['visual.conv1.weight'].clone().detach() - rgba_weigth = torch.zeros_like(rgb_weight)[:, 0:1, :, :] - state_dict['visual.conv1_alpha.weight'] = rgba_weigth - convert_weights(model) - model.load_state_dict(state_dict, strict=False) - return model.eval() diff --git a/src/alphaclip/alpha_clip/simple_tokenizer.py b/src/alphaclip/alpha_clip/simple_tokenizer.py deleted file mode 100644 index 0a66286b7d5019c6e221932a813768038f839c91..0000000000000000000000000000000000000000 --- a/src/alphaclip/alpha_clip/simple_tokenizer.py +++ /dev/null @@ -1,132 +0,0 @@ -import gzip -import html -import os -from functools import lru_cache - -import ftfy -import regex as re - - -@lru_cache() -def default_bpe(): - return os.path.join(os.path.dirname(os.path.abspath(__file__)), "bpe_simple_vocab_16e6.txt.gz") - - -@lru_cache() -def bytes_to_unicode(): - """ - Returns list of utf-8 byte and a corresponding list of unicode strings. - The reversible bpe codes work on unicode strings. - This means you need a large # of unicode characters in your vocab if you want to avoid UNKs. - When you're at something like a 10B token dataset you end up needing around 5K for decent coverage. - This is a signficant percentage of your normal, say, 32K bpe vocab. - To avoid that, we want lookup tables between utf-8 bytes and unicode strings. - And avoids mapping to whitespace/control characters the bpe code barfs on. - """ - bs = list(range(ord("!"), ord("~")+1))+list(range(ord("ยก"), ord("ยฌ")+1))+list(range(ord("ยฎ"), ord("รฟ")+1)) - cs = bs[:] - n = 0 - for b in range(2**8): - if b not in bs: - bs.append(b) - cs.append(2**8+n) - n += 1 - cs = [chr(n) for n in cs] - return dict(zip(bs, cs)) - - -def get_pairs(word): - """Return set of symbol pairs in a word. - Word is represented as tuple of symbols (symbols being variable-length strings). - """ - pairs = set() - prev_char = word[0] - for char in word[1:]: - pairs.add((prev_char, char)) - prev_char = char - return pairs - - -def basic_clean(text): - text = ftfy.fix_text(text) - text = html.unescape(html.unescape(text)) - return text.strip() - - -def whitespace_clean(text): - text = re.sub(r'\s+', ' ', text) - text = text.strip() - return text - - -class SimpleTokenizer(object): - def __init__(self, bpe_path: str = default_bpe()): - self.byte_encoder = bytes_to_unicode() - self.byte_decoder = {v: k for k, v in self.byte_encoder.items()} - merges = gzip.open(bpe_path).read().decode("utf-8").split('\n') - merges = merges[1:49152-256-2+1] - merges = [tuple(merge.split()) for merge in merges] - vocab = list(bytes_to_unicode().values()) - vocab = vocab + [v+'' for v in vocab] - for merge in merges: - vocab.append(''.join(merge)) - vocab.extend(['<|startoftext|>', '<|endoftext|>']) - self.encoder = dict(zip(vocab, range(len(vocab)))) - self.decoder = {v: k for k, v in self.encoder.items()} - self.bpe_ranks = dict(zip(merges, range(len(merges)))) - self.cache = {'<|startoftext|>': '<|startoftext|>', '<|endoftext|>': '<|endoftext|>'} - self.pat = re.compile(r"""<\|startoftext\|>|<\|endoftext\|>|'s|'t|'re|'ve|'m|'ll|'d|[\p{L}]+|[\p{N}]|[^\s\p{L}\p{N}]+""", re.IGNORECASE) - - def bpe(self, token): - if token in self.cache: - return self.cache[token] - word = tuple(token[:-1]) + ( token[-1] + '',) - pairs = get_pairs(word) - - if not pairs: - return token+'' - - while True: - bigram = min(pairs, key = lambda pair: self.bpe_ranks.get(pair, float('inf'))) - if bigram not in self.bpe_ranks: - break - first, second = bigram - new_word = [] - i = 0 - while i < len(word): - try: - j = word.index(first, i) - new_word.extend(word[i:j]) - i = j - except: - new_word.extend(word[i:]) - break - - if word[i] == first and i < len(word)-1 and word[i+1] == second: - new_word.append(first+second) - i += 2 - else: - new_word.append(word[i]) - i += 1 - new_word = tuple(new_word) - word = new_word - if len(word) == 1: - break - else: - pairs = get_pairs(word) - word = ' '.join(word) - self.cache[token] = word - return word - - def encode(self, text): - bpe_tokens = [] - text = whitespace_clean(basic_clean(text)).lower() - for token in re.findall(self.pat, text): - token = ''.join(self.byte_encoder[b] for b in token.encode('utf-8')) - bpe_tokens.extend(self.encoder[bpe_token] for bpe_token in self.bpe(token).split(' ')) - return bpe_tokens - - def decode(self, tokens): - text = ''.join([self.decoder[token] for token in tokens]) - text = bytearray([self.byte_decoder[c] for c in text]).decode('utf-8', errors="replace").replace('', ' ') - return text diff --git a/src/alphaclip/alpha_mask_utils.py b/src/alphaclip/alpha_mask_utils.py deleted file mode 100644 index 4665743b0fbaeb1e75eb6104e1270e7d7807652d..0000000000000000000000000000000000000000 --- a/src/alphaclip/alpha_mask_utils.py +++ /dev/null @@ -1,111 +0,0 @@ -""" -Utility functions for converting bboxes and traces to alpha masks for AlphaClip. -""" - -import torch -import math - - -def bbox_to_alpha_mask(bbox, grid_size, patch_size, crop_dim): - """ - Convert a single bounding box to an alpha mask for AlphaClip. - - Args: - bbox: [x_min, y_min, w, h] format in original coordinates - grid_size: Number of patches per side (e.g., 37 for 518/14) - patch_size: Size of each patch in pixels - crop_dim: Size of the cropped image - - Returns: - alpha_mask: Binary mask of shape (grid_size, grid_size) - """ - alpha_mask = torch.zeros((grid_size, grid_size)) - - # Convert bbox to patch coordinates - x_min, y_min, w, h = bbox - x_max = x_min + w - y_max = y_min + h - - # Scale to patch grid coordinates - x1_patch = int(x_min // patch_size) - y1_patch = int(y_min // patch_size) - x2_patch = int(x_max // patch_size) - y2_patch = int(y_max // patch_size) - - # Clamp to grid bounds - x1_patch = max(0, min(x1_patch, grid_size - 1)) - y1_patch = max(0, min(y1_patch, grid_size - 1)) - x2_patch = max(0, min(x2_patch, grid_size)) # Allow up to grid_size for exclusive end - y2_patch = max(0, min(y2_patch, grid_size)) - - # Set the region to 1 (using slice notation for proper indexing) - if x2_patch > x1_patch and y2_patch > y1_patch: - alpha_mask[y1_patch:y2_patch, x1_patch:x2_patch] = 1.0 - - return alpha_mask - - -def bboxes_to_alpha_mask(bboxes, grid_size, patch_size, crop_dim): - """ - Convert multiple bboxes to a single OR-ed alpha mask. - - Args: - bboxes: Tensor of bboxes in [x_min, y_min, w, h] format, shape [n_boxes, 4] - grid_size: Number of patches per side - patch_size: Size of each patch in pixels - crop_dim: Size of the cropped image - - Returns: - alpha_mask: Binary mask of shape (grid_size, grid_size) - """ - alpha_mask = torch.zeros((grid_size, grid_size)) - - for bbox in bboxes: - # Skip dummy boxes (negative values) - if bbox.sum().item() < 0: - continue - - bbox_mask = bbox_to_alpha_mask(bbox, grid_size, patch_size, crop_dim) - alpha_mask = torch.logical_or(alpha_mask, bbox_mask).float() - - return alpha_mask - - -def trace_to_alpha_mask(trace, grid_size): - """ - Convert a trace to an alpha mask using the existing map_traces_to_grid function. - - Args: - trace: List of trace points with 'x' and 'y' coordinates (normalized 0-1) - grid_size: Number of patches per side - - Returns: - alpha_mask: Binary mask of shape (grid_size, grid_size) - """ - from src.bbox_utils import map_traces_to_grid - - alpha_mask = map_traces_to_grid(trace, grid_size) - # Convert to binary (any value > 0 becomes 1) - alpha_mask = (alpha_mask > 0).float() - - return alpha_mask - - -def traces_to_alpha_mask(traces, grid_size): - """ - Convert multiple traces to a single OR-ed alpha mask. - - Args: - traces: List of traces - grid_size: Number of patches per side - - Returns: - alpha_mask: Binary mask of shape (grid_size, grid_size) - """ - alpha_mask = torch.zeros((grid_size, grid_size)) - - for trace in traces: - trace_mask = trace_to_alpha_mask(trace, grid_size) - alpha_mask = torch.logical_or(alpha_mask, trace_mask).float() - - return alpha_mask diff --git a/src/alphaclip/alphaclip_loader.py b/src/alphaclip/alphaclip_loader.py deleted file mode 100644 index f0c10e8abb8f116505d4fd0df2351b5f42f44a97..0000000000000000000000000000000000000000 --- a/src/alphaclip/alphaclip_loader.py +++ /dev/null @@ -1,233 +0,0 @@ -""" -AlphaCLIP Standalone Loader - -This module provides a simple interface to load and use AlphaCLIP models. -It exposes the core functionality of AlphaCLIP in a standalone package. - -Usage: - from alphaclip_loader import AlphaCLIPLoader - - # Initialize the loader - loader = AlphaCLIPLoader() - - # Load a model - model, preprocess = loader.load_model("ViT-B/16") - - # Tokenize text - tokens = loader.tokenize("A photo of a cat") - - # Get available models - models = loader.available_models() -""" - -import os -import sys -from typing import Union, List, Tuple, Optional - -# Check for critical dependencies -missing_deps = [] -try: - import torch -except ImportError: - missing_deps.append("torch") - -try: - from PIL import Image -except ImportError: - missing_deps.append("Pillow") - -if missing_deps: - raise ImportError(f"Missing required dependencies: {', '.join(missing_deps)}. " - f"Please install them with: pip install {' '.join(missing_deps)}") - -# Add the alpha_clip directory to the path -_current_dir = os.path.dirname(os.path.abspath(__file__)) -_alpha_clip_dir = os.path.join(_current_dir, 'alpha_clip') -if _alpha_clip_dir not in sys.path: - sys.path.insert(0, _alpha_clip_dir) - -# Import the alpha_clip modules -try: - #import .alpha_clip - from .alpha_clip import available_models, load, tokenize -except ImportError as e: - raise ImportError(f"Failed to import alpha_clip modules: {e}. Please ensure all dependencies are installed.") - - -class AlphaCLIPLoader: - """ - A convenience wrapper for AlphaCLIP functionality. - - This class provides a clean interface to load AlphaCLIP models and - perform text tokenization. - """ - - def __init__(self, default_device: Optional[str] = None): - """ - Initialize the AlphaCLIP loader. - - Args: - default_device: Default device to load models on. If None, will use - CUDA if available, otherwise CPU. - """ - if default_device is None: - self.default_device = "cuda" if torch.cuda.is_available() else "cpu" - else: - self.default_device = default_device - - def available_models(self) -> List[str]: - """ - Get list of available AlphaCLIP model names. - - Returns: - List of model names that can be used with load_model() - """ - return available_models() - - def load_model( - self, - name: str, - alpha_vision_ckpt_pth: str = "None", - device: Optional[Union[str, torch.device]] = None, - jit: bool = False, - download_root: Optional[str] = None, - lora_adapt: bool = False, - rank: int = 16 - ) -> Tuple[torch.nn.Module, callable]: - """ - Load an AlphaCLIP model. - - Args: - name: Model name (e.g., "ViT-B/16") or path to checkpoint - alpha_vision_ckpt_pth: Path to additional vision checkpoint - device: Device to load model on (defaults to self.default_device) - jit: Whether to load JIT optimized model - download_root: Directory to download models to - lora_adapt: Whether to use LoRA adaptation - rank: LoRA rank if lora_adapt is True - - Returns: - Tuple of (model, preprocess_function) - """ - if device is None: - device = self.default_device - - return load( - name=name, - alpha_vision_ckpt_pth=alpha_vision_ckpt_pth, - device=device, - jit=jit, - download_root=download_root, - lora_adapt=lora_adapt, - rank=rank - ) - - def tokenize( - self, - texts: Union[str, List[str]], - context_length: int = 77, - truncate: bool = True - ) -> torch.Tensor: - """ - Tokenize text for use with AlphaCLIP models. - - Args: - texts: String or list of strings to tokenize - context_length: Maximum token length (default 77) - truncate: Whether to truncate long texts - - Returns: - Tensor of tokenized text - """ - return tokenize(texts, context_length, truncate) - - def encode_text(self, model: torch.nn.Module, texts: Union[str, List[str]]) -> torch.Tensor: - """ - Convenience method to tokenize and encode text. - - Args: - model: Loaded AlphaCLIP model - texts: Text(s) to encode - - Returns: - Text embeddings tensor - """ - tokens = self.tokenize(texts) - if hasattr(model, 'token_embedding'): - # Move tokens to same device as model - device = next(model.parameters()).device - tokens = tokens.to(device) - - with torch.no_grad(): - text_features = model.encode_text(tokens) - - return text_features - - def encode_image(self, model: torch.nn.Module, images: torch.Tensor) -> torch.Tensor: - """ - Convenience method to encode images. - - Args: - model: Loaded AlphaCLIP model - images: Preprocessed image tensor - - Returns: - Image embeddings tensor - """ - with torch.no_grad(): - image_features = model.encode_image(images) - - return image_features - - def get_similarity(self, text_features: torch.Tensor, image_features: torch.Tensor) -> torch.Tensor: - """ - Compute cosine similarity between text and image features. - - Args: - text_features: Text embedding tensor - image_features: Image embedding tensor - - Returns: - Similarity scores tensor - """ - # Normalize features - text_features = text_features / text_features.norm(dim=-1, keepdim=True) - image_features = image_features / image_features.norm(dim=-1, keepdim=True) - - # Compute similarity - similarity = (text_features @ image_features.T) - return similarity - - -# Convenience function for quick model loading -def load_alphaclip( - model_name: str = "ViT-B/16", - device: Optional[str] = None, - alpha_vision_ckpt_pth: str = "None", - download_root = '/raid/datasets/models_weights/alphaclip', - **kwargs -) -> Tuple[AlphaCLIPLoader, torch.nn.Module, callable]: - """ - Quick function to load AlphaCLIP with a loader instance. - - Args: - model_name: Name of the model to load - device: Device to use - **kwargs: Additional arguments for model loading - - Returns: - Tuple of (loader, model, preprocess_function) - """ - loader = AlphaCLIPLoader(default_device=device) - model, preprocess = loader.load_model(model_name, **kwargs) - return loader, model, preprocess - - -# Make key functions available at module level -__all__ = [ - 'AlphaCLIPLoader', - 'load_alphaclip', - 'available_models', - 'load', - 'tokenize' -] diff --git a/src/alphaclip/example.py b/src/alphaclip/example.py deleted file mode 100644 index b4ff9d99ecc5687e8b87824760c6dcde891dc3b5..0000000000000000000000000000000000000000 --- a/src/alphaclip/example.py +++ /dev/null @@ -1,76 +0,0 @@ -#!/usr/bin/env python3 -""" -Example usage of AlphaCLIP Standalone - -This script demonstrates basic usage of the AlphaCLIP standalone package. -""" - -import torch -import numpy as np -from alphaclip_loader import AlphaCLIPLoader, load_alphaclip - -def main(): - print("AlphaCLIP Standalone Example") - print("=" * 40) - - # Check if CUDA is available - device = "cuda" if torch.cuda.is_available() else "cpu" - print(f"Using device: {device}") - - # Method 1: Using the loader class - print("\n1. Using AlphaCLIPLoader class:") - loader = AlphaCLIPLoader(default_device=device) - - # Show available models - models = loader.available_models() - print(f"Available models: {models}") - - # Load a model - print("\nLoading ViT-B/16 model...") - model, preprocess = loader.load_model("ViT-B/16") - print(f"Model loaded successfully!") - - # Test text encoding - test_texts = [ - "a photo of a cat", - "a dog running in the park", - "a beautiful sunset over the ocean" - ] - - print(f"\nEncoding {len(test_texts)} texts...") - text_features = loader.encode_text(model, test_texts) - print(f"Text features shape: {text_features.shape}") - - # Compute similarities between texts - print("\nComputing text-to-text similarities:") - similarities = loader.get_similarity(text_features, text_features) - - for i, text1 in enumerate(test_texts): - for j, text2 in enumerate(test_texts): - if i <= j: # Only show upper triangle - sim = similarities[i, j].item() - print(f" '{text1}' <-> '{text2}': {sim:.3f}") - - # Method 2: Using the quick loader function - print("\n\n2. Using quick loader function:") - loader2, model2, preprocess2 = load_alphaclip("ViT-B/16", device=device) - - # Test single text - single_text = "a red apple on a wooden table" - single_features = loader2.encode_text(model2, single_text) - print(f"Single text '{single_text}' encoded to shape: {single_features.shape}") - - # Test tokenization - print("\n3. Tokenization example:") - tokens = loader.tokenize(test_texts) - print(f"Tokenized {len(test_texts)} texts to shape: {tokens.shape}") - - # Show some token examples - print("First few tokens for each text:") - for i, text in enumerate(test_texts): - print(f" '{text}': {tokens[i][:10].tolist()}...") - - print("\nExample completed successfully!") - -if __name__ == "__main__": - main() diff --git a/src/alphaclip/requirements.txt b/src/alphaclip/requirements.txt deleted file mode 100644 index f2d8cb2b2ad48549df3bcba64ab2aeac7680bbcb..0000000000000000000000000000000000000000 --- a/src/alphaclip/requirements.txt +++ /dev/null @@ -1,10 +0,0 @@ -# Core dependencies for AlphaCLIP standalone -torch>=1.7.1 -torchvision -ftfy -regex -tqdm -loralib -Pillow -numpy -packaging diff --git a/src/alphaclip/setup.py b/src/alphaclip/setup.py deleted file mode 100644 index 5f2c2652e1ca28d0e1ab0e0dbec0287202ccc8b4..0000000000000000000000000000000000000000 --- a/src/alphaclip/setup.py +++ /dev/null @@ -1,47 +0,0 @@ -""" -Setup script for AlphaCLIP Standalone -""" - -from setuptools import setup, find_packages -import os - -# Read requirements -with open('requirements.txt', 'r') as f: - requirements = [line.strip() for line in f if line.strip() and not line.startswith('#')] - -# Read README if it exists -readme_content = "" -if os.path.exists('README.md'): - with open('README.md', 'r', encoding='utf-8') as f: - readme_content = f.read() - -setup( - name="alphaclip-standalone", - version="1.0.0", - author="AlphaCLIP Team", - description="Standalone version of AlphaCLIP for easy integration", - long_description=readme_content, - long_description_content_type="text/markdown", - packages=find_packages(), - package_data={ - 'alpha_clip': ['*.gz'], # Include the tokenizer vocabulary file - }, - include_package_data=True, - install_requires=requirements, - python_requires=">=3.7", - classifiers=[ - "Development Status :: 4 - Beta", - "Intended Audience :: Developers", - "Intended Audience :: Science/Research", - "License :: OSI Approved :: MIT License", - "Operating System :: OS Independent", - "Programming Language :: Python :: 3", - "Programming Language :: Python :: 3.7", - "Programming Language :: Python :: 3.8", - "Programming Language :: Python :: 3.9", - "Programming Language :: Python :: 3.10", - "Topic :: Scientific/Engineering :: Artificial Intelligence", - "Topic :: Software Development :: Libraries :: Python Modules", - ], - keywords="clip, vision, language, deep learning, pytorch", -) diff --git a/src/alphaclip/test_installation.py b/src/alphaclip/test_installation.py deleted file mode 100644 index 5fd389c2fdee2e76f0e86cd50f6d5b9b95adab4c..0000000000000000000000000000000000000000 --- a/src/alphaclip/test_installation.py +++ /dev/null @@ -1,149 +0,0 @@ -#!/usr/bin/env python3 -""" -Test script for AlphaCLIP Standalone - -This script tests the basic functionality of the standalone package -to ensure everything is working correctly. -""" - -import sys -import os - -def test_imports(): - """Test that all required modules can be imported.""" - print("Testing imports...") - - try: - import torch - print(f"โœ“ PyTorch {torch.__version__} imported successfully") - except ImportError as e: - print(f"โœ— Failed to import PyTorch: {e}") - return False - - try: - import torchvision - print(f"โœ“ Torchvision imported successfully") - except ImportError as e: - print(f"โœ— Failed to import torchvision: {e}") - return False - - try: - from alphaclip_loader import AlphaCLIPLoader - print("โœ“ AlphaCLIPLoader imported successfully") - except ImportError as e: - print(f"โœ— Failed to import AlphaCLIPLoader: {e}") - return False - - try: - import loralib - print("โœ“ LoraLib imported successfully") - except ImportError as e: - print(f"โœ— Failed to import loralib: {e}") - return False - - return True - -def test_model_loading(): - """Test loading a model.""" - print("\nTesting model loading...") - - try: - from alphaclip_loader import AlphaCLIPLoader - - loader = AlphaCLIPLoader(default_device="cpu") # Use CPU for testing - models = loader.available_models() - print(f"โœ“ Available models: {models}") - - # Try to load the smallest model for testing - print("Loading ViT-B/32 model (this may take a while for first download)...") - model, preprocess = loader.load_model("ViT-B/32", device="cpu") - print("โœ“ Model loaded successfully") - - return True - - except Exception as e: - print(f"โœ— Failed to load model: {e}") - return False - -def test_tokenization(): - """Test text tokenization.""" - print("\nTesting tokenization...") - - try: - from alphaclip_loader import AlphaCLIPLoader - - loader = AlphaCLIPLoader() - test_text = "a photo of a cat" - tokens = loader.tokenize(test_text) - print(f"โœ“ Tokenized '{test_text}' to shape {tokens.shape}") - - # Test batch tokenization - test_texts = ["a cat", "a dog", "a bird"] - batch_tokens = loader.tokenize(test_texts) - print(f"โœ“ Batch tokenized {len(test_texts)} texts to shape {batch_tokens.shape}") - - return True - - except Exception as e: - print(f"โœ— Failed tokenization test: {e}") - return False - -def test_text_encoding(): - """Test text encoding with a loaded model.""" - print("\nTesting text encoding...") - - try: - from alphaclip_loader import AlphaCLIPLoader - - loader = AlphaCLIPLoader(default_device="cpu") - model, preprocess = loader.load_model("ViT-B/32", device="cpu") - - test_text = "a photo of a cat" - features = loader.encode_text(model, test_text) - print(f"โœ“ Encoded text to features with shape {features.shape}") - - # Test batch encoding - test_texts = ["a cat", "a dog"] - batch_features = loader.encode_text(model, test_texts) - print(f"โœ“ Batch encoded {len(test_texts)} texts to shape {batch_features.shape}") - - return True - - except Exception as e: - print(f"โœ— Failed text encoding test: {e}") - return False - -def main(): - """Run all tests.""" - print("AlphaCLIP Standalone Test Suite") - print("=" * 40) - - tests = [ - test_imports, - test_tokenization, - test_model_loading, - test_text_encoding, - ] - - passed = 0 - total = len(tests) - - for test in tests: - try: - if test(): - passed += 1 - except Exception as e: - print(f"โœ— Test {test.__name__} failed with exception: {e}") - - print(f"\n{'='*40}") - print(f"Test Results: {passed}/{total} tests passed") - - if passed == total: - print("๐ŸŽ‰ All tests passed! AlphaCLIP Standalone is working correctly.") - return 0 - else: - print("โŒ Some tests failed. Please check the error messages above.") - return 1 - -if __name__ == "__main__": - sys.exit(main()) diff --git a/src/bbox_utils.py b/src/bbox_utils.py deleted file mode 100644 index 9d74df407c57416ae7e8416bc683ab00d76b7f00..0000000000000000000000000000000000000000 --- a/src/bbox_utils.py +++ /dev/null @@ -1,421 +0,0 @@ -import torch -from copy import deepcopy -from PIL import ImageDraw -import itertools -import random - - -def extract_bboxes_feats(patch_embeddings, bboxes, gaussian_avg=False, - gaussian_bbox_variance=0.5, get_single_embedding_per_image=False, - patch_size=14, attention_map=None): - """ - if get_single_embedding_per_image is True, the weights of all the bounding boxes patches on an image will be summed and the function will return the patch weights depending on this map - """ - N = patch_embeddings.shape[0] - N_boxes = bboxes.shape[1] - grid_size = int(patch_embeddings.shape[1]**0.5) - device = patch_embeddings.device - - bboxes //= patch_size - bboxes = bboxes.int() - - # Reshape patches to grid - patch_embeddings = patch_embeddings.view(N, grid_size, grid_size, -1) # Shape (N, grid_size, grid_size, embed_dim) - if attention_map is not None: - attention_map = attention_map.view(N, grid_size, grid_size) # Shape (N, grid_size, grid_size) - # Grid of the sum of the gaussian weights - total_patch_weights = torch.zeros(N, grid_size, grid_size) - - # Extract boxes - x1, y1, w, h = bboxes.unbind(-1) # Separate box dimensions (N, N_boxes) - - # Create mesh grid for slicing - x2 = x1 + w # Exclusive end x - y2 = y1 + h # Exclusive end y - - means = [] - for i in range(N): - image_means = [] - for j in range(N_boxes): - if bboxes[i, j].sum().item() < 0 and get_single_embedding_per_image: - # this is the case where we receive a dummy box - continue - # Extract the region for each box - region_patches = patch_embeddings[i, y1[i, j]:y2[i, j] + 1, x1[i, j]:x2[i, j] + 1, :] # (h, w, embed_dim) - - if attention_map is not None: - patch_weights = attention_map[i, y1[i, j]:y2[i, j] + 1, x1[i, j]:x2[i, j] + 1] - patch_weights /= patch_weights.sum() - total_patch_weights[i, y1[i, j]:y2[i, j] + 1, x1[i, j]:x2[i, j] + 1] += patch_weights - - weighted_patches = region_patches * patch_weights.to(device).unsqueeze(-1) # (h, w, embed_dim) - region_mean = weighted_patches.sum(dim=(0, 1)) # Weighted mean - - elif gaussian_avg: - # Create Gaussian weights - h_span, w_span = region_patches.shape[:2] - y_coords, x_coords = torch.meshgrid( - torch.linspace(-1, 1, h_span), - torch.linspace(-1, 1, w_span), - indexing="ij" - ) - if gaussian_bbox_variance == 0: - patch_weights = torch.zeros((h_span, w_span)) - # Determine central indices - center_y = [h_span // 2] if h_span % 2 == 1 else [h_span // 2 - 1, h_span // 2] - center_x = [w_span // 2] if w_span % 2 == 1 else [w_span // 2 - 1, w_span // 2] - # Randomly select one of the central elements in even case - center_y = random.choice(center_y) - center_x = random.choice(center_x) - # Set the selected central element to 1 - patch_weights[center_y, center_x] = 1.0 - else: - distances = x_coords**2 + y_coords**2 - patch_weights = torch.exp(-distances / gaussian_bbox_variance) - patch_weights = patch_weights / patch_weights.sum() # Normalize to sum to 1 - - # Apply Gaussian weights to region patches - weighted_patches = region_patches * patch_weights.to(device).unsqueeze(-1) # (h, w, embed_dim) - region_mean = weighted_patches.sum(dim=(0, 1)) # Weighted mean - - # Recording the bbox weight inside the image patch weight map - total_patch_weights[i, y1[i, j]:y2[i, j] + 1, x1[i, j]:x2[i, j] + 1] += patch_weights - else: - # Mean pooling case: create uniform weights - h_span, w_span = region_patches.shape[:2] - uniform_weights = torch.ones(h_span, w_span) / (h_span * w_span) - - # Update total_patch_weights for mean pooling - total_patch_weights[i, y1[i,j]:y2[i,j]+1, x1[i,j]:x2[i,j]+1] += uniform_weights - - # Compute mean of the region - region_mean = region_patches.mean(dim=(0, 1)) - - # Store the mean - image_means.append(region_mean) - if not get_single_embedding_per_image: - means.append(torch.stack(image_means)) - - # Normalizing the weight map so the sum is equal to 1 - total_patch_weights /= total_patch_weights.sum(dim=(1,2), keepdim=True) - if not get_single_embedding_per_image: - return torch.stack(means) # Shape (N, N_boxes, embed_dim) - else: - # Expand dimensions to match embeddings - total_patch_weights = total_patch_weights.unsqueeze(-1).to(device) - - # Compute weighted sum - weighted_patch_mean = (total_patch_weights * patch_embeddings).sum(dim=(1, 2)) - return weighted_patch_mean -# Shape (N, embed_dim) - -#def adjust_bbox_for_transform(image, bbox, resize_dim, crop_dim): -# """ -# Adjusts the bounding box for a resized and center-cropped image. -# -# Args: -# image (PIL.Image): The input image. -# bbox (list): The bounding box in [x1, y1, w, h] format. -# resize_dim (int): The dimension of the shortest side after resizing. -# crop_dim (int): The size of the square crop. -# -# Returns: -# list: The adjusted bounding box in [x1, y1, w, h] format. -# """ -# x1, y1, w, h = bbox -# orig_width, orig_height = image.size -# -# # Calculate resize scale for the shortest side -# if orig_width < orig_height: -# scale = resize_dim / orig_width -# resized_width, resized_height = resize_dim, int(orig_height * scale) -# else: -# scale = resize_dim / orig_height -# resized_width, resized_height = int(orig_width * scale), resize_dim -# -# # Scale the bounding box -# x1 *= scale -# y1 *= scale -# w *= scale -# h *= scale -# -# # Calculate cropping offsets -# crop_x = (resized_width - crop_dim) // 2 -# crop_y = (resized_height - crop_dim) // 2 -# -# # Adjust bounding box for cropping -# x1 -= crop_x -# y1 -= crop_y -# -# # Clamp the bounding box to the cropped area -# x1 = max(0, x1) -# y1 = max(0, y1) -# w = min(w, crop_dim - x1) -# h = min(h, crop_dim - y1) -# -# return [x1, y1, w, h] - -def map_traces_to_grid(traces, n_patch): - grid = torch.zeros((n_patch, n_patch)) - patch_size = 1.0 / n_patch - - for trace in traces: - x, y = trace['x'], trace['y'] - if 0 <= x <= 1 and 0 <= y <= 1: - grid_x, grid_y = int(x / patch_size), int(y / patch_size) - grid[min(grid_y, n_patch - 1), min(grid_x, n_patch - 1)] += 1 - - return grid - -def adjust_bbox_for_transform(image, bbox, resize_dim, crop_dim): - """ - Adjusts the bounding box for a resized and center-cropped image. - - Args: - image (PIL.Image): The input image. - bbox (list): The bounding box in [x1, y1, w, h] format. - resize_dim (int): The dimension of the shortest side after resizing. - crop_dim (int): The size of the square crop. - - Returns: - list: The adjusted bounding box in [x1, y1, w, h] format. - """ - x1, y1, w, h = bbox - orig_width, orig_height = image.size - - # Scale factors for resizing - if orig_width < orig_height: - scale_w = resize_dim / orig_width - scale_h = (resize_dim * orig_height) / orig_width / orig_height - else: - scale_h = resize_dim / orig_height - scale_w = (resize_dim * orig_width) / orig_height / orig_width - - # New dimensions after resize - new_width = int(orig_width * scale_w) - new_height = int(orig_height * scale_h) - - # Update bounding box for resizing - x1 = x1 * scale_w - y1 = y1 * scale_h - w = w * scale_w - h = h * scale_h - - # Compute cropping offsets - crop_x_offset = max(0, (new_width - crop_dim) // 2) - crop_y_offset = max(0, (new_height - crop_dim) // 2) - - # Adjust bounding box for cropping - x1 -= crop_x_offset - y1 -= crop_y_offset - - # Clip bounding box to crop dimensions - x1 = max(0, min(x1, crop_dim - 1)) - y1 = max(0, min(y1, crop_dim - 1)) - w = max(0, min(w, crop_dim - x1)) - h = max(0, min(h, crop_dim - y1)) - - return [x1, y1, w, h] - - - -def adjust_bbox_for_transform_no_scale(image, bbox, target_width, target_height): - """ - - Does not preserve the image scale. - Adjusts the bounding box for an image resized to a fixed width and height. - - Args: - image (PIL.Image): The original image. - bbox (list): The bounding box in [x1, y1, w, h] format. - target_width (int): The width of the resized image. - target_height (int): The height of the resized image. - - Returns: - list: The adjusted bounding box in [x1, y1, w, h] format. - """ - x1, y1, w, h = bbox - orig_width, orig_height = image.size - - # Calculate scale factors for width and height - scale_w = target_width / orig_width - scale_h = target_height / orig_height - - # Adjust the bounding box - x1 = x1 * scale_w - y1 = y1 * scale_h - w = w * scale_w - h = h * scale_h - - # Return the adjusted bounding box - return [x1, y1, w, h] - - -def draw_bounding_boxes(input_image, bounding_boxes, captions=[""], color="red", width=2, text_background=True, boxes_to_show = None): - """ - Draws bounding boxes on an image. - - Args: - image (PIL.Image): The image to draw on. - bounding_boxes (list): A list of bounding boxes, each as [x1, y1, x2, y2]. - color (str): The color of the bounding boxes (default is red). - width (int): The width of the bounding box lines (default is 2). - - Returns: - PIL.Image: The image with bounding boxes drawn. - """ - # Create a drawing context - image = deepcopy(input_image) - draw = ImageDraw.Draw( image ) - - #scale = 720.0 / max(image.size) - if boxes_to_show is not None: - if isinstance(boxes_to_show, int): - indexes_to_show = random.sample(range(len(bounding_boxes)), boxes_to_show) - else: - indexes_to_show = boxes_to_show - - for i, (bbox, cap ) in enumerate(itertools.zip_longest(bounding_boxes, captions, fillvalue="")): - - if boxes_to_show is not None: - if i not in indexes_to_show: continue - #bbox = [ i / scale for i in bbox ] - #x1, y1, w, h = bbox - x1, y1, x2, y2 = bbox - - #x2, y2 = x1 + w, y1 + h # Convert width/height to bottom-right corner - try: - draw.rectangle([x1, y1, x2, y2], outline=color, width=width) - if cap != "": - if text_background: - left,top,right,bottom = draw.multiline_textbbox((x1,y1), cap) #textbbox - draw.rectangle((left-5, top-5, right+5, bottom+5), fill="white") - draw.multiline_text((x1,y1), cap, fill=color) #text - - except Exception as e: - print("exception, i: ", i, f"{x1 = } {y1 = } {x2 = }, {y2 = }") - print(e) - - return image - -def extract_bboxes_feats_double_dino(dino_model, patch_embeddings, bboxes, cls_token, registers_tokens, patch_size, return_type="cls", gaussian_bbox_variance=0.5): - """ - Perform a forward pass of the last DINO layer with selected features, batched. - - Args: - dino_model: The DINO model. - patch_embeddings: Patch embeddings before the last layer. - bboxes: Bounding boxes for each image in the batch (BS x N_BOX_MAX x 4). - cls_token: CLS token embedding. - return_type: Type of feature to return ('cls', 'avg', 'gaussian_avg'). - gaussian_bbox_variance: Variance for Gaussian averaging. - - Returns: - bbox_features: Features for each bounding box based on return_type. - """ - N = patch_embeddings.shape[0] # Batch size - N_boxes = bboxes.shape[1] # Number of bounding boxes - grid_size = int(patch_embeddings.shape[1] ** 0.5) # Assuming square grid - embed_dim = patch_embeddings.shape[-1] - - bboxes_patch_indexes = bboxes.clone() - bboxes_patch_indexes //= patch_size # Scale down bbox coordinates to match patch grid - bboxes_patch_indexes = bboxes_patch_indexes.int() - - # Reshape patches to grid - patch_embeddings = patch_embeddings.view(N, grid_size, grid_size, embed_dim) # (N, grid_size, grid_size, embed_dim) - - if cls_token is not None: - cls_tokens = cls_token.view(N, embed_dim) - if registers_tokens is not None: - patches_offset = 5 - else: - patches_offset = 1 - else: - assert return_type != "cls" - patches_offset = 0 - batch_outputs = [] - - #batch_inputs = [] - - means = [] - for i in range(N): # Iterate over batch - image_means = [] - - if cls_token is not None: - cls_cur_img = cls_tokens[i].reshape(1, 1, embed_dim) - if registers_tokens is not None: - cur_img_register_tokens = registers_tokens[i].reshape(1, 4, embed_dim) - - for j in range(N_boxes): # Iterate over bounding boxes - # Extract the region for the bounding box - region_patches_xy = patch_embeddings[i, bboxes_patch_indexes[i, j, 1]:bboxes_patch_indexes[i, j, 3] + 1, bboxes_patch_indexes[i, j, 0]:bboxes_patch_indexes[i, j, 2] + 1, :] - #region_patches = region_patches.reshape(-1, embed_dim) # Flatten to (num_patches, embed_dim) - - #region_patches = region_patches.view(-1, embed_dim) # Flatten to (num_patches, embed_dim) - #cls_cur_img = cls_tokens[i].unsqueeze(0) # Add batch dimension (1, embed_dim) - #region_patches = region_patches.unsqueeze(0) # Add batch dimension (1, num_patches, embed_dim) - region_patches = region_patches_xy.reshape(1,-1, embed_dim) - if cls_token is not None: - inputs = torch.cat([cls_cur_img, region_patches], dim=1) # Concatenate along the token dimension (1, num_patches + 1, embed_dim) - if registers_tokens is not None: - inputs = torch.cat([cls_cur_img, cur_img_register_tokens, region_patches], dim=1) # Concatenate along the token dimension (1, num_patches + 5, embed_dim) - else: - inputs = torch.cat([region_patches], dim=1) # Concatenate along the token dimension (1, num_patches + 1, embed_dim) - - outputs = dino_model.blocks[-1](inputs) # Forward pass - # shape (1, 1 + len(region_patches), 768) - #cls_cur_img = cls_tokens[i] - #cls_cur_img = cls_cur_img.reshape(1, embed_dim) - #inputs = torch.cat([cls_cur_img, region_patches], dim=0) # Add CLS token to inputs - #outputs = dino_model.blocks[-1](inputs) # Forward pass - - batch_outputs.append(outputs) - - region_patches = outputs[0, patches_offset: ,] #(1,45,768) -> (1,1,768) - - if return_type == "gaussian_avg": - #region_patches = outputs[5: ,] - h_span, w_span = region_patches_xy.shape[:2] - y_coords, x_coords = torch.meshgrid( - torch.linspace(-1, 1, h_span), - torch.linspace(-1, 1, w_span), - indexing="ij" - ) - distances = x_coords**2 + y_coords**2 - gaussian_weights = torch.exp(-distances / gaussian_bbox_variance) # Adjust 0.1 for variance control - gaussian_weights = gaussian_weights / gaussian_weights.sum() # Normalize to sum to 1 - - # Apply Gaussian weights to region patches - weighted_patches = region_patches_xy * gaussian_weights.to(next(dino_model.parameters()).device).unsqueeze(-1) # (h, w, embed_dim) - region_mean = weighted_patches.sum(dim=(0,1)) # Weighted mean - #image_means.append(region_mean) - elif return_type == "avg": - # Compute mean of the region - region_mean = region_patches.mean(dim=(0)) # Mean over h, w - elif return_type == "cls": - region_mean = outputs[0, 0, ] - image_means.append(region_mean) - - means.append(torch.stack(image_means)) - - stacked_means = torch.stack(means) - #stacked_means = stacked_means.reshape(-1, embed_dim) - return stacked_means - - -def process_bboxes(imgs, bboxes, transform): - transformed_bboxes = [] - bboxes = bboxes.tolist() - for img, img_bboxes in zip(imgs, bboxes): - for bbox in img_bboxes: - # Crop the region defined by bbox - x_min, y_min, w, h = bbox - x_max = x_min + w - y_max = y_min + h - cropped_region = img.crop((x_min, y_min, x_max, y_max)) - - # Apply the transform to the cropped region - transformed_region = transform(cropped_region) - transformed_bboxes.append(transformed_region) - - return torch.stack(transformed_bboxes) \ No newline at end of file diff --git a/src/clipcap/CLIPCAP_INTEGRATION.md b/src/clipcap/CLIPCAP_INTEGRATION.md deleted file mode 100644 index 970820bd97d1b9aa75946b5d7e4b2112609200f8..0000000000000000000000000000000000000000 --- a/src/clipcap/CLIPCAP_INTEGRATION.md +++ /dev/null @@ -1,206 +0,0 @@ -# ClipCap Integration with Patchioner Class - -This document describes how ClipCap models have been integrated into the Patchioner class for DINO feature-based image captioning. - -## Overview - -ClipCap support has been added to the Patchioner class following the same pattern as other captioning models (VieCap, MeaCap, etc.). This integration allows you to use trained ClipCap models with DINO features for image captioning tasks. - -## Architecture - -### Files Added/Modified - -1. **`src/clipcap/entrypoint.py`** - Main ClipCap integration module - - `ClipCapModel` class for DINO feature-based captioning - - Model classes: `ClipCaptionModel`, `ClipCaptionPrefix`, `MLP`, `TransformerMapper` - - Text generation utilities - -2. **`src/model.py`** - Modified Patchioner class - - Added `clipcap_config` parameter to constructor - - Added ClipCap initialization logic - - Added ClipCap support to `caption_tokens` method - -3. **Configuration Files** - - `configs/clipcap_dino_vitb14.k.yaml` - DINOv2-B/14 configuration - - `configs/clipcap_dino_vitl14.k.yaml` - DINOv2-L/14 configuration - -## Configuration - -### YAML Configuration Format - -```yaml -decap_weights: '/path/to/decap/weights.pt' -prefix_size: 768 # DINO feature dimension -support_memory_size: 0 -dino_model: 'dinov2_vitb14' -normalize: True -resize_dim: 518 -crop_dim: 518 -use_talk2dino_project: False - -# ClipCap configuration -clipcap: - language_model: 'gpt2' - prefix_length: 10 # Sequence length for prefix - clip_length: 10 # CLIP sequence length (for transformer mapping) - num_layers: 8 # Number of transformer layers (for transformer mapping) - mapping_type: 'mlp' # 'mlp' or 'transformer' - only_prefix: True # Train only prefix mapping vs full model - temperature: 1.0 # Sampling temperature - top_p: 0.8 # Nucleus sampling parameter - entry_length: 67 # Maximum caption length - stop_token: '.' # Stop token for generation - weight_path: '/path/to/trained/clipcap/model.pt' -``` - -### Supported DINO Models - -The integration automatically detects DINO feature dimensions: - -- **DINOv2-S/14**: 384 dimensions (`dinov2_vits14`) -- **DINOv2-B/14**: 768 dimensions (`dinov2_vitb14`) -- **DINOv2-L/14**: 1024 dimensions (`dinov2_vitl14`) -- **DINOv2-G/14**: 1536 dimensions (`dinov2_vitg14`) - -## Usage - -### 1. Training ClipCap Models - -First, train your ClipCap model with DINO features: - -```bash -# Extract DINO features -python clipcap_dino_parse_coco.py --dino_model_type dinov2_vitb14 - -# Train ClipCap model -python clipcapTraining.py \ - --use_dino \ - --dino_model_type dinov2_vitb14 \ - --prefix_length 10 \ - --mapping_type mlp \ - --only_prefix \ - --epochs 10 -``` - -### 2. Using ClipCap with Patchioner - -```python -import torch -from src.model import Patchioner - -# Load model with ClipCap configuration -device = torch.device('cuda') -model = Patchioner.from_config('configs/clipcap_dino_vitb14.k.yaml', device) - -# Generate captions from images -imgs = torch.randn(2, 3, 518, 518).to(device) # Example batch -results = model.forward(imgs, get_cls_capt=True) -captions = results['cls_capt'] - -print("Generated captions:") -for i, caption in enumerate(captions): - print(f"Image {i+1}: {caption}") -``` - -### 3. Using ClipCap Directly - -```python -from src.clipcap.entrypoint import ClipCapModel -import torch - -# Configuration -config = { - 'language_model': 'gpt2', - 'prefix_length': 10, - 'mapping_type': 'mlp', - 'only_prefix': True, - 'weight_path': '/path/to/trained/model.pt' -} - -# Initialize model -device = torch.device('cuda') -clipcap = ClipCapModel(config, device, dino_feature_dim=768) - -# Generate captions from DINO features -dino_features = torch.randn(2, 768).to(device) -captions = clipcap.forward(dino_features) - -print(captions) -``` - -## Performance Improvements - -### Batched Text Generation - -The ClipCap integration includes an efficient batched text generation implementation: - -- **`generate_batched()`**: Processes entire batches simultaneously -- **Significant speedup**: 2-8x faster than sequential processing -- **Memory efficient**: Optimized for GPU memory usage -- **Configurable**: Can fallback to sequential mode if needed - -### Configuration Options - -```yaml -clipcap: - use_batched_generation: True # Enable batched generation (recommended) - temperature: 1.0 # Sampling temperature - top_p: 0.8 # Nucleus sampling parameter - entry_length: 67 # Maximum sequence length -``` - -## Model Architecture Details - -### ClipCap Model Structure - -1. **Input**: DINO features (384/768/1024/1536 dimensions) -2. **Mapping Layer**: - - **MLP**: `DINO_dim โ†’ GPT2_dim * prefix_length` - - **Transformer**: Multi-layer transformer mapping -3. **GPT-2 Decoder**: Pretrained GPT-2 for text generation -4. **Output**: Natural language captions - -### Key Components - -- **`ClipCapModel`**: Main class for DINO-to-text captioning -- **`MLP`/`TransformerMapper`**: Feature mapping from DINO to GPT-2 space -- **Text Generation**: Nucleus sampling with configurable parameters - -## Integration with Existing Pipeline - -The ClipCap integration follows the established pattern: - -1. **Configuration**: YAML-based configuration like other models -2. **Initialization**: Automatic DINO dimension detection -3. **Forward Pass**: Seamless integration with existing forward methods -4. **Scoring**: Optional confidence scoring support - -## Testing - -Run the integration test: - -```bash -python test_clipcap_integration.py -``` - -This test verifies: -- Configuration loading from YAML -- Model instantiation with ClipCap -- Caption generation with dummy DINO features -- Score computation functionality - - -## Troubleshooting - -### Common Issues - -1. **Dimension Mismatch**: Ensure `prefix_size` matches DINO model dimension -2. **Missing Weights**: Verify `weight_path` points to trained ClipCap model -3. **Memory Issues**: Use `only_prefix=True` for lower memory usage -4. **Generation Quality**: Tune `temperature`, `top_p`, and `entry_length` - -## References - -- [ClipCap Paper](https://arxiv.org/abs/2111.09734) -- [DINO Paper](https://arxiv.org/abs/2104.14294) -- [DINOv2 Paper](https://arxiv.org/abs/2304.07193) \ No newline at end of file diff --git a/src/clipcap/clipcapTrainREADME.md b/src/clipcap/clipcapTrainREADME.md deleted file mode 100644 index d7e1ea0e2869befa3807583090cc5c7389807caf..0000000000000000000000000000000000000000 --- a/src/clipcap/clipcapTrainREADME.md +++ /dev/null @@ -1,301 +0,0 @@ -# ClipCap Training with DINO Features - README - -This guide provides instructions for training ClipCap with DINO visual features instead of CLIP features. - -## Prerequisites - -1. Ensure you have the required dependencies installed: - - PyTorch - - torchvision - - transformers - - tqdm - - Pillow - - scikit-image - -2. Prepare your COCO dataset with the following structure: - ``` - ./data/coco/ - โ”œโ”€โ”€ annotations/ - โ”‚ โ””โ”€โ”€ train_caption.json - โ”œโ”€โ”€ train2014/ - โ”‚ โ””โ”€โ”€ COCO_train2014_*.jpg - โ””โ”€โ”€ val2014/ - โ””โ”€โ”€ COCO_val2014_*.jpg - ``` - -## Required Files for DINO Feature Extraction - -To start the DINO feature extraction for the COCO dataset, you need: - -### 1. **COCO Dataset Structure**: -``` -/raid/datasets/coco/ # Main COCO directory (default) -โ”œโ”€โ”€ train2014/ # REQUIRED: Training images -โ”‚ โ””โ”€โ”€ COCO_train2014_*.jpg # Image files -โ”œโ”€โ”€ val2014/ # REQUIRED: Validation images -โ”‚ โ””โ”€โ”€ COCO_val2014_*.jpg # Image files -โ””โ”€โ”€ train_split_karpathy.json # REQUIRED: Karpathy format annotations (default) -``` - -### 2. **Required Files**: -- **`train_split_karpathy.json`**: COCO caption annotations in Karpathy format (default) -- **Training images**: COCO 2014 training set (COCO_train2014_*.jpg) -- **Validation images**: COCO 2014 validation set (COCO_val2014_*.jpg) - -### 3. **Annotation Format Support**: - -The script supports two annotation formats: - -#### **A. Karpathy Format** (default, recommended): -```json -{ - "images": [ - {"id": 522418, "file_name": "COCO_val2014_000000522418.jpg"} - ], - "annotations": [ - {"image_id": 522418, "id": 0, "caption": "A woman wearing a net..."} - ] -} -``` - -#### **B. ClipCap Format** (legacy): -```json -[ - {"image_id": 522418, "caption": "A woman wearing a net..."} -] -``` - -### 3. **Specifying Custom Input/Output Paths**: - -You can customize the paths using command-line arguments: - -```bash -python clipcap_dino_parse_coco.py \ - --dino_model_type dinov2_vitb14 \ - --coco_images_dir "/path/to/your/coco/dataset" \ - --captions_file "/path/to/your/train_caption.json" \ - --output_file "/path/to/output/dino_features.pkl" -``` - -**Available path arguments**: -- `--coco_images_dir`: Path to COCO images directory (should contain `train2014/` and `val2014/` subdirs) - **Default: `/raid/datasets/coco`** -- `--captions_file`: Path to COCO captions JSON file (supports both Karpathy and ClipCap formats) - **Default: `/raid/datasets/coco/train_split_karpathy.json`** -- `--output_file`: Custom output file path (optional, auto-generated if not specified) - -### 4. **Default Behavior** (if no paths specified): -```bash -# This will use default paths for your setup: -python clipcap_dino_parse_coco.py --dino_model_type dinov2_vitb14 - -# Equivalent to: -python clipcap_dino_parse_coco.py \ - --dino_model_type dinov2_vitb14 \ - --coco_images_dir "/raid/datasets/coco" \ - --captions_file "/raid/datasets/coco/train_split_karpathy.json" \ - --output_file "/raid/datasets/coco/coco_karpathy_split_dinov2_vitb14_train.pkl" -``` - -## Step 1: Extract DINO Features - -First, extract DINO features from the COCO images using the modified feature extraction script: - -### For DINOv2-B/14 (768-dim features): -```bash -# Default paths (uses /raid/datasets/coco and Karpathy annotations) -python clipcap_dino_parse_coco.py --dino_model_type dinov2_vitb14 --resize_dim 518 --crop_dim 518 - -# Custom paths -python clipcap_dino_parse_coco.py \ - --dino_model_type dinov2_vitb14 \ - --coco_images_dir "/your/coco/path" \ - --captions_file "/your/coco/train_split_karpathy.json" \ - --output_file "/your/output/dino_vitb14_features.pkl" -``` - -### For DINOv2-L/14 (1024-dim features): -```bash -# Default paths -python clipcap_dino_parse_coco.py --dino_model_type dinov2_vitl14 --resize_dim 518 --crop_dim 518 - -# Custom paths -python clipcap_dino_parse_coco.py \ - --dino_model_type dinov2_vitl14 \ - --coco_images_dir "/your/coco/path" \ - --output_file "/your/output/dino_vitl14_features.pkl" -``` - -### For DINOv2-S/14 (384-dim features): -```bash -python clipcap_dino_parse_coco.py --dino_model_type dinov2_vits14 --resize_dim 518 --crop_dim 518 -``` - -### For DINOv2-G/14 (1536-dim features): -```bash -python clipcap_dino_parse_coco.py --dino_model_type dinov2_vitg14 --resize_dim 518 --crop_dim 518 -``` - -**Output**: This will create a file like `/raid/datasets/models_weights/clipcap/training-features/coco_karpathy_split_dinov2_vitb14_train.pkl` (or your custom path) containing the DINO features and captions. - -### Check Available Arguments: -```bash -python clipcap_dino_parse_coco.py --help -``` - -## Step 2: Train ClipCap with DINO Features - -### Basic Training Command (MLP with sequence length 10): - -For **DINOv2-B/14** with **MLP mapping** and **prefix length 10**: -```bash -python clipcapTraining.py \ - --data /raid/datasets/models_weights/clipcap/training-features/coco_karpathy_split_dinov2_vitb14_train.pkl \ - --out_dir ./checkpoints_dino_vitb14_mlp_len10 \ - --prefix dino_vitb14_mlp_len10 \ - --epochs 10 \ - --save_every 2 \ - --prefix_length 10 \ - --bs 32 \ - --mapping_type mlp \ - --use_dino \ - --dino_model_type dinov2_vitb14 \ - --only_prefix -``` - -### Training Options for Different DINO Models: - -#### DINOv2-L/14 (1024-dim): -```bash -python clipcapTraining.py \ - --data ./data/coco/coco_karpathy_split_dinov2_vitl14_train.pkl \ - --out_dir ./checkpoints_dino_vitl14_mlp_len10 \ - --prefix dino_vitl14_mlp_len10 \ - --epochs 10 \ - --save_every 2 \ - --prefix_length 10 \ - --bs 32 \ - --mapping_type mlp \ - --use_dino \ - --dino_model_type dinov2_vitl14 \ - --only_prefix -``` - -#### DINOv2-S/14 (384-dim): -```bash -python clipcapTraining.py \ - --data ./data/coco/coco_karpathy_split_dinov2_vits14_train.pkl \ - --out_dir ./checkpoints_dino_vits14_mlp_len10 \ - --prefix dino_vits14_mlp_len10 \ - --epochs 10 \ - --save_every 2 \ - --prefix_length 10 \ - --bs 32 \ - --mapping_type mlp \ - --use_dino \ - --dino_model_type dinov2_vits14 \ - --only_prefix -``` - -### Advanced Training Options: - -#### Train both prefix and GPT (full model): -```bash -python clipcapTraining.py \ - --data /raid/datasets/models_weights/clipcap/training-features/coco_karpathy_split_dinov2_vitb14_train.pkl \ - --out_dir ./checkpoints_dino_vitb14_mlp_len10_full \ - --prefix dino_vitb14_mlp_len10_full \ - --epochs 10 \ - --save_every 2 \ - --prefix_length 10 \ - --bs 16 \ - --mapping_type mlp \ - --use_dino \ - --dino_model_type dinov2_vitb14 -``` - -#### Use Transformer mapping instead of MLP: -```bash -python clipcapTraining.py \ - --data /raid/datasets/models_weights/clipcap/training-features/coco_karpathy_split_dinov2_vitb14_train.pkl \ - --out_dir ./checkpoints_dino_vitb14_transformer_len10 \ - --prefix dino_vitb14_transformer_len10 \ - --epochs 10 \ - --save_every 2 \ - --prefix_length 10 \ - --bs 32 \ - --mapping_type transformer \ - --num_layers 8 \ - --use_dino \ - --dino_model_type dinov2_vitb14 \ - --only_prefix -``` - -#### Custom feature dimension (if needed): -```bash -python clipcapTraining.py \ - --data ./data/coco/coco_karpathy_split_dinov2_vitb14_train.pkl \ - --out_dir ./checkpoints_dino_custom \ - --prefix dino_custom \ - --epochs 10 \ - --prefix_length 10 \ - --bs 32 \ - --mapping_type mlp \ - --use_dino \ - --dino_model_type dinov2_vitb14 \ - --dino_feature_dim 768 \ - --only_prefix -``` - -## Key Parameters Explanation: - -- `--use_dino`: Enable DINO mode (required for DINO training) -- `--dino_model_type`: Specify which DINO model was used for feature extraction -- `--dino_feature_dim`: Override automatic feature dimension detection -- `--prefix_length`: Number of prefix tokens (set to 10 as requested) -- `--mapping_type`: Choose between 'mlp' or 'transformer' mapping -- `--only_prefix`: Train only the mapping layer, freeze GPT-2 -- `--bs`: Batch size (adjust based on GPU memory) -- `--epochs`: Number of training epochs -- `--save_every`: Save checkpoint every N epochs - -## Expected Feature Dimensions: - -- **DINOv2-S/14**: 384 dimensions -- **DINOv2-B/14**: 768 dimensions -- **DINOv2-L/14**: 1024 dimensions -- **DINOv2-G/14**: 1536 dimensions - -## Training Tips: - -1. **Memory Usage**: DINO features are typically larger than CLIP features, so you might need to reduce batch size -2. **Convergence**: DINO-based models may require different learning rates or longer training -3. **Prefix Length**: Experiment with different prefix lengths (5, 10, 20) for optimal performance -4. **Mapping Type**: MLP is faster, Transformer might give better results but requires more memory - -## Output: - -The training will save checkpoints in the specified output directory: -- `{prefix}-{epoch:03d}.pt`: Model checkpoint for each epoch -- `{prefix}_latest.pt`: Latest model checkpoint (updated every 10k iterations) -- `{prefix}.json`: Training configuration - -## Example Full Workflow: - -```bash -# 1. Extract DINO features -python clipcap_dino_parse_coco.py --dino_model_type dinov2_vitb14 - -# 2. Train ClipCap with DINO features (MLP, length 10, prefix-only) -python clipcapTraining.py \ - --data /raid/datasets/models_weights/clipcap/training-features/coco_karpathy_split_dinov2_vitb14_train.pkl \ - --out_dir ./checkpoints_dino_vitb14_mlp_len10 \ - --prefix dino_vitb14_mlp_len10 \ - --epochs 10 \ - --prefix_length 10 \ - --bs 32 \ - --mapping_type mlp \ - --use_dino \ - --dino_model_type dinov2_vitb14 \ - --only_prefix -``` - -This will train a ClipCap model using DINO features with MLP mapping and sequence length 10 as requested. \ No newline at end of file diff --git a/src/clipcap/clipcapTraining.py b/src/clipcap/clipcapTraining.py deleted file mode 100644 index 20bcd502dcc9062fa8f625be2ea3bdaee94dd928..0000000000000000000000000000000000000000 --- a/src/clipcap/clipcapTraining.py +++ /dev/null @@ -1,405 +0,0 @@ -import torch -import torch.nn as nn -from torch.nn import functional as nnf -from torch.utils.data import Dataset, DataLoader -from enum import Enum -from transformers import GPT2Tokenizer, GPT2LMHeadModel, AdamW, get_linear_schedule_with_warmup -from tqdm import tqdm -import os -import pickle -import sys -import argparse -import json -from typing import Tuple, Optional, Union - - -class MappingType(Enum): - MLP = 'mlp' - Transformer = 'transformer' - - -class ClipCocoDataset(Dataset): - - def __len__(self) -> int: - return len(self.captions_tokens) - - def pad_tokens(self, item: int): - tokens = self.captions_tokens[item] - padding = self.max_seq_len - tokens.shape[0] - if padding > 0: - tokens = torch.cat((tokens, torch.zeros(padding, dtype=torch.int64) - 1)) - self.captions_tokens[item] = tokens - elif padding < 0: - tokens = tokens[:self.max_seq_len] - self.captions_tokens[item] = tokens - mask = tokens.ge(0) # mask is zero where we out of sequence - tokens[~mask] = 0 - mask = mask.float() - mask = torch.cat((torch.ones(self.prefix_length), mask), dim=0) # adding prefix mask - return tokens, mask - - def __getitem__(self, item: int) -> Tuple[torch.Tensor, ...]: - tokens, mask = self.pad_tokens(item) - prefix = self.prefixes[self.caption2embedding[item]] - if self.normalize_prefix: - prefix = prefix.float() - prefix = prefix / prefix.norm(2, -1) - return tokens, mask, prefix - - def __init__(self, data_path: str, prefix_length: int, gpt2_type: str = "gpt2", - normalize_prefix=False): - self.tokenizer = GPT2Tokenizer.from_pretrained(gpt2_type) - self.prefix_length = prefix_length - self.normalize_prefix = normalize_prefix - with open(data_path, 'rb') as f: - all_data = pickle.load(f) - print("Data size is %0d" % len(all_data["clip_embedding"])) - sys.stdout.flush() - self.prefixes = all_data["clip_embedding"] - captions_raw = all_data["captions"] - self.image_ids = [caption["image_id"] for caption in captions_raw] - self.captions = [caption['caption'] for caption in captions_raw] - if os.path.isfile(f"{data_path[:-4]}_tokens.pkl"): - with open(f"{data_path[:-4]}_tokens.pkl", 'rb') as f: - self.captions_tokens, self.caption2embedding, self.max_seq_len = pickle.load(f) - else: - self.captions_tokens = [] - self.caption2embedding = [] - max_seq_len = 0 - for caption in captions_raw: - self.captions_tokens.append(torch.tensor(self.tokenizer.encode(caption['caption']), dtype=torch.int64)) - self.caption2embedding.append(caption["clip_embedding"]) - max_seq_len = max(max_seq_len, self.captions_tokens[-1].shape[0]) - # self.max_seq_len = max_seq_len - with open(f"{data_path[:-4]}_tokens.pkl", 'wb') as f: - pickle.dump([self.captions_tokens, self.caption2embedding, max_seq_len], f) - all_len = torch.tensor([len(self.captions_tokens[i]) for i in range(len(self))]).float() - self.max_seq_len = min(int(all_len.mean() + all_len.std() * 10), int(all_len.max())) - - -class MLP(nn.Module): - - def forward(self, x: torch.Tensor) -> torch.Tensor: - return self.model(x) - - def __init__(self, sizes: Tuple[int, ...], bias=True, act=nn.Tanh): - super(MLP, self).__init__() - layers = [] - for i in range(len(sizes) - 1): - layers.append(nn.Linear(sizes[i], sizes[i + 1], bias=bias)) - if i < len(sizes) - 2: - layers.append(act()) - self.model = nn.Sequential(*layers) - - -class MlpTransformer(nn.Module): - def __init__(self, in_dim, h_dim, out_d: Optional[int] = None, act=nnf.relu, dropout=0.): - super().__init__() - out_d = out_d if out_d is not None else in_dim - self.fc1 = nn.Linear(in_dim, h_dim) - self.act = act - self.fc2 = nn.Linear(h_dim, out_d) - self.dropout = nn.Dropout(dropout) - - def forward(self, x): - x = self.fc1(x) - x = self.act(x) - x = self.dropout(x) - x = self.fc2(x) - x = self.dropout(x) - return x - -class MultiHeadAttention(nn.Module): - - def __init__(self, dim_self, dim_ref, num_heads, bias=True, dropout=0.): - super().__init__() - self.num_heads = num_heads - head_dim = dim_self // num_heads - self.scale = head_dim ** -0.5 - self.to_queries = nn.Linear(dim_self, dim_self, bias=bias) - self.to_keys_values = nn.Linear(dim_ref, dim_self * 2, bias=bias) - self.project = nn.Linear(dim_self, dim_self) - self.dropout = nn.Dropout(dropout) - - def forward(self, x, y=None, mask=None): - y = y if y is not None else x - b, n, c = x.shape - _, m, d = y.shape - # b n h dh - queries = self.to_queries(x).reshape(b, n, self.num_heads, c // self.num_heads) - # b m 2 h dh - keys_values = self.to_keys_values(y).reshape(b, m, 2, self.num_heads, c // self.num_heads) - keys, values = keys_values[:, :, 0], keys_values[:, :, 1] - attention = torch.einsum('bnhd,bmhd->bnmh', queries, keys) * self.scale - if mask is not None: - if mask.dim() == 2: - mask = mask.unsqueeze(1) - attention = attention.masked_fill(mask.unsqueeze(3), float("-inf")) - attention = attention.softmax(dim=2) - out = torch.einsum('bnmh,bmhd->bnhd', attention, values).reshape(b, n, c) - out = self.project(out) - return out, attention - - -class TransformerLayer(nn.Module): - - def forward_with_attention(self, x, y=None, mask=None): - x_, attention = self.attn(self.norm1(x), y, mask) - x = x + x_ - x = x + self.mlp(self.norm2(x)) - return x, attention - - def forward(self, x, y=None, mask=None): - x = x + self.attn(self.norm1(x), y, mask)[0] - x = x + self.mlp(self.norm2(x)) - return x - - def __init__(self, dim_self, dim_ref, num_heads, mlp_ratio=4., bias=False, dropout=0., act=nnf.relu, - norm_layer: nn.Module = nn.LayerNorm): - super().__init__() - self.norm1 = norm_layer(dim_self) - self.attn = MultiHeadAttention(dim_self, dim_ref, num_heads, bias=bias, dropout=dropout) - self.norm2 = norm_layer(dim_self) - self.mlp = MlpTransformer(dim_self, int(dim_self * mlp_ratio), act=act, dropout=dropout) - - -class Transformer(nn.Module): - - def forward_with_attention(self, x, y=None, mask=None): - attentions = [] - for layer in self.layers: - x, att = layer.forward_with_attention(x, y, mask) - attentions.append(att) - return x, attentions - - def forward(self, x, y=None, mask=None): - for i, layer in enumerate(self.layers): - if i % 2 == 0 and self.enc_dec: # cross - x = layer(x, y) - elif self.enc_dec: # self - x = layer(x, x, mask) - else: # self or cross - x = layer(x, y, mask) - return x - - def __init__(self, dim_self: int, num_heads: int, num_layers: int, dim_ref: Optional[int] = None, - mlp_ratio: float = 2., act=nnf.relu, norm_layer: nn.Module = nn.LayerNorm, enc_dec: bool = False): - super(Transformer, self).__init__() - dim_ref = dim_ref if dim_ref is not None else dim_self - self.enc_dec = enc_dec - if enc_dec: - num_layers = num_layers * 2 - layers = [] - for i in range(num_layers): - if i % 2 == 0 and enc_dec: # cross - layers.append(TransformerLayer(dim_self, dim_ref, num_heads, mlp_ratio, act=act, norm_layer=norm_layer)) - elif enc_dec: # self - layers.append(TransformerLayer(dim_self, dim_self, num_heads, mlp_ratio, act=act, norm_layer=norm_layer)) - else: # self or cross - layers.append(TransformerLayer(dim_self, dim_ref, num_heads, mlp_ratio, act=act, norm_layer=norm_layer)) - self.layers = nn.ModuleList(layers) - - -class TransformerMapper(nn.Module): - - def forward(self, x): - x = self.linear(x).view(x.shape[0], self.clip_length, -1) - prefix = self.prefix_const.unsqueeze(0).expand(x.shape[0], *self.prefix_const.shape) - prefix = torch.cat((x, prefix), dim=1) - out = self.transformer(prefix)[:, self.clip_length:] - return out - - def __init__(self, dim_clip: int, dim_embedding: int, prefix_length: int, clip_length: int, num_layers: int = 8): - super(TransformerMapper, self).__init__() - self.clip_length = clip_length - self.transformer = Transformer(dim_embedding, 8, num_layers) - self.linear = nn.Linear(dim_clip, clip_length * dim_embedding) - self.prefix_const = nn.Parameter(torch.randn(prefix_length, dim_embedding), requires_grad=True) - - -class ClipCaptionModel(nn.Module): - - def get_dummy_token(self, batch_size: int, device: torch.device) -> torch.Tensor: - return torch.zeros(batch_size, self.prefix_length, dtype=torch.int64, device=device) - - def forward(self, tokens: torch.Tensor, prefix: torch.Tensor, mask: Optional[torch.Tensor] = None, - labels: Optional[torch.Tensor] = None): - embedding_text = self.gpt.transformer.wte(tokens) - prefix_projections = self.clip_project(prefix).view(-1, self.prefix_length, self.gpt_embedding_size) - embedding_cat = torch.cat((prefix_projections, embedding_text), dim=1) - if labels is not None: - dummy_token = self.get_dummy_token(tokens.shape[0], tokens.device) - labels = torch.cat((dummy_token, tokens), dim=1) - out = self.gpt(inputs_embeds=embedding_cat, labels=labels, attention_mask=mask) - return out - - def __init__(self, prefix_length: int, clip_length: Optional[int] = None, prefix_size: int = 512, - num_layers: int = 8, mapping_type: MappingType = MappingType.MLP): - super(ClipCaptionModel, self).__init__() - self.prefix_length = prefix_length - self.gpt = GPT2LMHeadModel.from_pretrained('gpt2') - self.gpt_embedding_size = self.gpt.transformer.wte.weight.shape[1] - if mapping_type == MappingType.MLP: - self.clip_project = MLP((prefix_size, (self.gpt_embedding_size * prefix_length) // 2, - self.gpt_embedding_size * prefix_length)) - else: - self.clip_project = TransformerMapper(prefix_size, self.gpt_embedding_size, prefix_length, - clip_length, num_layers) - - -class ClipCaptionPrefix(ClipCaptionModel): - - def parameters(self, recurse: bool = True): - return self.clip_project.parameters() - - def train(self, mode: bool = True): - super(ClipCaptionPrefix, self).train(mode) - self.gpt.eval() - return self - - -def save_config(args: argparse.Namespace): - config = {} - for key, item in args._get_kwargs(): - config[key] = item - out_path = os.path.join(args.out_dir, f"{args.prefix}.json") - with open(out_path, 'w') as outfile: - json.dump(config, outfile) - - -def load_model(config_path: str, epoch_or_latest: Union[str, int] = '_latest'): - with open(config_path) as f: - config = json.load(f) - parser = argparse.ArgumentParser() - parser.set_defaults(**config) - args = parser.parse_args() - if type(epoch_or_latest) is int: - epoch_or_latest = f"-{epoch_or_latest:03d}" - model_path = os.path.join(args.out_dir, f"{args.prefix}{epoch_or_latest}.pt") - if args.only_prefix: - model = ClipCaptionPrefix(args.prefix_length) - else: - model = ClipCaptionModel(args.prefix_length) - if os.path.isfile(model_path): - print(f"loading model from {model_path}") - model.load_state_dict(torch.load(model_path, map_location=torch.device('cpu'))) - else: - print(f"{model_path} is not exist") - return model, parser - - -def train(dataset: ClipCocoDataset, model: ClipCaptionModel, args, - lr: float = 2e-5, warmup_steps: int = 5000, output_dir: str = ".", output_prefix: str = "", device = torch.device('cuda:0')): - - batch_size = args.bs - epochs = args.epochs - if not os.path.exists(output_dir): - os.makedirs(output_dir) - model = model.to(device) - model.train() - optimizer = AdamW(model.parameters(), lr=lr) - train_dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True, drop_last=True) - scheduler = get_linear_schedule_with_warmup( - optimizer, num_warmup_steps=warmup_steps, num_training_steps=epochs * len(train_dataloader) - ) - # save_config(args) - for epoch in range(epochs): - print(f">>> Training epoch {epoch}") - sys.stdout.flush() - progress = tqdm(total=len(train_dataloader), desc=output_prefix) - for idx, (tokens, mask, prefix) in enumerate(train_dataloader): - model.zero_grad() - tokens, mask, prefix = tokens.to(device), mask.to(device), prefix.to(device, dtype=torch.float32) - outputs = model(tokens, prefix, mask) - logits = outputs.logits[:, dataset.prefix_length - 1: -1] - loss = nnf.cross_entropy(logits.reshape(-1, logits.shape[-1]), tokens.flatten(), ignore_index=0) - loss.backward() - optimizer.step() - scheduler.step() - optimizer.zero_grad() - progress.set_postfix({"loss": loss.item()}) - progress.update() - if (idx + 1) % 10000 == 0: - torch.save( - model.state_dict(), - os.path.join(output_dir, f"{output_prefix}_latest.pt"), - ) - progress.close() - if epoch % args.save_every == 0 or epoch == epochs - 1: - torch.save( - model.state_dict(), - os.path.join(output_dir, f"{output_prefix}-{epoch:03d}.pt"), - ) - return model - - -def main(): - parser = argparse.ArgumentParser() - parser.add_argument('--data', default='/raid/datasets/models_weights/clipcap/training-features/coco_karpathy_split_train.pkl') - parser.add_argument('--out_dir', default='/raid/datasets/models_weights/clipcap/checkpoints/dinov2b14/') - parser.add_argument('--prefix', default='coco_prefix', help='prefix for saved filenames') - parser.add_argument('--epochs', type=int, default=10) - parser.add_argument('--save_every', type=int, default=1) - parser.add_argument('--prefix_length', type=int, default=10) - parser.add_argument('--prefix_length_clip', type=int, default=10) - parser.add_argument('--bs', type=int, default=40) - parser.add_argument('--only_prefix', dest='only_prefix', action='store_true') - parser.add_argument('--mapping_type', type=str, default='mlp', help='mlp/transformer') - parser.add_argument('--num_layers', type=int, default=8) - parser.add_argument('--is_rn', dest='is_rn', action='store_true') - parser.add_argument('--normalize_prefix', dest='normalize_prefix', action='store_true') - # DINO-specific arguments - parser.add_argument('--use_dino', action='store_true', default=False, help='Use DINO features instead of CLIP') - parser.add_argument('--dino_model_type', type=str, default='dinov2_vitb14', - choices=['dinov2_vits14', 'dinov2_vitb14', 'dinov2_vitl14', 'dinov2_vitg14'], - help='DINO model type') - parser.add_argument('--dino_feature_dim', type=int, default=None, - help='DINO feature dimension (auto-detected if None)') - parser.add_argument('--device', type=str, default='cuda:0', help='Device to use for training') - args = parser.parse_args() - - if isinstance(args.device, str): - if not args.device.startswith('cuda') and not args.device.startswith('cpu'): - # if it is an integer index, convert to f'cuda:{args.device}' - if args.device.isdigit(): - args.device = f'cuda:{args.device}' - else: - raise ValueError(f"Invalid device string: {args.device}") - args.device = torch.device(args.device) - - prefix_length = args.prefix_length - dataset = ClipCocoDataset(args.data, prefix_length, normalize_prefix=args.normalize_prefix) - - # Determine prefix dimension based on model type - if args.use_dino: - if args.dino_feature_dim is not None: - prefix_dim = args.dino_feature_dim - else: - # Auto-detect DINO feature dimensions - dino_dims = { - 'dinov2_vits14': 384, - 'dinov2_vitb14': 768, - 'dinov2_vitl14': 1024, - 'dinov2_vitg14': 1536 - } - prefix_dim = dino_dims.get(args.dino_model_type, 768) - print(f"Using DINO features with dimension: {prefix_dim}") - else: - prefix_dim = 640 if args.is_rn else 512 - print(f"Using CLIP features with dimension: {prefix_dim}") - - args.mapping_type = {'mlp': MappingType.MLP, 'transformer': MappingType.Transformer}[args.mapping_type] - if args.only_prefix: - model = ClipCaptionPrefix(prefix_length, clip_length=args.prefix_length_clip, prefix_size=prefix_dim, - num_layers=args.num_layers, mapping_type=args.mapping_type) - print("Train only prefix") - else: - model = ClipCaptionModel(prefix_length, clip_length=args.prefix_length_clip, prefix_size=prefix_dim, - num_layers=args.num_layers, mapping_type=args.mapping_type) - print("Train both prefix and GPT") - sys.stdout.flush() - train(dataset, model, args, output_dir=args.out_dir, output_prefix=args.prefix, device=args.device) - - -if __name__ == '__main__': - main() \ No newline at end of file diff --git a/src/clipcap/clipcap_dino_parse_coco.py b/src/clipcap/clipcap_dino_parse_coco.py deleted file mode 100644 index 14476848f8a26fbe2ba3cab8ab32b0682eee125c..0000000000000000000000000000000000000000 --- a/src/clipcap/clipcap_dino_parse_coco.py +++ /dev/null @@ -1,613 +0,0 @@ -import torch -import torch.nn.functional as F -import skimage.io as io -from PIL import Image -import pickle -import json -import os -from tqdm import tqdm -import argparse -import torchvision.transforms as T -import numpy as np -import yaml -import clip -import sys - -# Add the src directory to the path so we can import ProjectionLayer -sys.path.append(os.path.join(os.path.dirname(__file__), '../..', 'src')) - - -# Container to store intermediate outputs for feature extraction -feats = {} - -def get_self_attention(module, input, output): - """Hook to capture self-attention weights""" - global qkv_attention_out - qkv_attention_out = output - -def get_layer_n_output(module, input, output): - """Hook to capture intermediate layer output""" - feats['intermediate_output'] = output - -def transform_to_standard_dino_out(x, model, num_global_tokens=1): - """Transform raw DINO output to standardized format""" - x_norm = model.norm(x) - if num_global_tokens == 1: - # Standard model without registers - return { - "x_norm_clstoken": x_norm[:, 0], - "x_norm_regtokens": None, - "x_norm_patchtokens": x_norm[:, 1:], - "x_prenorm": x, - } - else: - # Model with registers (num_global_tokens = 5) - return { - "x_norm_clstoken": x_norm[:, 0], - "x_norm_regtokens": x_norm[:, 1:num_global_tokens], - "x_norm_patchtokens": x_norm[:, num_global_tokens:], - "x_prenorm": x, - } - -def process_self_attention(output, batch_size, num_tokens, num_attn_heads, embed_dim, scale, num_global_tokens, ret_self_attn_maps=False): - """Process self-attention output to compute attention weights""" - qkv = output.reshape(batch_size, num_tokens, 3, num_attn_heads, embed_dim // num_attn_heads).permute(2, 0, 3, 1, 4) - q, k, v = qkv[0] * scale, qkv[1], qkv[2] - attn = q @ k.transpose(-2, -1) - self_attn_maps = attn[:, :, 0, num_global_tokens:] # CLS token attention to patches - self_attn = self_attn_maps.mean(dim=1) # Average over attention heads - self_attn = self_attn.softmax(dim=-1) - if ret_self_attn_maps: - return self_attn, self_attn_maps - else: - return self_attn - - -# Global variables to store hook outputs -dino_layer_n_output = None -qkv_attention_out = None - -def get_layer_n_output(module, input, output): - """Hook to capture intermediate layer output""" - global dino_layer_n_output - dino_layer_n_output = output - - -def select_most_significant_patch(dino_outs, self_attn, criteria, cls_token=None, caption_embedding=None): - """ - Select the most significant patch token based on different criteria. - - Args: - dino_outs: Dictionary containing normalized DINO outputs - self_attn: Self-attention weights from CLS to patches [batch_size, num_patches] - criteria: Selection criteria ('max_attention', 'most_similar_to_cls', etc.) - cls_token: CLS token embeddings [batch_size, embed_dim] - caption_embedding: Text caption embeddings [batch_size, embed_dim] - - Returns: - selected_patches: [batch_size, embed_dim] - Selected patch embeddings - """ - patch_tokens = dino_outs['x_norm_patchtokens'] # [batch_size, num_patches, embed_dim] - batch_size, num_patches, embed_dim = patch_tokens.shape - - if criteria == "max_attention": - # Select patch with highest attention weight from CLS token - if self_attn is None: - raise ValueError("self_attn required for max_attention criteria") - max_attn_indices = self_attn.argmax(dim=1) # [batch_size] - selected_patches = patch_tokens[torch.arange(batch_size), max_attn_indices] - - elif criteria == "most_similar_to_cls": - # Select patch most similar to CLS token using cosine similarity - if cls_token is None: - raise ValueError("cls_token required for most_similar_to_cls criteria") - # Compute cosine similarity between CLS and all patches - cls_normalized = F.normalize(cls_token, p=2, dim=1) # [batch_size, embed_dim] - patches_normalized = F.normalize(patch_tokens, p=2, dim=2) # [batch_size, num_patches, embed_dim] - similarities = torch.bmm(patches_normalized, cls_normalized.unsqueeze(2)).squeeze(2) # [batch_size, num_patches] - max_sim_indices = similarities.argmax(dim=1) # [batch_size] - selected_patches = patch_tokens[torch.arange(batch_size), max_sim_indices] - - elif criteria == "most_similar_to_caption": - # Select patch most similar to caption embedding - if caption_embedding is None: - raise ValueError("caption_embedding required for most_similar_to_caption criteria") - caption_normalized = F.normalize(caption_embedding, p=2, dim=1) # [batch_size, embed_dim] - patches_normalized = F.normalize(patch_tokens, p=2, dim=2) # [batch_size, num_patches, embed_dim] - similarities = torch.bmm(patches_normalized, caption_normalized.unsqueeze(2)).squeeze(2) # [batch_size, num_patches] - max_sim_indices = similarities.argmax(dim=1) # [batch_size] - selected_patches = patch_tokens[torch.arange(batch_size), max_sim_indices] - - elif criteria == "max_norm": - # Select patch with highest L2 norm - patch_norms = torch.norm(patch_tokens, p=2, dim=2) # [batch_size, num_patches] - max_norm_indices = patch_norms.argmax(dim=1) # [batch_size] - selected_patches = patch_tokens[torch.arange(batch_size), max_norm_indices] - - elif criteria == "centroid_distance": - # Select patch farthest from the centroid of all patches - centroid = patch_tokens.mean(dim=1, keepdim=True) # [batch_size, 1, embed_dim] - distances = torch.norm(patch_tokens - centroid, p=2, dim=2) # [batch_size, num_patches] - max_dist_indices = distances.argmax(dim=1) # [batch_size] - selected_patches = patch_tokens[torch.arange(batch_size), max_dist_indices] - - else: - raise ValueError(f"Unknown patch selection criteria: {criteria}") - - return selected_patches - - -def load_text_encoder(text_encoder_path, device, config_path=None): - """ - Load a text encoder model for caption similarity. - Supports Talk2Dino, CLIP, and DINO.txt-based text encoders. - """ - if text_encoder_path is None: - return None - - print(f"Loading text encoder from: {text_encoder_path}") - - # Check for DINO.txt model - if text_encoder_path.lower() == 'dinotxt' or text_encoder_path.lower() == 'dino.txt': - # Load DINO.txt model - try: - from src.dinotxt_utils import get_tokenizer - - print("Loading DINO.txt model...") - dinotxt_model = torch.hub.load('facebookresearch/dinov2', 'dinov2_vitl14_reg4_dinotxt_tet1280d20h24l') - dinotxt_model.eval() - dinotxt_model.to(device) - - tokenizer = get_tokenizer() - - return { - 'type': 'dinotxt', - 'model': dinotxt_model, - 'tokenizer': tokenizer - } - - except ImportError: - raise ImportError("Could not import dinotxt_utils. Make sure src/dinotxt_utils.py is accessible.") - except Exception as e: - raise RuntimeError(f"Failed to load DINO.txt model: {e}") - - # Check if it's a Talk2Dino model (expect config and weights) - elif text_encoder_path.endswith('.pth') or text_encoder_path.endswith('.pt'): - # Use provided config or auto-find - if config_path is None: - # Look for corresponding config file - base_path = text_encoder_path.rsplit('.', 1)[0] - config_path = base_path + '.yaml' - - # Alternative config path patterns - if not os.path.exists(config_path): - # Try configs_talk2dino directory - config_name = os.path.basename(base_path) + '.yaml' - config_path = os.path.join(os.path.dirname(__file__), 'configs_talk2dino', config_name) - - if not os.path.exists(config_path): - raise FileNotFoundError(f"Could not find config file for {text_encoder_path}. " - f"Expected at {config_path} or specify --text_encoder_config.") - - # Load Talk2Dino model - try: - from src.model import ProjectionLayer - - print(f"Using config: {config_path}") - - # Load the projection layer - talk2dino = ProjectionLayer.from_config(config_path) - talk2dino.load_state_dict(torch.load(text_encoder_path, map_location=device)) - talk2dino.to(device) - talk2dino.eval() - - # Load CLIP model for text encoding - clip_model, _ = clip.load("ViT-B/32", device=device) - clip_model.eval() - - return { - 'type': 'talk2dino', - 'talk2dino': talk2dino, - 'clip_model': clip_model, - 'config_path': config_path - } - - except ImportError: - raise ImportError("Could not import ProjectionLayer. Make sure src/model.py is accessible.") - - else: - # Assume it's a direct model path (CLIP or other) - try: - # Try loading as a CLIP model - clip_model, _ = clip.load(text_encoder_path, device=device) - clip_model.eval() - - return { - 'type': 'clip', - 'clip_model': clip_model - } - except: - raise ValueError(f"Could not load text encoder from {text_encoder_path}. " - f"Supported formats: 1) 'dinotxt' or 'dino.txt' for DINO.txt model, " - f"2) Talk2Dino (.pth/.pt), 3) CLIP model names.") - - -def encode_caption(caption, text_encoder, device): - """ - Encode a text caption using the loaded text encoder. - """ - if text_encoder is None: - return None - - if text_encoder['type'] == 'dinotxt': - # Use DINO.txt pipeline: tokenize + encode + extract patch-aligned features - with torch.no_grad(): - # Tokenize with DINO.txt tokenizer - text_tokens = text_encoder['tokenizer'].tokenize([caption]).to(device) - - # Encode with DINO.txt model - dinotxt_features = text_encoder['model'].encode_text(text_tokens) - - # Extract patch-aligned text embeddings (dimensions 1024:) - # DINO.txt concatenates standard text features [0:1024] and patch-aligned features [1024:] - patch_aligned_features = dinotxt_features[:, 1024:] - - # Normalize the features to match DINO feature space - patch_aligned_features = F.normalize(patch_aligned_features, p=2, dim=-1) - return patch_aligned_features - - elif text_encoder['type'] == 'talk2dino': - # Use Talk2Dino pipeline: CLIP text encoding + Talk2Dino projection - with torch.no_grad(): - # Tokenize and encode with CLIP - text_tokens = clip.tokenize([caption]).to(device) - clip_text_features = text_encoder['clip_model'].encode_text(text_tokens) - - # Project through Talk2Dino to DINO space - dino_text_features = text_encoder['talk2dino'].project_clip_txt(clip_text_features) - - # Normalize the encoded text to match DINO feature space - dino_text_features = F.normalize(dino_text_features, p=2, dim=-1) - return dino_text_features - - elif text_encoder['type'] == 'clip': - # Use CLIP directly - with torch.no_grad(): - text_tokens = clip.tokenize([caption]).to(device) - clip_text_features = text_encoder['clip_model'].encode_text(text_tokens) - - # Normalize the features - clip_text_features = F.normalize(clip_text_features, p=2, dim=-1) - return clip_text_features - - else: - raise ValueError(f"Unknown text encoder type: {text_encoder['type']}") - - -def main(dino_model_type: str, resize_dim: int = 518, crop_dim: int = 518, - coco_images_dir: str = "/raid/datasets/coco/", captions_file: str = "/raid/datasets/coco/train_split_karpathy.json", - output_file: str = None, feature_type: str = "cls", extract_attention: bool = False, - patch_selection_criteria: str = "max_attention", text_encoder_path: str = None, text_encoder_config: str = None): - """ - Extract DINO features from COCO images for ClipCap training. - - Args: - feature_type: Type of features to extract - - "cls": CLS token features (default) - - "avg_patch": Mean pooled patch token features - - "avg_self_attn": Self-attention weighted patch token features - - "most_significant_patch": Single most important patch token - extract_attention: Whether to extract self-attention weights (required for avg_self_attn) - patch_selection_criteria: Criteria for selecting most significant patch - text_encoder_path: Path to text encoder for caption similarity - """ - device = torch.device('cuda:0') - dino_model_name = dino_model_type.replace('/', '_') - - # Determine model properties - num_global_tokens = 1 if "reg" not in dino_model_type else 5 - patch_size = 14 # DINOv2 uses 14x14 patches - num_patch_tokens = (crop_dim // patch_size) * (crop_dim // patch_size) - num_tokens = num_global_tokens + num_patch_tokens - - # Get embedding dimension based on model type - if 'vitl' in dino_model_type: - embed_dim = 1024 - num_attn_heads = 16 - elif 'vitb' in dino_model_type: - embed_dim = 768 - num_attn_heads = 12 - elif 'vits' in dino_model_type: - embed_dim = 384 - num_attn_heads = 6 - elif 'vitg' in dino_model_type: - embed_dim = 1536 - num_attn_heads = 24 - else: - raise ValueError(f"Unknown model type: {dino_model_type}") - - scale = (embed_dim // num_attn_heads) ** -0.5 - - # Set default output path if not specified - if output_file is None: - if feature_type == "cls": - feature_suffix = "" - elif feature_type == "most_significant_patch": - if patch_selection_criteria == "most_similar_to_caption" and text_encoder_path is not None: - # Determine text encoder type from path to create unique filename - if text_encoder_path.lower() in ['dinotxt', 'dino.txt']: - text_encoder_suffix = "_dinotxt" - elif text_encoder_path.endswith('.pth') or text_encoder_path.endswith('.pt'): - text_encoder_suffix = "_t2d" # Talk2Dino - else: - # CLIP or other models - text_encoder_suffix = "_clip" - feature_suffix = f"_{feature_type}_{patch_selection_criteria}{text_encoder_suffix}" - else: - feature_suffix = f"_{feature_type}_{patch_selection_criteria}" - else: - feature_suffix = f"_{feature_type}" - output_file = f"/raid/datasets/models_weights/clipcap/training-features/coco_karpathy_split_{dino_model_name}{feature_suffix}_train.pkl" - - # Create output directory if it doesn't exist - os.makedirs(os.path.dirname(output_file), exist_ok=True) - - # Load DINO model - print(f"Loading DINO model: {dino_model_type}") - print(f"Feature type: {feature_type}") - if feature_type == "most_significant_patch": - print(f"Patch selection criteria: {patch_selection_criteria}") - print(f"Model properties: embed_dim={embed_dim}, num_heads={num_attn_heads}, num_global_tokens={num_global_tokens}") - - if 'dinov2' in dino_model_type: - model_family = 'facebookresearch/dinov2' - dino_model = torch.hub.load(model_family, dino_model_type) - else: - raise ValueError(f"Unsupported DINO model type: {dino_model_type}") - - # Setup transforms for DINO - image_transforms = T.Compose([ - T.Resize(resize_dim, interpolation=T.InterpolationMode.BICUBIC), - T.CenterCrop(crop_dim), - T.ToTensor(), - T.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)), - ]) - - dino_model.eval() - dino_model.to(device) - - # Register hooks if we need attention or intermediate outputs - if feature_type == "avg_self_attn" or extract_attention or \ - (feature_type == "most_significant_patch" and patch_selection_criteria in ["max_attention", "most_similar_to_caption"]): - print("Registering hooks for attention extraction...") - dino_model.blocks[-1].attn.qkv.register_forward_hook(get_self_attention) - - if feature_type in ["avg_patch", "avg_self_attn", "most_significant_patch"]: - print("Registering hooks for intermediate output extraction...") - dino_model.blocks[-1].register_forward_hook(get_layer_n_output) - - # Load caption data - print(f"Loading captions from: {captions_file}") - with open(captions_file, 'r') as f: - data = json.load(f) - - # Handle different annotation formats - if isinstance(data, list): - # Original ClipCap format: list of dicts with 'image_id' and 'caption' - annotations = data - print(f"{len(annotations)} captions loaded from json (ClipCap format)") - elif isinstance(data, dict) and 'annotations' in data: - # Karpathy format: dict with 'annotations' key - annotations = data['annotations'] - print(f"{len(annotations)} captions loaded from json (Karpathy format)") - - # Create image ID to filename mapping for faster lookup - if 'images' in data: - image_id_to_filename = {img['id']: img['file_name'] for img in data['images']} - else: - image_id_to_filename = {} - else: - raise ValueError("Unsupported annotation format") - - # Load text encoder if needed for caption similarity - text_encoder = None - if feature_type == "most_significant_patch" and patch_selection_criteria == "most_similar_to_caption": - if text_encoder_path is None: - raise ValueError("text_encoder_path required for most_similar_to_caption criteria") - text_encoder = load_text_encoder(text_encoder_path, device, text_encoder_config) - print(f"Loaded text encoder from: {text_encoder_path}") - - all_embeddings = [] - all_captions = [] - - print(f"Processing images from: {coco_images_dir}") - print(f"Output will be saved to: {output_file}") - - for i, annotation in enumerate(tqdm(annotations)): - img_id = annotation["image_id"] - - # Determine filename based on format - if isinstance(data, list): - # Original format: construct filename from image_id - filename = os.path.join(coco_images_dir, "train2014", f"COCO_train2014_{int(img_id):012d}.jpg") - if not os.path.isfile(filename): - filename = os.path.join(coco_images_dir, "val2014", f"COCO_val2014_{int(img_id):012d}.jpg") - else: - # Karpathy format: use filename from images mapping or construct it - if img_id in image_id_to_filename: - if 'train' in image_id_to_filename[img_id]: - fold = "train2014" - else: - fold = "val2014" - filename = os.path.join(coco_images_dir, fold, image_id_to_filename[img_id]) - else: - # Fallback: try to construct filename - filename = os.path.join(coco_images_dir, "train2014", f"COCO_train2014_{int(img_id):012d}.jpg") - if not os.path.isfile(filename): - filename = os.path.join(coco_images_dir, "val2014", f"COCO_val2014_{int(img_id):012d}.jpg") - - if not os.path.isfile(filename): - print(f"Warning: Image not found: {filename}") - continue - - # Load and process image - try: - image = io.imread(filename) - if len(image.shape) == 2: # grayscale - image = Image.fromarray(image).convert('RGB') - else: - image = Image.fromarray(image) - except Exception as e: - print(f"Warning: Failed to load image {filename}: {e}") - continue - - # Apply DINO transforms - image_tensor = image_transforms(image).unsqueeze(0).to(device) - - with torch.no_grad(): - # Clear any previous stored data - global dino_layer_n_output, qkv_attention_out - dino_layer_n_output = None - qkv_attention_out = None - - # Extract DINO features - if feature_type == "cls": - # Standard CLS token extraction - features = dino_model(image_tensor) - # For DINOv2, the output is the CLS token by default - if len(features.shape) == 3: # If we get [batch, seq_len, dim] - features = features[:, 0, :] # Take CLS token - prefix = features.cpu() - else: - # For patch-based features, we need intermediate outputs - _ = dino_model(image_tensor) # Forward pass to trigger hooks - - if dino_layer_n_output is None: - raise RuntimeError("No intermediate output captured. Check hook registration.") - - # Transform to standard format - dino_outs = transform_to_standard_dino_out(dino_layer_n_output, dino_model, num_global_tokens) - - if feature_type == "avg_patch": - # Average of patch tokens (excluding global tokens) - prefix = dino_outs['x_norm_patchtokens'].mean(dim=1) # [B, D] - elif feature_type == "avg_self_attn": - # Self-attention weighted average of patch tokens - if qkv_attention_out is None: - raise RuntimeError("No attention output captured. Check hook registration.") - - # Process self-attention to get attention weights - batch_size = qkv_attention_out.shape[0] - self_attn = process_self_attention( - qkv_attention_out, - batch_size, - num_tokens, - num_attn_heads, - embed_dim, - scale, - num_global_tokens - ) - - # Compute attention-weighted average - prefix = (self_attn.unsqueeze(-1) * dino_outs['x_norm_patchtokens']).mean(dim=1) - elif feature_type == "most_significant_patch": - # Select single most significant patch based on criteria - self_attn = None - cls_token = None - caption_embedding = None - - # Prepare required inputs based on criteria - if patch_selection_criteria in ["max_attention", "most_similar_to_caption"]: - if qkv_attention_out is None: - raise RuntimeError("No attention output captured. Check hook registration.") - batch_size = qkv_attention_out.shape[0] - self_attn = process_self_attention( - qkv_attention_out, - batch_size, - num_tokens, - num_attn_heads, - embed_dim, - scale, - num_global_tokens - ) - - if patch_selection_criteria == "most_similar_to_cls": - cls_token = dino_outs['x_norm_clstoken'] - - if patch_selection_criteria == "most_similar_to_caption": - if text_encoder is not None: - caption_embedding = encode_caption(annotation["caption"], text_encoder, device) - - # Select the most significant patch - prefix = select_most_significant_patch( - dino_outs, - self_attn, - patch_selection_criteria, - cls_token=cls_token, - caption_embedding=caption_embedding - ) - else: - raise ValueError(f"Unknown feature type: {feature_type}") - - prefix = prefix.cpu() - - # Create annotation in ClipCap format for compatibility - caption_entry = { - "image_id": img_id, - "caption": annotation["caption"], - "clip_embedding": i # Index for the embedding - } - - all_embeddings.append(prefix) - all_captions.append(caption_entry) - - if (i + 1) % 10000 == 0: - # Create output directory if it doesn't exist - os.makedirs(os.path.dirname(output_file), exist_ok=True) - with open(output_file, 'wb') as f: - pickle.dump({"clip_embedding": torch.cat(all_embeddings, dim=0), "captions": all_captions}, f) - - # Create output directory if it doesn't exist - os.makedirs(os.path.dirname(output_file), exist_ok=True) - with open(output_file, 'wb') as f: - pickle.dump({"clip_embedding": torch.cat(all_embeddings, dim=0), "captions": all_captions}, f) - - print('Done') - print("%0d embeddings saved " % len(all_embeddings)) - print(f"Feature dimension: {all_embeddings[0].shape[-1]}") - return 0 - - -if __name__ == '__main__': - parser = argparse.ArgumentParser(description='Extract DINO features from COCO images for ClipCap training') - parser.add_argument('--dino_model_type', default="dinov2_vitb14", - choices=('dinov2_vits14', 'dinov2_vitb14', 'dinov2_vitl14', 'dinov2_vitg14', - 'dinov2_vits14_reg', 'dinov2_vitb14_reg', 'dinov2_vitl14_reg', 'dinov2_vitg14_reg'), - help='DINO model type to use for feature extraction') - parser.add_argument('--feature_type', default="cls", - choices=('cls', 'avg_patch', 'avg_self_attn', 'most_significant_patch'), - help='Type of features to extract: cls (CLS token), avg_patch (mean pooled patches), avg_self_attn (attention-weighted patches), most_significant_patch (single most important patch)') - parser.add_argument('--patch_selection_criteria', default="max_attention", - choices=('max_attention', 'most_similar_to_cls', 'most_similar_to_caption', 'max_norm', 'centroid_distance'), - help='Criteria for selecting the most significant patch (only used with most_significant_patch feature_type)') - parser.add_argument('--text_encoder_path', type=str, default=None, - help='Path to text encoder for caption similarity. Supports: 1) "dinotxt" or "dino.txt" for DINO.txt model, 2) Talk2Dino weights (.pth/.pt) - will auto-find config, 3) CLIP model names (e.g., "ViT-B/32")') - parser.add_argument('--text_encoder_config', type=str, default=None, - help='Optional: explicit config path for Talk2Dino models (if not auto-found)') - parser.add_argument('--resize_dim', type=int, default=518, help='Resize dimension for images') - parser.add_argument('--crop_dim', type=int, default=518, help='Crop dimension for images') - parser.add_argument('--coco_images_dir', type=str, default="/raid/datasets/coco", - help='Path to COCO images directory (should contain train2014/ and val2014/ subdirs)') - parser.add_argument('--captions_file', type=str, default="/raid/datasets/coco/train_split_karpathy.json", - help='Path to COCO captions JSON file (supports both Karpathy and ClipCap formats)') - parser.add_argument('--output_file', type=str, default=None, - help='Output pickle file path (default: auto-generated based on model and feature type)') - parser.add_argument('--extract_attention', action='store_true', - help='Extract attention weights (automatically enabled for avg_self_attn feature type)') - - args = parser.parse_args() - - main(args.dino_model_type, args.resize_dim, args.crop_dim, - args.coco_images_dir, args.captions_file, args.output_file, - args.feature_type, args.extract_attention, - args.patch_selection_criteria, args.text_encoder_path, args.text_encoder_config) \ No newline at end of file diff --git a/src/clipcap/clipcap_parse_coco.py b/src/clipcap/clipcap_parse_coco.py deleted file mode 100644 index 8f9497d2884efaece87e4d86098dd090a6b75a52..0000000000000000000000000000000000000000 --- a/src/clipcap/clipcap_parse_coco.py +++ /dev/null @@ -1,51 +0,0 @@ -import torch -import skimage.io as io -import clip -from PIL import Image -import pickle -import json -import os -from tqdm import tqdm -import argparse - - -def main(clip_model_type: str): - device = torch.device('cuda:0') - clip_model_name = clip_model_type.replace('/', '_') - out_path = f"./data/coco/oscar_split_{clip_model_name}_train.pkl" - clip_model, preprocess = clip.load(clip_model_type, device=device, jit=False) - with open('./data/coco/annotations/train_caption.json', 'r') as f: - data = json.load(f) - print("%0d captions loaded from json " % len(data)) - all_embeddings = [] - all_captions = [] - for i in tqdm(range(len(data))): - d = data[i] - img_id = d["image_id"] - filename = f"./data/coco/train2014/COCO_train2014_{int(img_id):012d}.jpg" - if not os.path.isfile(filename): - filename = f"./data/coco/val2014/COCO_val2014_{int(img_id):012d}.jpg" - image = io.imread(filename) - image = preprocess(Image.fromarray(image)).unsqueeze(0).to(device) - with torch.no_grad(): - prefix = clip_model.encode_image(image).cpu() - d["clip_embedding"] = i - all_embeddings.append(prefix) - all_captions.append(d) - if (i + 1) % 10000 == 0: - with open(out_path, 'wb') as f: - pickle.dump({"clip_embedding": torch.cat(all_embeddings, dim=0), "captions": all_captions}, f) - - with open(out_path, 'wb') as f: - pickle.dump({"clip_embedding": torch.cat(all_embeddings, dim=0), "captions": all_captions}, f) - - print('Done') - print("%0d embeddings saved " % len(all_embeddings)) - return 0 - - -if __name__ == '__main__': - parser = argparse.ArgumentParser() - parser.add_argument('--clip_model_type', default="ViT-B/32", choices=('RN50', 'RN101', 'RN50x4', 'ViT-B/32')) - args = parser.parse_args() - exit(main(args.clip_model_type)) \ No newline at end of file diff --git a/src/clipcap/entrypoint.py b/src/clipcap/entrypoint.py deleted file mode 100644 index ad039dd73bc8ea72e2aab2f3d6af8b563194ceb8..0000000000000000000000000000000000000000 --- a/src/clipcap/entrypoint.py +++ /dev/null @@ -1,564 +0,0 @@ -import torch -from torch import nn -import json -import os -from transformers import GPT2Tokenizer, GPT2LMHeadModel -from typing import List, Optional, Tuple, Union -from argparse import Namespace -from enum import Enum - -import torch.nn.functional as nnf - -class MappingType(Enum): - MLP = 'mlp' - Transformer = 'transformer' - - -class MLP(nn.Module): - def forward(self, x: torch.Tensor) -> torch.Tensor: - return self.model(x) - - def __init__(self, sizes: Tuple[int, ...], bias=True, act=nn.Tanh): - super(MLP, self).__init__() - layers = [] - for i in range(len(sizes) - 1): - layers.append(nn.Linear(sizes[i], sizes[i + 1], bias=bias)) - if i < len(sizes) - 2: - layers.append(act()) - self.model = nn.Sequential(*layers) - - -class MlpTransformer(nn.Module): - def __init__(self, in_dim, h_dim, out_d: Optional[int] = None, act=nnf.relu, dropout=0.): - super().__init__() - out_d = out_d if out_d is not None else in_dim - self.fc1 = nn.Linear(in_dim, h_dim) - self.act = act - self.fc2 = nn.Linear(h_dim, out_d) - self.dropout = nn.Dropout(dropout) - - def forward(self, x): - x = self.fc1(x) - x = self.act(x) - x = self.dropout(x) - x = self.fc2(x) - x = self.dropout(x) - return x - -class MultiHeadAttention(nn.Module): - - def __init__(self, dim_self, dim_ref, num_heads, bias=True, dropout=0.): - super().__init__() - self.num_heads = num_heads - head_dim = dim_self // num_heads - self.scale = head_dim ** -0.5 - self.to_queries = nn.Linear(dim_self, dim_self, bias=bias) - self.to_keys_values = nn.Linear(dim_ref, dim_self * 2, bias=bias) - self.project = nn.Linear(dim_self, dim_self) - self.dropout = nn.Dropout(dropout) - - def forward(self, x, y=None, mask=None): - y = y if y is not None else x - b, n, c = x.shape - _, m, d = y.shape - # b n h dh - queries = self.to_queries(x).reshape(b, n, self.num_heads, c // self.num_heads) - # b m 2 h dh - keys_values = self.to_keys_values(y).reshape(b, m, 2, self.num_heads, c // self.num_heads) - keys, values = keys_values[:, :, 0], keys_values[:, :, 1] - attention = torch.einsum('bnhd,bmhd->bnmh', queries, keys) * self.scale - if mask is not None: - if mask.dim() == 2: - mask = mask.unsqueeze(1) - attention = attention.masked_fill(mask.unsqueeze(3), float("-inf")) - attention = attention.softmax(dim=2) - out = torch.einsum('bnmh,bmhd->bnhd', attention, values).reshape(b, n, c) - out = self.project(out) - return out, attention - - -class TransformerLayer(nn.Module): - - def forward_with_attention(self, x, y=None, mask=None): - x_, attention = self.attn(self.norm1(x), y, mask) - x = x + x_ - x = x + self.mlp(self.norm2(x)) - return x, attention - - def forward(self, x, y=None, mask=None): - x = x + self.attn(self.norm1(x), y, mask)[0] - x = x + self.mlp(self.norm2(x)) - return x - - def __init__(self, dim_self, dim_ref, num_heads, mlp_ratio=4., bias=False, dropout=0., act=nnf.relu, - norm_layer: nn.Module = nn.LayerNorm): - super().__init__() - self.norm1 = norm_layer(dim_self) - self.attn = MultiHeadAttention(dim_self, dim_ref, num_heads, bias=bias, dropout=dropout) - self.norm2 = norm_layer(dim_self) - self.mlp = MlpTransformer(dim_self, int(dim_self * mlp_ratio), act=act, dropout=dropout) - - -class Transformer(nn.Module): - - def forward_with_attention(self, x, y=None, mask=None): - attentions = [] - for layer in self.layers: - x, att = layer.forward_with_attention(x, y, mask) - attentions.append(att) - return x, attentions - - def forward(self, x, y=None, mask=None): - for i, layer in enumerate(self.layers): - if i % 2 == 0 and self.enc_dec: # cross - x = layer(x, y) - elif self.enc_dec: # self - x = layer(x, x, mask) - else: # self or cross - x = layer(x, y, mask) - return x - - def __init__(self, dim_self: int, num_heads: int, num_layers: int, dim_ref: Optional[int] = None, - mlp_ratio: float = 2., act=nnf.relu, norm_layer: nn.Module = nn.LayerNorm, enc_dec: bool = False): - super(Transformer, self).__init__() - dim_ref = dim_ref if dim_ref is not None else dim_self - self.enc_dec = enc_dec - if enc_dec: - num_layers = num_layers * 2 - layers = [] - for i in range(num_layers): - if i % 2 == 0 and enc_dec: # cross - layers.append(TransformerLayer(dim_self, dim_ref, num_heads, mlp_ratio, act=act, norm_layer=norm_layer)) - elif enc_dec: # self - layers.append(TransformerLayer(dim_self, dim_self, num_heads, mlp_ratio, act=act, norm_layer=norm_layer)) - else: # self or cross - layers.append(TransformerLayer(dim_self, dim_ref, num_heads, mlp_ratio, act=act, norm_layer=norm_layer)) - self.layers = nn.ModuleList(layers) - -class TransformerMapper(nn.Module): - def forward(self, x): - x = self.linear(x).view(x.shape[0], self.clip_length, -1) - prefix = self.prefix_const.unsqueeze(0).expand(x.shape[0], *self.prefix_const.shape) - prefix = torch.cat((x, prefix), dim=1) - out = self.transformer(prefix)[:, self.clip_length:] - return out - - def __init__(self, dim_clip: int, dim_embedding: int, prefix_length: int, clip_length: int, num_layers: int = 8): - super(TransformerMapper, self).__init__() - self.clip_length = clip_length - self.transformer = Transformer(dim_embedding, 8, num_layers) #nn.Transformer(d_model=dim_embedding, nhead=8, num_encoder_layers=num_layers) - self.linear = nn.Linear(dim_clip, clip_length * dim_embedding) - self.prefix_const = nn.Parameter(torch.randn(prefix_length, dim_embedding), requires_grad=True) - - -class ClipCaptionModel(nn.Module): - - def get_dummy_token(self, batch_size: int, device: torch.device) -> torch.Tensor: - return torch.zeros(batch_size, self.prefix_length, dtype=torch.int64, device=device) - - def forward(self, tokens: torch.Tensor, prefix: torch.Tensor, mask: Optional[torch.Tensor] = None, - labels: Optional[torch.Tensor] = None): - embedding_text = self.gpt.transformer.wte(tokens) - prefix_projections = self.clip_project(prefix).view(-1, self.prefix_length, self.gpt_embedding_size) - embedding_cat = torch.cat((prefix_projections, embedding_text), dim=1) - if labels is not None: - dummy_token = self.get_dummy_token(tokens.shape[0], tokens.device) - labels = torch.cat((dummy_token, tokens), dim=1) - out = self.gpt(inputs_embeds=embedding_cat, labels=labels, attention_mask=mask) - return out - - def __init__(self, prefix_length: int, clip_length: Optional[int] = None, prefix_size: int = 512, - num_layers: int = 8, mapping_type: MappingType = MappingType.MLP): - super(ClipCaptionModel, self).__init__() - self.prefix_length = prefix_length - self.gpt = GPT2LMHeadModel.from_pretrained('gpt2') - self.gpt_embedding_size = self.gpt.transformer.wte.weight.shape[1] - if mapping_type == MappingType.MLP: - self.clip_project = MLP((prefix_size, (self.gpt_embedding_size * prefix_length) // 2, - self.gpt_embedding_size * prefix_length)) - else: - self.clip_project = TransformerMapper(prefix_size, self.gpt_embedding_size, prefix_length, - clip_length, num_layers) - - -class ClipCaptionPrefix(ClipCaptionModel): - - def parameters(self, recurse: bool = True): - return self.clip_project.parameters() - - def train(self, mode: bool = True): - super(ClipCaptionPrefix, self).train(mode) - self.gpt.eval() - return self - - -def generate_batched( - model, - tokenizer, - prefix_embeds, - entry_length=67, - top_p=0.8, - temperature=1.0, - stop_token: str = '.', -): - """ - Batched text generation for ClipCap models. - - Args: - model: ClipCap model - tokenizer: GPT2 tokenizer - prefix_embeds: (batch_size, prefix_length, embedding_dim) - prefix embeddings - entry_length: Maximum sequence length to generate - top_p: Nucleus sampling parameter - temperature: Sampling temperature - stop_token: Token to stop generation - - Returns: - List[str]: Generated captions for each item in batch - """ - model.eval() - device = next(model.parameters()).device - batch_size = prefix_embeds.shape[0] - - # Initialize - stop_token_index = tokenizer.encode(stop_token)[0] - filter_value = -float("Inf") - - # Track which sequences are still generating - active_sequences = torch.ones(batch_size, dtype=torch.bool, device=device) - - # Initialize token sequences - start with None - tokens = None - generated_embeds = prefix_embeds # Start with prefix embeddings - - with torch.no_grad(): - for step in range(entry_length): - # Forward pass for all active sequences - outputs = model.gpt(inputs_embeds=generated_embeds) - logits = outputs.logits[:, -1, :] # Get logits for last token: (batch_size, vocab_size) - - # Apply temperature - logits = logits / (temperature if temperature > 0 else 1.0) - - # Apply nucleus sampling for each sequence in batch - for i in range(batch_size): - if not active_sequences[i]: - continue - - # Sort logits for this sequence - sorted_logits, sorted_indices = torch.sort(logits[i], descending=True) - cumulative_probs = torch.cumsum(torch.softmax(sorted_logits, dim=-1), dim=-1) - - # Find indices to remove (above top_p threshold) - sorted_indices_to_remove = cumulative_probs > top_p - sorted_indices_to_remove[1:] = sorted_indices_to_remove[:-1].clone() - sorted_indices_to_remove[0] = 0 - - # Set logits to -inf for tokens to remove - indices_to_remove = sorted_indices[sorted_indices_to_remove] - logits[i, indices_to_remove] = filter_value - - # Clamp logits to avoid extreme values - logits = torch.clamp(logits, min=-1e9, max=1e9) # keep values bounded - # Sample next tokens for all sequences - probs = torch.softmax(logits, dim=-1) - - # if some sequences probs tensor contains NaNs (e.g. all logits were -inf), set stop_token_index prob to 1 - for i in range(batch_size): - if torch.isnan(probs[i]).all(): #if not torch.isfinite(probs[i]).any() or probs[i].sum() == 0: - probs[i] = torch.zeros_like(probs[i]) - probs[i, stop_token_index] = 1.0 - - next_tokens = torch.multinomial(probs, num_samples=1) # (batch_size, 1) - - # Get embeddings for next tokens - next_token_embeds = model.gpt.transformer.wte(next_tokens) # (batch_size, 1, embed_dim) - - # Update token sequences - if tokens is None: - tokens = next_tokens - else: - tokens = torch.cat((tokens, next_tokens), dim=1) - - # Update generated embeddings - generated_embeds = torch.cat((generated_embeds, next_token_embeds), dim=1) - - # Check for stop tokens and update active sequences - for i in range(batch_size): - if active_sequences[i] and next_tokens[i].item() == stop_token_index: - active_sequences[i] = False - - # If all sequences have stopped, break early - if not active_sequences.any(): - break - - # Decode all sequences - captions = [] - for i in range(batch_size): - if tokens is not None: - token_list = tokens[i].cpu().numpy().tolist() - # Remove padding and decode - caption = tokenizer.decode(token_list) - # Clean up the caption - caption = caption.split(stop_token)[0] + stop_token - captions.append(caption) - else: - captions.append("") - - return captions - - -def generate2( - model, - tokenizer, - tokens=None, - prompt=None, - embed=None, - entry_count=1, - entry_length=67, # maximum number of words - top_p=0.8, - temperature=1., - stop_token: str = '.', -): - """ - Legacy single-sequence generation function. - For new code, use generate_batched instead. - """ - model.eval() - generated_num = 0 - generated_list = [] - stop_token_index = tokenizer.encode(stop_token)[0] - filter_value = -float("Inf") - device = next(model.parameters()).device - - with torch.no_grad(): - - for entry_idx in range(entry_count): - if embed is not None: - generated = embed - else: - if tokens is None: - tokens = torch.tensor(tokenizer.encode(prompt)) - tokens = tokens.unsqueeze(0).to(device) - - generated = model.gpt.transformer.wte(tokens) - - for i in range(entry_length): - - outputs = model.gpt(inputs_embeds=generated) - logits = outputs.logits - logits = logits[:, -1, :] / (temperature if temperature > 0 else 1.0) - sorted_logits, sorted_indices = torch.sort(logits, descending=True) - cumulative_probs = torch.cumsum(torch.softmax(sorted_logits, dim=-1), dim=-1) - sorted_indices_to_remove = cumulative_probs > top_p - sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[ - ..., :-1 - ].clone() - sorted_indices_to_remove[..., 0] = 0 - - indices_to_remove = sorted_indices[sorted_indices_to_remove] - logits[:, indices_to_remove] = filter_value - next_token = torch.multinomial(torch.softmax(logits, dim=-1), num_samples=1) - next_token_embed = model.gpt.transformer.wte(next_token) - if tokens is None: - tokens = next_token - else: - tokens = torch.cat((tokens, next_token), dim=1) - generated = torch.cat((generated, next_token_embed), dim=1) - if stop_token_index == next_token.item(): - break - - output_list = list(tokens.squeeze().cpu().numpy()) - output_text = tokenizer.decode(output_list) - generated_list.append(output_text) - - return generated_list[0] - - -class ClipCapModel(torch.nn.Module): - """ - ClipCap integration for the Patchioner class. - """ - - def __init__(self, args, device, dino_feature_dim=768): - super(ClipCapModel, self).__init__() - args_dict = args.copy() - self.args = args = self.load_config(args) - self.device = device - self.dino_feature_dim = dino_feature_dim - - # Initialize tokenizer - self.tokenizer = GPT2Tokenizer.from_pretrained(args.language_model) - if self.tokenizer.pad_token_id is None: - self.tokenizer.pad_token_id = self.tokenizer.eos_token_id - - # Determine mapping type - mapping_type = MappingType.MLP if args.mapping_type.lower() == 'mlp' else MappingType.Transformer - - # Initialize model with DINO feature dimensions - if args.only_prefix: - self.model = ClipCaptionPrefix( - prefix_length=args.prefix_length, - clip_length=args.clip_length, - prefix_size=dino_feature_dim, - num_layers=args.num_layers, - mapping_type=mapping_type - ) - else: - self.model = ClipCaptionModel( - prefix_length=args.prefix_length, - clip_length=args.clip_length, - prefix_size=dino_feature_dim, - num_layers=args.num_layers, - mapping_type=mapping_type - ) - - # Load trained weights - print(f"Loading ClipCap weights from: {args.weight_path}") - checkpoint = torch.load(args.weight_path, map_location=device) - self.model.load_state_dict(checkpoint, strict=False) - self.model.to(device) - self.model.eval() - - defaults = { - "language_model": "gpt2", - "prefix_length": 10, - "clip_length": 10, - "num_layers": 8, - "mapping_type": "mlp", - "only_prefix": True, - "temperature": 1.0, - "top_p": 0.8, - "entry_length": 67, - "stop_token": ".", - "use_batched_generation": True, # Use batched generation by default - "normalize_prefix": False, # Whether to L2 normalize the input features - "weight_path": "/raid/datasets/models_weights/clipcap/training-features/clipcap_dino_vitb14_len10_mlp.pt" - } - - def load_config(self, args_dict: dict) -> Namespace: - def dict_to_namespace(d): - if isinstance(d, dict): - return Namespace(**{k: dict_to_namespace(v) for k, v in d.items()}) - return d - - # Apply defaults - for key, value in self.defaults.items(): - if isinstance(value, dict): - for sub_key, sub_value in value.items(): - args_dict.setdefault(key, {}).setdefault(sub_key, sub_value) - else: - args_dict.setdefault(key, value) - - args = dict_to_namespace(args_dict) - return args - - def forward(self, dino_features, compute_scores: bool = False) -> List[str]: - """ - DINO Features: (batch_size, dino_feature_dim) - - returns: List[str] of generated captions - """ - if self.args.use_batched_generation: - return self.forward_batched(dino_features, compute_scores) - else: - return self.forward_sequential(dino_features, compute_scores) - - def forward_batched(self, dino_features, compute_scores: bool = False) -> List[str]: - """ - Efficient batched generation for multiple sequences. - """ - batch_size = dino_features.shape[0] - - # Apply normalization if specified (to match training) - if self.args.normalize_prefix: - dino_features = dino_features / dino_features.norm(dim=-1, keepdim=True) - - # Generate prefix embeddings for entire batch - with torch.no_grad(): - prefix_embeds = self.model.clip_project(dino_features).view( - batch_size, self.args.prefix_length, -1 - ) - - # Generate captions for entire batch - captions = generate_batched( - model=self.model, - tokenizer=self.tokenizer, - prefix_embeds=prefix_embeds, - entry_length=self.args.entry_length, - temperature=self.args.temperature, - top_p=self.args.top_p, - stop_token=self.args.stop_token - ) - - if compute_scores: - # Compute perplexity scores for generated captions - scores = self.compute_perplexity_scores(captions) - return captions, scores - else: - return captions - - def forward_sequential(self, dino_features, compute_scores: bool = False) -> List[str]: - """ - Sequential generation for backward compatibility or debugging. - """ - batch_size = dino_features.shape[0] - captions = [] - scores = [] - - # Process each feature in the batch sequentially - for i in range(batch_size): - feature = dino_features[i:i+1] # Keep batch dimension - - # Apply normalization if enabled - if self.args.normalize_prefix: - feature = feature / feature.norm(dim=-1, keepdim=True) - - # Generate prefix embeddings - with torch.no_grad(): - prefix_embed = self.model.clip_project(feature).view(1, self.args.prefix_length, -1) - - # Generate caption using legacy function - caption = generate2( - model=self.model, - tokenizer=self.tokenizer, - embed=prefix_embed, - entry_length=self.args.entry_length, - temperature=self.args.temperature, - top_p=self.args.top_p, - stop_token=self.args.stop_token - ) - - captions.append(caption) - if compute_scores: - # Compute perplexity for this caption - score = self.compute_perplexity_scores([caption])[0] - scores.append(score) - - return captions if not compute_scores else (captions, scores) - - def compute_perplexity_scores(self, captions: List[str]) -> List[float]: - """ - Compute perplexity scores for generated captions. - """ - scores = [] - self.model.eval() - - with torch.no_grad(): - for caption in captions: - try: - # Tokenize caption - tokens = self.tokenizer.encode(caption, return_tensors='pt').to(self.device) - - # Compute loss (negative log-likelihood) - outputs = self.model.gpt(input_ids=tokens, labels=tokens) - loss = outputs.loss - - # Convert to perplexity (lower is better, but we'll use 1/perplexity as score) - perplexity = torch.exp(loss).item() - score = 1.0 / perplexity if perplexity > 0 else 1.0 - scores.append(score) - except: - # Fallback score if computation fails - scores.append(1.0) - - return scores \ No newline at end of file diff --git a/src/clipcap/predict.py b/src/clipcap/predict.py deleted file mode 100644 index a627250c27b24e32cb71901463a4bb2d664bdbcc..0000000000000000000000000000000000000000 --- a/src/clipcap/predict.py +++ /dev/null @@ -1,302 +0,0 @@ -# Prediction interface for Cog โš™๏ธ -# Reference: https://github.com/replicate/cog/blob/main/docs/python.md - -import clip -import os -from torch import nn -import numpy as np -import torch -import torch.nn.functional as nnf -import sys -from typing import Tuple, List, Union, Optional -from transformers import ( - GPT2Tokenizer, - GPT2LMHeadModel, - AdamW, - get_linear_schedule_with_warmup, -) -import skimage.io as io -import PIL.Image - -import cog - -# import torch - -N = type(None) -V = np.array -ARRAY = np.ndarray -ARRAYS = Union[Tuple[ARRAY, ...], List[ARRAY]] -VS = Union[Tuple[V, ...], List[V]] -VN = Union[V, N] -VNS = Union[VS, N] -T = torch.Tensor -TS = Union[Tuple[T, ...], List[T]] -TN = Optional[T] -TNS = Union[Tuple[TN, ...], List[TN]] -TSN = Optional[TS] -TA = Union[T, ARRAY] - -WEIGHTS_PATHS = { - "coco": "coco_weights.pt", - "conceptual-captions": "conceptual_weights.pt", -} - -D = torch.device -CPU = torch.device("cpu") - - -class Predictor(cog.Predictor): - def setup(self): - """Load the model into memory to make running multiple predictions efficient""" - self.device = torch.device("cuda") - self.clip_model, self.preprocess = clip.load( - "ViT-B/32", device=self.device, jit=False - ) - self.tokenizer = GPT2Tokenizer.from_pretrained("gpt2") - - self.models = {} - self.prefix_length = 10 - for key, weights_path in WEIGHTS_PATHS.items(): - model = ClipCaptionModel(self.prefix_length) - model.load_state_dict(torch.load(weights_path, map_location=CPU)) - model = model.eval() - model = model.to(self.device) - self.models[key] = model - - @cog.input("image", type=cog.Path, help="Input image") - @cog.input( - "model", - type=str, - options=WEIGHTS_PATHS.keys(), - default="coco", - help="Model to use", - ) - @cog.input( - "use_beam_search", - type=bool, - default=False, - help="Whether to apply beam search to generate the output text", - ) - def predict(self, image, model, use_beam_search): - """Run a single prediction on the model""" - image = io.imread(image) - model = self.models[model] - pil_image = PIL.Image.fromarray(image) - image = self.preprocess(pil_image).unsqueeze(0).to(self.device) - with torch.no_grad(): - prefix = self.clip_model.encode_image(image).to( - self.device, dtype=torch.float32 - ) - prefix_embed = model.clip_project(prefix).reshape(1, self.prefix_length, -1) - if use_beam_search: - return generate_beam(model, self.tokenizer, embed=prefix_embed)[0] - else: - return generate2(model, self.tokenizer, embed=prefix_embed) - - -class MLP(nn.Module): - def forward(self, x: T) -> T: - return self.model(x) - - def __init__(self, sizes: Tuple[int, ...], bias=True, act=nn.Tanh): - super(MLP, self).__init__() - layers = [] - for i in range(len(sizes) - 1): - layers.append(nn.Linear(sizes[i], sizes[i + 1], bias=bias)) - if i < len(sizes) - 2: - layers.append(act()) - self.model = nn.Sequential(*layers) - - -class ClipCaptionModel(nn.Module): - - # @functools.lru_cache #FIXME - def get_dummy_token(self, batch_size: int, device: D) -> T: - return torch.zeros( - batch_size, self.prefix_length, dtype=torch.int64, device=device - ) - - def forward( - self, tokens: T, prefix: T, mask: Optional[T] = None, labels: Optional[T] = None - ): - embedding_text = self.gpt.transformer.wte(tokens) - prefix_projections = self.clip_project(prefix).view( - -1, self.prefix_length, self.gpt_embedding_size - ) - # print(embedding_text.size()) #torch.Size([5, 67, 768]) - # print(prefix_projections.size()) #torch.Size([5, 1, 768]) - embedding_cat = torch.cat((prefix_projections, embedding_text), dim=1) - if labels is not None: - dummy_token = self.get_dummy_token(tokens.shape[0], tokens.device) - labels = torch.cat((dummy_token, tokens), dim=1) - out = self.gpt(inputs_embeds=embedding_cat, labels=labels, attention_mask=mask) - return out - - def __init__(self, prefix_length: int, prefix_size: int = 512): - super(ClipCaptionModel, self).__init__() - self.prefix_length = prefix_length - self.gpt = GPT2LMHeadModel.from_pretrained("gpt2") - self.gpt_embedding_size = self.gpt.transformer.wte.weight.shape[1] - if prefix_length > 10: # not enough memory - self.clip_project = nn.Linear( - prefix_size, self.gpt_embedding_size * prefix_length - ) - else: - self.clip_project = MLP( - ( - prefix_size, - (self.gpt_embedding_size * prefix_length) // 2, - self.gpt_embedding_size * prefix_length, - ) - ) - - -class ClipCaptionPrefix(ClipCaptionModel): - def parameters(self, recurse: bool = True): - return self.clip_project.parameters() - - def train(self, mode: bool = True): - super(ClipCaptionPrefix, self).train(mode) - self.gpt.eval() - return self - - -def generate_beam( - model, - tokenizer, - beam_size: int = 5, - prompt=None, - embed=None, - entry_length=67, - temperature=1.0, - stop_token: str = ".", -): - - model.eval() - stop_token_index = tokenizer.encode(stop_token)[0] - tokens = None - scores = None - device = next(model.parameters()).device - seq_lengths = torch.ones(beam_size, device=device) - is_stopped = torch.zeros(beam_size, device=device, dtype=torch.bool) - with torch.no_grad(): - if embed is not None: - generated = embed - else: - if tokens is None: - tokens = torch.tensor(tokenizer.encode(prompt)) - tokens = tokens.unsqueeze(0).to(device) - generated = model.gpt.transformer.wte(tokens) - for i in range(entry_length): - outputs = model.gpt(inputs_embeds=generated) - logits = outputs.logits - logits = logits[:, -1, :] / (temperature if temperature > 0 else 1.0) - logits = logits.softmax(-1).log() - if scores is None: - scores, next_tokens = logits.topk(beam_size, -1) - generated = generated.expand(beam_size, *generated.shape[1:]) - next_tokens, scores = next_tokens.permute(1, 0), scores.squeeze(0) - if tokens is None: - tokens = next_tokens - else: - tokens = tokens.expand(beam_size, *tokens.shape[1:]) - tokens = torch.cat((tokens, next_tokens), dim=1) - else: - logits[is_stopped] = -float(np.inf) - logits[is_stopped, 0] = 0 - scores_sum = scores[:, None] + logits - seq_lengths[~is_stopped] += 1 - scores_sum_average = scores_sum / seq_lengths[:, None] - scores_sum_average, next_tokens = scores_sum_average.view(-1).topk( - beam_size, -1 - ) - next_tokens_source = next_tokens // scores_sum.shape[1] - seq_lengths = seq_lengths[next_tokens_source] - next_tokens = next_tokens % scores_sum.shape[1] - next_tokens = next_tokens.unsqueeze(1) - tokens = tokens[next_tokens_source] - tokens = torch.cat((tokens, next_tokens), dim=1) - generated = generated[next_tokens_source] - scores = scores_sum_average * seq_lengths - is_stopped = is_stopped[next_tokens_source] - next_token_embed = model.gpt.transformer.wte(next_tokens.squeeze()).view( - generated.shape[0], 1, -1 - ) - generated = torch.cat((generated, next_token_embed), dim=1) - is_stopped = is_stopped + next_tokens.eq(stop_token_index).squeeze() - if is_stopped.all(): - break - scores = scores / seq_lengths - output_list = tokens.cpu().numpy() - output_texts = [ - tokenizer.decode(output[: int(length)]) - for output, length in zip(output_list, seq_lengths) - ] - order = scores.argsort(descending=True) - output_texts = [output_texts[i] for i in order] - return output_texts - - -def generate2( - model, - tokenizer, - tokens=None, - prompt=None, - embed=None, - entry_count=1, - entry_length=67, # maximum number of words - top_p=0.8, - temperature=1.0, - stop_token: str = ".", -): - model.eval() - generated_num = 0 - generated_list = [] - stop_token_index = tokenizer.encode(stop_token)[0] - filter_value = -float("Inf") - device = next(model.parameters()).device - - with torch.no_grad(): - - for entry_idx in range(entry_count): - if embed is not None: - generated = embed - else: - if tokens is None: - tokens = torch.tensor(tokenizer.encode(prompt)) - tokens = tokens.unsqueeze(0).to(device) - - generated = model.gpt.transformer.wte(tokens) - - for i in range(entry_length): - - outputs = model.gpt(inputs_embeds=generated) - logits = outputs.logits - logits = logits[:, -1, :] / (temperature if temperature > 0 else 1.0) - sorted_logits, sorted_indices = torch.sort(logits, descending=True) - cumulative_probs = torch.cumsum( - nnf.softmax(sorted_logits, dim=-1), dim=-1 - ) - sorted_indices_to_remove = cumulative_probs > top_p - sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[ - ..., :-1 - ].clone() - sorted_indices_to_remove[..., 0] = 0 - - indices_to_remove = sorted_indices[sorted_indices_to_remove] - logits[:, indices_to_remove] = filter_value - next_token = torch.argmax(logits, -1).unsqueeze(0) - next_token_embed = model.gpt.transformer.wte(next_token) - if tokens is None: - tokens = next_token - else: - tokens = torch.cat((tokens, next_token), dim=1) - generated = torch.cat((generated, next_token_embed), dim=1) - if stop_token_index == next_token.item(): - break - - output_list = list(tokens.squeeze().cpu().numpy()) - output_text = tokenizer.decode(output_list) - generated_list.append(output_text) - - return generated_list[0] \ No newline at end of file diff --git a/src/dataset.py b/src/dataset.py deleted file mode 100644 index 9305cd140db97ab24bb61b7db70bc585217b43c9..0000000000000000000000000000000000000000 --- a/src/dataset.py +++ /dev/null @@ -1,94 +0,0 @@ -import torch -from torch.utils.data import Dataset - -from tqdm import tqdm -import json -from typing import Tuple -import clip -import random -import json -import random -from tqdm import tqdm - -class ClipCocoDataset(Dataset): - - def __len__(self) -> int: - return len(self.captions_tokens) - - def pad_tokens(self, item: int): - tokens = self.captions_tokens[item] - padding = self.max_seq_len - tokens.shape[0] - if padding > 0: - tokens = torch.cat((tokens, torch.zeros(padding, dtype=torch.int64))) - elif padding < 0: - tokens = tokens[:self.max_seq_len] - return tokens - - - def __getitem__(self, item: int) -> Tuple[torch.Tensor, ...]: - # tokens = self.captions_tokens[item] - - clip_tokens = self.pad_tokens(item) - if self.feats is None: - clip_tokens_77 = self.captions_tokens[item] - return clip_tokens, clip_tokens_77 - else: - return clip_tokens, self.feats[item] - - def __init__(self, data_path: str, clip_model=None, talk2dino=None, use_dino_feats=False, tokenizer=None): - if tokenizer is not None: - self.clip_tokenizer = tokenizer - else: - print(f"Using default tokenizer") - self.clip_tokenizer = clip.tokenize - self.prefix_length = 10 - self.max_seq_len = 20 - self.feats = None - - if clip_model is not None: - device = next(clip_model.parameters()).device - print("Pre-extracting features...") - - if not use_dino_feats: - with open(data_path, 'r') as f: - self.captions = [ann['caption'] for ann in json.load(f)['annotations']] - else: - data = torch.load(data_path) - self.captions = [ann['caption'] for ann in data['annotations']] - self.feats = [ann['features'] for ann in data['annotations']] - - - random.shuffle(self.captions) - self.captions_tokens = [] - - batch_size = 64 - batched_captions = [self.captions[i:i + batch_size] for i in range(0, len(self.captions), batch_size)] - - for batch in tqdm(batched_captions): - try: - # Tokenize the batch of captions - batch_tokens = [torch.tensor(self.clip_tokenizer(caption)[0], dtype=torch.int64) for caption in batch] - - # Pad tokens to the same length for batching - batch_tokens_padded = torch.nn.utils.rnn.pad_sequence(batch_tokens, batch_first=True) - self.captions_tokens.extend(batch_tokens) - - if clip_model is not None: - with torch.no_grad(): - # Encode the text batch - feats = clip_model.encode_text(batch_tokens_padded.to(device)) - - if talk2dino is not None: - # Project to desired feature space - feats = talk2dino.project_clip_txt(feats).to('cpu') - - # Concatenate features - if self.feats is None: - self.feats = feats - else: - self.feats = torch.cat((self.feats, feats)) - except Exception as e: - print(f"Error processing batch: {e}") - print(len(self.captions_tokens)) - - diff --git a/src/datasetMix.py b/src/datasetMix.py deleted file mode 100644 index c86da8c71d799ba73d11ca0b85435a828320ec2f..0000000000000000000000000000000000000000 --- a/src/datasetMix.py +++ /dev/null @@ -1,153 +0,0 @@ -import torch -from torch.utils.data import Dataset - -from tqdm import tqdm -import json -from typing import Tuple -import clip -import random -import json -import random -from tqdm import tqdm - -from pycocotools.coco import COCO - -class ClipCocoDatasetMix(Dataset): - - def __len__(self) -> int: - return len(self.image_index_list) - - def _pad_tokens(self, tokens: torch.Tensor) -> torch.Tensor: - padding = self.max_seq_len - tokens.shape[0] - if padding > 0: - tokens = torch.cat((tokens, torch.zeros(padding, dtype=torch.int64))) - elif padding < 0: - tokens = tokens[:self.max_seq_len] - return tokens - - - def __getitem__(self, item: int) -> Tuple[torch.Tensor, ...]: - - # get the image index for the item - img_idx = self.image_index_list[item] - # get the caption index for that image - first_caption_idx = self.image_index_list.index(img_idx) - - # the caption index is the item - the first caption index - caption_idx = item - first_caption_idx - - # how many captions are there for that image? - num_captions = len(self.captions_list_of_lists[img_idx]) - try: - tokens = self.captions_tokens_list_of_lists[img_idx][caption_idx] #self.captions_list_of_lists[img_idx][caption_idx] - except IndexError: - print(f"{len(self.captions_tokens_list_of_lists)= } - {len(self.captions_tokens_list_of_lists[img_idx])= }") - print(f"IndexError: {img_idx}, {caption_idx}, {num_captions}") - raise - padded_tokens = self._pad_tokens(tokens) - - feats_same_img = self.feats[img_idx][random.choice(range(num_captions))] - - if self.feats is None or len(self.feats) == 0: - raise Exception("Precomputed features required") - else: - return padded_tokens, feats_same_img - - def __init__(self, data_path: str, clip_model=None, talk2dino=None, use_precomputed_feats=False, tokenizer=None): - - batch_size = 64 - self.max_seq_len = 20 - - if use_precomputed_feats: - raise Exception("Precomputed features not supported") - - if tokenizer is not None: - self.clip_tokenizer = tokenizer - else: - print(f"Using default tokenizer") - self.clip_tokenizer = clip.tokenize - - coco_data = COCO(data_path) - # I want to load the captions from the json file in a list of lists, - # where each list contains the captions for a single image - - self.captions_list_of_lists = [] - - self.image_index_list = [] - - max_seq_len = 20 - - for img_idx, (img_id, image) in enumerate(list(coco_data.imgs.items())): - # get the captions for that image - captions = coco_data.imgToAnns[img_id] - # get the texts of the captions - captions = [cap['caption'] for cap in captions] #[coco_data.anns[cap]['caption'] for cap in captions] - self.captions_list_of_lists.append(captions) - self.image_index_list.append([img_idx] * len(captions)) - - #max_seq_len = max(max_seq_len, max([len(caption) for caption in captions])) - - self.max_seq_len = max_seq_len - print(f"Computed Max seq len: {max_seq_len}") - - if clip_model is not None: - device = next(clip_model.parameters()).device - print("Pre-extracting features...") - - #random.shuffle(self.captions_list_of_lists) - # should shuffle in the same way self.image_index_list and self.captions_list_of_lists - # Combine captions and image indices into a list of pairs - combined = list(zip(self.captions_list_of_lists, self.image_index_list, range(len(self.captions_list_of_lists)))) - - # Shuffle them together - random.shuffle(combined) - - # Unzip the shuffled pairs back into two separate lists - self.captions_list_of_lists, self.image_index_list, img_idxes_shuffled = zip(*combined) - # Convert back to lists (zip returns tuples) - self.captions_list_of_lists = list(self.captions_list_of_lists) - self.image_index_list = list(self.image_index_list) - img_idxes_shuffled = list(img_idxes_shuffled) - - # self.image_index_list is a list of lists, where each list contains the image index for each caption, - # so we need to flatten it - self.image_index_list = [img_idxes_shuffled.index(item) for sublist in self.image_index_list for item in sublist] - - - self.captions_tokens_list_of_lists = [] - self.feats = [] # feats will be a list of tensors, each tensor will be (num_captions, embedding_dimension) - #ignore. # feats shape will be (num_images, num_captions, embedding_dimension) - - #batched_captions = [self.captions[i:i + batch_size] for i in range(0, len(self.captions), batch_size)] - - for captions_list in tqdm(self.captions_list_of_lists, dynamic_ncols=True): - try: - # Tokenize the batch of captions - batch_tokens = [torch.tensor(self.clip_tokenizer(caption)[0], dtype=torch.int64) for caption in captions_list] - - # Pad tokens to the same length for batching - batch_tokens_padded = torch.nn.utils.rnn.pad_sequence(batch_tokens, batch_first=True) - self.captions_tokens_list_of_lists.append(batch_tokens) - - # alternative: - # tokens = self.clip_tokenizer(captions_list, truncate=True).to(device) # shape: (num_captions, context_length) - - - if clip_model is not None: - with torch.no_grad(): - # Encode the text batch - feats = clip_model.encode_text(batch_tokens_padded.to(device)) - - if talk2dino is not None: - # Project to desired feature space - feats = talk2dino.project_clip_txt(feats).to('cpu') - - self.feats.append(feats.cpu()) # store (num_captions, embed_dim) for each image - - except Exception as e: - print(f"Error processing batch: {e}") - - print(f"Dataset loaded with {len(self.captions_list_of_lists)} images") - print(f"Max seq len: {max_seq_len}") - print(f"Number of captions: {len(self.image_index_list)}") - diff --git a/src/decap/decap.py b/src/decap/decap.py deleted file mode 100755 index b7f7e594bd86ca07545c5a72c552c52078d0e58e..0000000000000000000000000000000000000000 --- a/src/decap/decap.py +++ /dev/null @@ -1,193 +0,0 @@ -import os -from torch import nn -import numpy as np -import torch -import torch.nn.functional as nnf -import sys -from typing import Tuple, List, Union, Optional -from tqdm import tqdm, trange -import pickle -import PIL.Image as Image -import json -import random -import sys -import clip -import PIL -import random - -from torch.utils.data import Dataset, DataLoader -from enum import Enum -from transformers import GPT2Tokenizer, GPT2LMHeadModel, AdamW, get_linear_schedule_with_warmup -from tqdm import tqdm -import os -import pickle -import sys -import argparse -import json -from typing import Tuple, Optional, Union - -import os -from dotenv import load_dotenv - -load_dotenv() - - -DECAP_DECODER_CONFIG_PATH = os.path.join(os.path.dirname(__file__), "decoder_config.pkl") -DECAP_COCO_WEIGHTS_PATH = None#'../../thesis-data/decap/coco_model/coco_prefix-009.pt' - -class MappingType(Enum): - MLP = 'mlp' - Transformer = 'transformer' - - -class MLP(nn.Module): - - def forward(self, x: torch.Tensor) -> torch.Tensor: - return self.model(x) - - def __init__(self, sizes: Tuple[int, ...], bias=True, act=nn.Tanh): - super(MLP, self).__init__() - layers = [] - for i in range(len(sizes) - 1): - layers.append(nn.Linear(sizes[i], sizes[i + 1], bias=bias)) - if i < len(sizes) - 2: - layers.append(act()) - self.model = nn.Sequential(*layers) - - -class DeCap(nn.Module): - - def __init__(self,prefix_size: int = 512): - super(DeCap, self).__init__() - # decoder: 4 layers transformer with 4 attention heads - # the decoder is not pretrained - with open(DECAP_DECODER_CONFIG_PATH,'rb') as f: - config = pickle.load(f) - self.decoder = GPT2LMHeadModel(config) - self.embedding_size = self.decoder.transformer.wte.weight.shape[1] - self.clip_project = MLP((prefix_size,self.embedding_size)) - - def forward(self, clip_features,tokens): - embedding_text = self.decoder.transformer.wte(tokens) - embedding_clip = self.clip_project(clip_features) - embedding_clip = embedding_clip.reshape(-1,1,self.embedding_size) - embedding_cat = torch.cat([embedding_clip,embedding_text],dim=1) - out = self.decoder(inputs_embeds=embedding_cat) - return out - -from clip.simple_tokenizer import SimpleTokenizer as _Tokenizer -_Tokenizer = _Tokenizer() - -def Decoding(model,clip_features): - model.eval() - embedding_cat = model.clip_project(clip_features).reshape(1,1,-1) - entry_length = 30 - temperature = 1 - tokens = None - for i in range(entry_length): - # print(location_token.shape) - outputs = model.decoder(inputs_embeds=embedding_cat) - - logits = outputs.logits - logits = logits[:, -1, :] / (temperature if temperature > 0 else 1.0) - logits_max = logits.max() - logits = torch.nn.functional.softmax(logits, -1) - next_token = torch.argmax(logits, -1).unsqueeze(0) - next_token_embed = model.decoder.transformer.wte(next_token) - - if tokens is None: - tokens = next_token - - else: - tokens = torch.cat((tokens, next_token), dim=1) - if next_token.item()==49407: - break - embedding_cat = torch.cat((embedding_cat, next_token_embed), dim=1) - try: - output_list = list(tokens.squeeze().cpu().numpy()) - output = _Tokenizer.decode(output_list) - except: - output = 'None' - return output - -def decoding_batched(model, clip_features, compute_scores : bool = False, decoding_method : callable = None, return_start_end_tokens : bool = False): - """ - Returns the generated sequences for a batch of clip features. - - if compute_scores is True, also returns the scores of the generated sequences. - - returns a list of strings if compute_scores is False, otherwise a tuple of a list of strings and a list of floats. - """ - - model.eval() - embedding_cat = model.clip_project(clip_features).view(clip_features.shape[0], 1, -1) - entry_length = 30 - temperature = 1 - tokens = None - sequence_log_probs = None - - for i in range(entry_length): - outputs = model.decoder(inputs_embeds=embedding_cat) - - logits = outputs.logits[:, -1, :] - logits = logits / (temperature if temperature > 0 else 1.0) - - probs = torch.nn.functional.softmax(logits, -1) - - if compute_scores: - log_probs = torch.log(probs) # Convert to log-probabilities - - next_token = torch.argmax(probs, -1).unsqueeze(1) - next_token_embed = model.decoder.transformer.wte(next_token) - - if tokens is None: - tokens = next_token - if compute_scores: - sequence_log_probs = log_probs.gather(1, next_token) # Store log-prob of first token - else: - tokens = torch.cat((tokens, next_token), dim=1) - if compute_scores: - token_log_probs = log_probs.gather(1, next_token) # Get log-prob of chosen token - sequence_log_probs = torch.cat((sequence_log_probs, token_log_probs), dim=1) # Append - - # Append new token embedding to input - embedding_cat = torch.cat((embedding_cat, next_token_embed), dim=1) - - if compute_scores: - # Compute total sequence scores - sequence_scores = sequence_log_probs.sum(dim=-1) # Sum log-probs over sequence - final_scores = torch.exp(sequence_scores) # Convert log-sum-prob to probability-like score - - try: - outputs = [] - for tokens_elem in tokens: - output_list = list(tokens_elem.squeeze().cpu().numpy()) - if decoding_method is not None: - output = decoding_method(output_list) - else: - output = _Tokenizer.decode(output_list) - - - - output = output.split('<|endoftext|>')[0] - if not return_start_end_tokens: - output = output.replace('<|startoftext|>', '') - else: - output += '<|endoftext|>' - - outputs.append(output) - except: - outputs = None - - return (outputs, final_scores.cpu().numpy().tolist()) if compute_scores else outputs - - -decap_model = None - -def get_decap_model(device, weights_path = DECAP_COCO_WEIGHTS_PATH, prefix_size=512): - #global decap_model - #if decap_model is not None: - # return decap_model - decap_model = DeCap(prefix_size) - decap_model.load_state_dict(torch.load(weights_path,map_location= torch.device('cpu')), strict=False) - decap_model = decap_model.to(device) - decap_model = decap_model.eval() - return decap_model diff --git a/src/decap/decoderTraining.py b/src/decap/decoderTraining.py deleted file mode 100755 index 85eed24b1fd3355f8fbd261530d8d831d3fba0cf..0000000000000000000000000000000000000000 --- a/src/decap/decoderTraining.py +++ /dev/null @@ -1,464 +0,0 @@ -import torch -from torch.utils.data import DataLoader -from torch.utils.data.distributed import DistributedSampler -from torch.nn.parallel import DistributedDataParallel as DDP -import torch.distributed as dist - -from im2txtprojection.im2txtprojection import Im2TxtProjector, ProjectionType -from transformers import get_linear_schedule_with_warmup -from torch.optim import AdamW -from tqdm import tqdm -from decap import get_decap_model -import os -import sys -import argparse -import json -from typing import Union -import sys -import clip -import json - -import csv - - -from src.dataset import ClipCocoDataset -from src.datasetMix import ClipCocoDatasetMix -from src.model import DeCap, ProjectionLayer - -DECAP_DECODER_CONFIG_PATH = os.path.join("./decoder_config.pkl") - -def save_config(args: argparse.Namespace): - config = {} - for key, item in args._get_kwargs(): - config[key] = item - out_path = os.path.join(args.out_dir, f"{args.prefix}.json") - with open(out_path, 'w') as outfile: - json.dump(config, outfile) - - -def load_model(config_path: str, epoch_or_latest: Union[str, int] = '_latest'): - with open(config_path) as f: - config = json.load(f) - parser = argparse.ArgumentParser() - parser.set_defaults(**config) - args = parser.parse_args() - if type(epoch_or_latest) is int: - epoch_or_latest = f"-{epoch_or_latest:03d}" - model_path = os.path.join(args.out_dir, f"{args.prefix}{epoch_or_latest}.pt") - if args.only_prefix: - model = ClipCaptionPrefix(args.prefix_length) - else: - model = ClipCaptionModel(args.prefix_length) - if os.path.isfile(model_path): - print(f"loading model from {model_path}") - model.load_state_dict(torch.load(model_path, map_location=torch.device('cpu'))) - else: - print(f"{model_path} is not exist") - return model, parser - - - - -def train_decoder(args, - lr: float = 1e-5, warmup_steps: int = 1000, output_dir: str = ".", output_prefix: str = ""): - - # device = torch.device('cuda:1') - batch_size = args.bs - epochs = args.epochs - if not os.path.exists(output_dir): - os.makedirs(output_dir) - args.is_master = ( args.local_rank == 0 or args.not_distributed != False) - - # set the device - #torch.cuda.set_device(args.local_rank) - #device = torch.device('cuda:'+str(args.local_rank)) - if args.not_distributed == False: - torch.cuda.set_device(args.local_rank) - device = torch.device('cuda:'+str(args.local_rank)) - dist.init_process_group(backend='nccl', init_method='env://') - else: - device = torch.device('cuda:'+str(args.local_rank)) - print(f"NOT DISTRIBUTED") - print(f"Using device {device}") - SEED=42 - torch.cuda.manual_seed_all(SEED) - - if args.use_regionclip: - # RegionCLIP typically uses 1024 dimensions for ResNet-50 or 512 for ViT - # We'll determine this from the loaded model - prefix_size = 1024 # Default for RegionCLIP ResNet-50, but will be adjusted if needed - elif args.denseclip_config is not None: - # DenseClip typically uses 512 dimensions (similar to CLIP ViT-B) - from src.denseclip.loader import load_denseclip_config - denseclip_config_dict = load_denseclip_config(args.denseclip_config) - prefix_size = denseclip_config_dict.get('model', {}).get('text', {}).get('embed_dim', None) - if prefix_size is None: - print(f"Warning: Could not determine prefix_size from DenseClip config {args.denseclip_config}. Defaulting to 512.") - prefix_size = 512 # Fallback to a common size) - - elif 'H' in args.clip_model or args.use_dinotxt: - prefix_size = 1024 - elif args.talk2dino_weights is not None or args.use_dino_feats: - prefix_size = 768 - else: - prefix_size = 512 - - if args.im_proj: - memory_bank_path = os.path.abspath(args.dataset) - print(f"Using Im2TxtProjector with {memory_bank_path = }") - im_proj = Im2TxtProjector( - type=memory_bank_path, - use_talk2dino=True, - linear_talk2dino=False, - memory_bank_name='coco_karpathy', - device_str=device) - - if args.use_regionclip: - from src.regionclip.loader import load_regionclip_from_checkpoint - from src.regionclip.datasets.clip_prompt_utils import tokenize as regionclip_tokenize - - print("Using RegionCLIP for text encoding.") - if args.regionclip_checkpoint is None: - raise ValueError("RegionCLIP checkpoint path must be provided when using --use-regionclip") - - clip_model = load_regionclip_from_checkpoint( - args.regionclip_checkpoint, - device=device, - config=args.regionclip_config - ) - tokenizer = regionclip_tokenize - preprocess = None # RegionCLIP doesn't need preprocessing for text-only training - - # Determine the actual embedding dimension from the loaded model - if hasattr(clip_model, 'text_projection'): - actual_prefix_size = clip_model.text_projection.shape[1] - print(f"RegionCLIP text embedding dimension: {actual_prefix_size}") - if actual_prefix_size != prefix_size: - print(f"Updating prefix_size from {prefix_size} to {actual_prefix_size}") - prefix_size = actual_prefix_size - - # Test RegionCLIP text encoding to ensure it works - try: - test_text = ["A test sentence"] - test_tokens = tokenizer(test_text) - test_features = clip_model.encode_text(test_tokens.to(device)) - print(f"RegionCLIP test encoding successful. Output shape: {test_features.shape}") - except Exception as e: - print(f"Warning: RegionCLIP test encoding failed: {e}") - print("This might cause issues during training.") - - elif args.denseclip_config is not None: - from src.denseclip.loader import load_denseclip - - print(f"Using DenseClip for text encoding with config: {args.denseclip_config}") - - try: - clip_model = load_denseclip( - config_name=args.denseclip_config, - device=device - ) - - # Try to use DenseClip's tokenizer first - try: - from src.denseclip.loader import DenseCLIP_tokenize - tokenizer = DenseCLIP_tokenize - print("Using DenseClip tokenizer") - except ImportError: - # Fallback to CLIP tokenizer if DenseClip tokenizer is not available - import clip - tokenizer = clip.tokenize - print("Warning: DenseClip tokenizer not available, using CLIP tokenizer") - - preprocess = None # DenseClip doesn't need preprocessing for text-only training - - # Determine the actual embedding dimension from the loaded model - if hasattr(clip_model, 'text_encoder') and hasattr(clip_model.text_encoder, 'embed_dim'): - actual_prefix_size = clip_model.text_encoder.embed_dim - print(f"DenseClip text embedding dimension: {actual_prefix_size}") - if actual_prefix_size != prefix_size: - print(f"Updating prefix_size from {prefix_size} to {actual_prefix_size}") - prefix_size = actual_prefix_size - - # Test DenseClip text encoding to ensure it works - test_text = ["A test sentence"] - test_tokens = tokenizer(test_text) - if hasattr(test_tokens, 'to'): - test_tokens = test_tokens.to(device) - test_features = clip_model.encode_text(test_tokens) - print(f"DenseClip test encoding successful. Output shape: {test_features.shape}") - - except Exception as e: - print(f"Error loading DenseClip model: {e}") - raise e - - elif args.use_open_clip: - from open_clip import create_model_and_transforms, tokenize - print("Using open_clip for model loading.") - clip_model, preprocess_train, preprocess_val = create_model_and_transforms(model_name=args.clip_model, pretrained="laion2b_s32b_b79k", device=device) - preprocess = preprocess_train - tokenizer = tokenize - - elif args.use_dinotxt: - from src.dinotxt_utils import get_tokenizer - clip_model = torch.hub.load('facebookresearch/dinov2', 'dinov2_vitl14_reg4_dinotxt_tet1280d20h24l') - tokenizer = get_tokenizer().tokenize - else: - clip_model, preprocess = clip.load(args.clip_model, device=device, jit=False) - tokenizer = clip.tokenize - clip_model.eval() - clip_model.to(device) - - # Create model after determining the correct prefix_size - if args.decap_weights is None: - model = DeCap(prefix_size) - else: - model = get_decap_model(device, args.decap_weights, prefix_size) - - if args.talk2dino_weights is not None: - # loading Talk2DINO - print(f"Loading Talk2DINO weights from {args.talk2dino_weights}") - talk2dino = ProjectionLayer.from_config(args.talk2dino_config) - talk2dino.load_state_dict(torch.load(args.talk2dino_weights, device)) - talk2dino.to(device) - talk2dino.eval() - - else: - talk2dino = None - - - loss_ce = torch.nn.CrossEntropyLoss(ignore_index=0,label_smoothing=0.1) - model.to(device) - - if args.not_distributed == False: - model = DDP( - model, - device_ids=[args.local_rank], - output_device=args.local_rank, - find_unused_parameters=True - ) - - if not args.pre_extract_features: - print("Features pre-extraction de-activated") - if args.mix_captions: - print("Using mix captions") - dataset = ClipCocoDatasetMix(args.dataset, use_precomputed_feats=args.use_dino_feats, tokenizer=tokenizer) - else: - dataset = ClipCocoDataset(args.dataset, use_dino_feats=args.use_dino_feats, tokenizer=tokenizer) - else: - if args.mix_captions: - print("Using mix captions") - dataset = ClipCocoDatasetMix(args.dataset, clip_model=clip_model, talk2dino=talk2dino, tokenizer=tokenizer) - else: - dataset = ClipCocoDataset(args.dataset, clip_model=clip_model, talk2dino=talk2dino, tokenizer=tokenizer) - - - optimizer = AdamW(model.parameters(),lr=lr) - - print(f"Going to construct DataLoader with {len(dataset)} samples") - if args.not_distributed == False: - sampler = DistributedSampler(dataset) - train_dataloader = DataLoader(dataset, sampler=sampler, batch_size=batch_size, drop_last=True) - else: - train_dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True, drop_last=True) - - print("DataLoader constructed") - scheduler = get_linear_schedule_with_warmup( - optimizer, num_warmup_steps=warmup_steps, num_training_steps=epochs * len(train_dataloader) - ) - - - for epoch in range(epochs): - - epoch_loss = 0.0 - epoch_acc = 0.0 - num_batches = 0 - - loss_token_save,ac_save= 0,0 - sys.stdout.flush() - if args.is_master: - print(f">>> Training epoch {epoch}") - progress = tqdm(total=int(len(train_dataloader)/10), desc=output_prefix, dynamic_ncols=True) - - if args.not_distributed == False: - dist.barrier() - - for idx,(clip_tokens, pipeline_input) in enumerate(train_dataloader): - - - clip_tokens, pipeline_input = clip_tokens.to(device), pipeline_input.to(device) - - with torch.no_grad(): - if not args.pre_extract_features and not args.use_dino_feats: - if args.use_regionclip: - # RegionCLIP text encoding - feature_text = clip_model.encode_text(pipeline_input) - elif args.denseclip_config is not None: - # DenseClip text encoding - feature_text = clip_model.encode_text(pipeline_input) - else: - # Standard CLIP or OpenCLIP text encoding - feature_text = clip_model.encode_text(pipeline_input) - - if args.use_dinotxt: - feature_text = feature_text[:, 1024:] # patch-aligned text embedding - - if args.talk2dino_weights is not None: - feature_text = talk2dino.project_clip_txt(feature_text) - else: - feature_text = pipeline_input - if args.im_proj: - feature_text = im_proj.project(feature_text, normalize=True) - - feature_text /= feature_text.norm(dim=-1, keepdim=True) - - if args.gaussian_noise != 0: - feature_text += args.gaussian_noise * torch.randn(feature_text.shape).to(device) - feature_text /= feature_text.norm(dim=-1, keepdim=True) - - - outputs = model(feature_text.float(),clip_tokens) - logits = outputs - - logits = logits.logits - - logits = logits[:,: -1] - clip_tokens = clip_tokens.flatten() - logits = logits.reshape(-1, logits.shape[-1]) - - loss_token = loss_ce(logits, clip_tokens) - ac=((logits.argmax(1)==clip_tokens)*(clip_tokens>0)).sum()/(clip_tokens>0).sum() - optimizer.zero_grad() - loss_all = loss_token - loss_all.backward() - optimizer.step() - scheduler.step() - - epoch_loss += loss_token.item() - epoch_acc += ac.item() - num_batches += 1 - - if args.is_master: - - if(idx+1) %10 == 0: - progress.set_postfix({"loss_token": loss_token_save/10.0,"acc_token":ac_save/10.0}) - progress.update() - loss_token_save,ac_save= 0,0 - else: - loss_token_save += loss_token.item() - ac_save += ac.item() - - if args.is_master: - log_dir = os.path.join('./log', f"{args.dataset}.txt")#'./log/'+args.dataset+'.txt' - with open(log_dir,'w') as f: - f.writelines('epoch ' +str(epoch) +': '+ progress.postfix+'\r\n') - progress.close() - if epoch % args.save_every == 0 or epoch == epochs - 1: - torch.save( - model.state_dict(), - os.path.join(output_dir, f"{output_prefix}-{epoch:03d}.pt"), - ) - - # after the epoch, we need to synchronize the loss and accuracy across all processes - loss_tensor = torch.tensor(epoch_loss, device=device) - acc_tensor = torch.tensor(epoch_acc, device=device) - count_tensor = torch.tensor(num_batches, device=device) - - if args.not_distributed == False: - # sum on all processes - torch.distributed.all_reduce(loss_tensor, op=torch.distributed.ReduceOp.SUM) - torch.distributed.all_reduce(acc_tensor, op=torch.distributed.ReduceOp.SUM) - torch.distributed.all_reduce(count_tensor, op=torch.distributed.ReduceOp.SUM) - - # compute global mean - avg_loss = loss_tensor.item() / count_tensor.item() - avg_acc = acc_tensor.item() / count_tensor.item() - - if args.is_master: - epoch_loss_current = {'epoch': epoch, 'loss': avg_loss, 'accuracy': avg_acc} - #epoch_losses.append(epoch_loss_current) - print(f"Epoch {epoch} loss: {avg_loss}, accuracy: {avg_acc}") - - loss_csv_path = os.path.join(output_dir, f"{output_prefix}_epoch_losses.csv") - with open(loss_csv_path, 'a', newline='') as csvfile: - writer = csv.DictWriter(csvfile, fieldnames=['epoch', 'loss', 'accuracy']) - # Write the header only if the file is empty - if os.stat(loss_csv_path).st_size == 0: - writer.writeheader() - writer.writerow(epoch_loss_current) - return model - -# DeCap CLIP B16 karpathy train split: -#python decapTraining.py --out_dir weights_clip_b16_karpathy --not-distributed 1 --local-rank 0 --dataset coco_train_karpathy.json --prefix coco_karpathy -# DECAP with proj -> ma in realtร  non serve. -#python decapTraining.py --out_dir weights_clip_b16_proj_karpathy --not-distributed 1 --local-rank 0 --dataset coco_train_karpathy.json --prefix coco_karpathy --im_proj - -# Patchioner DINOv2 karpathy train split with proj: -#python decapTraining.py --out_dir weights_dino_b14_proj_karpathy --not-distributed 1 --local-rank 1 --dataset coco_train_karpathy.json --prefix coco_karpathy --talk2dino_weights weights_talk2dino/vitb_mlp_infonce.pth --talk2dino_config configs_talk2dino/vitb_mlp_infonce.yaml --pre_extract_features --im_proj -# Patchioner DINOv2 karpathy train split -#python decapTraining.py --out_dir weights_dino_b14_karpathy --not-distributed 1 --local-rank 1 --dataset coco_train_karpathy.json --prefix coco_karpathy --talk2dino_weights weights_talk2dino/vitb_mlp_infonce.pth --talk2dino_config configs_talk2dino/vitb_mlp_infonce.yaml -#python decapTraining.py --out_dir weights_dino_b14_karpathy --not-distributed 1 --local-rank 1 --dataset coco_train_karpathy.json --prefix coco_karpathy --talk2dino_weights weights_talk2dino/vitb_mlp_infonce.pth --talk2dino_config configs_talk2dino/vitb_mlp_infonce.yaml --use_dino_feats --pre_extract_features - -# DeCap CLIP B32 karpathy train split: -#python decapTraining.py --out_dir weights_clip_b32_karpathy --not-distributed 1 --local-rank 0 --dataset coco_train_karpathy.json --prefix coco_karpathy --clip_model ViT-B/32 - -# DeCap with RegionCLIP text encoder: -#python decoderTraining.py --out_dir weights_regionclip_karpathy --not-distributed 1 --local-rank 0 --dataset coco_train_karpathy.json --prefix coco_karpathy --use-regionclip - -# DeCap with DenseClip text encoder: -#python decoderTraining.py --out_dir weights_denseclip_segmentation_vitb16_karpathy --not-distributed 1 --local-rank 0 --dataset coco_train_karpathy.json --prefix coco_karpathy --denseclip-config denseclip_segmentation_vitb16 - - -def main(): - parser = argparse.ArgumentParser() - parser.add_argument('--decap_weights', type=str, default=None, help="If setted the Decap initialization is not random") - parser.add_argument('--clip_model', type=str, default='ViT-B/16', help="CLIP configuration") - parser.add_argument('--use_dinotxt', default=None, action='store_true', help="CLIP configuration") - parser.add_argument('--gaussian_noise', type=float, default=0, help="Standard deviation of the Gaussian noise to apply to the text input") - parser.add_argument('--out_dir', default='./coco_model') - parser.add_argument('--prefix', default='./coco_prefix', help='prefix for saved filenames') - parser.add_argument('--dataset', default='coco', help='coco or cc3m or bookcorpus') - parser.add_argument('--epochs', type=int, default=10) - parser.add_argument('--save_every', type=int, default=1) - parser.add_argument('--prefix_length', type=int, default=1) - parser.add_argument('--prefix_length_clip', type=int, default=1) - parser.add_argument('--bs', type=int, default=64) - parser.add_argument('--talk2dino_weights', type=str, default=None, help="Talk2DINO weights. If None, the training will be performed without Talk2DINO.") - parser.add_argument('--talk2dino_config', type=str, default=None, help="Talk2DINO configs. Valid only if the weights are setted.") - parser.add_argument('--use_dino_feats', action="store_true", default=False, help="If setted, we use the pre-extracted features of DINOv2") - parser.add_argument('--im_proj', action="store_true", default=False, help="If setted, we use the projection on the input features") - parser.add_argument('--pre_extract_features', action="store_true", default=False, help="If setted, the features will be extracted during the dataloading") - parser.add_argument('--only_prefix', dest='only_prefix', action='store_true') - parser.add_argument('--mapping_type', type=str, default='mlp', help='mlp/transformer') - parser.add_argument('--num_layers', type=int, default=8) - parser.add_argument('--is_rn', dest='is_rn', action='store_true') - parser.add_argument('--normalize_prefix', dest='normalize_prefix', action='store_true') - parser.add_argument('--local-rank', type=int, default=-1, metavar='N', help='Local process rank.') - parser.add_argument('--not-distributed', type=int, default=False, metavar='N', help='Not Distributed toggle.') - parser.add_argument('--use-open-clip', action='store_true', default=False, help='Use OpenCLIP instead of CLIP') - parser.add_argument('--mix-captions', action='store_true', default=False, help='Mix captions from the same image') - parser.add_argument('--use-regionclip', action='store_true', default=False, help='Use RegionCLIP for text encoding') - parser.add_argument('--regionclip-checkpoint', type=str, default='/raid/datasets/models_weights/regionclip/regionclip_pretrained-cc_rn50x4.pth', help='Path to RegionCLIP checkpoint file') - parser.add_argument('--regionclip-config', type=str, default='pretrain/RegionCLIP_RN50x4.yaml', help='Path to RegionCLIP config file or config name') - parser.add_argument('--denseclip-config', type=str, default=None, help='Path to DenseClip config file or config name') - args = parser.parse_args() - - # Validate RegionCLIP arguments - if args.use_regionclip and args.regionclip_checkpoint is None: - parser.error("--regionclip-checkpoint is required when using --use-regionclip") - - if args.use_regionclip and args.use_open_clip: - parser.error("Cannot use both --use-regionclip and --use-open-clip at the same time") - - # Validate DenseClip arguments - if args.denseclip_config is not None and args.use_regionclip: - parser.error("Cannot use both --denseclip-config and --use-regionclip at the same time") - - if args.denseclip_config is not None and args.use_open_clip: - parser.error("Cannot use both --denseclip-config and --use-open-clip at the same time") - - - train_decoder(args, output_dir=args.out_dir, output_prefix=args.prefix) - - -if __name__ == '__main__': - main() \ No newline at end of file diff --git a/src/decap/decoder_config.pkl b/src/decap/decoder_config.pkl deleted file mode 100644 index 31293e585081fe4de2004432722e04100ab9fa9c..0000000000000000000000000000000000000000 --- a/src/decap/decoder_config.pkl +++ /dev/null @@ -1,3 +0,0 @@ -version https://git-lfs.github.com/spec/v1 -oid sha256:c454e6bddb15af52c82734f1796391bf3a10a6c5533ea095de06f661ebb858bb -size 1744 diff --git a/src/decap/im2txtprojection/im2txtprojection.py b/src/decap/im2txtprojection/im2txtprojection.py deleted file mode 100644 index 22f8adc6725da6343b66a0149b278bb0ab1fb7ef..0000000000000000000000000000000000000000 --- a/src/decap/im2txtprojection/im2txtprojection.py +++ /dev/null @@ -1,500 +0,0 @@ -from enum import Enum -import numpy as np -import math -import json -import random -import torch -from tqdm import tqdm -import os -import h5py -from typing import Tuple -from dotenv import load_dotenv -from src.dinotxt_utils import get_tokenizer - -load_dotenv() - -class ProjectionType(Enum): - COCO_CAPTIONS = 'coco_captions' - MS_MARCO_QUERIES_A = 'ms_marco_queries_a' - CC3M_BLIP = 'cc3m_blip_captions' - VISUAL_GENOME = 'vg_captions' - VISUAL_GENOME_TEST = "vg_dense_captions_test" - ONLINE_TEXTS = "online_texts" - -class Im2TxtProjector: - """ - Im2TxtProjector creates and manages text embedding memory banks for different models: - - Standard CLIP models - - OpenCLIP models - - RegionCLIP models - - DenseClip models - - Talk2DINO projected embeddings - - For RegionCLIP usage, pass regionclip_config as a dict with: - { - 'checkpoint': '/path/to/regionclip_checkpoint.pth', - 'config_name': 'RegionCLIP_RN50.yaml' # optional - } - - For DenseClip usage, pass denseclip_config as a string with the config file name: - 'denseclip_vitb16' # or other valid DenseClip config name - """ - - SUPPORT_MEMORY_SIZE = 500000 - - __IM2TXT_MEMORY_PATH = os.getenv("IM2TXT_MEMORY_PATH") - - if __IM2TXT_MEMORY_PATH is None: - default_path = "weights/im2txtmemories" #os.path.join(os.path.dirname(__file__), "../../../im2txtmemories") - print(f"[!] Warning: IM2TXT_MEMORY_PATH not set in environment variables, using '{default_path}' [!]") - __IM2TXT_MEMORY_PATH = default_path - - __DECAP_FOLDER = os.path.join(os.path.dirname(__file__), "../") - __TALK2DINO_CONFIG_WEIGHTS_PATH = __DECAP_FOLDER - - captions_dataType = 'train2017' - ANNOTATIONS_CAPTION_FILE_PATH = os.path.join(__DECAP_FOLDER, 'captions_{}.json'.format(captions_dataType)) - VG_ANNOTATIONS_DENSE_CAPTIONS_FILE_PATH = '/raid/datasets/densecaptioning-annotations/data/vg/controlcap/vg1.2/train.json' - VG_ANNOTATIONS_DENSE_CAPTIONS_TEST_FILE_PATH = '/raid/datasets/densecaptioning-annotations/data/vg/controlcap/vg1.2/test.json' - - CC3M_BLIP_FILE_PATH = os.path.join(__DECAP_FOLDER, "blipv2_captions.txt") - MS_MARCO_QUERIES_FILE_PATH = '/raid/datasets/MSMarco/queries/queries.train.tsv' - - @staticmethod - def create_regionclip_config(checkpoint_path: str, config_name: str = None): - """ - Helper method to create RegionCLIP configuration dictionary. - - Args: - checkpoint_path (str): Path to RegionCLIP checkpoint file - config_name (str, optional): RegionCLIP config name (e.g., 'RegionCLIP_RN50.yaml') - - Returns: - dict: Configuration dictionary for RegionCLIP - """ - return { - 'checkpoint': checkpoint_path, - 'config_name': config_name - } - - def __init__(self, type = ProjectionType.COCO_CAPTIONS, verbose : bool = True, device_str = "cpu", use_talk2dino : bool = True, - support_memory_size : int = SUPPORT_MEMORY_SIZE, batch_size=1000, - clip_modelname = None, linear_talk2dino : bool = False, - normalize_memory_embs : bool = False, talk2dino_attn_type='qkv', online_texts=None, - memory_bank_name = None, use_open_clip = False, regionclip_config=None, invite_config=None, denseclip_config=None) -> None: - """ - - normalize_memory_embs -> normalizes the embeddings memory (required for projection in CLIP space) - - type : ProjectionType -> the type of the support memory to be built . Can either be the path to the file containing the captions or the type of the support memory to be built - - """ - # check if hdf5 already exists, otherwhise builds the support memory for that kind - - #if type not in ProjectionType.mro() - - self.type = type - self.device_str = device_str - self.device = torch.device(self.device_str) - self.use_talk2dino = use_talk2dino - self.linear_talk2dino = linear_talk2dino - self.talk2dino_attn_type = talk2dino_attn_type - self.online_texts = online_texts - self.use_open_clip = use_open_clip - self.regionclip_config = regionclip_config - self.invite_config = invite_config - self.denseclip_config = denseclip_config - - if use_open_clip: - assert use_talk2dino is False, "use_open_clip and use_talk2dino cannot be used together" - - if regionclip_config is not None: - assert use_talk2dino is False, "regionclip_config and use_talk2dino cannot be used together" - assert use_open_clip is False, "regionclip_config and use_open_clip cannot be used together" - - if invite_config is not None: - # overwrite clip_modelname with invite_config['name'] if provided - clip_modelname = invite_config.get('name', clip_modelname) - assert use_talk2dino is False, "invite_config and use_talk2dino cannot be used together" - - if denseclip_config is not None: - assert use_talk2dino is False, "denseclip_config and use_talk2dino cannot be used together" - assert use_open_clip is False, "denseclip_config and use_open_clip cannot be used together" - assert regionclip_config is None, "denseclip_config and regionclip_config cannot be used together" - - - if clip_modelname is None: - if self.use_talk2dino: - clip_modelname = "ViT-B/16" - elif regionclip_config is not None: - # For RegionCLIP, we'll use a generic identifier since the model type is in the config - clip_modelname = "RegionCLIP" - elif denseclip_config is not None: - # For DenseClip, we'll use a generic identifier since the model type is in the config - clip_modelname = "DenseClip" - else: - clip_modelname = "ViT-B/32" - self.clip_modelname = clip_modelname - - self.SUPPORT_MEMORY_SIZE = support_memory_size - if use_talk2dino: - prefix = "" - postfix = '-B16' if use_talk2dino is True else use_talk2dino - if linear_talk2dino: - postfix += "-linear" - elif regionclip_config is not None: - prefix = "regionclip-" - postfix = "" - elif denseclip_config is not None: - prefix = "denseclip-" - postfix = "" - else: - prefix = "clip-" - postfix = "" - if talk2dino_attn_type != 'qkv': - self.talk2dino_attn_type_str = f"_{talk2dino_attn_type}" - else: - self.talk2dino_attn_type_str = '' - if isinstance(type, ProjectionType): - dataset_name = type.value - elif memory_bank_name is not None: - dataset_name = memory_bank_name - else: - dataset_name = 'coco' - - if use_open_clip: - postfix += "-open_clip" - elif regionclip_config is not None: - postfix += "-regionclip" - # Add checkpoint identifier to make filename unique - checkpoint_path = regionclip_config.get('checkpoint', '') - checkpoint_name = os.path.basename(checkpoint_path).replace('.pth', '').replace('.pt', '') - if checkpoint_name: - postfix += f"-{checkpoint_name}" - elif denseclip_config is not None: - postfix += "-denseclip" - # Add config identifier to make filename unique - config_name = os.path.basename(denseclip_config).replace('.yaml', '').replace('.yml', '') - if config_name: - postfix += f"-{config_name}" - - self.H5PY_FILE_PATH = os.path.join( self.__IM2TXT_MEMORY_PATH, prefix + f'{dataset_name}{self.talk2dino_attn_type_str}_text_embeddings{postfix}-{clip_modelname.replace("/", ".")}-{self.SUPPORT_MEMORY_SIZE}.h5' ) - self.H5PY_EMBEDDINGS_DATASET_NAME = '{}-embeddings'.format(dataset_name) - self.H5PY_TEXT_DATASET_NAME = '{}-text'.format(dataset_name) - - embs_dataset, text_dataset = self._load_support_memory() - - if text_dataset is None: - if verbose: - model_type = "RegionCLIP" if regionclip_config is not None else ("DenseClip" if denseclip_config is not None else ("OpenCLIP" if use_open_clip else "CLIP")) - print(f"[+] Going to build support memory for the given data type: {type} using {model_type} [+]") - embs_dataset, text_dataset = self._build_support_memory(batch_size) - if verbose: print(f"[+] Done [+]") - - if self.type != ProjectionType.ONLINE_TEXTS: - embs_dataset, text_dataset = self._load_support_memory() - - print(f"[-] loaded memory from {os.path.abspath( self.H5PY_FILE_PATH )} [-]") - if regionclip_config is not None: - print(f"[-] Using RegionCLIP text embeddings from checkpoint: {regionclip_config.get('checkpoint', 'Unknown')} [-]") - elif denseclip_config is not None: - print(f"[-] Using DenseClip text embeddings from config: {denseclip_config} [-]") - - self.text_dataset = text_dataset - self.embs_dataset = torch.tensor(embs_dataset[:]).to(self.device) - self.embs_dataset = self.embs_dataset[self.embs_dataset.norm(dim=-1) != 0] - - - if normalize_memory_embs: - self.embs_dataset /= self.embs_dataset.norm(dim=-1,keepdim=True).float() - - - - def project(self, image_embedding, temperature : float = 0.01, normalize : bool = False, return_argmax_text : bool = False, return_n_best_sims=None) -> torch.TensorType: - if not isinstance(image_embedding, torch.Tensor): - print(f"the type of image_embedding is '{type(image_embedding)}' converting it to torch tensor") - image_embedding = torch.tensor(image_embedding, dtype=torch.float).to(self.device) - - orig_device = image_embedding.device - - if image_embedding.device != self.device: - image_embedding = image_embedding.to(self.device) - - if image_embedding.dtype != float: - #print(f"[-] image_embedding.dtype is {image_embedding.dtype}, converting it to float [-]") - image_embedding = image_embedding.float() - - embs_dataset = self.embs_dataset / self.embs_dataset.norm(dim=-1, keepdim=True) - image_embedding /= image_embedding.norm(dim=-1,keepdim=True) - - sim = image_embedding@embs_dataset.T.float() - if return_argmax_text: - argmax_texts = [self.text_dataset[idx].decode() for idx in sim.argmax(dim=-1)] - if return_n_best_sims: - return argmax_texts, sim.sort(dim=-1, descending=True).values[:, :return_n_best_sims].tolist() - return argmax_texts - softmax_sim = (sim / temperature).softmax(dim=-1) - prefix_embedding = softmax_sim@self.embs_dataset.float() - - if normalize: - prefix_embedding /= prefix_embedding.norm(dim=-1,keepdim=True) - - if return_n_best_sims: - return prefix_embedding.to(orig_device), sim.sort(dim=-1, descending=True).values[:, :return_n_best_sims].tolist() - - return prefix_embedding.to(orig_device) - - def _load_support_memory(self) -> Tuple[np.ndarray, np.ndarray]: - if self.type == ProjectionType.ONLINE_TEXTS: - print(f"[-] _load_support_memory: support memory for provided texts will be constructed [-]") - return None, None - if not os.path.exists(self.H5PY_FILE_PATH): - print(f"[-] _load_support_memory: the path '{self.H5PY_FILE_PATH}' does not exist [-]") - return None, None - - with h5py.File(self.H5PY_FILE_PATH, 'r') as hf: - - if self.H5PY_EMBEDDINGS_DATASET_NAME in hf: - embeddings_dataset = hf[self.H5PY_EMBEDDINGS_DATASET_NAME][:] - text_dataset = hf[self.H5PY_TEXT_DATASET_NAME][:] - else: - embeddings_dataset = None - text_dataset = None - if 'DINO.txt' in self.clip_modelname: - embeddings_dataset = embeddings_dataset[:, 1024:] # Get patch-aligned text embeddings - return embeddings_dataset, text_dataset - - - - def _build_support_memory(self, batch_size = 1000) -> Tuple[np.ndarray, np.ndarray]: - ## construct the support memory - - self._load_models() - - if self.type == ProjectionType.COCO_CAPTIONS: - from pycocotools.coco import COCO - coco_obj = COCO(Im2TxtProjector.ANNOTATIONS_CAPTION_FILE_PATH) - data = random.sample(list(coco_obj.anns.values()), k=self.SUPPORT_MEMORY_SIZE) - data = [ d['caption'] for d in data ] - elif self.type == ProjectionType.VISUAL_GENOME: - from pycocotools.coco import COCO - coco_obj = COCO(Im2TxtProjector.VG_ANNOTATIONS_DENSE_CAPTIONS_FILE_PATH) - # data = random.sample(list(coco_obj.anns.values()), k=self.SUPPORT_MEMORY_SIZE) - data = list(coco_obj.anns.values())[:self.SUPPORT_MEMORY_SIZE] - data = [ d['caption'] for d in data ] - elif self.type == ProjectionType.VISUAL_GENOME_TEST: - from pycocotools.coco import COCO - coco_obj = COCO(Im2TxtProjector.VG_ANNOTATIONS_DENSE_CAPTIONS_TEST_FILE_PATH) - # data = random.sample(list(coco_obj.anns.values()), k=self.SUPPORT_MEMORY_SIZE) - data = list(coco_obj.anns.values())[:self.SUPPORT_MEMORY_SIZE] - data = [ d['caption'] for d in data ] - elif self.type == ProjectionType.MS_MARCO_QUERIES_A: - print(f"Loading MSMarco queries from file ", Im2TxtProjector.MS_MARCO_QUERIES_FILE_PATH) - with open(Im2TxtProjector.MS_MARCO_QUERIES_FILE_PATH, "r") as input_file: - lines = input_file.readlines() - data = random.sample(lines, k=self.SUPPORT_MEMORY_SIZE) - data = [ d.split("\t")[1].replace("\n", "") for d in data ] - print(f"Loaded from file '{self.SUPPORT_MEMORY_SIZE}' lines, example of line: '{data[0]}'") - elif self.type == ProjectionType.CC3M_BLIP: - print(f"Loading cc3m captions txt file ", Im2TxtProjector.CC3M_BLIP_FILE_PATH) - with open(Im2TxtProjector.CC3M_BLIP_FILE_PATH, "r") as input_file: - lines = input_file.readlines() - data = random.sample(lines, k=self.SUPPORT_MEMORY_SIZE) - data = [ d.replace("\n", "") for d in data ] - print(f"Loaded from file '{len(data)}' lines, example of line: '{data[0]}'") - elif self.type == ProjectionType.CC3M_BLIP: - print(f"Loading cc3m captions txt file ", Im2TxtProjector.CC3M_BLIP_FILE_PATH) - with open(Im2TxtProjector.CC3M_BLIP_FILE_PATH, "r") as input_file: - lines = input_file.readlines() - data = random.sample(lines, k=self.SUPPORT_MEMORY_SIZE) - data = [ d.replace("\n", "") for d in data ] - print(f"Loaded from file '{len(data)}' lines, example of line: '{data[0]}'") - elif self.type == ProjectionType.ONLINE_TEXTS: - data = self.online_texts - print(f"Loaded online_texts '{len(data)}' lines, example of line: '{data[0]}'") - elif type(self.type) == str: - if os.path.exists(self.type): - path = self.type - from pycocotools.coco import COCO - coco_obj = COCO(path) - data = random.sample(list(coco_obj.anns.values()), k=min(self.SUPPORT_MEMORY_SIZE, len(coco_obj.anns))) - data = [ d['caption'] for d in data ] - else: - #data = random.sample(data,500000) - print(f"[!] Unimplemented data type '{self.type}'[!]") - return None, None - - text_features = [] - captions = [] - - self.clip_model.eval() - - n_txts = len(data) - n_batch = math.ceil(n_txts / batch_size) - for i in tqdm(range(n_batch)): - start = i * batch_size - end = start + batch_size if i < n_batch - 1 else n_txts - - texts = data[start:end] - with torch.no_grad(): - texts_token = self.tokenizer(texts).to(self.device) - text_feature = self.clip_model.encode_text(texts_token) - if self.use_talk2dino: - text_feature = self.talk2dino.project_clip_txt(text_feature) - text_features.append(text_feature) - captions.extend(texts) - - text_features = torch.cat(text_features,dim=0) - - - #text_features /= text_features.norm(dim=-1,keepdim=True).float() - - # store captions and text features in hdf5 dataset - - text_features_ndarray = text_features.cpu().numpy() - - assert len(text_features_ndarray) == len(captions), f"len(text_features_ndarray) = {len(text_features_ndarray)} != len(captions) = {len(captions)}" - - #if not os.path.exists(self.H5PY_FILE_PATH): - # print(f"os.path '{self.H5PY_FILE_PATH}' does not exists") - - EMBEDDINGS_DIMENSION = text_features_ndarray.shape[1] - - if self.type != ProjectionType.ONLINE_TEXTS: - with h5py.File(self.H5PY_FILE_PATH, 'w') as hf: - - if self.H5PY_EMBEDDINGS_DATASET_NAME in hf: - embeddings_dataset = hf[self.H5PY_EMBEDDINGS_DATASET_NAME] - text_dataset = hf[self.H5PY_TEXT_DATASET_NAME] - print(f"[!] Dataset '{self.H5PY_EMBEDDINGS_DATASET_NAME}' already exists! Going to overwrite [!]") - else: - embeddings_dataset = hf.create_dataset(self.H5PY_EMBEDDINGS_DATASET_NAME, shape=(self.SUPPORT_MEMORY_SIZE, EMBEDDINGS_DIMENSION), dtype='float32') - text_dataset = hf.create_dataset(self.H5PY_TEXT_DATASET_NAME, shape=(self.SUPPORT_MEMORY_SIZE, ), dtype=h5py.string_dtype(encoding='utf-8')) #, dtype='str' - - for num_row in range(len(text_features_ndarray)): - embeddings_dataset[num_row] = text_features_ndarray[num_row] - text_dataset[num_row] = captions[num_row] - else: - embeddings_dataset = text_features_ndarray - text_dataset = [x.encode() for x in captions] - - return embeddings_dataset, text_dataset - - clip_model = None - def _load_models(self): - - if self.clip_model is not None: - # case already done - return - - if self.use_open_clip: - print("[-] loading open_clip model [-]") - assert self.clip_modelname is not None, "clip_modelname must be provided when using open_clip" - from open_clip import create_model_and_transforms, tokenize - self.clip_model, preprocess_train, preprocess_val = create_model_and_transforms(self.clip_modelname, pretrained="laion2b_s32b_b79k", device=self.device) - self.preprocess = preprocess_train - self.tokenizer = tokenize - return - - if self.regionclip_config is not None: - print("[-] loading RegionCLIP model [-]") - from src.regionclip.loader import load_regionclip_from_checkpoint - from src.regionclip.datasets.clip_prompt_utils import tokenize as regionclip_tokenize - - regionclip_checkpoint = self.regionclip_config.get('checkpoint', None) - if regionclip_checkpoint is None: - raise ValueError("RegionCLIP checkpoint not specified in the configuration") - regionclip_config_name = self.regionclip_config.get('config_name', None) - - print(f"[-] Loading RegionCLIP from checkpoint: {regionclip_checkpoint} [-]") - if regionclip_config_name: - print(f"[-] Using RegionCLIP config: {regionclip_config_name} [-]") - - self.clip_model = load_regionclip_from_checkpoint( - regionclip_checkpoint, - device=self.device, - config=regionclip_config_name - ) - self.tokenizer = regionclip_tokenize - self.preprocess = None # RegionCLIP doesn't need preprocessing for text encoding - - # Test RegionCLIP text encoding to ensure it works - try: - test_text = ["A test sentence for RegionCLIP"] - test_tokens = self.tokenizer(test_text) - test_features = self.clip_model.encode_text(test_tokens.to(self.device)) - print(f"[-] RegionCLIP text encoding test successful. Output shape: {test_features.shape} [-]") - except Exception as e: - print(f"[!] Warning: RegionCLIP text encoding test failed: {e} [!]") - raise e - - return - - if self.denseclip_config is not None: - print("[-] loading DenseClip model [-]") - from src.denseclip.loader import load_denseclip, DenseCLIP_tokenize - - print(f"[-] Loading DenseClip from config: {self.denseclip_config} [-]") - - # Load DenseClip model - self.clip_model = load_denseclip( - config_name=self.denseclip_config, - device=self.device - ) - - # DenseClip should have encode_text method and a tokenizer - # We need to check if DenseClip has a tokenizer method - if DenseCLIP_tokenize is not None: - self.tokenizer = DenseCLIP_tokenize - else: - # Fallback to CLIP tokenizer if DenseClip doesn't provide one - import clip - self.tokenizer = clip.tokenize - print("[!] Warning: DenseClip model doesn't have tokenizer, using CLIP tokenizer [!]") - - self.preprocess = None # DenseClip doesn't need preprocessing for text encoding - - # Test DenseClip text encoding to ensure it works - try: - test_text = ["A test sentence for DenseClip"] - test_tokens = self.tokenizer(test_text) - if hasattr(test_tokens, 'to'): - test_tokens = test_tokens.to(self.device) - test_features = self.clip_model.encode_text(test_tokens) - print(f"[-] DenseClip text encoding test successful. Output shape: {test_features.shape} [-]") - except Exception as e: - print(f"[!] Warning: DenseClip text encoding test failed: {e} [!]") - raise e - - return - - import clip - if self.clip_modelname is None: - clip_model_name = "ViT-B/16" if self.use_talk2dino else "ViT-B/32" - else: - clip_model_name = self.clip_modelname - if 'DINO.txt' not in clip_model_name: - self.clip_model, self.preprocess = clip.load(clip_model_name, device=self.device, jit=False) - self.tokenizer = clip.tokenize - if self.use_talk2dino: - # loading Talk2DINO - if type(self.use_talk2dino) == str: - proj_name = self.use_talk2dino - elif self.linear_talk2dino is False: - proj_name = 'vitb_mlp_infonce' - else: - proj_name = 'vitb_linear_infonce' - - - config = os.path.join(self.__TALK2DINO_CONFIG_WEIGHTS_PATH, "configs_talk2dino", proj_name + '.yaml') - weights = os.path.join(self.__TALK2DINO_CONFIG_WEIGHTS_PATH, "weights_talk2dino", proj_name + self.talk2dino_attn_type_str + '.pth') - #import sys - #import os - #add_path = os.path.abspath( os.path.dirname("../")) - ##print(add_path) - #sys.path.insert(1, add_path ) - from src.model import ProjectionLayer - self.talk2dino = ProjectionLayer.from_config(config) - self.talk2dino.load_state_dict(torch.load((weights), self.device)) - self.talk2dino.to(self.device) - else: - self.clip_model = torch.hub.load('facebookresearch/dinov2', 'dinov2_vitl14_reg4_dinotxt_tet1280d20h24l').to(self.device) - self.tokenizer = get_tokenizer().tokenize diff --git a/src/denseclip/clip_loader/README.md b/src/denseclip/clip_loader/README.md deleted file mode 100644 index 1264ac75f187a1224662c377b35ab40873a9aeb9..0000000000000000000000000000000000000000 --- a/src/denseclip/clip_loader/README.md +++ /dev/null @@ -1,233 +0,0 @@ -# DenseCLIP to CLIP Loader - -A simple interface for loading DenseCLIP checkpoints as CLIP-like models for text and image encoding. - -## Overview - -This module provides a clean API to load DenseCLIP models and use them like standard CLIP models for encoding text and images. It abstracts away the complexity of DenseCLIP's detection/segmentation components and exposes only the core vision-language encoding functionality. - -## Features - -- โœ… **Simple API**: Load DenseCLIP models with a single function call -- โœ… **CLIP-like Interface**: Familiar `encode_text()` and `encode_image()` methods -- โœ… **Flexible Configuration**: YAML-based configuration system -- โœ… **Multiple Input Types**: Support for PIL Images, image tensors, strings, and text lists -- โœ… **Automatic Preprocessing**: Built-in image preprocessing pipeline -- โœ… **Device Management**: Automatic GPU/CPU detection and placement - -## Quick Start - -```python -from clip_loader import load_clip - -# Load DenseCLIP model with default configuration -model = load_clip('denseclip_segmentation_vitb16') - -# Encode text -texts = ["a photo of a cat", "a photo of a dog"] -text_features = model.encode_text(texts) - -# Encode images (PIL Images) -from PIL import Image -images = [Image.open("cat.jpg"), Image.open("dog.jpg")] -image_features = model.encode_image(images) - -# Compute similarities -similarities = model.compute_similarity(image_features, text_features) -print(f"Image-text similarities: {similarities}") -``` - -## Configuration - -Models are configured using YAML files in the `configs/` directory. The main configuration for DenseCLIP ViT-B/16 is in `configs/denseclip_vitb16.yaml`. - -### Configuration Structure - -```yaml -model: - name: "denseclip_vitb16" - type: "vit" - - vision: - image_resolution: 224 - vision_layers: 12 - vision_width: 768 - vision_patch_size: 16 - embed_dim: 512 - - text: - context_length: 13 # DenseCLIP uses shorter context - vocab_size: 49408 - transformer_width: 512 - transformer_heads: 8 - transformer_layers: 12 - embed_dim: 512 - -checkpoint: - path: "/path/to/denseclip/checkpoint.pth" - format: "denseclip" - -preprocessing: - image_mean: [0.48145466, 0.4578275, 0.40821073] - image_std: [0.26862954, 0.26130258, 0.27577711] - normalize: true -``` - -## API Reference - -### Core Functions - -#### `load_clip(config_name, checkpoint_path=None, device='auto')` - -Load a DenseCLIP model with the specified configuration. - -**Parameters:** -- `config_name` (str): Name of config file (without .yaml extension) -- `checkpoint_path` (str, optional): Path to checkpoint file (overrides config) -- `device` (str): Device to load on ('auto', 'cpu', 'cuda') - -**Returns:** -- `DenseCLIPModel`: Loaded model ready for inference - -#### `load_denseclip_model(config_path, checkpoint_path=None, device='auto')` - -Load a DenseCLIP model from configuration file path. - -### DenseCLIPModel Methods - -#### `encode_text(texts)` - -Encode text into feature vectors. - -**Parameters:** -- `texts` (str or List[str]): Text string(s) to encode - -**Returns:** -- `torch.Tensor`: Normalized text features [batch_size, embed_dim] - -#### `encode_image(images)` - -Encode images into feature vectors. - -**Parameters:** -- `images`: PIL Image, List[PIL.Image], or preprocessed tensor - -**Returns:** -- `torch.Tensor`: Normalized image features [batch_size, embed_dim] - -#### `compute_similarity(image_features, text_features, temperature=1.0)` - -Compute similarity between image and text features. - -**Parameters:** -- `image_features` (torch.Tensor): Image features [N, embed_dim] -- `text_features` (torch.Tensor): Text features [M, embed_dim] -- `temperature` (float): Temperature scaling factor - -**Returns:** -- `torch.Tensor`: Similarity matrix [N, M] - -## Examples - -### Basic Text-Image Retrieval - -```python -from clip_loader import load_clip -from PIL import Image - -# Load model -model = load_clip('denseclip_vitb16') - -# Load and encode images -images = [ - Image.open("cat.jpg"), - Image.open("dog.jpg"), - Image.open("car.jpg") -] -image_features = model.encode_image(images) - -# Encode text queries -queries = [ - "a cute cat", - "a happy dog", - "a red car" -] -text_features = model.encode_text(queries) - -# Find best matches -similarities = model.compute_similarity(image_features, text_features) -best_matches = similarities.argmax(dim=1) - -for i, query in enumerate(queries): - best_image_idx = best_matches[i] - score = similarities[best_image_idx, i].item() - print(f"Query '{query}' -> Image {best_image_idx} (score: {score:.3f})") -``` - -### Zero-Shot Classification - -```python -from clip_loader import load_clip -from PIL import Image - -model = load_clip('denseclip_vitb16') - -# Load test image -image = Image.open("test_image.jpg") -image_features = model.encode_image(image) - -# Define class labels -class_labels = [ - "a photo of a cat", - "a photo of a dog", - "a photo of a bird", - "a photo of a car", - "a photo of a house" -] - -# Encode labels -text_features = model.encode_text(class_labels) - -# Classify -similarities = model.compute_similarity(image_features, text_features) -probabilities = similarities.softmax(dim=-1) - -# Show results -for i, label in enumerate(class_labels): - prob = probabilities[0, i].item() - print(f"{label}: {prob:.3f}") -``` - -### Custom Configuration - -```python -from clip_loader import load_denseclip_model - -# Load with custom config -model = load_denseclip_model( - config_path='configs/custom_config.yaml', - checkpoint_path='/path/to/custom/checkpoint.pth', - device='cuda:1' -) -``` - -## Requirements - -- PyTorch >= 1.9.0 -- torchvision >= 0.10.0 -- Pillow >= 8.0.0 -- PyYAML >= 5.4.0 - -## Notes - -- DenseCLIP uses a shorter text context length (13) compared to standard CLIP (77) -- The model preserves ~98% similarity with original CLIP text representations -- Image preprocessing follows CLIP's standard normalization -- All features are L2-normalized for cosine similarity computation - -## Supported Models - -Currently supported: -- `denseclip_vitb16`: DenseCLIP with ViT-B/16 backbone - -To add support for other DenseCLIP variants, create new configuration files in the `configs/` directory. diff --git a/src/denseclip/clip_loader/SUMMARY.md b/src/denseclip/clip_loader/SUMMARY.md deleted file mode 100644 index 0a0b3d4a6feb65999bc817308adefc189bc33b65..0000000000000000000000000000000000000000 --- a/src/denseclip/clip_loader/SUMMARY.md +++ /dev/null @@ -1,78 +0,0 @@ -# DenseCLIP to CLIP Loader - Quick Start - -## โœ… Successfully Created! - -The `clip_loader` module provides a simple interface to load DenseCLIP checkpoints as CLIP-like models for text and image encoding. - -## ๐Ÿ“ Structure - -``` -/raid/homes/giacomo.pacini/DenseCLIP/clip_loader/ -โ”œโ”€โ”€ __init__.py # Module initialization -โ”œโ”€โ”€ denseclip_loader.py # Main loader implementation -โ”œโ”€โ”€ example_usage.py # Example script -โ”œโ”€โ”€ requirements.txt # Dependencies -โ”œโ”€โ”€ README.md # Full documentation -โ””โ”€โ”€ configs/ - โ””โ”€โ”€ denseclip_vitb16.yaml # Configuration for ViT-B/16 model -``` - -## ๐Ÿš€ Quick Usage - -```python -from clip_loader import load_clip - -# Load DenseCLIP model with default configuration -model = load_clip('denseclip_vitb16') - -# Encode text -texts = ["a photo of a cat", "a photo of a dog"] -text_features = model.encode_text(texts) # Shape: [2, 512] - -# Encode images (if you have PIL Images) -# image_features = model.encode_image(images) - -# Compute similarities -similarities = model.compute_similarity(text_features, text_features) -print(f"Cat-Dog similarity: {similarities[0, 1]:.3f}") -``` - -## โœ… Test Results - -- **โœ… Model loads successfully** from DenseCLIP checkpoint -- **โœ… Text encoding works** (shape: [batch_size, 512]) -- **โœ… Features are normalized** (L2 norm = 1.0) -- **โœ… Similarities make sense** (Cat-Dog: 0.872, Car-Person: lower) -- **โœ… Zero-shot classification** shows logical patterns -- **โœ… Model has 157M parameters** (94M vision + 63M text) - -## ๐Ÿ”ง Key Features - -- **Simple API**: Just call `load_clip()` and start encoding -- **Handles DenseCLIP specifics**: Automatically extracts weights from segmentation checkpoint -- **CLIP-compatible**: Same interface as OpenAI CLIP -- **Flexible configuration**: YAML-based configuration system -- **GPU ready**: Automatic device detection and placement -- **Context length**: Uses DenseCLIP's shorter context (13 vs 77) - -## ๐ŸŽฏ Use Cases - -1. **Text-Image Retrieval**: Encode both and compute similarities -2. **Zero-Shot Classification**: Encode class descriptions and compare -3. **Text Similarity**: Compare text representations -4. **Feature Extraction**: Get dense vector representations - -## ๐Ÿ“ Configuration - -The model uses `/raid/datasets/models_weights/denseclip/segmentation/semanticFPN/ViT-B-DenseCLIP.pth` by default. You can override this by modifying `configs/denseclip_vitb16.yaml` or passing a custom checkpoint path. - -## ๐Ÿ” What's Different from Standard CLIP - -- **Shorter context length**: 13 tokens vs 77 -- **Higher image resolution**: 640px vs 224px -- **Fine-tuned weights**: Adapted for dense prediction tasks -- **High text similarity**: ~98% similarity with original CLIP representations - -## ๐ŸŽ‰ Ready to Use! - -The loader is fully functional and ready for use in your projects. See `README.md` for detailed documentation and more examples. diff --git a/src/denseclip/clip_loader/__init__.py b/src/denseclip/clip_loader/__init__.py deleted file mode 100644 index 6432dd0289941f085bff7389a07c5306244f315e..0000000000000000000000000000000000000000 --- a/src/denseclip/clip_loader/__init__.py +++ /dev/null @@ -1,21 +0,0 @@ -""" -DenseCLIP to CLIP Loader - -A simple interface for loading DenseCLIP checkpoints as CLIP-like models -for text and image encoding. -""" - -from .denseclip_loader import ( - DenseCLIPModel, - load_denseclip_model, - load_clip, - load_config -) - -__version__ = "1.0.0" -__all__ = [ - "DenseCLIPModel", - "load_denseclip_model", - "load_clip", - "load_config" -] diff --git a/src/denseclip/clip_loader/bpe_simple_vocab_16e6.txt.gz b/src/denseclip/clip_loader/bpe_simple_vocab_16e6.txt.gz deleted file mode 100644 index 36a15856e00a06a9fbed8cdd34d2393fea4a3113..0000000000000000000000000000000000000000 --- a/src/denseclip/clip_loader/bpe_simple_vocab_16e6.txt.gz +++ /dev/null @@ -1,3 +0,0 @@ -version https://git-lfs.github.com/spec/v1 -oid sha256:924691ac288e54409236115652ad4aa250f48203de50a9e4722a6ecd48d6804a -size 1356917 diff --git a/src/denseclip/clip_loader/configs/denseclip_segmentation_vitb16.yaml b/src/denseclip/clip_loader/configs/denseclip_segmentation_vitb16.yaml deleted file mode 100644 index 658adbc7907d69217ba8f5a5521b364b6aad8e78..0000000000000000000000000000000000000000 --- a/src/denseclip/clip_loader/configs/denseclip_segmentation_vitb16.yaml +++ /dev/null @@ -1,41 +0,0 @@ -# DenseCLIP ViT-B/16 Configuration -# Configuration for loading DenseCLIP checkpoint as a CLIP-like model - -model: - name: "denseclip_vitb16" - type: "vit" # vision transformer - - # Vision encoder configuration - vision: - image_resolution: 640 - vision_layers: 12 - vision_width: 768 - vision_patch_size: 16 - embed_dim: 512 - - # Text encoder configuration - text: - context_length: 13 # DenseCLIP uses shorter context - vocab_size: 49408 - transformer_width: 512 - transformer_heads: 8 - transformer_layers: 12 - embed_dim: 512 - -# Checkpoint information -checkpoint: - path: "/raid/datasets/models_weights/denseclip/segmentation/semanticFPN/ViT-B-DenseCLIP.pth" - format: "denseclip" # vs "openai_clip" - -# Processing configuration -preprocessing: - image_mean: [0.48145466, 0.4578275, 0.40821073] - image_std: [0.26862954, 0.26130258, 0.27577711] - normalize: true - -# Optional overrides -overrides: - # Set to true to use OpenAI CLIP tokenizer instead of DenseCLIP's - use_openai_tokenizer: false - # Set custom context length (will resize positional embeddings if needed) - custom_context_length: null diff --git a/src/denseclip/clip_loader/configs/denseclip_segmentation_vitb16_long_ctx.yaml b/src/denseclip/clip_loader/configs/denseclip_segmentation_vitb16_long_ctx.yaml deleted file mode 100644 index 6cf19fe317d8720d15cca2f62f5dc168252b29f5..0000000000000000000000000000000000000000 --- a/src/denseclip/clip_loader/configs/denseclip_segmentation_vitb16_long_ctx.yaml +++ /dev/null @@ -1,41 +0,0 @@ -# DenseCLIP ViT-B/16 Configuration -# Configuration for loading DenseCLIP checkpoint as a CLIP-like model - -model: - name: "denseclip_vitb16" - type: "vit" # vision transformer - - # Vision encoder configuration - vision: - image_resolution: 640 - vision_layers: 12 - vision_width: 768 - vision_patch_size: 16 - embed_dim: 512 - - # Text encoder configuration - text: - context_length: 77 # DenseCLIP uses shorter context - vocab_size: 49408 - transformer_width: 512 - transformer_heads: 8 - transformer_layers: 12 - embed_dim: 512 - -# Checkpoint information -checkpoint: - path: "/raid/datasets/models_weights/denseclip/segmentation/semanticFPN/ViT-B-DenseCLIP_long_ctx.pth" - format: "denseclip" # vs "openai_clip" - -# Processing configuration -preprocessing: - image_mean: [0.48145466, 0.4578275, 0.40821073] - image_std: [0.26862954, 0.26130258, 0.27577711] - normalize: true - -# Optional overrides -overrides: - # Set to true to use OpenAI CLIP tokenizer instead of DenseCLIP's - use_openai_tokenizer: false - # Set custom context length (will resize positional embeddings if needed) - custom_context_length: null diff --git a/src/denseclip/clip_loader/denseclip_loader.py b/src/denseclip/clip_loader/denseclip_loader.py deleted file mode 100644 index ec4194998ab5c38e15069313d8085cf404a7985c..0000000000000000000000000000000000000000 --- a/src/denseclip/clip_loader/denseclip_loader.py +++ /dev/null @@ -1,316 +0,0 @@ -#!/usr/bin/env python3 -""" -DenseCLIP to CLIP Loader - -A simple interface for loading DenseCLIP checkpoints as CLIP-like models -for text and image encoding. -""" - -import os -import sys -import yaml -import torch -import torch.nn as nn -import torch.nn.functional as F -from typing import Union, List, Tuple, Optional, Dict, Any -from PIL import Image -import torchvision.transforms as transforms - -# Import local model components -try: - from .models import CLIPVisionTransformer, CLIPTextEncoder, ResidualAttentionBlock, LayerNorm, QuickGELU - from .tokenizer import tokenize -except ImportError: - # Fallback for direct execution - from models import CLIPVisionTransformer, CLIPTextEncoder, ResidualAttentionBlock, LayerNorm, QuickGELU - from tokenizer import tokenize - - -class DenseCLIPModel(nn.Module): - """ - A CLIP-like model loaded from DenseCLIP checkpoints. - Provides simple text and image encoding functionality. - """ - - def __init__(self, config: Dict[str, Any]): - super().__init__() - - self.config = config - - # Initialize vision encoder - vision_config = config['model']['vision'] - self.visual = CLIPVisionTransformer( - input_resolution=vision_config['image_resolution'], - patch_size=vision_config['vision_patch_size'], - width=vision_config['vision_width'], - layers=vision_config['vision_layers'], - heads=vision_config['vision_width'] // 64, - output_dim=vision_config['embed_dim'] - ) - - # Initialize text encoder - text_config = config['model']['text'] - self.text_encoder = CLIPTextEncoder( - context_length=text_config['context_length'], - vocab_size=text_config['vocab_size'], - transformer_width=text_config['transformer_width'], - transformer_heads=text_config['transformer_heads'], - transformer_layers=text_config['transformer_layers'], - embed_dim=text_config['embed_dim'] - ) - - # Store configuration for preprocessing - self.context_length = text_config['context_length'] - self.image_resolution = vision_config['image_resolution'] - - # Initialize preprocessing - self._setup_preprocessing() - - def _setup_preprocessing(self): - """Setup image preprocessing pipeline""" - preprocess_config = self.config['preprocessing'] - - self.preprocess = transforms.Compose([ - transforms.Resize(self.image_resolution, interpolation=transforms.InterpolationMode.BICUBIC), - transforms.CenterCrop(self.image_resolution), - transforms.ToTensor(), - transforms.Normalize( - mean=preprocess_config['image_mean'], - std=preprocess_config['image_std'] - ) - ]) - - def encode_image(self, images: Union[torch.Tensor, List[Image.Image], Image.Image]) -> torch.Tensor: - """ - Encode images into feature vectors - - Args: - images: PIL Images, list of PIL Images, or preprocessed tensor - - Returns: - Normalized image features [batch_size, embed_dim] - """ - if isinstance(images, (list, tuple)): - # List of PIL Images - image_tensors = torch.stack([self.preprocess(img) for img in images]) - elif isinstance(images, Image.Image): - # Single PIL Image - image_tensors = self.preprocess(images).unsqueeze(0) - elif isinstance(images, torch.Tensor): - # Already preprocessed tensor - image_tensors = images - else: - raise ValueError(f"Unsupported image type: {type(images)}") - - # Move to same device as model - device = next(self.parameters()).device - image_tensors = image_tensors.to(device) - - # Encode - with torch.no_grad(): - image_features = self.visual(image_tensors) - image_features = F.normalize(image_features, dim=-1) - - return image_features - - def encode_text(self, texts: Union[str, List[str]]) -> torch.Tensor: - """ - Encode texts into feature vectors - - Args: - texts: Single text string or list of text strings - - Returns: - Normalized text features [batch_size, embed_dim] - """ - if isinstance(texts, str): - texts = [texts] - - # Tokenize if necessary - if isinstance(texts, list): - tokens = tokenize(texts, context_length=self.context_length) - elif isinstance(texts, torch.Tensor): - if texts.dim() == 1: - # Single tokenized text - tokens = texts.unsqueeze(0) - else: - tokens = texts - else: - raise ValueError(f"Unsupported text type: {type(texts)}") - # Move to same device as model - device = next(self.parameters()).device - tokens = tokens.to(device) - - # Encode - with torch.no_grad(): - text_features = self.text_encoder(tokens) - text_features = F.normalize(text_features, dim=-1) - - return text_features - - def compute_similarity(self, - image_features: torch.Tensor, - text_features: torch.Tensor, - temperature: float = 1.0) -> torch.Tensor: - """ - Compute similarity between image and text features - - Args: - image_features: Normalized image features [N, embed_dim] - text_features: Normalized text features [M, embed_dim] - temperature: Temperature for scaling similarities - - Returns: - Similarity matrix [N, M] - """ - return (image_features @ text_features.t()) / temperature - - def forward(self, images: torch.Tensor, texts: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: - """ - Forward pass for both image and text encoding - - Args: - images: Preprocessed image tensor [batch_size, 3, H, W] - texts: Tokenized text tensor [batch_size, context_length] - - Returns: - Tuple of (image_features, text_features) - """ - image_features = self.visual(images) - text_features = self.text_encoder(texts) - - # Normalize features - image_features = F.normalize(image_features, dim=-1) - text_features = F.normalize(text_features, dim=-1) - - return image_features, text_features - - -def load_config(config_path: str) -> Dict[str, Any]: - """Load configuration from YAML file""" - with open(config_path, 'r') as f: - config = yaml.safe_load(f) - return config - - -def load_denseclip_weights(checkpoint_path: str) -> Dict[str, torch.Tensor]: - """Load DenseCLIP checkpoint and extract relevant weights""" - print(f"Loading DenseCLIP checkpoint from: {checkpoint_path}") - checkpoint = torch.load(checkpoint_path, map_location='cpu') - - if 'state_dict' not in checkpoint: - raise ValueError("Checkpoint doesn't contain 'state_dict'") - - state_dict = checkpoint['state_dict'] - - # Extract vision and text encoder weights - vision_weights = {} - text_weights = {} - - for key, value in state_dict.items(): - if key.startswith('backbone.'): - # Remove 'backbone.' prefix for vision encoder - new_key = key[len('backbone.'):] - vision_weights[new_key] = value - elif key.startswith('text_encoder.'): - # Remove 'text_encoder.' prefix - new_key = key[len('text_encoder.'):] - text_weights[new_key] = value - - print(f"Extracted {len(vision_weights)} vision parameters") - print(f"Extracted {len(text_weights)} text parameters") - - return { - 'vision': vision_weights, - 'text': text_weights, - 'full_state_dict': state_dict - } - - -def load_denseclip_model(config_path: str, - checkpoint_path: Optional[str] = None, - device: str = 'auto') -> DenseCLIPModel: - """ - Load a DenseCLIP model from configuration and checkpoint - - Args: - config_path: Path to YAML configuration file - checkpoint_path: Optional path to checkpoint (overrides config) - device: Device to load model on ('auto', 'cpu', 'cuda') - - Returns: - Loaded DenseCLIPModel ready for inference - """ - # Load configuration - config = load_config(config_path) - - # Override checkpoint path if provided - if checkpoint_path is not None: - config['checkpoint']['path'] = checkpoint_path - - # Create model - model = DenseCLIPModel(config) - - # Load weights - checkpoint_path = config['checkpoint']['path'] - if os.path.exists(checkpoint_path): - weights = load_denseclip_weights(checkpoint_path) - - # Load vision encoder weights - if weights['vision']: - missing_v, unexpected_v = model.visual.load_state_dict(weights['vision'], strict=False) - if missing_v: - print(f"Missing vision keys: {len(missing_v)} (expected for FPN/post-norm components)") - if unexpected_v: - # Filter out expected mismatches - important_unexpected = [k for k in unexpected_v if not any(x in k for x in ['fpn', 'ln_post', 'proj'])] - if important_unexpected: - print(f"Unexpected vision keys: {important_unexpected}") - else: - print(f"โœ“ Vision weights loaded (ignoring {len(unexpected_v)} FPN/post-norm parameters)") - - # Load text encoder weights - if weights['text']: - missing_t, unexpected_t = model.text_encoder.load_state_dict(weights['text'], strict=False) - if missing_t: - print(f"Missing text keys: {len(missing_t)}") - if unexpected_t: - print(f"Unexpected text keys: {unexpected_t}") - - print("โœ“ Model weights loaded successfully") - else: - print(f"โš  Checkpoint not found at {checkpoint_path}, using random weights") - - # Setup device - if device == 'auto': - device = 'cuda' if torch.cuda.is_available() else 'cpu' - - model = model.to(device) - model.eval() - - print(f"โœ“ Model loaded on {device}") - return model - - -# Convenience function -def load_clip(config_name: str = 'denseclip_vitb16', - checkpoint_path: Optional[str] = None, - device: str = 'auto') -> DenseCLIPModel: - """ - Convenience function to load a DenseCLIP model - - Args: - config_name: Name of config file (without .yaml extension) - checkpoint_path: Optional path to checkpoint - device: Device to load on - - Returns: - Loaded DenseCLIPModel - """ - current_dir = os.path.dirname(os.path.abspath(__file__)) - config_path = os.path.join(current_dir, 'configs', f'{config_name}.yaml') - - if not os.path.exists(config_path): - raise FileNotFoundError(f"Config file not found: {config_path}") - - return load_denseclip_model(config_path, checkpoint_path, device) diff --git a/src/denseclip/clip_loader/example_usage.py b/src/denseclip/clip_loader/example_usage.py deleted file mode 100644 index 9ca22e0d09b02d964fb2a56ffa547ac6a89f0381..0000000000000000000000000000000000000000 --- a/src/denseclip/clip_loader/example_usage.py +++ /dev/null @@ -1,108 +0,0 @@ -#!/usr/bin/env python3 -""" -Example usage of the DenseCLIP to CLIP loader -""" - -import sys -import os - -# Add the clip_loader to path -current_dir = os.path.dirname(os.path.abspath(__file__)) -sys.path.append(current_dir) - -from denseclip_loader import load_clip -import torch - - -def main(): - print("๐Ÿš€ DenseCLIP to CLIP Loader Example") - print("=" * 50) - - # Load model - print("Loading DenseCLIP model...") - try: - model = load_clip('denseclip_segmentation_vitb16') - print("โœ… Model loaded successfully!") - except Exception as e: - print(f"โŒ Error loading model: {e}") - return - - print(f"Model device: {next(model.parameters()).device}") - print(f"Text context length: {model.context_length}") - print(f"Image resolution: {model.image_resolution}") - - # Test text encoding - print("\n๐Ÿ“ Testing text encoding...") - texts = [ - "a photo of a cat", - "a photo of a dog", - "a photo of a car", - "a person walking", - "a beautiful sunset" - ] - - text_features = model.encode_text(texts) - print(f"Text features shape: {text_features.shape}") - print(f"Text features norm: {text_features.norm(dim=-1)}") # Should be ~1.0 (normalized) - - # Test text-text similarities - print("\n๐Ÿ” Text-to-text similarities:") - text_similarities = model.compute_similarity(text_features, text_features) - - print(f"{'Text':<20} {'Self-sim':<10} {'vs Cat':<10} {'vs Dog':<10}") - print("-" * 50) - for i, text in enumerate(texts): - self_sim = text_similarities[i, i].item() - cat_sim = text_similarities[i, 0].item() - dog_sim = text_similarities[i, 1].item() - print(f"{text:<20} {self_sim:<10.3f} {cat_sim:<10.3f} {dog_sim:<10.3f}") - - # Test zero-shot classification concepts - print("\n๐ŸŽฏ Zero-shot classification example:") - test_queries = [ - "an animal", - "a vehicle", - "a person", - "nature scene" - ] - - query_features = model.encode_text(test_queries) - classification_similarities = model.compute_similarity(text_features, query_features) - - print(f"{'Original Text':<20} {'Animal':<8} {'Vehicle':<8} {'Person':<8} {'Nature':<8}") - print("-" * 60) - for i, text in enumerate(texts): - sims = classification_similarities[i] - print(f"{text:<20} {sims[0]:<8.3f} {sims[1]:<8.3f} {sims[2]:<8.3f} {sims[3]:<8.3f}") - - # Test feature statistics - print("\n๐Ÿ“Š Feature statistics:") - print(f"Text feature mean: {text_features.mean():.6f}") - print(f"Text feature std: {text_features.std():.6f}") - print(f"Text feature min: {text_features.min():.6f}") - print(f"Text feature max: {text_features.max():.6f}") - - # Test model components - print("\n๐Ÿ”ง Model architecture:") - print(f"Vision encoder: {type(model.visual).__name__}") - print(f"Text encoder: {type(model.text_encoder).__name__}") - - # Count parameters - vision_params = sum(p.numel() for p in model.visual.parameters()) - text_params = sum(p.numel() for p in model.text_encoder.parameters()) - total_params = vision_params + text_params - - print(f"\n๐Ÿ“ˆ Parameter count:") - print(f"Vision encoder: {vision_params:,}") - print(f"Text encoder: {text_params:,}") - print(f"Total: {total_params:,}") - - print("\nโœ… All tests completed successfully!") - print("\n๐Ÿ’ก Usage tip:") - print(" from clip_loader import load_clip") - print(" model = load_clip('denseclip_vitb16')") - print(" features = model.encode_text(['your text here'])") - - -if __name__ == "__main__": - main() diff --git a/src/denseclip/clip_loader/models.py b/src/denseclip/clip_loader/models.py deleted file mode 100644 index 81cfadfe03fdd4dc5f0da19b0fa1f1cef9cae05b..0000000000000000000000000000000000000000 --- a/src/denseclip/clip_loader/models.py +++ /dev/null @@ -1,178 +0,0 @@ -""" -Standalone model components for DenseCLIP to CLIP loader. -Contains all necessary classes without external dependencies. -""" - -import math -import torch -import torch.nn as nn -import torch.nn.functional as F -from collections import OrderedDict -from typing import Tuple, Union - - -class LayerNorm(nn.LayerNorm): - """Subclass torch's LayerNorm to handle fp16.""" - - def forward(self, x: torch.Tensor): - orig_type = x.dtype - ret = super().forward(x.type(torch.float32)) - return ret.type(orig_type) - - -class QuickGELU(nn.Module): - def forward(self, x: torch.Tensor): - return x * torch.sigmoid(1.702 * x) - - -class ResidualAttentionBlock(nn.Module): - def __init__(self, d_model, n_head, attn_mask=None): - super().__init__() - - self.attn = nn.MultiheadAttention(d_model, n_head) - self.ln_1 = LayerNorm(d_model) - self.mlp = nn.Sequential(OrderedDict([ - ("c_fc", nn.Linear(d_model, d_model * 4)), - ("gelu", QuickGELU()), - ("c_proj", nn.Linear(d_model * 4, d_model)) - ])) - self.ln_2 = LayerNorm(d_model) - self.attn_mask = attn_mask - - def attention(self, x): - # Create attention mask for current sequence length - seq_len = x.shape[0] - if self.attn_mask is not None: - if seq_len <= self.attn_mask.shape[0]: - attn_mask = self.attn_mask[:seq_len, :seq_len].to(dtype=x.dtype, device=x.device) - else: - # Extend mask for longer sequences - attn_mask = torch.empty(seq_len, seq_len, device=x.device, dtype=x.dtype) - attn_mask.fill_(float("-inf")) - attn_mask.triu_(1) - else: - attn_mask = None - - return self.attn(x, x, x, need_weights=False, attn_mask=attn_mask)[0] - - def forward(self, x): - x = x + self.attention(self.ln_1(x)) - x = x + self.mlp(self.ln_2(x)) - return x - - -class CLIPTextEncoder(nn.Module): - """ - Standard CLIP text encoder implementation for comparison. - This matches the original CLIP text encoder architecture. - """ - - def __init__(self, context_length=77, vocab_size=49408, transformer_width=512, - transformer_heads=8, transformer_layers=12, embed_dim=1024, **kwargs): - super().__init__() - - self.context_length = context_length - self.vocab_size = vocab_size - self.transformer_width = transformer_width - self.embed_dim = embed_dim - - # Build the transformer layers with proper naming - self.transformer = self._build_transformer( - transformer_width, transformer_layers, transformer_heads, context_length - ) - - # Text processing components - self.token_embedding = nn.Embedding(vocab_size, transformer_width) - self.positional_embedding = nn.Parameter(torch.empty(context_length, transformer_width)) - self.ln_final = LayerNorm(transformer_width) - self.text_projection = nn.Parameter(torch.empty(transformer_width, embed_dim)) - - self._initialize_parameters() - - def _build_transformer(self, width, layers, heads, context_length): - """Build transformer layers with causal attention mask""" - # Create causal attention mask - mask = torch.empty(context_length, context_length) - mask.fill_(float("-inf")) - mask.triu_(1) # zero out the lower diagonal - - # Build transformer blocks with proper naming (resblocks) - resblocks = nn.Sequential(*[ - ResidualAttentionBlock(width, heads, mask) for _ in range(layers) - ]) - - # Create a module that matches CLIP's naming convention - transformer = nn.Module() - transformer.resblocks = resblocks - return transformer - - def _initialize_parameters(self): - """Initialize parameters following CLIP initialization""" - nn.init.normal_(self.token_embedding.weight, std=0.02) - nn.init.normal_(self.positional_embedding, std=0.01) - nn.init.normal_(self.text_projection, std=self.transformer_width ** -0.5) - - def forward(self, text): - """ - Forward pass for text encoding - - Args: - text: Tokenized text tensor of shape [batch_size, context_length] - - Returns: - Text features tensor of shape [batch_size, embed_dim] - """ - x = self.token_embedding(text) # [batch_size, n_ctx, d_model] - x = x + self.positional_embedding - x = x.permute(1, 0, 2) # NLD -> LND - x = self.transformer.resblocks(x) - x = x.permute(1, 0, 2) # LND -> NLD - x = self.ln_final(x) - - # Take features from the eot embedding (eot_token is the highest number in each sequence) - x = x[torch.arange(x.shape[0]), text.argmax(dim=-1)] @ self.text_projection - - return x - - -class CLIPVisionTransformer(nn.Module): - def __init__(self, input_resolution: int, patch_size: int, width: int, layers: int, heads: int, output_dim: int): - super().__init__() - self.input_resolution = input_resolution - self.output_dim = output_dim - self.conv1 = nn.Conv2d(in_channels=3, out_channels=width, kernel_size=patch_size, stride=patch_size, bias=False) - - scale = width ** -0.5 - self.class_embedding = nn.Parameter(scale * torch.randn(width)) - self.positional_embedding = nn.Parameter(scale * torch.randn((input_resolution // patch_size) ** 2 + 1, width)) - self.ln_pre = LayerNorm(width) - - self.transformer = nn.Module() - self.transformer.resblocks = nn.Sequential(*[ResidualAttentionBlock(width, heads) for _ in range(layers)]) - - self.ln_post = LayerNorm(width) - self.proj = nn.Parameter(scale * torch.randn(width, output_dim)) - - def forward(self, x: torch.Tensor, get_patches : bool = False): - x = self.conv1(x) # shape = [*, width, grid, grid] - x = x.reshape(x.shape[0], x.shape[1], -1) # shape = [*, width, grid ** 2] - x = x.permute(0, 2, 1) # shape = [*, grid ** 2, width] - x = torch.cat([self.class_embedding.to(x.dtype) + torch.zeros(x.shape[0], 1, x.shape[-1], dtype=x.dtype, device=x.device), x], dim=1) # shape = [*, grid ** 2 + 1, width] - x = x + self.positional_embedding.to(x.dtype) - x = self.ln_pre(x) - - x = x.permute(1, 0, 2) # NLD -> LND - x = self.transformer.resblocks(x) - x = x.permute(1, 0, 2) # LND -> NLD - - if not get_patches: - # returns only the CLS token - x = self.ln_post(x[:, 0, :]) - else: - # returns all patch tokens AND the CLS token - x = self.ln_post(x[:, :, :]) - - if self.proj is not None: - x = x @ self.proj - - return x diff --git a/src/denseclip/clip_loader/requirements.txt b/src/denseclip/clip_loader/requirements.txt deleted file mode 100644 index d831bc24083af44b7dd9ddfcff359ff294ce4bb3..0000000000000000000000000000000000000000 --- a/src/denseclip/clip_loader/requirements.txt +++ /dev/null @@ -1,7 +0,0 @@ -# Requirements for DenseCLIP to CLIP Loader -torch>=1.9.0 -torchvision>=0.10.0 -Pillow>=8.0.0 -PyYAML>=5.4.0 -numpy>=1.20.0 -regex>=2021.4.4 diff --git a/src/denseclip/clip_loader/tokenizer.py b/src/denseclip/clip_loader/tokenizer.py deleted file mode 100644 index f45c0ba749daf13987dd7e14f5df9f201dac19f8..0000000000000000000000000000000000000000 --- a/src/denseclip/clip_loader/tokenizer.py +++ /dev/null @@ -1,232 +0,0 @@ -""" -Simple tokenizer for CLIP text processing. -This is a simplified version that doesn't require the full CLIP dependencies. -""" - -import gzip -import html -import os -import regex as re -from functools import lru_cache -from typing import List, Union - -try: - import torch -except ImportError: - print("Warning: PyTorch not available. Please install PyTorch.") - torch = None - -try: - import ftfy -except ImportError: - print("Warning: ftfy not available. Using basic text cleaning.") - ftfy = None - - -@lru_cache() -def default_bpe(): - return os.path.join(os.path.dirname(os.path.abspath(__file__)), "bpe_simple_vocab_16e6.txt.gz") - - -@lru_cache() -def bytes_to_unicode(): - """ - Returns list of utf-8 byte and a corresponding list of unicode strings. - """ - bs = list(range(ord("!"), ord("~")+1))+list(range(ord("ยก"), ord("ยฌ")+1))+list(range(ord("ยฎ"), ord("รฟ")+1)) - cs = bs[:] - n = 0 - for b in range(2**8): - if b not in bs: - bs.append(b) - cs.append(2**8+n) - n += 1 - cs = [chr(n) for n in cs] - return dict(zip(bs, cs)) - - -def get_pairs(word): - """Return set of symbol pairs in a word.""" - pairs = set() - prev_char = word[0] - for char in word[1:]: - pairs.add((prev_char, char)) - prev_char = char - return pairs - - -def basic_clean(text): - if ftfy: - text = ftfy.fix_text(text) - text = html.unescape(html.unescape(text)) - return text.strip() - - -def whitespace_clean(text): - text = re.sub(r'\s+', ' ', text) - text = text.strip() - return text - - -class SimpleTokenizer(object): - def __init__(self, bpe_path: str = default_bpe()): - self.byte_encoder = bytes_to_unicode() - self.byte_decoder = {v: k for k, v in self.byte_encoder.items()} - - # Load BPE merges if file exists, otherwise use simple word-level tokenization - if os.path.exists(bpe_path): - with gzip.open(bpe_path) as f: - merges = f.read().decode("utf-8").split('\n') - merges = merges[1:49152-256-2+1] - merges = [tuple(merge.split()) for merge in merges] - vocab = list(bytes_to_unicode().values()) - vocab = vocab + [v+'' for v in vocab] - for merge in merges: - vocab.append(''.join(merge)) - vocab.extend(['<|startoftext|>', '<|endoftext|>']) - self.encoder = dict(zip(vocab, range(len(vocab)))) - self.decoder = {v: k for k, v in self.encoder.items()} - self.bpe_ranks = dict(zip(merges, range(len(merges)))) - self.cache = {'<|startoftext|>': '<|startoftext|>', '<|endoftext|>': '<|endoftext|>'} - self.pat = re.compile(r"""<\|startoftext\|>|<\|endoftext\|>|'s|'t|'re|'ve|'m|'ll|'d|[\p{L}]+|[\p{N}]|[^\s\p{L}\p{N}]+""", re.IGNORECASE) - - else: - # Fallback: simple vocabulary - print("BPE file not found, using simple tokenization") - self.encoder = {'<|startoftext|>': 0, '<|endoftext|>': 1} - self.decoder = {0: '<|startoftext|>', 1: '<|endoftext|>'} - self.cache = {} - self.pat = re.compile(r'\S+') - - def bpe(self, token): - if token in self.cache: - return self.cache[token] - - if not hasattr(self, 'bpe_ranks'): - # Simple tokenization fallback - return token - - word = tuple(token[:-1]) + ( token[-1] + '',) - pairs = get_pairs(word) - - if not pairs: - return token+'' - - while True: - bigram = min(pairs, key = lambda pair: self.bpe_ranks.get(pair, float('inf'))) - if bigram not in self.bpe_ranks: - break - first, second = bigram - new_word = [] - i = 0 - while i < len(word): - try: - j = word.index(first, i) - new_word.extend(word[i:j]) - i = j - except: - new_word.extend(word[i:]) - break - - if word[i] == first and i < len(word)-1 and word[i+1] == second: - new_word.append(first+second) - i += 2 - else: - new_word.append(word[i]) - i += 1 - new_word = tuple(new_word) - word = new_word - if len(word) == 1: - break - else: - pairs = get_pairs(word) - - word = ' '.join(word) - self.cache[token] = word - return word - - def encode(self, text): - bpe_tokens = [] - text = whitespace_clean(basic_clean(text)).lower() - for token in re.findall(self.pat, text): - token = ''.join(self.byte_encoder[b] for b in token.encode('utf-8')) - bpe_tokens.extend(self.encoder.get(bpe_token, 1) for bpe_token in self.bpe(token).split(' ')) - return bpe_tokens - - def decode(self, tokens): - text = ''.join([self.decoder.get(token, '') for token in tokens]) - text = bytearray([self.byte_decoder[c] for c in text]).decode('utf-8', errors="replace").replace('', ' ') - return text - - -# Global tokenizer instance -_tokenizer = None - -def get_tokenizer(): - global _tokenizer - if _tokenizer is None: - # Try to find BPE file in current directory - bpe_path = os.path.join(os.path.dirname(__file__), "bpe_simple_vocab_16e6.txt.gz") - _tokenizer = SimpleTokenizer(bpe_path) - return _tokenizer - - -def tokenize(texts: Union[str, List[str]], context_length: int = 77, truncate: bool = False): - """ - Returns the tokenized representation of given input string(s) - - Parameters - ---------- - texts : Union[str, List[str]] - An input string or a list of input strings to tokenize - context_length : int - The context length to use; all CLIP models use 77 as the context length - truncate: bool - Whether to truncate the text in case its encoding is longer than the context length - - Returns - ------- - A two-dimensional tensor containing the resulting tokens, shape = [number of input strings, context_length] - """ - tokenizer = get_tokenizer() - - if torch is None: - raise RuntimeError("PyTorch is required for tokenization") - - if isinstance(texts, str): - texts = [texts] - - sot_token = tokenizer.encoder.get("<|startoftext|>", 0) - eot_token = tokenizer.encoder.get("<|endoftext|>", 1) - all_tokens = [[sot_token] + tokenizer.encode(text) + [eot_token] for text in texts] - - if any(len(s) > context_length for s in all_tokens): - if truncate: - for tokens in all_tokens: - if len(tokens) > context_length: - tokens[:] = tokens[:context_length] - tokens[-1] = eot_token - else: - raise RuntimeError(f"Input text is too long for context length {context_length}") - - result = torch.zeros(len(all_tokens), context_length, dtype=torch.long) - - for i, tokens in enumerate(all_tokens): - if len(tokens) > context_length: - if truncate: - tokens = tokens[:context_length] - tokens[-1] = eot_token - else: - raise RuntimeError(f"Input {texts[i]} is too long for context length {context_length}") - result[i, :len(tokens)] = torch.tensor(tokens) - - return result - - -if __name__ == "__main__": - # Test tokenizer - text = "a photo of a cat" - tokens = tokenize([text]) - print(f"Text: {text}") - print(f"Tokens shape: {tokens.shape}") - print(f"Tokens: {tokens}") diff --git a/src/denseclip/loader.py b/src/denseclip/loader.py deleted file mode 100644 index 9448d2e6759e3898dbfcd56ce838c7cf32171853..0000000000000000000000000000000000000000 --- a/src/denseclip/loader.py +++ /dev/null @@ -1,33 +0,0 @@ -from .clip_loader.denseclip_loader import load_clip, load_config -from .clip_loader.denseclip_loader import DenseCLIPModel -from .clip_loader.tokenizer import tokenize as DenseCLIP_tokenize -import os - -def load_denseclip(config_name: str, device: str = "cuda") -> DenseCLIPModel: - """ - Load a DenseCLIP model. - - Args: - model_name (str): The name of the DenseCLIP model to load. - device (str): The device to load the model onto, default is "cuda". - - Returns: - The loaded DenseCLIP model. - """ - return load_clip(config_name=config_name, device=device) - -def load_denseclip_config(config_name: str) -> dict: - """ - Load the configuration for a DenseCLIP model. - - Args: - config_name (str): The name of the DenseCLIP configuration to load. - - Returns: - dict: The loaded configuration dictionary. - """ - config_name = config_name + '.yaml' if not config_name.endswith('.yaml') else config_name - config_path = os.path.join(os.path.dirname(__file__), 'clip_loader/configs', f'{config_name}') - if not os.path.exists(config_path): - raise FileNotFoundError(f"DenseClip configuration file {config_path} does not exist.") - return load_config(config_path=config_path) \ No newline at end of file diff --git a/src/dino_extraction.py b/src/dino_extraction.py deleted file mode 100644 index 11f356a2e3b2f1ea4c46417b4502405682dcab5f..0000000000000000000000000000000000000000 --- a/src/dino_extraction.py +++ /dev/null @@ -1,98 +0,0 @@ -import argparse -import torch -import torchvision.transforms as T - -from PIL import Image - -feats = {} -def get_self_attention(module, input, output): - feats['self_attn'] = output - -def get_layer_n_output(module, input, output): - feats['intermediate_output'] = output - -def transform_to_standard_dino_out(x, model): - x_norm = model.norm(x) - return { - "x_norm_clstoken": x_norm[:, 0], - "x_norm_regtokens": x_norm[:, 1 : 4 + 1], - "x_norm_patchtokens": x_norm[:, 4 + 1 :], - "x_prenorm": x, - # "masks": masks, - } - -def process_self_attention(output, batch_size, num_tokens, num_attn_heads, embed_dim, scale, num_global_tokens, ret_self_attn_maps=False): - qkv = output.reshape(batch_size, num_tokens, 3, num_attn_heads, embed_dim // num_attn_heads).permute(2, 0, 3, 1, 4) - q, k, v = qkv[0] * scale, qkv[1], qkv[2] - attn = q @ k.transpose(-2, -1) - self_attn_maps = attn[:, : , 0, num_global_tokens:] - self_attn = self_attn_maps.mean(dim=1) - self_attn = self_attn.softmax(dim=-1) - if ret_self_attn_maps: - return self_attn, self_attn_maps - else: - return self_attn - - -def run_dinov2_extraction(model_name, resize_dim=518, crop_dim=518, img_path='cat.jpeg'): - device = 'cuda' if torch.cuda.is_available else 'cpu' - - num_global_tokens = 1 if "reg" not in model_name else 5 - num_patch_tokens = crop_dim // 14 * crop_dim // 14 - num_tokens = num_global_tokens + num_patch_tokens - if 'vitl' in model_name or 'vit_large' in model_name or 'ViT-L' in model_name: - embed_dim = 1024 - elif 'vitb' in model_name or 'vit_base' in model_name or 'ViT-B' in model_name: - embed_dim = 768 - elif 'vits' in model_name or 'vit_small' in model_name: - embed_dim = 384 - else: - raise Exception("Unknown ViT model") - - num_attn_heads = 16 if not 'vits' in model_name else 6 - scale = 0.125 - - # loading the model - if 'dinov2' in model_name: - model_family = 'facebookresearch/dinov2' - model = torch.hub.load(model_family, model_name) - image_transforms = T.Compose([ - T.Resize(resize_dim, interpolation=T.InterpolationMode.BICUBIC), - T.CenterCrop(crop_dim), - T.ToTensor(), - T.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)), - ]) - - model.eval() - model.to(device) - model.blocks[-1].attn.qkv.register_forward_hook(get_self_attention) - model.blocks[-1].register_forward_hook(get_layer_n_output) - - pil_img = Image.open(img_path) - - if pil_img.mode != 'RGB': - pil_img = pil_img.convert('RGB') - - batch_imgs = image_transforms(pil_img).unsqueeze(0).to(device) - - with torch.no_grad(): - outs = model(batch_imgs, is_training=True) - outs_layer_n = transform_to_standard_dino_out(feats['intermediate_output'], model) - - self_attn = process_self_attention(feats['self_attn'], 1, num_tokens, num_attn_heads, embed_dim, scale, num_global_tokens, ret_self_attn_maps=False) - avg_self_attn_token = (self_attn.unsqueeze(-1) * outs['x_norm_patchtokens']).mean(dim=1) - - print(avg_self_attn_token) - print(avg_self_attn_token.shape) - - -def main(): - parser = argparse.ArgumentParser() - parser.add_argument('--model', type=str, default="dinov2_vitb14_reg", help="Model configuration to extract features from") - parser.add_argument('--resize_dim', type=int, default=518, help="Resize dimension") - parser.add_argument('--crop_dim', type=int, default=518, help="Crop dimension") - args = parser.parse_args() - - run_dinov2_extraction(args.model, args.resize_dim, args.crop_dim) -if __name__ == '__main__': - main() \ No newline at end of file diff --git a/src/dinotxt_utils.py b/src/dinotxt_utils.py deleted file mode 100644 index 7c58250d19b4e3096e89ba5210fc6e51c6f43435..0000000000000000000000000000000000000000 --- a/src/dinotxt_utils.py +++ /dev/null @@ -1,99 +0,0 @@ -import torch - -from typing import List, Union -from clip.simple_tokenizer import SimpleTokenizer -from torchvision import transforms -from typing import Sequence - - -_DINOV2_BASE_URL = "https://dl.fbaipublicfiles.com/dinov2" - -class Tokenizer(SimpleTokenizer): - def __init__(self, vocab_path: str): - SimpleTokenizer.__init__(self, bpe_path=vocab_path) - - def tokenize(self, texts: Union[str, List[str]], context_length: int = 77) -> torch.LongTensor: - """ - Returns the tokenized representation of given input string(s) - - Parameters - ---------- - texts : Union[str, List[str]] - An input string or a list of input strings to tokenize - context_length : int - The context length to use; all CLIP models use 77 as the context length - - Returns - ------- - A two-dimensional tensor containing the resulting tokens, shape = [number of input strings, context_length] - """ - if isinstance(texts, str): - texts = [texts] - sot_token = self.encoder["<|startoftext|>"] - eot_token = self.encoder["<|endoftext|>"] - all_tokens = [[sot_token] + self.encode(text) + [eot_token] for text in texts] - result = torch.zeros(len(all_tokens), context_length, dtype=torch.long) - - for i, tokens in enumerate(all_tokens): - if len(tokens) > context_length: - tokens = tokens[:context_length] # Truncate - tokens[-1] = eot_token - result[i, : len(tokens)] = torch.tensor(tokens) - - return result - - -def get_tokenizer(): - import requests - from io import BytesIO - - url = _DINOV2_BASE_URL + "/thirdparty/bpe_simple_vocab_16e6.txt.gz" - try: - response = requests.get(url) - response.raise_for_status() - file_buf = BytesIO(response.content) - return Tokenizer(vocab_path=file_buf) - except Exception as e: - raise FileNotFoundError(f"Failed to download file from url {url} with error last: {e}") - -# Use timm's names -IMAGENET_DEFAULT_MEAN = (0.485, 0.456, 0.406) -IMAGENET_DEFAULT_STD = (0.229, 0.224, 0.225) - -class MaybeToTensor(transforms.ToTensor): - """ - Convert a ``PIL Image`` or ``numpy.ndarray`` to tensor, or keep as is if already a tensor. - """ - - def __call__(self, pic): - """ - Args: - pic (PIL Image, numpy.ndarray or torch.tensor): Image to be converted to tensor. - Returns: - Tensor: Converted image. - """ - if isinstance(pic, torch.Tensor): - return pic - return super().__call__(pic) - -def make_normalize_transform( - mean: Sequence[float] = IMAGENET_DEFAULT_MEAN, - std: Sequence[float] = IMAGENET_DEFAULT_STD, -) -> transforms.Normalize: - return transforms.Normalize(mean=mean, std=std) - -def make_classification_eval_transform( - *, - resize_size: int = 256, - interpolation=transforms.InterpolationMode.BICUBIC, - crop_size: int = 224, - mean: Sequence[float] = IMAGENET_DEFAULT_MEAN, - std: Sequence[float] = IMAGENET_DEFAULT_STD, -) -> transforms.Compose: - transforms_list = [ - transforms.Resize(resize_size, interpolation=interpolation), - transforms.CenterCrop(crop_size), - MaybeToTensor(), - make_normalize_transform(mean=mean, std=std), - ] - return transforms.Compose(transforms_list) diff --git a/src/embedding_utils.py b/src/embedding_utils.py deleted file mode 100644 index b5699d9a69056bab5881eae7cbd1fe792a56f2b9..0000000000000000000000000000000000000000 --- a/src/embedding_utils.py +++ /dev/null @@ -1,25 +0,0 @@ -import torch - -def get_pseudo_inverse(A): - # Perform SVD - U, S, Vh = torch.linalg.svd(A, full_matrices=False) - - # Compute the pseudo-inverse of the singular values - S_pinv = torch.zeros_like(S) - non_zero = S > 1e-10 # Tolerance for considering a singular value as zero - S_pinv[non_zero] = 1.0 / S[non_zero] - - # Construct the pseudo-inverse - A_pinv = Vh.T @ torch.diag(S_pinv) @ U.T - - return A_pinv - -def revert_transformation(features, linear_layer=None, A_pinv=None, b=None): - assert linear_layer is not None or (A_pinv is not None and b is not None), "revert_transformation needs either the pseudo inverse od the linear layer to calculate the pseudo inverse from" - if A_pinv is None: - W = linear_layer.weight - b = linear_layer.bias - - A_pinv = get_pseudo_inverse(W) - - return (features - b) @ A_pinv.t() \ No newline at end of file diff --git a/src/meacap/args.py b/src/meacap/args.py deleted file mode 100644 index 610dc10c3f1febca6435da5b39c0d730b7d7910f..0000000000000000000000000000000000000000 --- a/src/meacap/args.py +++ /dev/null @@ -1,84 +0,0 @@ -import argparse - - -def get_args(): - parser = argparse.ArgumentParser(description="args for Memory augmented zero-shot image captioning.") - - # HYPERPARAMETERS ## - parser.add_argument('--seed', type=int, default=0) - parser.add_argument('--gpu', type=str, default='3') - parser.add_argument('--batch_size', type=str, default=1,help='only support batch_size=1 now') - parser.add_argument('--conzic_sample', type=bool, default=True, help='conzic sample means a way to process logits by conzic method' - 'https://arxiv.org/abs/2303.02437') - parser.add_argument('--conzic_top_k', type=int, default=200) - parser.add_argument("--alpha", type=float, default=0.1, help="weight for fluency") - parser.add_argument("--beta", type=float, default=0.8, help="weight for image-matching degree") - parser.add_argument("--gamma", type=float, default=0.2, help="weight for fluency") - - parser.add_argument("--use_prompt", action='store_true', default=False) - parser.add_argument("--prompt", type=list, default=['The image depicts that']) - parser.add_argument("--prompt_ensembling", action='store_true', default=False) - - ## MEMORY ## - parser.add_argument("--use_memory", type=bool, default=True) - parser.add_argument("--memory_id", type=str, default=r"coco_B16",help="memory name") - #parser.add_argument("--memory_caption_path", type=str, default='data/memory/coco/memory_captions.json') - parser.add_argument("--memory_caption_num", type=int, default=5) - - parser.add_argument("--memory_path", type=str, default="/raid/datasets/meacap_files/memory/coco/memory_captions.json") - - ## DATA/MODEL PATH ## - parser.add_argument('--img_path', type=str, default=r'./image_example') - parser.add_argument('--output_path', type=str, default=r'./outputs') - #vl_model : "openai/clip-vit-base-patch16" - #wte_model_path : "sentence-transformers/all-MiniLM-L6-v2" - #parser_checkpoint : "lizhuang144/flan-t5-base-VG-factual-sg" - parser.add_argument('--vl_model', type=str, default=r'openai/clip-vit-base-patch16') - parser.add_argument('--use_t2d', type=bool, default=False, help='whether to use talk2dino for textual feature extraction') - parser.add_argument("--talk2dino_config_path", type=str, default=r'../../configs_talk2dino/vitb_mlp_infonce.yaml') - parser.add_argument("--talk2dino_weights_path", type=str, default=r'../../weights_talk2dino/vitb_mlp_infonce.pth') - parser.add_argument("--parser_checkpoint", type=str, default=r'lizhuang144/flan-t5-base-VG-factual-sg') - parser.add_argument("--wte_model_path", type=str, default=r'sentence-transformers/all-MiniLM-L6-v2') - parser.add_argument("--lm_model_path", type=str, default=r'F:/ImageText/MeaCap-family/pretrain_model/CBART_COCO') - - parser.add_argument("--memory_base_path", type=str, default="/raid/datasets/meacap_files/") - - ## lANGUAGE MODEL CBART ## - parser.add_argument('--bart', type=str, default='large', choices=['base', 'large']) - - parser.add_argument('--refinement_steps', type=int, default=10, help='The number of refinements for each input.') - parser.add_argument('--adaptive', type=bool, default=False, help='The number of refinements is on the fly but ' - 'no bigger than max_refinement_steps') - parser.add_argument('--max_refinement_steps', type=int, default=30, help='The maximum number of refinements for each input.') - parser.add_argument('--max_len', type=int, default=20, help='The maximum length of the generated sentence.') - parser.add_argument('--min_len', type=int, default=10, help='The minimum length of the generated sentence.') - parser.add_argument('--temperature', type=float, default=1, - help='The value used to module the next token probabilities. Must be strictly positive. Default to 1.0.') - parser.add_argument('--repetition_penalty', type=float, default=2, - help='Between 1.0 and infinity.1.0 means no penalty.Default to 1.0.') - parser.add_argument('--threshold', type=float, default=0, - help='Between 0 and 1. 0 means no threshold for copy action. Default to 0.') - parser.add_argument('--top_k', type=int, default=0, - help='The number of highest probability vocabulary tokens to keep for top-k-filtering. Between 1 and infinity.') - parser.add_argument('--top_p', type=float, default=0.9, - help='The cumulative probability of parameter highest probability vocabulary tokens to keep for nucleus sampling. ' - 'Must be between 0 and 1.') - parser.add_argument('--decoder_chain', type=int, default=1, - help='the number of parallel chains for decoder, each chain refers to an unique token sequence.') - parser.add_argument('--do_sample', type=int, default=0, - help='if 0 decode with greedy method, otherwise decode with top_k or top_p.') - parser.add_argument('--encoder_loss_type', type=int, default=0, help='0 is classification loss, 1 is regression loss') - parser.add_argument('--insert_mode', type=int, default=0, choices=[0, 1, 2, 3, 4], - help='0 means using the left part, 1 means using the middle part, 2 means using the right part,' - '3 means randomly selecting, 4 means selecting the tokens with highest weight') - parser.add_argument('--max_insert_label', type=int, default=1, help='the maximum number of tokens to be inserted before a token.') - parser.add_argument('--num_labels', type=int, default=3, - help='0 for copy, 1 for replace, 2-5 means insert 1-4 tokens') - parser.add_argument('--generate_mode', type=int, default=0, choices=[0, 1, 2, 3], - help='0 for random, 1 for lm, 2 for combination') - parser.add_argument('--full_mask', type=float, default=0, help='0 for using casual mask attention for decoder, ' - '1 for without using casual mask attention for decoder.') - parser.add_argument('--w', type=float, default=1.0, help='The weight for the encoder loss') - args = parser.parse_args() - - return args \ No newline at end of file diff --git a/src/meacap/entrypoint.py b/src/meacap/entrypoint.py deleted file mode 100644 index ca5238956084091735b221392b50280588d6215d..0000000000000000000000000000000000000000 --- a/src/meacap/entrypoint.py +++ /dev/null @@ -1,190 +0,0 @@ -import os, sys - - -if os.path.abspath(os.path.join(os.path.dirname(__file__), '..')) not in sys.path: - sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '..'))) - -from viecap.entrypoint import VieCap -from viecap.utils import compose_discrete_prompts -from viecap.search import greedy_search, beam_search, opt_search - -from .models.clip_utils import CLIP - -from sentence_transformers import SentenceTransformer -from transformers import AutoTokenizer, AutoModelForSeq2SeqLM -import torch, json -from torch.nn.utils.rnn import pad_sequence - -from typing import List - -from .utils.detect_utils import retrieve_concepts - -class MeaCap(VieCap): - - retrieve_on_CPU = False - - def __init__(self, args, device, clip_name): - super(MeaCap, self).__init__(args, device, clip_name) - - args = self.args - - self.vl_model = CLIP(args.meacap.vl_model) - self.vl_model = self.vl_model.to(self.device) - print('[MeaCap] Loaded CLIP vl_model from the checkpoint {}.'.format(args.meacap.vl_model)) - - self.wte_model = SentenceTransformer(args.meacap.wte_model_path, device=self.device) - print('[MeaCap] Load sentenceBERT from the checkpoint {}.'.format(args.meacap.wte_model_path)) - - with torch.cuda.device(self.device): - self.parser_tokenizer = AutoTokenizer.from_pretrained(args.meacap.parser_checkpoint) - self.parser_model = AutoModelForSeq2SeqLM.from_pretrained(args.meacap.parser_checkpoint) - self.parser_model.eval() - self.parser_model.to(self.device) - print('[MeaCap] Load Textual Scene Graph parser from the checkpoint {}.'.format(args.meacap.parser_checkpoint)) - - memory_id = args.meacap.memory_id - memory_base_path = args.meacap.memory_base_path - memory_caption_path = os.path.join(memory_base_path, f"memory/{memory_id}", "memory_captions.json") - memory_clip_embedding_file = os.path.join(memory_base_path, f"memory/{memory_id}", "memory_clip_embeddings.pt") - memory_wte_embedding_file = os.path.join(memory_base_path, f"memory/{memory_id}", "memory_wte_embeddings.pt") - - self.memory_clip_embeddings = torch.load(memory_clip_embedding_file, map_location=self.device).to(self.device) - self.memory_wte_embeddings = torch.load(memory_wte_embedding_file, map_location=self.device).to(self.device) - with open(memory_caption_path, 'r') as f: - self.memory_captions = json.load(f) - print('[MeaCap] Loaded memory bank for memory_id {}.'.format(memory_id)) - - self.vl_model_retrieve = self.vl_model - - self.eval() - - def get_viecap_texts_embeddings(self, args, clip_name): - return None, None - - def load_config(self, args_dict): - default = { - "meacap" : { - "memory_caption_num" : 5, - "vl_model" : "openai/clip-vit-base-patch32", - "wte_model_path" : "sentence-transformers/all-MiniLM-L6-v2", - "parser_checkpoint" : "lizhuang144/flan-t5-base-VG-factual-sg", - "memory_id" : "coco", - "memory_base_path" : "/raid/datasets/meacap_files/" - } - } - - def deep_merge(dict1, dict2): - """ - Recursively merges the contents of dict2 into dict1. - - the value from dict2 overwrites the value in dict1 - For each key in dict2: - - If the key exists in dict1 and both values are dictionaries, merge them recursively. - - Otherwise, the value from dict2 overwrites the value in dict1. - Parameters: - dict1 (dict): The dictionary to be updated in place. - dict2 (dict): The dictionary whose values will be merged into dict1. - """ - for key, value in dict2.items(): - if ( - key in dict1 and isinstance(dict1[key], dict) - and isinstance(value, dict) - ): - deep_merge(dict1[key], value) - else: - dict1[key] = value - return dict1 - - deep_merge(default, args_dict) # the priority of the default config is lower than the user input - args_dict = default - - return super().load_config(args_dict) - - def forward(self, image_features, compute_scores : bool = False, eval_mode : bool = True) -> List[str]: - - if eval_mode: - self.eval() - - pad_id = self.tokenizer.pad_token_id if self.tokenizer.pad_token_id is not None else 0 - - image_features /= image_features.norm(2, dim = -1, keepdim = True) - - continuous_embeddings = self.model.mapping_network(image_features).view(-1, self.args.continuous_prompt_length, self.model.gpt_hidden_size) - - if self.args.using_hard_prompt: - - #batch_image_embeds = self.vl_model.compute_image_representation_from_image_path(self.args.image_path) - batch_image_embeds = image_features - - if self.retrieve_on_CPU != True: - #batch _size = batch_image_ embeds.sha pe[0] - #memory_clip_em beds_batched = self.memory_c lip_embeddings.unsq ueeze(0).repeat(batch_size, 1, 1) - #clip_sc ore, cli p_ref = self.vl_model_r etrieve.compute_image _text_similarity_via_embeddings( - # batch_image_e mbeds, memory_clip _embeds_batched) - clip_score, clip_ref = self.vl_model_retrieve.compute_image_text_similarity_via_embeddings_new( - batch_image_embeds, self.memory_clip_embeddings) - else: - - raise Exception("retrieve_on_CPU is not supported in this version.") - #batch_image_embeds_cpu = batch_image_embeds.to(cpu_device) - #clip_score_cpu, clip_ref_cpu = vl_model_retrieve.compute_image_text_similarity_via_embeddings( - # batch_image_embeds_cpu, - # memory_clip_embeddings) - #clip_score = clip_score_cpu.to(device) - #clip_ref = clip_ref_cpu.to(device) - - select_memory_ids_batch = clip_score.topk(self.args.meacap.memory_caption_num, dim=-1)[1]#.squeeze(0) - - all_discrete_tokens = [] - - for select_memory_ids in select_memory_ids_batch: - select_memory_captions = [self.memory_captions[id] for id in select_memory_ids] - select_memory_wte_embeddings = self.memory_wte_embeddings[select_memory_ids] - detected_objects = retrieve_concepts(parser_model=self.parser_model, parser_tokenizer=self.parser_tokenizer, - wte_model=self.wte_model, - select_memory_captions=select_memory_captions, - image_embeds=batch_image_embeds, - device=self.device) - - #print("memory concepts:", detected_objects) - discrete_tokens = compose_discrete_prompts(self.tokenizer, detected_objects).to(self.device) #.unsqueeze(dim = 0) - all_discrete_tokens.append(discrete_tokens) - - all_discrete_tokens = [t.to(self.device) for t in all_discrete_tokens] - discrete_tokens = pad_sequence(all_discrete_tokens, batch_first=True, padding_value=pad_id) - #discrete_tokens = torch.stack(all_discrete_tokens).to(self.device) - - discrete_embeddings = self.model.word_embed(discrete_tokens) - - - if self.args.only_hard_prompt: - embeddings = discrete_embeddings - elif self.args.soft_prompt_first: - embeddings = torch.cat((continuous_embeddings, discrete_embeddings), dim = 1) - else: - embeddings = torch.cat((discrete_embeddings, continuous_embeddings), dim = 1) - else: - embeddings = continuous_embeddings - - if 'gpt' in self.args.language_model: - if not self.args.using_greedy_search: - #sentences = beam_search(embeddings = embeddings, tokenizer = self.tokenizer, beam_width = self.args.beam_width, model = self.model.gpt) # List[str] - sentences = [] - for i in range(embeddings.shape[0]): - sentence = beam_search(embeddings = embeddings[i:i+1], tokenizer = self.tokenizer, beam_width = self.args.beam_width, model = self.model.gpt) - sentences.append(sentence[0]) - - else: - sentences = greedy_search(embeddings = embeddings, tokenizer = self.tokenizer, model = self.model.gpt) - else: - sentences = opt_search(prompts=self.args.text_prompt, embeddings = embeddings, tokenizer = self.tokenizer, beam_width = self.args.beam_width, model = self.model.gpt) - - if compute_scores: - perplexities = self.compute_perplexity( - sentences, - tokenizer=self.tokenizer, - model=self.model.gpt, - device=self.device, - ) - return sentences, perplexities - else: - return sentences \ No newline at end of file diff --git a/src/meacap/models/clip_utils.py b/src/meacap/models/clip_utils.py deleted file mode 100644 index 0602d27c9d19bd64351a26e1eab24264c899f8e4..0000000000000000000000000000000000000000 --- a/src/meacap/models/clip_utils.py +++ /dev/null @@ -1,188 +0,0 @@ -import torch -import requests -from torch import nn -from PIL import Image -from typing import Tuple - -class CLIP(nn.Module): - def __init__(self, model_name): - super(CLIP, self).__init__() - # model name: e.g. openai/vl_models-vit-base-patch32 - print('Initializing CLIP model...') - from transformers import CLIPProcessor, CLIPModel - self.model = CLIPModel.from_pretrained(model_name) - self.model.eval() - self.processor = CLIPProcessor.from_pretrained(model_name) - from transformers import CLIPTokenizerFast - self.tokenizer = CLIPTokenizerFast.from_pretrained(model_name) - self.cuda_has_been_checked = False - print('CLIP model initialized.') - - def check_cuda(self): - self.cuda_available = next(self.model.parameters()).is_cuda - self.device = next(self.model.parameters()).get_device() - if self.cuda_available: - print('Cuda is available.') - print('Device is {}'.format(self.device)) - else: - print('Cuda is not available.') - print('Device is {}'.format(self.device)) - - def device_convert(self, tar_device): - self.device = tar_device - self.model.to(self.device) - print(f'CLIP Model moved to {tar_device}!') - - @torch.no_grad() - def compute_image_representation_from_image_path(self, image_path): - if not self.cuda_has_been_checked: - self.check_cuda() - self.cuda_has_been_checked = True - else: - pass - # image_path: the path of the image - image = Image.open(image_path) - inputs = self.processor(images=image, return_tensors="pt") - pixel_values = inputs['pixel_values'] - # if self.cuda_available: - pixel_values = pixel_values.to(self.device) - visual_outputs = self.model.vision_model(pixel_values=pixel_values) - image_embeds = visual_outputs[1] - image_embeds = self.model.visual_projection(image_embeds) # [1 x embed_dim] - return image_embeds - - def compute_image_representation_from_image_instance(self, image): - if not self.cuda_has_been_checked: - self.check_cuda() - self.cuda_has_been_checked = True - else: - pass - # image_path: the path of the image - inputs = self.processor(images=image, return_tensors="pt") - pixel_values = inputs['pixel_values'] - # if self.cuda_available: - pixel_values = pixel_values.to(self.device) - visual_outputs = self.model.vision_model(pixel_values=pixel_values) - image_embeds = visual_outputs[1] - image_embeds = self.model.visual_projection(image_embeds) # [1 x embed_dim] - return image_embeds - - def compute_frame_representation_from_tensor(self, pixel_values): - if not self.cuda_has_been_checked: - self.check_cuda() - self.cuda_has_been_checked = True - else: - pass - - # if self.cuda_available: - pixel_values = pixel_values.to(self.device) - visual_outputs = self.model.vision_model(pixel_values=pixel_values) - image_embeds = visual_outputs[1] - image_embeds = self.model.visual_projection(image_embeds) # [1 x embed_dim] - return image_embeds - - def compute_text_representation(self, text_list): - if not self.cuda_has_been_checked: - self.check_cuda() - self.cuda_has_been_checked = True - else: - pass - # text_list: a list of text - text_inputs = self.tokenizer(text_list, padding=True, return_tensors="pt", - max_length=50, truncation=True) - # self.tokenizer.max_len_single_sentence + 2 = 77 - input_ids, attention_mask = text_inputs['input_ids'], text_inputs['attention_mask'] - # if self.cuda_available: - input_ids = input_ids.to(self.device) - attention_mask = attention_mask.to(self.device) - text_outputs = self.model.text_model( - input_ids=input_ids, - attention_mask=attention_mask - ) - text_embeds = text_outputs[1] - text_embeds = self.model.text_projection(text_embeds) - return text_embeds - - def compute_image_text_similarity_via_embeddings(self, image_embeds, text_embeds) -> Tuple[torch.Tensor, torch.Tensor]: - ''' - image_embeds: batch x embed_dim - text_embeds: batch x len(text_list) x embed_dim - ''' - text_embeds = text_embeds.view(image_embeds.shape[0], -1, text_embeds.shape[-1]) - image_embeds = image_embeds / image_embeds.norm(dim=-1, keepdim=True) - text_embeds = text_embeds / text_embeds.norm(dim=-1, keepdim=True) - image_embeds = image_embeds.unsqueeze(-1) - logit_scale = self.model.logit_scale.exp() - logits_per_text = torch.matmul(text_embeds, image_embeds) * logit_scale - logits_per_image = logits_per_text.squeeze(-1) - return logits_per_image.softmax(dim=1), logits_per_image/logit_scale # batch x len(text_list) - - def compute_image_text_similarity_via_raw_text(self, image_embeds, text_list): - text_embeds = self.compute_text_representation(text_list) - return self.compute_image_text_similarity_via_embeddings(image_embeds, text_embeds) - - def compute_image_text_similarity_via_Image_text(self, image, text_list): - image_embeds = self.compute_image_representation_from_image_instance(image) - text_embeds = self.compute_text_representation(text_list) - return self.compute_image_text_similarity_via_embeddings(image_embeds, text_embeds) - - def compute_image_text_similarity_via_embeddings_new(self, image_embeds, text_embeds) -> Tuple[torch.Tensor, torch.Tensor]: - ''' - image_embeds: [B, D] - text_embeds: [N, D] - ''' - image_embeds = image_embeds / image_embeds.norm(dim=-1, keepdim=True) - text_embeds = text_embeds / text_embeds.norm(dim=-1, keepdim=True) - - # Compute cosine similarity - logits_per_image = image_embeds @ text_embeds.T # [B, N] - - logit_scale = self.model.logit_scale.exp() - logits_per_image = logits_per_image * logit_scale - - return logits_per_image.softmax(dim=1), logits_per_image / logit_scale - - ### -------------------- functions for building index ---------------------- ### - def compute_batch_index_image_features(self, image_list): - ''' - # list of image instances - ''' - if not self.cuda_has_been_checked: - self.check_cuda() - self.cuda_has_been_checked = True - else: - pass - # image_path: the path of the image - inputs = self.processor(images=image_list, return_tensors="pt") - pixel_values = inputs['pixel_values'] - # if self.cuda_available: - pixel_values = pixel_values.to(self.device) - visual_outputs = self.model.vision_model(pixel_values=pixel_values) - image_embeds = visual_outputs[1] - image_embeds = self.model.visual_projection(image_embeds) # [1 x embed_dim] - return image_embeds # len(image_list) x embed_dim - - def compute_batch_index_text_representation(self, text_list): - if not self.cuda_has_been_checked: - self.check_cuda() - self.cuda_has_been_checked = True - else: - pass - # text_list: a list of text - #text_inputs = self.tokenizer(text_list, padding=True, return_tensors="pt") - text_inputs = self.tokenizer(text_list, padding=True, return_tensors="pt", - max_length=self.tokenizer.max_len_single_sentence + 2, truncation=True) - input_ids, attention_mask = text_inputs['input_ids'], text_inputs['attention_mask'] - # if self.cuda_available: - input_ids = input_ids.to(self.device) - attention_mask = attention_mask.to(self.device) - text_outputs = self.model.text_model( - input_ids=input_ids, - attention_mask=attention_mask - ) - text_embeds = text_outputs[1] - text_embeds = self.model.text_projection(text_embeds) - return text_embeds - #logit_scale = self.model.logit_scale.exp() - #text_embeds = text_embeds * logit_scale - #return text_embeds \ No newline at end of file diff --git a/src/meacap/prepare_embeddings.py b/src/meacap/prepare_embeddings.py deleted file mode 100644 index 77bcab472e945a1e397deee1c5cc74ecf887054a..0000000000000000000000000000000000000000 --- a/src/meacap/prepare_embeddings.py +++ /dev/null @@ -1,93 +0,0 @@ -# -*- coding: utf-8 -*- - -import torch -#from torch.utils.data import DataLoader - -import time -import os -import sys -from args import get_args -#from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, AutoModel -from sentence_transformers import SentenceTransformer - -from models.clip_utils import CLIP -import json -import copy -from tqdm import tqdm -import shutil - -#dir_path = os.path.dirname(os.path.realpath(__file__)) -#parent_dir_path = os.path.abspath(os.path.join(dir_path, os.pardir)) -#sys.path.insert(0, parent_dir_path) - -#from .utils.log import Logger -#from language_models.language_model import LanguageModel -#from transformers import GPT2Tokenizer, GPT2LMHeadModel -#from dataset.ImgDataset import Imgdata, collate_img -#from dataset.ImgDataset_img_return import Imgdata_img_return, collate_img_img_return -# -from utils.some_utils import set_seed, update_args_logger -#from utils.detect_utils import detect_keyword -#from utils.generate_utils_ import Get_shuffle_score, filter_text - -from talk2dino import ProjectionLayer - -if __name__ == "__main__": - args = get_args() - - - input_text_corpus_path = args.memory_path - - save_path = os.path.join(args.memory_base_path, f"memory/{args.memory_id}") - - print(F"Will use '{args.memory_id = } and the input text corpus is at {input_text_corpus_path = }'") - print(F"Will save the outputs in '{save_path = }'") - - set_seed(args) - os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu - device = torch.device("cuda" if torch.cuda.is_available() else "cpu") - cpu_device = torch.device("cpu") - - print(f"Going to use '{args.vl_model}' as CLIP model") - vl_model = CLIP(args.vl_model) - vl_model = vl_model.to(device) - - if args.use_t2d: - # loading Talk2DINO - print(f"Loading Talk2DINO weights from {args.talk2dino_weights_path}") - talk2dino = ProjectionLayer.from_config(args.talk2dino_config_path) - talk2dino.load_state_dict(torch.load(args.talk2dino_weights_path, device)) - talk2dino.to(device) - talk2dino.eval() - else: - talk2dino = None - - #sim_func = torch.nn.CosineSimilarity(dim=1, eps=1e-6) - wte_model = SentenceTransformer(args.wte_model_path) - - clip_embed_list = [] - wte_embed_list = [] - - with open(input_text_corpus_path,'r') as json_f: - textual_data = json.load(json_f) - batch_size = 128 - for idx in tqdm(range(0,len(textual_data),batch_size), dynamic_ncols=True): - text_list = textual_data[idx:idx+batch_size] - clip_embeds = vl_model.compute_text_representation(text_list) - if args.use_t2d: - with torch.no_grad(): - clip_embeds = talk2dino.project_clip_txt(clip_embeds) - clip_embeds = clip_embeds.detach().cpu() - clip_embed_list.append(clip_embeds) - wte_embeds = wte_model.encode(text_list, convert_to_tensor=True, normalize_embeddings=True).detach().cpu() - wte_embed_list.append(wte_embeds) - # if idx >= 200000: - # break - all_clip_embeds = torch.cat(clip_embed_list) - all_wte_embeds = torch.cat(wte_embed_list) - - if os.path.exists(save_path) == False: - os.makedirs(save_path) - shutil.copy(args.memory_path, os.path.join(save_path, "memory_captions.json")) - torch.save(all_clip_embeds, os.path.join(save_path, "memory_clip_embeddings.pt")) - torch.save(all_wte_embeds, os.path.join(save_path, "memory_wte_embeddings.pt")) \ No newline at end of file diff --git a/src/meacap/readme.md b/src/meacap/readme.md deleted file mode 100644 index 133f00cfa0993ab14b0216ade6c29ff6ead7eac3..0000000000000000000000000000000000000000 --- a/src/meacap/readme.md +++ /dev/null @@ -1,48 +0,0 @@ -pip install sentence-transformers - - -(3.2.1) - - -package for the usage of meacap invlm version (based on viecap decoder) - -pip install nltk - -huggingface-cli snapshot-download \ - --repo-id JoeyZoZ/MeaCap \ - --allow-pattern "memory/*" \ - --local-dir /raid/datasets/meacap_files/data/memory - -python -c " -from huggingface_hub import snapshot_download -snapshot_download( - repo_id='JoeyZoZ/MeaCap', - local_dir='/raid/datasets/meacap_files', - local_dir_use_symlinks=False -) -" - -# Preparing Embeddings for T2D Space - -python prepare_embeddings.py --memory_id coco_B16_t2d --use_t2d True - - -# inference example on one image - -python viecap_inference.py --memory_id coco --image_path "/raid/datasets/coco/train2017/000000000064.jpg" - - -memory concepts: ['clock', 'tree', 'sidewalk', 'city'] -the generated caption: clock on sidewalk in city with trees and sidewalk in background. -![http://images.cocodataset.org/train2017/000000000064.jpg](http://images.cocodataset.org/train2017/000000000064.jpg) - - - -python viecap_inference.py --memory_id coco --image_path "/raid/datasets/coco/train2017/000000000071.jpg" - - -memory concepts: ['trains'] -the generated caption: Blue and yellow trains passing each other on the tracks. -![http://images.cocodataset.org/train2017/000000000071.jpg](http://images.cocodataset.org/train2017/000000000071.jpg) - - diff --git a/src/meacap/talk2dino.py b/src/meacap/talk2dino.py deleted file mode 100644 index 87cfc0d5d251c1a6ea46958db05f93b709ea7cda..0000000000000000000000000000000000000000 --- a/src/meacap/talk2dino.py +++ /dev/null @@ -1,97 +0,0 @@ - -import torch -import torch.nn as nn -import yaml - -class ProjectionLayer(nn.Module): - """ - Creates a projection layer on top of the CLIP-text encoder. - The forward method calculate the similarity between the DINO CLS token and the projected CLIP textual CLS token. - """ - def __init__(self, act=nn.Tanh(), hidden_layer=False, cosine=True, dino_embed_dim=1024, clip_embed_dim=512, num_attn_head=16, weight_attn_heads=None, - alignment_strategy='max_score', alpha=0.6, keep_cls=False, keep_end_seq=False): - # mlp_dims list of mlp dimensions - super().__init__() - self.num_attn_head = num_attn_head - - self.linear_layer = nn.Linear(clip_embed_dim, dino_embed_dim) - if hidden_layer: - hidden_layer = 1 if hidden_layer is True else hidden_layer # ensuring compatibility with old code - # self.linear_layer2 = nn.Linear(dino_embed_dim, dino_embed_dim) - self.hidden_layers = nn.ModuleList([nn.Linear(dino_embed_dim, dino_embed_dim) for _ in range(hidden_layer)]) - self.act = act - self.cosine = cosine - - self.weight_attn_heads = weight_attn_heads - if weight_attn_heads == 'static': - self.attn_weights = nn.Parameter(torch.rand(self.num_attn_head)) - elif weight_attn_heads == 'conditioned': - self.weight_layer1 = nn.Linear(dino_embed_dim, dino_embed_dim) - self.weight_layer2 = nn.Linear(dino_embed_dim, self.num_attn_head) - - self.alignment_strategy = alignment_strategy # relevant only if we use disentangled_self_attn - self.keep_cls = keep_cls # relevant only if we use clip_txt_tokens_out - self.keep_end_seq = keep_end_seq # relevant only if we use clip_txt_tokens_out - self.alpha = alpha - - @classmethod - def from_config(cls, config): - if type(config) is str: - # if the configuration is a string, we treat it as a file path - with open(config, 'r') as f: - config = yaml.safe_load(f)['model'] - - # loading the activation function - act = config.get('act', None) - if act == 'tanh': - act = nn.Tanh() - elif act == 'relu': - act = nn.ReLU() - elif act == 'sigmoid': - act = nn.Sigmoid() - elif act is not None: - raise Exception("Unknown activation function") - - model = cls( - act=act, - hidden_layer=config.get('hidden_layer', False), - cosine=config.get('cosine', True), - dino_embed_dim=config.get('dino_embed_dim', 1024), - num_attn_head=config.get('num_attn_head', 16), - clip_embed_dim=config.get('clip_embed_dim', 512), - weight_attn_heads=config.get('weight_attn_heads', None), - alignment_strategy=config.get('alignment_strategy', 'max_score'), - alpha=config.get('alpha', 0.6), - keep_cls=config.get('keep_cls', None), - keep_end_seq=config.get('keep_end_seq', None), - ) - if config.get('starting_checkpoint', None) is not None: - model.load_state_dict(torch.load(config['starting_checkpoint'], 'cpu')) - - return model - - def project_clip_txt(self, textual_embedding): - textual_embedding = textual_embedding.float() - x = self.linear_layer(textual_embedding) - - if hasattr(self, 'hidden_layers'): - for hidden_layer in self.hidden_layers: - if self.act: - x = self.act(x) - x = hidden_layer(x) - - return x - def load_state_dict(self, state_dict, strict=True): - # compatibility with old code - if 'linear_layer2.weight' in state_dict: - state_dict['hidden_layers.0.weight'] = state_dict.pop('linear_layer2.weight') - state_dict['hidden_layers.0.bias'] = state_dict.pop('linear_layer2.bias') - # Call the parent class's load_state_dict with the modified state_dict - super(ProjectionLayer, self).load_state_dict(state_dict, strict) - - def set_alignment_strategy(self, alignment_strategy): - self.alignment_strategy = alignment_strategy - return - - def __len__(self): - return sum(p.numel() for p in self.parameters()) \ No newline at end of file diff --git a/src/meacap/utils/detect_utils.py b/src/meacap/utils/detect_utils.py deleted file mode 100644 index 3039b5428f34e2cbf505c8b535de0b5b295463bd..0000000000000000000000000000000000000000 --- a/src/meacap/utils/detect_utils.py +++ /dev/null @@ -1,69 +0,0 @@ -import os -import math -import torch -import json -from collections import OrderedDict -import traceback - -from .parse_tool import parse, get_entitys, get_graph_dict, merge_graph_dict - - -def add_prompt(word_list: list = None, - prompt: str = 'Image of '): - sentence_list = [] - for word in word_list: - sentence = prompt + word + '.' - sentence_list.append(sentence) - return sentence_list - - -def retrieve_concepts(parser_model=None, parser_tokenizer=None, wte_model=None, select_memory_captions=None,image_embeds=None, - device=None, logger=None, args=None,verbose=False): - ''' - memory-based key concepts extracting - ''' - torch.set_printoptions(sci_mode=False) - - scene_graphs = parse(parser_model, parser_tokenizer, - text_input=select_memory_captions, - device=device) - type_dict = {} - count_dict = OrderedDict() - attribute_dict = {} - entities_, count_dict_, entire_graph_dict = get_graph_dict(wte_model, scene_graphs, type_dict, attribute_dict) - concepts, count_dict, filtered_graph_dict = merge_graph_dict(wte_model, entities_, count_dict_, entire_graph_dict, select_memory_captions) - # concepts, count_dict = merge_sim_entities(args.wte_model, entities_, count_dict_, attribute_dict) - if logger is not None: - logger.logger.info(f"********************************************") - logger.logger.info(f"Memory captions: {select_memory_captions}") - logger.logger.info(f"Memory scene graphs: {scene_graphs}") - logger.logger.info(f"Memory concepts: {concepts}") - logger.logger.info(f"********************************************") - - return concepts[:4] - -def retrieve_concepts_from_image(parser_model=None, parser_tokenizer=None, wte_model=None, select_memory_captions=None,image_path=None, - device=None, logger=None, args=None): - ''' - memory-based key concepts extracting - ''' - - - torch.set_printoptions(sci_mode=False) - logger.logger.info(f"********************************************") - logger.logger.info(f"Memory captions: {select_memory_captions}") - scene_graphs = parse(parser_model, parser_tokenizer, - text_input=select_memory_captions, - device=device) - logger.logger.info(f"Memory scene graphs: {scene_graphs}") - type_dict = {} - count_dict = OrderedDict() - attribute_dict = {} - entities_, count_dict_, entire_graph_dict = get_graph_dict(wte_model, scene_graphs, type_dict, attribute_dict) - concepts, count_dict, filtered_graph_dict = merge_graph_dict(wte_model, entities_, count_dict_, entire_graph_dict, select_memory_captions) - # concepts, count_dict = merge_sim_entities(args.wte_model, entities_, count_dict_, attribute_dict) - - logger.logger.info(f"Memory concepts: {concepts}") - logger.logger.info(f"********************************************") - - return concepts[:4] \ No newline at end of file diff --git a/src/meacap/utils/log.py b/src/meacap/utils/log.py deleted file mode 100644 index eda55c7c46972b37e186f3dc549f3f6fd9560817..0000000000000000000000000000000000000000 --- a/src/meacap/utils/log.py +++ /dev/null @@ -1,33 +0,0 @@ -# -*- coding: utf-8 -*- -# @Time : 2019/12/30 8:06 PM -# @Author : He Xingwei - -import logging - -class Logger(object): - level_relations = { - 'debug':logging.DEBUG, - 'info':logging.INFO, - 'warning':logging.WARNING, - 'error':logging.ERROR, - 'crit':logging.CRITICAL - }#ๆ—ฅๅฟ—็บงๅˆซๅ…ณ็ณปๆ˜ ๅฐ„ - - def __init__(self,filename,level='info',fmt='%(asctime)s - %(levelname)s: %(message)s'): - self.logger = logging.getLogger(filename) - format_str = logging.Formatter(fmt)#่ฎพ็ฝฎๆ—ฅๅฟ—ๆ ผๅผ - self.logger.setLevel(self.level_relations.get(level))#่ฎพ็ฝฎๆ—ฅๅฟ—็บงๅˆซ - sh = logging.StreamHandler()#ๅพ€ๅฑๅน•ไธŠ่พ“ๅ‡บ - # sh.setFormatter(format_str) #่ฎพ็ฝฎๅฑๅน•ไธŠๆ˜พ็คบ็š„ๆ ผๅผ - fh = logging.FileHandler(filename=filename)#ๅพ€ๆ–‡ไปถ้‡Œๅ†™ๅ…ฅ - fh.setFormatter(format_str)#่ฎพ็ฝฎๆ–‡ไปถ้‡Œๅ†™ๅ…ฅ็š„ๆ ผๅผ - self.logger.addHandler(sh) #ๆŠŠๅฏน่ฑกๅŠ ๅˆฐlogger้‡Œ - self.logger.addHandler(fh) - -if __name__ == '__main__': - log = Logger('all.log',level='debug') - log.logger.debug('debug') - log.logger.info('info') - log.logger.warning('่ญฆๅ‘Š') - log.logger.error('ๆŠฅ้”™') - log.logger.critical('ไธฅ้‡') \ No newline at end of file diff --git a/src/meacap/utils/parse_tool.py b/src/meacap/utils/parse_tool.py deleted file mode 100644 index c702d5443fadbb763f428f0fd4f55284f7cbe42e..0000000000000000000000000000000000000000 --- a/src/meacap/utils/parse_tool.py +++ /dev/null @@ -1,613 +0,0 @@ -from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, AutoModel -import torch -import nltk -from collections import OrderedDict -import numpy as np - -NUMBER_DICT = {'2':"two","3":"three","4":"four","5":"five",'6':"six",'7':"seven","8":"eight","9":"nine"} - -def merge_sim_node(entire_graph_dict, x, y): - entire_graph_dict[x]["Relation"].update(entire_graph_dict[y]["Relation"]) - entire_graph_dict[x]["count"] += entire_graph_dict[y]["count"] - for attr_key in list(entire_graph_dict[y]["Attribute"].keys()): - if attr_key not in entire_graph_dict[x]["Attribute"]: - entire_graph_dict[x]["Attribute"][attr_key] = entire_graph_dict[y]["Attribute"][attr_key] - else: - entire_graph_dict[x]["Attribute"][attr_key] += entire_graph_dict[y]["Attribute"][attr_key] - -def filter_relation(graph_dict,sim_entity_dict ,remove_map, sentences, attribute_thresh=3): - res_dict = {} - nodes = list(graph_dict.keys()) - for node in nodes: - pos_list = [] - for sentence in sentences: - pos = sentence.find(node)/len(sentence) - if pos > 0: - pos_list.append(pos) - final_pos = np.mean(pos_list) if pos_list else 1 - if node not in res_dict: - res_dict[node] = {} - res_dict[node]["rating"] = 0 - res_dict[node]["relative_pos"] = final_pos - res_dict[node]["Attribute"] = graph_dict[node]["Attribute"] - res_dict[node]["count"] = graph_dict[node]["count"] - res_dict[node]["Relation"] = {} - for obj in graph_dict[node]["Relation"]: - if obj in nodes: #copy - if obj in res_dict[node]["Relation"]: - res_dict[node]["Relation"][obj] += graph_dict[node]["Relation"][obj] - else: - res_dict[node]["Relation"][obj] = graph_dict[node]["Relation"][obj] - if obj not in res_dict: - res_dict[obj] = {} - res_dict[obj]["rating"] = 1 - else: - res_dict[obj]["rating"] += 1 - res_dict[node]["rating"] += 2 - elif obj in list(remove_map.keys()) and remove_map[obj] in nodes: # merge - if remove_map[obj] in res_dict[node]["Relation"]: - res_dict[node]["Relation"][remove_map[obj]] += graph_dict[node]["Relation"][obj] - else: - res_dict[node]["Relation"][remove_map[obj]] = graph_dict[node]["Relation"][obj] - if remove_map[obj] not in res_dict: - res_dict[remove_map[obj]] = {} - res_dict[remove_map[obj]]["rating"] = 1 - else: - res_dict[remove_map[obj]]["rating"] += 1 - res_dict[node]["rating"] += 2 - else: # pass - pass - # res_dict[node]["rating"] += len(res_dict[node]["Relation"]) * 5 - - # res_dict_sorted = OrderedDict(sorted(res_dict.items(), key=lambda item: item[1]["rating"], reverse=True)) - res_dict_sorted = OrderedDict(sorted(res_dict.items(), key=lambda item: item[1]["relative_pos"])) - entities = [] - for entity in res_dict_sorted: - flag = 0 - for attribute in res_dict_sorted[entity]["Attribute"]: - if res_dict_sorted[entity]["Attribute"][attribute] >= attribute_thresh: - entities.append(attribute +' '+ entity) - flag = 1 - break - if flag==0: - entities.append(entity) - # entities = list(res_dict_sorted.keys()) - - return res_dict_sorted, entities - - - - -# def merge_sim_entities(model, entities, count_dict, attribute_dict): -# entity_embeddings = model.encode(entities, convert_to_tensor=True, normalize_embeddings=True) -# entity_correlation = torch.mm(entity_embeddings, entity_embeddings.T) -# for idx in range(len(entity_correlation)): -# entity_correlation[idx, idx] = 0 -# sim_index = torch.where(entity_correlation > 0.6) -# sim_entity_dict = {} -# -# remove_list = [] -# for ids, (x, y) in enumerate(zip(sim_index[0], sim_index[1])): -# if entities[x] not in sim_entity_dict: -# sim_entity_dict[entities[x]] = [entities[y]] -# else: -# sim_entity_dict[entities[x]].append(entities[y]) -# if entities[y] not in sim_entity_dict: -# remove_list.append(entities[y]) -# count_dict[entities[x]] = count_dict[entities[x]] + count_dict[entities[y]] -# if entities[y] in attribute_dict: -# if entities[x] in attribute_dict: -# attribute_dict[entities[x]] = attribute_dict[entities[x]] + attribute_dict[entities[y]] -# else: -# attribute_dict[entities[x]] = attribute_dict[entities[y]] -# new_count_dict = OrderedDict() -# -# for key in list(count_dict.keys()): -# if key in remove_list or count_dict[key] <= 2: -# continue -# new_count_dict[key] = count_dict[key] -# new_count_dict = OrderedDict(sorted(new_count_dict.items(), key=lambda item: item[1], reverse=True)) -# entities = list(new_count_dict.keys()) -# -# return entities, new_count_dict - -def merge_graph_dict(model, entities, count_dict, entire_graph_dict, sentences): - # compute similarity - entity_embeddings = model.encode(entities, convert_to_tensor=True, normalize_embeddings=True) - entity_correlation = torch.mm(entity_embeddings, entity_embeddings.T) - for idx in range(len(entity_correlation)): - entity_correlation[idx, idx] = 0 - sim_index = torch.where(entity_correlation > 0.55) # TODO:xieyan - sim_entity_dict = {} - remove_entity_dict = {} - remove_list = [] - for ids, (x, y) in enumerate(zip(sim_index[0], sim_index[1])): - if entities[x] in remove_list: - if entities[x] not in remove_entity_dict: - remove_entity_dict[entities[x]] = [entities[y]] - else: - remove_entity_dict[entities[x]].append(entities[y]) - else: - if entities[x] not in sim_entity_dict: - sim_entity_dict[entities[x]] = [entities[y]] - else: - sim_entity_dict[entities[x]].append(entities[y]) - count_dict[entities[x]] = count_dict[entities[x]] + count_dict[entities[y]] - if entities[y] not in sim_entity_dict: - remove_list.append(entities[y]) - - # if entities[y] in attribute_dict: - # if entities[x] in attribute_dict: - # attribute_dict[entities[x]] = attribute_dict[entities[x]] + attribute_dict[entities[y]] - # else: - # attribute_dict[entities[x]] = attribute_dict[entities[y]] - merge_sim_node(entire_graph_dict, entities[x], entities[y]) - new_count_dict = OrderedDict() - filterd_graph_dict = {} - # update remove_list - removed_map = {} - remove_list = [] - for ent in sim_entity_dict: - remove_list += sim_entity_dict[ent] - for remove_wd in remove_list: - try: - removed_map[remove_wd] = [wd for wd in remove_entity_dict[remove_wd] if wd not in remove_list][0] - except: - print("remove wrong!") - - for key in list(count_dict.keys()): - if key in remove_list or count_dict[key] <= 2: # TODO: xieyan - continue - new_count_dict[key] = count_dict[key] - filterd_graph_dict[key] = entire_graph_dict[key] - if filterd_graph_dict: # >1 - filterd_graph_dict_final, entities = filter_relation(filterd_graph_dict, sim_entity_dict, removed_map, sentences) - else: - # get the first one - filterd_graph_dict_final = {} - entities = [] - # key = next(iter(entire_graph_dict)) - # filterd_graph_dict_final[key] = entire_graph_dict[key] - # entities = [key] - - new_count_dict = OrderedDict(sorted(new_count_dict.items(), key=lambda item: item[1], reverse=True)) - # entities = list(new_count_dict.keys()) - - return entities, new_count_dict, filterd_graph_dict_final - -def add_node_graph(scene_graph, subject, new_edge): - # new_edge: (object, relation) or (attribute) - if subject not in scene_graph: - scene_graph[subject] = { - "Relation":{}, - "Attribute":{}, - "count":1, - } - if len(new_edge)==2: # add relation - scene_graph[subject]["Relation"][new_edge[0]] = [new_edge[1]] - elif len(new_edge)==1: # add attribute - scene_graph[subject]["Attribute"][new_edge[0]] = 1 - elif len(new_edge)==0: # only subject - pass - else: - raise KeyError(f"{new_edge} is wrong") - - else: - if len(new_edge)==2: # add relation - if new_edge[0] not in scene_graph[subject]["Relation"]: - scene_graph[subject]["Relation"][new_edge[0]] = [new_edge[1]] - else: - - scene_graph[subject]["Relation"][new_edge[0]] += [new_edge[1]] - elif len(new_edge) == 1: # add attribute - scene_graph[subject]["Attribute"][new_edge[0]] = 1 - elif len(new_edge) == 0: # only subject - pass - else: - raise KeyError(f"{new_edge} is wrong") - return scene_graph - -def merge_seperate_graph(scene_graph, new_graph): - for key in list(new_graph.keys()): - if key in scene_graph: - scene_graph[key]["Relation"].update(new_graph[key]["Relation"]) - scene_graph[key]["count"]+= new_graph[key]["count"] - for attr_key in list(new_graph[key]["Attribute"].keys()): - if attr_key not in scene_graph[key]["Attribute"]: - scene_graph[key]["Attribute"][attr_key] = new_graph[key]["Attribute"][attr_key] - else: - scene_graph[key]["Attribute"][attr_key] += new_graph[key]["Attribute"][attr_key] - else: - scene_graph[key] = new_graph[key] - return scene_graph - - - -def format_scene_graph(graph_str): - return " ".join([item for item in graph_str.replace('(', ' ( ').replace(')', ' ) ').replace(',', ' , ').split() if item != '']) - - -def get_seg_list(graphs): - if isinstance(graphs, str): - seg_list = [scene_seg.replace('(', '').replace(')', '').strip() for scene_seg in format_scene_graph(graphs).split(') , (')] - elif isinstance(graphs, list): - seg_list = [] - for graph in graphs: - seg_list.extend([scene_seg.replace('(', '').replace(')', '').strip() for scene_seg in format_scene_graph(graph).split(') , (')]) - else: - raise ValueError('input should be either a string or a list of strings') - return list(set(seg_list)) - -def get_seg_list_seperate(graphs): - if isinstance(graphs, str): - seg_list = [scene_seg.replace('(', '').replace(')', '').strip() for scene_seg in format_scene_graph(graphs).split(') , (')] - elif isinstance(graphs, list): - seg_list = [] - for graph in graphs: - cur_list = [] - cur_list.extend([scene_seg.replace('(', '').replace(')', '').strip() for scene_seg in format_scene_graph(graph).split(') , (')]) - seg_list.append(cur_list) - else: - raise ValueError('input should be either a string or a list of strings') - return list(seg_list) - - -def parse(parser, parser_tokenizer, text_input, - max_input_length=128, max_output_length=128, beam_size=1, device="cuda:0"): - ''' - :param text_input: one or a list of textual image descriptions - :return: corresponding scene graphs of the input descriptions - ''' - - if isinstance(text_input, str): - text_input = [text_input] - - # breakpoint() - text_input = ['Generate Scene Graph: ' + text for text in text_input] - with torch.no_grad(): - encoded_text = parser_tokenizer( - text_input, - max_length=max_input_length, - truncation=True, - padding=True, - return_tensors='pt' - ) - text_tokens = encoded_text['input_ids'].to(device) - text_mask = encoded_text['attention_mask'].to(device) - - generated_ids = parser.generate( - text_tokens, - attention_mask=text_mask, - use_cache=True, - decoder_start_token_id=parser_tokenizer.pad_token_id, - num_beams=beam_size, - max_length=max_output_length, - early_stopping=True - ) - - # output to text - output_text = parser_tokenizer.batch_decode(generated_ids, skip_special_tokens=True, - clean_up_tokenization_spaces=True) - output_text = [format_scene_graph(text.replace('Generate Scene Graph:', '').strip()) for text in output_text] - return output_text - - -def get_graph_phrases(graph_str_list, type_dict): - seg_list = get_seg_list(graph_str_list) - #breakpoint() - new_pairs = [] - for seg in seg_list: - new_seg = [item.strip() for item in seg.split(',')] - try: - if len(new_seg) == 1 and len(seg_list) == 1: - new_pairs.append(new_seg[0]) - type_dict[new_seg[0]] = "object" - continue - - if len(new_seg) == 2: - new_pairs.append(new_seg[1] + " " + new_seg[0]) - type_dict[new_seg[1] + " " + new_seg[0]] = "attribute" - new_pairs.append(new_seg[0]) - type_dict[new_seg[0]] = "object" - continue - elif len(new_seg) == 3: - sentence = new_seg[0] + " " + new_seg[1] + " " + new_seg[2] - sentence_word = nltk.word_tokenize(sentence) - pos_type = nltk.pos_tag(sentence_word) - if new_seg[1] == 'is' and pos_type[-1][1] == 'JJ': - new_pairs.append(new_seg[2] + " " + new_seg[0]) - type_dict[new_seg[2] + " " + new_seg[0]] = "attribute" - new_pairs.append(new_seg[0]) - type_dict[new_seg[0]] = "object" - else: - # new_pairs.append(new_seg[0] + " " + new_seg[1] + " " + new_seg[2]) - type_dict[new_seg[0] + " " + new_seg[1] + " " + new_seg[2]] = "fact" - new_pairs.append(new_seg[0]) - type_dict[new_seg[0]] = "object" - if new_seg[1] == 'is': - continue - else: - new_pairs.append(new_seg[2]) - type_dict[new_seg[2]] = "object" - elif len(new_seg) > 3: - # new_pairs.append(new_seg[0] + " ".join(new_seg[1:-1]) + new_seg[-1]) - type_dict[new_seg[0] + " ".join(new_seg[1:-1]) + new_seg[-1]] = "fact" - new_pairs.append(new_seg[0]) - type_dict[new_seg[0]] = "object" - new_pairs.append(new_seg[-1]) - type_dict[new_seg[-1]] = "object" - except IndexError: - print(seg_list) - continue - - return list(set(new_pairs)) - -def get_graph_dict(model, graph_str_list,type_dict, attribute_dict): - seg_lists = get_seg_list_seperate(graph_str_list) - count_dict = OrderedDict() - total_entity_lists = [] - total_graph_dicts = [] - # process graphs - for seg_list in seg_lists: - #breakpoint() - entity_list = [] - cur_sg = dict() - for seg in seg_list: - new_seg = [item.strip() for item in seg.split(',')] - try: - if len(new_seg) == 1 and len(seg_list) == 1: - entity_list.append(new_seg[0]) - type_dict[new_seg[0]] = "object" - add_node_graph(cur_sg, new_seg[0], []) - continue - - if len(new_seg) == 2: - # entity_list.append(new_seg[1] + " " + new_seg[0]) - type_dict[new_seg[1] + " " + new_seg[0]] = "attribute" - entity_list.append(new_seg[0]) - type_dict[new_seg[0]] = "object" - if new_seg[0] not in attribute_dict: - attribute_dict[new_seg[0]] = [new_seg[1]] - else: - attribute_dict[new_seg[0]].append(new_seg[1]) - add_node_graph(cur_sg, new_seg[0], [new_seg[1]]) - continue - elif len(new_seg) == 3: - if new_seg[2] in list(NUMBER_DICT.keys()): - new_seg[2] = NUMBER_DICT[new_seg[2]] - sentence = new_seg[0] + " " + new_seg[1] + " " + new_seg[2] - # sentence_word = nltk.word_tokenize(sentence) - # pos_type = nltk.pos_tag(sentence_word) - if new_seg[1] == 'is': - # entity_list.append(new_seg[2] + " " + new_seg[0]) - type_dict[new_seg[2] + " " + new_seg[0]] = "attribute" - entity_list.append(new_seg[0]) - type_dict[new_seg[0]] = "object" - if new_seg[0] not in attribute_dict: - attribute_dict[new_seg[0]] = [new_seg[2]] - else: - attribute_dict[new_seg[0]].append(new_seg[2]) - add_node_graph(cur_sg, new_seg[0], [new_seg[2]]) - else: - # entity_list.append(new_seg[0] + " " + new_seg[1] + " " + new_seg[2]) - type_dict[new_seg[0] + " " + new_seg[1] + " " + new_seg[2]] = "fact" - entity_list.append(new_seg[0]) - type_dict[new_seg[0]] = "object" - if new_seg[1] == 'is': - continue - else: - entity_list.append(new_seg[2]) - type_dict[new_seg[2]] = "object" - add_node_graph(cur_sg, new_seg[0], [new_seg[2],new_seg[1]]) - add_node_graph(cur_sg, new_seg[2], []) - elif len(new_seg) > 3: - # entity_list.append(new_seg[0] + " ".join(new_seg[1:-1]) + new_seg[-1]) - type_dict[new_seg[0] + " ".join(new_seg[1:-1]) + new_seg[-1]] = "fact" - entity_list.append(new_seg[0]) - type_dict[new_seg[0]] = "object" - entity_list.append(new_seg[-1]) - type_dict[new_seg[-1]] = "object" - add_node_graph(cur_sg, new_seg[0], [new_seg[-1], new_seg[1:-1]]) - add_node_graph(cur_sg, new_seg[-1], []) - except IndexError: - print(seg_list) - continue - entity_list = list(set(entity_list)) - for entity in entity_list: - if entity not in count_dict: - count_dict[entity] = 1 - else: - count_dict[entity] += 1 - total_entity_lists.append(entity_list) - total_graph_dicts.append(cur_sg) - sorted_count_dict = OrderedDict(sorted(count_dict.items(), key=lambda item: item[1], reverse=True)) - entitys = list(sorted_count_dict.keys()) - entire_graph_dict = {} - for graph_dict in total_graph_dicts: - merge_seperate_graph(entire_graph_dict, graph_dict) - - - return entitys, sorted_count_dict, entire_graph_dict - - -def get_entitys(graph_str_list,type_dict, attribute_dict): - seg_lists = get_seg_list_seperate(graph_str_list) - count_dict = OrderedDict() - total_entity_lists = [] - for seg_list in seg_lists: - #breakpoint() - entity_list = [] - for seg in seg_list: - new_seg = [item.strip() for item in seg.split(',')] - try: - if len(new_seg) == 1 and len(seg_list) == 1: - entity_list.append(new_seg[0]) - type_dict[new_seg[0]] = "object" - continue - - if len(new_seg) == 2: - # entity_list.append(new_seg[1] + " " + new_seg[0]) - type_dict[new_seg[1] + " " + new_seg[0]] = "attribute" - entity_list.append(new_seg[0]) - type_dict[new_seg[0]] = "object" - if new_seg[0] not in attribute_dict: - attribute_dict[new_seg[0]] = [new_seg[1]] - else: - attribute_dict[new_seg[0]].append(new_seg[1]) - continue - elif len(new_seg) == 3: - if new_seg[2] in list(NUMBER_DICT.keys()): - new_seg[2] = NUMBER_DICT[new_seg[2]] - sentence = new_seg[0] + " " + new_seg[1] + " " + new_seg[2] - # sentence_word = nltk.word_tokenize(sentence) - # pos_type = nltk.pos_tag(sentence_word) - if new_seg[1] == 'is': - # entity_list.append(new_seg[2] + " " + new_seg[0]) - type_dict[new_seg[2] + " " + new_seg[0]] = "attribute" - entity_list.append(new_seg[0]) - type_dict[new_seg[0]] = "object" - if new_seg[0] not in attribute_dict: - attribute_dict[new_seg[0]] = [new_seg[2]] - else: - attribute_dict[new_seg[0]].append(new_seg[2]) - else: - # entity_list.append(new_seg[0] + " " + new_seg[1] + " " + new_seg[2]) - type_dict[new_seg[0] + " " + new_seg[1] + " " + new_seg[2]] = "fact" - entity_list.append(new_seg[0]) - type_dict[new_seg[0]] = "object" - if new_seg[1] == 'is': - continue - else: - entity_list.append(new_seg[2]) - type_dict[new_seg[2]] = "object" - elif len(new_seg) > 3: - # entity_list.append(new_seg[0] + " ".join(new_seg[1:-1]) + new_seg[-1]) - type_dict[new_seg[0] + " ".join(new_seg[1:-1]) + new_seg[-1]] = "fact" - entity_list.append(new_seg[0]) - type_dict[new_seg[0]] = "object" - entity_list.append(new_seg[-1]) - type_dict[new_seg[-1]] = "object" - except IndexError: - print(seg_list) - continue - entity_list = list(set(entity_list)) - for entity in entity_list: - if entity not in count_dict: - count_dict[entity] = 1 - else: - count_dict[entity] += 1 - total_entity_lists.append(entity_list) - sorted_count_dict = OrderedDict(sorted(count_dict.items(), key=lambda item: item[1], reverse=True)) - - entitys = list(sorted_count_dict.keys()) - - return entitys, sorted_count_dict - -def get_graph_phrases_new(graph_str_list, type_dict, count_dict): - seg_lists = get_seg_list_seperate(graph_str_list) - - - total_pairs = [] - for seg_list in seg_lists: - #breakpoint() - new_pairs = [] - for seg in seg_list: - new_seg = [item.strip() for item in seg.split(',')] - try: - if len(new_seg) == 1 and len(seg_list) == 1: - new_pairs.append(new_seg[0]) - type_dict[new_seg[0]] = "object" - if new_seg[0] not in count_dict: - count_dict[new_seg[0]] = 1 - else: - count_dict[new_seg[0]] += 1 - continue - - if len(new_seg) == 2: - new_pairs.append(new_seg[1] + " " + new_seg[0]) - type_dict[new_seg[1] + " " + new_seg[0]] = "attribute" - new_pairs.append(new_seg[0]) - type_dict[new_seg[0]] = "object" - if new_seg[0] not in count_dict: - count_dict[new_seg[0]] = 1 - else: - count_dict[new_seg[0]] += 1 - continue - elif len(new_seg) == 3: - sentence = new_seg[0] + " " + new_seg[1] + " " + new_seg[2] - sentence_word = nltk.word_tokenize(sentence) - pos_type = nltk.pos_tag(sentence_word) - if new_seg[1] == 'is' and pos_type[-1][1] == 'JJ': - new_pairs.append(new_seg[2] + " " + new_seg[0]) - type_dict[new_seg[2] + " " + new_seg[0]] = "attribute" - new_pairs.append(new_seg[0]) - type_dict[new_seg[0]] = "object" - if new_seg[0] not in count_dict: - count_dict[new_seg[0]] = 1 - else: - count_dict[new_seg[0]] += 1 - else: - # new_pairs.append(new_seg[0] + " " + new_seg[1] + " " + new_seg[2]) - type_dict[new_seg[0] + " " + new_seg[1] + " " + new_seg[2]] = "fact" - new_pairs.append(new_seg[0]) - type_dict[new_seg[0]] = "object" - if new_seg[0] not in count_dict: - count_dict[new_seg[0]] = 1 - else: - count_dict[new_seg[0]] += 1 - if new_seg[1] == 'is': - continue - else: - new_pairs.append(new_seg[2]) - type_dict[new_seg[2]] = "object" - if new_seg[2] not in count_dict: - count_dict[new_seg[2]] = 1 - else: - count_dict[new_seg[2]] += 1 - elif len(new_seg) > 3: - # new_pairs.append(new_seg[0] + " ".join(new_seg[1:-1]) + new_seg[-1]) - type_dict[new_seg[0] + " ".join(new_seg[1:-1]) + new_seg[-1]] = "fact" - new_pairs.append(new_seg[0]) - type_dict[new_seg[0]] = "object" - if new_seg[0] not in count_dict: - count_dict[new_seg[0]] = 1 - else: - count_dict[new_seg[0]] += 1 - new_pairs.append(new_seg[-1]) - type_dict[new_seg[-1]] = "object" - if new_seg[0] not in count_dict: - count_dict[new_seg[-1]] = 1 - else: - count_dict[new_seg[-1]] += 1 - except IndexError: - print(seg_list) - continue - total_pairs.append(new_pairs) - - all_pairs = [pair for pairs in total_pairs for pair in pairs] - - return list(set(all_pairs)) - - -if __name__ == "__main__": - device = "cuda" - parser_checkpoint = "/media/xieyan/Hard Disk2/pretrain_model/flan-t5-base-VG-factual-sg" - parser_tokenizer = AutoTokenizer.from_pretrained(parser_checkpoint) - parser = AutoModelForSeq2SeqLM.from_pretrained(parser_checkpoint) - - parser.eval() - parser.to(device) - scene_graphs = parse(parser, parser_tokenizer, - ["A young girl inhales with the intent of blowing out a candle.", - "A young girl is preparing to blow out her candle.", - "A kid is to blow out the single candle in a bowl of birthday goodness.", - "Girl blowing out the candle on an ice-cream", - "A little girl is getting ready to blow out a candle on a small dessert."], - device=device) - - # scene_graphs = parse(parser, parser_tokenizer, - # ["People talk to each other."], - # device=device) - type_dict = {} - concepts = get_graph_phrases(scene_graphs, type_dict) - print(concepts) \ No newline at end of file diff --git a/src/meacap/utils/some_utils.py b/src/meacap/utils/some_utils.py deleted file mode 100644 index fcf67004c4989eef607f350f0fc2f28dc38d7ab6..0000000000000000000000000000000000000000 --- a/src/meacap/utils/some_utils.py +++ /dev/null @@ -1,66 +0,0 @@ -import torch -import numpy as np -import random -import os -from .log import Logger - -def set_seed(args): - torch.manual_seed(args.seed) - torch.cuda.manual_seed_all(args.seed) - np.random.seed(args.seed) - random.seed(args.seed) - torch.backends.cudnn.deterministic = True - - -def update_args_logger(prefix, args): - if args.random_init == 1: - prefix = 'random_initialization_'+prefix - prefix = f"cbart-{args.bart}_{prefix}" - - if args.do_sample: - if args.top_k > 0: - prefix += f'_sample_top_k_{args.top_k}' - else: - prefix += f'_sample_top_p_{args.top_p}' - if args.decoder_chain > 1: - prefix += f'_decoder_chain{args.decoder_chain}' - if args.threshold > 0: - prefix += f'_threshold{args.threshold}' - - prefix += f'_{args.num_keywords}keywords' - - log_path = f'outputs/generate_keywords' - if not os.path.exists(log_path): - os.makedirs(log_path) - - output_path = f'outputs' - if not os.path.exists(output_path): - os.makedirs(output_path) - output_file = '{}/{}.txt'.format(output_path, prefix) - args.output_file = output_file - args.log_path = log_path - - if args.conzic_sample: - log_file = f'outputs/{args.dataset}_alpha{args.alpha}_beta{args.beta}.log' - else: - log_file = '{}/{}.log'.format(log_path, prefix) - logger = Logger(log_file) - logger.logger.info(f'The log file is {log_file}') - logger.logger.info(f'output file is {args.output_file}') - logger.logger.info(args) - - return args, logger - -PROMPT_ENSEMBLING = [ - 'Attention! There is', - 'Attention! There are', - 'There is', - 'There are', - 'A picture showing', - 'The picture shows', - 'A photo of', - 'An image of', - 'See! There is', - 'See! There are', - 'The image depicts', - 'The image depicts that'] \ No newline at end of file diff --git a/src/meacap/viecap_inference.py b/src/meacap/viecap_inference.py deleted file mode 100644 index ae83dfd1a846ee8c3225acdbe838ece309c7de0d..0000000000000000000000000000000000000000 --- a/src/meacap/viecap_inference.py +++ /dev/null @@ -1,157 +0,0 @@ -import clip -import torch -import argparse -from PIL import Image -from transformers import AutoTokenizer -from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, AutoModel - -import json -import copy -import sys, os - -if os.path.abspath(os.path.join(os.path.dirname(__file__), '..')) not in sys.path: - sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '..'))) - -from viecap.ClipCap import ClipCaptionModel -from viecap.utils import compose_discrete_prompts -from viecap.search import greedy_search, beam_search, opt_search - -from sentence_transformers import SentenceTransformer -from utils.detect_utils import retrieve_concepts -from models.clip_utils import CLIP - -@torch.no_grad() -def main(args) -> None: - # initializing - device = args.device - clip_name = args.clip_model.replace('/', '') - clip_hidden_size = 640 if 'RN' in args.clip_model else 512 - - # loading model - tokenizer = AutoTokenizer.from_pretrained(args.language_model) - model = ClipCaptionModel(args.continuous_prompt_length, args.clip_project_length, clip_hidden_size, gpt_type = args.language_model) - model.load_state_dict(torch.load(args.weight_path, map_location = device), strict = False) - model.to(device) - encoder, preprocess = clip.load(args.clip_model, device = device) - - vl_model = CLIP(args.vl_model) - vl_model = vl_model.to(device) - print('Load CLIP from the checkpoint {}.'.format(args.clip_model)) - - sim_func = torch.nn.CosineSimilarity(dim=1, eps=1e-6) - wte_model = SentenceTransformer(args.wte_model_path) - print('Load sentenceBERT from the checkpoint {}.'.format(args.wte_model_path)) - - # parser model for memory concepts extracting - parser_tokenizer = AutoTokenizer.from_pretrained(args.parser_checkpoint) - parser_model = AutoModelForSeq2SeqLM.from_pretrained(args.parser_checkpoint) - parser_model.eval() - parser_model.to(device) - print('Load Textual Scene Graph parser from the checkpoint {}.'.format(args.parser_checkpoint)) - - # prepare memory bank - memory_id = args.memory_id - memory_base_path = args.memory_base_path - memory_caption_path = os.path.join(memory_base_path, f"memory/{memory_id}", "memory_captions.json") - memory_clip_embedding_file = os.path.join(memory_base_path, f"memory/{memory_id}", "memory_clip_embeddings.pt") - memory_wte_embedding_file = os.path.join(memory_base_path, f"memory/{memory_id}", "memory_wte_embeddings.pt") - memory_clip_embeddings = torch.load(memory_clip_embedding_file) - memory_wte_embeddings = torch.load(memory_wte_embedding_file) - with open(memory_caption_path, 'r') as f: - memory_captions = json.load(f) - - # huge memeory bank cannot load on GPU - if memory_id == 'cc3m' or memory_id == 'ss1m': - retrieve_on_CPU = True - print('CC3M/SS1M Memory is too big to compute on RTX 3090, Moving to CPU...') - vl_model_retrieve = copy.deepcopy(vl_model).to(cpu_device) - memory_clip_embeddings = memory_clip_embeddings.to(cpu_device) - else: - vl_model_retrieve = vl_model - retrieve_on_CPU = False - - image = preprocess(Image.open(args.image_path)).unsqueeze(dim = 0).to(device) - image_features = encoder.encode_image(image).float() - image_features /= image_features.norm(2, dim = -1, keepdim = True) - - continuous_embeddings = model.mapping_network(image_features).view(-1, args.continuous_prompt_length, model.gpt_hidden_size) - if args.using_hard_prompt: - batch_image_embeds = vl_model.compute_image_representation_from_image_path(args.image_path) - - if retrieve_on_CPU != True: - clip_score, clip_ref = vl_model_retrieve.compute_image_text_similarity_via_embeddings( - batch_image_embeds, memory_clip_embeddings) - else: - batch_image_embeds_cpu = batch_image_embeds.to(cpu_device) - clip_score_cpu, clip_ref_cpu = vl_model_retrieve.compute_image_text_similarity_via_embeddings( - batch_image_embeds_cpu, - memory_clip_embeddings) - clip_score = clip_score_cpu.to(device) - clip_ref = clip_ref_cpu.to(device) - select_memory_ids = clip_score.topk(args.memory_caption_num, dim=-1)[1].squeeze(0) - select_memory_captions = [memory_captions[id] for id in select_memory_ids] - select_memory_wte_embeddings = memory_wte_embeddings[select_memory_ids] - detected_objects = retrieve_concepts(parser_model=parser_model, parser_tokenizer=parser_tokenizer, - wte_model=wte_model, - select_memory_captions=select_memory_captions, - image_embeds=batch_image_embeds, - device=device) - - print("memory concepts:", detected_objects) - discrete_tokens = compose_discrete_prompts(tokenizer, detected_objects).unsqueeze(dim = 0).to(args.device) - - discrete_embeddings = model.word_embed(discrete_tokens) - if args.only_hard_prompt: - embeddings = discrete_embeddings - elif args.soft_prompt_first: - embeddings = torch.cat((continuous_embeddings, discrete_embeddings), dim = 1) - else: - embeddings = torch.cat((discrete_embeddings, continuous_embeddings), dim = 1) - else: - embeddings = continuous_embeddings - - if 'gpt' in args.language_model: - if not args.using_greedy_search: - sentence = beam_search(embeddings = embeddings, tokenizer = tokenizer, beam_width = args.beam_width, model = model.gpt) # List[str] - sentence = sentence[0] # selected top 1 - else: - sentence = greedy_search(embeddings = embeddings, tokenizer = tokenizer, model = model.gpt) - else: - sentence = opt_search(prompts=args.text_prompt, embeddings = embeddings, tokenizer = tokenizer, beam_width = args.beam_width, model = model.gpt) - sentence=sentence[0] - - print(f'the generated caption: {sentence}') - -if __name__ == '__main__': - - parser = argparse.ArgumentParser() - parser.add_argument('--device', default = 'cuda:0') - parser.add_argument('--clip_model', default = 'ViT-B/32') - parser.add_argument('--language_model', default = 'openai-community/gpt2') - parser.add_argument('--vl_model', type=str, default=r'openai/clip-vit-base-patch32') - parser.add_argument("--parser_checkpoint", type=str, default=r'lizhuang144/flan-t5-base-VG-factual-sg') - parser.add_argument("--wte_model_path", type=str, default=r'sentence-transformers/all-MiniLM-L6-v2') - parser.add_argument('--continuous_prompt_length', type = int, default = 10) - parser.add_argument('--clip_project_length', type = int, default = 10) - parser.add_argument('--temperature', type = float, default = 0.01) - parser.add_argument('--top_k', type = int, default = 3) - parser.add_argument('--threshold', type = float, default = 0.2) - parser.add_argument('--disable_all_entities', action = 'store_true', default = False, help = 'whether to use entities with a single word only') - parser.add_argument('--name_of_entities_text', default = 'coco_entities', choices = ('visual_genome_entities', 'coco_entities', 'open_image_entities', 'vinvl_vg_entities', 'vinvl_vgoi_entities')) - parser.add_argument('--prompt_ensemble', action = 'store_true', default = False) - parser.add_argument('--weight_path', default = '/raid/datasets/viecap_files/checkpoints/train_coco/coco_prefix-0014.pt') - parser.add_argument('--image_path', default = 'image_example/COCO_val2014_000000027440.jpg') - parser.add_argument('--using_hard_prompt', action = 'store_true', default = True) - parser.add_argument('--soft_prompt_first', action = 'store_true', default = False) - parser.add_argument('--only_hard_prompt', action = 'store_true', default = False) - parser.add_argument('--using_greedy_search', action = 'store_true', default = False, help = 'greedy search or beam search') - parser.add_argument('--beam_width', type = int, default = 5, help = 'width of beam') - parser.add_argument('--text_prompt', type = str, default = None) - parser.add_argument("--memory_id", type=str, default=r"coco",help="memory name") - parser.add_argument("--memory_base_path", type=str, default="/raid/datasets/meacap_files/") - #parser.add_argument("--memory_caption_path", type=str, default='data/memory/coco/memory_captions.json') # unused - parser.add_argument("--memory_caption_num", type=int, default=5) - args = parser.parse_args() - print('args: {}\n'.format(vars(args))) - - main(args) \ No newline at end of file diff --git a/src/model.py b/src/model.py deleted file mode 100644 index 3921c768a505a023a71556a1bf0369e593b8b058..0000000000000000000000000000000000000000 --- a/src/model.py +++ /dev/null @@ -1,1565 +0,0 @@ -import timm -import torch -import torch.nn as nn -import math -import yaml -import os -import pickle -import random -import torchvision.transforms as T - - -from src.decap.decap import decoding_batched, DeCap, MLP -from src.decap.decap import get_decap_model -from src.dino_extraction import get_self_attention, process_self_attention, transform_to_standard_dino_out, get_layer_n_output, feats -from src.decap.im2txtprojection.im2txtprojection import Im2TxtProjector, ProjectionType -from src.bbox_utils import extract_bboxes_feats, extract_bboxes_feats_double_dino, process_bboxes, map_traces_to_grid -from src.talk2dino.talk2dino import ProjectionLayer -from transformers import GPT2LMHeadModel -from src.embedding_utils import get_pseudo_inverse, revert_transformation - - -import math -from tqdm import tqdm - -import torch.nn.functional as F - - -# Container to store outputs -patch_embeddings = {} - -# Hook function -def save_patch_embeddings(module, input, output): - """ - module: the module being hooked (the transformer) - input: input to the module - output: output from the module - """ - # output shape: (batch_size, 1 + num_patches, embedding_dim) - patch_tokens = output[:, 1:, :] # remove the CLS token - patch_embeddings['tokens'] = patch_tokens - patch_embeddings['cls'] = output[:, 0, :] - patch_embeddings['full'] = output - -def compute_region_means(patch_embeddings, variance): - """ - Compute weighted region means for a batch of patch embeddings. - - Args: - patch_embeddings (torch.Tensor): Tensor of shape (N, H, W, embed_dim). - variance (float): Variance for the Gaussian weighting. If 0, select the center patch. - If variance > 100, use uniform weights. - - Returns: - region_means (torch.Tensor): Weighted means for each region, shape (N, embed_dim). - patch_weights (torch.Tensor): The weights applied, shape (N, H, W). - """ - N = patch_embeddings.shape[0] - grid_size = int(patch_embeddings.shape[1]**0.5) - - W = H = grid_size - - patch_embeddings = patch_embeddings.view(N, grid_size, grid_size, -1) # Shape (N, grid_size, grid_size, embed_dim) - device = patch_embeddings.device - - # Create coordinate grid once - y = torch.linspace(-1, 1, grid_size, device=device) - x = torch.linspace(-1, 1, grid_size, device=device) - yy, xx = torch.meshgrid(y, x, indexing="ij") - - if variance == 0: - # One-hot weight at the center - patch_weights = torch.zeros(N, H, W, device=device) - center_y_options = [grid_size // 2] if grid_size % 2 == 1 else [grid_size // 2 - 1, grid_size // 2] - center_x_options = [grid_size // 2] if grid_size % 2 == 1 else [grid_size // 2 - 1, grid_size // 2] - for i in range(N): - cy = random.choice(center_y_options) - cx = random.choice(center_x_options) - patch_weights[i, cy, cx] = 1.0 - elif variance >= 100: - # Uniform weights - patch_weights = torch.full((N, H, W), 1 / (H * W), device=device) - else: - # Gaussian weights - distances = xx**2 + yy**2 - weights = torch.exp(-distances / variance) - weights = weights / weights.sum() # Normalize - patch_weights = weights.unsqueeze(0).expand(N, -1, -1) - - # Compute the weighted sum (i.e., the weighted mean) - weighted_patches = patch_embeddings * patch_weights.unsqueeze(-1) - region_means = weighted_patches.sum(dim=(1, 2)) - - return region_means - -class Patchioner(nn.Module): - - def __init__(self, projection_type, decoder_weights, device, prefix_size, linear_talk2dino, support_memory_size, - dino_model=None, proxyclip_clipmodel=None, proxyclip_vfm=None, use_talk2dino_project=True, normalize=True, attention_type='qkv', talk2dino_config=None, - talk2dino_weights=None, resize_dim=518, crop_dim=518, talk2dino_attn_type='qkv', calculate_argmax_text=False, - online_texts=None, clip_model_name=None, use_open_clip=False, viecap_config=None, regionclip_config=None, invite_config=None, denseclip_config=None, alphaclip_config=None, clipcap_config=None): - super().__init__() - - self.decoding_method = None - - if viecap_config is not None: - if viecap_config.get('meacap', False): - from src.meacap.entrypoint import MeaCap - self.viecap = MeaCap(viecap_config, device, clip_model_name) - else: - from src.viecap.entrypoint import VieCap - self.viecap = VieCap(viecap_config, device, clip_model_name) - else: - self.viecap = None - - if clipcap_config is not None: - # Determine DINO feature dimension based on model type - dino_feature_dim = prefix_size # Use prefix_size as DINO feature dimension - if dino_model is not None: - if 'dinov2_vits14' in dino_model: - dino_feature_dim = 384 - elif 'dinov2_vitb14' in dino_model: - dino_feature_dim = 768 - elif 'dinov2_vitl14' in dino_model: - dino_feature_dim = 1024 - elif 'dinov2_vitg14' in dino_model: - dino_feature_dim = 1536 - - - from src.clipcap.entrypoint import ClipCapModel - self.clipcap = ClipCapModel(clipcap_config, device, dino_feature_dim) - else: - self.clipcap = None - - if dino_model is not None and 'dinotxt' in dino_model: - clip_model_name = 'DINO.txt' - - if alphaclip_config is not None: - print(f"Using AlphaCLIP model {alphaclip_config}") - # AlphaClip will be loaded later after determining patch sizes - - # decoder initialization - if online_texts is not None: - projection_type_enum = ProjectionType.ONLINE_TEXTS - elif projection_type == 'coco': - projection_type_enum = ProjectionType.COCO_CAPTIONS - elif projection_type == 'msmarco': - projection_type_enum = ProjectionType.MS_MARCO_QUERIES_A - elif projection_type == 'blip': - projection_type_enum = ProjectionType.CC3M_BLIP - elif projection_type == 'vg': - projection_type_enum = ProjectionType.VISUAL_GENOME - elif projection_type == 'vg-test': - projection_type_enum = ProjectionType.VISUAL_GENOME_TEST - elif os.path.exists(projection_type): - print(f"Loading memory bank from {projection_type}") - projection_type_enum = projection_type - else: - raise Exception("The projection_type field must be 'coco', 'msmarco', 'blip' or 'vg'") - - self.calculate_argmax_text = calculate_argmax_text - if not self.calculate_argmax_text and decoder_weights is not None: - self.decoder = get_decap_model(device, decoder_weights, prefix_size) - if support_memory_size > 0: - self.im_proj = Im2TxtProjector( - type=projection_type_enum, - use_talk2dino=use_talk2dino_project, - linear_talk2dino=linear_talk2dino, - support_memory_size=support_memory_size, - device_str=device, - normalize_memory_embs=(dino_model is not None) and ('dinov2' not in dino_model), - talk2dino_attn_type=talk2dino_attn_type, - online_texts=online_texts, - clip_modelname=clip_model_name, - use_open_clip=use_open_clip, - regionclip_config=regionclip_config, - invite_config=invite_config, - denseclip_config=denseclip_config, - - ) - else: - self.im_proj = None - - self.normalize = normalize - # ProxyCLIP initialization - if proxyclip_clipmodel: - from src.proxyclip.proxyclip import ProxyCLIP - self.proxyclip = ProxyCLIP(clip_type='openai', model_type=proxyclip_clipmodel, vfm_model=proxyclip_vfm, device=device) - self.patch_size = self.proxyclip.vfm.patch_embed.patch_size - if isinstance(self.patch_size, tuple): - self.patch_size = self.patch_size[0] - # DINOv2 initialization - self.resize_dim=resize_dim - self.crop_dim=crop_dim - self.num_global_tokens = 1 if dino_model is None or "reg" not in dino_model else 5 - - if dino_model is not None: - if 'dinov2' in dino_model: - patch_size = 14 - elif 'patch16' in dino_model: - patch_size = 16 - elif 'patch14' in dino_model: - patch_size = 14 - elif 'patch32' in dino_model: - patch_size = 32 - elif dino_model is None: - pass - elif use_open_clip: - patch_size = int(dino_model.split('/')[-1]) - assert patch_size > 0, "Patch size must be a positive integer, got {}".format(patch_size) - elif regionclip_config is not None: - # For RegionCLIP ResNet, use effective patch size of 32 (spatial downsampling factor) - patch_size = regionclip_config.get('patch_size', 32) - elif invite_config is not None: - # For INViTE CLIP ViT, extract patch size from model name - model_name = invite_config.get('name', 'ViT-B/32') - if 'ViT-B/32' in model_name: - patch_size = 32 - elif 'ViT-B/16' in model_name: - patch_size = 16 - elif 'ViT-L/14' in model_name: - patch_size = 14 - else: - # Default patch size for ViT models - print(f"Unknown INViTE model {model_name}, using default patch size 32") - patch_size = 32 - elif denseclip_config is not None: - # For DenseClip ViT, extract patch size from config - from src.denseclip.loader import load_denseclip_config - denseclip_config_dict = load_denseclip_config(denseclip_config) - patch_size = denseclip_config_dict.get('model', {}).get('vision', {}).get('vision_patch_size', 16) - elif alphaclip_config is not None: - # For AlphaClip, extract patch size from model name - model_name = alphaclip_config.get('name', 'ViT-B/16') - patch_size = alphaclip_config.get('patch_size', None) - if patch_size is None: - if 'ViT-B/32' in model_name: - patch_size = 32 - elif 'ViT-B/16' in model_name: - patch_size = 16 - elif 'ViT-L/14' in model_name: - patch_size = 14 - else: - print(f"Unknown AlphaClip model {model_name}, using default patch size 16") - patch_size = 16 - elif clip_model_name == 'ResNet50x4' and dino_model == 'RN50x4': - patch_size = 32 # Effective patch size for ResNet50x4 - else: - raise Exception("Unknown patch size") - - if regionclip_config is not None: - # For RegionCLIP ResNet, calculate spatial dimensions differently - # ResNet reduces input by factor of 32, so for crop_dim=224, final spatial size is 7x7 - spatial_size = crop_dim // patch_size - self.num_patch_tokens = spatial_size * spatial_size - self.num_tokens = self.num_global_tokens + self.num_patch_tokens - - # RegionCLIP ResNet typically uses different embedding dimensions - # This should be determined from the loaded model - self.embed_dim = regionclip_config.get('embed_dim', 1024) # Common for ResNet-50 CLIP models, but should be verified - elif invite_config is not None: - # For INViTE CLIP ViT, calculate patch dimensions like standard ViT - self.num_patch_tokens = (crop_dim // patch_size) * (crop_dim // patch_size) - self.num_tokens = self.num_global_tokens + self.num_patch_tokens - - # INViTE CLIP ViT embedding dimensions based on model architecture - model_name = invite_config.get('name', 'ViT-B/32') - if 'ViT-L' in model_name: - self.embed_dim = 768 # ViT-L/14 uses 768-dim embeddings in CLIP - elif 'ViT-B' in model_name: - self.embed_dim = 512 # ViT-B uses 512-dim embeddings in CLIP - else: - self.embed_dim = 512 # Default for CLIP ViT models - elif denseclip_config is not None: - # For DenseClip ViT, calculate patch dimensions like standard ViT - self.num_patch_tokens = (crop_dim // patch_size) * (crop_dim // patch_size) - self.num_tokens = self.num_global_tokens + self.num_patch_tokens - - # DenseClip embedding dimensions from config - self.embed_dim = denseclip_config_dict.get('model', {}).get('vision', {}).get('embed_dim', 512) - else: - self.num_patch_tokens = crop_dim // patch_size * crop_dim // patch_size - self.num_tokens = self.num_global_tokens + self.num_patch_tokens - - if regionclip_config is not None: - # RegionCLIP ResNet typically uses different embedding dimensions - # This should be determined from the loaded model - self.embed_dim = regionclip_config.get('embed_dim', 1024) # Common for ResNet-50 CLIP models, but should be verified - elif invite_config is not None: - # INViTE CLIP ViT embedding dimensions based on model architecture - model_name = invite_config.get('name', 'ViT-B/32') - if 'ViT-L' in model_name: - self.embed_dim = 768 # ViT-L/14 uses 768-dim embeddings in CLIP - elif 'ViT-B' in model_name: - self.embed_dim = 512 # ViT-B uses 512-dim embeddings in CLIP - else: - print(f"Unknown INViTE model {model_name}, using default embedding dimension 512") - self.embed_dim = 512 # Default for CLIP ViT models - elif denseclip_config is not None: - # DenseClip embedding dimensions from config - self.embed_dim = denseclip_config_dict.get('model', {}).get('vision', {}).get('embed_dim', 512) - elif alphaclip_config is not None: - # AlphaClip embedding dimensions based on model architecture - model_name = alphaclip_config.get('name', 'ViT-B/16') - embed_dim = alphaclip_config.get('embed_dim', None) - if embed_dim is not None: - self.embed_dim = embed_dim - else: - if 'ViT-L' in model_name: - self.embed_dim = 768 # ViT-L uses 768-dim embeddings in CLIP - elif 'ViT-B' in model_name: - self.embed_dim = 512 # ViT-B uses 512-dim embeddings in CLIP - else: - print(f"Unknown AlphaClip model {model_name}, using default embedding dimension 512") - self.embed_dim = 512 - - # For AlphaClip, calculate patch dimensions - self.num_patch_tokens = (crop_dim // patch_size) * (crop_dim // patch_size) - self.num_tokens = self.num_global_tokens + self.num_patch_tokens - elif 'vitl' in dino_model or 'vit_large' in dino_model or 'ViT-L' in dino_model or 'ViT-H' in dino_model: - self.embed_dim = 1024 - elif 'vitb' in dino_model or 'vit_base' in dino_model or 'ViT-B' in dino_model: - self.embed_dim = 768 - elif 'vits' in dino_model or 'vit_small' in dino_model: - self.embed_dim = 384 - elif prefix_size is not None: - print("[FALLBACK] Using prefix_size as embed_dim:", prefix_size) - self.embed_dim = prefix_size - else: - raise Exception("Unknown ViT model") - - self.model_name = dino_model if dino_model is not None else 'proxyclip' - self.num_attn_heads = 16 if dino_model is not None and not 'vits' in dino_model else 6 - self.scale = 0.125 - if dino_model is not None: - if 'dinov2' in dino_model: - self.num_global_tokens = 1 if "reg" not in dino_model else 5 - - model_family = 'facebookresearch/dinov2' - self.dino = torch.hub.load(model_family, dino_model) - - if 'dinotxt' in dino_model: - self.dino = self.dino.visual_model.backbone.model - self.image_transforms = T.Compose([ - T.Resize(resize_dim, interpolation=T.InterpolationMode.BICUBIC), - T.CenterCrop(crop_dim), - T.ToTensor(), - T.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)), - ]) - self.image_transforms_no_crop = T.Compose([ - T.Resize((resize_dim, resize_dim), interpolation=T.InterpolationMode.BICUBIC), - T.ToTensor(), - T.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)), - ]) - elif 'openai' in dino_model: - print(f"Loading OpenAI model {dino_model} using timm.create_model...") - # we use this case to test DeCap original architecture (with CLIP instead of DINOv2) - - # timm uses GELU while original OpenAI model uses QuickGELU - # https://github.com/huggingface/pytorch-image-models/issues/1754 - # we fix the activation function because DeCap is trained using OpenAI interface - class QuickGELU(torch.nn.Module): - def forward(self, x: torch.Tensor): - return x * torch.sigmoid(1.702 * x) - - print(f"timm.list_models(pretrained=True) contains dino_model ({dino_model}):", dino_model in timm.list_models(pretrained=True)) - - timm_model = timm.create_model(dino_model, pretrained=True, act_layer=QuickGELU, img_size=resize_dim) - - print(f"timm_model is instance of {type(timm_model)}") - - self.dino = timm_model.to(device) - - assert hasattr(self.dino, 'blocks'), f"The model does not have 'blocks' attribute. dino is instance of {type(self.dino)}" - - self.image_transforms = T.Compose([ - T.Resize(resize_dim, interpolation=T.InterpolationMode.BICUBIC), - T.CenterCrop(crop_dim), - T.ToTensor(), - T.Normalize(mean=(0.48145466, 0.4578275, 0.40821073), - std=(0.26862954, 0.26130258, 0.27577711)), - ]) - - self.image_transforms_no_crop = T.Compose([ - T.Resize((resize_dim, resize_dim), interpolation=T.InterpolationMode.BICUBIC), - T.ToTensor(), - T.Normalize(mean=(0.48145466, 0.4578275, 0.40821073), - std=(0.26862954, 0.26130258, 0.27577711)), - ]) - elif use_open_clip: - - print(f""" - ------------------------------------------- - Using OpenCLIP model {dino_model} - ------------------------------------------- - """) - # load open clip weights - from open_clip import create_model_and_transforms, get_tokenizer - open_clip, preprocess_train, preprocess_val = create_model_and_transforms( - model_name=dino_model, - pretrained="laion2b_s32b_b79k", - device=device, - #image_size=224, - #context_length=77, - #vocab_size=49408, - ) - tokenizer = get_tokenizer(dino_model.replace("/", "-")) - - - open_clip.eval() - - image_transforms_open_clip = preprocess_train - - self.dino = open_clip - self.image_transforms = image_transforms_open_clip - self.image_transforms_no_crop = T.Compose([ - T.Resize((resize_dim, resize_dim), interpolation=T.InterpolationMode.BICUBIC), - T.ToTensor(), - T.Normalize(mean=(0.48145466, 0.4578275, 0.40821073), - std=(0.26862954, 0.26130258, 0.27577711)), - ]) - - self.decoding_method = tokenizer.decode - elif regionclip_config is not None: - # load regionclip model - from src.regionclip.loader import load_regionclip_from_checkpoint - - regionclip_checkpoint = regionclip_config.get('checkpoint', None) - if regionclip_checkpoint is None: - raise Exception("RegionCLIP checkpoint not specified in the configuration") - regionclip_config_name = regionclip_config.get('config_name', None) - - self.dino = load_regionclip_from_checkpoint(regionclip_checkpoint, device=device, config=regionclip_config_name, override_config=regionclip_config) - - # use standard clip preprocessing transforms - self.image_transforms = T.Compose([ - T.Resize(resize_dim, interpolation=T.InterpolationMode.BICUBIC), - T.CenterCrop(crop_dim), - T.ToTensor(), - T.Normalize(mean=(0.48145466, 0.4578275, 0.40821073), - std=(0.26862954, 0.26130258, 0.27577711)), - ]) - self.image_transforms_no_crop = T.Compose([ - T.Resize((resize_dim, resize_dim), interpolation=T.InterpolationMode.BICUBIC), - T.ToTensor(), - T.Normalize(mean=(0.48145466, 0.4578275, 0.40821073), - std=(0.26862954, 0.26130258, 0.27577711)), - ]) - - # For RegionCLIP ResNet, compute effective patch size based on spatial downsampling - # ResNet reduces spatial resolution by factor of 32 (2^5: stem avgpool + 4 layers with stride 2) - # The final feature map from res4 (before attnpool) has resolution input_size // 32 - # So effective patch size is 32 for mapping between image coordinates and feature map coordinates - self.patch_size = regionclip_config.get('patch_size', 32) - - elif invite_config is not None: - # load INViTE CLIP model - from src.INViTE.loader import load_invite_clip - - # Load INViTE CLIP model using the config - self.dino, preprocess_transform, tokenize_method = load_invite_clip(invite_config, device=device) - - # Use the preprocess transform from INViTE CLIP - self.image_transforms = preprocess_transform - self.image_transforms_no_crop = T.Compose([ - T.Resize((resize_dim, resize_dim), interpolation=T.InterpolationMode.BICUBIC), - T.ToTensor(), - T.Normalize(mean=(0.48145466, 0.4578275, 0.40821073), - std=(0.26862954, 0.26130258, 0.27577711)), - ]) - - # Extract patch size from model name for coordinate mapping - model_name = invite_config.get('name', 'ViT-B/32') - if 'ViT-B/32' in model_name: - self.patch_size = 32 - elif 'ViT-B/16' in model_name: - self.patch_size = 16 - elif 'ViT-L/14' in model_name: - self.patch_size = 14 - else: - print(f"Unknown INViTE model {model_name}, using default patch size 32") - self.patch_size = 32 - - elif denseclip_config is not None: - # load DenseClip model - from src.denseclip.loader import load_denseclip - - # Load DenseClip model using the config - checkpoint_path = denseclip_config_dict.get('checkpoint_path', None) - config_name = denseclip_config_dict.get('config_name', 'denseclip_vitb16') - - self.dino = load_denseclip(config_name=denseclip_config, device=device) - - # Use standard CLIP preprocessing transforms - self.image_transforms = T.Compose([ - T.Resize(resize_dim, interpolation=T.InterpolationMode.BICUBIC), - T.CenterCrop(crop_dim), - T.ToTensor(), - T.Normalize(mean=(0.48145466, 0.4578275, 0.40821073), - std=(0.26862954, 0.26130258, 0.27577711)), - ]) - self.image_transforms_no_crop = T.Compose([ - T.Resize((resize_dim, resize_dim), interpolation=T.InterpolationMode.BICUBIC), - T.ToTensor(), - T.Normalize(mean=(0.48145466, 0.4578275, 0.40821073), - std=(0.26862954, 0.26130258, 0.27577711)), - ]) - - # Extract patch size from config for coordinate mapping - self.patch_size = denseclip_config_dict.get('model', {}).get('vision', {}).get('vision_patch_size', 16) - - elif alphaclip_config is not None: - # load AlphaClip model - from src.alphaclip.alphaclip_loader import load_alphaclip - - # Load AlphaClip model using the config - model_name = alphaclip_config.get('name', None) - alpha_vision_checkpoint = alphaclip_config.get('alpha_vision_checkpoint', None) - loader, self.dino, preprocess_transform = load_alphaclip(model_name=model_name, device=device, alpha_vision_ckpt_pth=alpha_vision_checkpoint) - - # Use standard CLIP preprocessing transforms - self.image_transforms = T.Compose([ - T.Resize(resize_dim, interpolation=T.InterpolationMode.BICUBIC), - T.CenterCrop(crop_dim), - T.ToTensor(), - T.Normalize(mean=(0.48145466, 0.4578275, 0.40821073), - std=(0.26862954, 0.26130258, 0.27577711)), - ]) - self.image_transforms_no_crop = T.Compose([ - T.Resize((resize_dim, resize_dim), interpolation=T.InterpolationMode.BICUBIC), - T.ToTensor(), - T.Normalize(mean=(0.48145466, 0.4578275, 0.40821073), - std=(0.26862954, 0.26130258, 0.27577711)), - ]) - - # Extract patch size from model name for coordinate mapping - patch_size = alphaclip_config.get('patch_size', None) - if patch_size is not None: - self.patch_size = patch_size - else: - if 'ViT-B/32' in model_name: - self.patch_size = 32 - elif 'ViT-B/16' in model_name: - self.patch_size = 16 - elif 'ViT-L/14' in model_name: - self.patch_size = 14 - else: - print(f"Unknown AlphaClip model {model_name}, using default patch size 16") - self.patch_size = 16 - - else: - raise Exception("Model family unsupported") - else: - self.image_transforms = T.Compose([ - T.Resize(resize_dim, interpolation=T.InterpolationMode.BICUBIC), - T.CenterCrop(crop_dim), - T.ToTensor(), - T.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)), - ]) - self.image_transforms_no_crop = T.Compose([ - T.Resize((resize_dim, resize_dim), interpolation=T.InterpolationMode.BICUBIC), - T.ToTensor(), - T.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)), - ]) - - if attention_type != 'qkv': - # in case kkv_attention is True, we perform the attention of the last block using Keys as Queries - original_qkv = self.dino.blocks[-1].attn.qkv - embed_dim = original_qkv.in_features - - weights = {} - biases = {} - weights['q'], weights['k'], weights['v'] = original_qkv.weight.reshape(3, embed_dim, embed_dim) - biases['q'], biases['k'], biases['v'] = original_qkv.bias.reshape(3, embed_dim) - - new_qkv = nn.Linear(embed_dim, embed_dim * 3, bias=True) - new_qkv.weight.data.copy_(torch.cat([weights[x] for x in attention_type], dim=0)) - new_qkv.bias.data.copy_(torch.cat([biases[x] for x in attention_type], dim=0)) - self.dino.blocks[-1].attn.qkv = new_qkv - - if dino_model is not None: - - if self.dino is not None: - self.dino.eval() - - if hasattr(self.dino, 'blocks'): - self.dino.blocks[-1].attn.qkv.register_forward_hook(get_self_attention) - # elif hasattr(self.dino, 'visual_model'): - # self.dino.visual_model.backbone.model.blocks[-1].attn.qkv.register_forward_hook(get_self_attention) - # need patch_size - if 'dino' in dino_model: - self.patch_size = self.dino.patch_size - elif 'openai' in dino_model: - # in the case self.dino is a timm model, we need to get patch_size from - # the model's configuration - # should get patch size from dino_model, which is a string with the following format: - # 'vit_base_patch32_clip_224.openai' - self.patch_size = int(dino_model.split('_')[2].replace('patch', '')) - elif regionclip_config is not None: - # For RegionCLIP ResNet, patch_size was already set above to 32 - pass # self.patch_size = 32 was set earlier - elif invite_config is not None: - # For INViTE CLIP ViT, patch_size was already set above based on model name - pass # self.patch_size was set during model loading - elif denseclip_config is not None: - # For DenseClip ViT, patch_size was already set above from config - pass # self.patch_size was set during model loading - elif alphaclip_config is not None: - # AlphaClip initialization - if self.dino is not None: - self.dino.eval() - # AlphaClip patch_size was already set during model loading - # No attention hooks needed for AlphaClip since we don't access self-attention - - if talk2dino_weights is not None: - # Talk2DINO initialization - talk2dino = ProjectionLayer.from_config(talk2dino_config) - talk2dino.load_state_dict(torch.load((talk2dino_weights), device)) - - self.embed_inversion = True - self.talk2dino_A_pinv = get_pseudo_inverse(talk2dino.linear_layer.weight).to(device) - self.talk2dino_b = talk2dino.linear_layer.bias.to(device) - else: - self.embed_inversion = False - else: - self.embed_inversion = False - - # Determine backbone type based on configuration - if proxyclip_clipmodel is not None: - self.backbone_type = 'CLIP' # ProxyCLIP uses CLIP - elif regionclip_config is not None: - self.backbone_type = 'RegionCLIP' - self.regionclip_config = regionclip_config.copy() # Store config for later use - elif invite_config is not None: - self.backbone_type = 'INViTE' - elif denseclip_config is not None: - self.backbone_type = 'DenseClip' - self.denseclip_config = denseclip_config_dict.copy() # Store config for later use - elif alphaclip_config is not None: - self.backbone_type = 'AlphaClip' - self.alphaclip_config = alphaclip_config.copy() # Store config for later use - elif use_open_clip and dino_model is not None: - self.backbone_type = 'OpenCLIP' - elif dino_model is not None: - if 'dinotxt' in dino_model: - self.backbone_type = 'DINO.txt' - elif 'dinov2' in dino_model: - self.backbone_type = 'DINO' - elif 'openai' in dino_model: - - self.backbone_type = 'CLIP' - else: - self.backbone_type = 'DINO' # Default for other DINO variants - else: - self.backbone_type = 'CLIP' # Default fallback - - if not hasattr(self, 'dino'): - print(f"Warning: No DINO model loaded!") - self.dino = None - - - - @classmethod - def from_config(cls, config, device='cpu', online_texts=None): - if type(config) is str: - # if the configuration is a string, we treat it as a file path - with open(config, 'r') as f: - config = yaml.safe_load(f) - model = cls( - projection_type=config.get('projection_type', 'coco'), - decoder_weights=config.get('decap_weights', None), - device=device, - prefix_size=config['prefix_size'], - linear_talk2dino=config.get('linear_talk2dino', False), - support_memory_size=config['support_memory_size'], - dino_model=config.get('dino_model', None), - proxyclip_clipmodel=config.get('proxyclip_clipmodel', None), - proxyclip_vfm=config.get('proxyclip_vfm', None), - use_talk2dino_project=config.get('use_talk2dino_project', True), - normalize=config.get('normalize', True), - attention_type=config.get('attention_type', 'qkv'), - talk2dino_config=config.get('talk2dino_config', None), - talk2dino_weights=config.get('talk2dino_weights', None), - resize_dim=config.get('resize_dim', 518), - crop_dim=config.get('crop_dim', 518), - talk2dino_attn_type=config.get('talk2dino_attn_type', 'qkv'), - calculate_argmax_text=config.get('calculate_argmax_text', False), - clip_model_name=config.get('clip_model_name', None), - online_texts=online_texts, - use_open_clip=config.get('use_open_clip', False), - viecap_config=config.get('viecap', None), - regionclip_config=config.get('regionclip_config', None), - invite_config=config.get('invite_config', None), - denseclip_config=config.get('denseclip_config', None), - alphaclip_config=config.get('alphaclip_config', None), - clipcap_config=config.get('clipcap', None), - ) - model.to(device) - return model - - - def forward(self, imgs, - get_cls_capt=True, - get_avg_self_attn_capt=False, - get_attn_heads_capt=False, - get_patch_capts=False, - get_register_capts=False, - bboxes=None, - traces=None, - get_controllable_capts=False, - bs_factor=4, - gaussian_avg=False, - gaussian_bbox_variance=0.5, - get_avg_patch_capt=False, - gaussian_img_variance=1, - use_attn_map_for_bboxes=False, - use_attention_tracing=False, - double_DINO_for_bboxes=False, - double_DINO_for_bboxes_return_type="avg", - double_DINO_use_cls=False, - cleaning_type=None, - clean_after_projection=True, - alpha=1.0, - clean_from="cls", - caption_bboxes_type : str = None, - return_n_best_sims=None, - compute_scores : bool = False - ): - """ - bboxes: [BS x N_BOX_MAX x 4] - - double_DINO_for_bboxes_return_type : "cls" | "avg" | "gaussian_avg" - - caption_bboxes_type = None | capt_type : str either 'avg_self_attn_capt' or 'cls_capt' if we want to compute the image caption of each bounding box as the caption of the cropped image - - cleaning_type : None | "orthogonal_projection" | "contrastive_mask" - - clean_after_projection : bool - if True, it first projects the patch embeddings and general token in textual space and then apply cleaning - - alpha : between 0.0 and 1.0, used for "orthogonal_projection", weights the projection to subtract - - clean_from : "cls" | "avg_self_attn" - """ - assert clean_from in ["cls", "avg_self_attn"] - assert cleaning_type in [None, "orthogonal_projection", "contrastive_mask"] - - outs = {} - bs = imgs.shape[0] - - if self.dino is not None and bboxes is not None and double_DINO_for_bboxes: - if self.backbone_type == 'AlphaClip': - raise ValueError("double_DINO_for_bboxes is not supported with AlphaClip. AlphaClip processes regions differently.") - self.dino.blocks[-1].attn.qkv.register_forward_hook(get_self_attention) - self.dino.blocks[-1].register_forward_hook(get_layer_n_output) - - if self.dino is not None and hasattr(self.dino, 'visual') and hasattr(self.dino.visual, 'transformer'): - # Attach hook to the visual transformer - hook_handle = self.dino.visual.transformer.register_forward_hook(save_patch_embeddings) - - if caption_bboxes_type is not None: - return self.caption_bboxes(imgs, bboxes, caption_bboxes_type, compute_scores=compute_scores) - - # Special handling for AlphaClip: process each bbox/trace separately - if self.backbone_type == 'AlphaClip' and (bboxes is not None or traces is not None): - return self.forward_alphaclip_with_regions(imgs, bboxes, traces, get_cls_capt, get_avg_self_attn_capt, - get_attn_heads_capt, get_patch_capts, get_register_capts, - get_controllable_capts, get_avg_patch_capt, - gaussian_avg, gaussian_bbox_variance, gaussian_img_variance, - compute_scores, return_n_best_sims) - - # Forward pass based on backbone type - if 'DINO' in self.backbone_type: - dino_outs = self.dino(imgs, is_training=True) - elif self.backbone_type == 'CLIP' and self.model_name == 'proxyclip': - dino_outs = self.proxyclip(imgs) - elif self.backbone_type == 'CLIP' and 'openai' in self.model_name: - # Using timm interface for OpenAI CLIP models - output = self.dino.forward_features(imgs) - # Projecting 768 -> 512 - output = self.dino.head(output) - - # Reporting output in DINOv2 format - dino_outs = { - 'x_norm_clstoken': output[:, 0, :], - 'x_norm_patchtokens': output[:, 1:, :], - } - elif self.backbone_type == 'AlphaClip': - # AlphaClip ViT case - standard forward for whole images - # alphaclip always needs the alpha mask, so we pass it dummy masks made of ones, one per image - # the shape is [BS, 1, H, W] where H and W are grid dimensions - grid_size = self.crop_dim // self.patch_size - alpha_mask = torch.ones((imgs.shape[0], 1, grid_size, grid_size), device=imgs.device) - # upscale alpha_mask to match the input image size - alpha_mask = F.interpolate(alpha_mask, size=(self.crop_dim, self.crop_dim), - mode='nearest') # using nearest, so that the mask is made only of ones - output = self.dino.visual(imgs, alpha=alpha_mask, return_patches=True) - - # Reporting output in DINOv2 format - dino_outs = { - 'x_norm_clstoken': output[:, 0, :], # CLS token - 'x_norm_patchtokens': output[:, 1:, :], # Patch tokens - 'x_norm_regtokens': None, # AlphaClip doesn't have register tokens - } - elif self.backbone_type == 'RegionCLIP': - # RegionCLIP ResNet case - use_attnpool_for_spatial_feats = self.regionclip_config.get('use_attnpool_for_spatial_feats', True) - dino_outs = self.dino.visual.forward_return_spatial_feats(imgs, use_attnpool_for_spatial_feats=use_attnpool_for_spatial_feats) - elif self.backbone_type == 'INViTE': - # INViTE CLIP ViT case - get_all_last = True - output = self.dino.visual(imgs, get_all_last=get_all_last) - - # Reporting output in DINOv2 format - dino_outs = { - 'x_norm_clstoken': output[:, 0, :], # CLS token - 'x_norm_patchtokens': output[:, 1:, :], # Patch tokens - } - - elif self.backbone_type == 'DenseClip': - # DenseClip ViT case - # DenseClip model has encode_image method that returns features - output = self.dino.visual.forward(imgs, get_patches=True) - - # DenseClip returns features in format compatible with CLIP - # We need to extract the visual features and structure them properly - if hasattr(output, 'shape') and len(output.shape) == 3: - # Output is [batch_size, num_tokens, embed_dim] - # First token is CLS, rest are patch tokens - dino_outs = { - 'x_norm_clstoken': output[:, 0, :], # CLS token - 'x_norm_patchtokens': output[:, 1:, :], # Patch tokens - } - else: - # If output format is different, handle accordingly - # This might need adjustment based on actual DenseClip output format - raise ValueError(f"Unexpected DenseClip output format: {output.shape if hasattr(output, 'shape') else type(output)}") - - elif self.backbone_type == 'OpenCLIP': - # Using open_clip interface - output = self.dino.visual(imgs) - output = patch_embeddings['full'] - - output = output @ self.dino.visual.proj # shape (B, N_patches, output_dim) - - # Reporting output in DINOv2 format - dino_outs = { - 'x_norm_clstoken': output[:, 0, :], - 'x_norm_patchtokens': output[:, 1:, :], - } - else: - raise ValueError(f"Unsupported backbone type: {self.backbone_type}") - - # Handle self-attention processing (only for models that have attention mechanisms) - has_attention = (('DINO' in self.backbone_type or self.backbone_type == 'DenseClip') and - 'self_attn' in feats) - - if has_attention: - self_attn, self_attn_maps = process_self_attention(feats['self_attn'], imgs.shape[0], self.num_tokens, self.num_attn_heads, self.embed_dim, self.scale, self.num_global_tokens, ret_self_attn_maps=True) - avg_self_attn_token = (self_attn.unsqueeze(-1) * dino_outs['x_norm_patchtokens']).mean(dim=1) - - self_attn_maps = self_attn_maps.softmax(dim=-1) - disentangled_self_attn = (dino_outs['x_norm_patchtokens'].unsqueeze(1) * self_attn_maps.unsqueeze(-1)).mean(dim=2) - #else: - # # For models without accessible self-attention (RegionCLIP ResNet, INViTE CLIP), create fallback - # avg_self_attn_token = dino_outs['x_norm_patchtokens'].mean(dim=1) - # # Create dummy attention heads (just repeat the average) - # disentangled_self_attn = avg_self_attn_token.unsqueeze(1).repeat(1, self.num_attn_heads, 1) - - if cleaning_type is not None: - batch_patchtokens = dino_outs['x_norm_patchtokens'] - if clean_from == "cls": - batch_clean_from_token = dino_outs['x_norm_clstoken'] - else: # clean_from == "avg_self_attn" - if has_attention: - batch_clean_from_token = avg_self_attn_token - else: - # Fallback to cls token if self-attention not available - batch_clean_from_token = dino_outs['x_norm_clstoken'] - - dino_outs['x_norm_patchtokens'] = None - - # Loop over the batch size and apply ctx_cleaner per element - for i in range(bs): - # Extract the patch tokens and class token for the current batch element - patchtokens_i = batch_patchtokens[i:i+1] # Shape: [1, seq_len, embed_dim] - clean_from_token_i = batch_clean_from_token[i:i+1] # Shape: [1, embed_dim] - - # Apply ctx_cleaner to each batch element - if clean_after_projection: - cleaned_patchtokens = self.ctx_cleaner( - self.im_proj.project(patchtokens_i, normalize=True), - self.im_proj.project(clean_from_token_i, normalize=True), - cleaning_type=cleaning_type, - alpha=alpha - ) - else: - cleaned_patchtokens = self.im_proj.project( \ - self.ctx_cleaner( - patchtokens_i / patchtokens_i.norm(dim=-1,keepdim=True), - clean_from_token_i / clean_from_token_i.norm(dim=-1,keepdim=True), - cleaning_type=cleaning_type, - alpha=alpha - ), normalize=True - ) - - # Store the cleaned patch tokens in the output dictionary - if 'x_norm_patchtokens' not in dino_outs or dino_outs['x_norm_patchtokens'] is None: - dino_outs['x_norm_patchtokens'] = cleaned_patchtokens - else: - dino_outs['x_norm_patchtokens'] = torch.cat( - (dino_outs['x_norm_patchtokens'], cleaned_patchtokens), dim=0 - ) - - - embed_dim = dino_outs['x_norm_patchtokens'].shape[-1] - if get_cls_capt: - ret = self.caption_tokens(dino_outs['x_norm_clstoken'], compute_scores=compute_scores) - if compute_scores is True: - outs['cls_capt'], outs['cls_capt_scores'] = ret - else: - outs['cls_capt'] = ret - if get_avg_self_attn_capt: - ret = self.caption_tokens(avg_self_attn_token, compute_scores=compute_scores) - if compute_scores is True: - outs['avg_self_attn_capt'], outs['avg_self_attn_capt_scores'] = ret - else: - outs['avg_self_attn_capt'] = ret - if get_avg_patch_capt: - ret = self.caption_tokens(compute_region_means(dino_outs['x_norm_patchtokens'], gaussian_img_variance), compute_scores=compute_scores) - if compute_scores is True: - outs['avg_patch_capt'], outs['avg_patch_capt_scores'] = ret - else: - outs['avg_patch_capt'] = ret - - - if get_attn_heads_capt: - - ret = self.caption_tokens(disentangled_self_attn.view(-1, embed_dim), compute_scores=compute_scores) - - if compute_scores is True: - attn_heads_capt_unrolled, attn_heads_scores_unrolled = ret - outs['attn_heads_capts'] = [attn_heads_capt_unrolled[i * self.num_attn_heads:(i + 1) * self.num_attn_heads] for i in range(bs)] - outs['attn_heads_scores'] = [attn_heads_scores_unrolled[i * self.num_attn_heads:(i + 1) * self.num_attn_heads] for i in range(bs)] - else: - attn_heads_capt_unrolled = ret - outs['attn_heads_capts'] = [attn_heads_capt_unrolled[i * self.num_attn_heads:(i + 1) * self.num_attn_heads] for i in range(bs)] - if get_patch_capts: - n_patches = dino_outs['x_norm_patchtokens'].shape[1] - - ret = self.caption_tokens(dino_outs['x_norm_patchtokens'].reshape(-1, embed_dim), project=cleaning_type is None, compute_scores=compute_scores) - - if compute_scores is True: - patch_tokens_capts_unrolled, patch_tokens_scores_unrolled = ret - outs['patch_tokens_capts'] = [patch_tokens_capts_unrolled[i * n_patches:(i + 1) * n_patches] for i in range(bs)] - outs['patch_tokens_scores'] = [patch_tokens_scores_unrolled[i * n_patches:(i + 1) * n_patches] for i in range(bs)] - else: - patch_tokens_capts_unrolled = ret - outs['patch_tokens_capts'] = [patch_tokens_capts_unrolled[i * n_patches:(i + 1) * n_patches] for i in range(bs)] - if get_register_capts: - - ret = self.caption_tokens(dino_outs['x_norm_regtokens'].view(-1, embed_dim), compute_scores=compute_scores) - - if compute_scores is True: - register_capt_unrolled, register_scores_unrolled = ret - outs['register_capts'] = [register_capt_unrolled[i * 4:(i + 1) * 4] for i in range(bs)] - outs['register_scores'] = [register_scores_unrolled[i * 4:(i + 1) * 4] for i in range(bs)] - else: - register_capt_unrolled = ret - outs['register_capts'] = [register_capt_unrolled[i * 4:(i + 1) * 4] for i in range(bs)] - if bboxes is not None and not get_controllable_capts: - bbox_bs = bs * bs_factor - n_boxes = bboxes.shape[1] - if double_DINO_for_bboxes: - outs_layer_n = transform_to_standard_dino_out(feats['intermediate_output'], self.dino) - if double_DINO_use_cls: - cls_layer_n = outs_layer_n['x_norm_clstoken'] - registers_layer_n = outs_layer_n['x_norm_regtokens'] - else: - cls_layer_n = None - registers_layer_n = None - patches_layer_n = outs_layer_n['x_norm_patchtokens'] - bbox_feats = extract_bboxes_feats_double_dino(self.dino, patches_layer_n, bboxes, cls_layer_n, registers_layer_n, self.patch_size, return_type=double_DINO_for_bboxes_return_type, gaussian_bbox_variance=gaussian_bbox_variance)#.view(-1, self.embed_dim) - else: - bbox_attn_maps = self_attn.cpu() if (use_attn_map_for_bboxes and has_attention) else None - bbox_feats = extract_bboxes_feats(dino_outs['x_norm_patchtokens'], bboxes, gaussian_avg=gaussian_avg, - gaussian_bbox_variance=gaussian_bbox_variance, - patch_size=self.patch_size, attention_map=bbox_attn_maps)#.view(-1, self.embed_dim) - - - bbox_feats = bbox_feats.view(-1, embed_dim) - n_batch = math.ceil(bbox_feats.shape[0] / bbox_bs) - outs['bbox_capts'] = [] - if compute_scores is True: - outs['bbox_scores'] = [] - if return_n_best_sims is not None: - outs['bbox_sims'] = [] - #print(f"{n_batch = }, {bs = }, {bbox_bs = }") - for i in range(n_batch): - start = i * bbox_bs - end = start + bbox_bs if i < n_batch - 1 else bbox_feats.shape[0] - #cur_bbox_feats = bbox_feats[start:end] - if return_n_best_sims is None: - - ret = self.caption_tokens(bbox_feats[start:end], project=(cleaning_type is None), compute_scores=compute_scores) - - if compute_scores is True: - bbox_capts, bbox_scores = ret - outs['bbox_capts'].extend(bbox_capts) - outs['bbox_scores'].extend(bbox_scores) - else: - bbox_capts = ret - outs['bbox_capts'].extend(bbox_capts) - else: - - ret = self.caption_tokens(bbox_feats[start:end], project=(cleaning_type is None), return_n_best_sims=return_n_best_sims, compute_scores=compute_scores) - - if compute_scores is True: - (bbox_capts, bbox_sims), bbox_scores = ret - outs['bbox_capts'].extend(bbox_capts) - outs['bbox_sims'].extend(bbox_sims) - outs['bbox_scores'].extend(bbox_scores) - else: - bbox_capts, bbox_sims = ret - outs['bbox_capts'].extend(bbox_capts) - outs['bbox_sims'].extend(bbox_sims) - - outs['bbox_capts'] = [outs['bbox_capts'][i * n_boxes:(i + 1) * n_boxes] for i in range(bs)] - if compute_scores is True: - outs['bbox_scores'] = [outs['bbox_scores'][i * n_boxes:(i + 1) * n_boxes] for i in range(bs)] - if return_n_best_sims is not None: - outs['bbox_sims'] = [outs['bbox_sims'][i * n_boxes:(i + 1) * n_boxes] for i in range(bs)] - elif bboxes is not None and get_controllable_capts and self.backbone_type != 'AlphaClip': - bbox_attn_maps = self_attn.cpu() if (use_attn_map_for_bboxes and has_attention) else None - n_boxes = bboxes.shape[1] - bbox_feats = extract_bboxes_feats(dino_outs['x_norm_patchtokens'], bboxes, gaussian_avg=gaussian_avg, gaussian_bbox_variance=gaussian_bbox_variance, get_single_embedding_per_image=True, patch_size=self.patch_size, attention_map=bbox_attn_maps) - - outs['set_controllable_capts'] = self.caption_tokens(bbox_feats) - - if traces is not None and self.backbone_type != 'AlphaClip': - n_patches = int(dino_outs['x_norm_patchtokens'].shape[1] ** 0.5) - relevant_patches = torch.stack([map_traces_to_grid(trace, n_patches) for trace in traces], dim=0).to(next(self.parameters()).device) - if use_attention_tracing and has_attention: - relevant_patches = (self_attn.view(relevant_patches.shape) * relevant_patches) - trace_embeds = (relevant_patches.unsqueeze(-1) * dino_outs['x_norm_patchtokens'].view(bs, n_patches, n_patches, embed_dim)).mean(dim=(1,2)) - - outs['trace_capts'] = self.caption_tokens(trace_embeds) - - return outs - - def forward_alphaclip_with_regions(self, imgs, bboxes=None, traces=None, get_cls_capt=True, - get_avg_self_attn_capt=False, get_attn_heads_capt=False, - get_patch_capts=False, get_register_capts=False, - get_controllable_capts=False, get_avg_patch_capt=False, - gaussian_avg=False, gaussian_bbox_variance=0.5, gaussian_img_variance=1, - compute_scores=False, return_n_best_sims=None): - """ - Special forward method for AlphaClip that processes each bbox/trace separately. - This is required because AlphaClip processes regions differently than other backbones. - - AlphaClip's visual forward accepts an alpha parameter which is a binary mask - indicating which patches should be attended to. - """ - from src.alphaclip.alpha_mask_utils import ( - bbox_to_alpha_mask, bboxes_to_alpha_mask, - trace_to_alpha_mask, traces_to_alpha_mask - ) - - outs = {} - bs = imgs.shape[0] - device = next(self.parameters()).device - - # Calculate grid size for alpha masks - crop_dim = self.crop_dim - patch_size = 1 - grid_size = crop_dim // patch_size - - effective_grid_size = self.crop_dim // self.patch_size - - # Check if we should use CLS token or aggregate patches - use_cls_for_localized = self.alphaclip_config.get('use_cls_for_localized_captions', True) - - def extract_alphaclip_features(output, alpha_mask=None): - """ - Extract features from AlphaClip output based on configuration. - - Args: - output: AlphaClip visual output [batch, num_tokens, embed_dim] - alpha_mask: Optional alpha mask used for processing [batch, grid_size, grid_size] - - Returns: - features: Extracted features [batch, embed_dim] - """ - if use_cls_for_localized: - # Use CLS token (original implementation) - return output[:, 0, :] # [batch, embed_dim] - else: - # Aggregate patches like standard forward method - patch_tokens = output[:, 1:, :] # [batch, num_patches, embed_dim] - - if alpha_mask is not None: - # Weight patches by alpha mask - # Flatten alpha mask to match patch dimensions - alpha_flat = alpha_mask.view(alpha_mask.shape[0], -1) # [batch, num_patches] - - # Apply mask weights to patches - weighted_patches = patch_tokens * alpha_flat.unsqueeze(-1) # [batch, num_patches, embed_dim] - - # Normalize by sum of weights to get average - mask_sum = alpha_flat.sum(dim=1, keepdim=True) + 1e-8 # Avoid division by zero - aggregated_features = weighted_patches.sum(dim=1) / mask_sum.unsqueeze(-1) # [batch, embed_dim] - else: - # Simple average of all patches - aggregated_features = patch_tokens.mean(dim=1) # [batch, embed_dim] - - return aggregated_features - - # Handle controllable captions case (OR all regions into single mask per image) - if get_controllable_capts and (bboxes is not None or traces is not None): - controllable_capts = [] - - for img_idx in range(bs): - img = imgs[img_idx:img_idx+1] # [1, C, H, W] - - # Create combined alpha mask for this image - alpha_mask = torch.zeros((grid_size, grid_size)) - - alpha_mask_patches = torch.zeros((effective_grid_size, effective_grid_size), device=device) - - if bboxes is not None: - img_bboxes = bboxes[img_idx] # [n_boxes, 4] - alpha_mask = bboxes_to_alpha_mask(img_bboxes, grid_size, patch_size, crop_dim) - alpha_mask_patches = bboxes_to_alpha_mask(img_bboxes, effective_grid_size, self.patch_size, self.crop_dim) - - if traces is not None: - img_traces = traces[img_idx] # List of traces - trace_mask = traces_to_alpha_mask(img_traces, grid_size) - alpha_mask = torch.logical_or(alpha_mask, trace_mask).float() - alpha_mask_patches = torch.logical_or(alpha_mask_patches, traces_to_alpha_mask(img_traces, effective_grid_size)).float() - - # Add batch dimension and move to device: [1, grid_size, grid_size] - alpha_mask = alpha_mask.unsqueeze(0).to(device) - - alpha_mask_patches = alpha_mask_patches.unsqueeze(0).to(device) # [1, grid_size, grid_size] - - # Process with AlphaClip using the combined mask - output = self.dino.visual(img, alpha=alpha_mask, return_patches=True) - - # Extract features based on configuration - features = extract_alphaclip_features(output, alpha_mask_patches) - - # Caption the extracted features - ret = self.caption_tokens(features, compute_scores=compute_scores) - if compute_scores: - capt, score = ret - controllable_capts.extend(capt) - else: - controllable_capts.extend(ret) - - outs['set_controllable_capts'] = controllable_capts - return outs - - # Handle standard bboxes case (separate caption for each bbox) - if bboxes is not None: - n_boxes = bboxes.shape[1] - - # Process each image and each bbox separately - all_bbox_capts = [] - all_bbox_scores = [] if compute_scores else None - all_bbox_sims = [] if return_n_best_sims is not None else None - - for img_idx in range(bs): - img = imgs[img_idx:img_idx+1] # [1, C, H, W] - img_bboxes = bboxes[img_idx] # [n_boxes, 4] - - bbox_capts_for_img = [] - bbox_scores_for_img = [] if compute_scores else None - bbox_sims_for_img = [] if return_n_best_sims is not None else None - - for box_idx in range(n_boxes): - bbox = img_bboxes[box_idx] # [4] - - # Skip dummy boxes (negative values) - if bbox.sum().item() < 0: - bbox_capts_for_img.append("") # Empty caption for dummy box - if compute_scores: - bbox_scores_for_img.append(0.0) - if return_n_best_sims is not None: - bbox_sims_for_img.append([]) - continue - - # Create alpha mask for this bbox - alpha_mask = bbox_to_alpha_mask(bbox, grid_size, patch_size, self.crop_dim) - alpha_mask = alpha_mask.unsqueeze(0).to(device) # [1, crop_dim, crop_dim] - - alpha_mask_patches = bbox_to_alpha_mask(bbox, effective_grid_size, self.patch_size, self.crop_dim) - alpha_mask_patches = alpha_mask_patches.unsqueeze(0).to(device) # [1, effective_grid_size, effective_grid_size] - - # Process with AlphaClip using the bbox mask - output = self.dino.visual(img, alpha=alpha_mask, return_patches=True) - - # Extract features based on configuration - features = extract_alphaclip_features(output, alpha_mask_patches) - - # Caption the extracted features - if return_n_best_sims is None: - ret = self.caption_tokens(features, compute_scores=compute_scores) - if compute_scores: - capt, score = ret - bbox_capts_for_img.extend(capt) - bbox_scores_for_img.extend(score) - else: - bbox_capts_for_img.extend(ret) - else: - ret = self.caption_tokens(features, return_n_best_sims=return_n_best_sims, compute_scores=compute_scores) - if compute_scores: - (capt, sim), score = ret - bbox_capts_for_img.extend(capt) - bbox_sims_for_img.extend(sim) - bbox_scores_for_img.extend(score) - else: - capt, sim = ret - bbox_capts_for_img.extend(capt) - bbox_sims_for_img.extend(sim) - - all_bbox_capts.append(bbox_capts_for_img) - if compute_scores: - all_bbox_scores.append(bbox_scores_for_img) - if return_n_best_sims is not None: - all_bbox_sims.append(bbox_sims_for_img) - - outs['bbox_capts'] = all_bbox_capts - if compute_scores: - outs['bbox_scores'] = all_bbox_scores - if return_n_best_sims is not None: - outs['bbox_sims'] = all_bbox_sims - - # Handle traces case (separate caption for each trace) - if traces is not None: - trace_capts = [] - trace_scores = [] if compute_scores else None - - for img_idx in range(bs): - img = imgs[img_idx:img_idx+1] # [1, C, H, W] - trace = traces[img_idx] # List of traces for this image - - # Create alpha mask for this trace - alpha_mask = trace_to_alpha_mask(trace, grid_size) - alpha_mask = alpha_mask.unsqueeze(0).to(device) # [1, crop_size, crop_size] - - alpha_mask_patches = trace_to_alpha_mask(trace, effective_grid_size) - alpha_mask_patches = alpha_mask_patches.unsqueeze(0).to(device) # [1, effective_grid_size, effective_grid_size] - - # Process with AlphaClip using the trace mask - output = self.dino.visual(img, alpha=alpha_mask, return_patches=True) - - # Extract features based on configuration - features = extract_alphaclip_features(output, alpha_mask_patches) - - # Caption the extracted features - ret = self.caption_tokens(features, compute_scores=compute_scores) - if compute_scores: - capt, score = ret - else: - capt = ret - - trace_capts.extend(capt) - if compute_scores: - trace_scores.extend(score) - - outs['trace_capts'] = trace_capts - if compute_scores: - outs['trace_scores'] = trace_scores - - # If no bboxes or traces, do standard processing for other caption types - if bboxes is None and traces is None: - # Standard AlphaClip processing without alpha mask (whole image) - output = self.dino.visual(imgs, return_patches=True) # No alpha parameter = whole image attention - - if get_cls_capt: - # For CLS captions, always use CLS token regardless of config - cls_token = output[:, 0, :] - ret = self.caption_tokens(cls_token, compute_scores=compute_scores) - if compute_scores: - outs['cls_capt'], outs['cls_capt_scores'] = ret - else: - outs['cls_capt'] = ret - - if get_avg_patch_capt: - patch_tokens = output[:, 1:, :] - avg_patch = compute_region_means(patch_tokens, gaussian_img_variance) - ret = self.caption_tokens(avg_patch, compute_scores=compute_scores) - if compute_scores: - outs['avg_patch_capt'], outs['avg_patch_capt_scores'] = ret - else: - outs['avg_patch_capt'] = ret - - # For get_avg_self_attn_capt, get_attn_heads_capt, get_patch_capts, get_register_capts - # AlphaClip doesn't provide self-attention access, so we'll use fallback behavior - if get_avg_self_attn_capt: - # Use average of patch tokens as fallback - patch_tokens = output[:, 1:, :] - avg_self_attn_token = patch_tokens.mean(dim=1) - ret = self.caption_tokens(avg_self_attn_token, compute_scores=compute_scores) - if compute_scores: - outs['avg_self_attn_capt'], outs['avg_self_attn_capt_scores'] = ret - else: - outs['avg_self_attn_capt'] = ret - - if get_attn_heads_capt: - # Use repeated average patch token as fallback for attention heads - patch_tokens = output[:, 1:, :] - avg_patch_token = patch_tokens.mean(dim=1) # [bs, embed_dim] - # Repeat for each attention head - repeated_tokens = avg_patch_token.unsqueeze(1).repeat(1, self.num_attn_heads, 1) # [bs, num_heads, embed_dim] - - ret = self.caption_tokens(repeated_tokens.view(-1, self.embed_dim), compute_scores=compute_scores) - if compute_scores: - attn_heads_capt_unrolled, attn_heads_scores_unrolled = ret - outs['attn_heads_capts'] = [attn_heads_capt_unrolled[i * self.num_attn_heads:(i + 1) * self.num_attn_heads] for i in range(bs)] - outs['attn_heads_scores'] = [attn_heads_scores_unrolled[i * self.num_attn_heads:(i + 1) * self.num_attn_heads] for i in range(bs)] - else: - attn_heads_capt_unrolled = ret - outs['attn_heads_capts'] = [attn_heads_capt_unrolled[i * self.num_attn_heads:(i + 1) * self.num_attn_heads] for i in range(bs)] - - if get_patch_capts: - patch_tokens = output[:, 1:, :] # [bs, num_patches, embed_dim] - n_patches = patch_tokens.shape[1] - - ret = self.caption_tokens(patch_tokens.reshape(-1, self.embed_dim), compute_scores=compute_scores) - if compute_scores: - patch_tokens_capts_unrolled, patch_tokens_scores_unrolled = ret - outs['patch_tokens_capts'] = [patch_tokens_capts_unrolled[i * n_patches:(i + 1) * n_patches] for i in range(bs)] - outs['patch_tokens_scores'] = [patch_tokens_scores_unrolled[i * n_patches:(i + 1) * n_patches] for i in range(bs)] - else: - patch_tokens_capts_unrolled = ret - outs['patch_tokens_capts'] = [patch_tokens_capts_unrolled[i * n_patches:(i + 1) * n_patches] for i in range(bs)] - - if get_register_capts: - # AlphaClip doesn't have register tokens, return empty lists - outs['register_capts'] = [[] for _ in range(bs)] - if compute_scores: - outs['register_scores'] = [[] for _ in range(bs)] - - return outs - - def caption_bboxes(self, imgs, bboxes, capt_type='cls_capt', crop_boxes=False, compute_scores=False): - """ - - capt_type : str either 'avg_self_attn_capt' or 'cls_capt' - """ - device = next(self.parameters()).device - bs = len(imgs) - n_bboxes = bboxes.shape[1] - if not crop_boxes: - crops = process_bboxes(imgs, bboxes, self.image_transforms_no_crop).to(device) - else: - crops = process_bboxes(imgs, bboxes, self.image_transforms).to(device) - - n_batch = n_bboxes - capts = [] - scores = [] - # batching the inference of crops - for i in range(n_batch): - start = i * bs - end = start + bs if i < n_batch - 1 else crops.shape[0] - forward_out = self.forward(crops[start:end], - get_cls_capt=capt_type == 'cls_capt', - get_avg_self_attn_capt=capt_type == 'avg_self_attn_capt') - capts += forward_out[capt_type] - if compute_scores: - scores += forward_out[f"{capt_type}_scores"] - - # rearranging the captions ensuring shape BS x N_BBOXES - capts = [capts[i * n_bboxes:(i + 1) * n_bboxes] for i in range(bs)] - - ret = {'bbox_capts' : capts} - - if compute_scores: - scores = [scores[i * n_bboxes:(i + 1) * n_bboxes] for i in range(bs)] - ret['bbox_scores'] = scores - return ret - - def caption_tokens(self, dino_tokens, project=True, return_n_best_sims=None, compute_scores : bool = False): - - if self.viecap is not None: - if return_n_best_sims: - raise Exception("return_n_best_sims is not supported with viecap") - outs = self.viecap.forward(dino_tokens, compute_scores=compute_scores) - return outs - - if self.clipcap is not None: - if return_n_best_sims: - raise Exception("return_n_best_sims is not supported with clipcap") - outs = self.clipcap.forward(dino_tokens, compute_scores=compute_scores) - return outs - - if self.im_proj is None: - project = False - if self.calculate_argmax_text: - # if calculate_argmax_text we return the argmax of the similarities between tokens and memory without using the decoder - captions = self.im_proj.project(dino_tokens, normalize=self.normalize, return_argmax_text=True, return_n_best_sims=return_n_best_sims) - return captions if compute_scores is False else (captions, [1.0] * len(captions)) # we return a list of 1.0s as scores - if not self.embed_inversion: - # classical decoder forward - if project: - projected_outs = self.im_proj.project(dino_tokens, normalize=self.normalize) - else: - projected_outs = dino_tokens - outs = decoding_batched(self.decoder, projected_outs, compute_scores=compute_scores, decoding_method=self.decoding_method) - else: - # DINOv2 embedding inversion - clip_tokens = revert_transformation(self.im_proj.project(dino_tokens, normalize=self.normalize), A_pinv=self.talk2dino_A_pinv, b=self.talk2dino_b) - outs = decoding_batched(self.decoder, clip_tokens, compute_scores=compute_scores, decoding_method=self.decoding_method) - return outs - - def ctx_cleaner(self, dirty_embeds : torch.Tensor, ctx_embed : torch.Tensor, cleaning_type='orthogonal_projection', alpha=1.0, epsilon=1e-6): - if cleaning_type == 'orthogonal_projection': - #return dirty_embeds - (alpha * (dirty_embeds @ ctx_embed.t() / (torch.norm(ctx_embed, p=2) ** 2))) * ctx_embed - ctx_embed = ctx_embed.unsqueeze(1) # [batch_size, 1, embed_dim] - projection = (dirty_embeds @ ctx_embed.transpose(-1, -2)) / (torch.norm(ctx_embed, dim=-1, keepdim=True) ** 2) - return dirty_embeds - alpha * projection * ctx_embed - if cleaning_type == "contrastive_mask": - ctx_embed = ctx_embed.unsqueeze(1) # [batch_size, 1, embed_dim] - ctx_embed_norm = torch.norm(ctx_embed, p=2, dim=2, keepdim=True) + epsilon - mask = 1 - (ctx_embed / ctx_embed_norm) - specific_embedding = dirty_embeds * mask - return specific_embedding - - def analyze_feature_compatibility(self, imgs, analyze_layers=True): - """ - Analyze compatibility between different layer features and textual embeddings. - - Args: - imgs: Input images tensor - analyze_layers: Whether to compare layer3 vs layer4 features - - Returns: - Dictionary with compatibility metrics - """ - device = imgs.device - results = {} - - if self.dino is not None and not (hasattr(self.dino, 'visual') and hasattr(self.dino.visual, 'attnpool')): - print("Feature compatibility analysis only available for RegionCLIP ResNet models") - return results - - original_patch_size = self.patch_size - - with torch.no_grad(): - # Test both layer3 and layer4 if requested - layer_configs = [] - if analyze_layers: - layer_configs = [ - {'patch_size': 16, 'use_layer3': True, 'name': 'layer3'}, - {'patch_size': 32, 'use_layer3': False, 'name': 'layer4'} - ] - else: - # Use current configuration - use_layer3 = (self.patch_size == 16) - layer_configs = [{'patch_size': self.patch_size, 'use_layer3': use_layer3, 'name': f'layer{"3" if use_layer3 else "4"}'}] - - for config in layer_configs: - self.patch_size = config['patch_size'] - - # Get features from specified layer - dino_outs = self.dino.visual.forward_return_spatial_feats(imgs, use_layer3=config['use_layer3']) - features = dino_outs['x_norm_patchtokens'] # [B, num_patches, embed_dim] - cls_features = dino_outs['x_norm_clstoken'] # [B, embed_dim] - - layer_results = { - 'spatial_resolution': f"{int(features.shape[1]**0.5)}x{int(features.shape[1]**0.5)}", - 'embed_dim': features.shape[-1], - 'num_patches': features.shape[1] - } - - if self.im_proj is not None: - # Analyze patch features - patch_mean = features.mean(dim=1) # [B, embed_dim] - average across patches - projected_patches = self.im_proj.project(patch_mean, normalize=True) - - # Analyze CLS features - if cls_features is not None: - projected_cls = self.im_proj.project(cls_features, normalize=True) - - # Measure similarity to memory bank - cls_sims = torch.mm(projected_cls, self.im_proj.embs_dataset.T) - patch_sims = torch.mm(projected_patches, self.im_proj.embs_dataset.T) - - layer_results.update({ - 'cls_max_similarity': torch.max(cls_sims, dim=1)[0].mean().item(), - 'cls_mean_similarity': torch.mean(cls_sims).item(), - 'patch_max_similarity': torch.max(patch_sims, dim=1)[0].mean().item(), - 'patch_mean_similarity': torch.mean(patch_sims).item(), - 'cls_feature_norm': torch.norm(cls_features, dim=1).mean().item(), - 'patch_feature_norm': torch.norm(patch_mean, dim=1).mean().item(), - 'cls_projected_norm': torch.norm(projected_cls, dim=1).mean().item(), - 'patch_projected_norm': torch.norm(projected_patches, dim=1).mean().item() - }) - - # Feature distribution analysis - feature_std = torch.std(features.reshape(-1, features.shape[-1]), dim=0).mean().item() - projection_std = torch.std(projected_patches, dim=0).mean().item() - - layer_results.update({ - 'feature_variability': feature_std, - 'projection_variability': projection_std, - 'projection_efficiency': projection_std / (feature_std + 1e-8) # How well projection preserves variability - }) - - results[config['name']] = layer_results - - # Restore original configuration - self.patch_size = original_patch_size - - return results - - def print_compatibility_analysis(self, analysis_results): - """Print formatted compatibility analysis results.""" - print("\n" + "="*60) - print("REGIONCLIP LAYER COMPATIBILITY ANALYSIS") - print("="*60) - - for layer_name, metrics in analysis_results.items(): - print(f"\n{layer_name.upper()} FEATURES:") - print("-" * 30) - - # Basic info - print(f"Spatial Resolution: {metrics['spatial_resolution']}") - print(f"Embedding Dimension: {metrics['embed_dim']}") - print(f"Number of Patches: {metrics['num_patches']}") - - if 'cls_max_similarity' in metrics: - print(f"\nSimilarity to Text Memory Bank:") - print(f" CLS Token - Max: {metrics['cls_max_similarity']:.4f}, Mean: {metrics['cls_mean_similarity']:.4f}") - print(f" Patch Avg - Max: {metrics['patch_max_similarity']:.4f}, Mean: {metrics['patch_mean_similarity']:.4f}") - - print(f"\nFeature Norms:") - print(f" CLS Features: {metrics['cls_feature_norm']:.4f}") - print(f" Patch Features: {metrics['patch_feature_norm']:.4f}") - print(f" CLS Projected: {metrics['cls_projected_norm']:.4f}") - print(f" Patch Projected: {metrics['patch_projected_norm']:.4f}") - - print(f"\nProjection Quality:") - print(f" Feature Variability: {metrics['feature_variability']:.4f}") - print(f" Projection Variability: {metrics['projection_variability']:.4f}") - print(f" Projection Efficiency: {metrics['projection_efficiency']:.4f}") - - if len(analysis_results) == 2: - layer3_metrics = analysis_results.get('layer3', {}) - layer4_metrics = analysis_results.get('layer4', {}) - - if 'cls_max_similarity' in layer3_metrics and 'cls_max_similarity' in layer4_metrics: - print(f"\n{'COMPARISON (Layer3 vs Layer4)':^60}") - print("-" * 60) - - # Similarity comparison - l3_sim = layer3_metrics['patch_max_similarity'] - l4_sim = layer4_metrics['patch_max_similarity'] - better_sim = "Layer3" if l3_sim > l4_sim else "Layer4" - print(f"Better Text Similarity: {better_sim} ({max(l3_sim, l4_sim):.4f} vs {min(l3_sim, l4_sim):.4f})") - - # Projection efficiency comparison - l3_eff = layer3_metrics['projection_efficiency'] - l4_eff = layer4_metrics['projection_efficiency'] - better_eff = "Layer3" if l3_eff > l4_eff else "Layer4" - print(f"Better Projection Efficiency: {better_eff} ({max(l3_eff, l4_eff):.4f} vs {min(l3_eff, l4_eff):.4f})") - - # Spatial resolution comparison - print(f"Spatial Resolution: Layer3 ({layer3_metrics['spatial_resolution']}) vs Layer4 ({layer4_metrics['spatial_resolution']})") - - - def __len__(self): - return sum(p.numel() for p in self.parameters()) \ No newline at end of file diff --git a/src/proxyclip/open_clip_proxy/__init__.py b/src/proxyclip/open_clip_proxy/__init__.py deleted file mode 100644 index 23856a3f13d8ae592b343131345108b3432e43a3..0000000000000000000000000000000000000000 --- a/src/proxyclip/open_clip_proxy/__init__.py +++ /dev/null @@ -1,16 +0,0 @@ -from .coca_model import CoCa -from .constants import OPENAI_DATASET_MEAN, OPENAI_DATASET_STD -from .factory import create_model, create_model_and_transforms, create_model_from_pretrained, get_tokenizer, create_loss -from .factory import list_models, add_model_config, get_model_config, load_checkpoint -from .loss import ClipLoss, DistillClipLoss, CoCaLoss -from .model import CLIP, CustomTextCLIP, CLIPTextCfg, CLIPVisionCfg, \ - convert_weights_to_lp, convert_weights_to_fp16, trace_model, get_cast_dtype, get_input_dtype, \ - get_model_tokenize_cfg, get_model_preprocess_cfg, set_model_preprocess_cfg -from .openai import load_openai_model, list_openai_models -from .pretrained import list_pretrained, list_pretrained_models_by_tag, list_pretrained_tags_by_model, \ - get_pretrained_url, download_pretrained_from_url, is_pretrained_cfg, get_pretrained_cfg, download_pretrained -from .push_to_hf_hub import push_pretrained_to_hf_hub, push_to_hf_hub -from .tokenizer import SimpleTokenizer, tokenize, decode -from .transform import image_transform, AugmentationCfg -from .zero_shot_classifier import build_zero_shot_classifier, build_zero_shot_classifier_legacy -from .zero_shot_metadata import OPENAI_IMAGENET_TEMPLATES, SIMPLE_IMAGENET_TEMPLATES, IMAGENET_CLASSNAMES diff --git a/src/proxyclip/open_clip_proxy/big_vision.py b/src/proxyclip/open_clip_proxy/big_vision.py deleted file mode 100644 index 0d7eaf3fa543dba7d7517ac566c6364a5a893796..0000000000000000000000000000000000000000 --- a/src/proxyclip/open_clip_proxy/big_vision.py +++ /dev/null @@ -1,136 +0,0 @@ -import torch -import numpy as np - -from .model import CustomTextCLIP -from .transformer import TextTransformer, Transformer - - -@torch.no_grad() -def load_big_vision_weights(model: CustomTextCLIP, checkpoint_path: str): - """ Load weights from .npz checkpoints for official Google big_vision image-text models - - Currently the SigLIP source models are supported and a CustomTextCLIP destination model - w/ timm image encoder. - """ - from timm.layers import resample_patch_embed, resample_abs_pos_embed - - def _n2p(w, t=True): - if w.ndim == 4 and w.shape[0] == w.shape[1] == w.shape[2] == 1: - w = w.flatten() - if t: - if w.ndim == 4: - w = w.transpose([3, 2, 0, 1]) - elif w.ndim == 3: - w = w.transpose([2, 0, 1]) - elif w.ndim == 2: - w = w.transpose([1, 0]) - return torch.from_numpy(w) - - w = np.load(checkpoint_path) - interpolation = 'bilinear' - antialias = False - - def _convert_timm_img(module, prefix): - embed_conv_w = _n2p(w[f'{prefix}embedding/kernel']) - if embed_conv_w.shape[-2:] != module.patch_embed.proj.weight.shape[-2:]: - embed_conv_w = resample_patch_embed( - embed_conv_w, - module.patch_embed.proj.weight.shape[-2:], - interpolation=interpolation, - antialias=antialias, - verbose=True, - ) - module.patch_embed.proj.weight.copy_(embed_conv_w) - module.patch_embed.proj.bias.copy_(_n2p(w[f'{prefix}embedding/bias'])) - - if module.cls_token is not None: - module.cls_token.copy_(_n2p(w[f'{prefix}cls'], t=False)) - - pos_embed_w = _n2p(w[f'{prefix}pos_embedding'], t=False) - if pos_embed_w.shape != module.pos_embed.shape: - assert False, f'{pos_embed_w.shape}, {module.pos_embed.shape}' - num_prefix_tokens = 0 if getattr(module, 'no_embed_class', False) else getattr(module, 'num_prefix_tokens', 1) - pos_embed_w = resample_abs_pos_embed( # resize pos embedding when different size from pretrained weights - pos_embed_w, - new_size=module.patch_embed.grid_size, - num_prefix_tokens=num_prefix_tokens, - interpolation=interpolation, - antialias=antialias, - verbose=True, - ) - module.pos_embed.copy_(pos_embed_w) - - mha_sub, b_sub, ln1_sub = (0, 0, 1) - for i, block in enumerate(module.blocks.children()): - block_prefix = f'{prefix}Transformer/encoderblock_{i}/' - mha_prefix = block_prefix + f'MultiHeadDotProductAttention_{mha_sub}/' - block.norm1.weight.copy_(_n2p(w[f'{block_prefix}LayerNorm_0/scale'])) - block.norm1.bias.copy_(_n2p(w[f'{block_prefix}LayerNorm_0/bias'])) - block.attn.qkv.weight.copy_(torch.cat([ - _n2p(w[f'{mha_prefix}{n}/kernel'], t=False).flatten(1).T for n in ('query', 'key', 'value')])) - block.attn.qkv.bias.copy_(torch.cat([ - _n2p(w[f'{mha_prefix}{n}/bias'], t=False).reshape(-1) for n in ('query', 'key', 'value')])) - block.attn.proj.weight.copy_(_n2p(w[f'{mha_prefix}out/kernel']).flatten(1)) - block.attn.proj.bias.copy_(_n2p(w[f'{mha_prefix}out/bias'])) - for r in range(2): - getattr(block.mlp, f'fc{r + 1}').weight.copy_(_n2p(w[f'{block_prefix}MlpBlock_{b_sub}/Dense_{r}/kernel'])) - getattr(block.mlp, f'fc{r + 1}').bias.copy_(_n2p(w[f'{block_prefix}MlpBlock_{b_sub}/Dense_{r}/bias'])) - block.norm2.weight.copy_(_n2p(w[f'{block_prefix}LayerNorm_{ln1_sub}/scale'])) - block.norm2.bias.copy_(_n2p(w[f'{block_prefix}LayerNorm_{ln1_sub}/bias'])) - - module.norm.weight.copy_(_n2p(w[f'{prefix}Transformer/encoder_norm/scale'])) - module.norm.bias.copy_(_n2p(w[f'{prefix}Transformer/encoder_norm/bias'])) - - if module.attn_pool is not None: - block_prefix = f'{prefix}MAPHead_0/' - mha_prefix = block_prefix + f'MultiHeadDotProductAttention_0/' - module.attn_pool.latent.copy_(_n2p(w[f'{block_prefix}probe'], t=False)) - module.attn_pool.q.weight.copy_(_n2p(w[f'{mha_prefix}query/kernel'], t=False).flatten(1).T) - module.attn_pool.q.bias.copy_(_n2p(w[f'{mha_prefix}query/bias'], t=False).reshape(-1)) - module.attn_pool.kv.weight.copy_(torch.cat([ - _n2p(w[f'{mha_prefix}{n}/kernel'], t=False).flatten(1).T for n in ('key', 'value')])) - module.attn_pool.kv.bias.copy_(torch.cat([ - _n2p(w[f'{mha_prefix}{n}/bias'], t=False).reshape(-1) for n in ('key', 'value')])) - module.attn_pool.proj.weight.copy_(_n2p(w[f'{mha_prefix}out/kernel']).flatten(1)) - module.attn_pool.proj.bias.copy_(_n2p(w[f'{mha_prefix}out/bias'])) - module.attn_pool.norm.weight.copy_(_n2p(w[f'{block_prefix}LayerNorm_0/scale'])) - module.attn_pool.norm.bias.copy_(_n2p(w[f'{block_prefix}LayerNorm_0/bias'])) - for r in range(2): - getattr(module.attn_pool.mlp, f'fc{r + 1}').weight.copy_(_n2p(w[f'{block_prefix}MlpBlock_0/Dense_{r}/kernel'])) - getattr(module.attn_pool.mlp, f'fc{r + 1}').bias.copy_(_n2p(w[f'{block_prefix}MlpBlock_0/Dense_{r}/bias'])) - - def _convert_openclip_transformer(module: Transformer, prefix): - for i, block in enumerate(module.resblocks.children()): - block_prefix = f'{prefix}encoderblock_{i}/' - mha_prefix = block_prefix + f'MultiHeadDotProductAttention_0/' - block.ln_1.weight.copy_(_n2p(w[f'{block_prefix}LayerNorm_0/scale'])) - block.ln_1.bias.copy_(_n2p(w[f'{block_prefix}LayerNorm_0/bias'])) - block.attn.in_proj_weight.copy_(torch.cat([ - _n2p(w[f'{mha_prefix}{n}/kernel'], t=False).flatten(1).T for n in ('query', 'key', 'value')])) - block.attn.in_proj_bias.copy_(torch.cat([ - _n2p(w[f'{mha_prefix}{n}/bias'], t=False).reshape(-1) for n in ('query', 'key', 'value')])) - block.attn.out_proj.weight.copy_(_n2p(w[f'{mha_prefix}out/kernel']).flatten(1)) - block.attn.out_proj.bias.copy_(_n2p(w[f'{mha_prefix}out/bias'])) - block.ln_2.weight.copy_(_n2p(w[f'{block_prefix}LayerNorm_1/scale'])) - block.ln_2.bias.copy_(_n2p(w[f'{block_prefix}LayerNorm_1/bias'])) - block.mlp.c_fc.weight.copy_(_n2p(w[f'{block_prefix}MlpBlock_0/Dense_0/kernel'])) - block.mlp.c_fc.bias.copy_(_n2p(w[f'{block_prefix}MlpBlock_0/Dense_0/bias'])) - block.mlp.c_proj.weight.copy_(_n2p(w[f'{block_prefix}MlpBlock_0/Dense_1/kernel'])) - block.mlp.c_proj.bias.copy_(_n2p(w[f'{block_prefix}MlpBlock_0/Dense_1/bias'])) - - def _convert_openclip_txt(module: TextTransformer, prefix): - module.token_embedding.weight.copy_(_n2p(w[f'{prefix}Embed_0/embedding'], t=False)) - pos_embed_w = _n2p(w[f'{prefix}pos_embedding'], t=False).squeeze(0) - module.positional_embedding.copy_(pos_embed_w) - _convert_openclip_transformer(module.transformer, prefix=prefix + 'Encoder_0/') - module.ln_final.weight.copy_(_n2p(w[f'{prefix}Encoder_0/encoder_norm/scale'])) - module.ln_final.bias.copy_(_n2p(w[f'{prefix}Encoder_0/encoder_norm/bias'])) - module.text_projection.weight.copy_(_n2p(w[f'{prefix}head/kernel'])) - module.text_projection.bias.copy_(_n2p(w[f'{prefix}head/bias'])) - - _convert_timm_img(model.visual.trunk, 'params/img/') - _convert_openclip_txt(model.text, 'params/txt/') - model.logit_bias.copy_(_n2p(w['params/b'])[0]) - model.logit_scale.copy_(_n2p(w['params/t'])[0]) - - diff --git a/src/proxyclip/open_clip_proxy/bpe_simple_vocab_16e6.txt.gz b/src/proxyclip/open_clip_proxy/bpe_simple_vocab_16e6.txt.gz deleted file mode 100644 index 36a15856e00a06a9fbed8cdd34d2393fea4a3113..0000000000000000000000000000000000000000 --- a/src/proxyclip/open_clip_proxy/bpe_simple_vocab_16e6.txt.gz +++ /dev/null @@ -1,3 +0,0 @@ -version https://git-lfs.github.com/spec/v1 -oid sha256:924691ac288e54409236115652ad4aa250f48203de50a9e4722a6ecd48d6804a -size 1356917 diff --git a/src/proxyclip/open_clip_proxy/coca_model.py b/src/proxyclip/open_clip_proxy/coca_model.py deleted file mode 100644 index 272b2cc065a774307802634333c166327af6cf90..0000000000000000000000000000000000000000 --- a/src/proxyclip/open_clip_proxy/coca_model.py +++ /dev/null @@ -1,478 +0,0 @@ -from typing import Optional - -import torch -from torch import nn -from torch.nn import functional as F -import numpy as np -from dataclasses import dataclass - -from .transformer import ( - LayerNormFp32, - LayerNorm, - QuickGELU, - MultimodalTransformer, -) -from .model import CLIPTextCfg, CLIPVisionCfg, _build_vision_tower, _build_text_tower - -try: - from transformers import ( - BeamSearchScorer, - LogitsProcessorList, - TopPLogitsWarper, - TopKLogitsWarper, - RepetitionPenaltyLogitsProcessor, - MinLengthLogitsProcessor, - MaxLengthCriteria, - StoppingCriteriaList - ) - - GENERATION_TYPES = { - "top_k": TopKLogitsWarper, - "top_p": TopPLogitsWarper, - "beam_search": "beam_search" - } - _has_transformers = True -except ImportError as e: - GENERATION_TYPES = { - "top_k": None, - "top_p": None, - "beam_search": "beam_search" - } - _has_transformers = False - - -@dataclass -class MultimodalCfg(CLIPTextCfg): - mlp_ratio: int = 4 - dim_head: int = 64 - heads: int = 8 - n_queries: int = 256 - attn_pooler_heads: int = 8 - - -def _build_text_decoder_tower( - embed_dim, - multimodal_cfg, - quick_gelu: bool = False, - cast_dtype: Optional[torch.dtype] = None, -): - multimodal_cfg = MultimodalCfg(**multimodal_cfg) if isinstance(multimodal_cfg, dict) else multimodal_cfg - act_layer = QuickGELU if quick_gelu else nn.GELU - norm_layer = ( - LayerNormFp32 if cast_dtype in (torch.float16, torch.bfloat16) else LayerNorm - ) - - decoder = MultimodalTransformer( - context_length=multimodal_cfg.context_length, - width=multimodal_cfg.width, - heads=multimodal_cfg.heads, - layers=multimodal_cfg.layers, - ls_init_value=multimodal_cfg.ls_init_value, - output_dim=embed_dim, - act_layer=act_layer, - norm_layer=norm_layer, - ) - - return decoder - - -class CoCa(nn.Module): - def __init__( - self, - embed_dim, - multimodal_cfg: MultimodalCfg, - text_cfg: CLIPTextCfg, - vision_cfg: CLIPVisionCfg, - quick_gelu: bool = False, - init_logit_scale: float = np.log(1 / 0.07), - init_logit_bias: Optional[float] = None, - cast_dtype: Optional[torch.dtype] = None, - pad_id: int = 0, - ): - super().__init__() - multimodal_cfg = MultimodalCfg(**multimodal_cfg) if isinstance(multimodal_cfg, dict) else multimodal_cfg - text_cfg = CLIPTextCfg(**text_cfg) if isinstance(text_cfg, dict) else text_cfg - vision_cfg = CLIPVisionCfg(**vision_cfg) if isinstance(vision_cfg, dict) else vision_cfg - - self.text = _build_text_tower( - embed_dim=embed_dim, - text_cfg=text_cfg, - quick_gelu=quick_gelu, - cast_dtype=cast_dtype, - ) - - vocab_size = ( - text_cfg.vocab_size # for hf models - if hasattr(text_cfg, "hf_model_name") and text_cfg.hf_model_name is not None - else text_cfg.vocab_size - ) - - self.visual = _build_vision_tower( - embed_dim=embed_dim, - vision_cfg=vision_cfg, - quick_gelu=quick_gelu, - cast_dtype=cast_dtype, - ) - - self.text_decoder = _build_text_decoder_tower( - vocab_size, - multimodal_cfg=multimodal_cfg, - quick_gelu=quick_gelu, - cast_dtype=cast_dtype, - ) - - self.logit_scale = nn.Parameter(torch.ones([]) * init_logit_scale) - if init_logit_bias is not None: - self.logit_bias = nn.Parameter(torch.ones([]) * init_logit_bias) - else: - self.logit_bias = None - self.pad_id = pad_id - - self.context_length = multimodal_cfg.context_length - - @torch.jit.ignore - def set_grad_checkpointing(self, enable: bool = True): - self.visual.set_grad_checkpointing(enable) - self.text.set_grad_checkpointing(enable) - self.text_decoder.set_grad_checkpointing(enable) - - def _encode_image(self, images, normalize: bool = True): - image_latent, tokens_embs = self.visual(images) - image_latent = F.normalize(image_latent, dim=-1) if normalize else image_latent - return image_latent, tokens_embs - - def _encode_text(self, text, normalize: bool = True): - text_latent, token_emb = self.text(text) - text_latent = F.normalize(text_latent, dim=-1) if normalize else text_latent - return text_latent, token_emb - - def encode_image(self, images, normalize: bool = True): - image_latent, _ = self._encode_image(images, normalize=normalize) - return image_latent - - def encode_text(self, text, normalize: bool = True): - text_latent, _ = self._encode_text(text, normalize=normalize) - return text_latent - - def forward( - self, - image, - text: Optional[torch.Tensor] = None, - image_latent: Optional[torch.Tensor] = None, - image_embs: Optional[torch.Tensor] = None, - ): - if image_latent is None or image_embs is None: - image_latent, image_embs = self._encode_image(image) - - if text is None: - return {"image_features": image_latent, "image_embs": image_embs} - - text_latent, token_embs = self._encode_text(text) - - # TODO: add assertion to avoid bugs? - labels = text[:, -token_embs.shape[1]:] - - logits = self.text_decoder(image_embs, token_embs) - out_dict = { - "image_features": image_latent, - "text_features": text_latent, - "logits": logits, - "labels": labels, - "logit_scale": self.logit_scale.exp() - } - if self.logit_bias is not None: - out_dict["logit_bias"] = self.logit_bias - return out_dict - - def generate( - self, - image, - text=None, - seq_len=30, - max_seq_len=77, - temperature=1., - generation_type="beam_search", - top_p=0.1, # keep tokens in the 1 - top_p quantile - top_k=1, # keeps the top_k most probable tokens - pad_token_id=None, - eos_token_id=None, - sot_token_id=None, - num_beams=6, - num_beam_groups=3, - min_seq_len=5, - stopping_criteria=None, - repetition_penalty=1.0, - fixed_output_length=False # if True output.shape == (batch_size, seq_len) - ): - # taking many ideas and components from HuggingFace GenerationMixin - # https://huggingface.co/docs/transformers/main/en/main_classes/text_generation - assert _has_transformers, "Please install transformers for generate functionality. `pip install transformers`." - assert seq_len > min_seq_len, "seq_len must be larger than min_seq_len" - - with torch.no_grad(): - sot_token_id = 49406 if sot_token_id is None else sot_token_id - eos_token_id = 49407 if eos_token_id is None else eos_token_id - pad_token_id = self.pad_id if pad_token_id is None else pad_token_id - logit_processor = LogitsProcessorList( - [ - MinLengthLogitsProcessor(min_seq_len, eos_token_id), - RepetitionPenaltyLogitsProcessor(repetition_penalty), - ] - ) - - if stopping_criteria is None: - stopping_criteria = [MaxLengthCriteria(max_length=seq_len)] - - stopping_criteria = StoppingCriteriaList( - stopping_criteria - ) - - device = image.device - - if generation_type == "beam_search": - output = self._generate_beamsearch( - image_inputs=image, - pad_token_id=pad_token_id, - eos_token_id=eos_token_id, - sot_token_id=sot_token_id, - num_beams=num_beams, - num_beam_groups=num_beam_groups, - min_seq_len=min_seq_len, - stopping_criteria=stopping_criteria, - logit_processor=logit_processor, - ) - if fixed_output_length and output.shape[1] < seq_len: - return torch.cat( - (output, torch.ones(output.shape[0], seq_len-output.shape[1], device=device, dtype=output.dtype) * self.pad_id), - dim=1 - ) - return output - - elif generation_type == "top_p": - logit_warper = GENERATION_TYPES[generation_type](top_p) - elif generation_type == "top_k": - logit_warper = GENERATION_TYPES[generation_type](top_k) - else: - raise ValueError( - f"generation_type has to be one of " - f"{'| ' + ' | '.join(list(GENERATION_TYPES.keys())) + ' |'}." - ) - - image_latent, image_embs = self._encode_image(image) - - if text is None: - text = torch.ones((image.shape[0], 1), device=device, dtype=torch.long) * sot_token_id - - was_training = self.training - num_dims = len(text.shape) - - if num_dims == 1: - text = text[None, :] - - cur_len = text.shape[1] - self.eval() - out = text - - while True: - x = out[:, -max_seq_len:] - cur_len = x.shape[1] - logits = self(image, x, image_latent=image_latent, image_embs=image_embs)["logits"][:, -1] - mask = (out[:, -1] == eos_token_id) | (out[:, -1] == pad_token_id) - sample = torch.ones((out.shape[0], 1), device=device, dtype=torch.long) * pad_token_id - - if mask.all(): - if not fixed_output_length: - break - else: - logits = logits[~mask, :] - filtered_logits = logit_processor(x[~mask, :], logits) - filtered_logits = logit_warper(x[~mask, :], filtered_logits) - probs = F.softmax(filtered_logits / temperature, dim=-1) - - if (cur_len + 1 == seq_len): - sample[~mask, :] = torch.ones((sum(~mask), 1), device=device, dtype=torch.long) * eos_token_id - else: - sample[~mask, :] = torch.multinomial(probs, 1) - - out = torch.cat((out, sample), dim=-1) - - cur_len += 1 - - if stopping_criteria(out, None): - break - - if num_dims == 1: - out = out.squeeze(0) - - self.train(was_training) - return out - - def _generate_beamsearch( - self, - image_inputs, - pad_token_id=None, - eos_token_id=None, - sot_token_id=None, - num_beams=6, - num_beam_groups=3, - min_seq_len=5, - stopping_criteria=None, - logit_processor=None, - logit_warper=None, - ): - device = image_inputs.device - batch_size = image_inputs.shape[0] - image_inputs = torch.repeat_interleave(image_inputs, num_beams, dim=0) - image_latent, image_embs = self._encode_image(image_inputs) - - input_ids = torch.ones((batch_size * num_beams, 1), device=device, dtype=torch.long) - input_ids = input_ids * sot_token_id - beam_scorer = BeamSearchScorer( - batch_size=batch_size, - num_beams=num_beams, - device=device, - num_beam_groups=num_beam_groups, - ) - # instantiate logits processors - logits_processor = ( - LogitsProcessorList([MinLengthLogitsProcessor(min_seq_len, eos_token_id=eos_token_id)]) - if logit_processor is None - else logit_processor - ) - - num_beams = beam_scorer.num_beams - num_beam_groups = beam_scorer.num_beam_groups - num_sub_beams = num_beams // num_beam_groups - batch_size = len(beam_scorer._beam_hyps) // num_beam_groups - batch_beam_size, cur_len = input_ids.shape - beam_indices = None - - if num_beams * batch_size != batch_beam_size: - raise ValueError( - f"Batch dimension of `input_ids` should be {num_beams * batch_size}, but is {batch_beam_size}." - ) - - beam_scores = torch.full((batch_size, num_beams), -1e9, dtype=torch.float, device=device) - # initialise score of first beam of each group with 0 and the rest with 1e-9. This ensures that the beams in - # the same group don't produce same tokens everytime. - beam_scores[:, ::num_sub_beams] = 0 - beam_scores = beam_scores.view((batch_size * num_beams,)) - - while True: - - # predicted tokens in cur_len step - current_tokens = torch.zeros(batch_size * num_beams, dtype=input_ids.dtype, device=device) - - # indices which will form the beams in the next time step - reordering_indices = torch.zeros(batch_size * num_beams, dtype=torch.long, device=device) - - # do one decoder step on all beams of all sentences in batch - model_inputs = prepare_inputs_for_generation(input_ids=input_ids, image_inputs=image_inputs) - outputs = self( - model_inputs['images'], - model_inputs['text'], - image_latent=image_latent, - image_embs=image_embs - ) - - for beam_group_idx in range(num_beam_groups): - group_start_idx = beam_group_idx * num_sub_beams - group_end_idx = min(group_start_idx + num_sub_beams, num_beams) - group_size = group_end_idx - group_start_idx - - # indices of beams of current group among all sentences in batch - batch_group_indices = [] - - for batch_idx in range(batch_size): - batch_group_indices.extend( - [batch_idx * num_beams + idx for idx in range(group_start_idx, group_end_idx)] - ) - group_input_ids = input_ids[batch_group_indices] - - # select outputs of beams of currentg group only - next_token_logits = outputs['logits'][batch_group_indices, -1, :] - vocab_size = next_token_logits.shape[-1] - - next_token_scores_processed = logits_processor( - group_input_ids, next_token_logits, current_tokens=current_tokens, beam_group_idx=beam_group_idx - ) - next_token_scores = next_token_scores_processed + beam_scores[batch_group_indices].unsqueeze(-1) - next_token_scores = next_token_scores.expand_as(next_token_scores_processed) - - # reshape for beam search - next_token_scores = next_token_scores.view(batch_size, group_size * vocab_size) - - next_token_scores, next_tokens = torch.topk( - next_token_scores, 2 * group_size, dim=1, largest=True, sorted=True - ) - - next_indices = torch.div(next_tokens, vocab_size, rounding_mode="floor") - next_tokens = next_tokens % vocab_size - - # stateless - process_beam_indices = sum(beam_indices, ()) if beam_indices is not None else None - beam_outputs = beam_scorer.process( - group_input_ids, - next_token_scores, - next_tokens, - next_indices, - pad_token_id=pad_token_id, - eos_token_id=eos_token_id, - beam_indices=process_beam_indices, - group_index=beam_group_idx, - ) - beam_scores[batch_group_indices] = beam_outputs["next_beam_scores"] - beam_next_tokens = beam_outputs["next_beam_tokens"] - beam_idx = beam_outputs["next_beam_indices"] - - input_ids[batch_group_indices] = group_input_ids[beam_idx] - group_input_ids = torch.cat([group_input_ids[beam_idx, :], beam_next_tokens.unsqueeze(-1)], dim=-1) - current_tokens[batch_group_indices] = group_input_ids[:, -1] - - # (beam_idx // group_size) -> batch_idx - # (beam_idx % group_size) -> offset of idx inside the group - reordering_indices[batch_group_indices] = ( - num_beams * torch.div(beam_idx, group_size, rounding_mode="floor") + group_start_idx + (beam_idx % group_size) - ) - - input_ids = torch.cat([input_ids, current_tokens.unsqueeze(-1)], dim=-1) - - # increase cur_len - cur_len = cur_len + 1 - if beam_scorer.is_done or stopping_criteria(input_ids, None): - break - - final_beam_indices = sum(beam_indices, ()) if beam_indices is not None else None - sequence_outputs = beam_scorer.finalize( - input_ids, - beam_scores, - next_tokens, - next_indices, - pad_token_id=pad_token_id, - eos_token_id=eos_token_id, - max_length=stopping_criteria.max_length, - beam_indices=final_beam_indices, - ) - return sequence_outputs['sequences'] - - -def prepare_inputs_for_generation(input_ids, image_inputs, past=None, **kwargs): - if past: - input_ids = input_ids[:, -1].unsqueeze(-1) - - attention_mask = kwargs.get("attention_mask", None) - position_ids = kwargs.get("position_ids", None) - - if attention_mask is not None and position_ids is None: - # create position_ids on the fly for batch generation - position_ids = attention_mask.long().cumsum(-1) - 1 - position_ids.masked_fill_(attention_mask == 0, 1) - else: - position_ids = None - return { - "text": input_ids, - "images": image_inputs, - "past_key_values": past, - "position_ids": position_ids, - "attention_mask": attention_mask, - } diff --git a/src/proxyclip/open_clip_proxy/constants.py b/src/proxyclip/open_clip_proxy/constants.py deleted file mode 100644 index 599c48c03f7a1ed97af20cbc482db27984514622..0000000000000000000000000000000000000000 --- a/src/proxyclip/open_clip_proxy/constants.py +++ /dev/null @@ -1,6 +0,0 @@ -OPENAI_DATASET_MEAN = (0.48145466, 0.4578275, 0.40821073) -OPENAI_DATASET_STD = (0.26862954, 0.26130258, 0.27577711) -IMAGENET_MEAN = (0.485, 0.456, 0.406) -IMAGENET_STD = (0.229, 0.224, 0.225) -INCEPTION_MEAN = (0.5, 0.5, 0.5) -INCEPTION_STD = (0.5, 0.5, 0.5) diff --git a/src/proxyclip/open_clip_proxy/factory.py b/src/proxyclip/open_clip_proxy/factory.py deleted file mode 100644 index cf62d5a1c9e8556216f33b002eb94260218a6f8c..0000000000000000000000000000000000000000 --- a/src/proxyclip/open_clip_proxy/factory.py +++ /dev/null @@ -1,460 +0,0 @@ -import json -import logging -import os -import re -from copy import deepcopy -from dataclasses import asdict -from pathlib import Path -from typing import Any, Dict, Optional, Tuple, Union - -import torch - -from .constants import OPENAI_DATASET_MEAN, OPENAI_DATASET_STD -from .model import CLIP, CustomTextCLIP, convert_weights_to_lp, convert_to_custom_text_state_dict,\ - resize_pos_embed, get_cast_dtype, resize_text_pos_embed, set_model_preprocess_cfg -from .coca_model import CoCa -from .loss import ClipLoss, DistillClipLoss, CoCaLoss, SigLipLoss -from .openai import load_openai_model -from .pretrained import is_pretrained_cfg, get_pretrained_cfg, download_pretrained,\ - list_pretrained_tags_by_model, download_pretrained_from_hf -from .transform import image_transform_v2, AugmentationCfg, PreprocessCfg, merge_preprocess_dict, merge_preprocess_kwargs -from .tokenizer import HFTokenizer, SimpleTokenizer, DEFAULT_CONTEXT_LENGTH - -HF_HUB_PREFIX = 'hf-hub:' -_MODEL_CONFIG_PATHS = [Path(__file__).parent / f"model_configs/"] -_MODEL_CONFIGS = {} # directory (model_name: config) of model architecture configs - - -def _natural_key(string_): - return [int(s) if s.isdigit() else s for s in re.split(r'(\d+)', string_.lower())] - - -def _rescan_model_configs(): - global _MODEL_CONFIGS - - config_ext = ('.json',) - config_files = [] - for config_path in _MODEL_CONFIG_PATHS: - if config_path.is_file() and config_path.suffix in config_ext: - config_files.append(config_path) - elif config_path.is_dir(): - for ext in config_ext: - config_files.extend(config_path.glob(f'*{ext}')) - - for cf in config_files: - with open(cf, 'r') as f: - model_cfg = json.load(f) - if all(a in model_cfg for a in ('embed_dim', 'vision_cfg', 'text_cfg')): - _MODEL_CONFIGS[cf.stem] = model_cfg - - _MODEL_CONFIGS = {k: v for k, v in sorted(_MODEL_CONFIGS.items(), key=lambda x: _natural_key(x[0]))} - - -_rescan_model_configs() # initial populate of model config registry - - -def list_models(): - """ enumerate available model architectures based on config files """ - return list(_MODEL_CONFIGS.keys()) - - -def add_model_config(path): - """ add model config path or file and update registry """ - if not isinstance(path, Path): - path = Path(path) - _MODEL_CONFIG_PATHS.append(path) - _rescan_model_configs() - - -def get_model_config(model_name): - if model_name in _MODEL_CONFIGS: - return deepcopy(_MODEL_CONFIGS[model_name]) - else: - return None - - -def _get_hf_config(model_id, cache_dir=None): - config_path = download_pretrained_from_hf(model_id, filename='open_clip_config.json', cache_dir=cache_dir) - with open(config_path, 'r', encoding='utf-8') as f: - config = json.load(f) - return config - - -def get_tokenizer( - model_name: str = '', - context_length: Optional[int] = None, - **kwargs, -): - if model_name.startswith(HF_HUB_PREFIX): - model_name = model_name[len(HF_HUB_PREFIX):] - try: - config = _get_hf_config(model_name)['model_cfg'] - except Exception: - tokenizer = HFTokenizer( - model_name, - context_length=context_length or DEFAULT_CONTEXT_LENGTH, - **kwargs, - ) - return tokenizer - else: - config = get_model_config(model_name) - assert config is not None, f"No valid model config found for {model_name}." - - text_config = config.get('text_cfg', {}) - if 'tokenizer_kwargs' in text_config: - tokenizer_kwargs = dict(text_config['tokenizer_kwargs'], **kwargs) - else: - tokenizer_kwargs = kwargs - - if context_length is None: - context_length = text_config.get('context_length', DEFAULT_CONTEXT_LENGTH) - - if 'hf_tokenizer_name' in text_config: - tokenizer = HFTokenizer( - text_config['hf_tokenizer_name'], - context_length=context_length, - **tokenizer_kwargs, - ) - else: - tokenizer = SimpleTokenizer( - context_length=context_length, - **tokenizer_kwargs, - ) - - return tokenizer - - -def load_state_dict(checkpoint_path: str, map_location='cpu'): - checkpoint = torch.load(checkpoint_path, map_location=map_location) - if isinstance(checkpoint, dict) and 'state_dict' in checkpoint: - state_dict = checkpoint['state_dict'] - elif isinstance(checkpoint, torch.jit.ScriptModule): - state_dict = checkpoint.state_dict() - for key in ["input_resolution", "context_length", "vocab_size"]: - state_dict.pop(key, None) - else: - state_dict = checkpoint - if next(iter(state_dict.items()))[0].startswith('module'): - state_dict = {k[7:]: v for k, v in state_dict.items()} - return state_dict - - -def load_checkpoint(model, checkpoint_path, strict=True): - if Path(checkpoint_path).suffix in ('.npz', '.npy'): - from .big_vision import load_big_vision_weights - load_big_vision_weights(model, checkpoint_path) - return {} - - state_dict = load_state_dict(checkpoint_path) - # detect old format and make compatible with new format - if 'positional_embedding' in state_dict and not hasattr(model, 'positional_embedding'): - state_dict = convert_to_custom_text_state_dict(state_dict) - # If loading a non-SigLIP model for SigLIP training. See https://github.com/mlfoundations/open_clip/issues/712 - if 'logit_bias' not in state_dict and model.logit_bias is not None: - state_dict["logit_bias"] = torch.zeros_like(state_dict["logit_scale"]) - # Certain text transformers no longer expect position_ids after transformers==4.31 - position_id_key = 'text.transformer.embeddings.position_ids' - if position_id_key in state_dict and not hasattr(model, position_id_key): - del state_dict[position_id_key] - resize_pos_embed(state_dict, model) - resize_text_pos_embed(state_dict, model) - incompatible_keys = model.load_state_dict(state_dict, strict=strict) - return incompatible_keys - - -def create_model( - model_name: str, - pretrained: Optional[str] = None, - precision: str = 'fp32', - device: Union[str, torch.device] = 'cpu', - jit: bool = False, - force_quick_gelu: bool = False, - force_custom_text: bool = False, - force_patch_dropout: Optional[float] = None, - force_image_size: Optional[Union[int, Tuple[int, int]]] = None, - force_preprocess_cfg: Optional[Dict[str, Any]] = None, - pretrained_image: bool = False, - pretrained_hf: bool = True, - cache_dir: Optional[str] = None, - output_dict: Optional[bool] = None, - require_pretrained: bool = False, - **model_kwargs, -): - force_preprocess_cfg = force_preprocess_cfg or {} - preprocess_cfg = asdict(PreprocessCfg()) - has_hf_hub_prefix = model_name.startswith(HF_HUB_PREFIX) - if has_hf_hub_prefix: - model_id = model_name[len(HF_HUB_PREFIX):] - checkpoint_path = download_pretrained_from_hf(model_id, cache_dir=cache_dir) - config = _get_hf_config(model_id, cache_dir) - preprocess_cfg = merge_preprocess_dict(preprocess_cfg, config['preprocess_cfg']) - model_cfg = config['model_cfg'] - pretrained_hf = False # override, no need to load original HF text weights - else: - model_name = model_name.replace('/', '-') # for callers using old naming with / in ViT names - checkpoint_path = None - model_cfg = None - - if isinstance(device, str): - device = torch.device(device) - - if pretrained and pretrained.lower() == 'openai': - logging.info(f'Loading pretrained {model_name} from OpenAI.') - model = load_openai_model( - model_name, - precision=precision, - device=device, - cache_dir=cache_dir, - ) - else: - model_cfg = model_cfg or get_model_config(model_name) - if model_cfg is not None: - logging.info(f'Loaded {model_name} model config.') - else: - logging.error(f'Model config for {model_name} not found; available models {list_models()}.') - raise RuntimeError(f'Model config for {model_name} not found.') - - if force_quick_gelu: - # override for use of QuickGELU on non-OpenAI transformer models - model_cfg["quick_gelu"] = True - - if force_patch_dropout is not None: - # override the default patch dropout value - model_cfg["vision_cfg"]["patch_dropout"] = force_patch_dropout - - if force_image_size is not None: - # override model config's image size - model_cfg["vision_cfg"]["image_size"] = force_image_size - - is_timm_model = 'timm_model_name' in model_cfg.get('vision_cfg', {}) - if pretrained_image: - if is_timm_model: - # pretrained weight loading for timm models set via vision_cfg - model_cfg['vision_cfg']['timm_model_pretrained'] = True - else: - assert False, 'pretrained image towers currently only supported for timm models' - - # cast_dtype set for fp16 and bf16 (manual mixed-precision), not set for 'amp' or 'pure' modes - cast_dtype = get_cast_dtype(precision) - is_hf_model = 'hf_model_name' in model_cfg.get('text_cfg', {}) - if is_hf_model: - # load pretrained weights for HF text model IFF no CLIP weights being loaded - model_cfg['text_cfg']['hf_model_pretrained'] = pretrained_hf and not pretrained - custom_text = model_cfg.pop('custom_text', False) or force_custom_text or is_hf_model - - model_cfg = dict(model_cfg, **model_kwargs) # merge cfg dict w/ kwargs (kwargs overrides cfg) - if custom_text: - if "multimodal_cfg" in model_cfg: - model = CoCa(**model_cfg, cast_dtype=cast_dtype) - else: - model = CustomTextCLIP(**model_cfg, cast_dtype=cast_dtype) - else: - model = CLIP(**model_cfg, cast_dtype=cast_dtype) - - if precision in ("fp16", "bf16"): - dtype = torch.float16 if 'fp16' in precision else torch.bfloat16 - # manual mixed precision that matches original OpenAI behaviour - if is_timm_model: - # FIXME this is a bit janky, create timm based model in low-precision and - # then cast only LayerNormFp32 instances back to float32 so they don't break. - # Why? The convert_weights_to_lp fn only works with native models. - model.to(device=device, dtype=dtype) - from .transformer import LayerNormFp32 - - def _convert_ln(m): - if isinstance(m, LayerNormFp32): - m.weight.data = m.weight.data.to(torch.float32) - m.bias.data = m.bias.data.to(torch.float32) - model.apply(_convert_ln) - else: - model.to(device=device) - convert_weights_to_lp(model, dtype=dtype) - elif precision in ("pure_fp16", "pure_bf16"): - dtype = torch.float16 if 'fp16' in precision else torch.bfloat16 - model.to(device=device, dtype=dtype) - else: - model.to(device=device) - - pretrained_loaded = False - if pretrained: - checkpoint_path = '' - pretrained_cfg = get_pretrained_cfg(model_name, pretrained) - if pretrained_cfg: - checkpoint_path = download_pretrained(pretrained_cfg, cache_dir=cache_dir) - preprocess_cfg = merge_preprocess_dict(preprocess_cfg, pretrained_cfg) - elif os.path.exists(pretrained): - checkpoint_path = pretrained - - if checkpoint_path: - logging.info(f'Loading pretrained {model_name} weights ({pretrained}).') - load_checkpoint(model, checkpoint_path) - else: - error_str = ( - f'Pretrained weights ({pretrained}) not found for model {model_name}.' - f' Available pretrained tags ({list_pretrained_tags_by_model(model_name)}.') - logging.warning(error_str) - raise RuntimeError(error_str) - pretrained_loaded = True - elif has_hf_hub_prefix: - logging.info(f'Loading pretrained {model_name} weights ({checkpoint_path}).') - load_checkpoint(model, checkpoint_path) - pretrained_loaded = True - - if require_pretrained and not pretrained_loaded: - # callers of create_model_from_pretrained always expect pretrained weights - raise RuntimeError( - f'Pretrained weights were required for (model: {model_name}, pretrained: {pretrained}) but not loaded.') - - if output_dict and hasattr(model, "output_dict"): - model.output_dict = True - - if jit: - model = torch.jit.script(model) - - # set image preprocessing configuration in model attributes for convenience - if getattr(model.visual, 'image_size', None) is not None: - # use image_size set on model creation (via config or force_image_size arg) - force_preprocess_cfg['size'] = model.visual.image_size - set_model_preprocess_cfg(model, merge_preprocess_dict(preprocess_cfg, force_preprocess_cfg)) - - return model - - -def create_loss(args): - if args.distill: - return DistillClipLoss( - local_loss=args.local_loss, - gather_with_grad=args.gather_with_grad, - cache_labels=True, - rank=args.rank, - world_size=args.world_size, - use_horovod=args.horovod, - ) - elif "coca" in args.model.lower(): - return CoCaLoss( - caption_loss_weight=args.coca_caption_loss_weight, - clip_loss_weight=args.coca_contrastive_loss_weight, - local_loss=args.local_loss, - gather_with_grad=args.gather_with_grad, - cache_labels=True, - rank=args.rank, - world_size=args.world_size, - use_horovod=args.horovod, - ) - elif args.siglip: - assert not args.horovod, "Horovod not currently supported for SigLip" - return SigLipLoss( - rank=args.rank, - world_size=args.world_size, - ) - return ClipLoss( - local_loss=args.local_loss, - gather_with_grad=args.gather_with_grad, - cache_labels=True, - rank=args.rank, - world_size=args.world_size, - use_horovod=args.horovod, - ) - - -def create_model_and_transforms( - model_name: str, - pretrained: Optional[str] = None, - precision: str = 'fp32', - device: Union[str, torch.device] = 'cpu', - jit: bool = False, - force_quick_gelu: bool = False, - force_custom_text: bool = False, - force_patch_dropout: Optional[float] = None, - force_image_size: Optional[Union[int, Tuple[int, int]]] = None, - image_mean: Optional[Tuple[float, ...]] = None, - image_std: Optional[Tuple[float, ...]] = None, - image_interpolation: Optional[str] = None, - image_resize_mode: Optional[str] = None, # only effective for inference - aug_cfg: Optional[Union[Dict[str, Any], AugmentationCfg]] = None, - pretrained_image: bool = False, - pretrained_hf: bool = True, - cache_dir: Optional[str] = None, - output_dict: Optional[bool] = None, - **model_kwargs, -): - force_preprocess_cfg = merge_preprocess_kwargs( - {}, mean=image_mean, std=image_std, interpolation=image_interpolation, resize_mode=image_resize_mode) - - model = create_model( - model_name, - pretrained, - precision=precision, - device=device, - jit=jit, - force_quick_gelu=force_quick_gelu, - force_custom_text=force_custom_text, - force_patch_dropout=force_patch_dropout, - force_image_size=force_image_size, - force_preprocess_cfg=force_preprocess_cfg, - pretrained_image=pretrained_image, - pretrained_hf=pretrained_hf, - cache_dir=cache_dir, - output_dict=output_dict, - **model_kwargs, - ) - - pp_cfg = PreprocessCfg(**model.visual.preprocess_cfg) - - preprocess_train = image_transform_v2( - pp_cfg, - is_train=True, - aug_cfg=aug_cfg, - ) - preprocess_val = image_transform_v2( - pp_cfg, - is_train=False, - ) - - return model, preprocess_train, preprocess_val - - -def create_model_from_pretrained( - model_name: str, - pretrained: Optional[str] = None, - precision: str = 'fp32', - device: Union[str, torch.device] = 'cpu', - jit: bool = False, - force_quick_gelu: bool = False, - force_custom_text: bool = False, - force_image_size: Optional[Union[int, Tuple[int, int]]] = None, - image_mean: Optional[Tuple[float, ...]] = None, - image_std: Optional[Tuple[float, ...]] = None, - image_interpolation: Optional[str] = None, - image_resize_mode: Optional[str] = None, # only effective for inference - return_transform: bool = True, - cache_dir: Optional[str] = None, - **model_kwargs, -): - force_preprocess_cfg = merge_preprocess_kwargs( - {}, mean=image_mean, std=image_std, interpolation=image_interpolation, resize_mode=image_resize_mode) - - model = create_model( - model_name, - pretrained, - precision=precision, - device=device, - jit=jit, - force_quick_gelu=force_quick_gelu, - force_custom_text=force_custom_text, - force_image_size=force_image_size, - force_preprocess_cfg=force_preprocess_cfg, - cache_dir=cache_dir, - require_pretrained=True, - **model_kwargs, - ) - - if not return_transform: - return model - - preprocess = image_transform_v2( - PreprocessCfg(**model.visual.preprocess_cfg), - is_train=False, - ) - - return model, preprocess diff --git a/src/proxyclip/open_clip_proxy/hf_configs.py b/src/proxyclip/open_clip_proxy/hf_configs.py deleted file mode 100644 index 3d2067476500a7c16511af18696fc5e23b066aff..0000000000000000000000000000000000000000 --- a/src/proxyclip/open_clip_proxy/hf_configs.py +++ /dev/null @@ -1,67 +0,0 @@ -# HF architecture dict: -arch_dict = { - # https://huggingface.co/docs/transformers/model_doc/roberta#roberta - "roberta": { - "config_names": { - "context_length": "max_position_embeddings", - "vocab_size": "vocab_size", - "width": "hidden_size", - "heads": "num_attention_heads", - "layers": "num_hidden_layers", - "layer_attr": "layer", - "token_embeddings_attr": "embeddings" - }, - "pooler": "mean_pooler", - }, - # https://huggingface.co/docs/transformers/model_doc/xlm-roberta#transformers.XLMRobertaConfig - "xlm-roberta": { - "config_names": { - "context_length": "max_position_embeddings", - "vocab_size": "vocab_size", - "width": "hidden_size", - "heads": "num_attention_heads", - "layers": "num_hidden_layers", - "layer_attr": "layer", - "token_embeddings_attr": "embeddings" - }, - "pooler": "mean_pooler", - }, - # https://huggingface.co/docs/transformers/model_doc/mt5#mt5 - "mt5": { - "config_names": { - # unlimited seqlen - # https://github.com/google-research/text-to-text-transfer-transformer/issues/273 - # https://github.com/huggingface/transformers/blob/v4.24.0/src/transformers/models/t5/modeling_t5.py#L374 - "context_length": "", - "vocab_size": "vocab_size", - "width": "d_model", - "heads": "num_heads", - "layers": "num_layers", - "layer_attr": "block", - "token_embeddings_attr": "embed_tokens" - }, - "pooler": "mean_pooler", - }, - # https://huggingface.co/docs/transformers/model_doc/bert - "bert": { - "config_names": { - "context_length": "max_position_embeddings", - "vocab_size": "vocab_size", - "width": "hidden_size", - "heads": "num_attention_heads", - "layers": "num_hidden_layers", - }, - "pooler": "cls_pooler", - }, - # https://huggingface.co/docs/transformers/model_doc/m2m_100 - "m2m_100": { - "config_names": { - "context_length": "max_position_embeddings", - "vocab_size": "vocab_size", - "width": "d_model", - "heads": "encoder_attention_heads", - "layers": "encoder_layers", - }, - "pooler": "cls_pooler", - }, -} diff --git a/src/proxyclip/open_clip_proxy/hf_model.py b/src/proxyclip/open_clip_proxy/hf_model.py deleted file mode 100644 index 281a06cc5f16f41e17ba0e6ea9b5b29fab5bc076..0000000000000000000000000000000000000000 --- a/src/proxyclip/open_clip_proxy/hf_model.py +++ /dev/null @@ -1,193 +0,0 @@ -""" huggingface model adapter - -Wraps HuggingFace transformers (https://github.com/huggingface/transformers) models for use as a text tower in CLIP model. -""" -import re - -import torch -import torch.nn as nn -from torch import TensorType - -try: - import transformers - from transformers import AutoModel, AutoTokenizer, AutoConfig, PretrainedConfig - from transformers.modeling_outputs import BaseModelOutput, BaseModelOutputWithPooling, \ - BaseModelOutputWithPoolingAndCrossAttentions -except ImportError as e: - transformers = None - - - class BaseModelOutput: - pass - - - class PretrainedConfig: - pass - -from .hf_configs import arch_dict - - -# utils -def _camel2snake(s): - return re.sub(r'(? torch.Tensor: - # calculated ground-truth and cache if enabled - if self.prev_num_logits != num_logits or device not in self.labels: - labels = torch.arange(num_logits, device=device, dtype=torch.long) - if self.world_size > 1 and self.local_loss: - labels = labels + num_logits * self.rank - if self.cache_labels: - self.labels[device] = labels - self.prev_num_logits = num_logits - else: - labels = self.labels[device] - return labels - - def get_logits(self, image_features, text_features, logit_scale): - if self.world_size > 1: - all_image_features, all_text_features = gather_features( - image_features, text_features, - self.local_loss, self.gather_with_grad, self.rank, self.world_size, self.use_horovod) - - if self.local_loss: - logits_per_image = logit_scale * image_features @ all_text_features.T - logits_per_text = logit_scale * text_features @ all_image_features.T - else: - logits_per_image = logit_scale * all_image_features @ all_text_features.T - logits_per_text = logits_per_image.T - else: - logits_per_image = logit_scale * image_features @ text_features.T - logits_per_text = logit_scale * text_features @ image_features.T - - return logits_per_image, logits_per_text - - def forward(self, image_features, text_features, logit_scale, output_dict=False): - device = image_features.device - logits_per_image, logits_per_text = self.get_logits(image_features, text_features, logit_scale) - - labels = self.get_ground_truth(device, logits_per_image.shape[0]) - - total_loss = ( - F.cross_entropy(logits_per_image, labels) + - F.cross_entropy(logits_per_text, labels) - ) / 2 - - return {"contrastive_loss": total_loss} if output_dict else total_loss - - -class CoCaLoss(ClipLoss): - def __init__( - self, - caption_loss_weight, - clip_loss_weight, - pad_id=0, # pad_token for open_clip custom tokenizer - local_loss=False, - gather_with_grad=False, - cache_labels=False, - rank=0, - world_size=1, - use_horovod=False, - ): - super().__init__( - local_loss=local_loss, - gather_with_grad=gather_with_grad, - cache_labels=cache_labels, - rank=rank, - world_size=world_size, - use_horovod=use_horovod - ) - - self.clip_loss_weight = clip_loss_weight - self.caption_loss_weight = caption_loss_weight - self.caption_loss = nn.CrossEntropyLoss(ignore_index=pad_id) - - def forward(self, image_features, text_features, logits, labels, logit_scale, output_dict=False): - - clip_loss = torch.tensor(0) - - if self.clip_loss_weight: - clip_loss = super().forward(image_features, text_features, logit_scale) - clip_loss = self.clip_loss_weight * clip_loss - - caption_loss = self.caption_loss( - logits.permute(0, 2, 1), - labels, - ) - caption_loss = caption_loss * self.caption_loss_weight - - if output_dict: - return {"contrastive_loss": clip_loss, "caption_loss": caption_loss} - - return clip_loss, caption_loss - - -class DistillClipLoss(ClipLoss): - - def dist_loss(self, teacher_logits, student_logits): - return -(teacher_logits.softmax(dim=1) * student_logits.log_softmax(dim=1)).sum(dim=1).mean(dim=0) - - def forward( - self, - image_features, - text_features, - logit_scale, - dist_image_features, - dist_text_features, - dist_logit_scale, - output_dict=False, - ): - logits_per_image, logits_per_text = \ - self.get_logits(image_features, text_features, logit_scale) - - dist_logits_per_image, dist_logits_per_text = \ - self.get_logits(dist_image_features, dist_text_features, dist_logit_scale) - - labels = self.get_ground_truth(image_features.device, logits_per_image.shape[0]) - - contrastive_loss = ( - F.cross_entropy(logits_per_image, labels) + - F.cross_entropy(logits_per_text, labels) - ) / 2 - - distill_loss = ( - self.dist_loss(dist_logits_per_image, logits_per_image) + - self.dist_loss(dist_logits_per_text, logits_per_text) - ) / 2 - - if output_dict: - return {"contrastive_loss": contrastive_loss, "distill_loss": distill_loss} - - return contrastive_loss, distill_loss - - -def neighbour_exchange(from_rank, to_rank, tensor, group=None): - tensor_recv = torch.zeros_like(tensor) - send_op = torch.distributed.P2POp( - torch.distributed.isend, - tensor, - to_rank, - group=group, - ) - recv_op = torch.distributed.P2POp( - torch.distributed.irecv, - tensor_recv, - from_rank, - group=group, - ) - reqs = torch.distributed.batch_isend_irecv([send_op, recv_op]) - for req in reqs: - req.wait() - return tensor_recv - - -def neighbour_exchange_bidir(left_rank, right_rank, tensor_to_left, tensor_to_right, group=None): - tensor_from_left = torch.zeros_like(tensor_to_right) - tensor_from_right = torch.zeros_like(tensor_to_left) - send_op_left = torch.distributed.P2POp( - torch.distributed.isend, - tensor_to_left, - left_rank, - group=group, - ) - send_op_right = torch.distributed.P2POp( - torch.distributed.isend, - tensor_to_right, - right_rank, - group=group, - ) - recv_op_left = torch.distributed.P2POp( - torch.distributed.irecv, - tensor_from_left, - left_rank, - group=group, - ) - recv_op_right = torch.distributed.P2POp( - torch.distributed.irecv, - tensor_from_right, - right_rank, - group=group, - ) - reqs = torch.distributed.batch_isend_irecv([send_op_right, send_op_left, recv_op_right, recv_op_left]) - for req in reqs: - req.wait() - return tensor_from_right, tensor_from_left - - -class NeighbourExchange(torch.autograd.Function): - @staticmethod - def forward(ctx, from_rank, to_rank, group, tensor): - ctx.group = group - ctx.from_rank = from_rank - ctx.to_rank = to_rank - return neighbour_exchange(from_rank, to_rank, tensor, group=group) - - @staticmethod - def backward(ctx, grad_output): - return (None, None, None) + (NeighbourExchange.apply(ctx.to_rank, ctx.from_rank, ctx.group, grad_output),) - - -def neighbour_exchange_with_grad(from_rank, to_rank, tensor, group=None): - return NeighbourExchange.apply(from_rank, to_rank, group, tensor) - - -class NeighbourExchangeBidir(torch.autograd.Function): - @staticmethod - def forward(ctx, left_rank, right_rank, group, tensor_to_left, tensor_to_right): - ctx.group = group - ctx.left_rank = left_rank - ctx.right_rank = right_rank - return neighbour_exchange_bidir(left_rank, right_rank, tensor_to_left, tensor_to_right, group=group) - - @staticmethod - def backward(ctx, *grad_outputs): - return (None, None, None) + \ - NeighbourExchangeBidir.apply(ctx.right_rank, ctx.left_rank, ctx.group, *grad_outputs) - - -def neighbour_exchange_bidir_with_grad(left_rank, right_rank, tensor_to_left, tensor_to_right, group=None): - return NeighbourExchangeBidir.apply(left_rank, right_rank, group, tensor_to_left, tensor_to_right) - - -class SigLipLoss(nn.Module): - """ Sigmoid Loss for Language Image Pre-Training (SigLIP) - https://arxiv.org/abs/2303.15343 - - @article{zhai2023sigmoid, - title={Sigmoid loss for language image pre-training}, - author={Zhai, Xiaohua and Mustafa, Basil and Kolesnikov, Alexander and Beyer, Lucas}, - journal={arXiv preprint arXiv:2303.15343}, - year={2023} - } - """ - def __init__( - self, - cache_labels=False, - rank=0, - world_size=1, - bidir=True, - use_horovod=False, - ): - super().__init__() - self.cache_labels = cache_labels - self.rank = rank - self.world_size = world_size - assert not use_horovod # FIXME need to look at hvd ops for ring transfers - self.use_horovod = use_horovod - self.bidir = bidir - - # cache state FIXME cache not currently used, worthwhile? - self.prev_num_logits = 0 - self.labels = {} - - def get_ground_truth(self, device, dtype, num_logits, negative_only=False) -> torch.Tensor: - labels = -torch.ones((num_logits, num_logits), device=device, dtype=dtype) - if not negative_only: - labels = 2 * torch.eye(num_logits, device=device, dtype=dtype) + labels - return labels - - def get_logits(self, image_features, text_features, logit_scale, logit_bias=None): - logits = logit_scale * image_features @ text_features.T - if logit_bias is not None: - logits += logit_bias - return logits - - def _loss(self, image_features, text_features, logit_scale, logit_bias=None, negative_only=False): - logits = self.get_logits(image_features, text_features, logit_scale, logit_bias) - labels = self.get_ground_truth( - image_features.device, - image_features.dtype, - image_features.shape[0], - negative_only=negative_only, - ) - loss = -F.logsigmoid(labels * logits).sum() / image_features.shape[0] - return loss - - def forward(self, image_features, text_features, logit_scale, logit_bias, output_dict=False): - loss = self._loss(image_features, text_features, logit_scale, logit_bias) - - if self.world_size > 1: - # exchange text features w/ neighbour world_size - 1 times - right_rank = (self.rank + 1) % self.world_size - left_rank = (self.rank - 1 + self.world_size) % self.world_size - if self.bidir: - text_features_to_right = text_features_to_left = text_features - num_bidir, remainder = divmod(self.world_size - 1, 2) - for i in range(num_bidir): - text_features_recv = neighbour_exchange_bidir_with_grad( - left_rank, - right_rank, - text_features_to_left, - text_features_to_right, - ) - - for f in text_features_recv: - loss += self._loss( - image_features, - f, - logit_scale, - logit_bias, - negative_only=True, - ) - text_features_to_left, text_features_to_right = text_features_recv - - if remainder: - text_features_recv = neighbour_exchange_with_grad( - left_rank, right_rank, text_features_to_right) - - loss += self._loss( - image_features, - text_features_recv, - logit_scale, - logit_bias, - negative_only=True, - ) - else: - text_features_to_right = text_features - for i in range(self.world_size - 1): - text_features_from_left = neighbour_exchange_with_grad( - left_rank, right_rank, text_features_to_right) - - loss += self._loss( - image_features, - text_features_from_left, - logit_scale, - logit_bias, - negative_only=True, - ) - text_features_to_right = text_features_from_left - - return {"contrastive_loss": loss} if output_dict else loss diff --git a/src/proxyclip/open_clip_proxy/model.py b/src/proxyclip/open_clip_proxy/model.py deleted file mode 100644 index 089d912051851d7beecee646c580da60d526ccb6..0000000000000000000000000000000000000000 --- a/src/proxyclip/open_clip_proxy/model.py +++ /dev/null @@ -1,624 +0,0 @@ -""" CLIP Model - -Adapted from https://github.com/openai/CLIP. Originally MIT License, Copyright (c) 2021 OpenAI. -""" -import copy -import logging -import math -from dataclasses import dataclass -from typing import Any, Dict, Optional, Tuple, Union - -import numpy as np -import torch -import torch.nn.functional as F -from torch import nn -from torch.utils.checkpoint import checkpoint -from functools import partial - -from .hf_model import HFTextEncoder -from .modified_resnet import ModifiedResNet -from .timm_model import TimmModel -from .transformer import LayerNormFp32, LayerNorm, QuickGELU, Attention, VisionTransformer, TextTransformer,\ - text_global_pool -from .utils import to_2tuple - - -@dataclass -class CLIPVisionCfg: - layers: Union[Tuple[int, int, int, int], int] = 12 - width: int = 768 - head_width: int = 64 - mlp_ratio: float = 4.0 - patch_size: int = 16 - image_size: Union[Tuple[int, int], int] = 224 - - ls_init_value: Optional[float] = None # layer scale initial value - patch_dropout: float = 0. # what fraction of patches to dropout during training (0 would mean disabled and no patches dropped) - 0.5 to 0.75 recommended in the paper for optimal results - attentional_pool: bool = False # whether to use attentional pooler in the last embedding layer (overrides pool_type) - attn_pooler_queries: int = 256 # n_queries for attentional pooler - attn_pooler_heads: int = 8 # n heads for attentional_pooling - no_ln_pre: bool = False # disable pre transformer LayerNorm - pos_embed_type: str = 'learnable' - final_ln_after_pool: bool = False # apply final LayerNorm after pooling - pool_type: str = 'tok' - output_tokens: bool = False - act_kwargs: Optional[dict] = None - norm_kwargs: Optional[dict] = None - - timm_model_name: Optional[str] = None # a valid model name overrides layers, width, patch_size - timm_model_pretrained: bool = False # use (imagenet) pretrained weights for named model - timm_pool: str = 'avg' # feature pooling for timm model ('abs_attn', 'rot_attn', 'avg', '') - timm_proj: str = 'linear' # linear projection for timm model output ('linear', 'mlp', '') - timm_proj_bias: bool = False # enable bias final projection - timm_drop: float = 0. # head dropout - timm_drop_path: Optional[float] = None # backbone stochastic depth - - -@dataclass -class CLIPTextCfg: - context_length: int = 77 - vocab_size: int = 49408 - hf_tokenizer_name: Optional[str] = None - tokenizer_kwargs: Optional[dict] = None - - width: int = 512 - heads: int = 8 - layers: int = 12 - mlp_ratio: float = 4.0 - ls_init_value: Optional[float] = None # layer scale initial value - embed_cls: bool = False - pad_id: int = 0 - no_causal_mask: bool = False # disable causal masking - final_ln_after_pool: bool = False # apply final LayerNorm after pooling - pool_type: str = 'argmax' - proj_bias: bool = False - output_tokens: bool = False - act_kwargs: dict = None - norm_kwargs: dict = None - - # HuggingFace specific text tower config - hf_model_name: Optional[str] = None - hf_model_pretrained: bool = True - hf_proj_type: str = 'mlp' - hf_pooler_type: str = 'mean_pooler' # attentional pooling for HF models - - -def get_cast_dtype(precision: str): - cast_dtype = None - if precision == 'bf16': - cast_dtype = torch.bfloat16 - elif precision == 'fp16': - cast_dtype = torch.float16 - return cast_dtype - - -def get_input_dtype(precision: str): - input_dtype = None - if precision in ('bf16', 'pure_bf16'): - input_dtype = torch.bfloat16 - elif precision in ('fp16', 'pure_fp16'): - input_dtype = torch.float16 - return input_dtype - - -def _build_vision_tower( - embed_dim: int, - vision_cfg: CLIPVisionCfg, - quick_gelu: bool = False, - cast_dtype: Optional[torch.dtype] = None -): - if isinstance(vision_cfg, dict): - vision_cfg = CLIPVisionCfg(**vision_cfg) - - # OpenAI models are pretrained w/ QuickGELU but native nn.GELU is both faster and more - # memory efficient in recent PyTorch releases (>= 1.10). - # NOTE: timm models always use native GELU regardless of quick_gelu flag. - act_layer = QuickGELU if quick_gelu else nn.GELU - - if vision_cfg.timm_model_name: - visual = TimmModel( - vision_cfg.timm_model_name, - pretrained=vision_cfg.timm_model_pretrained, - pool=vision_cfg.timm_pool, - proj=vision_cfg.timm_proj, - proj_bias=vision_cfg.timm_proj_bias, - drop=vision_cfg.timm_drop, - drop_path=vision_cfg.timm_drop_path, - patch_drop=vision_cfg.patch_dropout if vision_cfg.patch_dropout > 0 else None, - embed_dim=embed_dim, - image_size=vision_cfg.image_size, - ) - elif isinstance(vision_cfg.layers, (tuple, list)): - vision_heads = vision_cfg.width * 32 // vision_cfg.head_width - visual = ModifiedResNet( - layers=vision_cfg.layers, - output_dim=embed_dim, - heads=vision_heads, - image_size=vision_cfg.image_size, - width=vision_cfg.width, - ) - else: - vision_heads = vision_cfg.width // vision_cfg.head_width - norm_layer = LayerNormFp32 if cast_dtype in (torch.float16, torch.bfloat16) else LayerNorm - if vision_cfg.norm_kwargs: - norm_layer = partial(norm_layer, **vision_cfg.norm_kwargs) - if vision_cfg.act_kwargs is not None: - act_layer = partial(act_layer, **vision_cfg.act_kwargs) - - visual = VisionTransformer( - image_size=vision_cfg.image_size, - patch_size=vision_cfg.patch_size, - width=vision_cfg.width, - layers=vision_cfg.layers, - heads=vision_heads, - mlp_ratio=vision_cfg.mlp_ratio, - ls_init_value=vision_cfg.ls_init_value, - patch_dropout=vision_cfg.patch_dropout, - attentional_pool=vision_cfg.attentional_pool, - attn_pooler_queries=vision_cfg.attn_pooler_queries, - attn_pooler_heads=vision_cfg.attn_pooler_heads, - pos_embed_type=vision_cfg.pos_embed_type, - no_ln_pre=vision_cfg.no_ln_pre, - final_ln_after_pool=vision_cfg.final_ln_after_pool, - pool_type=vision_cfg.pool_type, - output_tokens=vision_cfg.output_tokens, - output_dim=embed_dim, - act_layer=act_layer, - norm_layer=norm_layer, - ) - - return visual - - -def _build_text_tower( - embed_dim: int, - text_cfg: CLIPTextCfg, - quick_gelu: bool = False, - cast_dtype: Optional[torch.dtype] = None, -): - if isinstance(text_cfg, dict): - text_cfg = CLIPTextCfg(**text_cfg) - - if text_cfg.hf_model_name: - text = HFTextEncoder( - text_cfg.hf_model_name, - output_dim=embed_dim, - proj_type=text_cfg.hf_proj_type, - pooler_type=text_cfg.hf_pooler_type, - pretrained=text_cfg.hf_model_pretrained, - output_tokens=text_cfg.output_tokens, - ) - else: - act_layer = QuickGELU if quick_gelu else nn.GELU - norm_layer = LayerNormFp32 if cast_dtype in (torch.float16, torch.bfloat16) else LayerNorm - if text_cfg.norm_kwargs: - norm_layer = partial(norm_layer, **text_cfg.norm_kwargs) - if text_cfg.act_kwargs is not None: - act_layer = partial(act_layer, **text_cfg.act_kwargs) - - text = TextTransformer( - context_length=text_cfg.context_length, - vocab_size=text_cfg.vocab_size, - width=text_cfg.width, - heads=text_cfg.heads, - layers=text_cfg.layers, - mlp_ratio=text_cfg.mlp_ratio, - ls_init_value=text_cfg.ls_init_value, - output_dim=embed_dim, - embed_cls=text_cfg.embed_cls, - no_causal_mask=text_cfg.no_causal_mask, - pad_id=text_cfg.pad_id, - pool_type=text_cfg.pool_type, - proj_bias=text_cfg.proj_bias, - output_tokens=text_cfg.output_tokens, - act_layer=act_layer, - norm_layer=norm_layer, - ) - return text - - -class CLIP(nn.Module): - output_dict: torch.jit.Final[bool] - - def __init__( - self, - embed_dim: int, - vision_cfg: CLIPVisionCfg, - text_cfg: CLIPTextCfg, - quick_gelu: bool = False, - init_logit_scale: float = np.log(1 / 0.07), - init_logit_bias: Optional[float] = None, - cast_dtype: Optional[torch.dtype] = None, - output_dict: bool = False, - ): - super().__init__() - self.output_dict = output_dict - - self.visual = _build_vision_tower(embed_dim, vision_cfg, quick_gelu, cast_dtype) - - text = _build_text_tower(embed_dim, text_cfg, quick_gelu, cast_dtype) - self.transformer = text.transformer - self.context_length = text.context_length - self.vocab_size = text.vocab_size - self.token_embedding = text.token_embedding - self.positional_embedding = text.positional_embedding - self.ln_final = text.ln_final - self.text_projection = text.text_projection - self.text_pool_type = text.pool_type - self.register_buffer('attn_mask', text.attn_mask, persistent=False) - - self.logit_scale = nn.Parameter(torch.ones([]) * init_logit_scale) - if init_logit_bias is not None: - self.logit_bias = nn.Parameter(torch.ones([]) * init_logit_bias) - else: - self.logit_bias = None - - def lock_image_tower(self, unlocked_groups=0, freeze_bn_stats=False): - # lock image tower as per LiT - https://arxiv.org/abs/2111.07991 - self.visual.lock(unlocked_groups=unlocked_groups, freeze_bn_stats=freeze_bn_stats) - - @torch.jit.ignore - def set_grad_checkpointing(self, enable=True): - self.visual.set_grad_checkpointing(enable) - self.transformer.grad_checkpointing = enable - - def encode_image(self, image, external_feats, beta, gamma, normalize: bool = False): - features = self.visual(image, external_feats, beta, gamma) - return F.normalize(features, dim=-1) if normalize else features - - def encode_text(self, text, normalize: bool = False): - cast_dtype = self.transformer.get_cast_dtype() - - x = self.token_embedding(text).to(cast_dtype) # [batch_size, n_ctx, d_model] - - x = x + self.positional_embedding.to(cast_dtype) - x = x.permute(1, 0, 2) # NLD -> LND - x = self.transformer(x, attn_mask=self.attn_mask) - x = x.permute(1, 0, 2) # LND -> NLD - x = self.ln_final(x) # [batch_size, n_ctx, transformer.width] - x, _ = text_global_pool(x, text, self.text_pool_type) - if self.text_projection is not None: - if isinstance(self.text_projection, nn.Linear): - x = self.text_projection(x) - else: - x = x @ self.text_projection - - return F.normalize(x, dim=-1) if normalize else x - - def get_logits(self, image, text): - image_features = self.encode_image(image, normalize=True) - text_features = self.encode_text(text, normalize=True) - image_logits = self.logit_scale.exp() * image_features @ text_features.T - if self.logit_bias is not None: - image_logits += self.logit_bias - text_logits = image_logits.T - return image_logits, text_logits - - def forward( - self, - image: Optional[torch.Tensor] = None, - text: Optional[torch.Tensor] = None, - ): - image_features = self.encode_image(image, normalize=True) if image is not None else None - text_features = self.encode_text(text, normalize=True) if text is not None else None - - if self.output_dict: - out_dict = { - "image_features": image_features, - "text_features": text_features, - "logit_scale": self.logit_scale.exp() - } - if self.logit_bias is not None: - out_dict['logit_bias'] = self.logit_bias - return out_dict - - if self.logit_bias is not None: - return image_features, text_features, self.logit_scale.exp(), self.logit_bias - return image_features, text_features, self.logit_scale.exp() - - -class CustomTextCLIP(nn.Module): - output_dict: torch.jit.Final[bool] - - def __init__( - self, - embed_dim: int, - vision_cfg: CLIPVisionCfg, - text_cfg: CLIPTextCfg, - quick_gelu: bool = False, - init_logit_scale: float = np.log(1 / 0.07), - init_logit_bias: Optional[float] = None, - cast_dtype: Optional[torch.dtype] = None, - output_dict: bool = False, - ): - super().__init__() - self.output_dict = output_dict - self.visual = _build_vision_tower(embed_dim, vision_cfg, quick_gelu, cast_dtype) - self.text = _build_text_tower(embed_dim, text_cfg, quick_gelu, cast_dtype) - self.context_length = self.text.context_length - self.vocab_size = self.text.vocab_size - self.logit_scale = nn.Parameter(torch.ones([]) * init_logit_scale) - if init_logit_bias is not None: - self.logit_bias = nn.Parameter(torch.ones([]) * init_logit_bias) - else: - self.logit_bias = None - - def lock_image_tower(self, unlocked_groups=0, freeze_bn_stats=False): - # lock image tower as per LiT - https://arxiv.org/abs/2111.07991 - self.visual.lock(unlocked_groups=unlocked_groups, freeze_bn_stats=freeze_bn_stats) - - def lock_text_tower(self, unlocked_layers: int = 0, freeze_layer_norm: bool = True): - self.text.lock(unlocked_layers, freeze_layer_norm) - - @torch.jit.ignore - def set_grad_checkpointing(self, enable=True): - self.visual.set_grad_checkpointing(enable) - self.text.set_grad_checkpointing(enable) - - def encode_image(self, image, normalize: bool = False): - features = self.visual(image) - return F.normalize(features, dim=-1) if normalize else features - - def encode_text(self, text, normalize: bool = False): - features = self.text(text) - return F.normalize(features, dim=-1) if normalize else features - - def get_logits(self, image, text): - image_features = self.encode_image(image, normalize=True) - text_features = self.encode_text(text, normalize=True) - image_logits = self.logit_scale.exp() * image_features @ text_features.T - if self.logit_bias is not None: - image_logits += self.logit_bias - text_logits = image_logits.T - return image_logits, text_logits - - def forward( - self, - image: Optional[torch.Tensor] = None, - text: Optional[torch.Tensor] = None, - ): - image_features = self.encode_image(image, normalize=True) if image is not None else None - text_features = self.encode_text(text, normalize=True) if text is not None else None - - if self.output_dict: - out_dict = { - "image_features": image_features, - "text_features": text_features, - "logit_scale": self.logit_scale.exp() - } - if self.logit_bias is not None: - out_dict['logit_bias'] = self.logit_bias - return out_dict - - if self.logit_bias is not None: - return image_features, text_features, self.logit_scale.exp(), self.logit_bias - return image_features, text_features, self.logit_scale.exp() - - -def convert_weights_to_lp(model: nn.Module, dtype=torch.float16): - """Convert applicable model parameters to low-precision (bf16 or fp16)""" - - def _convert_weights(l): - if isinstance(l, (nn.Conv1d, nn.Conv2d, nn.Linear)): - l.weight.data = l.weight.data.to(dtype) - if l.bias is not None: - l.bias.data = l.bias.data.to(dtype) - - if isinstance(l, (nn.MultiheadAttention, Attention)): - for attr in [*[f"{s}_proj_weight" for s in ["in", "q", "k", "v"]], "in_proj_bias", "bias_k", "bias_v"]: - tensor = getattr(l, attr) - if tensor is not None: - tensor.data = tensor.data.to(dtype) - - if isinstance(l, (CLIP, TextTransformer)): - # convert text nn.Parameter projections - attr = getattr(l, "text_projection", None) - if attr is not None: - attr.data = attr.data.to(dtype) - - if isinstance(l, VisionTransformer): - # convert vision nn.Parameter projections - attr = getattr(l, "proj", None) - if attr is not None: - attr.data = attr.data.to(dtype) - - model.apply(_convert_weights) - - -convert_weights_to_fp16 = convert_weights_to_lp # backwards compat - - -# used to maintain checkpoint compatibility -def convert_to_custom_text_state_dict(state_dict: dict): - if 'text_projection' in state_dict: - # old format state_dict, move text tower -> .text - new_state_dict = {} - for k, v in state_dict.items(): - if any(k.startswith(p) for p in ( - 'text_projection', - 'positional_embedding', - 'token_embedding', - 'transformer', - 'ln_final', - )): - k = 'text.' + k - new_state_dict[k] = v - return new_state_dict - return state_dict - - -def build_model_from_openai_state_dict( - state_dict: dict, - quick_gelu=True, - cast_dtype=torch.float16, -): - vit = "visual.proj" in state_dict - - if vit: - vision_width = state_dict["visual.conv1.weight"].shape[0] - vision_layers = len( - [k for k in state_dict.keys() if k.startswith("visual.") and k.endswith(".attn.in_proj_weight")]) - vision_patch_size = state_dict["visual.conv1.weight"].shape[-1] - grid_size = round((state_dict["visual.positional_embedding"].shape[0] - 1) ** 0.5) - image_size = vision_patch_size * grid_size - else: - counts: list = [ - len(set(k.split(".")[2] for k in state_dict if k.startswith(f"visual.layer{b}"))) for b in [1, 2, 3, 4]] - vision_layers = tuple(counts) - vision_width = state_dict["visual.layer1.0.conv1.weight"].shape[0] - output_width = round((state_dict["visual.attnpool.positional_embedding"].shape[0] - 1) ** 0.5) - vision_patch_size = None - assert output_width ** 2 + 1 == state_dict["visual.attnpool.positional_embedding"].shape[0] - image_size = output_width * 32 - - embed_dim = state_dict["text_projection"].shape[1] - context_length = state_dict["positional_embedding"].shape[0] - vocab_size = state_dict["token_embedding.weight"].shape[0] - transformer_width = state_dict["ln_final.weight"].shape[0] - transformer_heads = transformer_width // 64 - transformer_layers = len(set(k.split(".")[2] for k in state_dict if k.startswith(f"transformer.resblocks"))) - - vision_cfg = CLIPVisionCfg( - layers=vision_layers, - width=vision_width, - patch_size=vision_patch_size, - image_size=image_size, - ) - text_cfg = CLIPTextCfg( - context_length=context_length, - vocab_size=vocab_size, - width=transformer_width, - heads=transformer_heads, - layers=transformer_layers, - ) - model = CLIP( - embed_dim, - vision_cfg=vision_cfg, - text_cfg=text_cfg, - quick_gelu=quick_gelu, # OpenAI models were trained with QuickGELU - cast_dtype=cast_dtype, - ) - - for key in ["input_resolution", "context_length", "vocab_size"]: - state_dict.pop(key, None) - convert_weights_to_fp16(model) # OpenAI state dicts are partially converted to float16 - model.load_state_dict(state_dict) - return model.eval() - - -def trace_model(model, batch_size=256, device=torch.device('cpu')): - model.eval() - image_size = model.visual.image_size - example_images = torch.ones((batch_size, 3, image_size, image_size), device=device) - example_text = torch.zeros((batch_size, model.context_length), dtype=torch.int, device=device) - model = torch.jit.trace_module( - model, - inputs=dict( - forward=(example_images, example_text), - encode_text=(example_text,), - encode_image=(example_images,) - )) - model.visual.image_size = image_size - return model - - -def resize_pos_embed(state_dict, model, interpolation: str = 'bicubic', antialias: bool = True): - # Rescale the grid of position embeddings when loading from state_dict - old_pos_embed = state_dict.get('visual.positional_embedding', None) - if old_pos_embed is None or not hasattr(model.visual, 'grid_size'): - return - grid_size = to_2tuple(model.visual.grid_size) - extra_tokens = 1 # FIXME detect different token configs (ie no class token, or more) - new_seq_len = grid_size[0] * grid_size[1] + extra_tokens - if new_seq_len == old_pos_embed.shape[0]: - return - - if extra_tokens: - pos_emb_tok, pos_emb_img = old_pos_embed[:extra_tokens], old_pos_embed[extra_tokens:] - else: - pos_emb_tok, pos_emb_img = None, old_pos_embed - old_grid_size = to_2tuple(int(math.sqrt(len(pos_emb_img)))) - - logging.info('Resizing position embedding grid-size from %s to %s', old_grid_size, grid_size) - pos_emb_img = pos_emb_img.reshape(1, old_grid_size[0], old_grid_size[1], -1).permute(0, 3, 1, 2) - pos_emb_img = F.interpolate( - pos_emb_img, - size=grid_size, - mode=interpolation, - antialias=antialias, - align_corners=False, - ) - pos_emb_img = pos_emb_img.permute(0, 2, 3, 1).reshape(1, grid_size[0] * grid_size[1], -1)[0] - if pos_emb_tok is not None: - new_pos_embed = torch.cat([pos_emb_tok, pos_emb_img], dim=0) - else: - new_pos_embed = pos_emb_img - state_dict['visual.positional_embedding'] = new_pos_embed - - -def resize_text_pos_embed(state_dict, model, interpolation: str = 'linear', antialias: bool = False): - old_pos_embed = state_dict.get('positional_embedding', None) - if old_pos_embed is None: - return - # FIXME add support for text cls_token - model_pos_embed = getattr(model, 'positional_embedding', None) - if model_pos_embed is None: - model_pos_embed = getattr(model.text, 'positional_embedding', None) - - old_num_pos = old_pos_embed.shape[0] - old_width = old_pos_embed.shape[1] - num_pos = model_pos_embed.shape[0] - width = model_pos_embed.shape[1] - assert old_width == width, 'text pos_embed width changed!' - if old_num_pos == num_pos: - return - - logging.info('Resizing text position embedding num_pos from %s to %s', old_num_pos, num_pos) - old_pos_embed = old_pos_embed.reshape(1, old_num_pos, old_width).permute(0, 2, 1) - old_pos_embed = F.interpolate( - old_pos_embed, - size=num_pos, - mode=interpolation, - antialias=antialias, - align_corners=False, - ) - old_pos_embed = old_pos_embed.permute(0, 2, 1)[0] - new_pos_embed = old_pos_embed - - state_dict['positional_embedding'] = new_pos_embed - - -def get_model_preprocess_cfg(model): - module = getattr(model, 'visual', model) - preprocess_cfg = getattr(module, 'preprocess_cfg', {}) - if not preprocess_cfg: - # use separate legacy attributes if preprocess_cfg dict not found - size = getattr(module, 'image_size') - if size is not None: - preprocess_cfg['size'] = size - mean = getattr(module, 'image_mean', None) - if mean is not None: - preprocess_cfg['mean'] = mean - std = getattr(module, 'image_std', None) - if std is not None: - preprocess_cfg['std'] = std - return preprocess_cfg - - -def set_model_preprocess_cfg(model, preprocess_cfg: Dict[str, Any]): - module = getattr(model, 'visual', model) - module.image_mean = preprocess_cfg['mean'] # legacy attribute, keeping for bwd compat - module.image_std = preprocess_cfg['std'] # legacy attribute, keeping for bwd compat - module.preprocess_cfg = copy.deepcopy(preprocess_cfg) # new attr, package all pp cfg as dict - - -def get_model_tokenize_cfg(model): - module = getattr(model, 'text', model) - cfg = {} - context_length = getattr(module, 'context_length', None) - if context_length is not None: - cfg['context_length'] = context_length - vocab_size = getattr(module, 'vocab_size', None) - if vocab_size is not None: - cfg['vocab_size'] = vocab_size - return cfg \ No newline at end of file diff --git a/src/proxyclip/open_clip_proxy/model_configs/EVA01-g-14-plus.json b/src/proxyclip/open_clip_proxy/model_configs/EVA01-g-14-plus.json deleted file mode 100644 index 73f46a71e664fce987218b8eb48903e7bd895f41..0000000000000000000000000000000000000000 --- a/src/proxyclip/open_clip_proxy/model_configs/EVA01-g-14-plus.json +++ /dev/null @@ -1,18 +0,0 @@ -{ - "embed_dim": 1024, - "vision_cfg": { - "image_size": 224, - "timm_model_name": "eva_giant_patch14_224", - "timm_model_pretrained": false, - "timm_pool": "token", - "timm_proj": null - }, - "text_cfg": { - "context_length": 77, - "vocab_size": 49408, - "width": 1024, - "heads": 16, - "layers": 24 - }, - "custom_text": true -} \ No newline at end of file diff --git a/src/proxyclip/open_clip_proxy/model_configs/EVA01-g-14.json b/src/proxyclip/open_clip_proxy/model_configs/EVA01-g-14.json deleted file mode 100644 index 9d0e80f290d9491b7c46fafd576201b1258165aa..0000000000000000000000000000000000000000 --- a/src/proxyclip/open_clip_proxy/model_configs/EVA01-g-14.json +++ /dev/null @@ -1,18 +0,0 @@ -{ - "embed_dim": 1024, - "vision_cfg": { - "image_size": 224, - "timm_model_name": "eva_giant_patch14_224", - "timm_model_pretrained": false, - "timm_pool": "token", - "timm_proj": null - }, - "text_cfg": { - "context_length": 77, - "vocab_size": 49408, - "width": 768, - "heads": 12, - "layers": 12 - }, - "custom_text": true -} \ No newline at end of file diff --git a/src/proxyclip/open_clip_proxy/model_configs/EVA02-B-16.json b/src/proxyclip/open_clip_proxy/model_configs/EVA02-B-16.json deleted file mode 100644 index 3f92357287e1f6600da1e7f391cb6370d7f66de4..0000000000000000000000000000000000000000 --- a/src/proxyclip/open_clip_proxy/model_configs/EVA02-B-16.json +++ /dev/null @@ -1,18 +0,0 @@ -{ - "embed_dim": 512, - "vision_cfg": { - "image_size": 224, - "timm_model_name": "eva02_base_patch16_clip_224", - "timm_model_pretrained": false, - "timm_pool": "token", - "timm_proj": null - }, - "text_cfg": { - "context_length": 77, - "vocab_size": 49408, - "width": 512, - "heads": 8, - "layers": 12 - }, - "custom_text": true -} \ No newline at end of file diff --git a/src/proxyclip/open_clip_proxy/model_configs/EVA02-E-14-plus.json b/src/proxyclip/open_clip_proxy/model_configs/EVA02-E-14-plus.json deleted file mode 100644 index e250c2a404c86ff168c54cfcf71bc2492be1b74c..0000000000000000000000000000000000000000 --- a/src/proxyclip/open_clip_proxy/model_configs/EVA02-E-14-plus.json +++ /dev/null @@ -1,18 +0,0 @@ -{ - "embed_dim": 1024, - "vision_cfg": { - "image_size": 224, - "timm_model_name": "eva02_enormous_patch14_clip_224", - "timm_model_pretrained": false, - "timm_pool": "token", - "timm_proj": null - }, - "text_cfg": { - "context_length": 77, - "vocab_size": 49408, - "width": 1280, - "heads": 20, - "layers": 32 - }, - "custom_text": true -} \ No newline at end of file diff --git a/src/proxyclip/open_clip_proxy/model_configs/EVA02-E-14.json b/src/proxyclip/open_clip_proxy/model_configs/EVA02-E-14.json deleted file mode 100644 index 4b6648e25092b151a9095e0a66956c7ebf835b16..0000000000000000000000000000000000000000 --- a/src/proxyclip/open_clip_proxy/model_configs/EVA02-E-14.json +++ /dev/null @@ -1,18 +0,0 @@ -{ - "embed_dim": 1024, - "vision_cfg": { - "image_size": 224, - "timm_model_name": "eva02_enormous_patch14_clip_224", - "timm_model_pretrained": false, - "timm_pool": "token", - "timm_proj": null - }, - "text_cfg": { - "context_length": 77, - "vocab_size": 49408, - "width": 1024, - "heads": 16, - "layers": 24 - }, - "custom_text": true -} \ No newline at end of file diff --git a/src/proxyclip/open_clip_proxy/model_configs/EVA02-L-14-336.json b/src/proxyclip/open_clip_proxy/model_configs/EVA02-L-14-336.json deleted file mode 100644 index 2bb07f3c082fd88c4e86131b272163aaacfaef9e..0000000000000000000000000000000000000000 --- a/src/proxyclip/open_clip_proxy/model_configs/EVA02-L-14-336.json +++ /dev/null @@ -1,18 +0,0 @@ -{ - "embed_dim": 768, - "vision_cfg": { - "image_size": 336, - "timm_model_name": "eva02_large_patch14_clip_336", - "timm_model_pretrained": false, - "timm_pool": "token", - "timm_proj": null - }, - "text_cfg": { - "context_length": 77, - "vocab_size": 49408, - "width": 768, - "heads": 12, - "layers": 12 - }, - "custom_text": true -} \ No newline at end of file diff --git a/src/proxyclip/open_clip_proxy/model_configs/EVA02-L-14.json b/src/proxyclip/open_clip_proxy/model_configs/EVA02-L-14.json deleted file mode 100644 index b4c7f377bc543aa92a145358f2630a58ae9be989..0000000000000000000000000000000000000000 --- a/src/proxyclip/open_clip_proxy/model_configs/EVA02-L-14.json +++ /dev/null @@ -1,18 +0,0 @@ -{ - "embed_dim": 768, - "vision_cfg": { - "image_size": 224, - "timm_model_name": "eva02_large_patch14_clip_224", - "timm_model_pretrained": false, - "timm_pool": "token", - "timm_proj": null - }, - "text_cfg": { - "context_length": 77, - "vocab_size": 49408, - "width": 768, - "heads": 12, - "layers": 12 - }, - "custom_text": true -} \ No newline at end of file diff --git a/src/proxyclip/open_clip_proxy/model_configs/RN50x16.json b/src/proxyclip/open_clip_proxy/model_configs/RN50x16.json deleted file mode 100644 index 3161e1a2c9a839161e652a4d729c2cdc971161db..0000000000000000000000000000000000000000 --- a/src/proxyclip/open_clip_proxy/model_configs/RN50x16.json +++ /dev/null @@ -1,21 +0,0 @@ -{ - "embed_dim": 768, - "vision_cfg": { - "image_size": 384, - "layers": [ - 6, - 8, - 18, - 8 - ], - "width": 96, - "patch_size": null - }, - "text_cfg": { - "context_length": 77, - "vocab_size": 49408, - "width": 768, - "heads": 12, - "layers": 12 - } -} \ No newline at end of file diff --git a/src/proxyclip/open_clip_proxy/model_configs/ViT-B-16-SigLIP-256.json b/src/proxyclip/open_clip_proxy/model_configs/ViT-B-16-SigLIP-256.json deleted file mode 100644 index d7ad3acba6bd37701ff8f19ca5f791c6342b73d6..0000000000000000000000000000000000000000 --- a/src/proxyclip/open_clip_proxy/model_configs/ViT-B-16-SigLIP-256.json +++ /dev/null @@ -1,29 +0,0 @@ -{ - "embed_dim": 768, - "init_logit_bias": -10, - "custom_text": true, - "vision_cfg": { - "image_size": 256, - "timm_model_name": "vit_base_patch16_siglip_256", - "timm_model_pretrained": false, - "timm_pool": "map", - "timm_proj": "none" - }, - "text_cfg": { - "context_length": 64, - "vocab_size": 32000, - "hf_tokenizer_name": "timm/ViT-B-16-SigLIP", - "tokenizer_kwargs": { - "clean": "canonicalize" - }, - "width": 768, - "heads": 12, - "layers": 12, - "no_causal_mask": true, - "proj_bias": true, - "pool_type": "last", - "norm_kwargs":{ - "eps": 1e-6 - } - } -} \ No newline at end of file diff --git a/src/proxyclip/open_clip_proxy/model_configs/ViT-B-16-SigLIP-384.json b/src/proxyclip/open_clip_proxy/model_configs/ViT-B-16-SigLIP-384.json deleted file mode 100644 index df9a25cdca5207a8954801c0f2cf28514c15a1cd..0000000000000000000000000000000000000000 --- a/src/proxyclip/open_clip_proxy/model_configs/ViT-B-16-SigLIP-384.json +++ /dev/null @@ -1,29 +0,0 @@ -{ - "embed_dim": 768, - "init_logit_bias": -10, - "custom_text": true, - "vision_cfg": { - "image_size": 384, - "timm_model_name": "vit_base_patch16_siglip_384", - "timm_model_pretrained": false, - "timm_pool": "map", - "timm_proj": "none" - }, - "text_cfg": { - "context_length": 64, - "vocab_size": 32000, - "hf_tokenizer_name": "timm/ViT-B-16-SigLIP", - "tokenizer_kwargs": { - "clean": "canonicalize" - }, - "width": 768, - "heads": 12, - "layers": 12, - "no_causal_mask": true, - "proj_bias": true, - "pool_type": "last", - "norm_kwargs":{ - "eps": 1e-6 - } - } -} \ No newline at end of file diff --git a/src/proxyclip/open_clip_proxy/model_configs/ViT-B-16-SigLIP-512.json b/src/proxyclip/open_clip_proxy/model_configs/ViT-B-16-SigLIP-512.json deleted file mode 100644 index 88b018528b2e7806cd11b95d5808136786ea0f97..0000000000000000000000000000000000000000 --- a/src/proxyclip/open_clip_proxy/model_configs/ViT-B-16-SigLIP-512.json +++ /dev/null @@ -1,29 +0,0 @@ -{ - "embed_dim": 768, - "init_logit_bias": -10, - "custom_text": true, - "vision_cfg": { - "image_size": 512, - "timm_model_name": "vit_base_patch16_siglip_512", - "timm_model_pretrained": false, - "timm_pool": "map", - "timm_proj": "none" - }, - "text_cfg": { - "context_length": 64, - "vocab_size": 32000, - "hf_tokenizer_name": "timm/ViT-B-16-SigLIP", - "tokenizer_kwargs": { - "clean": "canonicalize" - }, - "width": 768, - "heads": 12, - "layers": 12, - "no_causal_mask": true, - "proj_bias": true, - "pool_type": "last", - "norm_kwargs":{ - "eps": 1e-6 - } - } -} \ No newline at end of file diff --git a/src/proxyclip/open_clip_proxy/model_configs/ViT-B-16-SigLIP-i18n-256.json b/src/proxyclip/open_clip_proxy/model_configs/ViT-B-16-SigLIP-i18n-256.json deleted file mode 100644 index 7a28797a7e1487af986540872447a68da0dd69b2..0000000000000000000000000000000000000000 --- a/src/proxyclip/open_clip_proxy/model_configs/ViT-B-16-SigLIP-i18n-256.json +++ /dev/null @@ -1,29 +0,0 @@ -{ - "embed_dim": 768, - "init_logit_bias": -10, - "custom_text": true, - "vision_cfg": { - "image_size": 256, - "timm_model_name": "vit_base_patch16_siglip_256", - "timm_model_pretrained": false, - "timm_pool": "map", - "timm_proj": "none" - }, - "text_cfg": { - "context_length": 64, - "vocab_size": 250000, - "hf_tokenizer_name": "timm/ViT-B-16-SigLIP-i18n-256", - "tokenizer_kwargs": { - "clean": "canonicalize" - }, - "width": 768, - "heads": 12, - "layers": 12, - "no_causal_mask": true, - "proj_bias": true, - "pool_type": "last", - "norm_kwargs":{ - "eps": 1e-6 - } - } -} \ No newline at end of file diff --git a/src/proxyclip/open_clip_proxy/model_configs/ViT-B-16-SigLIP.json b/src/proxyclip/open_clip_proxy/model_configs/ViT-B-16-SigLIP.json deleted file mode 100644 index a9f2b654a671c9bd235f351b2a253ca889758549..0000000000000000000000000000000000000000 --- a/src/proxyclip/open_clip_proxy/model_configs/ViT-B-16-SigLIP.json +++ /dev/null @@ -1,29 +0,0 @@ -{ - "embed_dim": 768, - "init_logit_bias": -10, - "custom_text": true, - "vision_cfg": { - "image_size": 224, - "timm_model_name": "vit_base_patch16_siglip_224", - "timm_model_pretrained": false, - "timm_pool": "map", - "timm_proj": "none" - }, - "text_cfg": { - "context_length": 64, - "vocab_size": 32000, - "hf_tokenizer_name": "timm/ViT-B-16-SigLIP", - "tokenizer_kwargs": { - "clean": "canonicalize" - }, - "width": 768, - "heads": 12, - "layers": 12, - "no_causal_mask": true, - "proj_bias": true, - "pool_type": "last", - "norm_kwargs":{ - "eps": 1e-6 - } - } -} \ No newline at end of file diff --git a/src/proxyclip/open_clip_proxy/model_configs/ViT-B-16-plus-240.json b/src/proxyclip/open_clip_proxy/model_configs/ViT-B-16-plus-240.json deleted file mode 100644 index 5bbd12bcd01f64d6d0a0aa8316b129327a0d169a..0000000000000000000000000000000000000000 --- a/src/proxyclip/open_clip_proxy/model_configs/ViT-B-16-plus-240.json +++ /dev/null @@ -1,16 +0,0 @@ -{ - "embed_dim": 640, - "vision_cfg": { - "image_size": 240, - "layers": 12, - "width": 896, - "patch_size": 16 - }, - "text_cfg": { - "context_length": 77, - "vocab_size": 49408, - "width": 640, - "heads": 10, - "layers": 12 - } -} \ No newline at end of file diff --git a/src/proxyclip/open_clip_proxy/model_configs/ViT-B-16-plus.json b/src/proxyclip/open_clip_proxy/model_configs/ViT-B-16-plus.json deleted file mode 100644 index 5dc1e09baccef2b15055c1bffeb9903e760101c6..0000000000000000000000000000000000000000 --- a/src/proxyclip/open_clip_proxy/model_configs/ViT-B-16-plus.json +++ /dev/null @@ -1,16 +0,0 @@ -{ - "embed_dim": 640, - "vision_cfg": { - "image_size": 224, - "layers": 12, - "width": 896, - "patch_size": 16 - }, - "text_cfg": { - "context_length": 77, - "vocab_size": 49408, - "width": 640, - "heads": 10, - "layers": 12 - } -} \ No newline at end of file diff --git a/src/proxyclip/open_clip_proxy/model_configs/ViT-B-16-quickgelu.json b/src/proxyclip/open_clip_proxy/model_configs/ViT-B-16-quickgelu.json deleted file mode 100644 index ff5431ea3065d18094de94d3c87d8814d3f651fe..0000000000000000000000000000000000000000 --- a/src/proxyclip/open_clip_proxy/model_configs/ViT-B-16-quickgelu.json +++ /dev/null @@ -1,17 +0,0 @@ -{ - "embed_dim": 512, - "quick_gelu": true, - "vision_cfg": { - "image_size": 224, - "layers": 12, - "width": 768, - "patch_size": 16 - }, - "text_cfg": { - "context_length": 77, - "vocab_size": 49408, - "width": 512, - "heads": 8, - "layers": 12 - } -} \ No newline at end of file diff --git a/src/proxyclip/open_clip_proxy/model_configs/ViT-B-16.json b/src/proxyclip/open_clip_proxy/model_configs/ViT-B-16.json deleted file mode 100644 index 395eea77ec3907c0611531aba63459b193e67b9c..0000000000000000000000000000000000000000 --- a/src/proxyclip/open_clip_proxy/model_configs/ViT-B-16.json +++ /dev/null @@ -1,16 +0,0 @@ -{ - "embed_dim": 512, - "vision_cfg": { - "image_size": 224, - "layers": 12, - "width": 768, - "patch_size": 16 - }, - "text_cfg": { - "context_length": 77, - "vocab_size": 49408, - "width": 512, - "heads": 8, - "layers": 12 - } -} \ No newline at end of file diff --git a/src/proxyclip/open_clip_proxy/model_configs/ViT-B-32-256.json b/src/proxyclip/open_clip_proxy/model_configs/ViT-B-32-256.json deleted file mode 100644 index 80a2597d8f7d5d500df2aacbded9507196dad6da..0000000000000000000000000000000000000000 --- a/src/proxyclip/open_clip_proxy/model_configs/ViT-B-32-256.json +++ /dev/null @@ -1,16 +0,0 @@ -{ - "embed_dim": 512, - "vision_cfg": { - "image_size": 256, - "layers": 12, - "width": 768, - "patch_size": 32 - }, - "text_cfg": { - "context_length": 77, - "vocab_size": 49408, - "width": 512, - "heads": 8, - "layers": 12 - } -} diff --git a/src/proxyclip/open_clip_proxy/model_configs/ViT-B-32-plus-256.json b/src/proxyclip/open_clip_proxy/model_configs/ViT-B-32-plus-256.json deleted file mode 100644 index 2f09c857de9a4c01ae51297a7e2451984879f9de..0000000000000000000000000000000000000000 --- a/src/proxyclip/open_clip_proxy/model_configs/ViT-B-32-plus-256.json +++ /dev/null @@ -1,16 +0,0 @@ -{ - "embed_dim": 640, - "vision_cfg": { - "image_size": 256, - "layers": 12, - "width": 896, - "patch_size": 32 - }, - "text_cfg": { - "context_length": 77, - "vocab_size": 49408, - "width": 640, - "heads": 10, - "layers": 12 - } -} \ No newline at end of file diff --git a/src/proxyclip/open_clip_proxy/model_configs/ViT-B-32-quickgelu.json b/src/proxyclip/open_clip_proxy/model_configs/ViT-B-32-quickgelu.json deleted file mode 100644 index ce6bd923593293ed50dfcfb28b73ca7403bcf3c5..0000000000000000000000000000000000000000 --- a/src/proxyclip/open_clip_proxy/model_configs/ViT-B-32-quickgelu.json +++ /dev/null @@ -1,17 +0,0 @@ -{ - "embed_dim": 512, - "quick_gelu": true, - "vision_cfg": { - "image_size": 224, - "layers": 12, - "width": 768, - "patch_size": 32 - }, - "text_cfg": { - "context_length": 77, - "vocab_size": 49408, - "width": 512, - "heads": 8, - "layers": 12 - } -} \ No newline at end of file diff --git a/src/proxyclip/open_clip_proxy/model_configs/ViT-B-32.json b/src/proxyclip/open_clip_proxy/model_configs/ViT-B-32.json deleted file mode 100644 index 07c8e28eb06fa1813ba932fe4eec668262d1c47f..0000000000000000000000000000000000000000 --- a/src/proxyclip/open_clip_proxy/model_configs/ViT-B-32.json +++ /dev/null @@ -1,16 +0,0 @@ -{ - "embed_dim": 512, - "vision_cfg": { - "image_size": 224, - "layers": 12, - "width": 768, - "patch_size": 32 - }, - "text_cfg": { - "context_length": 77, - "vocab_size": 49408, - "width": 512, - "heads": 8, - "layers": 12 - } -} \ No newline at end of file diff --git a/src/proxyclip/open_clip_proxy/model_configs/ViT-H-14-378-quickgelu.json b/src/proxyclip/open_clip_proxy/model_configs/ViT-H-14-378-quickgelu.json deleted file mode 100644 index e2b2ecf9ae278eeb4f6b20d16e17a6523f961580..0000000000000000000000000000000000000000 --- a/src/proxyclip/open_clip_proxy/model_configs/ViT-H-14-378-quickgelu.json +++ /dev/null @@ -1,18 +0,0 @@ -{ - "embed_dim": 1024, - "quick_gelu": true, - "vision_cfg": { - "image_size": 378, - "layers": 32, - "width": 1280, - "head_width": 80, - "patch_size": 14 - }, - "text_cfg": { - "context_length": 77, - "vocab_size": 49408, - "width": 1024, - "heads": 16, - "layers": 24 - } -} \ No newline at end of file diff --git a/src/proxyclip/open_clip_proxy/model_configs/ViT-H-14-CLIPA-336.json b/src/proxyclip/open_clip_proxy/model_configs/ViT-H-14-CLIPA-336.json deleted file mode 100644 index 01fabb29db2bcbd9513e903064d61e3e1974d580..0000000000000000000000000000000000000000 --- a/src/proxyclip/open_clip_proxy/model_configs/ViT-H-14-CLIPA-336.json +++ /dev/null @@ -1,26 +0,0 @@ -{ - "embed_dim": 1024, - "vision_cfg": { - "image_size": 336, - "layers": 32, - "width": 1280, - "head_width": 80, - "patch_size": 14, - "no_ln_pre": true, - "pool_type": "avg", - "final_ln_after_pool": true - }, - "text_cfg": { - "context_length": 32, - "vocab_size": 32000, - "hf_tokenizer_name": "bert-base-uncased", - "tokenizer_kwargs": { - "strip_sep_token": true - }, - "width": 1024, - "heads": 16, - "layers": 24, - "pool_type": "last", - "no_causal_mask": true - } -} \ No newline at end of file diff --git a/src/proxyclip/open_clip_proxy/model_configs/ViT-H-14-CLIPA.json b/src/proxyclip/open_clip_proxy/model_configs/ViT-H-14-CLIPA.json deleted file mode 100644 index 7df0338844bfff4d30f3ca08711311f645dda866..0000000000000000000000000000000000000000 --- a/src/proxyclip/open_clip_proxy/model_configs/ViT-H-14-CLIPA.json +++ /dev/null @@ -1,26 +0,0 @@ -{ - "embed_dim": 1024, - "vision_cfg": { - "image_size": 224, - "layers": 32, - "width": 1280, - "head_width": 80, - "patch_size": 14, - "no_ln_pre": true, - "pool_type": "avg", - "final_ln_after_pool": true - }, - "text_cfg": { - "context_length": 32, - "vocab_size": 32000, - "hf_tokenizer_name": "bert-base-uncased", - "tokenizer_kwargs": { - "strip_sep_token": true - }, - "width": 1024, - "heads": 16, - "layers": 24, - "pool_type": "last", - "no_causal_mask": true - } -} \ No newline at end of file diff --git a/src/proxyclip/open_clip_proxy/model_configs/ViT-H-14-quickgelu.json b/src/proxyclip/open_clip_proxy/model_configs/ViT-H-14-quickgelu.json deleted file mode 100644 index 41f22f65bb002c320111790e0cd0f2425a575df7..0000000000000000000000000000000000000000 --- a/src/proxyclip/open_clip_proxy/model_configs/ViT-H-14-quickgelu.json +++ /dev/null @@ -1,18 +0,0 @@ -{ - "embed_dim": 1024, - "quick_gelu": true, - "vision_cfg": { - "image_size": 224, - "layers": 32, - "width": 1280, - "head_width": 80, - "patch_size": 14 - }, - "text_cfg": { - "context_length": 77, - "vocab_size": 49408, - "width": 1024, - "heads": 16, - "layers": 24 - } -} \ No newline at end of file diff --git a/src/proxyclip/open_clip_proxy/model_configs/ViT-H-14.json b/src/proxyclip/open_clip_proxy/model_configs/ViT-H-14.json deleted file mode 100644 index 3e3a7e934e7f02e41f4829996c4950e05f015a74..0000000000000000000000000000000000000000 --- a/src/proxyclip/open_clip_proxy/model_configs/ViT-H-14.json +++ /dev/null @@ -1,17 +0,0 @@ -{ - "embed_dim": 1024, - "vision_cfg": { - "image_size": 224, - "layers": 32, - "width": 1280, - "head_width": 80, - "patch_size": 14 - }, - "text_cfg": { - "context_length": 77, - "vocab_size": 49408, - "width": 1024, - "heads": 16, - "layers": 24 - } -} \ No newline at end of file diff --git a/src/proxyclip/open_clip_proxy/model_configs/ViT-H-16.json b/src/proxyclip/open_clip_proxy/model_configs/ViT-H-16.json deleted file mode 100644 index 588485455fdf8193ec16474450b94e31c91ea93c..0000000000000000000000000000000000000000 --- a/src/proxyclip/open_clip_proxy/model_configs/ViT-H-16.json +++ /dev/null @@ -1,17 +0,0 @@ -{ - "embed_dim": 1024, - "vision_cfg": { - "image_size": 224, - "layers": 32, - "width": 1280, - "head_width": 80, - "patch_size": 16 - }, - "text_cfg": { - "context_length": 77, - "vocab_size": 49408, - "width": 1024, - "heads": 16, - "layers": 24 - } -} \ No newline at end of file diff --git a/src/proxyclip/open_clip_proxy/model_configs/ViT-L-14-280.json b/src/proxyclip/open_clip_proxy/model_configs/ViT-L-14-280.json deleted file mode 100644 index 2262deaefa82792d35d73c0d7c8e620525092581..0000000000000000000000000000000000000000 --- a/src/proxyclip/open_clip_proxy/model_configs/ViT-L-14-280.json +++ /dev/null @@ -1,16 +0,0 @@ -{ - "embed_dim": 768, - "vision_cfg": { - "image_size": 280, - "layers": 24, - "width": 1024, - "patch_size": 14 - }, - "text_cfg": { - "context_length": 77, - "vocab_size": 49408, - "width": 768, - "heads": 12, - "layers": 12 - } -} \ No newline at end of file diff --git a/src/proxyclip/open_clip_proxy/model_configs/ViT-L-14-336.json b/src/proxyclip/open_clip_proxy/model_configs/ViT-L-14-336.json deleted file mode 100644 index 8d1f74c2639c3a3705df9865b9c08215675ddc97..0000000000000000000000000000000000000000 --- a/src/proxyclip/open_clip_proxy/model_configs/ViT-L-14-336.json +++ /dev/null @@ -1,16 +0,0 @@ -{ - "embed_dim": 768, - "vision_cfg": { - "image_size": 336, - "layers": 24, - "width": 1024, - "patch_size": 14 - }, - "text_cfg": { - "context_length": 77, - "vocab_size": 49408, - "width": 768, - "heads": 12, - "layers": 12 - } -} \ No newline at end of file diff --git a/src/proxyclip/open_clip_proxy/model_configs/ViT-L-14-CLIPA-336.json b/src/proxyclip/open_clip_proxy/model_configs/ViT-L-14-CLIPA-336.json deleted file mode 100644 index 60a4df589b9e9ed269807204ec9788e613026382..0000000000000000000000000000000000000000 --- a/src/proxyclip/open_clip_proxy/model_configs/ViT-L-14-CLIPA-336.json +++ /dev/null @@ -1,25 +0,0 @@ -{ - "embed_dim": 768, - "vision_cfg": { - "image_size": 336, - "layers": 24, - "width": 1024, - "patch_size": 14, - "no_ln_pre": true, - "pool_type": "avg", - "final_ln_after_pool": true - }, - "text_cfg": { - "context_length": 32, - "vocab_size": 32000, - "hf_tokenizer_name": "bert-base-uncased", - "tokenizer_kwargs": { - "strip_sep_token": true - }, - "width": 768, - "heads": 12, - "layers": 12, - "pool_type": "last", - "no_causal_mask": true - } -} \ No newline at end of file diff --git a/src/proxyclip/open_clip_proxy/model_configs/ViT-L-14-CLIPA.json b/src/proxyclip/open_clip_proxy/model_configs/ViT-L-14-CLIPA.json deleted file mode 100644 index b4dde7b546b6c53d5c55f2abe50b599ff2519964..0000000000000000000000000000000000000000 --- a/src/proxyclip/open_clip_proxy/model_configs/ViT-L-14-CLIPA.json +++ /dev/null @@ -1,25 +0,0 @@ -{ - "embed_dim": 768, - "vision_cfg": { - "image_size": 224, - "layers": 24, - "width": 1024, - "patch_size": 14, - "no_ln_pre": true, - "pool_type": "avg", - "final_ln_after_pool": true - }, - "text_cfg": { - "context_length": 32, - "vocab_size": 32000, - "hf_tokenizer_name": "bert-base-uncased", - "tokenizer_kwargs": { - "strip_sep_token": true - }, - "width": 768, - "heads": 12, - "layers": 12, - "pool_type": "last", - "no_causal_mask": true - } -} \ No newline at end of file diff --git a/src/proxyclip/open_clip_proxy/model_configs/ViT-L-14-quickgelu.json b/src/proxyclip/open_clip_proxy/model_configs/ViT-L-14-quickgelu.json deleted file mode 100644 index d5a3fd36aa9cd9cc4a3dc29e362945cec13a02f3..0000000000000000000000000000000000000000 --- a/src/proxyclip/open_clip_proxy/model_configs/ViT-L-14-quickgelu.json +++ /dev/null @@ -1,17 +0,0 @@ -{ - "embed_dim": 768, - "quick_gelu": true, - "vision_cfg": { - "image_size": 224, - "layers": 24, - "width": 1024, - "patch_size": 14 - }, - "text_cfg": { - "context_length": 77, - "vocab_size": 49408, - "width": 768, - "heads": 12, - "layers": 12 - } -} \ No newline at end of file diff --git a/src/proxyclip/open_clip_proxy/model_configs/ViT-L-14.json b/src/proxyclip/open_clip_proxy/model_configs/ViT-L-14.json deleted file mode 100644 index d4a4bbb1dd4ed4edb317d3ace4f3ad13b211c241..0000000000000000000000000000000000000000 --- a/src/proxyclip/open_clip_proxy/model_configs/ViT-L-14.json +++ /dev/null @@ -1,16 +0,0 @@ -{ - "embed_dim": 768, - "vision_cfg": { - "image_size": 224, - "layers": 24, - "width": 1024, - "patch_size": 14 - }, - "text_cfg": { - "context_length": 77, - "vocab_size": 49408, - "width": 768, - "heads": 12, - "layers": 12 - } -} \ No newline at end of file diff --git a/src/proxyclip/open_clip_proxy/model_configs/ViT-L-16-320.json b/src/proxyclip/open_clip_proxy/model_configs/ViT-L-16-320.json deleted file mode 100644 index fc2d13ca9ec7f0b56a886ddaf66c4a7ba7a442ba..0000000000000000000000000000000000000000 --- a/src/proxyclip/open_clip_proxy/model_configs/ViT-L-16-320.json +++ /dev/null @@ -1,16 +0,0 @@ -{ - "embed_dim": 768, - "vision_cfg": { - "image_size": 320, - "layers": 24, - "width": 1024, - "patch_size": 16 - }, - "text_cfg": { - "context_length": 77, - "vocab_size": 49408, - "width": 768, - "heads": 12, - "layers": 12 - } -} \ No newline at end of file diff --git a/src/proxyclip/open_clip_proxy/model_configs/ViT-L-16-SigLIP-256.json b/src/proxyclip/open_clip_proxy/model_configs/ViT-L-16-SigLIP-256.json deleted file mode 100644 index 5ba8f7abb68e5a798d38f976a828c63f74b94ae8..0000000000000000000000000000000000000000 --- a/src/proxyclip/open_clip_proxy/model_configs/ViT-L-16-SigLIP-256.json +++ /dev/null @@ -1,29 +0,0 @@ -{ - "embed_dim": 1024, - "init_logit_bias": -10, - "custom_text": true, - "vision_cfg": { - "image_size": 256, - "timm_model_name": "vit_large_patch16_siglip_256", - "timm_model_pretrained": false, - "timm_pool": "map", - "timm_proj": "none" - }, - "text_cfg": { - "context_length": 64, - "vocab_size": 32000, - "hf_tokenizer_name": "timm/ViT-B-16-SigLIP", - "tokenizer_kwargs": { - "clean": "canonicalize" - }, - "width": 1024, - "heads": 16, - "layers": 24, - "no_causal_mask": true, - "proj_bias": true, - "pool_type": "last", - "norm_kwargs":{ - "eps": 1e-6 - } - } -} \ No newline at end of file diff --git a/src/proxyclip/open_clip_proxy/model_configs/ViT-L-16-SigLIP-384.json b/src/proxyclip/open_clip_proxy/model_configs/ViT-L-16-SigLIP-384.json deleted file mode 100644 index fd2cc2e346f7110a5de01cfaf7eae8c94360de3a..0000000000000000000000000000000000000000 --- a/src/proxyclip/open_clip_proxy/model_configs/ViT-L-16-SigLIP-384.json +++ /dev/null @@ -1,29 +0,0 @@ -{ - "embed_dim": 1024, - "init_logit_bias": -10, - "custom_text": true, - "vision_cfg": { - "image_size": 384, - "timm_model_name": "vit_large_patch16_siglip_384", - "timm_model_pretrained": false, - "timm_pool": "map", - "timm_proj": "none" - }, - "text_cfg": { - "context_length": 64, - "vocab_size": 32000, - "hf_tokenizer_name": "timm/ViT-B-16-SigLIP", - "tokenizer_kwargs": { - "clean": "canonicalize" - }, - "width": 1024, - "heads": 16, - "layers": 24, - "no_causal_mask": true, - "proj_bias": true, - "pool_type": "last", - "norm_kwargs":{ - "eps": 1e-6 - } - } -} \ No newline at end of file diff --git a/src/proxyclip/open_clip_proxy/model_configs/ViT-L-16.json b/src/proxyclip/open_clip_proxy/model_configs/ViT-L-16.json deleted file mode 100644 index 82a1cedfa290adacbbdc02bc5d589734c22d41d3..0000000000000000000000000000000000000000 --- a/src/proxyclip/open_clip_proxy/model_configs/ViT-L-16.json +++ /dev/null @@ -1,16 +0,0 @@ -{ - "embed_dim": 768, - "vision_cfg": { - "image_size": 224, - "layers": 24, - "width": 1024, - "patch_size": 16 - }, - "text_cfg": { - "context_length": 77, - "vocab_size": 49408, - "width": 768, - "heads": 12, - "layers": 12 - } -} \ No newline at end of file diff --git a/src/proxyclip/open_clip_proxy/model_configs/ViT-M-16-alt.json b/src/proxyclip/open_clip_proxy/model_configs/ViT-M-16-alt.json deleted file mode 100644 index 1a317aad8e02d9c26d2decc7cc49a18dfdf9e0d8..0000000000000000000000000000000000000000 --- a/src/proxyclip/open_clip_proxy/model_configs/ViT-M-16-alt.json +++ /dev/null @@ -1,17 +0,0 @@ -{ - "embed_dim": 384, - "vision_cfg": { - "image_size": 224, - "layers": 12, - "width": 512, - "patch_size": 16, - "ls_init_value": 1e-4 - }, - "text_cfg": { - "context_length": 77, - "vocab_size": 49408, - "width": 384, - "heads": 6, - "layers": 12 - } -} \ No newline at end of file diff --git a/src/proxyclip/open_clip_proxy/model_configs/ViT-M-16.json b/src/proxyclip/open_clip_proxy/model_configs/ViT-M-16.json deleted file mode 100644 index f2f3225a46e09237730a151d161f70c86b985172..0000000000000000000000000000000000000000 --- a/src/proxyclip/open_clip_proxy/model_configs/ViT-M-16.json +++ /dev/null @@ -1,16 +0,0 @@ -{ - "embed_dim": 512, - "vision_cfg": { - "image_size": 224, - "layers": 12, - "width": 512, - "patch_size": 16 - }, - "text_cfg": { - "context_length": 77, - "vocab_size": 49408, - "width": 512, - "heads": 8, - "layers": 12 - } -} \ No newline at end of file diff --git a/src/proxyclip/open_clip_proxy/model_configs/ViT-M-32-alt.json b/src/proxyclip/open_clip_proxy/model_configs/ViT-M-32-alt.json deleted file mode 100644 index fd222aeac0f582ef6a1a33f1b3fec70a5b386ac0..0000000000000000000000000000000000000000 --- a/src/proxyclip/open_clip_proxy/model_configs/ViT-M-32-alt.json +++ /dev/null @@ -1,16 +0,0 @@ -{ - "embed_dim": 384, - "vision_cfg": { - "image_size": 224, - "layers": 12, - "width": 512, - "patch_size": 32 - }, - "text_cfg": { - "context_length": 77, - "vocab_size": 49408, - "width": 384, - "heads": 6, - "layers": 12 - } -} \ No newline at end of file diff --git a/src/proxyclip/open_clip_proxy/model_configs/ViT-M-32.json b/src/proxyclip/open_clip_proxy/model_configs/ViT-M-32.json deleted file mode 100644 index 4f718642821035d9776d1e006817d65ede074366..0000000000000000000000000000000000000000 --- a/src/proxyclip/open_clip_proxy/model_configs/ViT-M-32.json +++ /dev/null @@ -1,16 +0,0 @@ -{ - "embed_dim": 512, - "vision_cfg": { - "image_size": 224, - "layers": 12, - "width": 512, - "patch_size": 32 - }, - "text_cfg": { - "context_length": 77, - "vocab_size": 49408, - "width": 512, - "heads": 8, - "layers": 12 - } -} \ No newline at end of file diff --git a/src/proxyclip/open_clip_proxy/model_configs/ViT-S-16-alt.json b/src/proxyclip/open_clip_proxy/model_configs/ViT-S-16-alt.json deleted file mode 100644 index a8c056555e4da3ba0d1475a61fc316362ecce76f..0000000000000000000000000000000000000000 --- a/src/proxyclip/open_clip_proxy/model_configs/ViT-S-16-alt.json +++ /dev/null @@ -1,16 +0,0 @@ -{ - "embed_dim": 256, - "vision_cfg": { - "image_size": 224, - "layers": 12, - "width": 384, - "patch_size": 16 - }, - "text_cfg": { - "context_length": 77, - "vocab_size": 49408, - "width": 256, - "heads": 4, - "layers": 10 - } -} \ No newline at end of file diff --git a/src/proxyclip/open_clip_proxy/model_configs/ViT-S-16.json b/src/proxyclip/open_clip_proxy/model_configs/ViT-S-16.json deleted file mode 100644 index 1d8504e59658803f3093e5b05de45f30a09b8185..0000000000000000000000000000000000000000 --- a/src/proxyclip/open_clip_proxy/model_configs/ViT-S-16.json +++ /dev/null @@ -1,16 +0,0 @@ -{ - "embed_dim": 384, - "vision_cfg": { - "image_size": 224, - "layers": 12, - "width": 384, - "patch_size": 16 - }, - "text_cfg": { - "context_length": 77, - "vocab_size": 49408, - "width": 384, - "heads": 6, - "layers": 12 - } -} \ No newline at end of file diff --git a/src/proxyclip/open_clip_proxy/model_configs/ViT-S-32-alt.json b/src/proxyclip/open_clip_proxy/model_configs/ViT-S-32-alt.json deleted file mode 100644 index e1dfdec9824df09a2010e991ccfa1d9ee2f45807..0000000000000000000000000000000000000000 --- a/src/proxyclip/open_clip_proxy/model_configs/ViT-S-32-alt.json +++ /dev/null @@ -1,16 +0,0 @@ -{ - "embed_dim": 256, - "vision_cfg": { - "image_size": 224, - "layers": 12, - "width": 384, - "patch_size": 32 - }, - "text_cfg": { - "context_length": 77, - "vocab_size": 49408, - "width": 256, - "heads": 4, - "layers": 10 - } -} \ No newline at end of file diff --git a/src/proxyclip/open_clip_proxy/model_configs/ViT-S-32.json b/src/proxyclip/open_clip_proxy/model_configs/ViT-S-32.json deleted file mode 100644 index 9b8b4191b268de267268cfcb90fc01c6b9df07d8..0000000000000000000000000000000000000000 --- a/src/proxyclip/open_clip_proxy/model_configs/ViT-S-32.json +++ /dev/null @@ -1,16 +0,0 @@ -{ - "embed_dim": 384, - "vision_cfg": { - "image_size": 224, - "layers": 12, - "width": 384, - "patch_size": 32 - }, - "text_cfg": { - "context_length": 77, - "vocab_size": 49408, - "width": 384, - "heads": 6, - "layers": 12 - } -} \ No newline at end of file diff --git a/src/proxyclip/open_clip_proxy/model_configs/ViT-SO400M-14-SigLIP-384.json b/src/proxyclip/open_clip_proxy/model_configs/ViT-SO400M-14-SigLIP-384.json deleted file mode 100644 index 4c527f581230938d7b39baf36b6bd749b0e7f169..0000000000000000000000000000000000000000 --- a/src/proxyclip/open_clip_proxy/model_configs/ViT-SO400M-14-SigLIP-384.json +++ /dev/null @@ -1,30 +0,0 @@ -{ - "embed_dim": 1152, - "init_logit_bias": -10, - "custom_text": true, - "vision_cfg": { - "image_size": 384, - "timm_model_name": "vit_so400m_patch14_siglip_384", - "timm_model_pretrained": false, - "timm_pool": "map", - "timm_proj": "none" - }, - "text_cfg": { - "context_length": 64, - "vocab_size": 32000, - "hf_tokenizer_name": "timm/ViT-B-16-SigLIP", - "tokenizer_kwargs": { - "clean": "canonicalize" - }, - "width": 1152, - "heads": 16, - "layers": 27, - "mlp_ratio": 3.7362, - "no_causal_mask": true, - "proj_bias": true, - "pool_type": "last", - "norm_kwargs":{ - "eps": 1e-6 - } - } -} \ No newline at end of file diff --git a/src/proxyclip/open_clip_proxy/model_configs/ViT-SO400M-14-SigLIP.json b/src/proxyclip/open_clip_proxy/model_configs/ViT-SO400M-14-SigLIP.json deleted file mode 100644 index 564eb78a49c8ff31cac047277b9344bbe85fef40..0000000000000000000000000000000000000000 --- a/src/proxyclip/open_clip_proxy/model_configs/ViT-SO400M-14-SigLIP.json +++ /dev/null @@ -1,30 +0,0 @@ -{ - "embed_dim": 1152, - "init_logit_bias": -10, - "custom_text": true, - "vision_cfg": { - "image_size": 224, - "timm_model_name": "vit_so400m_patch14_siglip_224", - "timm_model_pretrained": false, - "timm_pool": "map", - "timm_proj": "none" - }, - "text_cfg": { - "context_length": 16, - "vocab_size": 32000, - "hf_tokenizer_name": "timm/ViT-B-16-SigLIP", - "tokenizer_kwargs": { - "clean": "canonicalize" - }, - "width": 1152, - "heads": 16, - "layers": 27, - "mlp_ratio": 3.7362, - "no_causal_mask": true, - "proj_bias": true, - "pool_type": "last", - "norm_kwargs":{ - "eps": 1e-6 - } - } -} \ No newline at end of file diff --git a/src/proxyclip/open_clip_proxy/model_configs/ViT-bigG-14-CLIPA-336.json b/src/proxyclip/open_clip_proxy/model_configs/ViT-bigG-14-CLIPA-336.json deleted file mode 100644 index 75ba7675c643cd482f06886e58ded6fb934233fc..0000000000000000000000000000000000000000 --- a/src/proxyclip/open_clip_proxy/model_configs/ViT-bigG-14-CLIPA-336.json +++ /dev/null @@ -1,27 +0,0 @@ -{ - "embed_dim": 1280, - "vision_cfg": { - "image_size": 336, - "layers": 48, - "width": 1664, - "head_width": 104, - "mlp_ratio": 4.9231, - "patch_size": 14, - "no_ln_pre": true, - "pool_type": "avg", - "final_ln_after_pool": true - }, - "text_cfg": { - "context_length": 32, - "vocab_size": 32000, - "hf_tokenizer_name": "bert-base-uncased", - "tokenizer_kwargs": { - "strip_sep_token": true - }, - "width": 1280, - "heads": 20, - "layers": 32, - "pool_type": "last", - "no_causal_mask": true - } -} \ No newline at end of file diff --git a/src/proxyclip/open_clip_proxy/model_configs/ViT-bigG-14-CLIPA.json b/src/proxyclip/open_clip_proxy/model_configs/ViT-bigG-14-CLIPA.json deleted file mode 100644 index 83ec709f8b8362d892067adafde9a0d78ce4db14..0000000000000000000000000000000000000000 --- a/src/proxyclip/open_clip_proxy/model_configs/ViT-bigG-14-CLIPA.json +++ /dev/null @@ -1,27 +0,0 @@ -{ - "embed_dim": 1280, - "vision_cfg": { - "image_size": 224, - "layers": 48, - "width": 1664, - "head_width": 104, - "mlp_ratio": 4.9231, - "patch_size": 14, - "no_ln_pre": true, - "pool_type": "avg", - "final_ln_after_pool": true - }, - "text_cfg": { - "context_length": 32, - "vocab_size": 32000, - "hf_tokenizer_name": "bert-base-uncased", - "tokenizer_kwargs": { - "strip_sep_token": true - }, - "width": 1280, - "heads": 20, - "layers": 32, - "pool_type": "last", - "no_causal_mask": true - } -} \ No newline at end of file diff --git a/src/proxyclip/open_clip_proxy/model_configs/ViT-bigG-14.json b/src/proxyclip/open_clip_proxy/model_configs/ViT-bigG-14.json deleted file mode 100644 index 2cfba479a2e8f3737e71ce240732bf3bc743d8b7..0000000000000000000000000000000000000000 --- a/src/proxyclip/open_clip_proxy/model_configs/ViT-bigG-14.json +++ /dev/null @@ -1,18 +0,0 @@ -{ - "embed_dim": 1280, - "vision_cfg": { - "image_size": 224, - "layers": 48, - "width": 1664, - "head_width": 104, - "mlp_ratio": 4.9231, - "patch_size": 14 - }, - "text_cfg": { - "context_length": 77, - "vocab_size": 49408, - "width": 1280, - "heads": 20, - "layers": 32 - } -} \ No newline at end of file diff --git a/src/proxyclip/open_clip_proxy/model_configs/ViT-e-14.json b/src/proxyclip/open_clip_proxy/model_configs/ViT-e-14.json deleted file mode 100644 index 91a0fe14d25a107fb8ec48dd7faae313fd26ed7b..0000000000000000000000000000000000000000 --- a/src/proxyclip/open_clip_proxy/model_configs/ViT-e-14.json +++ /dev/null @@ -1,18 +0,0 @@ -{ - "embed_dim": 1280, - "vision_cfg": { - "image_size": 224, - "layers": 56, - "width": 1792, - "head_width": 112, - "mlp_ratio": 8.5715, - "patch_size": 14 - }, - "text_cfg": { - "context_length": 77, - "vocab_size": 49408, - "width": 1280, - "heads": 20, - "layers": 36 - } -} \ No newline at end of file diff --git a/src/proxyclip/open_clip_proxy/model_configs/ViT-g-14.json b/src/proxyclip/open_clip_proxy/model_configs/ViT-g-14.json deleted file mode 100644 index 8c4b7325cc75b6112be7107d36ae2cb5762d9091..0000000000000000000000000000000000000000 --- a/src/proxyclip/open_clip_proxy/model_configs/ViT-g-14.json +++ /dev/null @@ -1,18 +0,0 @@ -{ - "embed_dim": 1024, - "vision_cfg": { - "image_size": 224, - "layers": 40, - "width": 1408, - "head_width": 88, - "mlp_ratio": 4.3637, - "patch_size": 14 - }, - "text_cfg": { - "context_length": 77, - "vocab_size": 49408, - "width": 1024, - "heads": 16, - "layers": 24 - } -} \ No newline at end of file diff --git a/src/proxyclip/open_clip_proxy/model_configs/coca_ViT-B-32.json b/src/proxyclip/open_clip_proxy/model_configs/coca_ViT-B-32.json deleted file mode 100644 index 7e7eb520a6a0096e5602d509ecd6186e278f4725..0000000000000000000000000000000000000000 --- a/src/proxyclip/open_clip_proxy/model_configs/coca_ViT-B-32.json +++ /dev/null @@ -1,30 +0,0 @@ -{ - "embed_dim": 512, - "vision_cfg": { - "image_size": 224, - "layers": 12, - "width": 768, - "patch_size": 32, - "attentional_pool": true, - "attn_pooler_heads": 8, - "output_tokens": true - }, - "text_cfg": { - "context_length": 76, - "vocab_size": 49408, - "width": 512, - "heads": 8, - "layers": 12, - "embed_cls": true, - "output_tokens": true - }, - "multimodal_cfg": { - "context_length": 76, - "vocab_size": 49408, - "width": 512, - "heads": 8, - "layers": 12, - "attn_pooler_heads": 8 - }, - "custom_text": true -} \ No newline at end of file diff --git a/src/proxyclip/open_clip_proxy/model_configs/coca_ViT-L-14.json b/src/proxyclip/open_clip_proxy/model_configs/coca_ViT-L-14.json deleted file mode 100644 index 3d5ca4ca2338540f06852df5ff35ea6277e64555..0000000000000000000000000000000000000000 --- a/src/proxyclip/open_clip_proxy/model_configs/coca_ViT-L-14.json +++ /dev/null @@ -1,30 +0,0 @@ -{ - "embed_dim": 768, - "vision_cfg": { - "image_size": 224, - "layers": 24, - "width": 1024, - "patch_size": 14, - "attentional_pool": true, - "attn_pooler_heads": 8, - "output_tokens": true - }, - "text_cfg": { - "context_length": 76, - "vocab_size": 49408, - "width": 768, - "heads": 12, - "layers": 12, - "embed_cls": true, - "output_tokens": true - }, - "multimodal_cfg": { - "context_length": 76, - "vocab_size": 49408, - "width": 768, - "heads": 12, - "layers": 12, - "attn_pooler_heads": 12 - }, - "custom_text": true -} diff --git a/src/proxyclip/open_clip_proxy/model_configs/coca_base.json b/src/proxyclip/open_clip_proxy/model_configs/coca_base.json deleted file mode 100644 index cf8c6cecb78a49d7e7140145a0307cbd561077c2..0000000000000000000000000000000000000000 --- a/src/proxyclip/open_clip_proxy/model_configs/coca_base.json +++ /dev/null @@ -1,31 +0,0 @@ -{ - "embed_dim": 512, - "multimodal_cfg": { - "width": 768, - "context_length": 76, - "vocab_size": 64000, - "mlp_ratio": 4, - "layers": 12, - "dim_head": 64, - "heads": 12, - "n_queries": 256, - "attn_pooler_heads": 8 - }, - "vision_cfg": { - "image_size": 288, - "layers": 12, - "width": 768, - "patch_size": 18, - "output_tokens": true - }, - "text_cfg": { - "context_length": 76, - "vocab_size": 64000, - "layers": 12, - "heads": 12, - "width": 768, - "embed_cls": true, - "output_tokens": true - }, - "custom_text": true -} \ No newline at end of file diff --git a/src/proxyclip/open_clip_proxy/model_configs/coca_roberta-ViT-B-32.json b/src/proxyclip/open_clip_proxy/model_configs/coca_roberta-ViT-B-32.json deleted file mode 100644 index aa9d3f562057f849e6ced8b495de2dd73387fe61..0000000000000000000000000000000000000000 --- a/src/proxyclip/open_clip_proxy/model_configs/coca_roberta-ViT-B-32.json +++ /dev/null @@ -1,24 +0,0 @@ -{ - "embed_dim": 512, - "vision_cfg": { - "image_size": 224, - "layers": 12, - "width": 768, - "patch_size": 32, - "output_tokens": true - }, - "text_cfg": { - "hf_model_name": "roberta-base", - "hf_tokenizer_name": "roberta-base", - "hf_proj_type": "linear", - "width": 768, - "output_tokens": true - }, - "multimodal_cfg": { - "context_length": 76, - "width": 768, - "heads": 8, - "layers": 12 - }, - "custom_text": true -} diff --git a/src/proxyclip/open_clip_proxy/model_configs/vit_medium_patch16_gap_256.json b/src/proxyclip/open_clip_proxy/model_configs/vit_medium_patch16_gap_256.json deleted file mode 100644 index 8843eaf08cad16c3e7b5f496fd650715c9573f65..0000000000000000000000000000000000000000 --- a/src/proxyclip/open_clip_proxy/model_configs/vit_medium_patch16_gap_256.json +++ /dev/null @@ -1,17 +0,0 @@ -{ - "embed_dim": 512, - "vision_cfg": { - "timm_model_name": "vit_medium_patch16_gap_256", - "timm_model_pretrained": false, - "timm_pool": "", - "timm_proj": "linear", - "image_size": 256 - }, - "text_cfg": { - "context_length": 77, - "vocab_size": 49408, - "width": 512, - "heads": 8, - "layers": 12 - } -} \ No newline at end of file diff --git a/src/proxyclip/open_clip_proxy/model_configs/vit_relpos_medium_patch16_cls_224.json b/src/proxyclip/open_clip_proxy/model_configs/vit_relpos_medium_patch16_cls_224.json deleted file mode 100644 index ed217b202d5e6071c5307f4547c97ff4cfe2abd1..0000000000000000000000000000000000000000 --- a/src/proxyclip/open_clip_proxy/model_configs/vit_relpos_medium_patch16_cls_224.json +++ /dev/null @@ -1,17 +0,0 @@ -{ - "embed_dim": 512, - "vision_cfg": { - "timm_model_name": "vit_relpos_medium_patch16_cls_224", - "timm_model_pretrained": false, - "timm_pool": "", - "timm_proj": "linear", - "image_size": 224 - }, - "text_cfg": { - "context_length": 77, - "vocab_size": 49408, - "width": 512, - "heads": 8, - "layers": 12 - } -} \ No newline at end of file diff --git a/src/proxyclip/open_clip_proxy/modified_resnet.py b/src/proxyclip/open_clip_proxy/modified_resnet.py deleted file mode 100644 index f7c0b033a80e7d08a20a367050c5b1bc5d5292e7..0000000000000000000000000000000000000000 --- a/src/proxyclip/open_clip_proxy/modified_resnet.py +++ /dev/null @@ -1,181 +0,0 @@ -from collections import OrderedDict - -import torch -from torch import nn -from torch.nn import functional as F - -from open_clip.utils import freeze_batch_norm_2d - - -class Bottleneck(nn.Module): - expansion = 4 - - def __init__(self, inplanes, planes, stride=1): - super().__init__() - - # all conv layers have stride 1. an avgpool is performed after the second convolution when stride > 1 - self.conv1 = nn.Conv2d(inplanes, planes, 1, bias=False) - self.bn1 = nn.BatchNorm2d(planes) - self.act1 = nn.ReLU(inplace=True) - - self.conv2 = nn.Conv2d(planes, planes, 3, padding=1, bias=False) - self.bn2 = nn.BatchNorm2d(planes) - self.act2 = nn.ReLU(inplace=True) - - self.avgpool = nn.AvgPool2d(stride) if stride > 1 else nn.Identity() - - self.conv3 = nn.Conv2d(planes, planes * self.expansion, 1, bias=False) - self.bn3 = nn.BatchNorm2d(planes * self.expansion) - self.act3 = nn.ReLU(inplace=True) - - self.downsample = None - self.stride = stride - - if stride > 1 or inplanes != planes * Bottleneck.expansion: - # downsampling layer is prepended with an avgpool, and the subsequent convolution has stride 1 - self.downsample = nn.Sequential(OrderedDict([ - ("-1", nn.AvgPool2d(stride)), - ("0", nn.Conv2d(inplanes, planes * self.expansion, 1, stride=1, bias=False)), - ("1", nn.BatchNorm2d(planes * self.expansion)) - ])) - - def forward(self, x: torch.Tensor): - identity = x - - out = self.act1(self.bn1(self.conv1(x))) - out = self.act2(self.bn2(self.conv2(out))) - out = self.avgpool(out) - out = self.bn3(self.conv3(out)) - - if self.downsample is not None: - identity = self.downsample(x) - - out += identity - out = self.act3(out) - return out - - -class AttentionPool2d(nn.Module): - def __init__(self, spacial_dim: int, embed_dim: int, num_heads: int, output_dim: int = None): - super().__init__() - self.positional_embedding = nn.Parameter(torch.randn(spacial_dim ** 2 + 1, embed_dim) / embed_dim ** 0.5) - self.k_proj = nn.Linear(embed_dim, embed_dim) - self.q_proj = nn.Linear(embed_dim, embed_dim) - self.v_proj = nn.Linear(embed_dim, embed_dim) - self.c_proj = nn.Linear(embed_dim, output_dim or embed_dim) - self.num_heads = num_heads - - def forward(self, x): - x = x.reshape(x.shape[0], x.shape[1], x.shape[2] * x.shape[3]).permute(2, 0, 1) # NCHW -> (HW)NC - x = torch.cat([x.mean(dim=0, keepdim=True), x], dim=0) # (HW+1)NC - x = x + self.positional_embedding[:, None, :].to(x.dtype) # (HW+1)NC - x, _ = F.multi_head_attention_forward( - query=x, key=x, value=x, - embed_dim_to_check=x.shape[-1], - num_heads=self.num_heads, - q_proj_weight=self.q_proj.weight, - k_proj_weight=self.k_proj.weight, - v_proj_weight=self.v_proj.weight, - in_proj_weight=None, - in_proj_bias=torch.cat([self.q_proj.bias, self.k_proj.bias, self.v_proj.bias]), - bias_k=None, - bias_v=None, - add_zero_attn=False, - dropout_p=0., - out_proj_weight=self.c_proj.weight, - out_proj_bias=self.c_proj.bias, - use_separate_proj_weight=True, - training=self.training, - need_weights=False - ) - - return x[0] - - -class ModifiedResNet(nn.Module): - """ - A ResNet class that is similar to torchvision's but contains the following changes: - - There are now 3 "stem" convolutions as opposed to 1, with an average pool instead of a max pool. - - Performs anti-aliasing strided convolutions, where an avgpool is prepended to convolutions with stride > 1 - - The final pooling layer is a QKV attention instead of an average pool - """ - - def __init__(self, layers, output_dim, heads, image_size=224, width=64): - super().__init__() - self.output_dim = output_dim - self.image_size = image_size - - # the 3-layer stem - self.conv1 = nn.Conv2d(3, width // 2, kernel_size=3, stride=2, padding=1, bias=False) - self.bn1 = nn.BatchNorm2d(width // 2) - self.act1 = nn.ReLU(inplace=True) - self.conv2 = nn.Conv2d(width // 2, width // 2, kernel_size=3, padding=1, bias=False) - self.bn2 = nn.BatchNorm2d(width // 2) - self.act2 = nn.ReLU(inplace=True) - self.conv3 = nn.Conv2d(width // 2, width, kernel_size=3, padding=1, bias=False) - self.bn3 = nn.BatchNorm2d(width) - self.act3 = nn.ReLU(inplace=True) - self.avgpool = nn.AvgPool2d(2) - - # residual layers - self._inplanes = width # this is a *mutable* variable used during construction - self.layer1 = self._make_layer(width, layers[0]) - self.layer2 = self._make_layer(width * 2, layers[1], stride=2) - self.layer3 = self._make_layer(width * 4, layers[2], stride=2) - self.layer4 = self._make_layer(width * 8, layers[3], stride=2) - - embed_dim = width * 32 # the ResNet feature dimension - self.attnpool = AttentionPool2d(image_size // 32, embed_dim, heads, output_dim) - - self.init_parameters() - - def _make_layer(self, planes, blocks, stride=1): - layers = [Bottleneck(self._inplanes, planes, stride)] - - self._inplanes = planes * Bottleneck.expansion - for _ in range(1, blocks): - layers.append(Bottleneck(self._inplanes, planes)) - - return nn.Sequential(*layers) - - def init_parameters(self): - if self.attnpool is not None: - std = self.attnpool.c_proj.in_features ** -0.5 - nn.init.normal_(self.attnpool.q_proj.weight, std=std) - nn.init.normal_(self.attnpool.k_proj.weight, std=std) - nn.init.normal_(self.attnpool.v_proj.weight, std=std) - nn.init.normal_(self.attnpool.c_proj.weight, std=std) - - for resnet_block in [self.layer1, self.layer2, self.layer3, self.layer4]: - for name, param in resnet_block.named_parameters(): - if name.endswith("bn3.weight"): - nn.init.zeros_(param) - - def lock(self, unlocked_groups=0, freeze_bn_stats=False): - assert unlocked_groups == 0, 'partial locking not currently supported for this model' - for param in self.parameters(): - param.requires_grad = False - if freeze_bn_stats: - freeze_batch_norm_2d(self) - - @torch.jit.ignore - def set_grad_checkpointing(self, enable=True): - # FIXME support for non-transformer - pass - - def stem(self, x): - x = self.act1(self.bn1(self.conv1(x))) - x = self.act2(self.bn2(self.conv2(x))) - x = self.act3(self.bn3(self.conv3(x))) - x = self.avgpool(x) - return x - - def forward(self, x): - x = self.stem(x) - x = self.layer1(x) - x = self.layer2(x) - x = self.layer3(x) - x = self.layer4(x) - x = self.attnpool(x) - - return x diff --git a/src/proxyclip/open_clip_proxy/openai.py b/src/proxyclip/open_clip_proxy/openai.py deleted file mode 100644 index 6c2c0235245c2e4f1217b3b2bfaf2acf78e74981..0000000000000000000000000000000000000000 --- a/src/proxyclip/open_clip_proxy/openai.py +++ /dev/null @@ -1,90 +0,0 @@ -""" OpenAI pretrained model functions - -Adapted from https://github.com/openai/CLIP. Originally MIT License, Copyright (c) 2021 OpenAI. -""" - -import os -import warnings -from typing import List, Optional, Union - -import torch - -from .constants import OPENAI_DATASET_MEAN, OPENAI_DATASET_STD -from .model import build_model_from_openai_state_dict, convert_weights_to_lp, get_cast_dtype -from .pretrained import get_pretrained_url, list_pretrained_models_by_tag, download_pretrained_from_url - -__all__ = ["list_openai_models", "load_openai_model"] - - -def list_openai_models() -> List[str]: - """Returns the names of available CLIP models""" - return list_pretrained_models_by_tag('openai') - - -def load_openai_model( - name: str, - precision: Optional[str] = None, - device: Optional[Union[str, torch.device]] = None, - cache_dir: Optional[str] = None, -): - """Load a CLIP model - - Parameters - ---------- - name : str - A model name listed by `clip.available_models()`, or the path to a model checkpoint containing the state_dict - precision: str - Model precision, if None defaults to 'fp32' if device == 'cpu' else 'fp16'. - device : Union[str, torch.device] - The device to put the loaded model - cache_dir : Optional[str] - The directory to cache the downloaded model weights - - Returns - ------- - model : torch.nn.Module - The CLIP model - preprocess : Callable[[PIL.Image], torch.Tensor] - A torchvision transform that converts a PIL image into a tensor that the returned model can take as its input - """ - if device is None: - device = "cuda" if torch.cuda.is_available() else "cpu" - if precision is None: - precision = 'fp32' if device == 'cpu' else 'fp16' - - if get_pretrained_url(name, 'openai'): - model_path = download_pretrained_from_url(get_pretrained_url(name, 'openai'), cache_dir=cache_dir) - elif os.path.isfile(name): - model_path = name - else: - raise RuntimeError(f"Model {name} not found; available models = {list_openai_models()}") - - try: - # loading JIT archive - model = torch.jit.load(model_path, map_location="cpu").eval() - state_dict = None - except RuntimeError: - # loading saved state dict - state_dict = torch.load(model_path, map_location="cpu") - - # Build a non-jit model from the OpenAI jitted model state dict - cast_dtype = get_cast_dtype(precision) - try: - model = build_model_from_openai_state_dict(state_dict or model.state_dict(), cast_dtype=cast_dtype) - except KeyError: - sd = {k[7:]: v for k, v in state_dict["state_dict"].items()} - model = build_model_from_openai_state_dict(sd, cast_dtype=cast_dtype) - - # model from OpenAI state dict is in manually cast fp16 mode, must be converted for AMP/fp32/bf16 use - model = model.to(device) - # FIXME support pure fp16/bf16 precision modes - if precision != 'fp16': - model.float() - if precision == 'bf16': - # for bf16, convert back to low-precision - convert_weights_to_lp(model, dtype=torch.bfloat16) - - # add mean / std attributes for consistency with OpenCLIP models - model.visual.image_mean = OPENAI_DATASET_MEAN - model.visual.image_std = OPENAI_DATASET_STD - return model diff --git a/src/proxyclip/open_clip_proxy/pos_embed.py b/src/proxyclip/open_clip_proxy/pos_embed.py deleted file mode 100644 index 5c8082b34df2318dd25a4ec8346b3f9a888f38de..0000000000000000000000000000000000000000 --- a/src/proxyclip/open_clip_proxy/pos_embed.py +++ /dev/null @@ -1,96 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. - -# This source code is licensed under the license found in the -# LICENSE file in the root directory of this source tree. -# -------------------------------------------------------- -# Position embedding utils -# -------------------------------------------------------- - -import numpy as np - -import torch - -# -------------------------------------------------------- -# 2D sine-cosine position embedding -# References: -# Transformer: https://github.com/tensorflow/models/blob/master/official/nlp/transformer/model_utils.py -# MoCo v3: https://github.com/facebookresearch/moco-v3 -# -------------------------------------------------------- -def get_2d_sincos_pos_embed(embed_dim, grid_size, cls_token=False): - """ - grid_size: int of the grid height and width - return: - pos_embed: [grid_size*grid_size, embed_dim] or [1+grid_size*grid_size, embed_dim] (w/ or w/o cls_token) - """ - grid_h = np.arange(grid_size, dtype=np.float32) - grid_w = np.arange(grid_size, dtype=np.float32) - grid = np.meshgrid(grid_w, grid_h) # here w goes first - grid = np.stack(grid, axis=0) - - grid = grid.reshape([2, 1, grid_size, grid_size]) - pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid) - if cls_token: - pos_embed = np.concatenate([np.zeros([1, embed_dim]), pos_embed], axis=0) - return pos_embed - - -def get_2d_sincos_pos_embed_from_grid(embed_dim, grid): - assert embed_dim % 2 == 0 - - # use half of dimensions to encode grid_h - emb_h = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[0]) # (H*W, D/2) - emb_w = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[1]) # (H*W, D/2) - - emb = np.concatenate([emb_h, emb_w], axis=1) # (H*W, D) - return emb - - -def get_1d_sincos_pos_embed_from_grid(embed_dim, pos): - """ - embed_dim: output dimension for each position - pos: a list of positions to be encoded: size (M,) - out: (M, D) - """ - assert embed_dim % 2 == 0 - omega = np.arange(embed_dim // 2, dtype=float) - omega /= embed_dim / 2. - omega = 1. / 10000**omega # (D/2,) - - pos = pos.reshape(-1) # (M,) - out = np.einsum('m,d->md', pos, omega) # (M, D/2), outer product - - emb_sin = np.sin(out) # (M, D/2) - emb_cos = np.cos(out) # (M, D/2) - - emb = np.concatenate([emb_sin, emb_cos], axis=1) # (M, D) - return emb - - -# -------------------------------------------------------- -# Interpolate position embeddings for high-resolution -# References: -# DeiT: https://github.com/facebookresearch/deit -# -------------------------------------------------------- -def interpolate_pos_embed(model, checkpoint_model): - if 'pos_embed' in checkpoint_model: - pos_embed_checkpoint = checkpoint_model['pos_embed'] - embedding_size = pos_embed_checkpoint.shape[-1] - num_patches = model.patch_embed.num_patches - num_extra_tokens = model.pos_embed.shape[-2] - num_patches - # height (== width) for the checkpoint position embedding - orig_size = int((pos_embed_checkpoint.shape[-2] - num_extra_tokens) ** 0.5) - # height (== width) for the new position embedding - new_size = int(num_patches ** 0.5) - # class_token and dist_token are kept unchanged - if orig_size != new_size: - print("Position interpolate from %dx%d to %dx%d" % (orig_size, orig_size, new_size, new_size)) - extra_tokens = pos_embed_checkpoint[:, :num_extra_tokens] - # only the position tokens are interpolated - pos_tokens = pos_embed_checkpoint[:, num_extra_tokens:] - pos_tokens = pos_tokens.reshape(-1, orig_size, orig_size, embedding_size).permute(0, 3, 1, 2) - pos_tokens = torch.nn.functional.interpolate( - pos_tokens, size=(new_size, new_size), mode='bicubic', align_corners=False) - pos_tokens = pos_tokens.permute(0, 2, 3, 1).flatten(1, 2) - new_pos_embed = torch.cat((extra_tokens, pos_tokens), dim=1) - checkpoint_model['pos_embed'] = new_pos_embed diff --git a/src/proxyclip/open_clip_proxy/pretrained.py b/src/proxyclip/open_clip_proxy/pretrained.py deleted file mode 100644 index e7cd74fe19f52c3c7a604d3ea48fd559e0c991ed..0000000000000000000000000000000000000000 --- a/src/proxyclip/open_clip_proxy/pretrained.py +++ /dev/null @@ -1,584 +0,0 @@ -import hashlib -import os -import urllib -import warnings -from functools import partial -from typing import Dict, Union - -from tqdm import tqdm - -from .constants import ( - IMAGENET_MEAN, - IMAGENET_STD, - INCEPTION_MEAN, - INCEPTION_STD, - OPENAI_DATASET_MEAN, - OPENAI_DATASET_STD, -) -from .version import __version__ - -try: - from huggingface_hub import hf_hub_download - hf_hub_download = partial(hf_hub_download, library_name="open_clip", library_version=__version__) - _has_hf_hub = True -except ImportError: - hf_hub_download = None - _has_hf_hub = False - - -def _pcfg(url='', hf_hub='', **kwargs): - # OpenAI / OpenCLIP defaults - return { - 'url': url, - 'hf_hub': hf_hub, - 'mean': OPENAI_DATASET_MEAN, - 'std': OPENAI_DATASET_STD, - 'interpolation': 'bicubic', - 'resize_mode': 'shortest', - **kwargs, - } - - -def _slpcfg(url='', hf_hub='', **kwargs): - # SiGLIP defaults - return { - 'url': url, - 'hf_hub': hf_hub, - 'mean': INCEPTION_MEAN, - 'std': INCEPTION_STD, - 'interpolation': 'bicubic', - 'resize_mode': 'squash', - **kwargs, - } - - -def _apcfg(url='', hf_hub='', **kwargs): - # CLIPA defaults - return { - 'url': url, - 'hf_hub': hf_hub, - 'mean': IMAGENET_MEAN, - 'std': IMAGENET_STD, - 'interpolation': 'bilinear', - 'resize_mode': 'squash', - **kwargs, - } - - -_RN50 = dict( - openai=_pcfg( - "https://openaipublic.azureedge.net/clip/models/afeb0e10f9e5a86da6080e35cf09123aca3b358a0c3e3b6c78a7b63bc04b6762/RN50.pt"), - yfcc15m=_pcfg( - "https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/rn50-quickgelu-yfcc15m-455df137.pt"), - cc12m=_pcfg( - "https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/rn50-quickgelu-cc12m-f000538c.pt"), -) - -_RN50_quickgelu = dict( - openai=_pcfg( - "https://openaipublic.azureedge.net/clip/models/afeb0e10f9e5a86da6080e35cf09123aca3b358a0c3e3b6c78a7b63bc04b6762/RN50.pt"), - yfcc15m=_pcfg( - "https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/rn50-quickgelu-yfcc15m-455df137.pt"), - cc12m=_pcfg( - "https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/rn50-quickgelu-cc12m-f000538c.pt"), -) - -_RN101 = dict( - openai=_pcfg( - "https://openaipublic.azureedge.net/clip/models/8fa8567bab74a42d41c5915025a8e4538c3bdbe8804a470a72f30b0d94fab599/RN101.pt"), - yfcc15m=_pcfg( - "https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/rn101-quickgelu-yfcc15m-3e04b30e.pt"), -) - -_RN101_quickgelu = dict( - openai=_pcfg( - "https://openaipublic.azureedge.net/clip/models/8fa8567bab74a42d41c5915025a8e4538c3bdbe8804a470a72f30b0d94fab599/RN101.pt"), - yfcc15m=_pcfg( - "https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/rn101-quickgelu-yfcc15m-3e04b30e.pt"), -) - -_RN50x4 = dict( - openai=_pcfg( - "https://openaipublic.azureedge.net/clip/models/7e526bd135e493cef0776de27d5f42653e6b4c8bf9e0f653bb11773263205fdd/RN50x4.pt"), -) - -_RN50x16 = dict( - openai=_pcfg( - "https://openaipublic.azureedge.net/clip/models/52378b407f34354e150460fe41077663dd5b39c54cd0bfd2b27167a4a06ec9aa/RN50x16.pt"), -) - -_RN50x64 = dict( - openai=_pcfg( - "https://openaipublic.azureedge.net/clip/models/be1cfb55d75a9666199fb2206c106743da0f6468c9d327f3e0d0a543a9919d9c/RN50x64.pt"), -) - -_VITB32 = dict( - openai=_pcfg( - "https://openaipublic.azureedge.net/clip/models/40d365715913c9da98579312b702a82c18be219cc2a73407c4526f58eba950af/ViT-B-32.pt"), - laion400m_e31=_pcfg( - "https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/vit_b_32-quickgelu-laion400m_e31-d867053b.pt"), - laion400m_e32=_pcfg( - "https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/vit_b_32-quickgelu-laion400m_e32-46683a32.pt"), - laion2b_e16=_pcfg( - "https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/vit_b_32-laion2b_e16-af8dbd0c.pth"), - laion2b_s34b_b79k=_pcfg(hf_hub='laion/CLIP-ViT-B-32-laion2B-s34B-b79K/'), - # DataComp-XL models - datacomp_xl_s13b_b90k=_pcfg(hf_hub='laion/CLIP-ViT-B-32-DataComp.XL-s13B-b90K/'), - # DataComp-M models - datacomp_m_s128m_b4k=_pcfg(hf_hub='laion/CLIP-ViT-B-32-DataComp.M-s128M-b4K/'), - commonpool_m_clip_s128m_b4k=_pcfg(hf_hub='laion/CLIP-ViT-B-32-CommonPool.M.clip-s128M-b4K/'), - commonpool_m_laion_s128m_b4k=_pcfg(hf_hub='laion/CLIP-ViT-B-32-CommonPool.M.laion-s128M-b4K/'), - commonpool_m_image_s128m_b4k=_pcfg(hf_hub='laion/CLIP-ViT-B-32-CommonPool.M.image-s128M-b4K/'), - commonpool_m_text_s128m_b4k=_pcfg(hf_hub='laion/CLIP-ViT-B-32-CommonPool.M.text-s128M-b4K/'), - commonpool_m_basic_s128m_b4k=_pcfg(hf_hub='laion/CLIP-ViT-B-32-CommonPool.M.basic-s128M-b4K/'), - commonpool_m_s128m_b4k=_pcfg(hf_hub='laion/CLIP-ViT-B-32-CommonPool.M-s128M-b4K/'), - # DataComp-S models - datacomp_s_s13m_b4k=_pcfg(hf_hub='laion/CLIP-ViT-B-32-DataComp.S-s13M-b4K/'), - commonpool_s_clip_s13m_b4k=_pcfg(hf_hub='laion/CLIP-ViT-B-32-CommonPool.S.clip-s13M-b4K/'), - commonpool_s_laion_s13m_b4k=_pcfg(hf_hub='laion/CLIP-ViT-B-32-CommonPool.S.laion-s13M-b4K/'), - commonpool_s_image_s13m_b4k=_pcfg(hf_hub='laion/CLIP-ViT-B-32-CommonPool.S.image-s13M-b4K/'), - commonpool_s_text_s13m_b4k=_pcfg(hf_hub='laion/CLIP-ViT-B-32-CommonPool.S.text-s13M-b4K/'), - commonpool_s_basic_s13m_b4k=_pcfg(hf_hub='laion/CLIP-ViT-B-32-CommonPool.S.basic-s13M-b4K/'), - commonpool_s_s13m_b4k=_pcfg(hf_hub='laion/CLIP-ViT-B-32-CommonPool.S-s13M-b4K/'), -) - -_VITB32_quickgelu = dict( - openai=_pcfg( - "https://openaipublic.azureedge.net/clip/models/40d365715913c9da98579312b702a82c18be219cc2a73407c4526f58eba950af/ViT-B-32.pt"), - laion400m_e31=_pcfg( - "https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/vit_b_32-quickgelu-laion400m_e31-d867053b.pt"), - laion400m_e32=_pcfg( - "https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/vit_b_32-quickgelu-laion400m_e32-46683a32.pt"), - metaclip_400m=_pcfg( - "https://dl.fbaipublicfiles.com/MMPT/metaclip/b32_400m.pt"), - metaclip_fullcc=_pcfg( - "https://dl.fbaipublicfiles.com/MMPT/metaclip/b32_fullcc2.5b.pt"), -) - -_VITB32_256 = dict( - datacomp_s34b_b86k=_pcfg(hf_hub='laion/CLIP-ViT-B-32-256x256-DataComp-s34B-b86K/'), -) - -_VITB16 = dict( - openai=_pcfg( - "https://openaipublic.azureedge.net/clip/models/5806e77cd80f8b59890b7e101eabd078d9fb84e6937f9e85e4ecb61988df416f/ViT-B-16.pt"), - laion400m_e31=_pcfg( - "https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/vit_b_16-laion400m_e31-00efa78f.pt"), - laion400m_e32=_pcfg( - "https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/vit_b_16-laion400m_e32-55e67d44.pt"), - laion2b_s34b_b88k=_pcfg(hf_hub='laion/CLIP-ViT-B-16-laion2B-s34B-b88K/'), - # DataComp-XL models - datacomp_xl_s13b_b90k=_pcfg(hf_hub='laion/CLIP-ViT-B-16-DataComp.XL-s13B-b90K/'), - # DataComp-L models - datacomp_l_s1b_b8k=_pcfg(hf_hub='laion/CLIP-ViT-B-16-DataComp.L-s1B-b8K/'), - commonpool_l_clip_s1b_b8k=_pcfg(hf_hub='laion/CLIP-ViT-B-16-CommonPool.L.clip-s1B-b8K/'), - commonpool_l_laion_s1b_b8k=_pcfg(hf_hub='laion/CLIP-ViT-B-16-CommonPool.L.laion-s1B-b8K/'), - commonpool_l_image_s1b_b8k=_pcfg(hf_hub='laion/CLIP-ViT-B-16-CommonPool.L.image-s1B-b8K/'), - commonpool_l_text_s1b_b8k=_pcfg(hf_hub='laion/CLIP-ViT-B-16-CommonPool.L.text-s1B-b8K/'), - commonpool_l_basic_s1b_b8k=_pcfg(hf_hub='laion/CLIP-ViT-B-16-CommonPool.L.basic-s1B-b8K/'), - commonpool_l_s1b_b8k=_pcfg(hf_hub='laion/CLIP-ViT-B-16-CommonPool.L-s1B-b8K/'), - # DFN - dfn2b=_pcfg(hf_hub='apple/DFN2B-CLIP-ViT-B-16/') -) - -_VITB16_quickgelu = dict( - metaclip_400m=_pcfg( - "https://dl.fbaipublicfiles.com/MMPT/metaclip/b16_400m.pt"), - metaclip_fullcc=_pcfg( - "https://dl.fbaipublicfiles.com/MMPT/metaclip/b16_fullcc2.5b.pt"), -) - -_VITB16_PLUS_240 = dict( - laion400m_e31=_pcfg( - "https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/vit_b_16_plus_240-laion400m_e31-8fb26589.pt"), - laion400m_e32=_pcfg( - "https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/vit_b_16_plus_240-laion400m_e32-699c4b84.pt"), -) - -_VITL14 = dict( - openai=_pcfg( - "https://openaipublic.azureedge.net/clip/models/b8cca3fd41ae0c99ba7e8951adf17d267cdb84cd88be6f7c2e0eca1737a03836/ViT-L-14.pt"), - laion400m_e31=_pcfg( - "https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/vit_l_14-laion400m_e31-69988bb6.pt"), - laion400m_e32=_pcfg( - "https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/vit_l_14-laion400m_e32-3d133497.pt"), - laion2b_s32b_b82k=_pcfg( - hf_hub='laion/CLIP-ViT-L-14-laion2B-s32B-b82K/', - mean=INCEPTION_MEAN, std=INCEPTION_STD), - # DataComp-XL models - datacomp_xl_s13b_b90k=_pcfg(hf_hub='laion/CLIP-ViT-L-14-DataComp.XL-s13B-b90K/'), - commonpool_xl_clip_s13b_b90k=_pcfg(hf_hub='laion/CLIP-ViT-L-14-CommonPool.XL.clip-s13B-b90K/'), - commonpool_xl_laion_s13b_b90k=_pcfg(hf_hub='laion/CLIP-ViT-L-14-CommonPool.XL.laion-s13B-b90K/'), - commonpool_xl_s13b_b90k=_pcfg(hf_hub='laion/CLIP-ViT-L-14-CommonPool.XL-s13B-b90K/'), -) - -_VITL14_quickgelu = dict( - metaclip_400m=_pcfg( - "https://dl.fbaipublicfiles.com/MMPT/metaclip/l14_400m.pt"), - metaclip_fullcc=_pcfg( - "https://dl.fbaipublicfiles.com/MMPT/metaclip/l14_fullcc2.5b.pt"), - dfn2b=_pcfg(hf_hub='apple/DFN2B-CLIP-ViT-L-14/'), -) - -_VITL14_336 = dict( - openai=_pcfg( - "https://openaipublic.azureedge.net/clip/models/3035c92b350959924f9f00213499208652fc7ea050643e8b385c2dac08641f02/ViT-L-14-336px.pt"), -) - -_VITH14 = dict( - laion2b_s32b_b79k=_pcfg(hf_hub='laion/CLIP-ViT-H-14-laion2B-s32B-b79K/'), -) - -_VITH14_quickgelu = dict( - metaclip_fullcc=_pcfg( - "https://dl.fbaipublicfiles.com/MMPT/metaclip/h14_fullcc2.5b.pt"), - dfn5b=_pcfg( - hf_hub='apple/DFN5B-CLIP-ViT-H-14/', - interpolation="bicubic", - resize_mode="squash" - ), -) - -_VITH14_378_quickgelu = dict( - dfn5b=_pcfg( - hf_hub='apple/DFN5B-CLIP-ViT-H-14-378/', - interpolation="bicubic", - resize_mode="squash" - ), -) - -_VITg14 = dict( - laion2b_s12b_b42k=_pcfg(hf_hub='laion/CLIP-ViT-g-14-laion2B-s12B-b42K/'), - laion2b_s34b_b88k=_pcfg(hf_hub='laion/CLIP-ViT-g-14-laion2B-s34B-b88K/'), -) - -_VITbigG14 = dict( - laion2b_s39b_b160k=_pcfg(hf_hub='laion/CLIP-ViT-bigG-14-laion2B-39B-b160k/'), -) - -_robertaViTB32 = dict( - laion2b_s12b_b32k=_pcfg(hf_hub='laion/CLIP-ViT-B-32-roberta-base-laion2B-s12B-b32k/'), -) - -_xlmRobertaBaseViTB32 = dict( - laion5b_s13b_b90k=_pcfg(hf_hub='laion/CLIP-ViT-B-32-xlm-roberta-base-laion5B-s13B-b90k/'), -) - -_xlmRobertaLargeFrozenViTH14 = dict( - frozen_laion5b_s13b_b90k=_pcfg(hf_hub='laion/CLIP-ViT-H-14-frozen-xlm-roberta-large-laion5B-s13B-b90k/'), -) - -_convnext_base = dict( - laion400m_s13b_b51k=_pcfg(hf_hub='laion/CLIP-convnext_base-laion400M-s13B-b51K/'), -) - -_convnext_base_w = dict( - laion2b_s13b_b82k=_pcfg(hf_hub='laion/CLIP-convnext_base_w-laion2B-s13B-b82K/'), - laion2b_s13b_b82k_augreg=_pcfg(hf_hub='laion/CLIP-convnext_base_w-laion2B-s13B-b82K-augreg/'), - laion_aesthetic_s13b_b82k=_pcfg(hf_hub='laion/CLIP-convnext_base_w-laion_aesthetic-s13B-b82K/'), -) - -_convnext_base_w_320 = dict( - laion_aesthetic_s13b_b82k=_pcfg(hf_hub='laion/CLIP-convnext_base_w_320-laion_aesthetic-s13B-b82K/'), - laion_aesthetic_s13b_b82k_augreg=_pcfg(hf_hub='laion/CLIP-convnext_base_w_320-laion_aesthetic-s13B-b82K-augreg/'), -) - -_convnext_large_d = dict( - laion2b_s26b_b102k_augreg=_pcfg(hf_hub='laion/CLIP-convnext_large_d.laion2B-s26B-b102K-augreg/'), -) - -_convnext_large_d_320 = dict( - laion2b_s29b_b131k_ft=_pcfg(hf_hub='laion/CLIP-convnext_large_d_320.laion2B-s29B-b131K-ft/'), - laion2b_s29b_b131k_ft_soup=_pcfg(hf_hub='laion/CLIP-convnext_large_d_320.laion2B-s29B-b131K-ft-soup/'), -) - -_convnext_xxlarge = dict( - laion2b_s34b_b82k_augreg=_pcfg(hf_hub='laion/CLIP-convnext_xxlarge-laion2B-s34B-b82K-augreg/'), - laion2b_s34b_b82k_augreg_rewind=_pcfg(hf_hub='laion/CLIP-convnext_xxlarge-laion2B-s34B-b82K-augreg-rewind/'), - laion2b_s34b_b82k_augreg_soup=_pcfg(hf_hub='laion/CLIP-convnext_xxlarge-laion2B-s34B-b82K-augreg-soup/'), -) - -_coca_VITB32 = dict( - laion2b_s13b_b90k=_pcfg(hf_hub='laion/CoCa-ViT-B-32-laion2B-s13B-b90k/'), - mscoco_finetuned_laion2b_s13b_b90k=_pcfg(hf_hub='laion/mscoco_finetuned_CoCa-ViT-B-32-laion2B-s13B-b90k/') -) - -_coca_VITL14 = dict( - laion2b_s13b_b90k=_pcfg(hf_hub='laion/CoCa-ViT-L-14-laion2B-s13B-b90k/'), - mscoco_finetuned_laion2b_s13b_b90k=_pcfg(hf_hub='laion/mscoco_finetuned_CoCa-ViT-L-14-laion2B-s13B-b90k/') -) - - -_PRETRAINED = { - "RN50": _RN50, - "RN50-quickgelu": _RN50_quickgelu, - "RN101": _RN101, - "RN101-quickgelu": _RN101_quickgelu, - "RN50x4": _RN50x4, - "RN50x16": _RN50x16, - "RN50x64": _RN50x64, - - "ViT-B-32": _VITB32, - "ViT-B-32-256": _VITB32_256, - "ViT-B-32-quickgelu": _VITB32_quickgelu, - "ViT-B-16": _VITB16, - "ViT-B-16-quickgelu": _VITB16_quickgelu, - "ViT-B-16-plus-240": _VITB16_PLUS_240, - "ViT-L-14": _VITL14, - "ViT-L-14-quickgelu": _VITL14_quickgelu, - "ViT-L-14-336": _VITL14_336, - "ViT-H-14": _VITH14, - "ViT-H-14-quickgelu": _VITH14_quickgelu, - "ViT-H-14-378-quickgelu": _VITH14_378_quickgelu, - "ViT-g-14": _VITg14, - "ViT-bigG-14": _VITbigG14, - - "roberta-ViT-B-32": _robertaViTB32, - "xlm-roberta-base-ViT-B-32": _xlmRobertaBaseViTB32, - "xlm-roberta-large-ViT-H-14": _xlmRobertaLargeFrozenViTH14, - - "convnext_base": _convnext_base, - "convnext_base_w": _convnext_base_w, - "convnext_base_w_320": _convnext_base_w_320, - "convnext_large_d": _convnext_large_d, - "convnext_large_d_320": _convnext_large_d_320, - "convnext_xxlarge": _convnext_xxlarge, - - "coca_ViT-B-32": _coca_VITB32, - "coca_ViT-L-14": _coca_VITL14, - - "EVA01-g-14": dict( - # from QuanSun/EVA-CLIP/EVA01_CLIP_g_14_psz14_s11B.pt - laion400m_s11b_b41k=_pcfg(hf_hub='timm/eva_giant_patch14_clip_224.laion400m_s11b_b41k/'), - ), - "EVA01-g-14-plus": dict( - # from QuanSun/EVA-CLIP/EVA01_CLIP_g_14_plus_psz14_s11B.pt - merged2b_s11b_b114k=_pcfg(hf_hub='timm/eva_giant_patch14_plus_clip_224.merged2b_s11b_b114k/'), - ), - "EVA02-B-16": dict( - # from QuanSun/EVA-CLIP/EVA02_CLIP_B_psz16_s8B.pt - merged2b_s8b_b131k=_pcfg(hf_hub='timm/eva02_base_patch16_clip_224.merged2b_s8b_b131k/'), - ), - "EVA02-L-14": dict( - # from QuanSun/EVA-CLIP/EVA02_CLIP_L_psz14_s4B.pt - merged2b_s4b_b131k=_pcfg(hf_hub='timm/eva02_large_patch14_clip_224.merged2b_s4b_b131k/'), - ), - "EVA02-L-14-336": dict( - # from QuanSun/EVA-CLIP/EVA02_CLIP_L_336_psz14_s6B.pt - merged2b_s6b_b61k=_pcfg(hf_hub='timm/eva02_large_patch14_clip_336.merged2b_s6b_b61k/'), - ), - "EVA02-E-14": dict( - # from QuanSun/EVA-CLIP/EVA02_CLIP_E_psz14_s4B.pt - laion2b_s4b_b115k=_pcfg(hf_hub='timm/eva02_enormous_patch14_clip_224.laion2b_s4b_b115k/'), - ), - "EVA02-E-14-plus": dict( - # from QuanSun/EVA-CLIP/EVA02_CLIP_E_psz14_plus_s9B.pt - laion2b_s9b_b144k=_pcfg(hf_hub='timm/eva02_enormous_patch14_plus_clip_224.laion2b_s9b_b144k/'), - ), - - "ViT-B-16-SigLIP": dict( - webli=_slpcfg(hf_hub='timm/ViT-B-16-SigLIP/'), - ), - "ViT-B-16-SigLIP-256": dict( - webli=_slpcfg(hf_hub='timm/ViT-B-16-SigLIP-256/'), - ), - "ViT-B-16-SigLIP-i18n-256": dict( - webli=_slpcfg(hf_hub='timm/ViT-B-16-SigLIP-i18n-256/'), - ), - "ViT-B-16-SigLIP-384": dict( - webli=_slpcfg(hf_hub='timm/ViT-B-16-SigLIP-384/'), - ), - "ViT-B-16-SigLIP-512": dict( - webli=_slpcfg(hf_hub='timm/ViT-B-16-SigLIP-512/'), - ), - "ViT-L-16-SigLIP-256": dict( - webli=_slpcfg(hf_hub='timm/ViT-L-16-SigLIP-256/'), - ), - "ViT-L-16-SigLIP-384": dict( - webli=_slpcfg(hf_hub='timm/ViT-L-16-SigLIP-384/'), - ), - "ViT-SO400M-14-SigLIP": dict( - webli=_slpcfg(hf_hub='timm/ViT-SO400M-14-SigLIP/'), - ), - "ViT-SO400M-14-SigLIP-384": dict( - webli=_slpcfg(hf_hub='timm/ViT-SO400M-14-SigLIP-384/'), - ), - - "ViT-L-14-CLIPA": dict( - datacomp1b=_apcfg(hf_hub='UCSC-VLAA/ViT-L-14-CLIPA-datacomp1B/'), - ), - "ViT-L-14-CLIPA-336": dict( - datacomp1b=_apcfg(hf_hub='UCSC-VLAA/ViT-L-14-CLIPA-336-datacomp1B/'), - ), - "ViT-H-14-CLIPA": dict( - datacomp1b=_apcfg(hf_hub='UCSC-VLAA/ViT-H-14-CLIPA-datacomp1B/'), - ), - "ViT-H-14-CLIPA-336": dict( - laion2b=_apcfg(hf_hub='UCSC-VLAA/ViT-H-14-CLIPA-336-laion2B/'), - datacomp1b=_apcfg(hf_hub='UCSC-VLAA/ViT-H-14-CLIPA-336-datacomp1B/'), - ), - "ViT-bigG-14-CLIPA": dict( - datacomp1b=_apcfg(hf_hub='UCSC-VLAA/ViT-bigG-14-CLIPA-datacomp1B/'), - ), - "ViT-bigG-14-CLIPA-336": dict( - datacomp1b=_apcfg(hf_hub='UCSC-VLAA/ViT-bigG-14-CLIPA-336-datacomp1B/'), - ), - - "nllb-clip-base": dict( - v1=_pcfg(hf_hub='visheratin/nllb-clip-base-oc/'), - ), - "nllb-clip-large": dict( - v1=_pcfg(hf_hub='visheratin/nllb-clip-large-oc/'), - ), - - "nllb-clip-base-siglip": dict( - v1=_slpcfg(hf_hub='visheratin/nllb-clip-base-siglip/'), - ), - "nllb-clip-large-siglip": dict( - v1=_slpcfg(hf_hub='visheratin/nllb-clip-large-siglip/'), - ) -} - - -def _clean_tag(tag: str): - # normalize pretrained tags - return tag.lower().replace('-', '_') - - -def list_pretrained(as_str: bool = False): - """ returns list of pretrained models - Returns a tuple (model_name, pretrain_tag) by default or 'name:tag' if as_str == True - """ - return [':'.join([k, t]) if as_str else (k, t) for k in _PRETRAINED.keys() for t in _PRETRAINED[k].keys()] - - -def list_pretrained_models_by_tag(tag: str): - """ return all models having the specified pretrain tag """ - models = [] - tag = _clean_tag(tag) - for k in _PRETRAINED.keys(): - if tag in _PRETRAINED[k]: - models.append(k) - return models - - -def list_pretrained_tags_by_model(model: str): - """ return all pretrain tags for the specified model architecture """ - tags = [] - if model in _PRETRAINED: - tags.extend(_PRETRAINED[model].keys()) - return tags - - -def is_pretrained_cfg(model: str, tag: str): - if model not in _PRETRAINED: - return False - return _clean_tag(tag) in _PRETRAINED[model] - - -def get_pretrained_cfg(model: str, tag: str): - if model not in _PRETRAINED: - return {} - model_pretrained = _PRETRAINED[model] - return model_pretrained.get(_clean_tag(tag), {}) - - -def get_pretrained_url(model: str, tag: str): - cfg = get_pretrained_cfg(model, _clean_tag(tag)) - return cfg.get('url', '') - - -def download_pretrained_from_url( - url: str, - cache_dir: Union[str, None] = None, -): - if not cache_dir: - cache_dir = os.path.expanduser("~/.cache/clip") - os.makedirs(cache_dir, exist_ok=True) - filename = os.path.basename(url) - - if 'openaipublic' in url: - expected_sha256 = url.split("/")[-2] - elif 'mlfoundations' in url: - expected_sha256 = os.path.splitext(filename)[0].split("-")[-1] - else: - expected_sha256 = '' - - download_target = os.path.join(cache_dir, filename) - - if os.path.exists(download_target) and not os.path.isfile(download_target): - raise RuntimeError(f"{download_target} exists and is not a regular file") - - if os.path.isfile(download_target): - if expected_sha256: - if hashlib.sha256(open(download_target, "rb").read()).hexdigest().startswith(expected_sha256): - return download_target - else: - warnings.warn(f"{download_target} exists, but the SHA256 checksum does not match; re-downloading the file") - else: - return download_target - - with urllib.request.urlopen(url) as source, open(download_target, "wb") as output: - with tqdm(total=int(source.headers.get("Content-Length")), ncols=80, unit='iB', unit_scale=True) as loop: - while True: - buffer = source.read(8192) - if not buffer: - break - - output.write(buffer) - loop.update(len(buffer)) - - if expected_sha256 and not hashlib.sha256(open(download_target, "rb").read()).hexdigest().startswith(expected_sha256): - raise RuntimeError(f"Model has been downloaded but the SHA256 checksum does not not match") - - return download_target - - -def has_hf_hub(necessary=False): - if not _has_hf_hub and necessary: - # if no HF Hub module installed, and it is necessary to continue, raise error - raise RuntimeError( - 'Hugging Face hub model specified but package not installed. Run `pip install huggingface_hub`.') - return _has_hf_hub - - -def download_pretrained_from_hf( - model_id: str, - filename: str = 'open_clip_pytorch_model.bin', - revision=None, - cache_dir: Union[str, None] = None, -): - has_hf_hub(True) - cached_file = hf_hub_download(model_id, filename, revision=revision, cache_dir=cache_dir) - return cached_file - - -def download_pretrained( - cfg: Dict, - force_hf_hub: bool = False, - cache_dir: Union[str, None] = None, -): - target = '' - if not cfg: - return target - - download_url = cfg.get('url', '') - download_hf_hub = cfg.get('hf_hub', '') - if download_hf_hub and force_hf_hub: - # use HF hub even if url exists - download_url = '' - - if download_url: - target = download_pretrained_from_url(download_url, cache_dir=cache_dir) - elif download_hf_hub: - has_hf_hub(True) - # we assume the hf_hub entries in pretrained config combine model_id + filename in - # 'org/model_name/filename.pt' form. To specify just the model id w/o filename and - # use 'open_clip_pytorch_model.bin' default, there must be a trailing slash 'org/model_name/'. - model_id, filename = os.path.split(download_hf_hub) - if filename: - target = download_pretrained_from_hf(model_id, filename=filename, cache_dir=cache_dir) - else: - target = download_pretrained_from_hf(model_id, cache_dir=cache_dir) - - return target diff --git a/src/proxyclip/open_clip_proxy/push_to_hf_hub.py b/src/proxyclip/open_clip_proxy/push_to_hf_hub.py deleted file mode 100644 index dcb8a78b587a585dcf3e3518d66cc00b371e4a82..0000000000000000000000000000000000000000 --- a/src/proxyclip/open_clip_proxy/push_to_hf_hub.py +++ /dev/null @@ -1,317 +0,0 @@ -import argparse -import json -import os -from pathlib import Path -from tempfile import TemporaryDirectory -from typing import Optional, Tuple, Union - -import torch - -try: - from huggingface_hub import ( - create_repo, - get_hf_file_metadata, - hf_hub_download, - hf_hub_url, - repo_type_and_id_from_hf_id, - upload_folder, - list_repo_files, - ) - from huggingface_hub.utils import EntryNotFoundError - _has_hf_hub = True -except ImportError: - _has_hf_hub = False - -try: - import safetensors.torch - _has_safetensors = True -except ImportError: - _has_safetensors = False - -from .factory import create_model_from_pretrained, get_model_config, get_tokenizer -from .tokenizer import HFTokenizer - -# Default name for a weights file hosted on the Huggingface Hub. -HF_WEIGHTS_NAME = "open_clip_pytorch_model.bin" # default pytorch pkl -HF_SAFE_WEIGHTS_NAME = "open_clip_model.safetensors" # safetensors version -HF_CONFIG_NAME = 'open_clip_config.json' - - -def save_config_for_hf( - model, - config_path: str, - model_config: Optional[dict] -): - preprocess_cfg = { - 'mean': model.visual.image_mean, - 'std': model.visual.image_std, - } - other_pp = getattr(model.visual, 'preprocess_cfg', {}) - if 'interpolation' in other_pp: - preprocess_cfg['interpolation'] = other_pp['interpolation'] - if 'resize_mode' in other_pp: - preprocess_cfg['resize_mode'] = other_pp['resize_mode'] - hf_config = { - 'model_cfg': model_config, - 'preprocess_cfg': preprocess_cfg, - } - - with config_path.open('w') as f: - json.dump(hf_config, f, indent=2) - - -def save_for_hf( - model, - tokenizer: HFTokenizer, - model_config: dict, - save_directory: str, - safe_serialization: Union[bool, str] = 'both', - skip_weights : bool = False, -): - config_filename = HF_CONFIG_NAME - - save_directory = Path(save_directory) - save_directory.mkdir(exist_ok=True, parents=True) - - if not skip_weights: - tensors = model.state_dict() - if safe_serialization is True or safe_serialization == "both": - assert _has_safetensors, "`pip install safetensors` to use .safetensors" - safetensors.torch.save_file(tensors, save_directory / HF_SAFE_WEIGHTS_NAME) - if safe_serialization is False or safe_serialization == "both": - torch.save(tensors, save_directory / HF_WEIGHTS_NAME) - - tokenizer.save_pretrained(save_directory) - - config_path = save_directory / config_filename - save_config_for_hf(model, config_path, model_config=model_config) - - -def push_to_hf_hub( - model, - tokenizer, - model_config: Optional[dict], - repo_id: str, - commit_message: str = 'Add model', - token: Optional[str] = None, - revision: Optional[str] = None, - private: bool = False, - create_pr: bool = False, - model_card: Optional[dict] = None, - safe_serialization: Union[bool, str] = False, -): - if not isinstance(tokenizer, HFTokenizer): - # FIXME this makes it awkward to push models with new tokenizers, come up with better soln. - # default CLIP tokenizers use https://huggingface.co/openai/clip-vit-large-patch14 - tokenizer = HFTokenizer('openai/clip-vit-large-patch14') - - # Create repo if it doesn't exist yet - repo_url = create_repo(repo_id, token=token, private=private, exist_ok=True) - - # Infer complete repo_id from repo_url - # Can be different from the input `repo_id` if repo_owner was implicit - _, repo_owner, repo_name = repo_type_and_id_from_hf_id(repo_url) - repo_id = f"{repo_owner}/{repo_name}" - - # Check if repo already exists and determine what needs updating - repo_exists = False - repo_files = {} - try: - repo_files = set(list_repo_files(repo_id)) - repo_exists = True - except Exception as e: - print('Repo does not exist', e) - - try: - get_hf_file_metadata(hf_hub_url(repo_id=repo_id, filename="README.md", revision=revision)) - has_readme = True - except EntryNotFoundError: - has_readme = False - - # Dump model and push to Hub - with TemporaryDirectory() as tmpdir: - # Save model weights and config. - save_for_hf( - model, - tokenizer=tokenizer, - model_config=model_config, - save_directory=tmpdir, - safe_serialization=safe_serialization, - ) - - # Add readme if it does not exist - if not has_readme: - model_card = model_card or {} - model_name = repo_id.split('/')[-1] - readme_path = Path(tmpdir) / "README.md" - readme_text = generate_readme(model_card, model_name) - readme_path.write_text(readme_text) - - # Upload model and return - return upload_folder( - repo_id=repo_id, - folder_path=tmpdir, - revision=revision, - create_pr=create_pr, - commit_message=commit_message, - ) - - -def push_pretrained_to_hf_hub( - model_name, - pretrained: str, - repo_id: str, - precision: str = 'fp32', - image_mean: Optional[Tuple[float, ...]] = None, - image_std: Optional[Tuple[float, ...]] = None, - image_interpolation: Optional[str] = None, - image_resize_mode: Optional[str] = None, # only effective for inference - commit_message: str = 'Add model', - token: Optional[str] = None, - revision: Optional[str] = None, - private: bool = False, - create_pr: bool = False, - model_card: Optional[dict] = None, - hf_tokenizer_self: bool = False, -): - model, preprocess_eval = create_model_from_pretrained( - model_name, - pretrained=pretrained, - precision=precision, - image_mean=image_mean, - image_std=image_std, - image_interpolation=image_interpolation, - image_resize_mode=image_resize_mode, - ) - model_config = get_model_config(model_name) - assert model_config - - tokenizer = get_tokenizer(model_name) - if hf_tokenizer_self: - # make hf tokenizer config in the uploaded model point to self instead of original location - model_config['text']['hf_tokenizer_name'] = repo_id - - push_to_hf_hub( - model=model, - tokenizer=tokenizer, - model_config=model_config, - repo_id=repo_id, - commit_message=commit_message, - token=token, - revision=revision, - private=private, - create_pr=create_pr, - model_card=model_card, - safe_serialization='both', - ) - - -def generate_readme(model_card: dict, model_name: str): - tags = model_card.pop('tags', ('clip',)) - pipeline_tag = model_card.pop('pipeline_tag', 'zero-shot-image-classification') - readme_text = "---\n" - if tags: - readme_text += "tags:\n" - for t in tags: - readme_text += f"- {t}\n" - readme_text += "library_name: open_clip\n" - readme_text += f"pipeline_tag: {pipeline_tag}\n" - readme_text += f"license: {model_card.get('license', 'mit')}\n" - if 'details' in model_card and 'Dataset' in model_card['details']: - readme_text += 'datasets:\n' - readme_text += f"- {model_card['details']['Dataset'].lower()}\n" - readme_text += "---\n" - readme_text += f"# Model card for {model_name}\n" - if 'description' in model_card: - readme_text += f"\n{model_card['description']}\n" - if 'details' in model_card: - readme_text += f"\n## Model Details\n" - for k, v in model_card['details'].items(): - if isinstance(v, (list, tuple)): - readme_text += f"- **{k}:**\n" - for vi in v: - readme_text += f" - {vi}\n" - elif isinstance(v, dict): - readme_text += f"- **{k}:**\n" - for ki, vi in v.items(): - readme_text += f" - {ki}: {vi}\n" - else: - readme_text += f"- **{k}:** {v}\n" - if 'usage' in model_card: - readme_text += f"\n## Model Usage\n" - readme_text += model_card['usage'] - readme_text += '\n' - - if 'comparison' in model_card: - readme_text += f"\n## Model Comparison\n" - readme_text += model_card['comparison'] - readme_text += '\n' - - if 'citation' in model_card: - readme_text += f"\n## Citation\n" - if not isinstance(model_card['citation'], (list, tuple)): - citations = [model_card['citation']] - else: - citations = model_card['citation'] - for c in citations: - readme_text += f"```bibtex\n{c}\n```\n" - - return readme_text - - -if __name__ == "__main__": - parser = argparse.ArgumentParser(description="Push to Hugging Face Hub") - parser.add_argument( - "--model", type=str, help="Name of the model to use.", - ) - parser.add_argument( - "--pretrained", type=str, - help="Use a pretrained CLIP model weights with the specified tag or file path.", - ) - parser.add_argument( - "--repo-id", type=str, - help="Destination HF Hub repo-id ie 'organization/model_id'.", - ) - parser.add_argument( - "--precision", type=str, default='fp32', - ) - parser.add_argument( - '--image-mean', type=float, nargs='+', default=None, metavar='MEAN', - help='Override default image mean value of dataset') - parser.add_argument( - '--image-std', type=float, nargs='+', default=None, metavar='STD', - help='Override default image std deviation of of dataset') - parser.add_argument( - '--image-interpolation', - default=None, type=str, choices=['bicubic', 'bilinear', 'random'], - help="image resize interpolation" - ) - parser.add_argument( - '--image-resize-mode', - default=None, type=str, choices=['shortest', 'longest', 'squash'], - help="image resize mode during inference" - ) - parser.add_argument( - "--hf-tokenizer-self", - default=False, - action="store_true", - help="make hf_tokenizer_name point in uploaded config point to itself" - ) - args = parser.parse_args() - - print(f'Saving model {args.model} with pretrained weights {args.pretrained} to Hugging Face Hub at {args.repo_id}') - - # FIXME add support to pass model_card json / template from file via cmd line - - push_pretrained_to_hf_hub( - args.model, - args.pretrained, - args.repo_id, - precision=args.precision, - image_mean=args.image_mean, # override image mean/std if trained w/ non defaults - image_std=args.image_std, - image_interpolation=args.image_interpolation, - image_resize_mode=args.image_resize_mode, - ) - - print(f'{args.model} saved.') diff --git a/src/proxyclip/open_clip_proxy/timm_model.py b/src/proxyclip/open_clip_proxy/timm_model.py deleted file mode 100644 index 5ddb9a76bf085feeb8c20f3a39a6cfa4c2b643b4..0000000000000000000000000000000000000000 --- a/src/proxyclip/open_clip_proxy/timm_model.py +++ /dev/null @@ -1,152 +0,0 @@ -""" timm model adapter - -Wraps timm (https://github.com/rwightman/pytorch-image-models) models for use as a vision tower in CLIP model. -""" -import logging -from collections import OrderedDict - -import torch -import torch.nn as nn - -try: - import timm - from timm.models.layers import Mlp, to_2tuple - try: - # old timm imports < 0.8.1 - from timm.models.layers.attention_pool2d import RotAttentionPool2d - from timm.models.layers.attention_pool2d import AttentionPool2d as AbsAttentionPool2d - except ImportError: - # new timm imports >= 0.8.1 - from timm.layers import RotAttentionPool2d - from timm.layers import AttentionPool2d as AbsAttentionPool2d -except ImportError: - timm = None - -from .utils import freeze_batch_norm_2d - - -class TimmModel(nn.Module): - """ timm model adapter - """ - - def __init__( - self, - model_name, - embed_dim, - image_size=224, - pool='avg', - proj='linear', - proj_bias=False, - drop=0., - drop_path=None, - patch_drop=None, - pretrained=False, - ): - super().__init__() - if timm is None: - raise RuntimeError("Please `pip install timm` to use timm models.") - self.image_size = to_2tuple(image_size) - - # setup kwargs that may not be common across all models - timm_kwargs = {} - if drop_path is not None: - timm_kwargs['drop_path_rate'] = drop_path - if patch_drop is not None: - timm_kwargs['patch_drop_rate'] = patch_drop - - custom_pool = pool in ('abs_attn', 'rot_attn') - if proj: - assert proj in ("linear", "mlp", "none") - extra_proj = proj in ("linear", "mlp") - if not extra_proj and not custom_pool: - # use network classifier head as projection if no proj specified and no custom pooling used - # if projection is explicitly set to "none" will be pass through from network trunk - proj_dim = 0 if proj == 'none' else embed_dim - self.trunk = timm.create_model( - model_name, - num_classes=proj_dim, - global_pool=pool, - pretrained=pretrained, - **timm_kwargs, - ) - prev_chs = embed_dim - else: - self.trunk = timm.create_model( - model_name, - pretrained=pretrained, - **timm_kwargs, - ) - feat_size = self.trunk.default_cfg.get('pool_size', None) - feature_ndim = 1 if not feat_size else 2 - if custom_pool: - assert feature_ndim == 2 - # if attn pooling used, remove both classifier and default pool - self.trunk.reset_classifier(0, global_pool='') - else: - # reset global pool if pool config set, otherwise leave as network default - reset_kwargs = dict(global_pool=pool) if pool else {} - self.trunk.reset_classifier(0, **reset_kwargs) - prev_chs = self.trunk.num_features - - head_layers = OrderedDict() - - # Add custom pooling to head - if pool == 'abs_attn': - head_layers['pool'] = AbsAttentionPool2d(prev_chs, feat_size=feat_size, out_features=embed_dim) - prev_chs = embed_dim - elif pool == 'rot_attn': - head_layers['pool'] = RotAttentionPool2d(prev_chs, out_features=embed_dim) - prev_chs = embed_dim - - # NOTE attention pool ends with a projection layer, so proj should usually be set to '' if such pooling is used - if proj == 'linear': - head_layers['drop'] = nn.Dropout(drop) - head_layers['proj'] = nn.Linear(prev_chs, embed_dim, bias=proj_bias) - elif proj == 'mlp': - head_layers['mlp'] = Mlp(prev_chs, 2 * embed_dim, embed_dim, drop=(drop, 0), bias=(True, proj_bias)) - - self.head = nn.Sequential(head_layers) - - def lock(self, unlocked_groups=0, freeze_bn_stats=False): - """ lock modules - Args: - unlocked_groups (int): leave last n layer groups unlocked (default: 0) - """ - if not unlocked_groups: - # lock full model - for param in self.trunk.parameters(): - param.requires_grad = False - if freeze_bn_stats: - freeze_batch_norm_2d(self.trunk) - else: - # NOTE: partial freeze requires latest timm (master) branch and is subject to change - try: - # FIXME import here until API stable and in an official release - from timm.models.helpers import group_parameters, group_modules - except ImportError: - raise RuntimeError( - 'Please install latest timm `pip install git+https://github.com/rwightman/pytorch-image-models`') - matcher = self.trunk.group_matcher() - gparams = group_parameters(self.trunk, matcher) - max_layer_id = max(gparams.keys()) - max_layer_id = max_layer_id - unlocked_groups - for group_idx in range(max_layer_id + 1): - group = gparams[group_idx] - for param in group: - self.trunk.get_parameter(param).requires_grad = False - if freeze_bn_stats: - gmodules = group_modules(self.trunk, matcher, reverse=True) - gmodules = {k for k, v in gmodules.items() if v <= max_layer_id} - freeze_batch_norm_2d(self.trunk, gmodules) - - @torch.jit.ignore - def set_grad_checkpointing(self, enable=True): - try: - self.trunk.set_grad_checkpointing(enable) - except Exception as e: - logging.warning('grad checkpointing not supported for this timm image tower, continuing without...') - - def forward(self, x): - x = self.trunk(x) - x = self.head(x) - return x diff --git a/src/proxyclip/open_clip_proxy/tokenizer.py b/src/proxyclip/open_clip_proxy/tokenizer.py deleted file mode 100644 index 19fbab06b066d7b9a4c39097a017c3da12ecd95a..0000000000000000000000000000000000000000 --- a/src/proxyclip/open_clip_proxy/tokenizer.py +++ /dev/null @@ -1,510 +0,0 @@ -""" CLIP tokenizer - -Copied from https://github.com/openai/CLIP. Originally MIT License, Copyright (c) 2021 OpenAI. -""" -import gzip -import html -import os -import random -import string -from functools import lru_cache, partial -from typing import Callable, List, Optional, Union -import warnings - -import ftfy -import numpy as np -import regex as re -import torch - -# https://stackoverflow.com/q/62691279 -os.environ["TOKENIZERS_PARALLELISM"] = "false" -_nltk_init = False - -DEFAULT_CONTEXT_LENGTH = 77 # default context length for OpenAI CLIP - - -@lru_cache() -def default_bpe(): - return os.path.join(os.path.dirname(os.path.abspath(__file__)), "bpe_simple_vocab_16e6.txt.gz") - - -@lru_cache() -def bytes_to_unicode(): - """ - Returns list of utf-8 byte and a corresponding list of unicode strings. - The reversible bpe codes work on unicode strings. - This means you need a large # of unicode characters in your vocab if you want to avoid UNKs. - When you're at something like a 10B token dataset you end up needing around 5K for decent coverage. - This is a significant percentage of your normal, say, 32K bpe vocab. - To avoid that, we want lookup tables between utf-8 bytes and unicode strings. - And avoids mapping to whitespace/control characters the bpe code barfs on. - """ - bs = list(range(ord("!"), ord("~")+1))+list(range(ord("ยก"), ord("ยฌ")+1))+list(range(ord("ยฎ"), ord("รฟ")+1)) - cs = bs[:] - n = 0 - for b in range(2**8): - if b not in bs: - bs.append(b) - cs.append(2**8+n) - n += 1 - cs = [chr(n) for n in cs] - return dict(zip(bs, cs)) - - -def get_pairs(word): - """Return set of symbol pairs in a word. - Word is represented as tuple of symbols (symbols being variable-length strings). - """ - pairs = set() - prev_char = word[0] - for char in word[1:]: - pairs.add((prev_char, char)) - prev_char = char - return pairs - - -def basic_clean(text): - text = ftfy.fix_text(text) - text = html.unescape(html.unescape(text)) - return text.strip() - - -def whitespace_clean(text): - text = re.sub(r'\s+', ' ', text) - text = text.strip() - return text - - -def _clean_canonicalize(x): - # basic, remove whitespace, remove punctuation, lower case - return canonicalize_text(basic_clean(x)) - - -def _clean_lower(x): - # basic, remove whitespace, lower case - return whitespace_clean(basic_clean(x)).lower() - - -def _clean_whitespace(x): - # basic, remove whitespace - return whitespace_clean(basic_clean(x)) - - -def get_clean_fn(type: str): - if type == 'canonicalize': - return _clean_canonicalize - elif type == 'lower': - return _clean_lower - elif type == 'whitespace': - return _clean_whitespace - else: - assert False, f"Invalid clean function ({type})." - - -def canonicalize_text(text, *, keep_punctuation_exact_string=None): - """Returns canonicalized `text` (lowercase and punctuation removed). - - From: https://github.com/google-research/big_vision/blob/53f18caf27a9419231bbf08d3388b07671616d3d/big_vision/evaluators/proj/image_text/prompt_engineering.py#L94 - - Args: - text: string to be canonicalized. - keep_punctuation_exact_string: If provided, then this exact string kept. - For example providing '{}' will keep any occurrences of '{}' (but will - still remove '{' and '}' that appear separately). - """ - text = text.replace("_", " ") - if keep_punctuation_exact_string: - text = keep_punctuation_exact_string.join( - part.translate(str.maketrans("", "", string.punctuation)) - for part in text.split(keep_punctuation_exact_string)) - else: - text = text.translate(str.maketrans("", "", string.punctuation)) - text = text.lower() - text = re.sub(r"\s+", " ", text) - return text.strip() - - -class SimpleTokenizer(object): - def __init__( - self, - bpe_path: str = default_bpe(), - additional_special_tokens: Optional[List[str]] = None, - context_length: Optional[int] = DEFAULT_CONTEXT_LENGTH, - clean: str = 'lower', - reduction_mask: str = '' - ): - self.byte_encoder = bytes_to_unicode() - self.byte_decoder = {v: k for k, v in self.byte_encoder.items()} - merges = gzip.open(bpe_path).read().decode("utf-8").split('\n') - merges = merges[1:49152-256-2+1] - merges = [tuple(merge.split()) for merge in merges] - vocab = list(bytes_to_unicode().values()) - vocab = vocab + [v+'' for v in vocab] - for merge in merges: - vocab.append(''.join(merge)) - special_tokens = ['', ''] - if additional_special_tokens: - special_tokens += additional_special_tokens - vocab.extend(special_tokens) - self.encoder = dict(zip(vocab, range(len(vocab)))) - self.decoder = {v: k for k, v in self.encoder.items()} - self.bpe_ranks = dict(zip(merges, range(len(merges)))) - self.cache = {t:t for t in special_tokens} - special = "|".join(special_tokens) - self.pat = re.compile( - special + r"""|'s|'t|'re|'ve|'m|'ll|'d|[\p{L}]+|[\p{N}]|[^\s\p{L}\p{N}]+""", - re.IGNORECASE, - ) - self.vocab_size = len(self.encoder) - self.all_special_ids = [self.encoder[t] for t in special_tokens] - self.sot_token_id = self.all_special_ids[0] - self.eot_token_id = self.all_special_ids[1] - self.context_length = context_length - self.clean_fn = get_clean_fn(clean) - self.reduction_fn = get_reduction_mask_fn(reduction_mask) if reduction_mask else None - - def bpe(self, token): - if token in self.cache: - return self.cache[token] - word = tuple(token[:-1]) + ( token[-1] + '',) - pairs = get_pairs(word) - - if not pairs: - return token+'' - - while True: - bigram = min(pairs, key = lambda pair: self.bpe_ranks.get(pair, float('inf'))) - if bigram not in self.bpe_ranks: - break - first, second = bigram - new_word = [] - i = 0 - while i < len(word): - try: - j = word.index(first, i) - new_word.extend(word[i:j]) - i = j - except: - new_word.extend(word[i:]) - break - - if word[i] == first and i < len(word)-1 and word[i+1] == second: - new_word.append(first+second) - i += 2 - else: - new_word.append(word[i]) - i += 1 - new_word = tuple(new_word) - word = new_word - if len(word) == 1: - break - else: - pairs = get_pairs(word) - word = ' '.join(word) - self.cache[token] = word - return word - - def encode(self, text): - bpe_tokens = [] - text = self.clean_fn(text) - for token in re.findall(self.pat, text): - token = ''.join(self.byte_encoder[b] for b in token.encode('utf-8')) - bpe_tokens.extend(self.encoder[bpe_token] for bpe_token in self.bpe(token).split(' ')) - return bpe_tokens - - def decode(self, tokens): - text = ''.join([self.decoder[token] for token in tokens]) - text = bytearray([self.byte_decoder[c] for c in text]).decode('utf-8', errors="replace").replace('', ' ') - return text - - def __call__(self, texts: Union[str, List[str]], context_length: Optional[int] = None) -> torch.LongTensor: - """ Returns the tokenized representation of given input string(s) - - Parameters - ---------- - texts : Union[str, List[str]] - An input string or a list of input strings to tokenize - context_length : int - The context length to use; all CLIP models use 77 as the context length - - Returns - ------- - A two-dimensional tensor containing the resulting tokens, shape = [number of input strings, context_length] - """ - if isinstance(texts, str): - texts = [texts] - - context_length = context_length or self.context_length - assert context_length, 'Please set a valid context length' - - if self.reduction_fn is not None: - # use reduction strategy for tokenize if set, otherwise default to truncation below - return self.reduction_fn( - texts, - context_length=context_length, - sot_token_id=self.sot_token_id, - eot_token_id=self.eot_token_id, - encode_fn=self.encode, - ) - - all_tokens = [[self.sot_token_id] + self.encode(text) + [self.eot_token_id] for text in texts] - result = torch.zeros(len(all_tokens), context_length, dtype=torch.long) - - for i, tokens in enumerate(all_tokens): - if len(tokens) > context_length: - tokens = tokens[:context_length] # Truncate - tokens[-1] = self.eot_token_id - result[i, :len(tokens)] = torch.tensor(tokens) - - return result - - -_tokenizer = SimpleTokenizer() - - -def decode(output_ids: torch.Tensor): - output_ids = output_ids.cpu().numpy() - return _tokenizer.decode(output_ids) - - -def tokenize(texts: Union[str, List[str]], context_length: int = DEFAULT_CONTEXT_LENGTH) -> torch.LongTensor: - return _tokenizer(texts, context_length=context_length) - - -def random_mask_tokenize( - texts: Union[str, List[str]], - context_length: int, - sot_token_id: int, - eot_token_id: int, - encode_fn: Callable, - shuffle: bool = False, -): - all_tokens = [encode_fn(text) for text in texts] - result = torch.zeros(len(all_tokens), context_length, dtype=torch.long) - - for i, tokens in enumerate(all_tokens): - tokens = torch.tensor(tokens) - num_tokens = len(tokens) - if num_tokens > context_length - 2: # 2 for sot and eot token - num_keep = context_length - 2 - indices = torch.randperm(len(tokens)) - indices = indices[:num_keep] - if not shuffle: - indices = indices.msort() - tokens = tokens[indices] - num_tokens = num_keep - result[i, 0] = sot_token_id - result[i, 1:num_tokens + 1] = tokens - result[i, num_tokens + 1] = eot_token_id - - return result - - -def simple_mask_tokenize( - texts: Union[str, List[str]], - context_length: int, - sot_token_id: int, - eot_token_id: int, - encode_fn: Callable, -): - all_tokens = [encode_fn(text) for text in texts] - result = torch.zeros(len(all_tokens), context_length, dtype=torch.long) - - for i, tokens in enumerate(all_tokens): - num_tokens = len(tokens) - if num_tokens > context_length - 2: # 2 for sot and eot token - num_keep = context_length - 2 - start_index = random.randint(0, num_tokens - num_keep) # high is incl - tokens = tokens[start_index: start_index + num_keep] - tokens = [sot_token_id] + tokens + [eot_token_id] - result[i, :len(tokens)] = torch.tensor(tokens) - - return result - - -def syntax_mask_tokenize( - texts: Union[str, List[str]], - context_length: int, - sot_token_id: int, - eot_token_id: int, - encode_fn: Callable, -) -> torch.LongTensor: - """ Returns the tokenized representation of given input string(s). - Apply syntax masking before tokenize. - """ - import nltk - global _nltk_init - if not _nltk_init: - # run them for the first time - nltk.download('punkt') - nltk.download('averaged_perceptron_tagger') - _nltk_init = True - - def get_order(x): - if x.startswith('NN'): - return 1 - elif x.startswith('JJ'): - return 2 - elif x.startswith('VB'): - return 3 - else: - return 4 - - # syntax masking - new_texts = [] - for text in texts: - list_tokens = nltk.tokenize.word_tokenize(text) - pos_tags = nltk.pos_tag(list_tokens) - # sample the words by get_order method - order_list = [get_order(tag) for _, tag in pos_tags] - sorted_ids = np.argsort(np.array(order_list)) - sampled_ids = sorted(sorted_ids[:context_length - 2]) # need 2 slots for sot and eot tokens - sampled_tokens = np.take(np.array(list_tokens), sampled_ids, axis=0) # sample the tokens - - new_text = '' - for token in sampled_tokens: - new_text = new_text + str(token) + ' ' - new_text = new_text.strip() - new_texts.append(new_text) - texts = new_texts - - all_tokens = [[sot_token_id] + encode_fn(text) + [eot_token_id] for text in texts] - result = torch.zeros(len(all_tokens), context_length, dtype=torch.long) - - for i, tokens in enumerate(all_tokens): - # still need first truncate because some words produces two tokens - if len(tokens) > context_length: - tokens = tokens[:context_length] # Truncate - tokens[-1] = eot_token_id - result[i, :len(tokens)] = torch.tensor(tokens) - - return result - - -def get_reduction_mask_fn(type: str): - """ Choose strategy for dropping (masking) tokens to achieve target context length""" - assert type in ('simple', 'random', 'shuffle', 'syntax') - if type == 'simple': - return simple_mask_tokenize # randomly select block [start:end] - elif type == 'random': - return random_mask_tokenize # randomly drop tokens (keep order) - elif type == 'shuffle': - return partial(random_mask_tokenize, shuffle=True) # randomly drop tokens (shuffle order) - elif type == 'syntax': - return syntax_mask_tokenize # randomly drop prioritized by syntax - - -class HFTokenizer: - """HuggingFace tokenizer wrapper""" - - def __init__( - self, - tokenizer_name: str, - context_length: Optional[int] = DEFAULT_CONTEXT_LENGTH, - clean: str = 'whitespace', - strip_sep_token: bool = False, - language: Optional[str] = None, - ): - from transformers import AutoTokenizer - self.tokenizer = AutoTokenizer.from_pretrained(tokenizer_name) - set_lang_fn = getattr(self.tokenizer, 'set_src_lang_special_tokens', None) - if callable(set_lang_fn): - self.set_lang_fn = set_lang_fn - if language is not None: - self.set_language(language) - self.context_length = context_length - self.clean_fn = get_clean_fn(clean) - self.strip_sep_token = strip_sep_token - - def save_pretrained(self, dest): - self.tokenizer.save_pretrained(dest) - - def __call__(self, texts: Union[str, List[str]], context_length: Optional[int] = None) -> torch.Tensor: - # same cleaning as for default tokenizer, except lowercasing - # adding lower (for case-sensitive tokenizers) will make it more robust but less sensitive to nuance - if isinstance(texts, str): - texts = [texts] - - context_length = context_length or self.context_length - assert context_length, 'Please set a valid context length in class init or call.' - - texts = [self.clean_fn(text) for text in texts] - input_ids = self.tokenizer.batch_encode_plus( - texts, - return_tensors='pt', - max_length=context_length, - padding='max_length', - truncation=True, - ).input_ids - - if self.strip_sep_token: - input_ids = torch.where( - input_ids == self.tokenizer.sep_token_id, - torch.zeros_like(input_ids), - input_ids, - ) - - return input_ids - - def set_language(self, src_lang): - if hasattr(self, 'set_lang_fn'): - self.set_lang_fn(src_lang) - else: - warnings.warn('Cannot set language for the tokenizer.') - - -class SigLipTokenizer: - """HuggingFace tokenizer wrapper for SigLIP T5 compatible sentencepiece vocabs - """ - VOCAB_FILES = { - # english, vocab_size=32_000 - "c4-en": "http://storage.googleapis.com/t5-data/vocabs/cc_en.32000/sentencepiece.model", - # used in multilingual models (mT5, PaLI), vocab_size=250_000 - "mc4": "http://storage.googleapis.com/t5-data/vocabs/mc4.250000.100extra/sentencepiece.model", - } - - def __init__( - self, - tokenizer_name: str, - context_length: Optional[int] = 64, - ): - from transformers import T5TokenizerFast - - if tokenizer_name in self.VOCAB_FILES: - # FIXME temporary hack? - import tempfile - - import fsspec - vocab_file = self.VOCAB_FILES[tokenizer_name] - with tempfile.NamedTemporaryFile('wb') as dst: - with fsspec.open(vocab_file, 'rb') as src: - dst.write(src.read()) - self.tokenizer = T5TokenizerFast(dst.name, legacy=False) - else: - self.tokenizer = T5TokenizerFast(tokenizer_name, legacy=False) - - self.tokenizer.pad_token_id = 1 - self.tokenizer.eos_token_id = 1 - self.context_length = context_length - - def save_pretrained(self, dest): - self.tokenizer.save_pretrained(dest) - - def __call__(self, texts: Union[str, List[str]], context_length: Optional[int] = None) -> torch.Tensor: - # same cleaning as for default tokenizer, except lowercasing - # adding lower (for case-sensitive tokenizers) will make it more robust but less sensitive to nuance - if isinstance(texts, str): - texts = [texts] - - context_length = context_length or self.context_length - assert context_length, 'Please set a valid context length in class init or call.' - - texts = [canonicalize_text(basic_clean(text)) for text in texts] - output = self.tokenizer( - texts, - return_tensors='pt', - max_length=context_length, - padding='max_length', - truncation=True, - ) - return output.input_ids diff --git a/src/proxyclip/open_clip_proxy/transform.py b/src/proxyclip/open_clip_proxy/transform.py deleted file mode 100644 index 521a203e3136587f7601325e09c244fc69238cfd..0000000000000000000000000000000000000000 --- a/src/proxyclip/open_clip_proxy/transform.py +++ /dev/null @@ -1,407 +0,0 @@ -import numbers -import random -import warnings -from dataclasses import dataclass, asdict -from typing import Any, Dict, List, Optional, Sequence, Tuple, Union - -import torch -import torchvision.transforms.functional as F -from torchvision.transforms import Normalize, Compose, RandomResizedCrop, InterpolationMode, ToTensor, Resize, \ - CenterCrop, ColorJitter, Grayscale - -from .constants import OPENAI_DATASET_MEAN, OPENAI_DATASET_STD -from .utils import to_2tuple - - -@dataclass -class PreprocessCfg: - size: Union[int, Tuple[int, int]] = 224 - mode: str = 'RGB' - mean: Tuple[float, ...] = OPENAI_DATASET_MEAN - std: Tuple[float, ...] = OPENAI_DATASET_STD - interpolation: str = 'bicubic' - resize_mode: str = 'shortest' - fill_color: int = 0 - - def __post_init__(self): - assert self.mode in ('RGB',) - - @property - def num_channels(self): - return 3 - - @property - def input_size(self): - return (self.num_channels,) + to_2tuple(self.size) - -_PREPROCESS_KEYS = set(asdict(PreprocessCfg()).keys()) - - -def merge_preprocess_dict( - base: Union[PreprocessCfg, Dict], - overlay: Dict, -): - """ Merge overlay key-value pairs on top of base preprocess cfg or dict. - Input dicts are filtered based on PreprocessCfg fields. - """ - if isinstance(base, PreprocessCfg): - base_clean = asdict(base) - else: - base_clean = {k: v for k, v in base.items() if k in _PREPROCESS_KEYS} - if overlay: - overlay_clean = {k: v for k, v in overlay.items() if k in _PREPROCESS_KEYS and v is not None} - base_clean.update(overlay_clean) - return base_clean - - -def merge_preprocess_kwargs(base: PreprocessCfg, **kwargs): - return merge_preprocess_dict(base, kwargs) - - -@dataclass -class AugmentationCfg: - scale: Tuple[float, float] = (0.9, 1.0) - ratio: Optional[Tuple[float, float]] = None - color_jitter: Optional[Union[float, Tuple[float, float, float], Tuple[float, float, float, float]]] = None - re_prob: Optional[float] = None - re_count: Optional[int] = None - use_timm: bool = False - - # params for simclr_jitter_gray - color_jitter_prob: float = None - gray_scale_prob: float = None - - -def _setup_size(size, error_msg): - if isinstance(size, numbers.Number): - return int(size), int(size) - - if isinstance(size, Sequence) and len(size) == 1: - return size[0], size[0] - - if len(size) != 2: - raise ValueError(error_msg) - - return size - - -class ResizeKeepRatio: - """ Resize and Keep Ratio - - Copy & paste from `timm` - """ - - def __init__( - self, - size, - longest=0., - interpolation=InterpolationMode.BICUBIC, - random_scale_prob=0., - random_scale_range=(0.85, 1.05), - random_aspect_prob=0., - random_aspect_range=(0.9, 1.11) - ): - if isinstance(size, (list, tuple)): - self.size = tuple(size) - else: - self.size = (size, size) - self.interpolation = interpolation - self.longest = float(longest) # [0, 1] where 0 == shortest edge, 1 == longest - self.random_scale_prob = random_scale_prob - self.random_scale_range = random_scale_range - self.random_aspect_prob = random_aspect_prob - self.random_aspect_range = random_aspect_range - - @staticmethod - def get_params( - img, - target_size, - longest, - random_scale_prob=0., - random_scale_range=(0.85, 1.05), - random_aspect_prob=0., - random_aspect_range=(0.9, 1.11) - ): - """Get parameters - """ - source_size = img.size[::-1] # h, w - h, w = source_size - target_h, target_w = target_size - ratio_h = h / target_h - ratio_w = w / target_w - ratio = max(ratio_h, ratio_w) * longest + min(ratio_h, ratio_w) * (1. - longest) - if random_scale_prob > 0 and random.random() < random_scale_prob: - ratio_factor = random.uniform(random_scale_range[0], random_scale_range[1]) - ratio_factor = (ratio_factor, ratio_factor) - else: - ratio_factor = (1., 1.) - if random_aspect_prob > 0 and random.random() < random_aspect_prob: - aspect_factor = random.uniform(random_aspect_range[0], random_aspect_range[1]) - ratio_factor = (ratio_factor[0] / aspect_factor, ratio_factor[1] * aspect_factor) - size = [round(x * f / ratio) for x, f in zip(source_size, ratio_factor)] - return size - - def __call__(self, img): - """ - Args: - img (PIL Image): Image to be cropped and resized. - - Returns: - PIL Image: Resized, padded to at least target size, possibly cropped to exactly target size - """ - size = self.get_params( - img, self.size, self.longest, - self.random_scale_prob, self.random_scale_range, - self.random_aspect_prob, self.random_aspect_range - ) - img = F.resize(img, size, self.interpolation) - return img - - def __repr__(self): - format_string = self.__class__.__name__ + '(size={0}'.format(self.size) - format_string += f', interpolation={self.interpolation})' - format_string += f', longest={self.longest:.3f})' - return format_string - - -def center_crop_or_pad(img: torch.Tensor, output_size: List[int], fill=0) -> torch.Tensor: - """Center crops and/or pads the given image. - If the image is torch Tensor, it is expected - to have [..., H, W] shape, where ... means an arbitrary number of leading dimensions. - If image size is smaller than output size along any edge, image is padded with 0 and then center cropped. - - Args: - img (PIL Image or Tensor): Image to be cropped. - output_size (sequence or int): (height, width) of the crop box. If int or sequence with single int, - it is used for both directions. - fill (int, Tuple[int]): Padding color - - Returns: - PIL Image or Tensor: Cropped image. - """ - if isinstance(output_size, numbers.Number): - output_size = (int(output_size), int(output_size)) - elif isinstance(output_size, (tuple, list)) and len(output_size) == 1: - output_size = (output_size[0], output_size[0]) - - _, image_height, image_width = F.get_dimensions(img) - crop_height, crop_width = output_size - - if crop_width > image_width or crop_height > image_height: - padding_ltrb = [ - (crop_width - image_width) // 2 if crop_width > image_width else 0, - (crop_height - image_height) // 2 if crop_height > image_height else 0, - (crop_width - image_width + 1) // 2 if crop_width > image_width else 0, - (crop_height - image_height + 1) // 2 if crop_height > image_height else 0, - ] - img = F.pad(img, padding_ltrb, fill=fill) - _, image_height, image_width = F.get_dimensions(img) - if crop_width == image_width and crop_height == image_height: - return img - - crop_top = int(round((image_height - crop_height) / 2.0)) - crop_left = int(round((image_width - crop_width) / 2.0)) - return F.crop(img, crop_top, crop_left, crop_height, crop_width) - - -class CenterCropOrPad(torch.nn.Module): - """Crops the given image at the center. - If the image is torch Tensor, it is expected - to have [..., H, W] shape, where ... means an arbitrary number of leading dimensions. - If image size is smaller than output size along any edge, image is padded with 0 and then center cropped. - - Args: - size (sequence or int): Desired output size of the crop. If size is an - int instead of sequence like (h, w), a square crop (size, size) is - made. If provided a sequence of length 1, it will be interpreted as (size[0], size[0]). - """ - - def __init__(self, size, fill=0): - super().__init__() - self.size = _setup_size(size, error_msg="Please provide only two dimensions (h, w) for size.") - self.fill = fill - - def forward(self, img): - """ - Args: - img (PIL Image or Tensor): Image to be cropped. - - Returns: - PIL Image or Tensor: Cropped image. - """ - return center_crop_or_pad(img, self.size, fill=self.fill) - - def __repr__(self) -> str: - return f"{self.__class__.__name__}(size={self.size})" - - -def _convert_to_rgb(image): - return image.convert('RGB') - - -class color_jitter(object): - """ - Apply Color Jitter to the PIL image with a specified probability. - """ - def __init__(self, brightness=0., contrast=0., saturation=0., hue=0., p=0.8): - assert 0. <= p <= 1. - self.p = p - self.transf = ColorJitter(brightness=brightness, contrast=contrast, saturation=saturation, hue=hue) - - def __call__(self, img): - if random.random() < self.p: - return self.transf(img) - else: - return img - - -class gray_scale(object): - """ - Apply Gray Scale to the PIL image with a specified probability. - """ - def __init__(self, p=0.2): - assert 0. <= p <= 1. - self.p = p - self.transf = Grayscale(num_output_channels=3) - - def __call__(self, img): - if random.random() < self.p: - return self.transf(img) - else: - return img - - -def image_transform( - image_size: Union[int, Tuple[int, int]], - is_train: bool, - mean: Optional[Tuple[float, ...]] = None, - std: Optional[Tuple[float, ...]] = None, - resize_mode: Optional[str] = None, - interpolation: Optional[str] = None, - fill_color: int = 0, - aug_cfg: Optional[Union[Dict[str, Any], AugmentationCfg]] = None, -): - mean = mean or OPENAI_DATASET_MEAN - if not isinstance(mean, (list, tuple)): - mean = (mean,) * 3 - - std = std or OPENAI_DATASET_STD - if not isinstance(std, (list, tuple)): - std = (std,) * 3 - - interpolation = interpolation or 'bicubic' - assert interpolation in ['bicubic', 'bilinear', 'random'] - # NOTE random is ignored for interpolation_mode, so defaults to BICUBIC for inference if set - interpolation_mode = InterpolationMode.BILINEAR if interpolation == 'bilinear' else InterpolationMode.BICUBIC - - resize_mode = resize_mode or 'shortest' - assert resize_mode in ('shortest', 'longest', 'squash') - - if isinstance(aug_cfg, dict): - aug_cfg = AugmentationCfg(**aug_cfg) - else: - aug_cfg = aug_cfg or AugmentationCfg() - - normalize = Normalize(mean=mean, std=std) - - if is_train: - aug_cfg_dict = {k: v for k, v in asdict(aug_cfg).items() if v is not None} - use_timm = aug_cfg_dict.pop('use_timm', False) - if use_timm: - from timm.data import create_transform # timm can still be optional - if isinstance(image_size, (tuple, list)): - assert len(image_size) >= 2 - input_size = (3,) + image_size[-2:] - else: - input_size = (3, image_size, image_size) - - aug_cfg_dict.setdefault('color_jitter', None) # disable by default - # drop extra non-timm items - aug_cfg_dict.pop('color_jitter_prob', None) - aug_cfg_dict.pop('gray_scale_prob', None) - - train_transform = create_transform( - input_size=input_size, - is_training=True, - hflip=0., - mean=mean, - std=std, - re_mode='pixel', - interpolation=interpolation, - **aug_cfg_dict, - ) - else: - train_transform = [ - RandomResizedCrop( - image_size, - scale=aug_cfg_dict.pop('scale'), - interpolation=InterpolationMode.BICUBIC, - ), - _convert_to_rgb, - ] - if aug_cfg.color_jitter_prob: - assert aug_cfg.color_jitter is not None and len(aug_cfg.color_jitter) == 4 - train_transform.extend([ - color_jitter(*aug_cfg.color_jitter, p=aug_cfg.color_jitter_prob) - ]) - if aug_cfg.gray_scale_prob: - train_transform.extend([ - gray_scale(aug_cfg.gray_scale_prob) - ]) - train_transform.extend([ - ToTensor(), - normalize, - ]) - train_transform = Compose(train_transform) - if aug_cfg_dict: - warnings.warn(f'Unused augmentation cfg items, specify `use_timm` to use ({list(aug_cfg_dict.keys())}).') - return train_transform - else: - if resize_mode == 'longest': - transforms = [ - ResizeKeepRatio(image_size, interpolation=interpolation_mode, longest=1), - CenterCropOrPad(image_size, fill=fill_color) - ] - elif resize_mode == 'squash': - if isinstance(image_size, int): - image_size = (image_size, image_size) - transforms = [ - Resize(image_size, interpolation=interpolation_mode), - ] - else: - assert resize_mode == 'shortest' - if not isinstance(image_size, (tuple, list)): - image_size = (image_size, image_size) - if image_size[0] == image_size[1]: - # simple case, use torchvision built-in Resize w/ shortest edge mode (scalar size arg) - transforms = [ - Resize(image_size[0], interpolation=interpolation_mode) - ] - else: - # resize shortest edge to matching target dim for non-square target - transforms = [ResizeKeepRatio(image_size)] - transforms += [CenterCrop(image_size)] - - transforms.extend([ - _convert_to_rgb, - ToTensor(), - normalize, - ]) - return Compose(transforms) - - -def image_transform_v2( - cfg: PreprocessCfg, - is_train: bool, - aug_cfg: Optional[Union[Dict[str, Any], AugmentationCfg]] = None, -): - return image_transform( - image_size=cfg.size, - is_train=is_train, - mean=cfg.mean, - std=cfg.std, - interpolation=cfg.interpolation, - resize_mode=cfg.resize_mode, - fill_color=cfg.fill_color, - aug_cfg=aug_cfg, - ) diff --git a/src/proxyclip/open_clip_proxy/transformer.py b/src/proxyclip/open_clip_proxy/transformer.py deleted file mode 100644 index 7eb65c5a8e0502a3cb3d1c3027197f77541de37c..0000000000000000000000000000000000000000 --- a/src/proxyclip/open_clip_proxy/transformer.py +++ /dev/null @@ -1,842 +0,0 @@ -from collections import OrderedDict -import math -from typing import Callable, Optional, Sequence, Tuple -from functools import partial - -import torch -from torch import nn -from torch.nn import functional as F -from torch.utils.checkpoint import checkpoint -import numpy as np - -from .utils import to_2tuple -from .pos_embed import get_2d_sincos_pos_embed - - -class LayerNormFp32(nn.LayerNorm): - """Subclass torch's LayerNorm to handle fp16 (by casting to float32 and back).""" - - def forward(self, x: torch.Tensor): - orig_type = x.dtype - x = F.layer_norm(x.to(torch.float32), self.normalized_shape, self.weight, self.bias, self.eps) - return x.to(orig_type) - - -class LayerNorm(nn.LayerNorm): - """Subclass torch's LayerNorm (with cast back to input dtype).""" - - def forward(self, x: torch.Tensor): - orig_type = x.dtype - x = F.layer_norm(x, self.normalized_shape, self.weight, self.bias, self.eps) - return x.to(orig_type) - - -class QuickGELU(nn.Module): - # NOTE This is slower than nn.GELU or nn.SiLU and uses more GPU memory - def forward(self, x: torch.Tensor): - return x * torch.sigmoid(1.702 * x) - - -class LayerScale(nn.Module): - def __init__(self, dim, init_values=1e-5, inplace=False): - super().__init__() - self.inplace = inplace - self.gamma = nn.Parameter(init_values * torch.ones(dim)) - - def forward(self, x): - return x.mul_(self.gamma) if self.inplace else x * self.gamma - - -class PatchDropout(nn.Module): - """ - https://arxiv.org/abs/2212.00794 - """ - - def __init__(self, prob, exclude_first_token=True): - super().__init__() - assert 0 <= prob < 1. - self.prob = prob - self.exclude_first_token = exclude_first_token # exclude CLS token - - def forward(self, x): - if not self.training or self.prob == 0.: - return x - - if self.exclude_first_token: - cls_tokens, x = x[:, :1], x[:, 1:] - else: - cls_tokens = torch.jit.annotate(torch.Tensor, x[:, :1]) - - batch = x.size()[0] - num_tokens = x.size()[1] - - batch_indices = torch.arange(batch) - batch_indices = batch_indices[..., None] - - keep_prob = 1 - self.prob - num_patches_keep = max(1, int(num_tokens * keep_prob)) - - rand = torch.randn(batch, num_tokens) - patch_indices_keep = rand.topk(num_patches_keep, dim=-1).indices - - x = x[batch_indices, patch_indices_keep] - - if self.exclude_first_token: - x = torch.cat((cls_tokens, x), dim=1) - - return x - - -class Attention(nn.Module): - def __init__( - self, - dim, - num_heads=8, - qkv_bias=True, - scaled_cosine=False, - scale_heads=False, - logit_scale_max=math.log(1. / 0.01), - attn_drop=0., - proj_drop=0. - ): - super().__init__() - self.scaled_cosine = scaled_cosine - self.scale_heads = scale_heads - assert dim % num_heads == 0, 'dim should be divisible by num_heads' - self.num_heads = num_heads - self.head_dim = dim // num_heads - self.scale = self.head_dim ** -0.5 - self.logit_scale_max = logit_scale_max - - # keeping in_proj in this form (instead of nn.Linear) to match weight scheme of original - self.in_proj_weight = nn.Parameter(torch.randn((dim * 3, dim)) * self.scale) - if qkv_bias: - self.in_proj_bias = nn.Parameter(torch.zeros(dim * 3)) - else: - self.in_proj_bias = None - - if self.scaled_cosine: - self.logit_scale = nn.Parameter(torch.log(10 * torch.ones((num_heads, 1, 1)))) - else: - self.logit_scale = None - self.attn_drop = nn.Dropout(attn_drop) - if self.scale_heads: - self.head_scale = nn.Parameter(torch.ones((num_heads, 1, 1))) - else: - self.head_scale = None - self.out_proj = nn.Linear(dim, dim) - self.out_drop = nn.Dropout(proj_drop) - - def forward(self, x, attn_mask: Optional[torch.Tensor] = None): - L, N, C = x.shape - q, k, v = F.linear(x, self.in_proj_weight, self.in_proj_bias).chunk(3, dim=-1) - q = q.contiguous().view(L, N * self.num_heads, -1).transpose(0, 1) - k = k.contiguous().view(L, N * self.num_heads, -1).transpose(0, 1) - v = v.contiguous().view(L, N * self.num_heads, -1).transpose(0, 1) - - if self.logit_scale is not None: - attn = torch.bmm(F.normalize(q, dim=-1), F.normalize(k, dim=-1).transpose(-1, -2)) - logit_scale = torch.clamp(self.logit_scale, max=self.logit_scale_max).exp() - attn = attn.view(N, self.num_heads, L, L) * logit_scale - attn = attn.view(-1, L, L) - else: - q = q * self.scale - attn = torch.bmm(q, k.transpose(-1, -2)) - - if attn_mask is not None: - if attn_mask.dtype == torch.bool: - new_attn_mask = torch.zeros_like(attn_mask, dtype=q.dtype) - new_attn_mask.masked_fill_(attn_mask, float("-inf")) - attn_mask = new_attn_mask - attn += attn_mask - - attn = attn.softmax(dim=-1) - attn = self.attn_drop(attn) - - x = torch.bmm(attn, v) - if self.head_scale is not None: - x = x.view(N, self.num_heads, L, C) * self.head_scale - x = x.view(-1, L, C) - x = x.transpose(0, 1).reshape(L, N, C) - x = self.out_proj(x) - x = self.out_drop(x) - return x - - -class AttentionalPooler(nn.Module): - def __init__( - self, - d_model: int, - context_dim: int, - n_head: int = 8, - n_queries: int = 256, - norm_layer: Callable = LayerNorm - ): - super().__init__() - self.query = nn.Parameter(torch.randn(n_queries, d_model)) - self.attn = nn.MultiheadAttention(d_model, n_head, kdim=context_dim, vdim=context_dim) - self.ln_q = norm_layer(d_model) - self.ln_k = norm_layer(context_dim) - - def forward(self, x: torch.Tensor): - x = self.ln_k(x).permute(1, 0, 2) # NLD -> LND - N = x.shape[1] - q = self.ln_q(self.query) - out = self.attn(q.unsqueeze(1).expand(-1, N, -1), x, x, need_weights=False)[0] - return out.permute(1, 0, 2) # LND -> NLD - - -class ResidualAttentionBlock(nn.Module): - def __init__( - self, - d_model: int, - n_head: int, - mlp_ratio: float = 4.0, - ls_init_value: float = None, - act_layer: Callable = nn.GELU, - norm_layer: Callable = LayerNorm, - is_cross_attention: bool = False, - ): - super().__init__() - - self.ln_1 = norm_layer(d_model) - self.attn = nn.MultiheadAttention(d_model, n_head) - self.ls_1 = LayerScale(d_model, ls_init_value) if ls_init_value is not None else nn.Identity() - if is_cross_attention: - self.ln_1_kv = norm_layer(d_model) - - self.ln_2 = norm_layer(d_model) - mlp_width = int(d_model * mlp_ratio) - self.mlp = nn.Sequential(OrderedDict([ - ("c_fc", nn.Linear(d_model, mlp_width)), - ("gelu", act_layer()), - ("c_proj", nn.Linear(mlp_width, d_model)) - ])) - self.ls_2 = LayerScale(d_model, ls_init_value) if ls_init_value is not None else nn.Identity() - - def attention( - self, - q_x: torch.Tensor, - k_x: Optional[torch.Tensor] = None, - v_x: Optional[torch.Tensor] = None, - attn_mask: Optional[torch.Tensor] = None, - ): - k_x = k_x if k_x is not None else q_x - v_x = v_x if v_x is not None else q_x - - attn_mask = attn_mask.to(q_x.dtype) if attn_mask is not None else None - return self.attn( - q_x, k_x, v_x, need_weights=False, attn_mask=attn_mask - )[0] - - def forward( - self, - q_x: torch.Tensor, - k_x: Optional[torch.Tensor] = None, - v_x: Optional[torch.Tensor] = None, - attn_mask: Optional[torch.Tensor] = None, - ): - k_x = self.ln_1_kv(k_x) if hasattr(self, "ln_1_kv") and k_x is not None else None - v_x = self.ln_1_kv(v_x) if hasattr(self, "ln_1_kv") and v_x is not None else None - - x = q_x + self.ls_1(self.attention(q_x=self.ln_1(q_x), k_x=k_x, v_x=v_x, attn_mask=attn_mask)) - x = x + self.ls_2(self.mlp(self.ln_2(x))) - return x - - -class CustomResidualAttentionBlock(nn.Module): - def __init__( - self, - d_model: int, - n_head: int, - mlp_ratio: float = 4.0, - ls_init_value: float = None, - act_layer: Callable = nn.GELU, - norm_layer: Callable = LayerNorm, - scale_cosine_attn: bool = False, - scale_heads: bool = False, - scale_attn: bool = False, - scale_fc: bool = False, - ): - super().__init__() - - self.ln_1 = norm_layer(d_model) - self.attn = Attention( - d_model, n_head, - scaled_cosine=scale_cosine_attn, - scale_heads=scale_heads, - ) - self.ln_attn = norm_layer(d_model) if scale_attn else nn.Identity() - self.ls_1 = LayerScale(d_model, ls_init_value) if ls_init_value is not None else nn.Identity() - - self.ln_2 = norm_layer(d_model) - mlp_width = int(d_model * mlp_ratio) - self.mlp = nn.Sequential(OrderedDict([ - ("c_fc", nn.Linear(d_model, mlp_width)), - ("gelu", act_layer()), - ('ln', norm_layer(mlp_width) if scale_fc else nn.Identity()), - ("c_proj", nn.Linear(mlp_width, d_model)) - ])) - self.ls_2 = LayerScale(d_model, ls_init_value) if ls_init_value is not None else nn.Identity() - - def forward(self, x: torch.Tensor, attn_mask: Optional[torch.Tensor] = None): - x = x + self.ls_1(self.ln_attn(self.attn(self.ln_1(x), attn_mask=attn_mask))) - x = x + self.ls_2(self.mlp(self.ln_2(x))) - return x - - -def _expand_token(token, batch_size: int): - return token.view(1, 1, -1).expand(batch_size, -1, -1) - - -class Transformer(nn.Module): - def __init__( - self, - width: int, - layers: int, - heads: int, - mlp_ratio: float = 4.0, - ls_init_value: float = None, - act_layer: Callable = nn.GELU, - norm_layer: Callable = LayerNorm, - ): - super().__init__() - self.width = width - self.layers = layers - self.grad_checkpointing = False - - self.resblocks = nn.ModuleList([ - ResidualAttentionBlock( - width, heads, mlp_ratio, ls_init_value=ls_init_value, act_layer=act_layer, norm_layer=norm_layer) - for _ in range(layers) - ]) - - def get_cast_dtype(self) -> torch.dtype: - if hasattr(self.resblocks[0].mlp.c_fc, 'int8_original_dtype'): - return self.resblocks[0].mlp.c_fc.int8_original_dtype - return self.resblocks[0].mlp.c_fc.weight.dtype - - def forward(self, x: torch.Tensor, attn_mask: Optional[torch.Tensor] = None): - for r in self.resblocks: - if self.grad_checkpointing and not torch.jit.is_scripting(): - # TODO: handle kwargs https://github.com/pytorch/pytorch/issues/79887#issuecomment-1161758372 - x = checkpoint(r, x, None, None, attn_mask) - else: - x = r(x, attn_mask=attn_mask) - return x - - -class VisionTransformer(nn.Module): - output_tokens: torch.jit.Final[bool] - - def __init__( - self, - image_size: int, - patch_size: int, - width: int, - layers: int, - heads: int, - mlp_ratio: float, - ls_init_value: float = None, - attentional_pool: bool = False, - attn_pooler_queries: int = 256, - attn_pooler_heads: int = 8, - output_dim: int = 512, - patch_dropout: float = 0., - no_ln_pre: bool = False, - pos_embed_type: str = 'learnable', - pool_type: str = 'tok', - final_ln_after_pool: bool = False, - act_layer: Callable = nn.GELU, - norm_layer: Callable = LayerNorm, - output_tokens: bool = False, - ): - super().__init__() - assert pool_type in ('tok', 'avg', 'none') - self.output_tokens = output_tokens - image_height, image_width = self.image_size = to_2tuple(image_size) - patch_height, patch_width = self.patch_size = to_2tuple(patch_size) - self.grid_size = (image_height // patch_height, image_width // patch_width) - self.final_ln_after_pool = final_ln_after_pool # currently ignored w/ attn pool enabled - self.output_dim = output_dim - - self.conv1 = nn.Conv2d(in_channels=3, out_channels=width, kernel_size=patch_size, stride=patch_size, bias=False) - - # class embeddings and positional embeddings - scale = width ** -0.5 - self.class_embedding = nn.Parameter(scale * torch.randn(width)) - if pos_embed_type == 'learnable': - self.positional_embedding = nn.Parameter( - scale * torch.randn(self.grid_size[0] * self.grid_size[1] + 1, width)) - elif pos_embed_type == 'sin_cos_2d': - # fixed sin-cos embedding - assert self.grid_size[0] == self.grid_size[1],\ - 'currently sin cos 2d pos embedding only supports square input' - self.positional_embedding = nn.Parameter( - torch.zeros(self.grid_size[0] * self.grid_size[1] + 1, width), requires_grad=False) - pos_embed_type = get_2d_sincos_pos_embed(width, self.grid_size[0], cls_token=True) - self.positional_embedding.data.copy_(torch.from_numpy(pos_embed_type).float()) - else: - raise ValueError - - # setting a patch_dropout of 0. would mean it is disabled and this function would be the identity fn - self.patch_dropout = PatchDropout(patch_dropout) if patch_dropout > 0. else nn.Identity() - - self.ln_pre = nn.Identity() if no_ln_pre else norm_layer(width) - self.transformer = Transformer( - width, - layers, - heads, - mlp_ratio, - ls_init_value=ls_init_value, - act_layer=act_layer, - norm_layer=norm_layer, - ) - - if attentional_pool: - if isinstance(attentional_pool, str): - self.attn_pool_type = attentional_pool - self.pool_type = 'none' - if attentional_pool in ('parallel', 'cascade'): - self.attn_pool = AttentionalPooler( - output_dim, - width, - n_head=attn_pooler_heads, - n_queries=attn_pooler_queries, - ) - self.attn_pool_contrastive = AttentionalPooler( - output_dim, - width, - n_head=attn_pooler_heads, - n_queries=1, - ) - else: - assert False - else: - self.attn_pool_type = '' - self.pool_type = pool_type - self.attn_pool = AttentionalPooler( - output_dim, - width, - n_head=attn_pooler_heads, - n_queries=attn_pooler_queries, - ) - self.attn_pool_contrastive = None - pool_dim = output_dim - else: - self.attn_pool = None - pool_dim = width - self.pool_type = pool_type - - self.ln_post = norm_layer(pool_dim) - self.proj = nn.Parameter(scale * torch.randn(pool_dim, output_dim)) - - self.init_parameters() - - def lock(self, unlocked_groups=0, freeze_bn_stats=False): - for param in self.parameters(): - param.requires_grad = False - - if unlocked_groups != 0: - groups = [ - [ - self.conv1, - self.class_embedding, - self.positional_embedding, - self.ln_pre, - ], - *self.transformer.resblocks[:-1], - [ - self.transformer.resblocks[-1], - self.ln_post, - ], - self.proj, - ] - - def _unlock(x): - if isinstance(x, Sequence): - for g in x: - _unlock(g) - else: - if isinstance(x, torch.nn.Parameter): - x.requires_grad = True - else: - for p in x.parameters(): - p.requires_grad = True - - _unlock(groups[-unlocked_groups:]) - - def init_parameters(self): - # FIXME OpenAI CLIP did not define an init for the VisualTransformer - # TODO experiment if default PyTorch init, below, or alternate init is best. - - # nn.init.normal_(self.class_embedding, std=self.scale) - # nn.init.normal_(self.positional_embedding, std=self.scale) - # - # proj_std = (self.transformer.width ** -0.5) * ((2 * self.transformer.layers) ** -0.5) - # attn_std = self.transformer.width ** -0.5 - # fc_std = (2 * self.transformer.width) ** -0.5 - # for block in self.transformer.resblocks: - # nn.init.normal_(block.attn.in_proj_weight, std=attn_std) - # nn.init.normal_(block.attn.out_proj.weight, std=proj_std) - # nn.init.normal_(block.mlp.c_fc.weight, std=fc_std) - # nn.init.normal_(block.mlp.c_proj.weight, std=proj_std) - # - # if self.text_projection is not None: - # nn.init.normal_(self.text_projection, std=self.scale) - pass - - @torch.jit.ignore - def set_grad_checkpointing(self, enable=True): - self.transformer.grad_checkpointing = enable - - def _global_pool(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: - if self.pool_type == 'avg': - pooled, tokens = x[:, 1:].mean(dim=1), x[:, 1:] - elif self.pool_type == 'tok': - pooled, tokens = x[:, 0], x[:, 1:] - else: - pooled = tokens = x - - return pooled, tokens - - def forward(self, x: torch.Tensor, ex_feats: Optional[torch.Tensor] = None, beta=1.2, gamma=3.0): - B, nc, w, h = x.shape - x = self.conv1(x) # shape = [*, width, grid, grid] - x = x.reshape(x.shape[0], x.shape[1], -1) # shape = [*, width, grid ** 2] - x = x.permute(0, 2, 1) # shape = [*, grid ** 2, width] - - # class embeddings and positional embeddings - x = torch.cat([_expand_token(self.class_embedding, x.shape[0]).to(x.dtype), x], dim=1) - # shape = [*, grid ** 2 + 1, width] - - if x.shape[1] != self.positional_embedding.shape[0]: - x = x + self.interpolate_pos_encoding(x, w, h).to(x.dtype) - else: - x = x + self.positional_embedding.to(x.dtype) - - x = self.patch_dropout(x) - x = self.ln_pre(x) - - token_size = h // self.patch_size[0], w // self.patch_size[1] - - x = x.permute(1, 0, 2) # NLD -> LND - for blk in self.transformer.resblocks[:-1]: - x = blk(x) - for blk in self.transformer.resblocks[-1:]: - if ex_feats is not None: - x = self.custom_attn(blk.attn, blk.ln_1(x), ex_feats=ex_feats, beta=beta, gamma=gamma, token_size=token_size) - else: - x = blk(x) - x = x[1:] - - x = x.permute(1, 0, 2) # LND -> NLD - - x = self.ln_post(x) - x = x @ self.proj - - return x - - def interpolate_pos_encoding(self, x, w, h): - npatch = x.shape[1] - 1 - N = self.positional_embedding.shape[0] - 1 - if npatch == N and w == h: - return self.positional_embedding - class_pos_embed = self.positional_embedding[[0]] - patch_pos_embed = self.positional_embedding[1:] - dim = x.shape[-1] - w0 = w // self.patch_size[0] - h0 = h // self.patch_size[1] - w0, h0 = w0 + 0.1, h0 + 0.1 - patch_pos_embed = nn.functional.interpolate( - patch_pos_embed.reshape(1, int(math.sqrt(N)), int(math.sqrt(N)), dim).permute(0, 3, 1, 2), - scale_factor=(w0 / math.sqrt(N), h0 / math.sqrt(N)), - mode='bicubic', - ) - assert int(w0) == patch_pos_embed.shape[-2] and int(h0) == patch_pos_embed.shape[-1] - patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim) - return torch.cat((class_pos_embed.unsqueeze(0), patch_pos_embed), dim=1) - - def custom_attn(self, attn_layer, x, ex_feats=None, beta=1.2, gamma=3.0, token_size=(16, 16)): - - num_heads = attn_layer.num_heads - _, bsz, embed_dim = x.size() - head_dim = embed_dim // num_heads - - q, k, v = F.linear(x, attn_layer.in_proj_weight, attn_layer.in_proj_bias).chunk(3, dim=-1) - v = v.contiguous().view(-1, bsz * num_heads, head_dim).transpose(0, 1) - - B, C, H, W = ex_feats.shape - q_k = F.normalize(ex_feats.flatten(2, 3), dim=1) - similarity = torch.einsum("b c m, b c n -> b m n", q_k, q_k) - - similarity = (similarity - torch.mean(similarity) * beta) * gamma - similarity[similarity < 0.0] = float('-inf') - - mask = similarity.to(q.dtype).unsqueeze(1).repeat(1, num_heads, 1, 1) - mask = mask.reshape(bsz * num_heads, mask.shape[2], mask.shape[3]) - attn_weights = F.softmax(mask, dim=-1) - - v = v[:, 1:, :].reshape(bsz*num_heads, token_size[0], token_size[1], head_dim).permute(0, 3, 1, 2) - v = F.interpolate(v, size=(H, W), mode='bilinear', align_corners=False) - v = v.permute(0, 2, 3, 1).reshape(bsz*num_heads, H*W, head_dim) - - attn_output = torch.bmm(attn_weights, v) - attn_output = attn_output.transpose(0, 1).contiguous().view(-1, bsz, embed_dim) - attn_output = attn_layer.out_proj(attn_output) - return attn_output - -def text_global_pool(x, text: Optional[torch.Tensor] = None, pool_type: str = 'argmax'): - if pool_type == 'first': - pooled, tokens = x[:, 0], x[:, 1:] - elif pool_type == 'last': - pooled, tokens = x[:, -1], x[:, :-1] - elif pool_type == 'argmax': - # take features from the eot embedding (eot_token is the highest number in each sequence) - assert text is not None - pooled, tokens = x[torch.arange(x.shape[0]), text.argmax(dim=-1)], x - else: - pooled = tokens = x - - return pooled, tokens - - -class TextTransformer(nn.Module): - output_tokens: torch.jit.Final[bool] - - def __init__( - self, - context_length: int = 77, - vocab_size: int = 49408, - width: int = 512, - heads: int = 8, - layers: int = 12, - mlp_ratio: float = 4.0, - ls_init_value: float = None, - output_dim: int = 512, - embed_cls: bool = False, - no_causal_mask: bool = False, - pad_id: int = 0, - pool_type: str = 'argmax', - proj_bias: bool = False, - act_layer: Callable = nn.GELU, - norm_layer: Callable = LayerNorm, - output_tokens: bool = False, - ): - super().__init__() - assert pool_type in ('first', 'last', 'argmax', 'none') - self.output_tokens = output_tokens - self.num_pos = self.context_length = context_length - self.vocab_size = vocab_size - self.width = width - self.output_dim = output_dim - self.heads = heads - self.pad_id = pad_id - self.pool_type = pool_type - - self.token_embedding = nn.Embedding(vocab_size, width) - if embed_cls: - self.cls_emb = nn.Parameter(torch.empty(width)) - self.num_pos += 1 - else: - self.cls_emb = None - self.positional_embedding = nn.Parameter(torch.empty(self.num_pos, width)) - self.transformer = Transformer( - width=width, - layers=layers, - heads=heads, - mlp_ratio=mlp_ratio, - ls_init_value=ls_init_value, - act_layer=act_layer, - norm_layer=norm_layer, - ) - self.ln_final = norm_layer(width) - - if no_causal_mask: - self.attn_mask = None - else: - self.register_buffer('attn_mask', self.build_causal_mask(), persistent=False) - - if proj_bias: - self.text_projection = nn.Linear(width, output_dim) - else: - self.text_projection = nn.Parameter(torch.empty(width, output_dim)) - - self.init_parameters() - - def init_parameters(self): - nn.init.normal_(self.token_embedding.weight, std=0.02) - nn.init.normal_(self.positional_embedding, std=0.01) - if self.cls_emb is not None: - nn.init.normal_(self.cls_emb, std=0.01) - - proj_std = (self.transformer.width ** -0.5) * ((2 * self.transformer.layers) ** -0.5) - attn_std = self.transformer.width ** -0.5 - fc_std = (2 * self.transformer.width) ** -0.5 - for block in self.transformer.resblocks: - nn.init.normal_(block.attn.in_proj_weight, std=attn_std) - nn.init.normal_(block.attn.out_proj.weight, std=proj_std) - nn.init.normal_(block.mlp.c_fc.weight, std=fc_std) - nn.init.normal_(block.mlp.c_proj.weight, std=proj_std) - - if self.text_projection is not None: - if isinstance(self.text_projection, nn.Linear): - nn.init.normal_(self.text_projection.weight, std=self.transformer.width ** -0.5) - if self.text_projection.bias is not None: - nn.init.zeros_(self.text_projection.bias) - else: - nn.init.normal_(self.text_projection, std=self.transformer.width ** -0.5) - - @torch.jit.ignore - def set_grad_checkpointing(self, enable=True): - self.transformer.grad_checkpointing = enable - - def build_causal_mask(self): - # lazily create causal attention mask, with full attention between the tokens - # pytorch uses additive attention mask; fill with -inf - mask = torch.empty(self.num_pos, self.num_pos) - mask.fill_(float("-inf")) - mask.triu_(1) # zero out the lower diagonal - return mask - - def build_cls_mask(self, text, cast_dtype: torch.dtype): - cls_mask = (text != self.pad_id).unsqueeze(1) - cls_mask = F.pad(cls_mask, (1, 0, cls_mask.shape[2], 0), value=True) - additive_mask = torch.empty(cls_mask.shape, dtype=cast_dtype, device=cls_mask.device) - additive_mask.fill_(0) - additive_mask.masked_fill_(~cls_mask, float("-inf")) - additive_mask = torch.repeat_interleave(additive_mask, self.heads, 0) - return additive_mask - - def forward(self, text): - cast_dtype = self.transformer.get_cast_dtype() - seq_len = text.shape[1] - - x = self.token_embedding(text).to(cast_dtype) # [batch_size, n_ctx, d_model] - attn_mask = self.attn_mask - if self.cls_emb is not None: - seq_len += 1 - x = torch.cat([x, _expand_token(self.cls_emb, x.shape[0])], dim=1) - cls_mask = self.build_cls_mask(text, cast_dtype) - if attn_mask is not None: - attn_mask = attn_mask[None, :seq_len, :seq_len] + cls_mask[:, :seq_len, :seq_len] - - x = x + self.positional_embedding[:seq_len].to(cast_dtype) - x = x.permute(1, 0, 2) # NLD -> LND - x = self.transformer(x, attn_mask=attn_mask) - x = x.permute(1, 0, 2) # LND -> NLD - - # x.shape = [batch_size, n_ctx, transformer.width] - if self.cls_emb is not None: - # presence of appended cls embed (CoCa) overrides pool_type, always take last token - pooled, tokens = text_global_pool(x, pool_type='last') - pooled = self.ln_final(pooled) # final LN applied after pooling in this case - else: - x = self.ln_final(x) - pooled, tokens = text_global_pool(x, text, pool_type=self.pool_type) - - if self.text_projection is not None: - if isinstance(self.text_projection, nn.Linear): - pooled = self.text_projection(pooled) - else: - pooled = pooled @ self.text_projection - - if self.output_tokens: - return pooled, tokens - - return pooled - - -class MultimodalTransformer(Transformer): - def __init__( - self, - width: int, - layers: int, - heads: int, - context_length: int = 77, - mlp_ratio: float = 4.0, - ls_init_value: float = None, - act_layer: Callable = nn.GELU, - norm_layer: Callable = LayerNorm, - output_dim: int = 512, - ): - - super().__init__( - width=width, - layers=layers, - heads=heads, - mlp_ratio=mlp_ratio, - ls_init_value=ls_init_value, - act_layer=act_layer, - norm_layer=norm_layer, - ) - self.context_length = context_length - self.cross_attn = nn.ModuleList([ - ResidualAttentionBlock( - width, - heads, - mlp_ratio, - ls_init_value=ls_init_value, - act_layer=act_layer, - norm_layer=norm_layer, - is_cross_attention=True, - ) - for _ in range(layers) - ]) - - self.register_buffer('attn_mask', self.build_attention_mask(), persistent=False) - - self.ln_final = norm_layer(width) - self.text_projection = nn.Parameter(torch.empty(width, output_dim)) - - def init_parameters(self): - proj_std = (self.transformer.width ** -0.5) * ((2 * self.transformer.layers) ** -0.5) - attn_std = self.transformer.width ** -0.5 - fc_std = (2 * self.transformer.width) ** -0.5 - for block in self.transformer.resblocks: - nn.init.normal_(block.attn.in_proj_weight, std=attn_std) - nn.init.normal_(block.attn.out_proj.weight, std=proj_std) - nn.init.normal_(block.mlp.c_fc.weight, std=fc_std) - nn.init.normal_(block.mlp.c_proj.weight, std=proj_std) - for block in self.transformer.cross_attn: - nn.init.normal_(block.attn.in_proj_weight, std=attn_std) - nn.init.normal_(block.attn.out_proj.weight, std=proj_std) - nn.init.normal_(block.mlp.c_fc.weight, std=fc_std) - nn.init.normal_(block.mlp.c_proj.weight, std=proj_std) - - if self.text_projection is not None: - nn.init.normal_(self.text_projection, std=self.transformer.width ** -0.5) - - def build_attention_mask(self): - # lazily create causal attention mask, with full attention between the tokens - # pytorch uses additive attention mask; fill with -inf - mask = torch.empty(self.context_length, self.context_length) - mask.fill_(float("-inf")) - mask.triu_(1) # zero out the lower diagonal - return mask - - def forward(self, image_embs, text_embs): - text_embs = text_embs.permute(1, 0, 2) # NLD -> LNDsq - image_embs = image_embs.permute(1, 0, 2) # NLD -> LND - seq_len = text_embs.shape[0] - - for resblock, cross_attn in zip(self.resblocks, self.cross_attn): - if self.grad_checkpointing and not torch.jit.is_scripting(): - # TODO: handle kwargs https://github.com/pytorch/pytorch/issues/79887#issuecomment-1161758372 - text_embs = checkpoint(resblock, text_embs, None, None, self.attn_mask[:seq_len, :seq_len]) - text_embs = checkpoint(cross_attn, text_embs, image_embs, image_embs, None) - else: - text_embs = resblock(text_embs, attn_mask=self.attn_mask[:seq_len, :seq_len]) - text_embs = cross_attn(text_embs, k_x=image_embs, v_x=image_embs) - - x = text_embs.permute(1, 0, 2) # LND -> NLD - x = self.ln_final(x) - - if self.text_projection is not None: - x = x @ self.text_projection - - return x - - @torch.jit.ignore - def set_grad_checkpointing(self, enable=True): - self.grad_checkpointing = enable diff --git a/src/proxyclip/open_clip_proxy/utils.py b/src/proxyclip/open_clip_proxy/utils.py deleted file mode 100644 index bb0bb8868ae1f2d31493ca32b73accd6bf1d3cdb..0000000000000000000000000000000000000000 --- a/src/proxyclip/open_clip_proxy/utils.py +++ /dev/null @@ -1,89 +0,0 @@ -from itertools import repeat -import collections.abc - -import torch -from torch import nn as nn -from torchvision.ops.misc import FrozenBatchNorm2d - - -def freeze_batch_norm_2d(module, module_match={}, name=''): - """ - Converts all `BatchNorm2d` and `SyncBatchNorm` layers of provided module into `FrozenBatchNorm2d`. If `module` is - itself an instance of either `BatchNorm2d` or `SyncBatchNorm`, it is converted into `FrozenBatchNorm2d` and - returned. Otherwise, the module is walked recursively and submodules are converted in place. - - Args: - module (torch.nn.Module): Any PyTorch module. - module_match (dict): Dictionary of full module names to freeze (all if empty) - name (str): Full module name (prefix) - - Returns: - torch.nn.Module: Resulting module - - Inspired by https://github.com/pytorch/pytorch/blob/a5895f85be0f10212791145bfedc0261d364f103/torch/nn/modules/batchnorm.py#L762 - """ - res = module - is_match = True - if module_match: - is_match = name in module_match - if is_match and isinstance(module, (nn.modules.batchnorm.BatchNorm2d, nn.modules.batchnorm.SyncBatchNorm)): - res = FrozenBatchNorm2d(module.num_features) - res.num_features = module.num_features - res.affine = module.affine - if module.affine: - res.weight.data = module.weight.data.clone().detach() - res.bias.data = module.bias.data.clone().detach() - res.running_mean.data = module.running_mean.data - res.running_var.data = module.running_var.data - res.eps = module.eps - else: - for child_name, child in module.named_children(): - full_child_name = '.'.join([name, child_name]) if name else child_name - new_child = freeze_batch_norm_2d(child, module_match, full_child_name) - if new_child is not child: - res.add_module(child_name, new_child) - return res - - -# From PyTorch internals -def _ntuple(n): - def parse(x): - if isinstance(x, collections.abc.Iterable): - return x - return tuple(repeat(x, n)) - return parse - - -to_1tuple = _ntuple(1) -to_2tuple = _ntuple(2) -to_3tuple = _ntuple(3) -to_4tuple = _ntuple(4) -to_ntuple = lambda n, x: _ntuple(n)(x) - -# Replaces all linear layers with linear_replacement -# TODO: add int8 support for other linear layers including attn and convnets -def replace_linear(model, linear_replacement, include_modules=['c_fc', 'c_proj'], copy_weights=True): - for name, module in model.named_children(): - if len(list(module.children())) > 0: - replace_linear(module, linear_replacement, include_modules, copy_weights) - - if isinstance(module, torch.nn.Linear) and name in include_modules: - old_module = model._modules[name] - model._modules[name] = linear_replacement( - module.in_features, - module.out_features, - module.bias is not None, - ) - if copy_weights: - model._modules[name].weight.data.copy_(old_module.weight.data) - if model._modules[name].bias is not None: - model._modules[name].bias.data.copy_(old_module.bias) - - return model - -def convert_int8_model_to_inference_mode(model): - for m in model.modules(): - if hasattr(m, 'prepare_for_eval'): - int8_original_dtype = m.weight.dtype - m.prepare_for_eval() - m.int8_original_dtype = int8_original_dtype \ No newline at end of file diff --git a/src/proxyclip/open_clip_proxy/version.py b/src/proxyclip/open_clip_proxy/version.py deleted file mode 100644 index 78afda8502b16f06c6a1b8a9f97f48ee0db9f6ce..0000000000000000000000000000000000000000 --- a/src/proxyclip/open_clip_proxy/version.py +++ /dev/null @@ -1 +0,0 @@ -__version__ = '2.24.0' diff --git a/src/proxyclip/open_clip_proxy/zero_shot_classifier.py b/src/proxyclip/open_clip_proxy/zero_shot_classifier.py deleted file mode 100644 index 535ec9696d27a1dcbe2c43da18f5fd20b599cb9b..0000000000000000000000000000000000000000 --- a/src/proxyclip/open_clip_proxy/zero_shot_classifier.py +++ /dev/null @@ -1,110 +0,0 @@ -from functools import partial -from itertools import islice -from typing import Callable, List, Optional, Sequence, Union - -import torch -import torch.nn.functional as F - - -def batched(iterable, n): - """Batch data into lists of length *n*. The last batch may be shorter. - NOTE based on more-itertools impl, to be replaced by python 3.12 itertools.batched impl - """ - it = iter(iterable) - while True: - batch = list(islice(it, n)) - if not batch: - break - yield batch - - -def build_zero_shot_classifier( - model, - tokenizer, - classnames: Sequence[str], - templates: Sequence[Union[Callable, str]], - num_classes_per_batch: Optional[int] = 10, - device: Union[str, torch.device] = 'cpu', - use_tqdm: bool = False, -): - """ Build zero-shot classifier weights by iterating over class names in batches - Args: - model: CLIP model instance - tokenizer: CLIP tokenizer instance - classnames: A sequence of class (label) names - templates: A sequence of callables or format() friendly strings to produce templates per class name - num_classes_per_batch: The number of classes to batch together in each forward, all if None - device: Device to use. - use_tqdm: Enable TQDM progress bar. - """ - assert isinstance(templates, Sequence) and len(templates) > 0 - assert isinstance(classnames, Sequence) and len(classnames) > 0 - use_format = isinstance(templates[0], str) - num_templates = len(templates) - num_classes = len(classnames) - if use_tqdm: - import tqdm - num_iter = 1 if num_classes_per_batch is None else ((num_classes - 1) // num_classes_per_batch + 1) - iter_wrap = partial(tqdm.tqdm, total=num_iter, unit_scale=num_classes_per_batch) - else: - iter_wrap = iter - - def _process_batch(batch_classnames): - num_batch_classes = len(batch_classnames) - texts = [template.format(c) if use_format else template(c) for c in batch_classnames for template in templates] - texts = tokenizer(texts).to(device) - class_embeddings = model.encode_text(texts, normalize=True) - class_embeddings = class_embeddings.reshape(num_batch_classes, num_templates, -1).mean(dim=1) - class_embeddings = class_embeddings / class_embeddings.norm(dim=1, keepdim=True) - class_embeddings = class_embeddings.T - return class_embeddings - - with torch.no_grad(): - if num_classes_per_batch: - batched_embeds = [_process_batch(batch) for batch in iter_wrap(batched(classnames, num_classes_per_batch))] - zeroshot_weights = torch.cat(batched_embeds, dim=1) - else: - zeroshot_weights = _process_batch(classnames) - return zeroshot_weights - - -def build_zero_shot_classifier_legacy( - model, - tokenizer, - classnames: Sequence[str], - templates: Sequence[Union[Callable, str]], - device: Union[str, torch.device] = 'cpu', - use_tqdm: bool = False, -): - """ Build zero-shot classifier weights by iterating over class names 1 by 1 - Args: - model: CLIP model instance - tokenizer: CLIP tokenizer instance - classnames: A sequence of class (label) names - templates: A sequence of callables or format() friendly strings to produce templates per class name - device: Device to use. - use_tqdm: Enable TQDM progress bar. - """ - assert isinstance(templates, Sequence) and len(templates) > 0 - assert isinstance(classnames, Sequence) and len(classnames) > 0 - if use_tqdm: - import tqdm - iter_wrap = tqdm.tqdm - else: - iter_wrap = iter - - use_format = isinstance(templates[0], str) - - with torch.no_grad(): - zeroshot_weights = [] - for classname in iter_wrap(classnames): - texts = [template.format(classname) if use_format else template(classname) for template in templates] - texts = tokenizer(texts).to(device) # tokenize - class_embeddings = model.encode_text(texts) - class_embedding = F.normalize(class_embeddings, dim=-1).mean(dim=0) - class_embedding /= class_embedding.norm() - zeroshot_weights.append(class_embedding) - zeroshot_weights = torch.stack(zeroshot_weights, dim=1).to(device) - - return zeroshot_weights - diff --git a/src/proxyclip/open_clip_proxy/zero_shot_metadata.py b/src/proxyclip/open_clip_proxy/zero_shot_metadata.py deleted file mode 100644 index ccb452bbb6e27b71cff1dd27e2bb263259b9363f..0000000000000000000000000000000000000000 --- a/src/proxyclip/open_clip_proxy/zero_shot_metadata.py +++ /dev/null @@ -1,266 +0,0 @@ - -OPENAI_IMAGENET_TEMPLATES = ( - lambda c: f'a bad photo of a {c}.', - lambda c: f'a photo of many {c}.', - lambda c: f'a sculpture of a {c}.', - lambda c: f'a photo of the hard to see {c}.', - lambda c: f'a low resolution photo of the {c}.', - lambda c: f'a rendering of a {c}.', - lambda c: f'graffiti of a {c}.', - lambda c: f'a bad photo of the {c}.', - lambda c: f'a cropped photo of the {c}.', - lambda c: f'a tattoo of a {c}.', - lambda c: f'the embroidered {c}.', - lambda c: f'a photo of a hard to see {c}.', - lambda c: f'a bright photo of a {c}.', - lambda c: f'a photo of a clean {c}.', - lambda c: f'a photo of a dirty {c}.', - lambda c: f'a dark photo of the {c}.', - lambda c: f'a drawing of a {c}.', - lambda c: f'a photo of my {c}.', - lambda c: f'the plastic {c}.', - lambda c: f'a photo of the cool {c}.', - lambda c: f'a close-up photo of a {c}.', - lambda c: f'a black and white photo of the {c}.', - lambda c: f'a painting of the {c}.', - lambda c: f'a painting of a {c}.', - lambda c: f'a pixelated photo of the {c}.', - lambda c: f'a sculpture of the {c}.', - lambda c: f'a bright photo of the {c}.', - lambda c: f'a cropped photo of a {c}.', - lambda c: f'a plastic {c}.', - lambda c: f'a photo of the dirty {c}.', - lambda c: f'a jpeg corrupted photo of a {c}.', - lambda c: f'a blurry photo of the {c}.', - lambda c: f'a photo of the {c}.', - lambda c: f'a good photo of the {c}.', - lambda c: f'a rendering of the {c}.', - lambda c: f'a {c} in a video game.', - lambda c: f'a photo of one {c}.', - lambda c: f'a doodle of a {c}.', - lambda c: f'a close-up photo of the {c}.', - lambda c: f'a photo of a {c}.', - lambda c: f'the origami {c}.', - lambda c: f'the {c} in a video game.', - lambda c: f'a sketch of a {c}.', - lambda c: f'a doodle of the {c}.', - lambda c: f'a origami {c}.', - lambda c: f'a low resolution photo of a {c}.', - lambda c: f'the toy {c}.', - lambda c: f'a rendition of the {c}.', - lambda c: f'a photo of the clean {c}.', - lambda c: f'a photo of a large {c}.', - lambda c: f'a rendition of a {c}.', - lambda c: f'a photo of a nice {c}.', - lambda c: f'a photo of a weird {c}.', - lambda c: f'a blurry photo of a {c}.', - lambda c: f'a cartoon {c}.', - lambda c: f'art of a {c}.', - lambda c: f'a sketch of the {c}.', - lambda c: f'a embroidered {c}.', - lambda c: f'a pixelated photo of a {c}.', - lambda c: f'itap of the {c}.', - lambda c: f'a jpeg corrupted photo of the {c}.', - lambda c: f'a good photo of a {c}.', - lambda c: f'a plushie {c}.', - lambda c: f'a photo of the nice {c}.', - lambda c: f'a photo of the small {c}.', - lambda c: f'a photo of the weird {c}.', - lambda c: f'the cartoon {c}.', - lambda c: f'art of the {c}.', - lambda c: f'a drawing of the {c}.', - lambda c: f'a photo of the large {c}.', - lambda c: f'a black and white photo of a {c}.', - lambda c: f'the plushie {c}.', - lambda c: f'a dark photo of a {c}.', - lambda c: f'itap of a {c}.', - lambda c: f'graffiti of the {c}.', - lambda c: f'a toy {c}.', - lambda c: f'itap of my {c}.', - lambda c: f'a photo of a cool {c}.', - lambda c: f'a photo of a small {c}.', - lambda c: f'a tattoo of the {c}.', -) - - -# a much smaller subset of above prompts -# from https://github.com/openai/CLIP/blob/main/notebooks/Prompt_Engineering_for_ImageNet.ipynb -SIMPLE_IMAGENET_TEMPLATES = ( - lambda c: f'itap of a {c}.', - lambda c: f'a bad photo of the {c}.', - lambda c: f'a origami {c}.', - lambda c: f'a photo of the large {c}.', - lambda c: f'a {c} in a video game.', - lambda c: f'art of the {c}.', - lambda c: f'a photo of the small {c}.', -) - - -IMAGENET_CLASSNAMES = ( - "tench", "goldfish", "great white shark", "tiger shark", "hammerhead shark", "electric ray", - "stingray", "rooster", "hen", "ostrich", "brambling", "goldfinch", "house finch", "junco", - "indigo bunting", "American robin", "bulbul", "jay", "magpie", "chickadee", "American dipper", - "kite (bird of prey)", "bald eagle", "vulture", "great grey owl", "fire salamander", - "smooth newt", "newt", "spotted salamander", "axolotl", "American bullfrog", "tree frog", - "tailed frog", "loggerhead sea turtle", "leatherback sea turtle", "mud turtle", "terrapin", - "box turtle", "banded gecko", "green iguana", "Carolina anole", - "desert grassland whiptail lizard", "agama", "frilled-necked lizard", "alligator lizard", - "Gila monster", "European green lizard", "chameleon", "Komodo dragon", "Nile crocodile", - "American alligator", "triceratops", "worm snake", "ring-necked snake", - "eastern hog-nosed snake", "smooth green snake", "kingsnake", "garter snake", "water snake", - "vine snake", "night snake", "boa constrictor", "African rock python", "Indian cobra", - "green mamba", "sea snake", "Saharan horned viper", "eastern diamondback rattlesnake", - "sidewinder rattlesnake", "trilobite", "harvestman", "scorpion", "yellow garden spider", - "barn spider", "European garden spider", "southern black widow", "tarantula", "wolf spider", - "tick", "centipede", "black grouse", "ptarmigan", "ruffed grouse", "prairie grouse", "peafowl", - "quail", "partridge", "african grey parrot", "macaw", "sulphur-crested cockatoo", "lorikeet", - "coucal", "bee eater", "hornbill", "hummingbird", "jacamar", "toucan", "duck", - "red-breasted merganser", "goose", "black swan", "tusker", "echidna", "platypus", "wallaby", - "koala", "wombat", "jellyfish", "sea anemone", "brain coral", "flatworm", "nematode", "conch", - "snail", "slug", "sea slug", "chiton", "chambered nautilus", "Dungeness crab", "rock crab", - "fiddler crab", "red king crab", "American lobster", "spiny lobster", "crayfish", "hermit crab", - "isopod", "white stork", "black stork", "spoonbill", "flamingo", "little blue heron", - "great egret", "bittern bird", "crane bird", "limpkin", "common gallinule", "American coot", - "bustard", "ruddy turnstone", "dunlin", "common redshank", "dowitcher", "oystercatcher", - "pelican", "king penguin", "albatross", "grey whale", "killer whale", "dugong", "sea lion", - "Chihuahua", "Japanese Chin", "Maltese", "Pekingese", "Shih Tzu", "King Charles Spaniel", - "Papillon", "toy terrier", "Rhodesian Ridgeback", "Afghan Hound", "Basset Hound", "Beagle", - "Bloodhound", "Bluetick Coonhound", "Black and Tan Coonhound", "Treeing Walker Coonhound", - "English foxhound", "Redbone Coonhound", "borzoi", "Irish Wolfhound", "Italian Greyhound", - "Whippet", "Ibizan Hound", "Norwegian Elkhound", "Otterhound", "Saluki", "Scottish Deerhound", - "Weimaraner", "Staffordshire Bull Terrier", "American Staffordshire Terrier", - "Bedlington Terrier", "Border Terrier", "Kerry Blue Terrier", "Irish Terrier", - "Norfolk Terrier", "Norwich Terrier", "Yorkshire Terrier", "Wire Fox Terrier", - "Lakeland Terrier", "Sealyham Terrier", "Airedale Terrier", "Cairn Terrier", - "Australian Terrier", "Dandie Dinmont Terrier", "Boston Terrier", "Miniature Schnauzer", - "Giant Schnauzer", "Standard Schnauzer", "Scottish Terrier", "Tibetan Terrier", - "Australian Silky Terrier", "Soft-coated Wheaten Terrier", "West Highland White Terrier", - "Lhasa Apso", "Flat-Coated Retriever", "Curly-coated Retriever", "Golden Retriever", - "Labrador Retriever", "Chesapeake Bay Retriever", "German Shorthaired Pointer", "Vizsla", - "English Setter", "Irish Setter", "Gordon Setter", "Brittany dog", "Clumber Spaniel", - "English Springer Spaniel", "Welsh Springer Spaniel", "Cocker Spaniel", "Sussex Spaniel", - "Irish Water Spaniel", "Kuvasz", "Schipperke", "Groenendael dog", "Malinois", "Briard", - "Australian Kelpie", "Komondor", "Old English Sheepdog", "Shetland Sheepdog", "collie", - "Border Collie", "Bouvier des Flandres dog", "Rottweiler", "German Shepherd Dog", "Dobermann", - "Miniature Pinscher", "Greater Swiss Mountain Dog", "Bernese Mountain Dog", - "Appenzeller Sennenhund", "Entlebucher Sennenhund", "Boxer", "Bullmastiff", "Tibetan Mastiff", - "French Bulldog", "Great Dane", "St. Bernard", "husky", "Alaskan Malamute", "Siberian Husky", - "Dalmatian", "Affenpinscher", "Basenji", "pug", "Leonberger", "Newfoundland dog", - "Great Pyrenees dog", "Samoyed", "Pomeranian", "Chow Chow", "Keeshond", "brussels griffon", - "Pembroke Welsh Corgi", "Cardigan Welsh Corgi", "Toy Poodle", "Miniature Poodle", - "Standard Poodle", "Mexican hairless dog (xoloitzcuintli)", "grey wolf", "Alaskan tundra wolf", - "red wolf or maned wolf", "coyote", "dingo", "dhole", "African wild dog", "hyena", "red fox", - "kit fox", "Arctic fox", "grey fox", "tabby cat", "tiger cat", "Persian cat", "Siamese cat", - "Egyptian Mau", "cougar", "lynx", "leopard", "snow leopard", "jaguar", "lion", "tiger", - "cheetah", "brown bear", "American black bear", "polar bear", "sloth bear", "mongoose", - "meerkat", "tiger beetle", "ladybug", "ground beetle", "longhorn beetle", "leaf beetle", - "dung beetle", "rhinoceros beetle", "weevil", "fly", "bee", "ant", "grasshopper", - "cricket insect", "stick insect", "cockroach", "praying mantis", "cicada", "leafhopper", - "lacewing", "dragonfly", "damselfly", "red admiral butterfly", "ringlet butterfly", - "monarch butterfly", "small white butterfly", "sulphur butterfly", "gossamer-winged butterfly", - "starfish", "sea urchin", "sea cucumber", "cottontail rabbit", "hare", "Angora rabbit", - "hamster", "porcupine", "fox squirrel", "marmot", "beaver", "guinea pig", "common sorrel horse", - "zebra", "pig", "wild boar", "warthog", "hippopotamus", "ox", "water buffalo", "bison", - "ram (adult male sheep)", "bighorn sheep", "Alpine ibex", "hartebeest", "impala (antelope)", - "gazelle", "arabian camel", "llama", "weasel", "mink", "European polecat", - "black-footed ferret", "otter", "skunk", "badger", "armadillo", "three-toed sloth", "orangutan", - "gorilla", "chimpanzee", "gibbon", "siamang", "guenon", "patas monkey", "baboon", "macaque", - "langur", "black-and-white colobus", "proboscis monkey", "marmoset", "white-headed capuchin", - "howler monkey", "titi monkey", "Geoffroy's spider monkey", "common squirrel monkey", - "ring-tailed lemur", "indri", "Asian elephant", "African bush elephant", "red panda", - "giant panda", "snoek fish", "eel", "silver salmon", "rock beauty fish", "clownfish", - "sturgeon", "gar fish", "lionfish", "pufferfish", "abacus", "abaya", "academic gown", - "accordion", "acoustic guitar", "aircraft carrier", "airliner", "airship", "altar", "ambulance", - "amphibious vehicle", "analog clock", "apiary", "apron", "trash can", "assault rifle", - "backpack", "bakery", "balance beam", "balloon", "ballpoint pen", "Band-Aid", "banjo", - "baluster / handrail", "barbell", "barber chair", "barbershop", "barn", "barometer", "barrel", - "wheelbarrow", "baseball", "basketball", "bassinet", "bassoon", "swimming cap", "bath towel", - "bathtub", "station wagon", "lighthouse", "beaker", "military hat (bearskin or shako)", - "beer bottle", "beer glass", "bell tower", "baby bib", "tandem bicycle", "bikini", - "ring binder", "binoculars", "birdhouse", "boathouse", "bobsleigh", "bolo tie", "poke bonnet", - "bookcase", "bookstore", "bottle cap", "hunting bow", "bow tie", "brass memorial plaque", "bra", - "breakwater", "breastplate", "broom", "bucket", "buckle", "bulletproof vest", - "high-speed train", "butcher shop", "taxicab", "cauldron", "candle", "cannon", "canoe", - "can opener", "cardigan", "car mirror", "carousel", "tool kit", "cardboard box / carton", - "car wheel", "automated teller machine", "cassette", "cassette player", "castle", "catamaran", - "CD player", "cello", "mobile phone", "chain", "chain-link fence", "chain mail", "chainsaw", - "storage chest", "chiffonier", "bell or wind chime", "china cabinet", "Christmas stocking", - "church", "movie theater", "cleaver", "cliff dwelling", "cloak", "clogs", "cocktail shaker", - "coffee mug", "coffeemaker", "spiral or coil", "combination lock", "computer keyboard", - "candy store", "container ship", "convertible", "corkscrew", "cornet", "cowboy boot", - "cowboy hat", "cradle", "construction crane", "crash helmet", "crate", "infant bed", - "Crock Pot", "croquet ball", "crutch", "cuirass", "dam", "desk", "desktop computer", - "rotary dial telephone", "diaper", "digital clock", "digital watch", "dining table", - "dishcloth", "dishwasher", "disc brake", "dock", "dog sled", "dome", "doormat", "drilling rig", - "drum", "drumstick", "dumbbell", "Dutch oven", "electric fan", "electric guitar", - "electric locomotive", "entertainment center", "envelope", "espresso machine", "face powder", - "feather boa", "filing cabinet", "fireboat", "fire truck", "fire screen", "flagpole", "flute", - "folding chair", "football helmet", "forklift", "fountain", "fountain pen", "four-poster bed", - "freight car", "French horn", "frying pan", "fur coat", "garbage truck", - "gas mask or respirator", "gas pump", "goblet", "go-kart", "golf ball", "golf cart", "gondola", - "gong", "gown", "grand piano", "greenhouse", "radiator grille", "grocery store", "guillotine", - "hair clip", "hair spray", "half-track", "hammer", "hamper", "hair dryer", "hand-held computer", - "handkerchief", "hard disk drive", "harmonica", "harp", "combine harvester", "hatchet", - "holster", "home theater", "honeycomb", "hook", "hoop skirt", "gymnastic horizontal bar", - "horse-drawn vehicle", "hourglass", "iPod", "clothes iron", "carved pumpkin", "jeans", "jeep", - "T-shirt", "jigsaw puzzle", "rickshaw", "joystick", "kimono", "knee pad", "knot", "lab coat", - "ladle", "lampshade", "laptop computer", "lawn mower", "lens cap", "letter opener", "library", - "lifeboat", "lighter", "limousine", "ocean liner", "lipstick", "slip-on shoe", "lotion", - "music speaker", "loupe magnifying glass", "sawmill", "magnetic compass", "messenger bag", - "mailbox", "tights", "one-piece bathing suit", "manhole cover", "maraca", "marimba", "mask", - "matchstick", "maypole", "maze", "measuring cup", "medicine cabinet", "megalith", "microphone", - "microwave oven", "military uniform", "milk can", "minibus", "miniskirt", "minivan", "missile", - "mitten", "mixing bowl", "mobile home", "ford model t", "modem", "monastery", "monitor", - "moped", "mortar and pestle", "graduation cap", "mosque", "mosquito net", "vespa", - "mountain bike", "tent", "computer mouse", "mousetrap", "moving van", "muzzle", "metal nail", - "neck brace", "necklace", "baby pacifier", "notebook computer", "obelisk", "oboe", "ocarina", - "odometer", "oil filter", "pipe organ", "oscilloscope", "overskirt", "bullock cart", - "oxygen mask", "product packet / packaging", "paddle", "paddle wheel", "padlock", "paintbrush", - "pajamas", "palace", "pan flute", "paper towel", "parachute", "parallel bars", "park bench", - "parking meter", "railroad car", "patio", "payphone", "pedestal", "pencil case", - "pencil sharpener", "perfume", "Petri dish", "photocopier", "plectrum", "Pickelhaube", - "picket fence", "pickup truck", "pier", "piggy bank", "pill bottle", "pillow", "ping-pong ball", - "pinwheel", "pirate ship", "drink pitcher", "block plane", "planetarium", "plastic bag", - "plate rack", "farm plow", "plunger", "Polaroid camera", "pole", "police van", "poncho", - "pool table", "soda bottle", "plant pot", "potter's wheel", "power drill", "prayer rug", - "printer", "prison", "missile", "projector", "hockey puck", "punching bag", "purse", "quill", - "quilt", "race car", "racket", "radiator", "radio", "radio telescope", "rain barrel", - "recreational vehicle", "fishing casting reel", "reflex camera", "refrigerator", - "remote control", "restaurant", "revolver", "rifle", "rocking chair", "rotisserie", "eraser", - "rugby ball", "ruler measuring stick", "sneaker", "safe", "safety pin", "salt shaker", "sandal", - "sarong", "saxophone", "scabbard", "weighing scale", "school bus", "schooner", "scoreboard", - "CRT monitor", "screw", "screwdriver", "seat belt", "sewing machine", "shield", "shoe store", - "shoji screen / room divider", "shopping basket", "shopping cart", "shovel", "shower cap", - "shower curtain", "ski", "balaclava ski mask", "sleeping bag", "slide rule", "sliding door", - "slot machine", "snorkel", "snowmobile", "snowplow", "soap dispenser", "soccer ball", "sock", - "solar thermal collector", "sombrero", "soup bowl", "keyboard space bar", "space heater", - "space shuttle", "spatula", "motorboat", "spider web", "spindle", "sports car", "spotlight", - "stage", "steam locomotive", "through arch bridge", "steel drum", "stethoscope", "scarf", - "stone wall", "stopwatch", "stove", "strainer", "tram", "stretcher", "couch", "stupa", - "submarine", "suit", "sundial", "sunglasses", "sunglasses", "sunscreen", "suspension bridge", - "mop", "sweatshirt", "swim trunks / shorts", "swing", "electrical switch", "syringe", - "table lamp", "tank", "tape player", "teapot", "teddy bear", "television", "tennis ball", - "thatched roof", "front curtain", "thimble", "threshing machine", "throne", "tile roof", - "toaster", "tobacco shop", "toilet seat", "torch", "totem pole", "tow truck", "toy store", - "tractor", "semi-trailer truck", "tray", "trench coat", "tricycle", "trimaran", "tripod", - "triumphal arch", "trolleybus", "trombone", "hot tub", "turnstile", "typewriter keyboard", - "umbrella", "unicycle", "upright piano", "vacuum cleaner", "vase", "vaulted or arched ceiling", - "velvet fabric", "vending machine", "vestment", "viaduct", "violin", "volleyball", - "waffle iron", "wall clock", "wallet", "wardrobe", "military aircraft", "sink", - "washing machine", "water bottle", "water jug", "water tower", "whiskey jug", "whistle", - "hair wig", "window screen", "window shade", "Windsor tie", "wine bottle", "airplane wing", - "wok", "wooden spoon", "wool", "split-rail fence", "shipwreck", "sailboat", "yurt", "website", - "comic book", "crossword", "traffic or street sign", "traffic light", "dust jacket", "menu", - "plate", "guacamole", "consomme", "hot pot", "trifle", "ice cream", "popsicle", "baguette", - "bagel", "pretzel", "cheeseburger", "hot dog", "mashed potatoes", "cabbage", "broccoli", - "cauliflower", "zucchini", "spaghetti squash", "acorn squash", "butternut squash", "cucumber", - "artichoke", "bell pepper", "cardoon", "mushroom", "Granny Smith apple", "strawberry", "orange", - "lemon", "fig", "pineapple", "banana", "jackfruit", "cherimoya (custard apple)", "pomegranate", - "hay", "carbonara", "chocolate syrup", "dough", "meatloaf", "pizza", "pot pie", "burrito", - "red wine", "espresso", "tea cup", "eggnog", "mountain", "bubble", "cliff", "coral reef", - "geyser", "lakeshore", "promontory", "sandbar", "beach", "valley", "volcano", "baseball player", - "bridegroom", "scuba diver", "rapeseed", "daisy", "yellow lady's slipper", "corn", "acorn", - "rose hip", "horse chestnut seed", "coral fungus", "agaric", "gyromitra", "stinkhorn mushroom", - "earth star fungus", "hen of the woods mushroom", "bolete", "corn cob", "toilet paper" -) - diff --git a/src/proxyclip/proxyclip.py b/src/proxyclip/proxyclip.py deleted file mode 100644 index 4660c21fbd3dd5702c3a468eaa43448a472d0b5a..0000000000000000000000000000000000000000 --- a/src/proxyclip/proxyclip.py +++ /dev/null @@ -1,86 +0,0 @@ -from torch import nn -import torch -from .open_clip_proxy import create_model, tokenizer -from torchvision import transforms as T - -class ProxyCLIP(nn.Module): - def __init__(self, clip_type, model_type, vfm_model, device=torch.device('cuda'), beta=1.2, gamma=3.0, slide_crop=336): - - super().__init__() - - self.clip = create_model(model_type, pretrained=clip_type, precision='fp16') - self.clip.eval().to(device) - self.tokenizer = tokenizer.tokenize - - self.vfm_model = vfm_model - - if vfm_model == 'dino': - # self.vfm = torch.hub.load('facebookresearch/dino:main', 'dino_vits16') - # self.vfm = torch.hub.load('facebookresearch/dino:main', 'dino_vits8') - # self.vfm = torch.hub.load('facebookresearch/dino:main', 'dino_vitb16') - self.vfm = torch.hub.load('facebookresearch/dino:main', 'dino_vitb8') - - elif vfm_model == 'dinov2': - # self.vfm = torch.hub.load('facebookresearch/dinov2', 'dinov2_vits14_reg') - self.vfm = torch.hub.load('facebookresearch/dinov2', 'dinov2_vitb14_reg') - - self.vfm = self.vfm.half() - for p in self.vfm.parameters(): - p.requires_grad = False - self.vfm.eval().to(device) - - self.norm = T.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) - - self.slide_crop = slide_crop - self.beta = beta - self.gamma = gamma - - @torch.no_grad() - def forward(self, img): - if type(img) == list: - img = img[0] - - clip_token_size = img.shape[-2] // self.clip.visual.patch_size[0], img.shape[-1] // self.clip.visual.patch_size[1] - - # imgs_norm = [self.norm(self.unnorm(img[i])) for i in range(len(img))] - # imgs_norm = torch.stack(imgs_norm, dim=0) - imgs_norm = img - - imgs_norm = imgs_norm.half() - if self.vfm_model == 'dino': - feat_out = {} - def hook_fn_forward_qkv(module, input, output): - feat_out["qkv"] = output - self.vfm._modules["blocks"][-1]._modules["attn"]._modules["qkv"].register_forward_hook( - hook_fn_forward_qkv) - - # Forward pass in the model - feat = self.vfm.get_intermediate_layers(imgs_norm)[0] - - nb_im = feat.shape[0] # Batch size - - patch_size = self.vfm.patch_embed.patch_size - I, J = imgs_norm[0].shape[-2] // patch_size, imgs_norm[0].shape[-2] // patch_size - ex_feats = feat[:, 1:, :].reshape(nb_im, I, J, -1).permute(0, 3, 1, 2) - - elif self.vfm_model == 'dinov2': - patch_size = self.vfm.patch_embed.patch_size - I, J = imgs_norm.shape[-2] // patch_size[0], imgs_norm.shape[-2] // patch_size[1] - ex_feats = self.vfm.get_intermediate_layers(imgs_norm, reshape=True)[0] - - else: - I, J = clip_token_size - ex_feats = None - - image_features = self.clip.encode_image(img.half(), - external_feats=ex_feats, - beta=self.beta, - gamma=self.gamma) - - image_features /= image_features.norm(dim=-1, keepdim=True) - - - - return { - 'x_norm_patchtokens': image_features.float() - } \ No newline at end of file diff --git a/src/regionclip/backbone.py b/src/regionclip/backbone.py deleted file mode 100644 index de32e2117c1a953f89a34ae9e388d9857d776927..0000000000000000000000000000000000000000 --- a/src/regionclip/backbone.py +++ /dev/null @@ -1,53 +0,0 @@ -# Copyright (c) Facebook, Inc. and its affiliates. -from abc import ABCMeta, abstractmethod -import torch.nn as nn - -from detectron2.layers import ShapeSpec - -__all__ = ["Backbone"] - - -class Backbone(nn.Module, metaclass=ABCMeta): - """ - Abstract base class for network backbones. - """ - - def __init__(self): - """ - The `__init__` method of any subclass can specify its own set of arguments. - """ - super().__init__() - - @abstractmethod - def forward(self): - """ - Subclasses must override this method, but adhere to the same return type. - - Returns: - dict[str->Tensor]: mapping from feature name (e.g., "res2") to tensor - """ - pass - - @property - def size_divisibility(self) -> int: - """ - Some backbones require the input height and width to be divisible by a - specific integer. This is typically true for encoder / decoder type networks - with lateral connection (e.g., FPN) for which feature maps need to match - dimension in the "bottom up" and "top down" paths. Set to 0 if no specific - input size divisibility is required. - """ - return 0 - - def output_shape(self): - """ - Returns: - dict[str->ShapeSpec] - """ - # this is a backward-compatible default - return { - name: ShapeSpec( - channels=self._out_feature_channels[name], stride=self._out_feature_strides[name] - ) - for name in self._out_features - } \ No newline at end of file diff --git a/src/regionclip/build.py b/src/regionclip/build.py deleted file mode 100644 index 181e5387354049330af816c96e632e4d37c012c4..0000000000000000000000000000000000000000 --- a/src/regionclip/build.py +++ /dev/null @@ -1,33 +0,0 @@ -# Copyright (c) Facebook, Inc. and its affiliates. -from detectron2.layers import ShapeSpec -from detectron2.utils.registry import Registry - -from .backbone import Backbone - -BACKBONE_REGISTRY = Registry("BACKBONE") -BACKBONE_REGISTRY.__doc__ = """ -Registry for backbones, which extract feature maps from images - -The registered object must be a callable that accepts two arguments: - -1. A :class:`detectron2.config.CfgNode` -2. A :class:`detectron2.layers.ShapeSpec`, which contains the input shape specification. - -Registered object must return instance of :class:`Backbone`. -""" - - -def build_backbone(cfg, input_shape=None): - """ - Build a backbone from `cfg.MODEL.BACKBONE.NAME`. - - Returns: - an instance of :class:`Backbone` - """ - if input_shape is None: - input_shape = ShapeSpec(channels=len(cfg.MODEL.PIXEL_MEAN)) - - backbone_name = cfg.MODEL.BACKBONE.NAME - backbone = BACKBONE_REGISTRY.get(backbone_name)(cfg, input_shape) - assert isinstance(backbone, Backbone) - return backbone \ No newline at end of file diff --git a/src/regionclip/clip_backbone.py b/src/regionclip/clip_backbone.py deleted file mode 100644 index c6faf3bce685169da4c32911bf855d7e19e71b36..0000000000000000000000000000000000000000 --- a/src/regionclip/clip_backbone.py +++ /dev/null @@ -1,960 +0,0 @@ -from collections import OrderedDict -from typing import Tuple, Union - -import numpy as np -import torch -import torch.nn.functional as F -from torch import nn - -from .backbone import Backbone -from .build import BACKBONE_REGISTRY -from detectron2.layers.blocks import FrozenBatchNorm2d -from detectron2.layers import ShapeSpec - -class Bottleneck(nn.Module): - expansion = 4 - - def __init__(self, inplanes, planes, stride=1, norm_type='FronzenBN'): - super().__init__() - - # all conv layers have stride 1. an avgpool is performed after the second convolution when stride > 1 - self.conv1 = nn.Conv2d(inplanes, planes, 1, bias=False) - if norm_type == 'FronzenBN': - self.bn1 = FrozenBatchNorm2d(planes) # nn.BatchNorm2d(planes) - elif norm_type == 'SyncBN': - self.bn1 = nn.SyncBatchNorm(planes) - - self.conv2 = nn.Conv2d(planes, planes, 3, padding=1, bias=False) - if norm_type == 'FronzenBN': - self.bn2 = FrozenBatchNorm2d(planes) # nn.BatchNorm2d(planes) - elif norm_type == 'SyncBN': - self.bn2 = nn.SyncBatchNorm(planes) - - self.avgpool = nn.AvgPool2d(stride) if stride > 1 else nn.Identity() - - self.conv3 = nn.Conv2d(planes, planes * self.expansion, 1, bias=False) - if norm_type == 'FronzenBN': - self.bn3 = FrozenBatchNorm2d(planes * self.expansion) # nn.BatchNorm2d(planes * self.expansion) - elif norm_type == 'SyncBN': - self.bn3 = nn.SyncBatchNorm(planes * self.expansion) - - self.relu = nn.ReLU(inplace=True) - self.downsample = None - self.stride = stride - - if stride > 1 or inplanes != planes * Bottleneck.expansion: - # downsampling layer is prepended with an avgpool, and the subsequent convolution has stride 1 - if norm_type == 'FronzenBN': - this_norm = FrozenBatchNorm2d(planes * self.expansion) #("1", nn.BatchNorm2d(planes * self.expansion)) - elif norm_type == 'SyncBN': - this_norm = nn.SyncBatchNorm(planes * self.expansion) - self.downsample = nn.Sequential(OrderedDict([ - ("-1", nn.AvgPool2d(stride)), - ("0", nn.Conv2d(inplanes, planes * self.expansion, 1, stride=1, bias=False)), - ("1", this_norm), #("1", nn.BatchNorm2d(planes * self.expansion)) - ])) - - def forward(self, x: torch.Tensor): - identity = x - - out = self.relu(self.bn1(self.conv1(x))) - out = self.relu(self.bn2(self.conv2(out))) - out = self.avgpool(out) - out = self.bn3(self.conv3(out)) - - if self.downsample is not None: - identity = self.downsample(x) - - out += identity - out = self.relu(out) - return out - - -class AttentionPool2d(nn.Module): - def __init__(self, spacial_dim: int, embed_dim: int, num_heads: int, output_dim: int = None): - super().__init__() - self.positional_embedding = nn.Parameter(torch.randn(spacial_dim ** 2 + 1, embed_dim) / embed_dim ** 0.5) - self.k_proj = nn.Linear(embed_dim, embed_dim) - self.q_proj = nn.Linear(embed_dim, embed_dim) - self.v_proj = nn.Linear(embed_dim, embed_dim) - self.c_proj = nn.Linear(embed_dim, output_dim or embed_dim) - self.num_heads = num_heads - - def forward(self, x, return_local_features=False): - x = x.reshape(x.shape[0], x.shape[1], x.shape[2] * x.shape[3]).permute(2, 0, 1) # NCHW -> (HW)NC - x = torch.cat([x.mean(dim=0, keepdim=True), x], dim=0) # (HW+1)NC - x = x + self.positional_embedding[:, None, :].to(x.dtype) # (HW+1)NC - x, _ = F.multi_head_attention_forward( - query=x, key=x, value=x, - embed_dim_to_check=x.shape[-1], - num_heads=self.num_heads, - q_proj_weight=self.q_proj.weight, - k_proj_weight=self.k_proj.weight, - v_proj_weight=self.v_proj.weight, - in_proj_weight=None, - in_proj_bias=torch.cat([self.q_proj.bias, self.k_proj.bias, self.v_proj.bias]), - bias_k=None, - bias_v=None, - add_zero_attn=False, - dropout_p=0, - out_proj_weight=self.c_proj.weight, - out_proj_bias=self.c_proj.bias, - use_separate_proj_weight=True, - training=self.training, - need_weights=False - ) - - if return_local_features: return x[0], x[1:] - - return x[0] - - -class ModifiedResNet(Backbone): - """ - Extended from CLIP implementation. It contains following changes: - 1. change all nn.BatchNorm2d() to FrozenBatchNorm2d(), due to small batch size of detection training - 2. add self._out_feature_strides according to standard ResNet - 2. modify forward() to be compatible with Detectron2 - 3. add freeze() and output_shape() to be compatible with Detectron2 - 4. add build_clip_resnet_backbone() to build this ModifiedResNet - - A ResNet class that is similar to torchvision's but contains the following changes: - - There are now 3 "stem" convolutions as opposed to 1, with an average pool instead of a max pool. - - Performs anti-aliasing strided convolutions, where an avgpool is prepended to convolutions with stride > 1 - - The final pooling layer is a QKV attention instead of an average pool - """ - - def __init__(self, layers, output_dim, heads, input_resolution=224, width=64, - out_features=None, freeze_at=0, depth=None, pool_vec=True, create_att_pool=False, norm_type='FronzenBN'): - super().__init__() - self.output_dim = output_dim - self.input_resolution = input_resolution - self.norm_type = norm_type - - # the 3-layer stem - self.conv1 = nn.Conv2d(3, width // 2, kernel_size=3, stride=2, padding=1, bias=False) - if norm_type == 'FronzenBN': - self.bn1 = FrozenBatchNorm2d(width // 2) # nn.BatchNorm2d(width // 2) - elif norm_type == 'SyncBN': - self.bn1 = nn.SyncBatchNorm(width // 2) - self.conv2 = nn.Conv2d(width // 2, width // 2, kernel_size=3, padding=1, bias=False) - if norm_type == 'FronzenBN': - self.bn2 = FrozenBatchNorm2d(width // 2) # nn.BatchNorm2d(width // 2) - elif norm_type == 'SyncBN': - self.bn2 = nn.SyncBatchNorm(width // 2) - self.conv3 = nn.Conv2d(width // 2, width, kernel_size=3, padding=1, bias=False) - if norm_type == 'FronzenBN': - self.bn3 = FrozenBatchNorm2d(width) # nn.BatchNorm2d(width) - elif norm_type == 'SyncBN': - self.bn3 = nn.SyncBatchNorm(width) - self.avgpool = nn.AvgPool2d(2) - self.relu = nn.ReLU(inplace=True) - - # residual layers - self._inplanes = width # this is a *mutable* variable used during construction - self.layer1 = self._make_layer(width, layers[0]) - self.layer2 = self._make_layer(width * 2, layers[1], stride=2) - self.layer3 = self._make_layer(width * 4, layers[2], stride=2) - if 'res5' in out_features: # FPN - self.layer4 = self._make_layer(width * 8, layers[3], stride=2) - else: # C4, layer4 created here won't be used in backbone, but used in roi_head - self.layer4 = self._make_layer(width * 8, layers[3], stride=2) # None - - self.pool_vec = pool_vec - self._out_features = out_features if out_features else [] - - if depth in [50,101]: # resnet50 or resnet 101 - # FPN: ["res2", "res3", "res4", "res5"]; C4: ["res4"] - self._out_feature_channels = {'stem': 64, 'res2': 256, 'res3': 512, 'res4': 1024, 'res5': 2048} if 'res5' in self._out_features \ - else {'stem': 64, 'res2': 256, 'res3': 512, 'res4': 1024} - self._out_feature_strides = {'stem': 4, 'res2': 4, 'res3': 8, 'res4': 16, 'res5': 32} if 'res5' in self._out_features \ - else {'stem': 4, 'res2': 4, 'res3': 8, 'res4': 16} # anti-aliasing strided conv??? - elif depth in [200]: # resnet50x4 - # FPN: ["res2", "res3", "res4", "res5"]; C4: ["res4"] - self._out_feature_channels = {'stem': 80, 'res2': 320, 'res3': 640, 'res4': 1280, 'res5': 2560} if 'res5' in self._out_features \ - else {'stem': 80, 'res2': 320, 'res3': 640, 'res4': 1280} - self._out_feature_strides = {'stem': 4, 'res2': 4, 'res3': 8, 'res4': 16, 'res5': 32} if 'res5' in self._out_features \ - else {'stem': 4, 'res2': 4, 'res3': 8, 'res4': 16} # anti-aliasing strided conv??? - - if self.pool_vec or create_att_pool: # pool a vector representation for an image - - #last_feat = 'res5' if 'res5' in self._out_features else 'res4' - #embed_dim = self._out_feature_channels[last_feat] - #stride = self._out_feature_strides[last_feat] - stride = 32 - embed_dim = width * stride - - - self.attnpool = AttentionPool2d(input_resolution // stride, embed_dim, heads, output_dim) - # if create_att_pool: # freeze attnpool layer - # for p in self.attnpool.parameters(): p.requires_grad = False - - self.freeze(freeze_at) - - - def _make_layer(self, planes, blocks, stride=1): - layers = [Bottleneck(self._inplanes, planes, stride, norm_type=self.norm_type)] - - self._inplanes = planes * Bottleneck.expansion - for _ in range(1, blocks): - layers.append(Bottleneck(self._inplanes, planes, norm_type=self.norm_type)) - - return nn.Sequential(*layers) - - def stem(self, x): - """ - Stem of the ResNet. - Computes the first 3 convolutions and the average pooling. - Do not call this method directly, use forward() instead. - """ - for conv, bn in [(self.conv1, self.bn1), (self.conv2, self.bn2), (self.conv3, self.bn3)]: - x = self.relu(bn(conv(x))) - x = self.avgpool(x) - return x - - def forward(self, x): - - - assert x.dim() == 4, f"ResNet takes an input of shape (N, C, H, W). Got {x.shape} instead!" - outputs = {} - x = x.type(self.conv1.weight.dtype) # det2 resnet50: [3, 800, 1216]; CLIP resnet50: [3, 224, 224] - x = self.stem(x) # det2 resnet50: [64, 200, 304]; CLIP resnet50: [64, 56, 56] - if "stem" in self._out_features: - outputs["stem"] = x - x = self.layer1(x) # det2 resnet50: [256, 200, 304]; CLIP resnet50: [256, 56, 56] - outputs['res2'] = x if "res2" in self._out_features else None - x = self.layer2(x) # det2 resnet50: [512, 100, 152]; CLIP resnet50: [512, 28, 28] - outputs['res3'] = x if "res3" in self._out_features else None - x = self.layer3(x) # det2 resnet50: [1024, 50, 76]; CLIP resnet50: [1024, 14, 14] - outputs['res4'] = x if "res4" in self._out_features else None - - x_5 = self.layer4(x) - - x = x_5 if "res5" in self._out_features else x # det2 resnet50: [2048, 25, 38]; CLIP resnet50: [2048, 7, 7] - - outputs['res5'] = x if "res5" in self._out_features else None - - if self.pool_vec: # pool a vector representation for an image, for global image classification - x = self.attnpool(x_5) # CLIP resnet50: [1024] - return x - else: # for FPN - return outputs - - def forward_return_spatial_feats(self, x, use_layer3=False, use_attnpool_for_spatial_feats=True): - """ - Forward pass that returns spatial features for each layer. - outputs is a dict with keys: - - x_norm_clstoken - - x_norm_patchtokens - The values are the normalized features for the class token and patch tokens. - Since this is a ResNet, the class token is the - - Args: - x: input tensor - use_layer3: if True, use features from layer3 instead of layer4 for spatial features - """ - outputs = {} - x = x.type(self.conv1.weight.dtype) - x = self.stem(x) - x = self.layer1(x) - x = self.layer2(x) - x_3 = self.layer3(x) - x_5 = self.layer4(x_3) - - # Choose which layer to use for spatial features - if use_layer3: - spatial_x = x_3 # Use layer3 features (1024-dim, higher spatial resolution) - else: - spatial_x = x_5 if "res5" in self._out_features else x_3 - - # Always use layer4 for global CLS token (via attention pooling) - if self.pool_vec: - if "res5" in self._out_features: - if use_attnpool_for_spatial_feats: - x_norm_cls_token, spatial_features = self.attnpool(x_5, return_local_features=True) - else: - x_norm_cls_token = self.attnpool(x_5, return_local_features=False) - x_5 = x_5.reshape(x_5.shape[0], x_5.shape[1], x_5.shape[2] * x_5.shape[3]).permute(2, 0, 1) # NCHW -> (HW)NC - spatial_features = self.attnpool.c_proj(x_5) # Use the projection layer directly - - # spatial_features has shape # (HW)NC - # we want it to be (B, H*W, C) - spatial_features = spatial_features.permute(1, 0, 2) # (B, H*W, C) - - else: - x_norm_cls_token = self.attnpool(x_5) - else: - x_norm_cls_token = None - spatial_features = None - - if spatial_features is None: # when res5 not in out_features, use layer3 features - B, C, H, W = spatial_x.shape - spatial_features = spatial_x.reshape(B, C, H * W).transpose(1, 2) # [B, H*W, C] - - # apply normalization - spatial_features = F.normalize(spatial_features, dim=-1) # Normalize along the feature dimension - x_norm_cls_token = F.normalize(x_norm_cls_token, dim=-1) if x_norm_cls_token is not None else None - outputs["x_norm_clstoken"] = x_norm_cls_token - outputs["x_norm_patchtokens"] = spatial_features - - return outputs - - def freeze(self, freeze_at=0): - """ - Freeze the first several stages of the ResNet. Commonly used in - fine-tuning. - - Layers that produce the same feature map spatial size are defined as one - "stage" by :paper:`FPN`. - - Args: - freeze_at (int): number of stages to freeze. - `1` means freezing the stem. `2` means freezing the stem and - one residual stage, etc. - - Returns: - nn.Module: this ResNet itself - """ - def cnnblockbase_freeze(nn_module): - """ - Make this block not trainable. - This method sets all parameters to `requires_grad=False`, - and convert all BatchNorm layers to FrozenBatchNorm - - Returns: - the block itself - """ - for p in nn_module.parameters(): - p.requires_grad = False - FrozenBatchNorm2d.convert_frozen_batchnorm(nn_module) - - if freeze_at >= 1: # stem - cnnblockbase_freeze(self.conv1) - cnnblockbase_freeze(self.bn1) - cnnblockbase_freeze(self.conv2) - cnnblockbase_freeze(self.bn2) - cnnblockbase_freeze(self.conv3) - cnnblockbase_freeze(self.bn3) - # each stage is a torch.nn.modules.container.Sequential - for idx, stage in enumerate([self.layer1, self.layer2, self.layer3, self.layer4], start=2): - if freeze_at >= idx: - for block in stage.children(): # each block is a Bottleneck - cnnblockbase_freeze(block) - return self - - def output_shape(self): - return { - name: ShapeSpec( - channels=self._out_feature_channels[name], stride=self._out_feature_strides[name] - ) - for name in self._out_features - } - - -class LayerNorm(nn.LayerNorm): - """Subclass torch's LayerNorm to handle fp16.""" - - def forward(self, x: torch.Tensor): - orig_type = x.dtype - ret = super().forward(x.type(torch.float32)) - return ret.type(orig_type) - - -class QuickGELU(nn.Module): - def forward(self, x: torch.Tensor): - return x * torch.sigmoid(1.702 * x) - - -class ResidualAttentionBlock(nn.Module): - def __init__(self, d_model: int, n_head: int, attn_mask: torch.Tensor = None): - super().__init__() - - self.attn = nn.MultiheadAttention(d_model, n_head) - self.ln_1 = LayerNorm(d_model) - self.mlp = nn.Sequential(OrderedDict([ - ("c_fc", nn.Linear(d_model, d_model * 4)), - ("gelu", QuickGELU()), - ("c_proj", nn.Linear(d_model * 4, d_model)) - ])) - self.ln_2 = LayerNorm(d_model) - self.attn_mask = attn_mask - - def attention(self, x: torch.Tensor): - self.attn_mask = self.attn_mask.to(dtype=x.dtype, device=x.device) if self.attn_mask is not None else None - return self.attn(x, x, x, need_weights=False, attn_mask=self.attn_mask)[0] - - def forward(self, x: torch.Tensor): - x = x + self.attention(self.ln_1(x)) - x = x + self.mlp(self.ln_2(x)) - return x - - -class Transformer(nn.Module): - def __init__(self, width: int, layers: int, heads: int, attn_mask: torch.Tensor = None): - super().__init__() - self.width = width - self.layers = layers - self.resblocks = nn.Sequential(*[ResidualAttentionBlock(width, heads, attn_mask) for _ in range(layers)]) - - def forward(self, x: torch.Tensor): - return self.resblocks(x) - - -class VisualTransformer(nn.Module): - def __init__(self, input_resolution: int, patch_size: int, width: int, layers: int, heads: int, output_dim: int): - super().__init__() - self.input_resolution = input_resolution - self.output_dim = output_dim - self.conv1 = nn.Conv2d(in_channels=3, out_channels=width, kernel_size=patch_size, stride=patch_size, bias=False) - - scale = width ** -0.5 - self.class_embedding = nn.Parameter(scale * torch.randn(width)) - self.positional_embedding = nn.Parameter(scale * torch.randn((input_resolution // patch_size) ** 2 + 1, width)) - self.ln_pre = LayerNorm(width) - - self.transformer = Transformer(width, layers, heads) - - self.ln_post = LayerNorm(width) - self.proj = nn.Parameter(scale * torch.randn(width, output_dim)) - - def forward(self, x: torch.Tensor): - x = self.conv1(x) # shape = [*, width, grid, grid] - x = x.reshape(x.shape[0], x.shape[1], -1) # shape = [*, width, grid ** 2] - x = x.permute(0, 2, 1) # shape = [*, grid ** 2, width] - x = torch.cat([self.class_embedding.to(x.dtype) + torch.zeros(x.shape[0], 1, x.shape[-1], dtype=x.dtype, device=x.device), x], dim=1) # shape = [*, grid ** 2 + 1, width] - x = x + self.positional_embedding.to(x.dtype) - x = self.ln_pre(x) - - x = x.permute(1, 0, 2) # NLD -> LND - x = self.transformer(x) - x = x.permute(1, 0, 2) # LND -> NLD - - x = self.ln_post(x[:, 0, :]) - - if self.proj is not None: - x = x @ self.proj - - return x - - -class CLIP(Backbone): - def __init__(self, - embed_dim: int, - # vision - image_resolution: int, - vision_layers: Union[Tuple[int, int, int, int], int], - vision_width: int, - vision_patch_size: int, - # text - context_length: int, - vocab_size: int, - transformer_width: int, - transformer_heads: int, - transformer_layers: int, - out_features, - freeze_at, - depth=None - ): - super().__init__() - - self.context_length = context_length - - if isinstance(vision_layers, (tuple, list)): - vision_heads = vision_width * 32 // 64 - self.visual = ModifiedResNet( - layers=vision_layers, - output_dim=embed_dim, - heads=vision_heads, - input_resolution=image_resolution, - width=vision_width, - out_features=out_features, - freeze_at=freeze_at, - depth=depth - ) - else: - vision_heads = vision_width // 64 - self.visual = VisualTransformer( - input_resolution=image_resolution, - patch_size=vision_patch_size, - width=vision_width, - layers=vision_layers, - heads=vision_heads, - output_dim=embed_dim - ) - - self.transformer = Transformer( - width=transformer_width, - layers=transformer_layers, - heads=transformer_heads, - attn_mask=self.build_attention_mask() - ) - - self.vocab_size = vocab_size - self.token_embedding = nn.Embedding(vocab_size, transformer_width) - self.positional_embedding = nn.Parameter(torch.empty(self.context_length, transformer_width)) - self.ln_final = LayerNorm(transformer_width) - - self.text_projection = nn.Parameter(torch.empty(transformer_width, embed_dim)) - self.logit_scale = nn.Parameter(torch.ones([]) * np.log(1 / 0.07)) - - self.initialize_parameters() - - def initialize_parameters(self): - nn.init.normal_(self.token_embedding.weight, std=0.02) - nn.init.normal_(self.positional_embedding, std=0.01) - - if isinstance(self.visual, ModifiedResNet): - if self.visual.attnpool is not None: - std = self.visual.attnpool.c_proj.in_features ** -0.5 - nn.init.normal_(self.visual.attnpool.q_proj.weight, std=std) - nn.init.normal_(self.visual.attnpool.k_proj.weight, std=std) - nn.init.normal_(self.visual.attnpool.v_proj.weight, std=std) - nn.init.normal_(self.visual.attnpool.c_proj.weight, std=std) - - for resnet_block in [self.visual.layer1, self.visual.layer2, self.visual.layer3, self.visual.layer4]: - for name, param in resnet_block.named_parameters(): - if name.endswith("bn3.weight"): - nn.init.zeros_(param) - - proj_std = (self.transformer.width ** -0.5) * ((2 * self.transformer.layers) ** -0.5) - attn_std = self.transformer.width ** -0.5 - fc_std = (2 * self.transformer.width) ** -0.5 - for block in self.transformer.resblocks: - nn.init.normal_(block.attn.in_proj_weight, std=attn_std) - nn.init.normal_(block.attn.out_proj.weight, std=proj_std) - nn.init.normal_(block.mlp.c_fc.weight, std=fc_std) - nn.init.normal_(block.mlp.c_proj.weight, std=proj_std) - - if self.text_projection is not None: - nn.init.normal_(self.text_projection, std=self.transformer.width ** -0.5) - - def build_attention_mask(self): - # lazily create causal attention mask, with full attention between the vision tokens - # pytorch uses additive attention mask; fill with -inf - mask = torch.empty(self.context_length, self.context_length) - mask.fill_(float("-inf")) - mask.triu_(1) # zero out the lower diagonal - return mask - - @property - def dtype(self): - return self.visual.conv1.weight.dtype - - def encode_image(self, image): - return self.visual(image.type(self.dtype)) - - def encode_text(self, text): - x = self.token_embedding(text).type(self.dtype) # [batch_size, n_ctx, d_model] - - x = x + self.positional_embedding.type(self.dtype) - x = x.permute(1, 0, 2) # NLD -> LND - x = self.transformer(x) - x = x.permute(1, 0, 2) # LND -> NLD - x = self.ln_final(x).type(self.dtype) - - # x.shape = [batch_size, n_ctx, transformer.width] - # take features from the eot embedding (eot_token is the highest number in each sequence) - x = x[torch.arange(x.shape[0]), text.argmax(dim=-1)] @ self.text_projection - - return x - - def forward(self, image, text): - image_features = self.encode_image(image) - text_features = self.encode_text(text) - - # normalized features - image_features = image_features / image_features.norm(dim=-1, keepdim=True) - text_features = text_features / text_features.norm(dim=-1, keepdim=True) - - # cosine similarity as logits - logit_scale = self.logit_scale.exp() - logits_per_image = logit_scale * image_features @ text_features.t() - logits_per_text = logit_scale * text_features @ image_features.t() - - # shape = [global_batch_size, global_batch_size] - return logits_per_image, logits_per_text - - -def convert_weights(model: nn.Module): - """Convert applicable model parameters to fp16""" - - def _convert_weights_to_fp16(l): - if isinstance(l, (nn.Conv1d, nn.Conv2d, nn.Linear)): - l.weight.data = l.weight.data.half() - if l.bias is not None: - l.bias.data = l.bias.data.half() - - if isinstance(l, nn.MultiheadAttention): - for attr in [*[f"{s}_proj_weight" for s in ["in", "q", "k", "v"]], "in_proj_bias", "bias_k", "bias_v"]: - tensor = getattr(l, attr) - if tensor is not None: - tensor.data = tensor.data.half() - - for name in ["text_projection", "proj"]: - if hasattr(l, name): - attr = getattr(l, name) - if attr is not None: - attr.data = attr.data.half() - - model.apply(_convert_weights_to_fp16) - - -def build_model(state_dict: dict): - vit = "visual.proj" in state_dict - - if vit: - vision_width = state_dict["visual.conv1.weight"].shape[0] - vision_layers = len([k for k in state_dict.keys() if k.startswith("visual.") and k.endswith(".attn.in_proj_weight")]) - vision_patch_size = state_dict["visual.conv1.weight"].shape[-1] - grid_size = round((state_dict["visual.positional_embedding"].shape[0] - 1) ** 0.5) - image_resolution = vision_patch_size * grid_size - else: - counts: list = [len(set(k.split(".")[2] for k in state_dict if k.startswith(f"visual.layer{b}"))) for b in [1, 2, 3, 4]] - vision_layers = tuple(counts) - vision_width = state_dict["visual.layer1.0.conv1.weight"].shape[0] - output_width = round((state_dict["visual.attnpool.positional_embedding"].shape[0] - 1) ** 0.5) - vision_patch_size = None - assert output_width ** 2 + 1 == state_dict["visual.attnpool.positional_embedding"].shape[0] - image_resolution = output_width * 32 - - embed_dim = state_dict["text_projection"].shape[1] - context_length = state_dict["positional_embedding"].shape[0] - vocab_size = state_dict["token_embedding.weight"].shape[0] - transformer_width = state_dict["ln_final.weight"].shape[0] - transformer_heads = transformer_width // 64 - transformer_layers = len(set(k.split(".")[2] for k in state_dict if k.startswith(f"transformer.resblocks"))) - - model = CLIP( - embed_dim, - image_resolution, vision_layers, vision_width, vision_patch_size, - context_length, vocab_size, transformer_width, transformer_heads, transformer_layers - ) - - for key in ["input_resolution", "context_length", "vocab_size"]: - if key in state_dict: - del state_dict[key] - - convert_weights(model) - model.load_state_dict(state_dict) - return model.eval() - - -@BACKBONE_REGISTRY.register() -def build_vit_clip(cfg, input_shape): - """ - Create the whole CLIP instance from config. - - Returns: - CLIP: a :class:`CLIP` instance. - """ - # port standard ResNet config to CLIP ModifiedResNet - freeze_at = cfg.MODEL.BACKBONE.FREEZE_AT - out_features = ['res5'] # includes the whole ResNet # cfg.MODEL.RESNETS.OUT_FEATURES - depth = cfg.MODEL.RESNETS.DEPTH - - # num_blocks_per_stage = { - # 18: [2, 2, 2, 2], - # 34: [3, 4, 6, 3], - # 50: [3, 4, 6, 3], - # 101: [3, 4, 23, 3], - # 152: [3, 8, 36, 3], - # }[depth] - vision_layers = 12 # num_blocks_per_stage - vision_width = 768 # cfg.MODEL.RESNETS.STEM_OUT_CHANNELS - - # default configs of CLIP - embed_dim = 512 # 1024 - image_resolution = 224 - vision_patch_size = 32 # None - context_length = 77 - vocab_size = 49408 - transformer_width = 512 - transformer_heads = 8 - transformer_layers = 12 - - model = CLIP( - embed_dim, - image_resolution, vision_layers, vision_width, vision_patch_size, - context_length, vocab_size, transformer_width, transformer_heads, transformer_layers, - out_features, freeze_at - ) - return model - -@BACKBONE_REGISTRY.register() -def build_resnet_clip(cfg, input_shape): - """ - Create the whole CLIP instance from config. - - Returns: - CLIP: a :class:`CLIP` instance. - """ - # port standard ResNet config to CLIP ModifiedResNet - freeze_at = cfg.MODEL.BACKBONE.FREEZE_AT - out_features = ['res5'] # includes the whole ResNet # cfg.MODEL.RESNETS.OUT_FEATURES - depth = cfg.MODEL.RESNETS.DEPTH - - num_blocks_per_stage = { - 18: [2, 2, 2, 2], - 34: [3, 4, 6, 3], - 50: [3, 4, 6, 3], - 101: [3, 4, 23, 3], - 152: [3, 8, 36, 3], - 200: [4, 6, 10, 6], # flag for ResNet50x4 - }[depth] - vision_layers = num_blocks_per_stage - vision_width = { - 50: 64, - 101: 64, - 200: 80, # flag for ResNet50x4 - }[depth] # cfg.MODEL.RESNETS.STEM_OUT_CHANNELS - - # default configs of CLIP - embed_dim = { - 50: 1024, - 101: 512, - 200: 640, # flag for ResNet50x4 - }[depth] - vision_heads = vision_width * 32 // 64 - image_resolution = { - 50: 224, - 101: 224, - 200: 288, # flag for ResNet50x4 - }[depth] - vision_patch_size = None - context_length = 77 - vocab_size = 49408 - transformer_width = { - 50: 512, - 101: 512, - 200: 640, # flag for ResNet50x4 - }[depth] - transformer_heads = { - 50: 8, - 101: 8, - 200: 10, # flag for ResNet50x4 - }[depth] - transformer_layers = 12 - - model = CLIP( - embed_dim, - image_resolution, vision_layers, vision_width, vision_patch_size, - context_length, vocab_size, transformer_width, transformer_heads, transformer_layers, - out_features, freeze_at - ) - return model - - -@BACKBONE_REGISTRY.register() -def build_clip_resnet_backbone(cfg, input_shape): - """ - Create a CLIP-version ResNet instance from config. - - Returns: - ModifiedResNet: a :class:`ModifiedResNet` instance. - """ - # port standard ResNet config to CLIP ModifiedResNet - freeze_at = cfg.MODEL.BACKBONE.FREEZE_AT - out_features = cfg.MODEL.RESNETS.OUT_FEATURES - depth = cfg.MODEL.RESNETS.DEPTH - # num_groups = cfg.MODEL.RESNETS.NUM_GROUPS - # width_per_group = cfg.MODEL.RESNETS.WIDTH_PER_GROUP - # bottleneck_channels = num_groups * width_per_group - # in_channels = cfg.MODEL.RESNETS.STEM_OUT_CHANNELS - # out_channels = cfg.MODEL.RESNETS.RES2_OUT_CHANNELS - # stride_in_1x1 = cfg.MODEL.RESNETS.STRIDE_IN_1X1 - # res5_dilation = cfg.MODEL.RESNETS.RES5_DILATION - # deform_on_per_stage = cfg.MODEL.RESNETS.DEFORM_ON_PER_STAGE - # deform_modulated = cfg.MODEL.RESNETS.DEFORM_MODULATED - # deform_num_groups = cfg.MODEL.RESNETS.DEFORM_NUM_GROUPS - - num_blocks_per_stage = { - 18: [2, 2, 2, 2], - 34: [3, 4, 6, 3], - 50: [3, 4, 6, 3], - 101: [3, 4, 23, 3], - 152: [3, 8, 36, 3], - 200: [4, 6, 10, 6], # flag for ResNet50x4 - }[depth] - vision_layers = num_blocks_per_stage - vision_width = { - 50: 64, - 101: 64, - 200: 80, # flag for ResNet50x4 - }[depth] # cfg.MODEL.RESNETS.STEM_OUT_CHANNELS - - # default configs of CLIP ModifiedResNet, but not used if only building ModifiedResNet as backbone - embed_dim = { - 50: 1024, - 101: 512, - 200: 640, # flag for ResNet50x4 - }[depth] - vision_heads = vision_width * 32 // 64 - image_resolution = { - 50: 224, - 101: 224, - 200: 288, # flag for ResNet50x4 - }[depth] - - # if combine {ModifiedResNet of CLIP, C4, text emb as classifier}, then has to use att_pool to match dimension - create_att_pool = True if (cfg.MODEL.ROI_HEADS.NAME in ['CLIPRes5ROIHeads', 'CLIPStandardROIHeads'] and cfg.MODEL.CLIP.USE_TEXT_EMB_CLASSIFIER)\ - or cfg.MODEL.ROI_HEADS.NAME == 'PretrainRes5ROIHeads' else False - - return ModifiedResNet(layers=vision_layers, - output_dim=embed_dim, - heads=vision_heads, - input_resolution=image_resolution, - width=vision_width, - out_features=out_features, - freeze_at=freeze_at, - depth=depth, - pool_vec=False, - create_att_pool=create_att_pool, - ) - - -class CLIPLangEncoder(nn.Module): - def __init__(self, - embed_dim: int, - # vision - image_resolution: int, - vision_layers: Union[Tuple[int, int, int, int], int], - vision_width: int, - vision_patch_size: int, - # text - context_length: int, - vocab_size: int, - transformer_width: int, - transformer_heads: int, - transformer_layers: int, - out_features, - freeze_at, - ): - super().__init__() - - self.context_length = context_length - - self.transformer = Transformer( - width=transformer_width, - layers=transformer_layers, - heads=transformer_heads, - attn_mask=self.build_attention_mask() - ) - - self.vocab_size = vocab_size - self.token_embedding = nn.Embedding(vocab_size, transformer_width) - self.positional_embedding = nn.Parameter(torch.empty(self.context_length, transformer_width)) - self.ln_final = LayerNorm(transformer_width) - - self.text_projection = nn.Parameter(torch.empty(transformer_width, embed_dim)) - #self.logit_scale = nn.Parameter(torch.ones([]) * np.log(1 / 0.07)) - - self.initialize_parameters() - - def initialize_parameters(self): - nn.init.normal_(self.token_embedding.weight, std=0.02) - nn.init.normal_(self.positional_embedding, std=0.01) - - proj_std = (self.transformer.width ** -0.5) * ((2 * self.transformer.layers) ** -0.5) - attn_std = self.transformer.width ** -0.5 - fc_std = (2 * self.transformer.width) ** -0.5 - for block in self.transformer.resblocks: - nn.init.normal_(block.attn.in_proj_weight, std=attn_std) - nn.init.normal_(block.attn.out_proj.weight, std=proj_std) - nn.init.normal_(block.mlp.c_fc.weight, std=fc_std) - nn.init.normal_(block.mlp.c_proj.weight, std=proj_std) - - if self.text_projection is not None: - nn.init.normal_(self.text_projection, std=self.transformer.width ** -0.5) - - def build_attention_mask(self): - # lazily create causal attention mask, with full attention between the vision tokens - # pytorch uses additive attention mask; fill with -inf - mask = torch.empty(self.context_length, self.context_length) - mask.fill_(float("-inf")) - mask.triu_(1) # zero out the lower diagonal - return mask - - @property - def dtype(self): - return self.transformer.resblocks[0].mlp[0].weight.dtype # torch.float32, not sure whether need to be fp16 in pretraining - - def encode_text(self, text, only_eot=True): - x = self.token_embedding(text).type(self.dtype) # [batch_size, n_ctx, d_model] - - x = x + self.positional_embedding.type(self.dtype) - x = x.permute(1, 0, 2) # NLD -> LND - x = self.transformer(x) - x = x.permute(1, 0, 2) # LND -> NLD - x = self.ln_final(x).type(self.dtype) - - if only_eot: - # x.shape = [batch_size, n_ctx, transformer.width] - # take features from the eot embedding (eot_token is the highest number in each sequence) - x = x[torch.arange(x.shape[0]), text.argmax(dim=-1)] @ self.text_projection - return x - else: - # return embeddings for all tokens, instead of the eot embedding as CLIP implementation below - return x @ self.text_projection - - -def build_clip_language_encoder(cfg): - """ - Create the CLIP language encoder instance from config. - - Returns: - CLIP: a :class:`CLIP` instance. - """ - # port standard ResNet config to CLIP ModifiedResNet - freeze_at = cfg.MODEL.BACKBONE.FREEZE_AT - out_features = ['res5'] # includes the whole ResNet # cfg.MODEL.RESNETS.OUT_FEATURES - depth = cfg.MODEL.RESNETS.DEPTH - - num_blocks_per_stage = { - 18: [2, 2, 2, 2], - 34: [3, 4, 6, 3], - 50: [3, 4, 6, 3], - 101: [3, 4, 23, 3], - 152: [3, 8, 36, 3], - 200: [4, 6, 10, 6], # flag for ResNet50x4 - }[depth] - vision_layers = num_blocks_per_stage - vision_width = { - 50: 64, - 101: 64, - 200: 80, # flag for ResNet50x4 - }[depth] # cfg.MODEL.RESNETS.STEM_OUT_CHANNELS - - # default configs of CLIP - embed_dim = { - 50: 1024, - 101: 512, - 200: 640, # flag for ResNet50x4 - }[depth] - vision_heads = vision_width * 32 // 64 - image_resolution = { - 50: 224, - 101: 224, - 200: 288, # flag for ResNet50x4 - }[depth] - vision_patch_size = None - context_length = 77 - vocab_size = 49408 - transformer_width = { - 50: 512, - 101: 512, - 200: 640, # flag for ResNet50x4 - }[depth] - transformer_heads = { - 50: 8, - 101: 8, - 200: 10, # flag for ResNet50x4 - }[depth] - transformer_layers = 12 - - model = CLIPLangEncoder( - embed_dim, - image_resolution, vision_layers, vision_width, vision_patch_size, - context_length, vocab_size, transformer_width, transformer_heads, transformer_layers, - out_features, freeze_at - ) - return model \ No newline at end of file diff --git a/src/regionclip/configs/pretrain/RegionCLIP_RN50.yaml b/src/regionclip/configs/pretrain/RegionCLIP_RN50.yaml deleted file mode 100644 index b88ce574e9c1a5f2a2e4798477857e9e76c73c29..0000000000000000000000000000000000000000 --- a/src/regionclip/configs/pretrain/RegionCLIP_RN50.yaml +++ /dev/null @@ -1,69 +0,0 @@ -_BASE_: "../Base-RCNN-C4.yaml" -MODEL: - META_ARCHITECTURE: "PretrainFastRCNN" - BACKBONE: - NAME: "build_clip_resnet_backbone" - FREEZE_AT: 2 - WEIGHTS: "" - MASK_ON: False - RESNETS: - DEPTH: 50 - OUT_FEATURES: ["res4"] - NORM: FrozenBN - STEM_OUT_CHANNELS: 64 - RES2_OUT_CHANNELS: 256 - ROI_HEADS: - NAME: "PretrainRes5ROIHeads" - IN_FEATURES: ["res4"] - PIXEL_MEAN: [0.48145466, 0.4578275, 0.40821073] - PIXEL_STD: [0.26862954, 0.26130258, 0.27577711] - CLIP: - CLSS_TEMP: 0.01 - CROP_REGION_TYPE: "RPN" - OFFLINE_RPN_NMS_THRESH: 0.5 - GATHER_GPUS: True - CONCEPT_THRES: 0.1 - PRETRAIN_RPN_REGIONS: 300 - PRETRAIN_SAMPLE_REGIONS: 100 - PRETRAIN_IMG_TXT_LEVEL: True - PRETRAIN_ONLY_EOT: True - TEACHER_RESNETS_DEPTH: 50 - TEACHER_POOLER_RESOLUTION: 14 -INPUT: - MIN_SIZE_TRAIN: (640, 672, 704, 736, 768, 800) -DATASETS: - TRAIN: ("imgtxtpairs",) - FACTORY_TRAIN: ("CLIPImgTxtPairTSVDataset",) - PATH_TRAIN: ("/home/v-yiwuzhong/projects/azureblobs/vlpdatasets/coco-caption/val2017",) # ("/tmp/datasets/CC3M",) - TEST: () -DATALOADER: - ASPECT_RATIO_GROUPING: False - NUM_WORKERS: 4 -TEST: - DETECTIONS_PER_IMAGE: 300 # LVIS allows up to 300 - EVAL_PERIOD: 2500000 -SOLVER: - IMS_PER_BATCH: 96 # 32 gpus - BASE_LR: 0.002 - WEIGHT_DECAY: 0.0001 - STEPS: (300000, 525000) - MAX_ITER: 600000 - CLIP_GRADIENTS: - ENABLED: True - CLIP_TYPE: "norm" - CLIP_VALUE: 5.0 -INPUT: - MIN_SIZE_TRAIN_SAMPLING: choice - MIN_SIZE_TRAIN: (640, 672, 704, 736, 768, 800) - MAX_SIZE_TRAIN: 1333 - MIN_SIZE_TEST: 800 - MAX_SIZE_TEST: 1333 - FORMAT: "RGB" -AUG: # Data Augmentation from MSR-CLIP - TRAIN: - IMAGE_SIZE: [800,] - MAX_SIZE: 1333 - TEST: - IMAGE_SIZE: [800,] - MAX_SIZE: 1333 - INTERPOLATION: 3 \ No newline at end of file diff --git a/src/regionclip/configs/pretrain/RegionCLIP_RN50x4.yaml b/src/regionclip/configs/pretrain/RegionCLIP_RN50x4.yaml deleted file mode 100644 index eea8cfd02e7d0045ccbd6c92cb3c78aebde842d6..0000000000000000000000000000000000000000 --- a/src/regionclip/configs/pretrain/RegionCLIP_RN50x4.yaml +++ /dev/null @@ -1,71 +0,0 @@ -_BASE_: "../Base-RCNN-C4.yaml" -MODEL: - META_ARCHITECTURE: "PretrainFastRCNN" - BACKBONE: - NAME: "build_clip_resnet_backbone" - FREEZE_AT: 2 - WEIGHTS: "" - MASK_ON: False - RESNETS: - DEPTH: 200 - OUT_FEATURES: ["res4"] - NORM: FrozenBN - STEM_OUT_CHANNELS: 64 - RES2_OUT_CHANNELS: 256 - ROI_HEADS: - NAME: "PretrainRes5ROIHeads" - IN_FEATURES: ["res4"] - ROI_BOX_HEAD: - POOLER_RESOLUTION: 18 - PIXEL_MEAN: [0.48145466, 0.4578275, 0.40821073] - PIXEL_STD: [0.26862954, 0.26130258, 0.27577711] - CLIP: - CLSS_TEMP: 0.01 - CROP_REGION_TYPE: "RPN" - OFFLINE_RPN_NMS_THRESH: 0.5 - GATHER_GPUS: True - CONCEPT_THRES: 0.1 - PRETRAIN_RPN_REGIONS: 300 - PRETRAIN_SAMPLE_REGIONS: 100 - PRETRAIN_IMG_TXT_LEVEL: True - PRETRAIN_ONLY_EOT: True - TEACHER_RESNETS_DEPTH: 200 - TEACHER_POOLER_RESOLUTION: 18 -INPUT: - MIN_SIZE_TRAIN: (640, 672, 704, 736, 768, 800) -DATASETS: - TRAIN: ("imgtxtpairs",) - FACTORY_TRAIN: ("CLIPImgTxtPairTSVDataset",) - PATH_TRAIN: ("/home/v-yiwuzhong/projects/azureblobs/vlpdatasets/coco-caption/val2017",) # ("/tmp/datasets/CC3M",) - TEST: () -DATALOADER: - ASPECT_RATIO_GROUPING: False - NUM_WORKERS: 4 -TEST: - DETECTIONS_PER_IMAGE: 300 # LVIS allows up to 300 - EVAL_PERIOD: 2500000 -SOLVER: - IMS_PER_BATCH: 96 # 32 gpus - BASE_LR: 0.002 - WEIGHT_DECAY: 0.0001 - STEPS: (300000, 525000) - MAX_ITER: 600000 - CLIP_GRADIENTS: - ENABLED: True - CLIP_TYPE: "norm" - CLIP_VALUE: 5.0 -INPUT: - MIN_SIZE_TRAIN_SAMPLING: choice - MIN_SIZE_TRAIN: (640, 672, 704, 736, 768, 800) - MAX_SIZE_TRAIN: 1333 - MIN_SIZE_TEST: 800 - MAX_SIZE_TEST: 1333 - FORMAT: "RGB" -AUG: # Data Augmentation from MSR-CLIP - TRAIN: - IMAGE_SIZE: [800,] - MAX_SIZE: 1333 - TEST: - IMAGE_SIZE: [800,] - MAX_SIZE: 1333 - INTERPOLATION: 3 \ No newline at end of file diff --git a/src/regionclip/datasets/bpe_simple_vocab_16e6.txt.gz b/src/regionclip/datasets/bpe_simple_vocab_16e6.txt.gz deleted file mode 100644 index 36a15856e00a06a9fbed8cdd34d2393fea4a3113..0000000000000000000000000000000000000000 --- a/src/regionclip/datasets/bpe_simple_vocab_16e6.txt.gz +++ /dev/null @@ -1,3 +0,0 @@ -version https://git-lfs.github.com/spec/v1 -oid sha256:924691ac288e54409236115652ad4aa250f48203de50a9e4722a6ecd48d6804a -size 1356917 diff --git a/src/regionclip/datasets/clip_prompt_utils.py b/src/regionclip/datasets/clip_prompt_utils.py deleted file mode 100644 index f8f72692caee5c490191e1541c80b10d036766b2..0000000000000000000000000000000000000000 --- a/src/regionclip/datasets/clip_prompt_utils.py +++ /dev/null @@ -1,441 +0,0 @@ -import gzip -import html -import os -from functools import lru_cache - -import ftfy -import regex as re -import torch -import numpy as np -from typing import Union, List - -from .lvis_v1_categories import LVIS_CATEGORIES as LVIS_V1_CATEGORIES -from .coco_zeroshot_categories import COCO_UNSEEN_CLS, COCO_SEEN_CLS, COCO_OVD_ALL_CLS, COCO_80_ALL_CLS - -# https://github.com/openai/CLIP/blob/main/clip/simple_tokenizer.py -@lru_cache() -def default_bpe(): - return os.path.join(os.path.dirname(os.path.abspath(__file__)), "bpe_simple_vocab_16e6.txt.gz") - - -@lru_cache() -def bytes_to_unicode(): - """ - Returns list of utf-8 byte and a corresponding list of unicode strings. - The reversible bpe codes work on unicode strings. - This means you need a large # of unicode characters in your vocab if you want to avoid UNKs. - When you're at something like a 10B token dataset you end up needing around 5K for decent coverage. - This is a signficant percentage of your normal, say, 32K bpe vocab. - To avoid that, we want lookup tables between utf-8 bytes and unicode strings. - And avoids mapping to whitespace/control characters the bpe code barfs on. - """ - bs = list(range(ord("!"), ord("~")+1))+list(range(ord("ยก"), ord("ยฌ")+1))+list(range(ord("ยฎ"), ord("รฟ")+1)) - cs = bs[:] - n = 0 - for b in range(2**8): - if b not in bs: - bs.append(b) - cs.append(2**8+n) - n += 1 - cs = [chr(n) for n in cs] - return dict(zip(bs, cs)) - - -def get_pairs(word): - """Return set of symbol pairs in a word. - Word is represented as tuple of symbols (symbols being variable-length strings). - """ - pairs = set() - prev_char = word[0] - for char in word[1:]: - pairs.add((prev_char, char)) - prev_char = char - return pairs - - -def basic_clean(text): - text = ftfy.fix_text(text) - text = html.unescape(html.unescape(text)) - return text.strip() - - -def whitespace_clean(text): - text = re.sub(r'\s+', ' ', text) - text = text.strip() - return text - - -class SimpleTokenizer(object): - def __init__(self, bpe_path: str = default_bpe()): - self.byte_encoder = bytes_to_unicode() - self.byte_decoder = {v: k for k, v in self.byte_encoder.items()} - merges = gzip.open(bpe_path).read().decode("utf-8").split('\n') - merges = merges[1:49152-256-2+1] - merges = [tuple(merge.split()) for merge in merges] - vocab = list(bytes_to_unicode().values()) - vocab = vocab + [v+'' for v in vocab] - self.vocab = vocab - for merge in merges: - vocab.append(''.join(merge)) - vocab.extend(['<|startoftext|>', '<|endoftext|>']) - self.encoder = dict(zip(vocab, range(len(vocab)))) - self.decoder = {v: k for k, v in self.encoder.items()} - self.bpe_ranks = dict(zip(merges, range(len(merges)))) - self.cache = {'<|startoftext|>': '<|startoftext|>', '<|endoftext|>': '<|endoftext|>'} - self.pat = re.compile(r"""<\|startoftext\|>|<\|endoftext\|>|'s|'t|'re|'ve|'m|'ll|'d|[\p{L}]+|[\p{N}]|[^\s\p{L}\p{N}]+""", re.IGNORECASE) - - def bpe(self, token): - if token in self.cache: - return self.cache[token] - word = tuple(token[:-1]) + ( token[-1] + '',) - pairs = get_pairs(word) - - if not pairs: - return token+'' - - while True: - bigram = min(pairs, key = lambda pair: self.bpe_ranks.get(pair, float('inf'))) - if bigram not in self.bpe_ranks: - break - first, second = bigram - new_word = [] - i = 0 - while i < len(word): - try: - j = word.index(first, i) - new_word.extend(word[i:j]) - i = j - except: - new_word.extend(word[i:]) - break - - if word[i] == first and i < len(word)-1 and word[i+1] == second: - new_word.append(first+second) - i += 2 - else: - new_word.append(word[i]) - i += 1 - new_word = tuple(new_word) - word = new_word - if len(word) == 1: - break - else: - pairs = get_pairs(word) - word = ' '.join(word) - self.cache[token] = word - return word - - def encode(self, text, return_link=False): - bpe_tokens = [] - text = whitespace_clean(basic_clean(text)).lower() - str2id_links = [] # link original sentence word to the tokenized ids of its subwords - for token in re.findall(self.pat, text): - this_link = [token] - token = ''.join(self.byte_encoder[b] for b in token.encode('utf-8')) - ids = [self.encoder[bpe_token] for bpe_token in self.bpe(token).split(' ')] - bpe_tokens.extend(ids) - this_link.append(ids) - str2id_links.append(this_link) - if return_link: - return bpe_tokens, str2id_links - return bpe_tokens - - def decode(self, tokens): - text = ''.join([self.decoder[token] for token in tokens]) - text = bytearray([self.byte_decoder[c] for c in text]).decode('utf-8', errors="replace").replace('', ' ') - return text - - -# https://github.com/openai/CLIP/blob/main/clip/clip.py -_tokenizer = SimpleTokenizer() - -def tokenize(texts: Union[str, List[str]], context_length: int = 77): - if isinstance(texts, str): - texts = [texts] - - sot_token = _tokenizer.encoder["<|startoftext|>"] - eot_token = _tokenizer.encoder["<|endoftext|>"] - all_tokens = [[sot_token] + _tokenizer.encode(text) + [eot_token] for text in texts] - result = torch.zeros(len(all_tokens), context_length, dtype=torch.long) - - for i, tokens in enumerate(all_tokens): - if len(tokens) > context_length: - raise RuntimeError(f"Input {texts[i]} is too long for context length {context_length}") - result[i, :len(tokens)] = torch.tensor(tokens) - - return result - - -# prompt_engineering.py -def get_prompt_templates(): - # prompt_templates = [ - # 'There is a {} in the scene.', - # 'There is the {} in the scene.', - # 'a photo of a {} in the scene.', - # 'a photo of the {} in the scene.', - # 'a photo of one {} in the scene.', - - # 'itap of a {}.', - # 'itap of my {}.', # itap: I took a picture of - # 'itap of the {}.', - # 'a photo of a {}.', - # 'a photo of my {}.', - # 'a photo of the {}.', - # 'a photo of one {}.', - # 'a photo of many {}.', - - # 'a good photo of a {}.', - # 'a good photo of the {}.', - # 'a bad photo of a {}.', - # 'a bad photo of the {}.', - # 'a photo of a nice {}.', - # 'a photo of the nice {}.', - # 'a photo of a cool {}.', - # 'a photo of the cool {}.', - # 'a photo of a weird {}.', - # 'a photo of the weird {}.', - - # 'a photo of a small {}.', - # 'a photo of the small {}.', - # 'a photo of a large {}.', - # 'a photo of the large {}.', - - # 'a photo of a clean {}.', - # 'a photo of the clean {}.', - # 'a photo of a dirty {}.', - # 'a photo of the dirty {}.', - - # 'a bright photo of a {}.', - # 'a bright photo of the {}.', - # 'a dark photo of a {}.', - # 'a dark photo of the {}.', - - # 'a photo of a hard to see {}.', - # 'a photo of the hard to see {}.', - # 'a low resolution photo of a {}.', - # 'a low resolution photo of the {}.', - # 'a cropped photo of a {}.', - # 'a cropped photo of the {}.', - # 'a close-up photo of a {}.', - # 'a close-up photo of the {}.', - # 'a jpeg corrupted photo of a {}.', - # 'a jpeg corrupted photo of the {}.', - # 'a blurry photo of a {}.', - # 'a blurry photo of the {}.', - # 'a pixelated photo of a {}.', - # 'a pixelated photo of the {}.', - - # 'a black and white photo of the {}.', - # 'a black and white photo of a {}.', - - # 'a plastic {}.', - # 'the plastic {}.', - - # 'a toy {}.', - # 'the toy {}.', - # 'a plushie {}.', - # 'the plushie {}.', - # 'a cartoon {}.', - # 'the cartoon {}.', - - # 'an embroidered {}.', - # 'the embroidered {}.', - - # 'a painting of the {}.', - # 'a painting of a {}.', - # ] - - prompt_templates = [ - '{}.', - 'a photo of a {}.', - 'a bad photo of a {}.', - 'a photo of many {}.', - 'a sculpture of a {}.', - 'a photo of the hard to see {}.', - 'a low resolution photo of the {}.', - 'a rendering of a {}.', - 'graffiti of a {}.', - 'a bad photo of the {}.', - 'a cropped photo of the {}.', - 'a tattoo of a {}.', - 'the embroidered {}.', - 'a photo of a hard to see {}.', - 'a bright photo of a {}.', - 'a photo of a clean {}.', - 'a photo of a dirty {}.', - 'a dark photo of the {}.', - 'a drawing of a {}.', - 'a photo of my {}.', - 'the plastic {}.', - 'a photo of the cool {}.', - 'a close-up photo of a {}.', - 'a black and white photo of the {}.', - 'a painting of the {}.', - 'a painting of a {}.', - 'a pixelated photo of the {}.', - 'a sculpture of the {}.', - 'a bright photo of the {}.', - 'a cropped photo of a {}.', - 'a plastic {}.', - 'a photo of the dirty {}.', - 'a jpeg corrupted photo of a {}.', - 'a blurry photo of the {}.', - 'a photo of the {}.', - 'a good photo of the {}.', - 'a rendering of the {}.', - 'a {} in a video game.', - 'a photo of one {}.', - 'a doodle of a {}.', - 'a close-up photo of the {}.', - 'the origami {}.', - 'the {} in a video game.', - 'a sketch of a {}.', - 'a doodle of the {}.', - 'a origami {}.', - 'a low resolution photo of a {}.', - 'the toy {}.', - 'a rendition of the {}.', - 'a photo of the clean {}.', - 'a photo of a large {}.', - 'a rendition of a {}.', - 'a photo of a nice {}.', - 'a photo of a weird {}.', - 'a blurry photo of a {}.', - 'a cartoon {}.', - 'art of a {}.', - 'a sketch of the {}.', - 'a embroidered {}.', - 'a pixelated photo of a {}.', - 'itap of the {}.', - 'a jpeg corrupted photo of the {}.', - 'a good photo of a {}.', - 'a plushie {}.', - 'a photo of the nice {}.', - 'a photo of the small {}.', - 'a photo of the weird {}.', - 'the cartoon {}.', - 'art of the {}.', - 'a drawing of the {}.', - 'a photo of the large {}.', - 'a black and white photo of a {}.', - 'the plushie {}.', - 'a dark photo of a {}.', - 'itap of a {}.', - 'graffiti of the {}.', - 'a toy {}.', - 'itap of my {}.', - 'a photo of a cool {}.', - 'a photo of a small {}.', - 'a tattoo of the {}.', - ] - return prompt_templates - -def prompt_engineering(classnames, template=""): - return template.replace('{}', classnames.replace(',', '').replace('+', ' ')) - -# clip_img_tsv.py -def convert_example_to_features_bpe(text, tokenizer, sot_token, eot_token, context_length=77): - """ - Convert a raw sample (pair of sentences as tokenized strings) into a proper training sample. - :param tokenizer: Tokenizer - :return: List, a list containing token id, padded by 0 - """ - assert isinstance(text, str) - input_ids = [sot_token] + tokenizer.encode(text) + [eot_token] - if len(input_ids) > context_length: - input_ids = input_ids[:context_length] - input_ids = np.array(input_ids) - - pad_input_ids = np.zeros(context_length) - pad_input_ids[:input_ids.shape[0]] = input_ids - - return pad_input_ids - -def get_cls_names(filter_novel=False, coco=None, from_file=False): - """ return a list of strings with each string as name of a class - """ - # the names are stored in a txt file - if from_file: - # coco_det_cls = {COCO_80_ALL_CLS[key]: key for key in COCO_80_ALL_CLS} - # # not found in nouns {'skis': 31, 'sports ball': 33, 'hot dog': 53, 'potted plant': 59, 'scissors': 77, 'hair drier': 79} - # coco_det_cls['ski'] = 81 - # coco_det_cls['scissor'] = 82 - # with open('/home/v-yiwuzhong/projects/azureblobs/vyiwuzhong_phillytools/trained_models/concept_pool/COCO_Caption_nouns_4688.txt','w') as g: - # with open(from_file, 'r') as f: - # cnt = 0 - # for row in f: - # if row.split(",")[0] not in coco_det_cls: - # g.write(row) - # cnt += 1 - # else: - # coco_det_cls.pop(row.split(",")[0]) - names = [] - with open(from_file, 'r') as f: - for row in f: - names.append(row.split(",")[0]) - return names - # classes' names - if coco == 'target': - return COCO_UNSEEN_CLS - elif coco == 'base': - return COCO_SEEN_CLS - elif coco == 'all': - return COCO_OVD_ALL_CLS - elif coco == 'all_80': - return [COCO_80_ALL_CLS[i+1] for i in range(80)] - assert len(LVIS_V1_CATEGORIES) == 1203 - cat_ids = [k["id"] for k in LVIS_V1_CATEGORIES] - assert min(cat_ids) == 1 and max(cat_ids) == len( - cat_ids - ), "Category ids are not in [1, #categories], as expected" - # Ensure that the category list is sorted by id - lvis_categories = sorted(LVIS_V1_CATEGORIES, key=lambda x: x["id"]) - if filter_novel: - class_names = [cls_meta['name'] for cls_meta in lvis_categories if cls_meta['frequency'] != 'r'] - else: - class_names = [cls_meta['name'] for cls_meta in lvis_categories] - - # remove or replace special symbols - class_names = [cls_n.replace("_", " ") for cls_n in class_names] - class_names = [cls_n.replace("(", "") for cls_n in class_names] - class_names = [cls_n.replace(")", "") for cls_n in class_names] - return class_names - -def pre_tokenize(class_names): - """ - pre-tokenize class names - :param class_names: List, a list of class names - :param tokenizer: Tokenizer, SimpleTokenizer() - :return: Tensor, containing all prompts for all classes, [#cls, #prompts, context_length] - """ - # tokenizer - tokenizer = SimpleTokenizer() - sot_token = tokenizer.encoder["<|startoftext|>"] - eot_token = tokenizer.encoder["<|endoftext|>"] - - # prompt engineering - prompt_templates = get_prompt_templates() - input_ids_all = [] - for k in range(len(class_names)): - v = class_names[k] - if isinstance(v, str): - vs = [v] - elif isinstance(v, list): - vs = v - t1s = [] - for v in vs: - for pt in prompt_templates: - t1s.append(prompt_engineering(v, template=pt)) - input_ids = [] - for t1 in t1s: - this_input_ids = convert_example_to_features_bpe(t1, tokenizer, sot_token, eot_token) - input_ids.append(torch.tensor(this_input_ids, dtype=torch.long)) - - input_ids_all.append(torch.stack(input_ids, 0)) - - input_ids_all_classes = torch.stack(input_ids_all, 0) - return input_ids_all_classes - - -if __name__ == "__main__": - flatten_input_ids = pre_tokenize() diff --git a/src/regionclip/datasets/coco_zeroshot_categories.py b/src/regionclip/datasets/coco_zeroshot_categories.py deleted file mode 100644 index baf2f4483292c1778432605fe8985d3134c26cb4..0000000000000000000000000000000000000000 --- a/src/regionclip/datasets/coco_zeroshot_categories.py +++ /dev/null @@ -1,208 +0,0 @@ -# COCO categories for zero-shot setting -# 65 categories in total, 48 base categories for training, 17 unseen categories are only used in testing -# from http://ankan.umiacs.io/files/mscoco_seen_classes.json, http://ankan.umiacs.io/files/mscoco_unseen_classes.json - -# 17 class names in order, obtained from load_coco_json() function -COCO_UNSEEN_CLS = ['airplane', 'bus', 'cat', 'dog', 'cow', 'elephant', 'umbrella', \ - 'tie', 'snowboard', 'skateboard', 'cup', 'knife', 'cake', 'couch', 'keyboard', \ - 'sink', 'scissors'] - -# 48 class names in order, obtained from load_coco_json() function -COCO_SEEN_CLS = ['person', 'bicycle', 'car', 'motorcycle', 'train', 'truck', \ - 'boat', 'bench', 'bird', 'horse', 'sheep', 'bear', 'zebra', 'giraffe', \ - 'backpack', 'handbag', 'suitcase', 'frisbee', 'skis', 'kite', 'surfboard', \ - 'bottle', 'fork', 'spoon', 'bowl', 'banana', 'apple', 'sandwich', 'orange', \ - 'broccoli', 'carrot', 'pizza', 'donut', 'chair', 'bed', 'toilet', 'tv', \ - 'laptop', 'mouse', 'remote', 'microwave', 'oven', 'toaster', \ - 'refrigerator', 'book', 'clock', 'vase', 'toothbrush'] - -# 65 class names in order, obtained from load_coco_json() function -COCO_OVD_ALL_CLS = ['person', 'bicycle', 'car', 'motorcycle', 'airplane', \ - 'bus', 'train', 'truck', 'boat', 'bench', 'bird', 'cat', 'dog', 'horse', \ - 'sheep', 'cow', 'elephant', 'bear', 'zebra', 'giraffe', 'backpack', 'umbrella', \ - 'handbag', 'tie', 'suitcase', 'frisbee', 'skis', 'snowboard', 'kite', 'skateboard', \ - 'surfboard', 'bottle', 'cup', 'fork', 'knife', 'spoon', 'bowl', 'banana', \ - 'apple', 'sandwich', 'orange', 'broccoli', 'carrot', 'pizza', 'donut', 'cake', \ - 'chair', 'couch', 'bed', 'toilet', 'tv', 'laptop', 'mouse', 'remote', 'keyboard', \ - 'microwave', 'oven', 'toaster', 'sink', 'refrigerator', 'book', 'clock', 'vase', \ - 'scissors', 'toothbrush'] - -# 80 class names -COCO_80_ALL_CLS = {1: 'person', - 2: 'bicycle', - 3: 'car', - 4: 'motorcycle', - 5: 'airplane', - 6: 'bus', - 7: 'train', - 8: 'truck', - 9: 'boat', - 10: 'traffic light', - 11: 'fire hydrant', - 12: 'stop sign', - 13: 'parking meter', - 14: 'bench', - 15: 'bird', - 16: 'cat', - 17: 'dog', - 18: 'horse', - 19: 'sheep', - 20: 'cow', - 21: 'elephant', - 22: 'bear', - 23: 'zebra', - 24: 'giraffe', - 25: 'backpack', - 26: 'umbrella', - 27: 'handbag', - 28: 'tie', - 29: 'suitcase', - 30: 'frisbee', - 31: 'skis', - 32: 'snowboard', - 33: 'sports ball', - 34: 'kite', - 35: 'baseball bat', - 36: 'baseball glove', - 37: 'skateboard', - 38: 'surfboard', - 39: 'tennis racket', - 40: 'bottle', - 41: 'wine glass', - 42: 'cup', - 43: 'fork', - 44: 'knife', - 45: 'spoon', - 46: 'bowl', - 47: 'banana', - 48: 'apple', - 49: 'sandwich', - 50: 'orange', - 51: 'broccoli', - 52: 'carrot', - 53: 'hot dog', - 54: 'pizza', - 55: 'donut', - 56: 'cake', - 57: 'chair', - 58: 'couch', - 59: 'potted plant', - 60: 'bed', - 61: 'dining table', - 62: 'toilet', - 63: 'tv', - 64: 'laptop', - 65: 'mouse', - 66: 'remote', - 67: 'keyboard', - 68: 'cell phone', - 69: 'microwave', - 70: 'oven', - 71: 'toaster', - 72: 'sink', - 73: 'refrigerator', - 74: 'book', - 75: 'clock', - 76: 'vase', - 77: 'scissors', - 78: 'teddy bear', - 79: 'hair drier', - 80: 'toothbrush'} - -""" -if __name__ == "__main__": - # from https://github.com/alirezazareian/ovr-cnn/blob/master/ipynb/001.ipynb - # Create zero-shot setting data split in COCO - import json - import ipdb - - with open('./datasets/coco/annotations/instances_train2017.json', 'r') as fin: - coco_train_anno_all = json.load(fin) - - with open('./datasets/coco/annotations/instances_train2017.json', 'r') as fin: - coco_train_anno_seen = json.load(fin) - - with open('./datasets/coco/annotations/instances_train2017.json', 'r') as fin: - coco_train_anno_unseen = json.load(fin) - - with open('./datasets/coco/annotations/instances_val2017.json', 'r') as fin: - coco_val_anno_all = json.load(fin) - - with open('./datasets/coco/annotations/instances_val2017.json', 'r') as fin: - coco_val_anno_seen = json.load(fin) - - with open('./datasets/coco/annotations/instances_val2017.json', 'r') as fin: - coco_val_anno_unseen = json.load(fin) - - labels_seen = COCO_SEEN_CLS - labels_unseen = COCO_UNSEEN_CLS - labels_all = [item['name'] for item in coco_val_anno_all['categories']] # 80 class names - # len(labels_seen), len(labels_unseen) - # set(labels_seen) - set(labels_all) - # set(labels_unseen) - set(labels_all) - - class_id_to_split = {} # {1: 'seen', 2: 'seen', 3: 'seen', 4: 'seen', 5: 'unseen',...} - class_name_to_split = {} # {'person': 'seen', 'bicycle': 'seen', 'car': 'seen', 'motorcycle': 'seen', 'airplane': 'unseen',...} - for item in coco_val_anno_all['categories']: - if item['name'] in labels_seen: - class_id_to_split[item['id']] = 'seen' - class_name_to_split[item['name']] = 'seen' - elif item['name'] in labels_unseen: - class_id_to_split[item['id']] = 'unseen' - class_name_to_split[item['name']] = 'unseen' - - # class_name_to_emb = {} - # with open('../datasets/coco/zero-shot/glove.6B.300d.txt', 'r') as fin: - # for row in fin: - # row_tk = row.split() - # if row_tk[0] in class_name_to_split: - # class_name_to_emb[row_tk[0]] = [float(num) for num in row_tk[1:]] - # len(class_name_to_emb), len(class_name_to_split) - - def filter_annotation(anno_dict, split_name_list): - " "" - COCO annotations have fields: dict_keys(['info', 'licenses', 'images', 'annotations', 'categories']) - This function (1) filters the category metadata (list) in 'categories'; - (2) filter instance annotation in 'annotations'; (3) filter image metadata (list) in 'images - "" " - filtered_categories = [] - for item in anno_dict['categories']: - if class_id_to_split.get(item['id']) in split_name_list: - #item['embedding'] = class_name_to_emb[item['name']] - item['split'] = class_id_to_split.get(item['id']) - filtered_categories.append(item) - anno_dict['categories'] = filtered_categories - - filtered_images = [] - filtered_annotations = [] - useful_image_ids = set() - for item in anno_dict['annotations']: - if class_id_to_split.get(item['category_id']) in split_name_list: - filtered_annotations.append(item) - useful_image_ids.add(item['image_id']) - for item in anno_dict['images']: - if item['id'] in useful_image_ids: - filtered_images.append(item) - anno_dict['annotations'] = filtered_annotations - anno_dict['images'] = filtered_images - - filter_annotation(coco_train_anno_seen, ['seen']) - filter_annotation(coco_train_anno_unseen, ['unseen']) - filter_annotation(coco_train_anno_all, ['seen', 'unseen']) - filter_annotation(coco_val_anno_seen, ['seen']) - filter_annotation(coco_val_anno_unseen, ['unseen']) - filter_annotation(coco_val_anno_all, ['seen', 'unseen']) - - with open('./datasets/coco/annotations/ovd_ins_train2017_b.json', 'w') as fout: - json.dump(coco_train_anno_seen, fout) - with open('./datasets/coco/annotations/ovd_ins_train2017_t.json', 'w') as fout: - json.dump(coco_train_anno_unseen, fout) - with open('./datasets/coco/annotations/ovd_ins_train2017_all.json', 'w') as fout: - json.dump(coco_train_anno_all, fout) - with open('./datasets/coco/annotations/ovd_ins_val2017_b.json', 'w') as fout: - json.dump(coco_val_anno_seen, fout) - with open('./datasets/coco/annotations/ovd_ins_val2017_t.json', 'w') as fout: - json.dump(coco_val_anno_unseen, fout) - with open('./datasets/coco/annotations/ovd_ins_val2017_all.json', 'w') as fout: - json.dump(coco_val_anno_all, fout) -""" \ No newline at end of file diff --git a/src/regionclip/datasets/lvis_v1_categories.py b/src/regionclip/datasets/lvis_v1_categories.py deleted file mode 100644 index 7374e6968bb006f5d8c49e75d9d3b31ea3d77d05..0000000000000000000000000000000000000000 --- a/src/regionclip/datasets/lvis_v1_categories.py +++ /dev/null @@ -1,16 +0,0 @@ -# Copyright (c) Facebook, Inc. and its affiliates. -# Autogen with -# with open("lvis_v1_val.json", "r") as f: -# a = json.load(f) -# c = a["categories"] -# for x in c: -# del x["image_count"] -# del x["instance_count"] -# LVIS_CATEGORIES = repr(c) + " # noqa" -# with open("/tmp/lvis_categories.py", "wt") as f: -# f.write(f"LVIS_CATEGORIES = {LVIS_CATEGORIES}") -# Then paste the contents of that file below - -# fmt: off -LVIS_CATEGORIES = [{'frequency': 'c', 'synset': 'aerosol.n.02', 'synonyms': ['aerosol_can', 'spray_can'], 'id': 1, 'def': 'a dispenser that holds a substance under pressure', 'name': 'aerosol_can'}, {'frequency': 'f', 'synset': 'air_conditioner.n.01', 'synonyms': ['air_conditioner'], 'id': 2, 'def': 'a machine that keeps air cool and dry', 'name': 'air_conditioner'}, {'frequency': 'f', 'synset': 'airplane.n.01', 'synonyms': ['airplane', 'aeroplane'], 'id': 3, 'def': 'an aircraft that has a fixed wing and is powered by propellers or jets', 'name': 'airplane'}, {'frequency': 'f', 'synset': 'alarm_clock.n.01', 'synonyms': ['alarm_clock'], 'id': 4, 'def': 'a clock that wakes a sleeper at some preset time', 'name': 'alarm_clock'}, {'frequency': 'c', 'synset': 'alcohol.n.01', 'synonyms': ['alcohol', 'alcoholic_beverage'], 'id': 5, 'def': 'a liquor or brew containing alcohol as the active agent', 'name': 'alcohol'}, {'frequency': 'c', 'synset': 'alligator.n.02', 'synonyms': ['alligator', 'gator'], 'id': 6, 'def': 'amphibious reptiles related to crocodiles but with shorter broader snouts', 'name': 'alligator'}, {'frequency': 'c', 'synset': 'almond.n.02', 'synonyms': ['almond'], 'id': 7, 'def': 'oval-shaped edible seed of the almond tree', 'name': 'almond'}, {'frequency': 'c', 'synset': 'ambulance.n.01', 'synonyms': ['ambulance'], 'id': 8, 'def': 'a vehicle that takes people to and from hospitals', 'name': 'ambulance'}, {'frequency': 'c', 'synset': 'amplifier.n.01', 'synonyms': ['amplifier'], 'id': 9, 'def': 'electronic equipment that increases strength of signals', 'name': 'amplifier'}, {'frequency': 'c', 'synset': 'anklet.n.03', 'synonyms': ['anklet', 'ankle_bracelet'], 'id': 10, 'def': 'an ornament worn around the ankle', 'name': 'anklet'}, {'frequency': 'f', 'synset': 'antenna.n.01', 'synonyms': ['antenna', 'aerial', 'transmitting_aerial'], 'id': 11, 'def': 'an electrical device that sends or receives radio or television signals', 'name': 'antenna'}, {'frequency': 'f', 'synset': 'apple.n.01', 'synonyms': ['apple'], 'id': 12, 'def': 'fruit with red or yellow or green skin and sweet to tart crisp whitish flesh', 'name': 'apple'}, {'frequency': 'r', 'synset': 'applesauce.n.01', 'synonyms': ['applesauce'], 'id': 13, 'def': 'puree of stewed apples usually sweetened and spiced', 'name': 'applesauce'}, {'frequency': 'r', 'synset': 'apricot.n.02', 'synonyms': ['apricot'], 'id': 14, 'def': 'downy yellow to rosy-colored fruit resembling a small peach', 'name': 'apricot'}, {'frequency': 'f', 'synset': 'apron.n.01', 'synonyms': ['apron'], 'id': 15, 'def': 'a garment of cloth that is tied about the waist and worn to protect clothing', 'name': 'apron'}, {'frequency': 'c', 'synset': 'aquarium.n.01', 'synonyms': ['aquarium', 'fish_tank'], 'id': 16, 'def': 'a tank/pool/bowl filled with water for keeping live fish and underwater animals', 'name': 'aquarium'}, {'frequency': 'r', 'synset': 'arctic.n.02', 'synonyms': ['arctic_(type_of_shoe)', 'galosh', 'golosh', 'rubber_(type_of_shoe)', 'gumshoe'], 'id': 17, 'def': 'a waterproof overshoe that protects shoes from water or snow', 'name': 'arctic_(type_of_shoe)'}, {'frequency': 'c', 'synset': 'armband.n.02', 'synonyms': ['armband'], 'id': 18, 'def': 'a band worn around the upper arm', 'name': 'armband'}, {'frequency': 'f', 'synset': 'armchair.n.01', 'synonyms': ['armchair'], 'id': 19, 'def': 'chair with a support on each side for arms', 'name': 'armchair'}, {'frequency': 'r', 'synset': 'armoire.n.01', 'synonyms': ['armoire'], 'id': 20, 'def': 'a large wardrobe or cabinet', 'name': 'armoire'}, {'frequency': 'r', 'synset': 'armor.n.01', 'synonyms': ['armor', 'armour'], 'id': 21, 'def': 'protective covering made of metal and used in combat', 'name': 'armor'}, {'frequency': 'c', 'synset': 'artichoke.n.02', 'synonyms': ['artichoke'], 'id': 22, 'def': 'a thistlelike flower head with edible fleshy leaves and heart', 'name': 'artichoke'}, {'frequency': 'f', 'synset': 'ashcan.n.01', 'synonyms': ['trash_can', 'garbage_can', 'wastebin', 'dustbin', 'trash_barrel', 'trash_bin'], 'id': 23, 'def': 'a bin that holds rubbish until it is collected', 'name': 'trash_can'}, {'frequency': 'c', 'synset': 'ashtray.n.01', 'synonyms': ['ashtray'], 'id': 24, 'def': "a receptacle for the ash from smokers' cigars or cigarettes", 'name': 'ashtray'}, {'frequency': 'c', 'synset': 'asparagus.n.02', 'synonyms': ['asparagus'], 'id': 25, 'def': 'edible young shoots of the asparagus plant', 'name': 'asparagus'}, {'frequency': 'c', 'synset': 'atomizer.n.01', 'synonyms': ['atomizer', 'atomiser', 'spray', 'sprayer', 'nebulizer', 'nebuliser'], 'id': 26, 'def': 'a dispenser that turns a liquid (such as perfume) into a fine mist', 'name': 'atomizer'}, {'frequency': 'f', 'synset': 'avocado.n.01', 'synonyms': ['avocado'], 'id': 27, 'def': 'a pear-shaped fruit with green or blackish skin and rich yellowish pulp enclosing a single large seed', 'name': 'avocado'}, {'frequency': 'c', 'synset': 'award.n.02', 'synonyms': ['award', 'accolade'], 'id': 28, 'def': 'a tangible symbol signifying approval or distinction', 'name': 'award'}, {'frequency': 'f', 'synset': 'awning.n.01', 'synonyms': ['awning'], 'id': 29, 'def': 'a canopy made of canvas to shelter people or things from rain or sun', 'name': 'awning'}, {'frequency': 'r', 'synset': 'ax.n.01', 'synonyms': ['ax', 'axe'], 'id': 30, 'def': 'an edge tool with a heavy bladed head mounted across a handle', 'name': 'ax'}, {'frequency': 'r', 'synset': 'baboon.n.01', 'synonyms': ['baboon'], 'id': 31, 'def': 'large terrestrial monkeys having doglike muzzles', 'name': 'baboon'}, {'frequency': 'f', 'synset': 'baby_buggy.n.01', 'synonyms': ['baby_buggy', 'baby_carriage', 'perambulator', 'pram', 'stroller'], 'id': 32, 'def': 'a small vehicle with four wheels in which a baby or child is pushed around', 'name': 'baby_buggy'}, {'frequency': 'c', 'synset': 'backboard.n.01', 'synonyms': ['basketball_backboard'], 'id': 33, 'def': 'a raised vertical board with basket attached; used to play basketball', 'name': 'basketball_backboard'}, {'frequency': 'f', 'synset': 'backpack.n.01', 'synonyms': ['backpack', 'knapsack', 'packsack', 'rucksack', 'haversack'], 'id': 34, 'def': 'a bag carried by a strap on your back or shoulder', 'name': 'backpack'}, {'frequency': 'f', 'synset': 'bag.n.04', 'synonyms': ['handbag', 'purse', 'pocketbook'], 'id': 35, 'def': 'a container used for carrying money and small personal items or accessories', 'name': 'handbag'}, {'frequency': 'f', 'synset': 'bag.n.06', 'synonyms': ['suitcase', 'baggage', 'luggage'], 'id': 36, 'def': 'cases used to carry belongings when traveling', 'name': 'suitcase'}, {'frequency': 'c', 'synset': 'bagel.n.01', 'synonyms': ['bagel', 'beigel'], 'id': 37, 'def': 'glazed yeast-raised doughnut-shaped roll with hard crust', 'name': 'bagel'}, {'frequency': 'r', 'synset': 'bagpipe.n.01', 'synonyms': ['bagpipe'], 'id': 38, 'def': 'a tubular wind instrument; the player blows air into a bag and squeezes it out', 'name': 'bagpipe'}, {'frequency': 'r', 'synset': 'baguet.n.01', 'synonyms': ['baguet', 'baguette'], 'id': 39, 'def': 'narrow French stick loaf', 'name': 'baguet'}, {'frequency': 'r', 'synset': 'bait.n.02', 'synonyms': ['bait', 'lure'], 'id': 40, 'def': 'something used to lure fish or other animals into danger so they can be trapped or killed', 'name': 'bait'}, {'frequency': 'f', 'synset': 'ball.n.06', 'synonyms': ['ball'], 'id': 41, 'def': 'a spherical object used as a plaything', 'name': 'ball'}, {'frequency': 'r', 'synset': 'ballet_skirt.n.01', 'synonyms': ['ballet_skirt', 'tutu'], 'id': 42, 'def': 'very short skirt worn by ballerinas', 'name': 'ballet_skirt'}, {'frequency': 'f', 'synset': 'balloon.n.01', 'synonyms': ['balloon'], 'id': 43, 'def': 'large tough nonrigid bag filled with gas or heated air', 'name': 'balloon'}, {'frequency': 'c', 'synset': 'bamboo.n.02', 'synonyms': ['bamboo'], 'id': 44, 'def': 'woody tropical grass having hollow woody stems', 'name': 'bamboo'}, {'frequency': 'f', 'synset': 'banana.n.02', 'synonyms': ['banana'], 'id': 45, 'def': 'elongated crescent-shaped yellow fruit with soft sweet flesh', 'name': 'banana'}, {'frequency': 'c', 'synset': 'band_aid.n.01', 'synonyms': ['Band_Aid'], 'id': 46, 'def': 'trade name for an adhesive bandage to cover small cuts or blisters', 'name': 'Band_Aid'}, {'frequency': 'c', 'synset': 'bandage.n.01', 'synonyms': ['bandage'], 'id': 47, 'def': 'a piece of soft material that covers and protects an injured part of the body', 'name': 'bandage'}, {'frequency': 'f', 'synset': 'bandanna.n.01', 'synonyms': ['bandanna', 'bandana'], 'id': 48, 'def': 'large and brightly colored handkerchief; often used as a neckerchief', 'name': 'bandanna'}, {'frequency': 'r', 'synset': 'banjo.n.01', 'synonyms': ['banjo'], 'id': 49, 'def': 'a stringed instrument of the guitar family with a long neck and circular body', 'name': 'banjo'}, {'frequency': 'f', 'synset': 'banner.n.01', 'synonyms': ['banner', 'streamer'], 'id': 50, 'def': 'long strip of cloth or paper used for decoration or advertising', 'name': 'banner'}, {'frequency': 'r', 'synset': 'barbell.n.01', 'synonyms': ['barbell'], 'id': 51, 'def': 'a bar to which heavy discs are attached at each end; used in weightlifting', 'name': 'barbell'}, {'frequency': 'r', 'synset': 'barge.n.01', 'synonyms': ['barge'], 'id': 52, 'def': 'a flatbottom boat for carrying heavy loads (especially on canals)', 'name': 'barge'}, {'frequency': 'f', 'synset': 'barrel.n.02', 'synonyms': ['barrel', 'cask'], 'id': 53, 'def': 'a cylindrical container that holds liquids', 'name': 'barrel'}, {'frequency': 'c', 'synset': 'barrette.n.01', 'synonyms': ['barrette'], 'id': 54, 'def': "a pin for holding women's hair in place", 'name': 'barrette'}, {'frequency': 'c', 'synset': 'barrow.n.03', 'synonyms': ['barrow', 'garden_cart', 'lawn_cart', 'wheelbarrow'], 'id': 55, 'def': 'a cart for carrying small loads; has handles and one or more wheels', 'name': 'barrow'}, {'frequency': 'f', 'synset': 'base.n.03', 'synonyms': ['baseball_base'], 'id': 56, 'def': 'a place that the runner must touch before scoring', 'name': 'baseball_base'}, {'frequency': 'f', 'synset': 'baseball.n.02', 'synonyms': ['baseball'], 'id': 57, 'def': 'a ball used in playing baseball', 'name': 'baseball'}, {'frequency': 'f', 'synset': 'baseball_bat.n.01', 'synonyms': ['baseball_bat'], 'id': 58, 'def': 'an implement used in baseball by the batter', 'name': 'baseball_bat'}, {'frequency': 'f', 'synset': 'baseball_cap.n.01', 'synonyms': ['baseball_cap', 'jockey_cap', 'golf_cap'], 'id': 59, 'def': 'a cap with a bill', 'name': 'baseball_cap'}, {'frequency': 'f', 'synset': 'baseball_glove.n.01', 'synonyms': ['baseball_glove', 'baseball_mitt'], 'id': 60, 'def': 'the handwear used by fielders in playing baseball', 'name': 'baseball_glove'}, {'frequency': 'f', 'synset': 'basket.n.01', 'synonyms': ['basket', 'handbasket'], 'id': 61, 'def': 'a container that is usually woven and has handles', 'name': 'basket'}, {'frequency': 'c', 'synset': 'basketball.n.02', 'synonyms': ['basketball'], 'id': 62, 'def': 'an inflated ball used in playing basketball', 'name': 'basketball'}, {'frequency': 'r', 'synset': 'bass_horn.n.01', 'synonyms': ['bass_horn', 'sousaphone', 'tuba'], 'id': 63, 'def': 'the lowest brass wind instrument', 'name': 'bass_horn'}, {'frequency': 'c', 'synset': 'bat.n.01', 'synonyms': ['bat_(animal)'], 'id': 64, 'def': 'nocturnal mouselike mammal with forelimbs modified to form membranous wings', 'name': 'bat_(animal)'}, {'frequency': 'f', 'synset': 'bath_mat.n.01', 'synonyms': ['bath_mat'], 'id': 65, 'def': 'a heavy towel or mat to stand on while drying yourself after a bath', 'name': 'bath_mat'}, {'frequency': 'f', 'synset': 'bath_towel.n.01', 'synonyms': ['bath_towel'], 'id': 66, 'def': 'a large towel; to dry yourself after a bath', 'name': 'bath_towel'}, {'frequency': 'c', 'synset': 'bathrobe.n.01', 'synonyms': ['bathrobe'], 'id': 67, 'def': 'a loose-fitting robe of towelling; worn after a bath or swim', 'name': 'bathrobe'}, {'frequency': 'f', 'synset': 'bathtub.n.01', 'synonyms': ['bathtub', 'bathing_tub'], 'id': 68, 'def': 'a large open container that you fill with water and use to wash the body', 'name': 'bathtub'}, {'frequency': 'r', 'synset': 'batter.n.02', 'synonyms': ['batter_(food)'], 'id': 69, 'def': 'a liquid or semiliquid mixture, as of flour, eggs, and milk, used in cooking', 'name': 'batter_(food)'}, {'frequency': 'c', 'synset': 'battery.n.02', 'synonyms': ['battery'], 'id': 70, 'def': 'a portable device that produces electricity', 'name': 'battery'}, {'frequency': 'r', 'synset': 'beach_ball.n.01', 'synonyms': ['beachball'], 'id': 71, 'def': 'large and light ball; for play at the seaside', 'name': 'beachball'}, {'frequency': 'c', 'synset': 'bead.n.01', 'synonyms': ['bead'], 'id': 72, 'def': 'a small ball with a hole through the middle used for ornamentation, jewellery, etc.', 'name': 'bead'}, {'frequency': 'c', 'synset': 'bean_curd.n.01', 'synonyms': ['bean_curd', 'tofu'], 'id': 73, 'def': 'cheeselike food made of curdled soybean milk', 'name': 'bean_curd'}, {'frequency': 'c', 'synset': 'beanbag.n.01', 'synonyms': ['beanbag'], 'id': 74, 'def': 'a bag filled with dried beans or similar items; used in games or to sit on', 'name': 'beanbag'}, {'frequency': 'f', 'synset': 'beanie.n.01', 'synonyms': ['beanie', 'beany'], 'id': 75, 'def': 'a small skullcap; formerly worn by schoolboys and college freshmen', 'name': 'beanie'}, {'frequency': 'f', 'synset': 'bear.n.01', 'synonyms': ['bear'], 'id': 76, 'def': 'large carnivorous or omnivorous mammals with shaggy coats and claws', 'name': 'bear'}, {'frequency': 'f', 'synset': 'bed.n.01', 'synonyms': ['bed'], 'id': 77, 'def': 'a piece of furniture that provides a place to sleep', 'name': 'bed'}, {'frequency': 'r', 'synset': 'bedpan.n.01', 'synonyms': ['bedpan'], 'id': 78, 'def': 'a shallow vessel used by a bedridden patient for defecation and urination', 'name': 'bedpan'}, {'frequency': 'f', 'synset': 'bedspread.n.01', 'synonyms': ['bedspread', 'bedcover', 'bed_covering', 'counterpane', 'spread'], 'id': 79, 'def': 'decorative cover for a bed', 'name': 'bedspread'}, {'frequency': 'f', 'synset': 'beef.n.01', 'synonyms': ['cow'], 'id': 80, 'def': 'cattle/cow', 'name': 'cow'}, {'frequency': 'f', 'synset': 'beef.n.02', 'synonyms': ['beef_(food)', 'boeuf_(food)'], 'id': 81, 'def': 'meat from an adult domestic bovine', 'name': 'beef_(food)'}, {'frequency': 'r', 'synset': 'beeper.n.01', 'synonyms': ['beeper', 'pager'], 'id': 82, 'def': 'an device that beeps when the person carrying it is being paged', 'name': 'beeper'}, {'frequency': 'f', 'synset': 'beer_bottle.n.01', 'synonyms': ['beer_bottle'], 'id': 83, 'def': 'a bottle that holds beer', 'name': 'beer_bottle'}, {'frequency': 'c', 'synset': 'beer_can.n.01', 'synonyms': ['beer_can'], 'id': 84, 'def': 'a can that holds beer', 'name': 'beer_can'}, {'frequency': 'r', 'synset': 'beetle.n.01', 'synonyms': ['beetle'], 'id': 85, 'def': 'insect with hard wing covers', 'name': 'beetle'}, {'frequency': 'f', 'synset': 'bell.n.01', 'synonyms': ['bell'], 'id': 86, 'def': 'a hollow device made of metal that makes a ringing sound when struck', 'name': 'bell'}, {'frequency': 'f', 'synset': 'bell_pepper.n.02', 'synonyms': ['bell_pepper', 'capsicum'], 'id': 87, 'def': 'large bell-shaped sweet pepper in green or red or yellow or orange or black varieties', 'name': 'bell_pepper'}, {'frequency': 'f', 'synset': 'belt.n.02', 'synonyms': ['belt'], 'id': 88, 'def': 'a band to tie or buckle around the body (usually at the waist)', 'name': 'belt'}, {'frequency': 'f', 'synset': 'belt_buckle.n.01', 'synonyms': ['belt_buckle'], 'id': 89, 'def': 'the buckle used to fasten a belt', 'name': 'belt_buckle'}, {'frequency': 'f', 'synset': 'bench.n.01', 'synonyms': ['bench'], 'id': 90, 'def': 'a long seat for more than one person', 'name': 'bench'}, {'frequency': 'c', 'synset': 'beret.n.01', 'synonyms': ['beret'], 'id': 91, 'def': 'a cap with no brim or bill; made of soft cloth', 'name': 'beret'}, {'frequency': 'c', 'synset': 'bib.n.02', 'synonyms': ['bib'], 'id': 92, 'def': 'a napkin tied under the chin of a child while eating', 'name': 'bib'}, {'frequency': 'r', 'synset': 'bible.n.01', 'synonyms': ['Bible'], 'id': 93, 'def': 'the sacred writings of the Christian religions', 'name': 'Bible'}, {'frequency': 'f', 'synset': 'bicycle.n.01', 'synonyms': ['bicycle', 'bike_(bicycle)'], 'id': 94, 'def': 'a wheeled vehicle that has two wheels and is moved by foot pedals', 'name': 'bicycle'}, {'frequency': 'f', 'synset': 'bill.n.09', 'synonyms': ['visor', 'vizor'], 'id': 95, 'def': 'a brim that projects to the front to shade the eyes', 'name': 'visor'}, {'frequency': 'f', 'synset': 'billboard.n.01', 'synonyms': ['billboard'], 'id': 96, 'def': 'large outdoor signboard', 'name': 'billboard'}, {'frequency': 'c', 'synset': 'binder.n.03', 'synonyms': ['binder', 'ring-binder'], 'id': 97, 'def': 'holds loose papers or magazines', 'name': 'binder'}, {'frequency': 'c', 'synset': 'binoculars.n.01', 'synonyms': ['binoculars', 'field_glasses', 'opera_glasses'], 'id': 98, 'def': 'an optical instrument designed for simultaneous use by both eyes', 'name': 'binoculars'}, {'frequency': 'f', 'synset': 'bird.n.01', 'synonyms': ['bird'], 'id': 99, 'def': 'animal characterized by feathers and wings', 'name': 'bird'}, {'frequency': 'c', 'synset': 'bird_feeder.n.01', 'synonyms': ['birdfeeder'], 'id': 100, 'def': 'an outdoor device that supplies food for wild birds', 'name': 'birdfeeder'}, {'frequency': 'c', 'synset': 'birdbath.n.01', 'synonyms': ['birdbath'], 'id': 101, 'def': 'an ornamental basin (usually in a garden) for birds to bathe in', 'name': 'birdbath'}, {'frequency': 'c', 'synset': 'birdcage.n.01', 'synonyms': ['birdcage'], 'id': 102, 'def': 'a cage in which a bird can be kept', 'name': 'birdcage'}, {'frequency': 'c', 'synset': 'birdhouse.n.01', 'synonyms': ['birdhouse'], 'id': 103, 'def': 'a shelter for birds', 'name': 'birdhouse'}, {'frequency': 'f', 'synset': 'birthday_cake.n.01', 'synonyms': ['birthday_cake'], 'id': 104, 'def': 'decorated cake served at a birthday party', 'name': 'birthday_cake'}, {'frequency': 'r', 'synset': 'birthday_card.n.01', 'synonyms': ['birthday_card'], 'id': 105, 'def': 'a card expressing a birthday greeting', 'name': 'birthday_card'}, {'frequency': 'r', 'synset': 'black_flag.n.01', 'synonyms': ['pirate_flag'], 'id': 106, 'def': 'a flag usually bearing a white skull and crossbones on a black background', 'name': 'pirate_flag'}, {'frequency': 'c', 'synset': 'black_sheep.n.02', 'synonyms': ['black_sheep'], 'id': 107, 'def': 'sheep with a black coat', 'name': 'black_sheep'}, {'frequency': 'c', 'synset': 'blackberry.n.01', 'synonyms': ['blackberry'], 'id': 108, 'def': 'large sweet black or very dark purple edible aggregate fruit', 'name': 'blackberry'}, {'frequency': 'f', 'synset': 'blackboard.n.01', 'synonyms': ['blackboard', 'chalkboard'], 'id': 109, 'def': 'sheet of slate; for writing with chalk', 'name': 'blackboard'}, {'frequency': 'f', 'synset': 'blanket.n.01', 'synonyms': ['blanket'], 'id': 110, 'def': 'bedding that keeps a person warm in bed', 'name': 'blanket'}, {'frequency': 'c', 'synset': 'blazer.n.01', 'synonyms': ['blazer', 'sport_jacket', 'sport_coat', 'sports_jacket', 'sports_coat'], 'id': 111, 'def': 'lightweight jacket; often striped in the colors of a club or school', 'name': 'blazer'}, {'frequency': 'f', 'synset': 'blender.n.01', 'synonyms': ['blender', 'liquidizer', 'liquidiser'], 'id': 112, 'def': 'an electrically powered mixer that mix or chop or liquefy foods', 'name': 'blender'}, {'frequency': 'r', 'synset': 'blimp.n.02', 'synonyms': ['blimp'], 'id': 113, 'def': 'a small nonrigid airship used for observation or as a barrage balloon', 'name': 'blimp'}, {'frequency': 'f', 'synset': 'blinker.n.01', 'synonyms': ['blinker', 'flasher'], 'id': 114, 'def': 'a light that flashes on and off; used as a signal or to send messages', 'name': 'blinker'}, {'frequency': 'f', 'synset': 'blouse.n.01', 'synonyms': ['blouse'], 'id': 115, 'def': 'a top worn by women', 'name': 'blouse'}, {'frequency': 'f', 'synset': 'blueberry.n.02', 'synonyms': ['blueberry'], 'id': 116, 'def': 'sweet edible dark-blue berries of blueberry plants', 'name': 'blueberry'}, {'frequency': 'r', 'synset': 'board.n.09', 'synonyms': ['gameboard'], 'id': 117, 'def': 'a flat portable surface (usually rectangular) designed for board games', 'name': 'gameboard'}, {'frequency': 'f', 'synset': 'boat.n.01', 'synonyms': ['boat', 'ship_(boat)'], 'id': 118, 'def': 'a vessel for travel on water', 'name': 'boat'}, {'frequency': 'r', 'synset': 'bob.n.05', 'synonyms': ['bob', 'bobber', 'bobfloat'], 'id': 119, 'def': 'a small float usually made of cork; attached to a fishing line', 'name': 'bob'}, {'frequency': 'c', 'synset': 'bobbin.n.01', 'synonyms': ['bobbin', 'spool', 'reel'], 'id': 120, 'def': 'a thing around which thread/tape/film or other flexible materials can be wound', 'name': 'bobbin'}, {'frequency': 'c', 'synset': 'bobby_pin.n.01', 'synonyms': ['bobby_pin', 'hairgrip'], 'id': 121, 'def': 'a flat wire hairpin used to hold bobbed hair in place', 'name': 'bobby_pin'}, {'frequency': 'c', 'synset': 'boiled_egg.n.01', 'synonyms': ['boiled_egg', 'coddled_egg'], 'id': 122, 'def': 'egg cooked briefly in the shell in gently boiling water', 'name': 'boiled_egg'}, {'frequency': 'r', 'synset': 'bolo_tie.n.01', 'synonyms': ['bolo_tie', 'bolo', 'bola_tie', 'bola'], 'id': 123, 'def': 'a cord fastened around the neck with an ornamental clasp and worn as a necktie', 'name': 'bolo_tie'}, {'frequency': 'c', 'synset': 'bolt.n.03', 'synonyms': ['deadbolt'], 'id': 124, 'def': 'the part of a lock that is engaged or withdrawn with a key', 'name': 'deadbolt'}, {'frequency': 'f', 'synset': 'bolt.n.06', 'synonyms': ['bolt'], 'id': 125, 'def': 'a screw that screws into a nut to form a fastener', 'name': 'bolt'}, {'frequency': 'r', 'synset': 'bonnet.n.01', 'synonyms': ['bonnet'], 'id': 126, 'def': 'a hat tied under the chin', 'name': 'bonnet'}, {'frequency': 'f', 'synset': 'book.n.01', 'synonyms': ['book'], 'id': 127, 'def': 'a written work or composition that has been published', 'name': 'book'}, {'frequency': 'c', 'synset': 'bookcase.n.01', 'synonyms': ['bookcase'], 'id': 128, 'def': 'a piece of furniture with shelves for storing books', 'name': 'bookcase'}, {'frequency': 'c', 'synset': 'booklet.n.01', 'synonyms': ['booklet', 'brochure', 'leaflet', 'pamphlet'], 'id': 129, 'def': 'a small book usually having a paper cover', 'name': 'booklet'}, {'frequency': 'r', 'synset': 'bookmark.n.01', 'synonyms': ['bookmark', 'bookmarker'], 'id': 130, 'def': 'a marker (a piece of paper or ribbon) placed between the pages of a book', 'name': 'bookmark'}, {'frequency': 'r', 'synset': 'boom.n.04', 'synonyms': ['boom_microphone', 'microphone_boom'], 'id': 131, 'def': 'a pole carrying an overhead microphone projected over a film or tv set', 'name': 'boom_microphone'}, {'frequency': 'f', 'synset': 'boot.n.01', 'synonyms': ['boot'], 'id': 132, 'def': 'footwear that covers the whole foot and lower leg', 'name': 'boot'}, {'frequency': 'f', 'synset': 'bottle.n.01', 'synonyms': ['bottle'], 'id': 133, 'def': 'a glass or plastic vessel used for storing drinks or other liquids', 'name': 'bottle'}, {'frequency': 'c', 'synset': 'bottle_opener.n.01', 'synonyms': ['bottle_opener'], 'id': 134, 'def': 'an opener for removing caps or corks from bottles', 'name': 'bottle_opener'}, {'frequency': 'c', 'synset': 'bouquet.n.01', 'synonyms': ['bouquet'], 'id': 135, 'def': 'an arrangement of flowers that is usually given as a present', 'name': 'bouquet'}, {'frequency': 'r', 'synset': 'bow.n.04', 'synonyms': ['bow_(weapon)'], 'id': 136, 'def': 'a weapon for shooting arrows', 'name': 'bow_(weapon)'}, {'frequency': 'f', 'synset': 'bow.n.08', 'synonyms': ['bow_(decorative_ribbons)'], 'id': 137, 'def': 'a decorative interlacing of ribbons', 'name': 'bow_(decorative_ribbons)'}, {'frequency': 'f', 'synset': 'bow_tie.n.01', 'synonyms': ['bow-tie', 'bowtie'], 'id': 138, 'def': "a man's tie that ties in a bow", 'name': 'bow-tie'}, {'frequency': 'f', 'synset': 'bowl.n.03', 'synonyms': ['bowl'], 'id': 139, 'def': 'a dish that is round and open at the top for serving foods', 'name': 'bowl'}, {'frequency': 'r', 'synset': 'bowl.n.08', 'synonyms': ['pipe_bowl'], 'id': 140, 'def': 'a small round container that is open at the top for holding tobacco', 'name': 'pipe_bowl'}, {'frequency': 'c', 'synset': 'bowler_hat.n.01', 'synonyms': ['bowler_hat', 'bowler', 'derby_hat', 'derby', 'plug_hat'], 'id': 141, 'def': 'a felt hat that is round and hard with a narrow brim', 'name': 'bowler_hat'}, {'frequency': 'r', 'synset': 'bowling_ball.n.01', 'synonyms': ['bowling_ball'], 'id': 142, 'def': 'a large ball with finger holes used in the sport of bowling', 'name': 'bowling_ball'}, {'frequency': 'f', 'synset': 'box.n.01', 'synonyms': ['box'], 'id': 143, 'def': 'a (usually rectangular) container; may have a lid', 'name': 'box'}, {'frequency': 'r', 'synset': 'boxing_glove.n.01', 'synonyms': ['boxing_glove'], 'id': 144, 'def': 'large glove coverings the fists of a fighter worn for the sport of boxing', 'name': 'boxing_glove'}, {'frequency': 'c', 'synset': 'brace.n.06', 'synonyms': ['suspenders'], 'id': 145, 'def': 'elastic straps that hold trousers up (usually used in the plural)', 'name': 'suspenders'}, {'frequency': 'f', 'synset': 'bracelet.n.02', 'synonyms': ['bracelet', 'bangle'], 'id': 146, 'def': 'jewelry worn around the wrist for decoration', 'name': 'bracelet'}, {'frequency': 'r', 'synset': 'brass.n.07', 'synonyms': ['brass_plaque'], 'id': 147, 'def': 'a memorial made of brass', 'name': 'brass_plaque'}, {'frequency': 'c', 'synset': 'brassiere.n.01', 'synonyms': ['brassiere', 'bra', 'bandeau'], 'id': 148, 'def': 'an undergarment worn by women to support their breasts', 'name': 'brassiere'}, {'frequency': 'c', 'synset': 'bread-bin.n.01', 'synonyms': ['bread-bin', 'breadbox'], 'id': 149, 'def': 'a container used to keep bread or cake in', 'name': 'bread-bin'}, {'frequency': 'f', 'synset': 'bread.n.01', 'synonyms': ['bread'], 'id': 150, 'def': 'food made from dough of flour or meal and usually raised with yeast or baking powder and then baked', 'name': 'bread'}, {'frequency': 'r', 'synset': 'breechcloth.n.01', 'synonyms': ['breechcloth', 'breechclout', 'loincloth'], 'id': 151, 'def': 'a garment that provides covering for the loins', 'name': 'breechcloth'}, {'frequency': 'f', 'synset': 'bridal_gown.n.01', 'synonyms': ['bridal_gown', 'wedding_gown', 'wedding_dress'], 'id': 152, 'def': 'a gown worn by the bride at a wedding', 'name': 'bridal_gown'}, {'frequency': 'c', 'synset': 'briefcase.n.01', 'synonyms': ['briefcase'], 'id': 153, 'def': 'a case with a handle; for carrying papers or files or books', 'name': 'briefcase'}, {'frequency': 'f', 'synset': 'broccoli.n.01', 'synonyms': ['broccoli'], 'id': 154, 'def': 'plant with dense clusters of tight green flower buds', 'name': 'broccoli'}, {'frequency': 'r', 'synset': 'brooch.n.01', 'synonyms': ['broach'], 'id': 155, 'def': 'a decorative pin worn by women', 'name': 'broach'}, {'frequency': 'c', 'synset': 'broom.n.01', 'synonyms': ['broom'], 'id': 156, 'def': 'bundle of straws or twigs attached to a long handle; used for cleaning', 'name': 'broom'}, {'frequency': 'c', 'synset': 'brownie.n.03', 'synonyms': ['brownie'], 'id': 157, 'def': 'square or bar of very rich chocolate cake usually with nuts', 'name': 'brownie'}, {'frequency': 'c', 'synset': 'brussels_sprouts.n.01', 'synonyms': ['brussels_sprouts'], 'id': 158, 'def': 'the small edible cabbage-like buds growing along a stalk', 'name': 'brussels_sprouts'}, {'frequency': 'r', 'synset': 'bubble_gum.n.01', 'synonyms': ['bubble_gum'], 'id': 159, 'def': 'a kind of chewing gum that can be blown into bubbles', 'name': 'bubble_gum'}, {'frequency': 'f', 'synset': 'bucket.n.01', 'synonyms': ['bucket', 'pail'], 'id': 160, 'def': 'a roughly cylindrical vessel that is open at the top', 'name': 'bucket'}, {'frequency': 'r', 'synset': 'buggy.n.01', 'synonyms': ['horse_buggy'], 'id': 161, 'def': 'a small lightweight carriage; drawn by a single horse', 'name': 'horse_buggy'}, {'frequency': 'c', 'synset': 'bull.n.11', 'synonyms': ['horned_cow'], 'id': 162, 'def': 'a cow with horns', 'name': 'bull'}, {'frequency': 'c', 'synset': 'bulldog.n.01', 'synonyms': ['bulldog'], 'id': 163, 'def': 'a thickset short-haired dog with a large head and strong undershot lower jaw', 'name': 'bulldog'}, {'frequency': 'r', 'synset': 'bulldozer.n.01', 'synonyms': ['bulldozer', 'dozer'], 'id': 164, 'def': 'large powerful tractor; a large blade in front flattens areas of ground', 'name': 'bulldozer'}, {'frequency': 'c', 'synset': 'bullet_train.n.01', 'synonyms': ['bullet_train'], 'id': 165, 'def': 'a high-speed passenger train', 'name': 'bullet_train'}, {'frequency': 'c', 'synset': 'bulletin_board.n.02', 'synonyms': ['bulletin_board', 'notice_board'], 'id': 166, 'def': 'a board that hangs on a wall; displays announcements', 'name': 'bulletin_board'}, {'frequency': 'r', 'synset': 'bulletproof_vest.n.01', 'synonyms': ['bulletproof_vest'], 'id': 167, 'def': 'a vest capable of resisting the impact of a bullet', 'name': 'bulletproof_vest'}, {'frequency': 'c', 'synset': 'bullhorn.n.01', 'synonyms': ['bullhorn', 'megaphone'], 'id': 168, 'def': 'a portable loudspeaker with built-in microphone and amplifier', 'name': 'bullhorn'}, {'frequency': 'f', 'synset': 'bun.n.01', 'synonyms': ['bun', 'roll'], 'id': 169, 'def': 'small rounded bread either plain or sweet', 'name': 'bun'}, {'frequency': 'c', 'synset': 'bunk_bed.n.01', 'synonyms': ['bunk_bed'], 'id': 170, 'def': 'beds built one above the other', 'name': 'bunk_bed'}, {'frequency': 'f', 'synset': 'buoy.n.01', 'synonyms': ['buoy'], 'id': 171, 'def': 'a float attached by rope to the seabed to mark channels in a harbor or underwater hazards', 'name': 'buoy'}, {'frequency': 'r', 'synset': 'burrito.n.01', 'synonyms': ['burrito'], 'id': 172, 'def': 'a flour tortilla folded around a filling', 'name': 'burrito'}, {'frequency': 'f', 'synset': 'bus.n.01', 'synonyms': ['bus_(vehicle)', 'autobus', 'charabanc', 'double-decker', 'motorbus', 'motorcoach'], 'id': 173, 'def': 'a vehicle carrying many passengers; used for public transport', 'name': 'bus_(vehicle)'}, {'frequency': 'c', 'synset': 'business_card.n.01', 'synonyms': ['business_card'], 'id': 174, 'def': "a card on which are printed the person's name and business affiliation", 'name': 'business_card'}, {'frequency': 'f', 'synset': 'butter.n.01', 'synonyms': ['butter'], 'id': 175, 'def': 'an edible emulsion of fat globules made by churning milk or cream; for cooking and table use', 'name': 'butter'}, {'frequency': 'c', 'synset': 'butterfly.n.01', 'synonyms': ['butterfly'], 'id': 176, 'def': 'insect typically having a slender body with knobbed antennae and broad colorful wings', 'name': 'butterfly'}, {'frequency': 'f', 'synset': 'button.n.01', 'synonyms': ['button'], 'id': 177, 'def': 'a round fastener sewn to shirts and coats etc to fit through buttonholes', 'name': 'button'}, {'frequency': 'f', 'synset': 'cab.n.03', 'synonyms': ['cab_(taxi)', 'taxi', 'taxicab'], 'id': 178, 'def': 'a car that takes passengers where they want to go in exchange for money', 'name': 'cab_(taxi)'}, {'frequency': 'r', 'synset': 'cabana.n.01', 'synonyms': ['cabana'], 'id': 179, 'def': 'a small tent used as a dressing room beside the sea or a swimming pool', 'name': 'cabana'}, {'frequency': 'c', 'synset': 'cabin_car.n.01', 'synonyms': ['cabin_car', 'caboose'], 'id': 180, 'def': 'a car on a freight train for use of the train crew; usually the last car on the train', 'name': 'cabin_car'}, {'frequency': 'f', 'synset': 'cabinet.n.01', 'synonyms': ['cabinet'], 'id': 181, 'def': 'a piece of furniture resembling a cupboard with doors and shelves and drawers', 'name': 'cabinet'}, {'frequency': 'r', 'synset': 'cabinet.n.03', 'synonyms': ['locker', 'storage_locker'], 'id': 182, 'def': 'a storage compartment for clothes and valuables; usually it has a lock', 'name': 'locker'}, {'frequency': 'f', 'synset': 'cake.n.03', 'synonyms': ['cake'], 'id': 183, 'def': 'baked goods made from or based on a mixture of flour, sugar, eggs, and fat', 'name': 'cake'}, {'frequency': 'c', 'synset': 'calculator.n.02', 'synonyms': ['calculator'], 'id': 184, 'def': 'a small machine that is used for mathematical calculations', 'name': 'calculator'}, {'frequency': 'f', 'synset': 'calendar.n.02', 'synonyms': ['calendar'], 'id': 185, 'def': 'a list or register of events (appointments/social events/court cases, etc)', 'name': 'calendar'}, {'frequency': 'c', 'synset': 'calf.n.01', 'synonyms': ['calf'], 'id': 186, 'def': 'young of domestic cattle', 'name': 'calf'}, {'frequency': 'c', 'synset': 'camcorder.n.01', 'synonyms': ['camcorder'], 'id': 187, 'def': 'a portable television camera and videocassette recorder', 'name': 'camcorder'}, {'frequency': 'c', 'synset': 'camel.n.01', 'synonyms': ['camel'], 'id': 188, 'def': 'cud-chewing mammal used as a draft or saddle animal in desert regions', 'name': 'camel'}, {'frequency': 'f', 'synset': 'camera.n.01', 'synonyms': ['camera'], 'id': 189, 'def': 'equipment for taking photographs', 'name': 'camera'}, {'frequency': 'c', 'synset': 'camera_lens.n.01', 'synonyms': ['camera_lens'], 'id': 190, 'def': 'a lens that focuses the image in a camera', 'name': 'camera_lens'}, {'frequency': 'c', 'synset': 'camper.n.02', 'synonyms': ['camper_(vehicle)', 'camping_bus', 'motor_home'], 'id': 191, 'def': 'a recreational vehicle equipped for camping out while traveling', 'name': 'camper_(vehicle)'}, {'frequency': 'f', 'synset': 'can.n.01', 'synonyms': ['can', 'tin_can'], 'id': 192, 'def': 'airtight sealed metal container for food or drink or paint etc.', 'name': 'can'}, {'frequency': 'c', 'synset': 'can_opener.n.01', 'synonyms': ['can_opener', 'tin_opener'], 'id': 193, 'def': 'a device for cutting cans open', 'name': 'can_opener'}, {'frequency': 'f', 'synset': 'candle.n.01', 'synonyms': ['candle', 'candlestick'], 'id': 194, 'def': 'stick of wax with a wick in the middle', 'name': 'candle'}, {'frequency': 'f', 'synset': 'candlestick.n.01', 'synonyms': ['candle_holder'], 'id': 195, 'def': 'a holder with sockets for candles', 'name': 'candle_holder'}, {'frequency': 'r', 'synset': 'candy_bar.n.01', 'synonyms': ['candy_bar'], 'id': 196, 'def': 'a candy shaped as a bar', 'name': 'candy_bar'}, {'frequency': 'c', 'synset': 'candy_cane.n.01', 'synonyms': ['candy_cane'], 'id': 197, 'def': 'a hard candy in the shape of a rod (usually with stripes)', 'name': 'candy_cane'}, {'frequency': 'c', 'synset': 'cane.n.01', 'synonyms': ['walking_cane'], 'id': 198, 'def': 'a stick that people can lean on to help them walk', 'name': 'walking_cane'}, {'frequency': 'c', 'synset': 'canister.n.02', 'synonyms': ['canister', 'cannister'], 'id': 199, 'def': 'metal container for storing dry foods such as tea or flour', 'name': 'canister'}, {'frequency': 'c', 'synset': 'canoe.n.01', 'synonyms': ['canoe'], 'id': 200, 'def': 'small and light boat; pointed at both ends; propelled with a paddle', 'name': 'canoe'}, {'frequency': 'c', 'synset': 'cantaloup.n.02', 'synonyms': ['cantaloup', 'cantaloupe'], 'id': 201, 'def': 'the fruit of a cantaloup vine; small to medium-sized melon with yellowish flesh', 'name': 'cantaloup'}, {'frequency': 'r', 'synset': 'canteen.n.01', 'synonyms': ['canteen'], 'id': 202, 'def': 'a flask for carrying water; used by soldiers or travelers', 'name': 'canteen'}, {'frequency': 'f', 'synset': 'cap.n.01', 'synonyms': ['cap_(headwear)'], 'id': 203, 'def': 'a tight-fitting headwear', 'name': 'cap_(headwear)'}, {'frequency': 'f', 'synset': 'cap.n.02', 'synonyms': ['bottle_cap', 'cap_(container_lid)'], 'id': 204, 'def': 'a top (as for a bottle)', 'name': 'bottle_cap'}, {'frequency': 'c', 'synset': 'cape.n.02', 'synonyms': ['cape'], 'id': 205, 'def': 'a sleeveless garment like a cloak but shorter', 'name': 'cape'}, {'frequency': 'c', 'synset': 'cappuccino.n.01', 'synonyms': ['cappuccino', 'coffee_cappuccino'], 'id': 206, 'def': 'equal parts of espresso and steamed milk', 'name': 'cappuccino'}, {'frequency': 'f', 'synset': 'car.n.01', 'synonyms': ['car_(automobile)', 'auto_(automobile)', 'automobile'], 'id': 207, 'def': 'a motor vehicle with four wheels', 'name': 'car_(automobile)'}, {'frequency': 'f', 'synset': 'car.n.02', 'synonyms': ['railcar_(part_of_a_train)', 'railway_car_(part_of_a_train)', 'railroad_car_(part_of_a_train)'], 'id': 208, 'def': 'a wheeled vehicle adapted to the rails of railroad (mark each individual railcar separately)', 'name': 'railcar_(part_of_a_train)'}, {'frequency': 'r', 'synset': 'car.n.04', 'synonyms': ['elevator_car'], 'id': 209, 'def': 'where passengers ride up and down', 'name': 'elevator_car'}, {'frequency': 'r', 'synset': 'car_battery.n.01', 'synonyms': ['car_battery', 'automobile_battery'], 'id': 210, 'def': 'a battery in a motor vehicle', 'name': 'car_battery'}, {'frequency': 'c', 'synset': 'card.n.02', 'synonyms': ['identity_card'], 'id': 211, 'def': 'a card certifying the identity of the bearer', 'name': 'identity_card'}, {'frequency': 'c', 'synset': 'card.n.03', 'synonyms': ['card'], 'id': 212, 'def': 'a rectangular piece of paper used to send messages (e.g. greetings or pictures)', 'name': 'card'}, {'frequency': 'c', 'synset': 'cardigan.n.01', 'synonyms': ['cardigan'], 'id': 213, 'def': 'knitted jacket that is fastened up the front with buttons or a zipper', 'name': 'cardigan'}, {'frequency': 'r', 'synset': 'cargo_ship.n.01', 'synonyms': ['cargo_ship', 'cargo_vessel'], 'id': 214, 'def': 'a ship designed to carry cargo', 'name': 'cargo_ship'}, {'frequency': 'r', 'synset': 'carnation.n.01', 'synonyms': ['carnation'], 'id': 215, 'def': 'plant with pink to purple-red spice-scented usually double flowers', 'name': 'carnation'}, {'frequency': 'c', 'synset': 'carriage.n.02', 'synonyms': ['horse_carriage'], 'id': 216, 'def': 'a vehicle with wheels drawn by one or more horses', 'name': 'horse_carriage'}, {'frequency': 'f', 'synset': 'carrot.n.01', 'synonyms': ['carrot'], 'id': 217, 'def': 'deep orange edible root of the cultivated carrot plant', 'name': 'carrot'}, {'frequency': 'f', 'synset': 'carryall.n.01', 'synonyms': ['tote_bag'], 'id': 218, 'def': 'a capacious bag or basket', 'name': 'tote_bag'}, {'frequency': 'c', 'synset': 'cart.n.01', 'synonyms': ['cart'], 'id': 219, 'def': 'a heavy open wagon usually having two wheels and drawn by an animal', 'name': 'cart'}, {'frequency': 'c', 'synset': 'carton.n.02', 'synonyms': ['carton'], 'id': 220, 'def': 'a container made of cardboard for holding food or drink', 'name': 'carton'}, {'frequency': 'c', 'synset': 'cash_register.n.01', 'synonyms': ['cash_register', 'register_(for_cash_transactions)'], 'id': 221, 'def': 'a cashbox with an adding machine to register transactions', 'name': 'cash_register'}, {'frequency': 'r', 'synset': 'casserole.n.01', 'synonyms': ['casserole'], 'id': 222, 'def': 'food cooked and served in a casserole', 'name': 'casserole'}, {'frequency': 'r', 'synset': 'cassette.n.01', 'synonyms': ['cassette'], 'id': 223, 'def': 'a container that holds a magnetic tape used for recording or playing sound or video', 'name': 'cassette'}, {'frequency': 'c', 'synset': 'cast.n.05', 'synonyms': ['cast', 'plaster_cast', 'plaster_bandage'], 'id': 224, 'def': 'bandage consisting of a firm covering that immobilizes broken bones while they heal', 'name': 'cast'}, {'frequency': 'f', 'synset': 'cat.n.01', 'synonyms': ['cat'], 'id': 225, 'def': 'a domestic house cat', 'name': 'cat'}, {'frequency': 'f', 'synset': 'cauliflower.n.02', 'synonyms': ['cauliflower'], 'id': 226, 'def': 'edible compact head of white undeveloped flowers', 'name': 'cauliflower'}, {'frequency': 'c', 'synset': 'cayenne.n.02', 'synonyms': ['cayenne_(spice)', 'cayenne_pepper_(spice)', 'red_pepper_(spice)'], 'id': 227, 'def': 'ground pods and seeds of pungent red peppers of the genus Capsicum', 'name': 'cayenne_(spice)'}, {'frequency': 'c', 'synset': 'cd_player.n.01', 'synonyms': ['CD_player'], 'id': 228, 'def': 'electronic equipment for playing compact discs (CDs)', 'name': 'CD_player'}, {'frequency': 'f', 'synset': 'celery.n.01', 'synonyms': ['celery'], 'id': 229, 'def': 'widely cultivated herb with aromatic leaf stalks that are eaten raw or cooked', 'name': 'celery'}, {'frequency': 'f', 'synset': 'cellular_telephone.n.01', 'synonyms': ['cellular_telephone', 'cellular_phone', 'cellphone', 'mobile_phone', 'smart_phone'], 'id': 230, 'def': 'a hand-held mobile telephone', 'name': 'cellular_telephone'}, {'frequency': 'r', 'synset': 'chain_mail.n.01', 'synonyms': ['chain_mail', 'ring_mail', 'chain_armor', 'chain_armour', 'ring_armor', 'ring_armour'], 'id': 231, 'def': '(Middle Ages) flexible armor made of interlinked metal rings', 'name': 'chain_mail'}, {'frequency': 'f', 'synset': 'chair.n.01', 'synonyms': ['chair'], 'id': 232, 'def': 'a seat for one person, with a support for the back', 'name': 'chair'}, {'frequency': 'r', 'synset': 'chaise_longue.n.01', 'synonyms': ['chaise_longue', 'chaise', 'daybed'], 'id': 233, 'def': 'a long chair; for reclining', 'name': 'chaise_longue'}, {'frequency': 'r', 'synset': 'chalice.n.01', 'synonyms': ['chalice'], 'id': 234, 'def': 'a bowl-shaped drinking vessel; especially the Eucharistic cup', 'name': 'chalice'}, {'frequency': 'f', 'synset': 'chandelier.n.01', 'synonyms': ['chandelier'], 'id': 235, 'def': 'branched lighting fixture; often ornate; hangs from the ceiling', 'name': 'chandelier'}, {'frequency': 'r', 'synset': 'chap.n.04', 'synonyms': ['chap'], 'id': 236, 'def': 'leather leggings without a seat; worn over trousers by cowboys to protect their legs', 'name': 'chap'}, {'frequency': 'r', 'synset': 'checkbook.n.01', 'synonyms': ['checkbook', 'chequebook'], 'id': 237, 'def': 'a book issued to holders of checking accounts', 'name': 'checkbook'}, {'frequency': 'r', 'synset': 'checkerboard.n.01', 'synonyms': ['checkerboard'], 'id': 238, 'def': 'a board having 64 squares of two alternating colors', 'name': 'checkerboard'}, {'frequency': 'c', 'synset': 'cherry.n.03', 'synonyms': ['cherry'], 'id': 239, 'def': 'a red fruit with a single hard stone', 'name': 'cherry'}, {'frequency': 'r', 'synset': 'chessboard.n.01', 'synonyms': ['chessboard'], 'id': 240, 'def': 'a checkerboard used to play chess', 'name': 'chessboard'}, {'frequency': 'c', 'synset': 'chicken.n.02', 'synonyms': ['chicken_(animal)'], 'id': 241, 'def': 'a domestic fowl bred for flesh or eggs', 'name': 'chicken_(animal)'}, {'frequency': 'c', 'synset': 'chickpea.n.01', 'synonyms': ['chickpea', 'garbanzo'], 'id': 242, 'def': 'the seed of the chickpea plant; usually dried', 'name': 'chickpea'}, {'frequency': 'c', 'synset': 'chili.n.02', 'synonyms': ['chili_(vegetable)', 'chili_pepper_(vegetable)', 'chilli_(vegetable)', 'chilly_(vegetable)', 'chile_(vegetable)'], 'id': 243, 'def': 'very hot and finely tapering pepper of special pungency', 'name': 'chili_(vegetable)'}, {'frequency': 'r', 'synset': 'chime.n.01', 'synonyms': ['chime', 'gong'], 'id': 244, 'def': 'an instrument consisting of a set of bells that are struck with a hammer', 'name': 'chime'}, {'frequency': 'r', 'synset': 'chinaware.n.01', 'synonyms': ['chinaware'], 'id': 245, 'def': 'dishware made of high quality porcelain', 'name': 'chinaware'}, {'frequency': 'c', 'synset': 'chip.n.04', 'synonyms': ['crisp_(potato_chip)', 'potato_chip'], 'id': 246, 'def': 'a thin crisp slice of potato fried in deep fat', 'name': 'crisp_(potato_chip)'}, {'frequency': 'r', 'synset': 'chip.n.06', 'synonyms': ['poker_chip'], 'id': 247, 'def': 'a small disk-shaped counter used to represent money when gambling', 'name': 'poker_chip'}, {'frequency': 'c', 'synset': 'chocolate_bar.n.01', 'synonyms': ['chocolate_bar'], 'id': 248, 'def': 'a bar of chocolate candy', 'name': 'chocolate_bar'}, {'frequency': 'c', 'synset': 'chocolate_cake.n.01', 'synonyms': ['chocolate_cake'], 'id': 249, 'def': 'cake containing chocolate', 'name': 'chocolate_cake'}, {'frequency': 'r', 'synset': 'chocolate_milk.n.01', 'synonyms': ['chocolate_milk'], 'id': 250, 'def': 'milk flavored with chocolate syrup', 'name': 'chocolate_milk'}, {'frequency': 'r', 'synset': 'chocolate_mousse.n.01', 'synonyms': ['chocolate_mousse'], 'id': 251, 'def': 'dessert mousse made with chocolate', 'name': 'chocolate_mousse'}, {'frequency': 'f', 'synset': 'choker.n.03', 'synonyms': ['choker', 'collar', 'neckband'], 'id': 252, 'def': 'shirt collar, animal collar, or tight-fitting necklace', 'name': 'choker'}, {'frequency': 'f', 'synset': 'chopping_board.n.01', 'synonyms': ['chopping_board', 'cutting_board', 'chopping_block'], 'id': 253, 'def': 'a wooden board where meats or vegetables can be cut', 'name': 'chopping_board'}, {'frequency': 'f', 'synset': 'chopstick.n.01', 'synonyms': ['chopstick'], 'id': 254, 'def': 'one of a pair of slender sticks used as oriental tableware to eat food with', 'name': 'chopstick'}, {'frequency': 'f', 'synset': 'christmas_tree.n.05', 'synonyms': ['Christmas_tree'], 'id': 255, 'def': 'an ornamented evergreen used as a Christmas decoration', 'name': 'Christmas_tree'}, {'frequency': 'c', 'synset': 'chute.n.02', 'synonyms': ['slide'], 'id': 256, 'def': 'sloping channel through which things can descend', 'name': 'slide'}, {'frequency': 'r', 'synset': 'cider.n.01', 'synonyms': ['cider', 'cyder'], 'id': 257, 'def': 'a beverage made from juice pressed from apples', 'name': 'cider'}, {'frequency': 'r', 'synset': 'cigar_box.n.01', 'synonyms': ['cigar_box'], 'id': 258, 'def': 'a box for holding cigars', 'name': 'cigar_box'}, {'frequency': 'f', 'synset': 'cigarette.n.01', 'synonyms': ['cigarette'], 'id': 259, 'def': 'finely ground tobacco wrapped in paper; for smoking', 'name': 'cigarette'}, {'frequency': 'c', 'synset': 'cigarette_case.n.01', 'synonyms': ['cigarette_case', 'cigarette_pack'], 'id': 260, 'def': 'a small flat case for holding cigarettes', 'name': 'cigarette_case'}, {'frequency': 'f', 'synset': 'cistern.n.02', 'synonyms': ['cistern', 'water_tank'], 'id': 261, 'def': 'a tank that holds the water used to flush a toilet', 'name': 'cistern'}, {'frequency': 'r', 'synset': 'clarinet.n.01', 'synonyms': ['clarinet'], 'id': 262, 'def': 'a single-reed instrument with a straight tube', 'name': 'clarinet'}, {'frequency': 'c', 'synset': 'clasp.n.01', 'synonyms': ['clasp'], 'id': 263, 'def': 'a fastener (as a buckle or hook) that is used to hold two things together', 'name': 'clasp'}, {'frequency': 'c', 'synset': 'cleansing_agent.n.01', 'synonyms': ['cleansing_agent', 'cleanser', 'cleaner'], 'id': 264, 'def': 'a preparation used in cleaning something', 'name': 'cleansing_agent'}, {'frequency': 'r', 'synset': 'cleat.n.02', 'synonyms': ['cleat_(for_securing_rope)'], 'id': 265, 'def': 'a fastener (usually with two projecting horns) around which a rope can be secured', 'name': 'cleat_(for_securing_rope)'}, {'frequency': 'r', 'synset': 'clementine.n.01', 'synonyms': ['clementine'], 'id': 266, 'def': 'a variety of mandarin orange', 'name': 'clementine'}, {'frequency': 'c', 'synset': 'clip.n.03', 'synonyms': ['clip'], 'id': 267, 'def': 'any of various small fasteners used to hold loose articles together', 'name': 'clip'}, {'frequency': 'c', 'synset': 'clipboard.n.01', 'synonyms': ['clipboard'], 'id': 268, 'def': 'a small writing board with a clip at the top for holding papers', 'name': 'clipboard'}, {'frequency': 'r', 'synset': 'clipper.n.03', 'synonyms': ['clippers_(for_plants)'], 'id': 269, 'def': 'shears for cutting grass or shrubbery (often used in the plural)', 'name': 'clippers_(for_plants)'}, {'frequency': 'r', 'synset': 'cloak.n.02', 'synonyms': ['cloak'], 'id': 270, 'def': 'a loose outer garment', 'name': 'cloak'}, {'frequency': 'f', 'synset': 'clock.n.01', 'synonyms': ['clock', 'timepiece', 'timekeeper'], 'id': 271, 'def': 'a timepiece that shows the time of day', 'name': 'clock'}, {'frequency': 'f', 'synset': 'clock_tower.n.01', 'synonyms': ['clock_tower'], 'id': 272, 'def': 'a tower with a large clock visible high up on an outside face', 'name': 'clock_tower'}, {'frequency': 'c', 'synset': 'clothes_hamper.n.01', 'synonyms': ['clothes_hamper', 'laundry_basket', 'clothes_basket'], 'id': 273, 'def': 'a hamper that holds dirty clothes to be washed or wet clothes to be dried', 'name': 'clothes_hamper'}, {'frequency': 'c', 'synset': 'clothespin.n.01', 'synonyms': ['clothespin', 'clothes_peg'], 'id': 274, 'def': 'wood or plastic fastener; for holding clothes on a clothesline', 'name': 'clothespin'}, {'frequency': 'r', 'synset': 'clutch_bag.n.01', 'synonyms': ['clutch_bag'], 'id': 275, 'def': "a woman's strapless purse that is carried in the hand", 'name': 'clutch_bag'}, {'frequency': 'f', 'synset': 'coaster.n.03', 'synonyms': ['coaster'], 'id': 276, 'def': 'a covering (plate or mat) that protects the surface of a table', 'name': 'coaster'}, {'frequency': 'f', 'synset': 'coat.n.01', 'synonyms': ['coat'], 'id': 277, 'def': 'an outer garment that has sleeves and covers the body from shoulder down', 'name': 'coat'}, {'frequency': 'c', 'synset': 'coat_hanger.n.01', 'synonyms': ['coat_hanger', 'clothes_hanger', 'dress_hanger'], 'id': 278, 'def': "a hanger that is shaped like a person's shoulders", 'name': 'coat_hanger'}, {'frequency': 'c', 'synset': 'coatrack.n.01', 'synonyms': ['coatrack', 'hatrack'], 'id': 279, 'def': 'a rack with hooks for temporarily holding coats and hats', 'name': 'coatrack'}, {'frequency': 'c', 'synset': 'cock.n.04', 'synonyms': ['cock', 'rooster'], 'id': 280, 'def': 'adult male chicken', 'name': 'cock'}, {'frequency': 'r', 'synset': 'cockroach.n.01', 'synonyms': ['cockroach'], 'id': 281, 'def': 'any of numerous chiefly nocturnal insects; some are domestic pests', 'name': 'cockroach'}, {'frequency': 'r', 'synset': 'cocoa.n.01', 'synonyms': ['cocoa_(beverage)', 'hot_chocolate_(beverage)', 'drinking_chocolate'], 'id': 282, 'def': 'a beverage made from cocoa powder and milk and sugar; usually drunk hot', 'name': 'cocoa_(beverage)'}, {'frequency': 'c', 'synset': 'coconut.n.02', 'synonyms': ['coconut', 'cocoanut'], 'id': 283, 'def': 'large hard-shelled brown oval nut with a fibrous husk', 'name': 'coconut'}, {'frequency': 'f', 'synset': 'coffee_maker.n.01', 'synonyms': ['coffee_maker', 'coffee_machine'], 'id': 284, 'def': 'a kitchen appliance for brewing coffee automatically', 'name': 'coffee_maker'}, {'frequency': 'f', 'synset': 'coffee_table.n.01', 'synonyms': ['coffee_table', 'cocktail_table'], 'id': 285, 'def': 'low table where magazines can be placed and coffee or cocktails are served', 'name': 'coffee_table'}, {'frequency': 'c', 'synset': 'coffeepot.n.01', 'synonyms': ['coffeepot'], 'id': 286, 'def': 'tall pot in which coffee is brewed', 'name': 'coffeepot'}, {'frequency': 'r', 'synset': 'coil.n.05', 'synonyms': ['coil'], 'id': 287, 'def': 'tubing that is wound in a spiral', 'name': 'coil'}, {'frequency': 'c', 'synset': 'coin.n.01', 'synonyms': ['coin'], 'id': 288, 'def': 'a flat metal piece (usually a disc) used as money', 'name': 'coin'}, {'frequency': 'c', 'synset': 'colander.n.01', 'synonyms': ['colander', 'cullender'], 'id': 289, 'def': 'bowl-shaped strainer; used to wash or drain foods', 'name': 'colander'}, {'frequency': 'c', 'synset': 'coleslaw.n.01', 'synonyms': ['coleslaw', 'slaw'], 'id': 290, 'def': 'basically shredded cabbage', 'name': 'coleslaw'}, {'frequency': 'r', 'synset': 'coloring_material.n.01', 'synonyms': ['coloring_material', 'colouring_material'], 'id': 291, 'def': 'any material used for its color', 'name': 'coloring_material'}, {'frequency': 'r', 'synset': 'combination_lock.n.01', 'synonyms': ['combination_lock'], 'id': 292, 'def': 'lock that can be opened only by turning dials in a special sequence', 'name': 'combination_lock'}, {'frequency': 'c', 'synset': 'comforter.n.04', 'synonyms': ['pacifier', 'teething_ring'], 'id': 293, 'def': 'device used for an infant to suck or bite on', 'name': 'pacifier'}, {'frequency': 'r', 'synset': 'comic_book.n.01', 'synonyms': ['comic_book'], 'id': 294, 'def': 'a magazine devoted to comic strips', 'name': 'comic_book'}, {'frequency': 'r', 'synset': 'compass.n.01', 'synonyms': ['compass'], 'id': 295, 'def': 'navigational instrument for finding directions', 'name': 'compass'}, {'frequency': 'f', 'synset': 'computer_keyboard.n.01', 'synonyms': ['computer_keyboard', 'keyboard_(computer)'], 'id': 296, 'def': 'a keyboard that is a data input device for computers', 'name': 'computer_keyboard'}, {'frequency': 'f', 'synset': 'condiment.n.01', 'synonyms': ['condiment'], 'id': 297, 'def': 'a preparation (a sauce or relish or spice) to enhance flavor or enjoyment', 'name': 'condiment'}, {'frequency': 'f', 'synset': 'cone.n.01', 'synonyms': ['cone', 'traffic_cone'], 'id': 298, 'def': 'a cone-shaped object used to direct traffic', 'name': 'cone'}, {'frequency': 'f', 'synset': 'control.n.09', 'synonyms': ['control', 'controller'], 'id': 299, 'def': 'a mechanism that controls the operation of a machine', 'name': 'control'}, {'frequency': 'r', 'synset': 'convertible.n.01', 'synonyms': ['convertible_(automobile)'], 'id': 300, 'def': 'a car that has top that can be folded or removed', 'name': 'convertible_(automobile)'}, {'frequency': 'r', 'synset': 'convertible.n.03', 'synonyms': ['sofa_bed'], 'id': 301, 'def': 'a sofa that can be converted into a bed', 'name': 'sofa_bed'}, {'frequency': 'r', 'synset': 'cooker.n.01', 'synonyms': ['cooker'], 'id': 302, 'def': 'a utensil for cooking', 'name': 'cooker'}, {'frequency': 'f', 'synset': 'cookie.n.01', 'synonyms': ['cookie', 'cooky', 'biscuit_(cookie)'], 'id': 303, 'def': "any of various small flat sweet cakes (`biscuit' is the British term)", 'name': 'cookie'}, {'frequency': 'r', 'synset': 'cooking_utensil.n.01', 'synonyms': ['cooking_utensil'], 'id': 304, 'def': 'a kitchen utensil made of material that does not melt easily; used for cooking', 'name': 'cooking_utensil'}, {'frequency': 'f', 'synset': 'cooler.n.01', 'synonyms': ['cooler_(for_food)', 'ice_chest'], 'id': 305, 'def': 'an insulated box for storing food often with ice', 'name': 'cooler_(for_food)'}, {'frequency': 'f', 'synset': 'cork.n.04', 'synonyms': ['cork_(bottle_plug)', 'bottle_cork'], 'id': 306, 'def': 'the plug in the mouth of a bottle (especially a wine bottle)', 'name': 'cork_(bottle_plug)'}, {'frequency': 'r', 'synset': 'corkboard.n.01', 'synonyms': ['corkboard'], 'id': 307, 'def': 'a sheet consisting of cork granules', 'name': 'corkboard'}, {'frequency': 'c', 'synset': 'corkscrew.n.01', 'synonyms': ['corkscrew', 'bottle_screw'], 'id': 308, 'def': 'a bottle opener that pulls corks', 'name': 'corkscrew'}, {'frequency': 'f', 'synset': 'corn.n.03', 'synonyms': ['edible_corn', 'corn', 'maize'], 'id': 309, 'def': 'ears or kernels of corn that can be prepared and served for human food (only mark individual ears or kernels)', 'name': 'edible_corn'}, {'frequency': 'r', 'synset': 'cornbread.n.01', 'synonyms': ['cornbread'], 'id': 310, 'def': 'bread made primarily of cornmeal', 'name': 'cornbread'}, {'frequency': 'c', 'synset': 'cornet.n.01', 'synonyms': ['cornet', 'horn', 'trumpet'], 'id': 311, 'def': 'a brass musical instrument with a narrow tube and a flared bell and many valves', 'name': 'cornet'}, {'frequency': 'c', 'synset': 'cornice.n.01', 'synonyms': ['cornice', 'valance', 'valance_board', 'pelmet'], 'id': 312, 'def': 'a decorative framework to conceal curtain fixtures at the top of a window casing', 'name': 'cornice'}, {'frequency': 'r', 'synset': 'cornmeal.n.01', 'synonyms': ['cornmeal'], 'id': 313, 'def': 'coarsely ground corn', 'name': 'cornmeal'}, {'frequency': 'c', 'synset': 'corset.n.01', 'synonyms': ['corset', 'girdle'], 'id': 314, 'def': "a woman's close-fitting foundation garment", 'name': 'corset'}, {'frequency': 'c', 'synset': 'costume.n.04', 'synonyms': ['costume'], 'id': 315, 'def': 'the attire characteristic of a country or a time or a social class', 'name': 'costume'}, {'frequency': 'r', 'synset': 'cougar.n.01', 'synonyms': ['cougar', 'puma', 'catamount', 'mountain_lion', 'panther'], 'id': 316, 'def': 'large American feline resembling a lion', 'name': 'cougar'}, {'frequency': 'r', 'synset': 'coverall.n.01', 'synonyms': ['coverall'], 'id': 317, 'def': 'a loose-fitting protective garment that is worn over other clothing', 'name': 'coverall'}, {'frequency': 'c', 'synset': 'cowbell.n.01', 'synonyms': ['cowbell'], 'id': 318, 'def': 'a bell hung around the neck of cow so that the cow can be easily located', 'name': 'cowbell'}, {'frequency': 'f', 'synset': 'cowboy_hat.n.01', 'synonyms': ['cowboy_hat', 'ten-gallon_hat'], 'id': 319, 'def': 'a hat with a wide brim and a soft crown; worn by American ranch hands', 'name': 'cowboy_hat'}, {'frequency': 'c', 'synset': 'crab.n.01', 'synonyms': ['crab_(animal)'], 'id': 320, 'def': 'decapod having eyes on short stalks and a broad flattened shell and pincers', 'name': 'crab_(animal)'}, {'frequency': 'r', 'synset': 'crab.n.05', 'synonyms': ['crabmeat'], 'id': 321, 'def': 'the edible flesh of any of various crabs', 'name': 'crabmeat'}, {'frequency': 'c', 'synset': 'cracker.n.01', 'synonyms': ['cracker'], 'id': 322, 'def': 'a thin crisp wafer', 'name': 'cracker'}, {'frequency': 'r', 'synset': 'crape.n.01', 'synonyms': ['crape', 'crepe', 'French_pancake'], 'id': 323, 'def': 'small very thin pancake', 'name': 'crape'}, {'frequency': 'f', 'synset': 'crate.n.01', 'synonyms': ['crate'], 'id': 324, 'def': 'a rugged box (usually made of wood); used for shipping', 'name': 'crate'}, {'frequency': 'c', 'synset': 'crayon.n.01', 'synonyms': ['crayon', 'wax_crayon'], 'id': 325, 'def': 'writing or drawing implement made of a colored stick of composition wax', 'name': 'crayon'}, {'frequency': 'r', 'synset': 'cream_pitcher.n.01', 'synonyms': ['cream_pitcher'], 'id': 326, 'def': 'a small pitcher for serving cream', 'name': 'cream_pitcher'}, {'frequency': 'c', 'synset': 'crescent_roll.n.01', 'synonyms': ['crescent_roll', 'croissant'], 'id': 327, 'def': 'very rich flaky crescent-shaped roll', 'name': 'crescent_roll'}, {'frequency': 'c', 'synset': 'crib.n.01', 'synonyms': ['crib', 'cot'], 'id': 328, 'def': 'baby bed with high sides made of slats', 'name': 'crib'}, {'frequency': 'c', 'synset': 'crock.n.03', 'synonyms': ['crock_pot', 'earthenware_jar'], 'id': 329, 'def': 'an earthen jar (made of baked clay) or a modern electric crockpot', 'name': 'crock_pot'}, {'frequency': 'f', 'synset': 'crossbar.n.01', 'synonyms': ['crossbar'], 'id': 330, 'def': 'a horizontal bar that goes across something', 'name': 'crossbar'}, {'frequency': 'r', 'synset': 'crouton.n.01', 'synonyms': ['crouton'], 'id': 331, 'def': 'a small piece of toasted or fried bread; served in soup or salads', 'name': 'crouton'}, {'frequency': 'c', 'synset': 'crow.n.01', 'synonyms': ['crow'], 'id': 332, 'def': 'black birds having a raucous call', 'name': 'crow'}, {'frequency': 'r', 'synset': 'crowbar.n.01', 'synonyms': ['crowbar', 'wrecking_bar', 'pry_bar'], 'id': 333, 'def': 'a heavy iron lever with one end forged into a wedge', 'name': 'crowbar'}, {'frequency': 'c', 'synset': 'crown.n.04', 'synonyms': ['crown'], 'id': 334, 'def': 'an ornamental jeweled headdress signifying sovereignty', 'name': 'crown'}, {'frequency': 'c', 'synset': 'crucifix.n.01', 'synonyms': ['crucifix'], 'id': 335, 'def': 'representation of the cross on which Jesus died', 'name': 'crucifix'}, {'frequency': 'c', 'synset': 'cruise_ship.n.01', 'synonyms': ['cruise_ship', 'cruise_liner'], 'id': 336, 'def': 'a passenger ship used commercially for pleasure cruises', 'name': 'cruise_ship'}, {'frequency': 'c', 'synset': 'cruiser.n.01', 'synonyms': ['police_cruiser', 'patrol_car', 'police_car', 'squad_car'], 'id': 337, 'def': 'a car in which policemen cruise the streets', 'name': 'police_cruiser'}, {'frequency': 'f', 'synset': 'crumb.n.03', 'synonyms': ['crumb'], 'id': 338, 'def': 'small piece of e.g. bread or cake', 'name': 'crumb'}, {'frequency': 'c', 'synset': 'crutch.n.01', 'synonyms': ['crutch'], 'id': 339, 'def': 'a wooden or metal staff that fits under the armpit and reaches to the ground', 'name': 'crutch'}, {'frequency': 'c', 'synset': 'cub.n.03', 'synonyms': ['cub_(animal)'], 'id': 340, 'def': 'the young of certain carnivorous mammals such as the bear or wolf or lion', 'name': 'cub_(animal)'}, {'frequency': 'c', 'synset': 'cube.n.05', 'synonyms': ['cube', 'square_block'], 'id': 341, 'def': 'a block in the (approximate) shape of a cube', 'name': 'cube'}, {'frequency': 'f', 'synset': 'cucumber.n.02', 'synonyms': ['cucumber', 'cuke'], 'id': 342, 'def': 'cylindrical green fruit with thin green rind and white flesh eaten as a vegetable', 'name': 'cucumber'}, {'frequency': 'c', 'synset': 'cufflink.n.01', 'synonyms': ['cufflink'], 'id': 343, 'def': 'jewelry consisting of linked buttons used to fasten the cuffs of a shirt', 'name': 'cufflink'}, {'frequency': 'f', 'synset': 'cup.n.01', 'synonyms': ['cup'], 'id': 344, 'def': 'a small open container usually used for drinking; usually has a handle', 'name': 'cup'}, {'frequency': 'c', 'synset': 'cup.n.08', 'synonyms': ['trophy_cup'], 'id': 345, 'def': 'a metal award or cup-shaped vessel with handles that is awarded as a trophy to a competition winner', 'name': 'trophy_cup'}, {'frequency': 'f', 'synset': 'cupboard.n.01', 'synonyms': ['cupboard', 'closet'], 'id': 346, 'def': 'a small room (or recess) or cabinet used for storage space', 'name': 'cupboard'}, {'frequency': 'f', 'synset': 'cupcake.n.01', 'synonyms': ['cupcake'], 'id': 347, 'def': 'small cake baked in a muffin tin', 'name': 'cupcake'}, {'frequency': 'r', 'synset': 'curler.n.01', 'synonyms': ['hair_curler', 'hair_roller', 'hair_crimper'], 'id': 348, 'def': 'a cylindrical tube around which the hair is wound to curl it', 'name': 'hair_curler'}, {'frequency': 'r', 'synset': 'curling_iron.n.01', 'synonyms': ['curling_iron'], 'id': 349, 'def': 'a cylindrical home appliance that heats hair that has been curled around it', 'name': 'curling_iron'}, {'frequency': 'f', 'synset': 'curtain.n.01', 'synonyms': ['curtain', 'drapery'], 'id': 350, 'def': 'hanging cloth used as a blind (especially for a window)', 'name': 'curtain'}, {'frequency': 'f', 'synset': 'cushion.n.03', 'synonyms': ['cushion'], 'id': 351, 'def': 'a soft bag filled with air or padding such as feathers or foam rubber', 'name': 'cushion'}, {'frequency': 'r', 'synset': 'cylinder.n.04', 'synonyms': ['cylinder'], 'id': 352, 'def': 'a cylindrical container', 'name': 'cylinder'}, {'frequency': 'r', 'synset': 'cymbal.n.01', 'synonyms': ['cymbal'], 'id': 353, 'def': 'a percussion instrument consisting of a concave brass disk', 'name': 'cymbal'}, {'frequency': 'r', 'synset': 'dagger.n.01', 'synonyms': ['dagger'], 'id': 354, 'def': 'a short knife with a pointed blade used for piercing or stabbing', 'name': 'dagger'}, {'frequency': 'r', 'synset': 'dalmatian.n.02', 'synonyms': ['dalmatian'], 'id': 355, 'def': 'a large breed having a smooth white coat with black or brown spots', 'name': 'dalmatian'}, {'frequency': 'c', 'synset': 'dartboard.n.01', 'synonyms': ['dartboard'], 'id': 356, 'def': 'a circular board of wood or cork used as the target in the game of darts', 'name': 'dartboard'}, {'frequency': 'r', 'synset': 'date.n.08', 'synonyms': ['date_(fruit)'], 'id': 357, 'def': 'sweet edible fruit of the date palm with a single long woody seed', 'name': 'date_(fruit)'}, {'frequency': 'f', 'synset': 'deck_chair.n.01', 'synonyms': ['deck_chair', 'beach_chair'], 'id': 358, 'def': 'a folding chair for use outdoors; a wooden frame supports a length of canvas', 'name': 'deck_chair'}, {'frequency': 'c', 'synset': 'deer.n.01', 'synonyms': ['deer', 'cervid'], 'id': 359, 'def': "distinguished from Bovidae by the male's having solid deciduous antlers", 'name': 'deer'}, {'frequency': 'c', 'synset': 'dental_floss.n.01', 'synonyms': ['dental_floss', 'floss'], 'id': 360, 'def': 'a soft thread for cleaning the spaces between the teeth', 'name': 'dental_floss'}, {'frequency': 'f', 'synset': 'desk.n.01', 'synonyms': ['desk'], 'id': 361, 'def': 'a piece of furniture with a writing surface and usually drawers or other compartments', 'name': 'desk'}, {'frequency': 'r', 'synset': 'detergent.n.01', 'synonyms': ['detergent'], 'id': 362, 'def': 'a surface-active chemical widely used in industry and laundering', 'name': 'detergent'}, {'frequency': 'c', 'synset': 'diaper.n.01', 'synonyms': ['diaper'], 'id': 363, 'def': 'garment consisting of a folded cloth drawn up between the legs and fastened at the waist', 'name': 'diaper'}, {'frequency': 'r', 'synset': 'diary.n.01', 'synonyms': ['diary', 'journal'], 'id': 364, 'def': 'yearly planner book', 'name': 'diary'}, {'frequency': 'r', 'synset': 'die.n.01', 'synonyms': ['die', 'dice'], 'id': 365, 'def': 'a small cube with 1 to 6 spots on the six faces; used in gambling', 'name': 'die'}, {'frequency': 'r', 'synset': 'dinghy.n.01', 'synonyms': ['dinghy', 'dory', 'rowboat'], 'id': 366, 'def': 'a small boat of shallow draft with seats and oars with which it is propelled', 'name': 'dinghy'}, {'frequency': 'f', 'synset': 'dining_table.n.01', 'synonyms': ['dining_table'], 'id': 367, 'def': 'a table at which meals are served', 'name': 'dining_table'}, {'frequency': 'r', 'synset': 'dinner_jacket.n.01', 'synonyms': ['tux', 'tuxedo'], 'id': 368, 'def': 'semiformal evening dress for men', 'name': 'tux'}, {'frequency': 'f', 'synset': 'dish.n.01', 'synonyms': ['dish'], 'id': 369, 'def': 'a piece of dishware normally used as a container for holding or serving food', 'name': 'dish'}, {'frequency': 'c', 'synset': 'dish.n.05', 'synonyms': ['dish_antenna'], 'id': 370, 'def': 'directional antenna consisting of a parabolic reflector', 'name': 'dish_antenna'}, {'frequency': 'c', 'synset': 'dishrag.n.01', 'synonyms': ['dishrag', 'dishcloth'], 'id': 371, 'def': 'a cloth for washing dishes or cleaning in general', 'name': 'dishrag'}, {'frequency': 'f', 'synset': 'dishtowel.n.01', 'synonyms': ['dishtowel', 'tea_towel'], 'id': 372, 'def': 'a towel for drying dishes', 'name': 'dishtowel'}, {'frequency': 'f', 'synset': 'dishwasher.n.01', 'synonyms': ['dishwasher', 'dishwashing_machine'], 'id': 373, 'def': 'a machine for washing dishes', 'name': 'dishwasher'}, {'frequency': 'r', 'synset': 'dishwasher_detergent.n.01', 'synonyms': ['dishwasher_detergent', 'dishwashing_detergent', 'dishwashing_liquid', 'dishsoap'], 'id': 374, 'def': 'dishsoap or dish detergent designed for use in dishwashers', 'name': 'dishwasher_detergent'}, {'frequency': 'f', 'synset': 'dispenser.n.01', 'synonyms': ['dispenser'], 'id': 375, 'def': 'a container so designed that the contents can be used in prescribed amounts', 'name': 'dispenser'}, {'frequency': 'r', 'synset': 'diving_board.n.01', 'synonyms': ['diving_board'], 'id': 376, 'def': 'a springboard from which swimmers can dive', 'name': 'diving_board'}, {'frequency': 'f', 'synset': 'dixie_cup.n.01', 'synonyms': ['Dixie_cup', 'paper_cup'], 'id': 377, 'def': 'a disposable cup made of paper; for holding drinks', 'name': 'Dixie_cup'}, {'frequency': 'f', 'synset': 'dog.n.01', 'synonyms': ['dog'], 'id': 378, 'def': 'a common domesticated dog', 'name': 'dog'}, {'frequency': 'f', 'synset': 'dog_collar.n.01', 'synonyms': ['dog_collar'], 'id': 379, 'def': 'a collar for a dog', 'name': 'dog_collar'}, {'frequency': 'f', 'synset': 'doll.n.01', 'synonyms': ['doll'], 'id': 380, 'def': 'a toy replica of a HUMAN (NOT AN ANIMAL)', 'name': 'doll'}, {'frequency': 'r', 'synset': 'dollar.n.02', 'synonyms': ['dollar', 'dollar_bill', 'one_dollar_bill'], 'id': 381, 'def': 'a piece of paper money worth one dollar', 'name': 'dollar'}, {'frequency': 'r', 'synset': 'dollhouse.n.01', 'synonyms': ['dollhouse', "doll's_house"], 'id': 382, 'def': "a house so small that it is likened to a child's plaything", 'name': 'dollhouse'}, {'frequency': 'c', 'synset': 'dolphin.n.02', 'synonyms': ['dolphin'], 'id': 383, 'def': 'any of various small toothed whales with a beaklike snout; larger than porpoises', 'name': 'dolphin'}, {'frequency': 'c', 'synset': 'domestic_ass.n.01', 'synonyms': ['domestic_ass', 'donkey'], 'id': 384, 'def': 'domestic beast of burden descended from the African wild ass; patient but stubborn', 'name': 'domestic_ass'}, {'frequency': 'f', 'synset': 'doorknob.n.01', 'synonyms': ['doorknob', 'doorhandle'], 'id': 385, 'def': "a knob used to open a door (often called `doorhandle' in Great Britain)", 'name': 'doorknob'}, {'frequency': 'c', 'synset': 'doormat.n.02', 'synonyms': ['doormat', 'welcome_mat'], 'id': 386, 'def': 'a mat placed outside an exterior door for wiping the shoes before entering', 'name': 'doormat'}, {'frequency': 'f', 'synset': 'doughnut.n.02', 'synonyms': ['doughnut', 'donut'], 'id': 387, 'def': 'a small ring-shaped friedcake', 'name': 'doughnut'}, {'frequency': 'r', 'synset': 'dove.n.01', 'synonyms': ['dove'], 'id': 388, 'def': 'any of numerous small pigeons', 'name': 'dove'}, {'frequency': 'r', 'synset': 'dragonfly.n.01', 'synonyms': ['dragonfly'], 'id': 389, 'def': 'slender-bodied non-stinging insect having iridescent wings that are outspread at rest', 'name': 'dragonfly'}, {'frequency': 'f', 'synset': 'drawer.n.01', 'synonyms': ['drawer'], 'id': 390, 'def': 'a boxlike container in a piece of furniture; made so as to slide in and out', 'name': 'drawer'}, {'frequency': 'c', 'synset': 'drawers.n.01', 'synonyms': ['underdrawers', 'boxers', 'boxershorts'], 'id': 391, 'def': 'underpants worn by men', 'name': 'underdrawers'}, {'frequency': 'f', 'synset': 'dress.n.01', 'synonyms': ['dress', 'frock'], 'id': 392, 'def': 'a one-piece garment for a woman; has skirt and bodice', 'name': 'dress'}, {'frequency': 'c', 'synset': 'dress_hat.n.01', 'synonyms': ['dress_hat', 'high_hat', 'opera_hat', 'silk_hat', 'top_hat'], 'id': 393, 'def': "a man's hat with a tall crown; usually covered with silk or with beaver fur", 'name': 'dress_hat'}, {'frequency': 'f', 'synset': 'dress_suit.n.01', 'synonyms': ['dress_suit'], 'id': 394, 'def': 'formalwear consisting of full evening dress for men', 'name': 'dress_suit'}, {'frequency': 'f', 'synset': 'dresser.n.05', 'synonyms': ['dresser'], 'id': 395, 'def': 'a cabinet with shelves', 'name': 'dresser'}, {'frequency': 'c', 'synset': 'drill.n.01', 'synonyms': ['drill'], 'id': 396, 'def': 'a tool with a sharp rotating point for making holes in hard materials', 'name': 'drill'}, {'frequency': 'r', 'synset': 'drone.n.04', 'synonyms': ['drone'], 'id': 397, 'def': 'an aircraft without a pilot that is operated by remote control', 'name': 'drone'}, {'frequency': 'r', 'synset': 'dropper.n.01', 'synonyms': ['dropper', 'eye_dropper'], 'id': 398, 'def': 'pipet consisting of a small tube with a vacuum bulb at one end for drawing liquid in and releasing it a drop at a time', 'name': 'dropper'}, {'frequency': 'c', 'synset': 'drum.n.01', 'synonyms': ['drum_(musical_instrument)'], 'id': 399, 'def': 'a musical percussion instrument; usually consists of a hollow cylinder with a membrane stretched across each end', 'name': 'drum_(musical_instrument)'}, {'frequency': 'r', 'synset': 'drumstick.n.02', 'synonyms': ['drumstick'], 'id': 400, 'def': 'a stick used for playing a drum', 'name': 'drumstick'}, {'frequency': 'f', 'synset': 'duck.n.01', 'synonyms': ['duck'], 'id': 401, 'def': 'small web-footed broad-billed swimming bird', 'name': 'duck'}, {'frequency': 'c', 'synset': 'duckling.n.02', 'synonyms': ['duckling'], 'id': 402, 'def': 'young duck', 'name': 'duckling'}, {'frequency': 'c', 'synset': 'duct_tape.n.01', 'synonyms': ['duct_tape'], 'id': 403, 'def': 'a wide silvery adhesive tape', 'name': 'duct_tape'}, {'frequency': 'f', 'synset': 'duffel_bag.n.01', 'synonyms': ['duffel_bag', 'duffle_bag', 'duffel', 'duffle'], 'id': 404, 'def': 'a large cylindrical bag of heavy cloth (does not include suitcases)', 'name': 'duffel_bag'}, {'frequency': 'r', 'synset': 'dumbbell.n.01', 'synonyms': ['dumbbell'], 'id': 405, 'def': 'an exercising weight with two ball-like ends connected by a short handle', 'name': 'dumbbell'}, {'frequency': 'c', 'synset': 'dumpster.n.01', 'synonyms': ['dumpster'], 'id': 406, 'def': 'a container designed to receive and transport and dump waste', 'name': 'dumpster'}, {'frequency': 'r', 'synset': 'dustpan.n.02', 'synonyms': ['dustpan'], 'id': 407, 'def': 'a short-handled receptacle into which dust can be swept', 'name': 'dustpan'}, {'frequency': 'c', 'synset': 'eagle.n.01', 'synonyms': ['eagle'], 'id': 408, 'def': 'large birds of prey noted for their broad wings and strong soaring flight', 'name': 'eagle'}, {'frequency': 'f', 'synset': 'earphone.n.01', 'synonyms': ['earphone', 'earpiece', 'headphone'], 'id': 409, 'def': 'device for listening to audio that is held over or inserted into the ear', 'name': 'earphone'}, {'frequency': 'r', 'synset': 'earplug.n.01', 'synonyms': ['earplug'], 'id': 410, 'def': 'a soft plug that is inserted into the ear canal to block sound', 'name': 'earplug'}, {'frequency': 'f', 'synset': 'earring.n.01', 'synonyms': ['earring'], 'id': 411, 'def': 'jewelry to ornament the ear', 'name': 'earring'}, {'frequency': 'c', 'synset': 'easel.n.01', 'synonyms': ['easel'], 'id': 412, 'def': "an upright tripod for displaying something (usually an artist's canvas)", 'name': 'easel'}, {'frequency': 'r', 'synset': 'eclair.n.01', 'synonyms': ['eclair'], 'id': 413, 'def': 'oblong cream puff', 'name': 'eclair'}, {'frequency': 'r', 'synset': 'eel.n.01', 'synonyms': ['eel'], 'id': 414, 'def': 'an elongate fish with fatty flesh', 'name': 'eel'}, {'frequency': 'f', 'synset': 'egg.n.02', 'synonyms': ['egg', 'eggs'], 'id': 415, 'def': 'oval reproductive body of a fowl (especially a hen) used as food', 'name': 'egg'}, {'frequency': 'r', 'synset': 'egg_roll.n.01', 'synonyms': ['egg_roll', 'spring_roll'], 'id': 416, 'def': 'minced vegetables and meat wrapped in a pancake and fried', 'name': 'egg_roll'}, {'frequency': 'c', 'synset': 'egg_yolk.n.01', 'synonyms': ['egg_yolk', 'yolk_(egg)'], 'id': 417, 'def': 'the yellow spherical part of an egg', 'name': 'egg_yolk'}, {'frequency': 'c', 'synset': 'eggbeater.n.02', 'synonyms': ['eggbeater', 'eggwhisk'], 'id': 418, 'def': 'a mixer for beating eggs or whipping cream', 'name': 'eggbeater'}, {'frequency': 'c', 'synset': 'eggplant.n.01', 'synonyms': ['eggplant', 'aubergine'], 'id': 419, 'def': 'egg-shaped vegetable having a shiny skin typically dark purple', 'name': 'eggplant'}, {'frequency': 'r', 'synset': 'electric_chair.n.01', 'synonyms': ['electric_chair'], 'id': 420, 'def': 'a chair-shaped instrument of execution by electrocution', 'name': 'electric_chair'}, {'frequency': 'f', 'synset': 'electric_refrigerator.n.01', 'synonyms': ['refrigerator'], 'id': 421, 'def': 'a refrigerator in which the coolant is pumped around by an electric motor', 'name': 'refrigerator'}, {'frequency': 'f', 'synset': 'elephant.n.01', 'synonyms': ['elephant'], 'id': 422, 'def': 'a common elephant', 'name': 'elephant'}, {'frequency': 'c', 'synset': 'elk.n.01', 'synonyms': ['elk', 'moose'], 'id': 423, 'def': 'large northern deer with enormous flattened antlers in the male', 'name': 'elk'}, {'frequency': 'c', 'synset': 'envelope.n.01', 'synonyms': ['envelope'], 'id': 424, 'def': 'a flat (usually rectangular) container for a letter, thin package, etc.', 'name': 'envelope'}, {'frequency': 'c', 'synset': 'eraser.n.01', 'synonyms': ['eraser'], 'id': 425, 'def': 'an implement used to erase something', 'name': 'eraser'}, {'frequency': 'r', 'synset': 'escargot.n.01', 'synonyms': ['escargot'], 'id': 426, 'def': 'edible snail usually served in the shell with a sauce of melted butter and garlic', 'name': 'escargot'}, {'frequency': 'r', 'synset': 'eyepatch.n.01', 'synonyms': ['eyepatch'], 'id': 427, 'def': 'a protective cloth covering for an injured eye', 'name': 'eyepatch'}, {'frequency': 'r', 'synset': 'falcon.n.01', 'synonyms': ['falcon'], 'id': 428, 'def': 'birds of prey having long pointed powerful wings adapted for swift flight', 'name': 'falcon'}, {'frequency': 'f', 'synset': 'fan.n.01', 'synonyms': ['fan'], 'id': 429, 'def': 'a device for creating a current of air by movement of a surface or surfaces', 'name': 'fan'}, {'frequency': 'f', 'synset': 'faucet.n.01', 'synonyms': ['faucet', 'spigot', 'tap'], 'id': 430, 'def': 'a regulator for controlling the flow of a liquid from a reservoir', 'name': 'faucet'}, {'frequency': 'r', 'synset': 'fedora.n.01', 'synonyms': ['fedora'], 'id': 431, 'def': 'a hat made of felt with a creased crown', 'name': 'fedora'}, {'frequency': 'r', 'synset': 'ferret.n.02', 'synonyms': ['ferret'], 'id': 432, 'def': 'domesticated albino variety of the European polecat bred for hunting rats and rabbits', 'name': 'ferret'}, {'frequency': 'c', 'synset': 'ferris_wheel.n.01', 'synonyms': ['Ferris_wheel'], 'id': 433, 'def': 'a large wheel with suspended seats that remain upright as the wheel rotates', 'name': 'Ferris_wheel'}, {'frequency': 'c', 'synset': 'ferry.n.01', 'synonyms': ['ferry', 'ferryboat'], 'id': 434, 'def': 'a boat that transports people or vehicles across a body of water and operates on a regular schedule', 'name': 'ferry'}, {'frequency': 'r', 'synset': 'fig.n.04', 'synonyms': ['fig_(fruit)'], 'id': 435, 'def': 'fleshy sweet pear-shaped yellowish or purple fruit eaten fresh or preserved or dried', 'name': 'fig_(fruit)'}, {'frequency': 'c', 'synset': 'fighter.n.02', 'synonyms': ['fighter_jet', 'fighter_aircraft', 'attack_aircraft'], 'id': 436, 'def': 'a high-speed military or naval airplane designed to destroy enemy targets', 'name': 'fighter_jet'}, {'frequency': 'f', 'synset': 'figurine.n.01', 'synonyms': ['figurine'], 'id': 437, 'def': 'a small carved or molded figure', 'name': 'figurine'}, {'frequency': 'c', 'synset': 'file.n.03', 'synonyms': ['file_cabinet', 'filing_cabinet'], 'id': 438, 'def': 'office furniture consisting of a container for keeping papers in order', 'name': 'file_cabinet'}, {'frequency': 'r', 'synset': 'file.n.04', 'synonyms': ['file_(tool)'], 'id': 439, 'def': 'a steel hand tool with small sharp teeth on some or all of its surfaces; used for smoothing wood or metal', 'name': 'file_(tool)'}, {'frequency': 'f', 'synset': 'fire_alarm.n.02', 'synonyms': ['fire_alarm', 'smoke_alarm'], 'id': 440, 'def': 'an alarm that is tripped off by fire or smoke', 'name': 'fire_alarm'}, {'frequency': 'f', 'synset': 'fire_engine.n.01', 'synonyms': ['fire_engine', 'fire_truck'], 'id': 441, 'def': 'large trucks that carry firefighters and equipment to the site of a fire', 'name': 'fire_engine'}, {'frequency': 'f', 'synset': 'fire_extinguisher.n.01', 'synonyms': ['fire_extinguisher', 'extinguisher'], 'id': 442, 'def': 'a manually operated device for extinguishing small fires', 'name': 'fire_extinguisher'}, {'frequency': 'c', 'synset': 'fire_hose.n.01', 'synonyms': ['fire_hose'], 'id': 443, 'def': 'a large hose that carries water from a fire hydrant to the site of the fire', 'name': 'fire_hose'}, {'frequency': 'f', 'synset': 'fireplace.n.01', 'synonyms': ['fireplace'], 'id': 444, 'def': 'an open recess in a wall at the base of a chimney where a fire can be built', 'name': 'fireplace'}, {'frequency': 'f', 'synset': 'fireplug.n.01', 'synonyms': ['fireplug', 'fire_hydrant', 'hydrant'], 'id': 445, 'def': 'an upright hydrant for drawing water to use in fighting a fire', 'name': 'fireplug'}, {'frequency': 'r', 'synset': 'first-aid_kit.n.01', 'synonyms': ['first-aid_kit'], 'id': 446, 'def': 'kit consisting of a set of bandages and medicines for giving first aid', 'name': 'first-aid_kit'}, {'frequency': 'f', 'synset': 'fish.n.01', 'synonyms': ['fish'], 'id': 447, 'def': 'any of various mostly cold-blooded aquatic vertebrates usually having scales and breathing through gills', 'name': 'fish'}, {'frequency': 'c', 'synset': 'fish.n.02', 'synonyms': ['fish_(food)'], 'id': 448, 'def': 'the flesh of fish used as food', 'name': 'fish_(food)'}, {'frequency': 'r', 'synset': 'fishbowl.n.02', 'synonyms': ['fishbowl', 'goldfish_bowl'], 'id': 449, 'def': 'a transparent bowl in which small fish are kept', 'name': 'fishbowl'}, {'frequency': 'c', 'synset': 'fishing_rod.n.01', 'synonyms': ['fishing_rod', 'fishing_pole'], 'id': 450, 'def': 'a rod that is used in fishing to extend the fishing line', 'name': 'fishing_rod'}, {'frequency': 'f', 'synset': 'flag.n.01', 'synonyms': ['flag'], 'id': 451, 'def': 'emblem usually consisting of a rectangular piece of cloth of distinctive design (do not include pole)', 'name': 'flag'}, {'frequency': 'f', 'synset': 'flagpole.n.02', 'synonyms': ['flagpole', 'flagstaff'], 'id': 452, 'def': 'a tall staff or pole on which a flag is raised', 'name': 'flagpole'}, {'frequency': 'c', 'synset': 'flamingo.n.01', 'synonyms': ['flamingo'], 'id': 453, 'def': 'large pink web-footed bird with down-bent bill', 'name': 'flamingo'}, {'frequency': 'c', 'synset': 'flannel.n.01', 'synonyms': ['flannel'], 'id': 454, 'def': 'a soft light woolen fabric; used for clothing', 'name': 'flannel'}, {'frequency': 'c', 'synset': 'flap.n.01', 'synonyms': ['flap'], 'id': 455, 'def': 'any broad thin covering attached at one edge, such as a mud flap next to a wheel or a flap on an airplane wing', 'name': 'flap'}, {'frequency': 'r', 'synset': 'flash.n.10', 'synonyms': ['flash', 'flashbulb'], 'id': 456, 'def': 'a lamp for providing momentary light to take a photograph', 'name': 'flash'}, {'frequency': 'c', 'synset': 'flashlight.n.01', 'synonyms': ['flashlight', 'torch'], 'id': 457, 'def': 'a small portable battery-powered electric lamp', 'name': 'flashlight'}, {'frequency': 'r', 'synset': 'fleece.n.03', 'synonyms': ['fleece'], 'id': 458, 'def': 'a soft bulky fabric with deep pile; used chiefly for clothing', 'name': 'fleece'}, {'frequency': 'f', 'synset': 'flip-flop.n.02', 'synonyms': ['flip-flop_(sandal)'], 'id': 459, 'def': 'a backless sandal held to the foot by a thong between two toes', 'name': 'flip-flop_(sandal)'}, {'frequency': 'c', 'synset': 'flipper.n.01', 'synonyms': ['flipper_(footwear)', 'fin_(footwear)'], 'id': 460, 'def': 'a shoe to aid a person in swimming', 'name': 'flipper_(footwear)'}, {'frequency': 'f', 'synset': 'flower_arrangement.n.01', 'synonyms': ['flower_arrangement', 'floral_arrangement'], 'id': 461, 'def': 'a decorative arrangement of flowers', 'name': 'flower_arrangement'}, {'frequency': 'c', 'synset': 'flute.n.02', 'synonyms': ['flute_glass', 'champagne_flute'], 'id': 462, 'def': 'a tall narrow wineglass', 'name': 'flute_glass'}, {'frequency': 'c', 'synset': 'foal.n.01', 'synonyms': ['foal'], 'id': 463, 'def': 'a young horse', 'name': 'foal'}, {'frequency': 'c', 'synset': 'folding_chair.n.01', 'synonyms': ['folding_chair'], 'id': 464, 'def': 'a chair that can be folded flat for storage', 'name': 'folding_chair'}, {'frequency': 'c', 'synset': 'food_processor.n.01', 'synonyms': ['food_processor'], 'id': 465, 'def': 'a kitchen appliance for shredding, blending, chopping, or slicing food', 'name': 'food_processor'}, {'frequency': 'c', 'synset': 'football.n.02', 'synonyms': ['football_(American)'], 'id': 466, 'def': 'the inflated oblong ball used in playing American football', 'name': 'football_(American)'}, {'frequency': 'r', 'synset': 'football_helmet.n.01', 'synonyms': ['football_helmet'], 'id': 467, 'def': 'a padded helmet with a face mask to protect the head of football players', 'name': 'football_helmet'}, {'frequency': 'c', 'synset': 'footstool.n.01', 'synonyms': ['footstool', 'footrest'], 'id': 468, 'def': 'a low seat or a stool to rest the feet of a seated person', 'name': 'footstool'}, {'frequency': 'f', 'synset': 'fork.n.01', 'synonyms': ['fork'], 'id': 469, 'def': 'cutlery used for serving and eating food', 'name': 'fork'}, {'frequency': 'c', 'synset': 'forklift.n.01', 'synonyms': ['forklift'], 'id': 470, 'def': 'an industrial vehicle with a power operated fork in front that can be inserted under loads to lift and move them', 'name': 'forklift'}, {'frequency': 'c', 'synset': 'freight_car.n.01', 'synonyms': ['freight_car'], 'id': 471, 'def': 'a railway car that carries freight', 'name': 'freight_car'}, {'frequency': 'c', 'synset': 'french_toast.n.01', 'synonyms': ['French_toast'], 'id': 472, 'def': 'bread slice dipped in egg and milk and fried', 'name': 'French_toast'}, {'frequency': 'c', 'synset': 'freshener.n.01', 'synonyms': ['freshener', 'air_freshener'], 'id': 473, 'def': 'anything that freshens air by removing or covering odor', 'name': 'freshener'}, {'frequency': 'f', 'synset': 'frisbee.n.01', 'synonyms': ['frisbee'], 'id': 474, 'def': 'a light, plastic disk propelled with a flip of the wrist for recreation or competition', 'name': 'frisbee'}, {'frequency': 'c', 'synset': 'frog.n.01', 'synonyms': ['frog', 'toad', 'toad_frog'], 'id': 475, 'def': 'a tailless stout-bodied amphibians with long hind limbs for leaping', 'name': 'frog'}, {'frequency': 'c', 'synset': 'fruit_juice.n.01', 'synonyms': ['fruit_juice'], 'id': 476, 'def': 'drink produced by squeezing or crushing fruit', 'name': 'fruit_juice'}, {'frequency': 'f', 'synset': 'frying_pan.n.01', 'synonyms': ['frying_pan', 'frypan', 'skillet'], 'id': 477, 'def': 'a pan used for frying foods', 'name': 'frying_pan'}, {'frequency': 'r', 'synset': 'fudge.n.01', 'synonyms': ['fudge'], 'id': 478, 'def': 'soft creamy candy', 'name': 'fudge'}, {'frequency': 'r', 'synset': 'funnel.n.02', 'synonyms': ['funnel'], 'id': 479, 'def': 'a cone-shaped utensil used to channel a substance into a container with a small mouth', 'name': 'funnel'}, {'frequency': 'r', 'synset': 'futon.n.01', 'synonyms': ['futon'], 'id': 480, 'def': 'a pad that is used for sleeping on the floor or on a raised frame', 'name': 'futon'}, {'frequency': 'r', 'synset': 'gag.n.02', 'synonyms': ['gag', 'muzzle'], 'id': 481, 'def': "restraint put into a person's mouth to prevent speaking or shouting", 'name': 'gag'}, {'frequency': 'r', 'synset': 'garbage.n.03', 'synonyms': ['garbage'], 'id': 482, 'def': 'a receptacle where waste can be discarded', 'name': 'garbage'}, {'frequency': 'c', 'synset': 'garbage_truck.n.01', 'synonyms': ['garbage_truck'], 'id': 483, 'def': 'a truck for collecting domestic refuse', 'name': 'garbage_truck'}, {'frequency': 'c', 'synset': 'garden_hose.n.01', 'synonyms': ['garden_hose'], 'id': 484, 'def': 'a hose used for watering a lawn or garden', 'name': 'garden_hose'}, {'frequency': 'c', 'synset': 'gargle.n.01', 'synonyms': ['gargle', 'mouthwash'], 'id': 485, 'def': 'a medicated solution used for gargling and rinsing the mouth', 'name': 'gargle'}, {'frequency': 'r', 'synset': 'gargoyle.n.02', 'synonyms': ['gargoyle'], 'id': 486, 'def': 'an ornament consisting of a grotesquely carved figure of a person or animal', 'name': 'gargoyle'}, {'frequency': 'c', 'synset': 'garlic.n.02', 'synonyms': ['garlic', 'ail'], 'id': 487, 'def': 'aromatic bulb used as seasoning', 'name': 'garlic'}, {'frequency': 'r', 'synset': 'gasmask.n.01', 'synonyms': ['gasmask', 'respirator', 'gas_helmet'], 'id': 488, 'def': 'a protective face mask with a filter', 'name': 'gasmask'}, {'frequency': 'c', 'synset': 'gazelle.n.01', 'synonyms': ['gazelle'], 'id': 489, 'def': 'small swift graceful antelope of Africa and Asia having lustrous eyes', 'name': 'gazelle'}, {'frequency': 'c', 'synset': 'gelatin.n.02', 'synonyms': ['gelatin', 'jelly'], 'id': 490, 'def': 'an edible jelly made with gelatin and used as a dessert or salad base or a coating for foods', 'name': 'gelatin'}, {'frequency': 'r', 'synset': 'gem.n.02', 'synonyms': ['gemstone'], 'id': 491, 'def': 'a crystalline rock that can be cut and polished for jewelry', 'name': 'gemstone'}, {'frequency': 'r', 'synset': 'generator.n.02', 'synonyms': ['generator'], 'id': 492, 'def': 'engine that converts mechanical energy into electrical energy by electromagnetic induction', 'name': 'generator'}, {'frequency': 'c', 'synset': 'giant_panda.n.01', 'synonyms': ['giant_panda', 'panda', 'panda_bear'], 'id': 493, 'def': 'large black-and-white herbivorous mammal of bamboo forests of China and Tibet', 'name': 'giant_panda'}, {'frequency': 'c', 'synset': 'gift_wrap.n.01', 'synonyms': ['gift_wrap'], 'id': 494, 'def': 'attractive wrapping paper suitable for wrapping gifts', 'name': 'gift_wrap'}, {'frequency': 'c', 'synset': 'ginger.n.03', 'synonyms': ['ginger', 'gingerroot'], 'id': 495, 'def': 'the root of the common ginger plant; used fresh as a seasoning', 'name': 'ginger'}, {'frequency': 'f', 'synset': 'giraffe.n.01', 'synonyms': ['giraffe'], 'id': 496, 'def': 'tall animal having a spotted coat and small horns and very long neck and legs', 'name': 'giraffe'}, {'frequency': 'c', 'synset': 'girdle.n.02', 'synonyms': ['cincture', 'sash', 'waistband', 'waistcloth'], 'id': 497, 'def': 'a band of material around the waist that strengthens a skirt or trousers', 'name': 'cincture'}, {'frequency': 'f', 'synset': 'glass.n.02', 'synonyms': ['glass_(drink_container)', 'drinking_glass'], 'id': 498, 'def': 'a container for holding liquids while drinking', 'name': 'glass_(drink_container)'}, {'frequency': 'c', 'synset': 'globe.n.03', 'synonyms': ['globe'], 'id': 499, 'def': 'a sphere on which a map (especially of the earth) is represented', 'name': 'globe'}, {'frequency': 'f', 'synset': 'glove.n.02', 'synonyms': ['glove'], 'id': 500, 'def': 'handwear covering the hand', 'name': 'glove'}, {'frequency': 'c', 'synset': 'goat.n.01', 'synonyms': ['goat'], 'id': 501, 'def': 'a common goat', 'name': 'goat'}, {'frequency': 'f', 'synset': 'goggles.n.01', 'synonyms': ['goggles'], 'id': 502, 'def': 'tight-fitting spectacles worn to protect the eyes', 'name': 'goggles'}, {'frequency': 'r', 'synset': 'goldfish.n.01', 'synonyms': ['goldfish'], 'id': 503, 'def': 'small golden or orange-red freshwater fishes used as pond or aquarium pets', 'name': 'goldfish'}, {'frequency': 'c', 'synset': 'golf_club.n.02', 'synonyms': ['golf_club', 'golf-club'], 'id': 504, 'def': 'golf equipment used by a golfer to hit a golf ball', 'name': 'golf_club'}, {'frequency': 'c', 'synset': 'golfcart.n.01', 'synonyms': ['golfcart'], 'id': 505, 'def': 'a small motor vehicle in which golfers can ride between shots', 'name': 'golfcart'}, {'frequency': 'r', 'synset': 'gondola.n.02', 'synonyms': ['gondola_(boat)'], 'id': 506, 'def': 'long narrow flat-bottomed boat propelled by sculling; traditionally used on canals of Venice', 'name': 'gondola_(boat)'}, {'frequency': 'c', 'synset': 'goose.n.01', 'synonyms': ['goose'], 'id': 507, 'def': 'loud, web-footed long-necked aquatic birds usually larger than ducks', 'name': 'goose'}, {'frequency': 'r', 'synset': 'gorilla.n.01', 'synonyms': ['gorilla'], 'id': 508, 'def': 'largest ape', 'name': 'gorilla'}, {'frequency': 'r', 'synset': 'gourd.n.02', 'synonyms': ['gourd'], 'id': 509, 'def': 'any of numerous inedible fruits with hard rinds', 'name': 'gourd'}, {'frequency': 'f', 'synset': 'grape.n.01', 'synonyms': ['grape'], 'id': 510, 'def': 'any of various juicy fruit with green or purple skins; grow in clusters', 'name': 'grape'}, {'frequency': 'c', 'synset': 'grater.n.01', 'synonyms': ['grater'], 'id': 511, 'def': 'utensil with sharp perforations for shredding foods (as vegetables or cheese)', 'name': 'grater'}, {'frequency': 'c', 'synset': 'gravestone.n.01', 'synonyms': ['gravestone', 'headstone', 'tombstone'], 'id': 512, 'def': 'a stone that is used to mark a grave', 'name': 'gravestone'}, {'frequency': 'r', 'synset': 'gravy_boat.n.01', 'synonyms': ['gravy_boat', 'gravy_holder'], 'id': 513, 'def': 'a dish (often boat-shaped) for serving gravy or sauce', 'name': 'gravy_boat'}, {'frequency': 'f', 'synset': 'green_bean.n.02', 'synonyms': ['green_bean'], 'id': 514, 'def': 'a common bean plant cultivated for its slender green edible pods', 'name': 'green_bean'}, {'frequency': 'f', 'synset': 'green_onion.n.01', 'synonyms': ['green_onion', 'spring_onion', 'scallion'], 'id': 515, 'def': 'a young onion before the bulb has enlarged', 'name': 'green_onion'}, {'frequency': 'r', 'synset': 'griddle.n.01', 'synonyms': ['griddle'], 'id': 516, 'def': 'cooking utensil consisting of a flat heated surface on which food is cooked', 'name': 'griddle'}, {'frequency': 'f', 'synset': 'grill.n.02', 'synonyms': ['grill', 'grille', 'grillwork', 'radiator_grille'], 'id': 517, 'def': 'a framework of metal bars used as a partition or a grate', 'name': 'grill'}, {'frequency': 'r', 'synset': 'grits.n.01', 'synonyms': ['grits', 'hominy_grits'], 'id': 518, 'def': 'coarsely ground corn boiled as a breakfast dish', 'name': 'grits'}, {'frequency': 'c', 'synset': 'grizzly.n.01', 'synonyms': ['grizzly', 'grizzly_bear'], 'id': 519, 'def': 'powerful brownish-yellow bear of the uplands of western North America', 'name': 'grizzly'}, {'frequency': 'c', 'synset': 'grocery_bag.n.01', 'synonyms': ['grocery_bag'], 'id': 520, 'def': "a sack for holding customer's groceries", 'name': 'grocery_bag'}, {'frequency': 'f', 'synset': 'guitar.n.01', 'synonyms': ['guitar'], 'id': 521, 'def': 'a stringed instrument usually having six strings; played by strumming or plucking', 'name': 'guitar'}, {'frequency': 'c', 'synset': 'gull.n.02', 'synonyms': ['gull', 'seagull'], 'id': 522, 'def': 'mostly white aquatic bird having long pointed wings and short legs', 'name': 'gull'}, {'frequency': 'c', 'synset': 'gun.n.01', 'synonyms': ['gun'], 'id': 523, 'def': 'a weapon that discharges a bullet at high velocity from a metal tube', 'name': 'gun'}, {'frequency': 'f', 'synset': 'hairbrush.n.01', 'synonyms': ['hairbrush'], 'id': 524, 'def': "a brush used to groom a person's hair", 'name': 'hairbrush'}, {'frequency': 'c', 'synset': 'hairnet.n.01', 'synonyms': ['hairnet'], 'id': 525, 'def': 'a small net that someone wears over their hair to keep it in place', 'name': 'hairnet'}, {'frequency': 'c', 'synset': 'hairpin.n.01', 'synonyms': ['hairpin'], 'id': 526, 'def': "a double pronged pin used to hold women's hair in place", 'name': 'hairpin'}, {'frequency': 'r', 'synset': 'halter.n.03', 'synonyms': ['halter_top'], 'id': 527, 'def': "a woman's top that fastens behind the back and neck leaving the back and arms uncovered", 'name': 'halter_top'}, {'frequency': 'f', 'synset': 'ham.n.01', 'synonyms': ['ham', 'jambon', 'gammon'], 'id': 528, 'def': 'meat cut from the thigh of a hog (usually smoked)', 'name': 'ham'}, {'frequency': 'c', 'synset': 'hamburger.n.01', 'synonyms': ['hamburger', 'beefburger', 'burger'], 'id': 529, 'def': 'a sandwich consisting of a patty of minced beef served on a bun', 'name': 'hamburger'}, {'frequency': 'c', 'synset': 'hammer.n.02', 'synonyms': ['hammer'], 'id': 530, 'def': 'a hand tool with a heavy head and a handle; used to deliver an impulsive force by striking', 'name': 'hammer'}, {'frequency': 'c', 'synset': 'hammock.n.02', 'synonyms': ['hammock'], 'id': 531, 'def': 'a hanging bed of canvas or rope netting (usually suspended between two trees)', 'name': 'hammock'}, {'frequency': 'r', 'synset': 'hamper.n.02', 'synonyms': ['hamper'], 'id': 532, 'def': 'a basket usually with a cover', 'name': 'hamper'}, {'frequency': 'c', 'synset': 'hamster.n.01', 'synonyms': ['hamster'], 'id': 533, 'def': 'short-tailed burrowing rodent with large cheek pouches', 'name': 'hamster'}, {'frequency': 'f', 'synset': 'hand_blower.n.01', 'synonyms': ['hair_dryer'], 'id': 534, 'def': 'a hand-held electric blower that can blow warm air onto the hair', 'name': 'hair_dryer'}, {'frequency': 'r', 'synset': 'hand_glass.n.01', 'synonyms': ['hand_glass', 'hand_mirror'], 'id': 535, 'def': 'a mirror intended to be held in the hand', 'name': 'hand_glass'}, {'frequency': 'f', 'synset': 'hand_towel.n.01', 'synonyms': ['hand_towel', 'face_towel'], 'id': 536, 'def': 'a small towel used to dry the hands or face', 'name': 'hand_towel'}, {'frequency': 'c', 'synset': 'handcart.n.01', 'synonyms': ['handcart', 'pushcart', 'hand_truck'], 'id': 537, 'def': 'wheeled vehicle that can be pushed by a person', 'name': 'handcart'}, {'frequency': 'r', 'synset': 'handcuff.n.01', 'synonyms': ['handcuff'], 'id': 538, 'def': 'shackle that consists of a metal loop that can be locked around the wrist', 'name': 'handcuff'}, {'frequency': 'c', 'synset': 'handkerchief.n.01', 'synonyms': ['handkerchief'], 'id': 539, 'def': 'a square piece of cloth used for wiping the eyes or nose or as a costume accessory', 'name': 'handkerchief'}, {'frequency': 'f', 'synset': 'handle.n.01', 'synonyms': ['handle', 'grip', 'handgrip'], 'id': 540, 'def': 'the appendage to an object that is designed to be held in order to use or move it', 'name': 'handle'}, {'frequency': 'r', 'synset': 'handsaw.n.01', 'synonyms': ['handsaw', "carpenter's_saw"], 'id': 541, 'def': 'a saw used with one hand for cutting wood', 'name': 'handsaw'}, {'frequency': 'r', 'synset': 'hardback.n.01', 'synonyms': ['hardback_book', 'hardcover_book'], 'id': 542, 'def': 'a book with cardboard or cloth or leather covers', 'name': 'hardback_book'}, {'frequency': 'r', 'synset': 'harmonium.n.01', 'synonyms': ['harmonium', 'organ_(musical_instrument)', 'reed_organ_(musical_instrument)'], 'id': 543, 'def': 'a free-reed instrument in which air is forced through the reeds by bellows', 'name': 'harmonium'}, {'frequency': 'f', 'synset': 'hat.n.01', 'synonyms': ['hat'], 'id': 544, 'def': 'headwear that protects the head from bad weather, sun, or worn for fashion', 'name': 'hat'}, {'frequency': 'r', 'synset': 'hatbox.n.01', 'synonyms': ['hatbox'], 'id': 545, 'def': 'a round piece of luggage for carrying hats', 'name': 'hatbox'}, {'frequency': 'c', 'synset': 'head_covering.n.01', 'synonyms': ['veil'], 'id': 546, 'def': 'a garment that covers the head OR face', 'name': 'veil'}, {'frequency': 'f', 'synset': 'headband.n.01', 'synonyms': ['headband'], 'id': 547, 'def': 'a band worn around or over the head', 'name': 'headband'}, {'frequency': 'f', 'synset': 'headboard.n.01', 'synonyms': ['headboard'], 'id': 548, 'def': 'a vertical board or panel forming the head of a bedstead', 'name': 'headboard'}, {'frequency': 'f', 'synset': 'headlight.n.01', 'synonyms': ['headlight', 'headlamp'], 'id': 549, 'def': 'a powerful light with reflector; attached to the front of an automobile or locomotive', 'name': 'headlight'}, {'frequency': 'c', 'synset': 'headscarf.n.01', 'synonyms': ['headscarf'], 'id': 550, 'def': 'a kerchief worn over the head and tied under the chin', 'name': 'headscarf'}, {'frequency': 'r', 'synset': 'headset.n.01', 'synonyms': ['headset'], 'id': 551, 'def': 'receiver consisting of a pair of headphones', 'name': 'headset'}, {'frequency': 'c', 'synset': 'headstall.n.01', 'synonyms': ['headstall_(for_horses)', 'headpiece_(for_horses)'], 'id': 552, 'def': "the band that is the part of a bridle that fits around a horse's head", 'name': 'headstall_(for_horses)'}, {'frequency': 'c', 'synset': 'heart.n.02', 'synonyms': ['heart'], 'id': 553, 'def': 'a muscular organ; its contractions move the blood through the body', 'name': 'heart'}, {'frequency': 'c', 'synset': 'heater.n.01', 'synonyms': ['heater', 'warmer'], 'id': 554, 'def': 'device that heats water or supplies warmth to a room', 'name': 'heater'}, {'frequency': 'c', 'synset': 'helicopter.n.01', 'synonyms': ['helicopter'], 'id': 555, 'def': 'an aircraft without wings that obtains its lift from the rotation of overhead blades', 'name': 'helicopter'}, {'frequency': 'f', 'synset': 'helmet.n.02', 'synonyms': ['helmet'], 'id': 556, 'def': 'a protective headgear made of hard material to resist blows', 'name': 'helmet'}, {'frequency': 'r', 'synset': 'heron.n.02', 'synonyms': ['heron'], 'id': 557, 'def': 'grey or white wading bird with long neck and long legs and (usually) long bill', 'name': 'heron'}, {'frequency': 'c', 'synset': 'highchair.n.01', 'synonyms': ['highchair', 'feeding_chair'], 'id': 558, 'def': 'a chair for feeding a very young child', 'name': 'highchair'}, {'frequency': 'f', 'synset': 'hinge.n.01', 'synonyms': ['hinge'], 'id': 559, 'def': 'a joint that holds two parts together so that one can swing relative to the other', 'name': 'hinge'}, {'frequency': 'r', 'synset': 'hippopotamus.n.01', 'synonyms': ['hippopotamus'], 'id': 560, 'def': 'massive thick-skinned animal living in or around rivers of tropical Africa', 'name': 'hippopotamus'}, {'frequency': 'r', 'synset': 'hockey_stick.n.01', 'synonyms': ['hockey_stick'], 'id': 561, 'def': 'sports implement consisting of a stick used by hockey players to move the puck', 'name': 'hockey_stick'}, {'frequency': 'c', 'synset': 'hog.n.03', 'synonyms': ['hog', 'pig'], 'id': 562, 'def': 'domestic swine', 'name': 'hog'}, {'frequency': 'f', 'synset': 'home_plate.n.01', 'synonyms': ['home_plate_(baseball)', 'home_base_(baseball)'], 'id': 563, 'def': '(baseball) a rubber slab where the batter stands; it must be touched by a base runner in order to score', 'name': 'home_plate_(baseball)'}, {'frequency': 'c', 'synset': 'honey.n.01', 'synonyms': ['honey'], 'id': 564, 'def': 'a sweet yellow liquid produced by bees', 'name': 'honey'}, {'frequency': 'f', 'synset': 'hood.n.06', 'synonyms': ['fume_hood', 'exhaust_hood'], 'id': 565, 'def': 'metal covering leading to a vent that exhausts smoke or fumes', 'name': 'fume_hood'}, {'frequency': 'f', 'synset': 'hook.n.05', 'synonyms': ['hook'], 'id': 566, 'def': 'a curved or bent implement for suspending or pulling something', 'name': 'hook'}, {'frequency': 'r', 'synset': 'hookah.n.01', 'synonyms': ['hookah', 'narghile', 'nargileh', 'sheesha', 'shisha', 'water_pipe'], 'id': 567, 'def': 'a tobacco pipe with a long flexible tube connected to a container where the smoke is cooled by passing through water', 'name': 'hookah'}, {'frequency': 'r', 'synset': 'hornet.n.01', 'synonyms': ['hornet'], 'id': 568, 'def': 'large stinging wasp', 'name': 'hornet'}, {'frequency': 'f', 'synset': 'horse.n.01', 'synonyms': ['horse'], 'id': 569, 'def': 'a common horse', 'name': 'horse'}, {'frequency': 'f', 'synset': 'hose.n.03', 'synonyms': ['hose', 'hosepipe'], 'id': 570, 'def': 'a flexible pipe for conveying a liquid or gas', 'name': 'hose'}, {'frequency': 'r', 'synset': 'hot-air_balloon.n.01', 'synonyms': ['hot-air_balloon'], 'id': 571, 'def': 'balloon for travel through the air in a basket suspended below a large bag of heated air', 'name': 'hot-air_balloon'}, {'frequency': 'r', 'synset': 'hot_plate.n.01', 'synonyms': ['hotplate'], 'id': 572, 'def': 'a portable electric appliance for heating or cooking or keeping food warm', 'name': 'hotplate'}, {'frequency': 'c', 'synset': 'hot_sauce.n.01', 'synonyms': ['hot_sauce'], 'id': 573, 'def': 'a pungent peppery sauce', 'name': 'hot_sauce'}, {'frequency': 'r', 'synset': 'hourglass.n.01', 'synonyms': ['hourglass'], 'id': 574, 'def': 'a sandglass timer that runs for sixty minutes', 'name': 'hourglass'}, {'frequency': 'r', 'synset': 'houseboat.n.01', 'synonyms': ['houseboat'], 'id': 575, 'def': 'a barge that is designed and equipped for use as a dwelling', 'name': 'houseboat'}, {'frequency': 'c', 'synset': 'hummingbird.n.01', 'synonyms': ['hummingbird'], 'id': 576, 'def': 'tiny American bird having brilliant iridescent plumage and long slender bills', 'name': 'hummingbird'}, {'frequency': 'r', 'synset': 'hummus.n.01', 'synonyms': ['hummus', 'humus', 'hommos', 'hoummos', 'humous'], 'id': 577, 'def': 'a thick spread made from mashed chickpeas', 'name': 'hummus'}, {'frequency': 'f', 'synset': 'ice_bear.n.01', 'synonyms': ['polar_bear'], 'id': 578, 'def': 'white bear of Arctic regions', 'name': 'polar_bear'}, {'frequency': 'c', 'synset': 'ice_cream.n.01', 'synonyms': ['icecream'], 'id': 579, 'def': 'frozen dessert containing cream and sugar and flavoring', 'name': 'icecream'}, {'frequency': 'r', 'synset': 'ice_lolly.n.01', 'synonyms': ['popsicle'], 'id': 580, 'def': 'ice cream or water ice on a small wooden stick', 'name': 'popsicle'}, {'frequency': 'c', 'synset': 'ice_maker.n.01', 'synonyms': ['ice_maker'], 'id': 581, 'def': 'an appliance included in some electric refrigerators for making ice cubes', 'name': 'ice_maker'}, {'frequency': 'r', 'synset': 'ice_pack.n.01', 'synonyms': ['ice_pack', 'ice_bag'], 'id': 582, 'def': 'a waterproof bag filled with ice: applied to the body (especially the head) to cool or reduce swelling', 'name': 'ice_pack'}, {'frequency': 'r', 'synset': 'ice_skate.n.01', 'synonyms': ['ice_skate'], 'id': 583, 'def': 'skate consisting of a boot with a steel blade fitted to the sole', 'name': 'ice_skate'}, {'frequency': 'c', 'synset': 'igniter.n.01', 'synonyms': ['igniter', 'ignitor', 'lighter'], 'id': 584, 'def': 'a substance or device used to start a fire', 'name': 'igniter'}, {'frequency': 'r', 'synset': 'inhaler.n.01', 'synonyms': ['inhaler', 'inhalator'], 'id': 585, 'def': 'a dispenser that produces a chemical vapor to be inhaled through mouth or nose', 'name': 'inhaler'}, {'frequency': 'f', 'synset': 'ipod.n.01', 'synonyms': ['iPod'], 'id': 586, 'def': 'a pocket-sized device used to play music files', 'name': 'iPod'}, {'frequency': 'c', 'synset': 'iron.n.04', 'synonyms': ['iron_(for_clothing)', 'smoothing_iron_(for_clothing)'], 'id': 587, 'def': 'home appliance consisting of a flat metal base that is heated and used to smooth cloth', 'name': 'iron_(for_clothing)'}, {'frequency': 'c', 'synset': 'ironing_board.n.01', 'synonyms': ['ironing_board'], 'id': 588, 'def': 'narrow padded board on collapsible supports; used for ironing clothes', 'name': 'ironing_board'}, {'frequency': 'f', 'synset': 'jacket.n.01', 'synonyms': ['jacket'], 'id': 589, 'def': 'a waist-length coat', 'name': 'jacket'}, {'frequency': 'c', 'synset': 'jam.n.01', 'synonyms': ['jam'], 'id': 590, 'def': 'preserve of crushed fruit', 'name': 'jam'}, {'frequency': 'f', 'synset': 'jar.n.01', 'synonyms': ['jar'], 'id': 591, 'def': 'a vessel (usually cylindrical) with a wide mouth and without handles', 'name': 'jar'}, {'frequency': 'f', 'synset': 'jean.n.01', 'synonyms': ['jean', 'blue_jean', 'denim'], 'id': 592, 'def': '(usually plural) close-fitting trousers of heavy denim for manual work or casual wear', 'name': 'jean'}, {'frequency': 'c', 'synset': 'jeep.n.01', 'synonyms': ['jeep', 'landrover'], 'id': 593, 'def': 'a car suitable for traveling over rough terrain', 'name': 'jeep'}, {'frequency': 'r', 'synset': 'jelly_bean.n.01', 'synonyms': ['jelly_bean', 'jelly_egg'], 'id': 594, 'def': 'sugar-glazed jellied candy', 'name': 'jelly_bean'}, {'frequency': 'f', 'synset': 'jersey.n.03', 'synonyms': ['jersey', 'T-shirt', 'tee_shirt'], 'id': 595, 'def': 'a close-fitting pullover shirt', 'name': 'jersey'}, {'frequency': 'c', 'synset': 'jet.n.01', 'synonyms': ['jet_plane', 'jet-propelled_plane'], 'id': 596, 'def': 'an airplane powered by one or more jet engines', 'name': 'jet_plane'}, {'frequency': 'r', 'synset': 'jewel.n.01', 'synonyms': ['jewel', 'gem', 'precious_stone'], 'id': 597, 'def': 'a precious or semiprecious stone incorporated into a piece of jewelry', 'name': 'jewel'}, {'frequency': 'c', 'synset': 'jewelry.n.01', 'synonyms': ['jewelry', 'jewellery'], 'id': 598, 'def': 'an adornment (as a bracelet or ring or necklace) made of precious metals and set with gems (or imitation gems)', 'name': 'jewelry'}, {'frequency': 'r', 'synset': 'joystick.n.02', 'synonyms': ['joystick'], 'id': 599, 'def': 'a control device for computers consisting of a vertical handle that can move freely in two directions', 'name': 'joystick'}, {'frequency': 'c', 'synset': 'jump_suit.n.01', 'synonyms': ['jumpsuit'], 'id': 600, 'def': "one-piece garment fashioned after a parachutist's uniform", 'name': 'jumpsuit'}, {'frequency': 'c', 'synset': 'kayak.n.01', 'synonyms': ['kayak'], 'id': 601, 'def': 'a small canoe consisting of a light frame made watertight with animal skins', 'name': 'kayak'}, {'frequency': 'r', 'synset': 'keg.n.02', 'synonyms': ['keg'], 'id': 602, 'def': 'small cask or barrel', 'name': 'keg'}, {'frequency': 'r', 'synset': 'kennel.n.01', 'synonyms': ['kennel', 'doghouse'], 'id': 603, 'def': 'outbuilding that serves as a shelter for a dog', 'name': 'kennel'}, {'frequency': 'c', 'synset': 'kettle.n.01', 'synonyms': ['kettle', 'boiler'], 'id': 604, 'def': 'a metal pot for stewing or boiling; usually has a lid', 'name': 'kettle'}, {'frequency': 'f', 'synset': 'key.n.01', 'synonyms': ['key'], 'id': 605, 'def': 'metal instrument used to unlock a lock', 'name': 'key'}, {'frequency': 'r', 'synset': 'keycard.n.01', 'synonyms': ['keycard'], 'id': 606, 'def': 'a plastic card used to gain access typically to a door', 'name': 'keycard'}, {'frequency': 'c', 'synset': 'kilt.n.01', 'synonyms': ['kilt'], 'id': 607, 'def': 'a knee-length pleated tartan skirt worn by men as part of the traditional dress in the Highlands of northern Scotland', 'name': 'kilt'}, {'frequency': 'c', 'synset': 'kimono.n.01', 'synonyms': ['kimono'], 'id': 608, 'def': 'a loose robe; imitated from robes originally worn by Japanese', 'name': 'kimono'}, {'frequency': 'f', 'synset': 'kitchen_sink.n.01', 'synonyms': ['kitchen_sink'], 'id': 609, 'def': 'a sink in a kitchen', 'name': 'kitchen_sink'}, {'frequency': 'r', 'synset': 'kitchen_table.n.01', 'synonyms': ['kitchen_table'], 'id': 610, 'def': 'a table in the kitchen', 'name': 'kitchen_table'}, {'frequency': 'f', 'synset': 'kite.n.03', 'synonyms': ['kite'], 'id': 611, 'def': 'plaything consisting of a light frame covered with tissue paper; flown in wind at end of a string', 'name': 'kite'}, {'frequency': 'c', 'synset': 'kitten.n.01', 'synonyms': ['kitten', 'kitty'], 'id': 612, 'def': 'young domestic cat', 'name': 'kitten'}, {'frequency': 'c', 'synset': 'kiwi.n.03', 'synonyms': ['kiwi_fruit'], 'id': 613, 'def': 'fuzzy brown egg-shaped fruit with slightly tart green flesh', 'name': 'kiwi_fruit'}, {'frequency': 'f', 'synset': 'knee_pad.n.01', 'synonyms': ['knee_pad'], 'id': 614, 'def': 'protective garment consisting of a pad worn by football or baseball or hockey players', 'name': 'knee_pad'}, {'frequency': 'f', 'synset': 'knife.n.01', 'synonyms': ['knife'], 'id': 615, 'def': 'tool with a blade and point used as a cutting instrument', 'name': 'knife'}, {'frequency': 'r', 'synset': 'knitting_needle.n.01', 'synonyms': ['knitting_needle'], 'id': 616, 'def': 'needle consisting of a slender rod with pointed ends; usually used in pairs', 'name': 'knitting_needle'}, {'frequency': 'f', 'synset': 'knob.n.02', 'synonyms': ['knob'], 'id': 617, 'def': 'a round handle often found on a door', 'name': 'knob'}, {'frequency': 'r', 'synset': 'knocker.n.05', 'synonyms': ['knocker_(on_a_door)', 'doorknocker'], 'id': 618, 'def': 'a device (usually metal and ornamental) attached by a hinge to a door', 'name': 'knocker_(on_a_door)'}, {'frequency': 'r', 'synset': 'koala.n.01', 'synonyms': ['koala', 'koala_bear'], 'id': 619, 'def': 'sluggish tailless Australian marsupial with grey furry ears and coat', 'name': 'koala'}, {'frequency': 'r', 'synset': 'lab_coat.n.01', 'synonyms': ['lab_coat', 'laboratory_coat'], 'id': 620, 'def': 'a light coat worn to protect clothing from substances used while working in a laboratory', 'name': 'lab_coat'}, {'frequency': 'f', 'synset': 'ladder.n.01', 'synonyms': ['ladder'], 'id': 621, 'def': 'steps consisting of two parallel members connected by rungs', 'name': 'ladder'}, {'frequency': 'c', 'synset': 'ladle.n.01', 'synonyms': ['ladle'], 'id': 622, 'def': 'a spoon-shaped vessel with a long handle frequently used to transfer liquids', 'name': 'ladle'}, {'frequency': 'c', 'synset': 'ladybug.n.01', 'synonyms': ['ladybug', 'ladybeetle', 'ladybird_beetle'], 'id': 623, 'def': 'small round bright-colored and spotted beetle, typically red and black', 'name': 'ladybug'}, {'frequency': 'f', 'synset': 'lamb.n.01', 'synonyms': ['lamb_(animal)'], 'id': 624, 'def': 'young sheep', 'name': 'lamb_(animal)'}, {'frequency': 'r', 'synset': 'lamb_chop.n.01', 'synonyms': ['lamb-chop', 'lambchop'], 'id': 625, 'def': 'chop cut from a lamb', 'name': 'lamb-chop'}, {'frequency': 'f', 'synset': 'lamp.n.02', 'synonyms': ['lamp'], 'id': 626, 'def': 'a piece of furniture holding one or more electric light bulbs', 'name': 'lamp'}, {'frequency': 'f', 'synset': 'lamppost.n.01', 'synonyms': ['lamppost'], 'id': 627, 'def': 'a metal post supporting an outdoor lamp (such as a streetlight)', 'name': 'lamppost'}, {'frequency': 'f', 'synset': 'lampshade.n.01', 'synonyms': ['lampshade'], 'id': 628, 'def': 'a protective ornamental shade used to screen a light bulb from direct view', 'name': 'lampshade'}, {'frequency': 'c', 'synset': 'lantern.n.01', 'synonyms': ['lantern'], 'id': 629, 'def': 'light in a transparent protective case', 'name': 'lantern'}, {'frequency': 'f', 'synset': 'lanyard.n.02', 'synonyms': ['lanyard', 'laniard'], 'id': 630, 'def': 'a cord worn around the neck to hold a knife or whistle, etc.', 'name': 'lanyard'}, {'frequency': 'f', 'synset': 'laptop.n.01', 'synonyms': ['laptop_computer', 'notebook_computer'], 'id': 631, 'def': 'a portable computer small enough to use in your lap', 'name': 'laptop_computer'}, {'frequency': 'r', 'synset': 'lasagna.n.01', 'synonyms': ['lasagna', 'lasagne'], 'id': 632, 'def': 'baked dish of layers of lasagna pasta with sauce and cheese and meat or vegetables', 'name': 'lasagna'}, {'frequency': 'f', 'synset': 'latch.n.02', 'synonyms': ['latch'], 'id': 633, 'def': 'a bar that can be lowered or slid into a groove to fasten a door or gate', 'name': 'latch'}, {'frequency': 'r', 'synset': 'lawn_mower.n.01', 'synonyms': ['lawn_mower'], 'id': 634, 'def': 'garden tool for mowing grass on lawns', 'name': 'lawn_mower'}, {'frequency': 'r', 'synset': 'leather.n.01', 'synonyms': ['leather'], 'id': 635, 'def': 'an animal skin made smooth and flexible by removing the hair and then tanning', 'name': 'leather'}, {'frequency': 'c', 'synset': 'legging.n.01', 'synonyms': ['legging_(clothing)', 'leging_(clothing)', 'leg_covering'], 'id': 636, 'def': 'a garment covering the leg (usually extending from the knee to the ankle)', 'name': 'legging_(clothing)'}, {'frequency': 'c', 'synset': 'lego.n.01', 'synonyms': ['Lego', 'Lego_set'], 'id': 637, 'def': "a child's plastic construction set for making models from blocks", 'name': 'Lego'}, {'frequency': 'r', 'synset': 'legume.n.02', 'synonyms': ['legume'], 'id': 638, 'def': 'the fruit or seed of bean or pea plants', 'name': 'legume'}, {'frequency': 'f', 'synset': 'lemon.n.01', 'synonyms': ['lemon'], 'id': 639, 'def': 'yellow oval fruit with juicy acidic flesh', 'name': 'lemon'}, {'frequency': 'r', 'synset': 'lemonade.n.01', 'synonyms': ['lemonade'], 'id': 640, 'def': 'sweetened beverage of diluted lemon juice', 'name': 'lemonade'}, {'frequency': 'f', 'synset': 'lettuce.n.02', 'synonyms': ['lettuce'], 'id': 641, 'def': 'leafy plant commonly eaten in salad or on sandwiches', 'name': 'lettuce'}, {'frequency': 'f', 'synset': 'license_plate.n.01', 'synonyms': ['license_plate', 'numberplate'], 'id': 642, 'def': "a plate mounted on the front and back of car and bearing the car's registration number", 'name': 'license_plate'}, {'frequency': 'f', 'synset': 'life_buoy.n.01', 'synonyms': ['life_buoy', 'lifesaver', 'life_belt', 'life_ring'], 'id': 643, 'def': 'a ring-shaped life preserver used to prevent drowning (NOT a life-jacket or vest)', 'name': 'life_buoy'}, {'frequency': 'f', 'synset': 'life_jacket.n.01', 'synonyms': ['life_jacket', 'life_vest'], 'id': 644, 'def': 'life preserver consisting of a sleeveless jacket of buoyant or inflatable design', 'name': 'life_jacket'}, {'frequency': 'f', 'synset': 'light_bulb.n.01', 'synonyms': ['lightbulb'], 'id': 645, 'def': 'lightblub/source of light', 'name': 'lightbulb'}, {'frequency': 'r', 'synset': 'lightning_rod.n.02', 'synonyms': ['lightning_rod', 'lightning_conductor'], 'id': 646, 'def': 'a metallic conductor that is attached to a high point and leads to the ground', 'name': 'lightning_rod'}, {'frequency': 'f', 'synset': 'lime.n.06', 'synonyms': ['lime'], 'id': 647, 'def': 'the green acidic fruit of any of various lime trees', 'name': 'lime'}, {'frequency': 'r', 'synset': 'limousine.n.01', 'synonyms': ['limousine'], 'id': 648, 'def': 'long luxurious car; usually driven by a chauffeur', 'name': 'limousine'}, {'frequency': 'c', 'synset': 'lion.n.01', 'synonyms': ['lion'], 'id': 649, 'def': 'large gregarious predatory cat of Africa and India', 'name': 'lion'}, {'frequency': 'c', 'synset': 'lip_balm.n.01', 'synonyms': ['lip_balm'], 'id': 650, 'def': 'a balm applied to the lips', 'name': 'lip_balm'}, {'frequency': 'r', 'synset': 'liquor.n.01', 'synonyms': ['liquor', 'spirits', 'hard_liquor', 'liqueur', 'cordial'], 'id': 651, 'def': 'liquor or beer', 'name': 'liquor'}, {'frequency': 'c', 'synset': 'lizard.n.01', 'synonyms': ['lizard'], 'id': 652, 'def': 'a reptile with usually two pairs of legs and a tapering tail', 'name': 'lizard'}, {'frequency': 'f', 'synset': 'log.n.01', 'synonyms': ['log'], 'id': 653, 'def': 'a segment of the trunk of a tree when stripped of branches', 'name': 'log'}, {'frequency': 'c', 'synset': 'lollipop.n.02', 'synonyms': ['lollipop'], 'id': 654, 'def': 'hard candy on a stick', 'name': 'lollipop'}, {'frequency': 'f', 'synset': 'loudspeaker.n.01', 'synonyms': ['speaker_(stero_equipment)'], 'id': 655, 'def': 'electronic device that produces sound often as part of a stereo system', 'name': 'speaker_(stero_equipment)'}, {'frequency': 'c', 'synset': 'love_seat.n.01', 'synonyms': ['loveseat'], 'id': 656, 'def': 'small sofa that seats two people', 'name': 'loveseat'}, {'frequency': 'r', 'synset': 'machine_gun.n.01', 'synonyms': ['machine_gun'], 'id': 657, 'def': 'a rapidly firing automatic gun', 'name': 'machine_gun'}, {'frequency': 'f', 'synset': 'magazine.n.02', 'synonyms': ['magazine'], 'id': 658, 'def': 'a paperback periodic publication', 'name': 'magazine'}, {'frequency': 'f', 'synset': 'magnet.n.01', 'synonyms': ['magnet'], 'id': 659, 'def': 'a device that attracts iron and produces a magnetic field', 'name': 'magnet'}, {'frequency': 'c', 'synset': 'mail_slot.n.01', 'synonyms': ['mail_slot'], 'id': 660, 'def': 'a slot (usually in a door) through which mail can be delivered', 'name': 'mail_slot'}, {'frequency': 'f', 'synset': 'mailbox.n.01', 'synonyms': ['mailbox_(at_home)', 'letter_box_(at_home)'], 'id': 661, 'def': 'a private box for delivery of mail', 'name': 'mailbox_(at_home)'}, {'frequency': 'r', 'synset': 'mallard.n.01', 'synonyms': ['mallard'], 'id': 662, 'def': 'wild dabbling duck from which domestic ducks are descended', 'name': 'mallard'}, {'frequency': 'r', 'synset': 'mallet.n.01', 'synonyms': ['mallet'], 'id': 663, 'def': 'a sports implement with a long handle and a hammer-like head used to hit a ball', 'name': 'mallet'}, {'frequency': 'r', 'synset': 'mammoth.n.01', 'synonyms': ['mammoth'], 'id': 664, 'def': 'any of numerous extinct elephants widely distributed in the Pleistocene', 'name': 'mammoth'}, {'frequency': 'r', 'synset': 'manatee.n.01', 'synonyms': ['manatee'], 'id': 665, 'def': 'sirenian mammal of tropical coastal waters of America', 'name': 'manatee'}, {'frequency': 'c', 'synset': 'mandarin.n.05', 'synonyms': ['mandarin_orange'], 'id': 666, 'def': 'a somewhat flat reddish-orange loose skinned citrus of China', 'name': 'mandarin_orange'}, {'frequency': 'c', 'synset': 'manger.n.01', 'synonyms': ['manger', 'trough'], 'id': 667, 'def': 'a container (usually in a barn or stable) from which cattle or horses feed', 'name': 'manger'}, {'frequency': 'f', 'synset': 'manhole.n.01', 'synonyms': ['manhole'], 'id': 668, 'def': 'a hole (usually with a flush cover) through which a person can gain access to an underground structure', 'name': 'manhole'}, {'frequency': 'f', 'synset': 'map.n.01', 'synonyms': ['map'], 'id': 669, 'def': "a diagrammatic representation of the earth's surface (or part of it)", 'name': 'map'}, {'frequency': 'f', 'synset': 'marker.n.03', 'synonyms': ['marker'], 'id': 670, 'def': 'a writing implement for making a mark', 'name': 'marker'}, {'frequency': 'r', 'synset': 'martini.n.01', 'synonyms': ['martini'], 'id': 671, 'def': 'a cocktail made of gin (or vodka) with dry vermouth', 'name': 'martini'}, {'frequency': 'r', 'synset': 'mascot.n.01', 'synonyms': ['mascot'], 'id': 672, 'def': 'a person or animal that is adopted by a team or other group as a symbolic figure', 'name': 'mascot'}, {'frequency': 'c', 'synset': 'mashed_potato.n.01', 'synonyms': ['mashed_potato'], 'id': 673, 'def': 'potato that has been peeled and boiled and then mashed', 'name': 'mashed_potato'}, {'frequency': 'r', 'synset': 'masher.n.02', 'synonyms': ['masher'], 'id': 674, 'def': 'a kitchen utensil used for mashing (e.g. potatoes)', 'name': 'masher'}, {'frequency': 'f', 'synset': 'mask.n.04', 'synonyms': ['mask', 'facemask'], 'id': 675, 'def': 'a protective covering worn over the face', 'name': 'mask'}, {'frequency': 'f', 'synset': 'mast.n.01', 'synonyms': ['mast'], 'id': 676, 'def': 'a vertical spar for supporting sails', 'name': 'mast'}, {'frequency': 'c', 'synset': 'mat.n.03', 'synonyms': ['mat_(gym_equipment)', 'gym_mat'], 'id': 677, 'def': 'sports equipment consisting of a piece of thick padding on the floor for gymnastics', 'name': 'mat_(gym_equipment)'}, {'frequency': 'r', 'synset': 'matchbox.n.01', 'synonyms': ['matchbox'], 'id': 678, 'def': 'a box for holding matches', 'name': 'matchbox'}, {'frequency': 'f', 'synset': 'mattress.n.01', 'synonyms': ['mattress'], 'id': 679, 'def': 'a thick pad filled with resilient material used as a bed or part of a bed', 'name': 'mattress'}, {'frequency': 'c', 'synset': 'measuring_cup.n.01', 'synonyms': ['measuring_cup'], 'id': 680, 'def': 'graduated cup used to measure liquid or granular ingredients', 'name': 'measuring_cup'}, {'frequency': 'c', 'synset': 'measuring_stick.n.01', 'synonyms': ['measuring_stick', 'ruler_(measuring_stick)', 'measuring_rod'], 'id': 681, 'def': 'measuring instrument having a sequence of marks at regular intervals', 'name': 'measuring_stick'}, {'frequency': 'c', 'synset': 'meatball.n.01', 'synonyms': ['meatball'], 'id': 682, 'def': 'ground meat formed into a ball and fried or simmered in broth', 'name': 'meatball'}, {'frequency': 'c', 'synset': 'medicine.n.02', 'synonyms': ['medicine'], 'id': 683, 'def': 'something that treats or prevents or alleviates the symptoms of disease', 'name': 'medicine'}, {'frequency': 'c', 'synset': 'melon.n.01', 'synonyms': ['melon'], 'id': 684, 'def': 'fruit of the gourd family having a hard rind and sweet juicy flesh', 'name': 'melon'}, {'frequency': 'f', 'synset': 'microphone.n.01', 'synonyms': ['microphone'], 'id': 685, 'def': 'device for converting sound waves into electrical energy', 'name': 'microphone'}, {'frequency': 'r', 'synset': 'microscope.n.01', 'synonyms': ['microscope'], 'id': 686, 'def': 'magnifier of the image of small objects', 'name': 'microscope'}, {'frequency': 'f', 'synset': 'microwave.n.02', 'synonyms': ['microwave_oven'], 'id': 687, 'def': 'kitchen appliance that cooks food by passing an electromagnetic wave through it', 'name': 'microwave_oven'}, {'frequency': 'r', 'synset': 'milestone.n.01', 'synonyms': ['milestone', 'milepost'], 'id': 688, 'def': 'stone post at side of a road to show distances', 'name': 'milestone'}, {'frequency': 'f', 'synset': 'milk.n.01', 'synonyms': ['milk'], 'id': 689, 'def': 'a white nutritious liquid secreted by mammals and used as food by human beings', 'name': 'milk'}, {'frequency': 'r', 'synset': 'milk_can.n.01', 'synonyms': ['milk_can'], 'id': 690, 'def': 'can for transporting milk', 'name': 'milk_can'}, {'frequency': 'r', 'synset': 'milkshake.n.01', 'synonyms': ['milkshake'], 'id': 691, 'def': 'frothy drink of milk and flavoring and sometimes fruit or ice cream', 'name': 'milkshake'}, {'frequency': 'f', 'synset': 'minivan.n.01', 'synonyms': ['minivan'], 'id': 692, 'def': 'a small box-shaped passenger van', 'name': 'minivan'}, {'frequency': 'r', 'synset': 'mint.n.05', 'synonyms': ['mint_candy'], 'id': 693, 'def': 'a candy that is flavored with a mint oil', 'name': 'mint_candy'}, {'frequency': 'f', 'synset': 'mirror.n.01', 'synonyms': ['mirror'], 'id': 694, 'def': 'polished surface that forms images by reflecting light', 'name': 'mirror'}, {'frequency': 'c', 'synset': 'mitten.n.01', 'synonyms': ['mitten'], 'id': 695, 'def': 'glove that encases the thumb separately and the other four fingers together', 'name': 'mitten'}, {'frequency': 'c', 'synset': 'mixer.n.04', 'synonyms': ['mixer_(kitchen_tool)', 'stand_mixer'], 'id': 696, 'def': 'a kitchen utensil that is used for mixing foods', 'name': 'mixer_(kitchen_tool)'}, {'frequency': 'c', 'synset': 'money.n.03', 'synonyms': ['money'], 'id': 697, 'def': 'the official currency issued by a government or national bank', 'name': 'money'}, {'frequency': 'f', 'synset': 'monitor.n.04', 'synonyms': ['monitor_(computer_equipment) computer_monitor'], 'id': 698, 'def': 'a computer monitor', 'name': 'monitor_(computer_equipment) computer_monitor'}, {'frequency': 'c', 'synset': 'monkey.n.01', 'synonyms': ['monkey'], 'id': 699, 'def': 'any of various long-tailed primates', 'name': 'monkey'}, {'frequency': 'f', 'synset': 'motor.n.01', 'synonyms': ['motor'], 'id': 700, 'def': 'machine that converts other forms of energy into mechanical energy and so imparts motion', 'name': 'motor'}, {'frequency': 'f', 'synset': 'motor_scooter.n.01', 'synonyms': ['motor_scooter', 'scooter'], 'id': 701, 'def': 'a wheeled vehicle with small wheels and a low-powered engine', 'name': 'motor_scooter'}, {'frequency': 'r', 'synset': 'motor_vehicle.n.01', 'synonyms': ['motor_vehicle', 'automotive_vehicle'], 'id': 702, 'def': 'a self-propelled wheeled vehicle that does not run on rails', 'name': 'motor_vehicle'}, {'frequency': 'f', 'synset': 'motorcycle.n.01', 'synonyms': ['motorcycle'], 'id': 703, 'def': 'a motor vehicle with two wheels and a strong frame', 'name': 'motorcycle'}, {'frequency': 'f', 'synset': 'mound.n.01', 'synonyms': ['mound_(baseball)', "pitcher's_mound"], 'id': 704, 'def': '(baseball) the slight elevation on which the pitcher stands', 'name': 'mound_(baseball)'}, {'frequency': 'f', 'synset': 'mouse.n.04', 'synonyms': ['mouse_(computer_equipment)', 'computer_mouse'], 'id': 705, 'def': 'a computer input device that controls an on-screen pointer (does not include trackpads / touchpads)', 'name': 'mouse_(computer_equipment)'}, {'frequency': 'f', 'synset': 'mousepad.n.01', 'synonyms': ['mousepad'], 'id': 706, 'def': 'a small portable pad that provides an operating surface for a computer mouse', 'name': 'mousepad'}, {'frequency': 'c', 'synset': 'muffin.n.01', 'synonyms': ['muffin'], 'id': 707, 'def': 'a sweet quick bread baked in a cup-shaped pan', 'name': 'muffin'}, {'frequency': 'f', 'synset': 'mug.n.04', 'synonyms': ['mug'], 'id': 708, 'def': 'with handle and usually cylindrical', 'name': 'mug'}, {'frequency': 'f', 'synset': 'mushroom.n.02', 'synonyms': ['mushroom'], 'id': 709, 'def': 'a common mushroom', 'name': 'mushroom'}, {'frequency': 'r', 'synset': 'music_stool.n.01', 'synonyms': ['music_stool', 'piano_stool'], 'id': 710, 'def': 'a stool for piano players; usually adjustable in height', 'name': 'music_stool'}, {'frequency': 'c', 'synset': 'musical_instrument.n.01', 'synonyms': ['musical_instrument', 'instrument_(musical)'], 'id': 711, 'def': 'any of various devices or contrivances that can be used to produce musical tones or sounds', 'name': 'musical_instrument'}, {'frequency': 'r', 'synset': 'nailfile.n.01', 'synonyms': ['nailfile'], 'id': 712, 'def': 'a small flat file for shaping the nails', 'name': 'nailfile'}, {'frequency': 'f', 'synset': 'napkin.n.01', 'synonyms': ['napkin', 'table_napkin', 'serviette'], 'id': 713, 'def': 'a small piece of table linen or paper that is used to wipe the mouth and to cover the lap in order to protect clothing', 'name': 'napkin'}, {'frequency': 'r', 'synset': 'neckerchief.n.01', 'synonyms': ['neckerchief'], 'id': 714, 'def': 'a kerchief worn around the neck', 'name': 'neckerchief'}, {'frequency': 'f', 'synset': 'necklace.n.01', 'synonyms': ['necklace'], 'id': 715, 'def': 'jewelry consisting of a cord or chain (often bearing gems) worn about the neck as an ornament', 'name': 'necklace'}, {'frequency': 'f', 'synset': 'necktie.n.01', 'synonyms': ['necktie', 'tie_(necktie)'], 'id': 716, 'def': 'neckwear consisting of a long narrow piece of material worn under a collar and tied in knot at the front', 'name': 'necktie'}, {'frequency': 'c', 'synset': 'needle.n.03', 'synonyms': ['needle'], 'id': 717, 'def': 'a sharp pointed implement (usually metal)', 'name': 'needle'}, {'frequency': 'c', 'synset': 'nest.n.01', 'synonyms': ['nest'], 'id': 718, 'def': 'a structure in which animals lay eggs or give birth to their young', 'name': 'nest'}, {'frequency': 'f', 'synset': 'newspaper.n.01', 'synonyms': ['newspaper', 'paper_(newspaper)'], 'id': 719, 'def': 'a daily or weekly publication on folded sheets containing news, articles, and advertisements', 'name': 'newspaper'}, {'frequency': 'c', 'synset': 'newsstand.n.01', 'synonyms': ['newsstand'], 'id': 720, 'def': 'a stall where newspapers and other periodicals are sold', 'name': 'newsstand'}, {'frequency': 'c', 'synset': 'nightwear.n.01', 'synonyms': ['nightshirt', 'nightwear', 'sleepwear', 'nightclothes'], 'id': 721, 'def': 'garments designed to be worn in bed', 'name': 'nightshirt'}, {'frequency': 'r', 'synset': 'nosebag.n.01', 'synonyms': ['nosebag_(for_animals)', 'feedbag'], 'id': 722, 'def': 'a canvas bag that is used to feed an animal (such as a horse); covers the muzzle and fastens at the top of the head', 'name': 'nosebag_(for_animals)'}, {'frequency': 'c', 'synset': 'noseband.n.01', 'synonyms': ['noseband_(for_animals)', 'nosepiece_(for_animals)'], 'id': 723, 'def': "a strap that is the part of a bridle that goes over the animal's nose", 'name': 'noseband_(for_animals)'}, {'frequency': 'f', 'synset': 'notebook.n.01', 'synonyms': ['notebook'], 'id': 724, 'def': 'a book with blank pages for recording notes or memoranda', 'name': 'notebook'}, {'frequency': 'c', 'synset': 'notepad.n.01', 'synonyms': ['notepad'], 'id': 725, 'def': 'a pad of paper for keeping notes', 'name': 'notepad'}, {'frequency': 'f', 'synset': 'nut.n.03', 'synonyms': ['nut'], 'id': 726, 'def': 'a small metal block (usually square or hexagonal) with internal screw thread to be fitted onto a bolt', 'name': 'nut'}, {'frequency': 'r', 'synset': 'nutcracker.n.01', 'synonyms': ['nutcracker'], 'id': 727, 'def': 'a hand tool used to crack nuts open', 'name': 'nutcracker'}, {'frequency': 'f', 'synset': 'oar.n.01', 'synonyms': ['oar'], 'id': 728, 'def': 'an implement used to propel or steer a boat', 'name': 'oar'}, {'frequency': 'r', 'synset': 'octopus.n.01', 'synonyms': ['octopus_(food)'], 'id': 729, 'def': 'tentacles of octopus prepared as food', 'name': 'octopus_(food)'}, {'frequency': 'r', 'synset': 'octopus.n.02', 'synonyms': ['octopus_(animal)'], 'id': 730, 'def': 'bottom-living cephalopod having a soft oval body with eight long tentacles', 'name': 'octopus_(animal)'}, {'frequency': 'c', 'synset': 'oil_lamp.n.01', 'synonyms': ['oil_lamp', 'kerosene_lamp', 'kerosine_lamp'], 'id': 731, 'def': 'a lamp that burns oil (as kerosine) for light', 'name': 'oil_lamp'}, {'frequency': 'c', 'synset': 'olive_oil.n.01', 'synonyms': ['olive_oil'], 'id': 732, 'def': 'oil from olives', 'name': 'olive_oil'}, {'frequency': 'r', 'synset': 'omelet.n.01', 'synonyms': ['omelet', 'omelette'], 'id': 733, 'def': 'beaten eggs cooked until just set; may be folded around e.g. ham or cheese or jelly', 'name': 'omelet'}, {'frequency': 'f', 'synset': 'onion.n.01', 'synonyms': ['onion'], 'id': 734, 'def': 'the bulb of an onion plant', 'name': 'onion'}, {'frequency': 'f', 'synset': 'orange.n.01', 'synonyms': ['orange_(fruit)'], 'id': 735, 'def': 'orange (FRUIT of an orange tree)', 'name': 'orange_(fruit)'}, {'frequency': 'c', 'synset': 'orange_juice.n.01', 'synonyms': ['orange_juice'], 'id': 736, 'def': 'bottled or freshly squeezed juice of oranges', 'name': 'orange_juice'}, {'frequency': 'c', 'synset': 'ostrich.n.02', 'synonyms': ['ostrich'], 'id': 737, 'def': 'fast-running African flightless bird with two-toed feet; largest living bird', 'name': 'ostrich'}, {'frequency': 'f', 'synset': 'ottoman.n.03', 'synonyms': ['ottoman', 'pouf', 'pouffe', 'hassock'], 'id': 738, 'def': 'a thick standalone cushion used as a seat or footrest, often next to a chair', 'name': 'ottoman'}, {'frequency': 'f', 'synset': 'oven.n.01', 'synonyms': ['oven'], 'id': 739, 'def': 'kitchen appliance used for baking or roasting', 'name': 'oven'}, {'frequency': 'c', 'synset': 'overall.n.01', 'synonyms': ['overalls_(clothing)'], 'id': 740, 'def': 'work clothing consisting of denim trousers usually with a bib and shoulder straps', 'name': 'overalls_(clothing)'}, {'frequency': 'c', 'synset': 'owl.n.01', 'synonyms': ['owl'], 'id': 741, 'def': 'nocturnal bird of prey with hawk-like beak and claws and large head with front-facing eyes', 'name': 'owl'}, {'frequency': 'c', 'synset': 'packet.n.03', 'synonyms': ['packet'], 'id': 742, 'def': 'a small package or bundle', 'name': 'packet'}, {'frequency': 'r', 'synset': 'pad.n.03', 'synonyms': ['inkpad', 'inking_pad', 'stamp_pad'], 'id': 743, 'def': 'absorbent material saturated with ink used to transfer ink evenly to a rubber stamp', 'name': 'inkpad'}, {'frequency': 'c', 'synset': 'pad.n.04', 'synonyms': ['pad'], 'id': 744, 'def': 'mostly arm/knee pads labeled', 'name': 'pad'}, {'frequency': 'f', 'synset': 'paddle.n.04', 'synonyms': ['paddle', 'boat_paddle'], 'id': 745, 'def': 'a short light oar used without an oarlock to propel a canoe or small boat', 'name': 'paddle'}, {'frequency': 'c', 'synset': 'padlock.n.01', 'synonyms': ['padlock'], 'id': 746, 'def': 'a detachable, portable lock', 'name': 'padlock'}, {'frequency': 'c', 'synset': 'paintbrush.n.01', 'synonyms': ['paintbrush'], 'id': 747, 'def': 'a brush used as an applicator to apply paint', 'name': 'paintbrush'}, {'frequency': 'f', 'synset': 'painting.n.01', 'synonyms': ['painting'], 'id': 748, 'def': 'graphic art consisting of an artistic composition made by applying paints to a surface', 'name': 'painting'}, {'frequency': 'f', 'synset': 'pajama.n.02', 'synonyms': ['pajamas', 'pyjamas'], 'id': 749, 'def': 'loose-fitting nightclothes worn for sleeping or lounging', 'name': 'pajamas'}, {'frequency': 'c', 'synset': 'palette.n.02', 'synonyms': ['palette', 'pallet'], 'id': 750, 'def': 'board that provides a flat surface on which artists mix paints and the range of colors used', 'name': 'palette'}, {'frequency': 'f', 'synset': 'pan.n.01', 'synonyms': ['pan_(for_cooking)', 'cooking_pan'], 'id': 751, 'def': 'cooking utensil consisting of a wide metal vessel', 'name': 'pan_(for_cooking)'}, {'frequency': 'r', 'synset': 'pan.n.03', 'synonyms': ['pan_(metal_container)'], 'id': 752, 'def': 'shallow container made of metal', 'name': 'pan_(metal_container)'}, {'frequency': 'c', 'synset': 'pancake.n.01', 'synonyms': ['pancake'], 'id': 753, 'def': 'a flat cake of thin batter fried on both sides on a griddle', 'name': 'pancake'}, {'frequency': 'r', 'synset': 'pantyhose.n.01', 'synonyms': ['pantyhose'], 'id': 754, 'def': "a woman's tights consisting of underpants and stockings", 'name': 'pantyhose'}, {'frequency': 'r', 'synset': 'papaya.n.02', 'synonyms': ['papaya'], 'id': 755, 'def': 'large oval melon-like tropical fruit with yellowish flesh', 'name': 'papaya'}, {'frequency': 'f', 'synset': 'paper_plate.n.01', 'synonyms': ['paper_plate'], 'id': 756, 'def': 'a disposable plate made of cardboard', 'name': 'paper_plate'}, {'frequency': 'f', 'synset': 'paper_towel.n.01', 'synonyms': ['paper_towel'], 'id': 757, 'def': 'a disposable towel made of absorbent paper', 'name': 'paper_towel'}, {'frequency': 'r', 'synset': 'paperback_book.n.01', 'synonyms': ['paperback_book', 'paper-back_book', 'softback_book', 'soft-cover_book'], 'id': 758, 'def': 'a book with paper covers', 'name': 'paperback_book'}, {'frequency': 'r', 'synset': 'paperweight.n.01', 'synonyms': ['paperweight'], 'id': 759, 'def': 'a weight used to hold down a stack of papers', 'name': 'paperweight'}, {'frequency': 'c', 'synset': 'parachute.n.01', 'synonyms': ['parachute'], 'id': 760, 'def': 'rescue equipment consisting of a device that fills with air and retards your fall', 'name': 'parachute'}, {'frequency': 'c', 'synset': 'parakeet.n.01', 'synonyms': ['parakeet', 'parrakeet', 'parroket', 'paraquet', 'paroquet', 'parroquet'], 'id': 761, 'def': 'any of numerous small slender long-tailed parrots', 'name': 'parakeet'}, {'frequency': 'c', 'synset': 'parasail.n.01', 'synonyms': ['parasail_(sports)'], 'id': 762, 'def': 'parachute that will lift a person up into the air when it is towed by a motorboat or a car', 'name': 'parasail_(sports)'}, {'frequency': 'c', 'synset': 'parasol.n.01', 'synonyms': ['parasol', 'sunshade'], 'id': 763, 'def': 'a handheld collapsible source of shade', 'name': 'parasol'}, {'frequency': 'r', 'synset': 'parchment.n.01', 'synonyms': ['parchment'], 'id': 764, 'def': 'a superior paper resembling sheepskin', 'name': 'parchment'}, {'frequency': 'c', 'synset': 'parka.n.01', 'synonyms': ['parka', 'anorak'], 'id': 765, 'def': "a kind of heavy jacket (`windcheater' is a British term)", 'name': 'parka'}, {'frequency': 'f', 'synset': 'parking_meter.n.01', 'synonyms': ['parking_meter'], 'id': 766, 'def': 'a coin-operated timer located next to a parking space', 'name': 'parking_meter'}, {'frequency': 'c', 'synset': 'parrot.n.01', 'synonyms': ['parrot'], 'id': 767, 'def': 'usually brightly colored tropical birds with short hooked beaks and the ability to mimic sounds', 'name': 'parrot'}, {'frequency': 'c', 'synset': 'passenger_car.n.01', 'synonyms': ['passenger_car_(part_of_a_train)', 'coach_(part_of_a_train)'], 'id': 768, 'def': 'a railcar where passengers ride', 'name': 'passenger_car_(part_of_a_train)'}, {'frequency': 'r', 'synset': 'passenger_ship.n.01', 'synonyms': ['passenger_ship'], 'id': 769, 'def': 'a ship built to carry passengers', 'name': 'passenger_ship'}, {'frequency': 'c', 'synset': 'passport.n.02', 'synonyms': ['passport'], 'id': 770, 'def': 'a document issued by a country to a citizen allowing that person to travel abroad and re-enter the home country', 'name': 'passport'}, {'frequency': 'f', 'synset': 'pastry.n.02', 'synonyms': ['pastry'], 'id': 771, 'def': 'any of various baked foods made of dough or batter', 'name': 'pastry'}, {'frequency': 'r', 'synset': 'patty.n.01', 'synonyms': ['patty_(food)'], 'id': 772, 'def': 'small flat mass of chopped food', 'name': 'patty_(food)'}, {'frequency': 'c', 'synset': 'pea.n.01', 'synonyms': ['pea_(food)'], 'id': 773, 'def': 'seed of a pea plant used for food', 'name': 'pea_(food)'}, {'frequency': 'c', 'synset': 'peach.n.03', 'synonyms': ['peach'], 'id': 774, 'def': 'downy juicy fruit with sweet yellowish or whitish flesh', 'name': 'peach'}, {'frequency': 'c', 'synset': 'peanut_butter.n.01', 'synonyms': ['peanut_butter'], 'id': 775, 'def': 'a spread made from ground peanuts', 'name': 'peanut_butter'}, {'frequency': 'f', 'synset': 'pear.n.01', 'synonyms': ['pear'], 'id': 776, 'def': 'sweet juicy gritty-textured fruit available in many varieties', 'name': 'pear'}, {'frequency': 'c', 'synset': 'peeler.n.03', 'synonyms': ['peeler_(tool_for_fruit_and_vegetables)'], 'id': 777, 'def': 'a device for peeling vegetables or fruits', 'name': 'peeler_(tool_for_fruit_and_vegetables)'}, {'frequency': 'r', 'synset': 'peg.n.04', 'synonyms': ['wooden_leg', 'pegleg'], 'id': 778, 'def': 'a prosthesis that replaces a missing leg', 'name': 'wooden_leg'}, {'frequency': 'r', 'synset': 'pegboard.n.01', 'synonyms': ['pegboard'], 'id': 779, 'def': 'a board perforated with regularly spaced holes into which pegs can be fitted', 'name': 'pegboard'}, {'frequency': 'c', 'synset': 'pelican.n.01', 'synonyms': ['pelican'], 'id': 780, 'def': 'large long-winged warm-water seabird having a large bill with a distensible pouch for fish', 'name': 'pelican'}, {'frequency': 'f', 'synset': 'pen.n.01', 'synonyms': ['pen'], 'id': 781, 'def': 'a writing implement with a point from which ink flows', 'name': 'pen'}, {'frequency': 'f', 'synset': 'pencil.n.01', 'synonyms': ['pencil'], 'id': 782, 'def': 'a thin cylindrical pointed writing implement made of wood and graphite', 'name': 'pencil'}, {'frequency': 'r', 'synset': 'pencil_box.n.01', 'synonyms': ['pencil_box', 'pencil_case'], 'id': 783, 'def': 'a box for holding pencils', 'name': 'pencil_box'}, {'frequency': 'r', 'synset': 'pencil_sharpener.n.01', 'synonyms': ['pencil_sharpener'], 'id': 784, 'def': 'a rotary implement for sharpening the point on pencils', 'name': 'pencil_sharpener'}, {'frequency': 'r', 'synset': 'pendulum.n.01', 'synonyms': ['pendulum'], 'id': 785, 'def': 'an apparatus consisting of an object mounted so that it swings freely under the influence of gravity', 'name': 'pendulum'}, {'frequency': 'c', 'synset': 'penguin.n.01', 'synonyms': ['penguin'], 'id': 786, 'def': 'short-legged flightless birds of cold southern regions having webbed feet and wings modified as flippers', 'name': 'penguin'}, {'frequency': 'r', 'synset': 'pennant.n.02', 'synonyms': ['pennant'], 'id': 787, 'def': 'a flag longer than it is wide (and often tapering)', 'name': 'pennant'}, {'frequency': 'r', 'synset': 'penny.n.02', 'synonyms': ['penny_(coin)'], 'id': 788, 'def': 'a coin worth one-hundredth of the value of the basic unit', 'name': 'penny_(coin)'}, {'frequency': 'f', 'synset': 'pepper.n.03', 'synonyms': ['pepper', 'peppercorn'], 'id': 789, 'def': 'pungent seasoning from the berry of the common pepper plant; whole or ground', 'name': 'pepper'}, {'frequency': 'c', 'synset': 'pepper_mill.n.01', 'synonyms': ['pepper_mill', 'pepper_grinder'], 'id': 790, 'def': 'a mill for grinding pepper', 'name': 'pepper_mill'}, {'frequency': 'c', 'synset': 'perfume.n.02', 'synonyms': ['perfume'], 'id': 791, 'def': 'a toiletry that emits and diffuses a fragrant odor', 'name': 'perfume'}, {'frequency': 'r', 'synset': 'persimmon.n.02', 'synonyms': ['persimmon'], 'id': 792, 'def': 'orange fruit resembling a plum; edible when fully ripe', 'name': 'persimmon'}, {'frequency': 'f', 'synset': 'person.n.01', 'synonyms': ['person', 'baby', 'child', 'boy', 'girl', 'man', 'woman', 'human'], 'id': 793, 'def': 'a human being', 'name': 'person'}, {'frequency': 'c', 'synset': 'pet.n.01', 'synonyms': ['pet'], 'id': 794, 'def': 'a domesticated animal kept for companionship or amusement', 'name': 'pet'}, {'frequency': 'c', 'synset': 'pew.n.01', 'synonyms': ['pew_(church_bench)', 'church_bench'], 'id': 795, 'def': 'long bench with backs; used in church by the congregation', 'name': 'pew_(church_bench)'}, {'frequency': 'r', 'synset': 'phonebook.n.01', 'synonyms': ['phonebook', 'telephone_book', 'telephone_directory'], 'id': 796, 'def': 'a directory containing an alphabetical list of telephone subscribers and their telephone numbers', 'name': 'phonebook'}, {'frequency': 'c', 'synset': 'phonograph_record.n.01', 'synonyms': ['phonograph_record', 'phonograph_recording', 'record_(phonograph_recording)'], 'id': 797, 'def': 'sound recording consisting of a typically black disk with a continuous groove', 'name': 'phonograph_record'}, {'frequency': 'f', 'synset': 'piano.n.01', 'synonyms': ['piano'], 'id': 798, 'def': 'a keyboard instrument that is played by depressing keys that cause hammers to strike tuned strings and produce sounds', 'name': 'piano'}, {'frequency': 'f', 'synset': 'pickle.n.01', 'synonyms': ['pickle'], 'id': 799, 'def': 'vegetables (especially cucumbers) preserved in brine or vinegar', 'name': 'pickle'}, {'frequency': 'f', 'synset': 'pickup.n.01', 'synonyms': ['pickup_truck'], 'id': 800, 'def': 'a light truck with an open body and low sides and a tailboard', 'name': 'pickup_truck'}, {'frequency': 'c', 'synset': 'pie.n.01', 'synonyms': ['pie'], 'id': 801, 'def': 'dish baked in pastry-lined pan often with a pastry top', 'name': 'pie'}, {'frequency': 'c', 'synset': 'pigeon.n.01', 'synonyms': ['pigeon'], 'id': 802, 'def': 'wild and domesticated birds having a heavy body and short legs', 'name': 'pigeon'}, {'frequency': 'r', 'synset': 'piggy_bank.n.01', 'synonyms': ['piggy_bank', 'penny_bank'], 'id': 803, 'def': "a child's coin bank (often shaped like a pig)", 'name': 'piggy_bank'}, {'frequency': 'f', 'synset': 'pillow.n.01', 'synonyms': ['pillow'], 'id': 804, 'def': 'a cushion to support the head of a sleeping person', 'name': 'pillow'}, {'frequency': 'r', 'synset': 'pin.n.09', 'synonyms': ['pin_(non_jewelry)'], 'id': 805, 'def': 'a small slender (often pointed) piece of wood or metal used to support or fasten or attach things', 'name': 'pin_(non_jewelry)'}, {'frequency': 'f', 'synset': 'pineapple.n.02', 'synonyms': ['pineapple'], 'id': 806, 'def': 'large sweet fleshy tropical fruit with a tuft of stiff leaves', 'name': 'pineapple'}, {'frequency': 'c', 'synset': 'pinecone.n.01', 'synonyms': ['pinecone'], 'id': 807, 'def': 'the seed-producing cone of a pine tree', 'name': 'pinecone'}, {'frequency': 'r', 'synset': 'ping-pong_ball.n.01', 'synonyms': ['ping-pong_ball'], 'id': 808, 'def': 'light hollow ball used in playing table tennis', 'name': 'ping-pong_ball'}, {'frequency': 'r', 'synset': 'pinwheel.n.03', 'synonyms': ['pinwheel'], 'id': 809, 'def': 'a toy consisting of vanes of colored paper or plastic that is pinned to a stick and spins when it is pointed into the wind', 'name': 'pinwheel'}, {'frequency': 'r', 'synset': 'pipe.n.01', 'synonyms': ['tobacco_pipe'], 'id': 810, 'def': 'a tube with a small bowl at one end; used for smoking tobacco', 'name': 'tobacco_pipe'}, {'frequency': 'f', 'synset': 'pipe.n.02', 'synonyms': ['pipe', 'piping'], 'id': 811, 'def': 'a long tube made of metal or plastic that is used to carry water or oil or gas etc.', 'name': 'pipe'}, {'frequency': 'r', 'synset': 'pistol.n.01', 'synonyms': ['pistol', 'handgun'], 'id': 812, 'def': 'a firearm that is held and fired with one hand', 'name': 'pistol'}, {'frequency': 'c', 'synset': 'pita.n.01', 'synonyms': ['pita_(bread)', 'pocket_bread'], 'id': 813, 'def': 'usually small round bread that can open into a pocket for filling', 'name': 'pita_(bread)'}, {'frequency': 'f', 'synset': 'pitcher.n.02', 'synonyms': ['pitcher_(vessel_for_liquid)', 'ewer'], 'id': 814, 'def': 'an open vessel with a handle and a spout for pouring', 'name': 'pitcher_(vessel_for_liquid)'}, {'frequency': 'r', 'synset': 'pitchfork.n.01', 'synonyms': ['pitchfork'], 'id': 815, 'def': 'a long-handled hand tool with sharp widely spaced prongs for lifting and pitching hay', 'name': 'pitchfork'}, {'frequency': 'f', 'synset': 'pizza.n.01', 'synonyms': ['pizza'], 'id': 816, 'def': 'Italian open pie made of thin bread dough spread with a spiced mixture of e.g. tomato sauce and cheese', 'name': 'pizza'}, {'frequency': 'f', 'synset': 'place_mat.n.01', 'synonyms': ['place_mat'], 'id': 817, 'def': 'a mat placed on a table for an individual place setting', 'name': 'place_mat'}, {'frequency': 'f', 'synset': 'plate.n.04', 'synonyms': ['plate'], 'id': 818, 'def': 'dish on which food is served or from which food is eaten', 'name': 'plate'}, {'frequency': 'c', 'synset': 'platter.n.01', 'synonyms': ['platter'], 'id': 819, 'def': 'a large shallow dish used for serving food', 'name': 'platter'}, {'frequency': 'r', 'synset': 'playpen.n.01', 'synonyms': ['playpen'], 'id': 820, 'def': 'a portable enclosure in which babies may be left to play', 'name': 'playpen'}, {'frequency': 'c', 'synset': 'pliers.n.01', 'synonyms': ['pliers', 'plyers'], 'id': 821, 'def': 'a gripping hand tool with two hinged arms and (usually) serrated jaws', 'name': 'pliers'}, {'frequency': 'r', 'synset': 'plow.n.01', 'synonyms': ['plow_(farm_equipment)', 'plough_(farm_equipment)'], 'id': 822, 'def': 'a farm tool having one or more heavy blades to break the soil and cut a furrow prior to sowing', 'name': 'plow_(farm_equipment)'}, {'frequency': 'r', 'synset': 'plume.n.02', 'synonyms': ['plume'], 'id': 823, 'def': 'a feather or cluster of feathers worn as an ornament', 'name': 'plume'}, {'frequency': 'r', 'synset': 'pocket_watch.n.01', 'synonyms': ['pocket_watch'], 'id': 824, 'def': 'a watch that is carried in a small watch pocket', 'name': 'pocket_watch'}, {'frequency': 'c', 'synset': 'pocketknife.n.01', 'synonyms': ['pocketknife'], 'id': 825, 'def': 'a knife with a blade that folds into the handle; suitable for carrying in the pocket', 'name': 'pocketknife'}, {'frequency': 'c', 'synset': 'poker.n.01', 'synonyms': ['poker_(fire_stirring_tool)', 'stove_poker', 'fire_hook'], 'id': 826, 'def': 'fire iron consisting of a metal rod with a handle; used to stir a fire', 'name': 'poker_(fire_stirring_tool)'}, {'frequency': 'f', 'synset': 'pole.n.01', 'synonyms': ['pole', 'post'], 'id': 827, 'def': 'a long (usually round) rod of wood or metal or plastic', 'name': 'pole'}, {'frequency': 'f', 'synset': 'polo_shirt.n.01', 'synonyms': ['polo_shirt', 'sport_shirt'], 'id': 828, 'def': 'a shirt with short sleeves designed for comfort and casual wear', 'name': 'polo_shirt'}, {'frequency': 'r', 'synset': 'poncho.n.01', 'synonyms': ['poncho'], 'id': 829, 'def': 'a blanket-like cloak with a hole in the center for the head', 'name': 'poncho'}, {'frequency': 'c', 'synset': 'pony.n.05', 'synonyms': ['pony'], 'id': 830, 'def': 'any of various breeds of small gentle horses usually less than five feet high at the shoulder', 'name': 'pony'}, {'frequency': 'r', 'synset': 'pool_table.n.01', 'synonyms': ['pool_table', 'billiard_table', 'snooker_table'], 'id': 831, 'def': 'game equipment consisting of a heavy table on which pool is played', 'name': 'pool_table'}, {'frequency': 'f', 'synset': 'pop.n.02', 'synonyms': ['pop_(soda)', 'soda_(pop)', 'tonic', 'soft_drink'], 'id': 832, 'def': 'a sweet drink containing carbonated water and flavoring', 'name': 'pop_(soda)'}, {'frequency': 'c', 'synset': 'postbox.n.01', 'synonyms': ['postbox_(public)', 'mailbox_(public)'], 'id': 833, 'def': 'public box for deposit of mail', 'name': 'postbox_(public)'}, {'frequency': 'c', 'synset': 'postcard.n.01', 'synonyms': ['postcard', 'postal_card', 'mailing-card'], 'id': 834, 'def': 'a card for sending messages by post without an envelope', 'name': 'postcard'}, {'frequency': 'f', 'synset': 'poster.n.01', 'synonyms': ['poster', 'placard'], 'id': 835, 'def': 'a sign posted in a public place as an advertisement', 'name': 'poster'}, {'frequency': 'f', 'synset': 'pot.n.01', 'synonyms': ['pot'], 'id': 836, 'def': 'metal or earthenware cooking vessel that is usually round and deep; often has a handle and lid', 'name': 'pot'}, {'frequency': 'f', 'synset': 'pot.n.04', 'synonyms': ['flowerpot'], 'id': 837, 'def': 'a container in which plants are cultivated', 'name': 'flowerpot'}, {'frequency': 'f', 'synset': 'potato.n.01', 'synonyms': ['potato'], 'id': 838, 'def': 'an edible tuber native to South America', 'name': 'potato'}, {'frequency': 'c', 'synset': 'potholder.n.01', 'synonyms': ['potholder'], 'id': 839, 'def': 'an insulated pad for holding hot pots', 'name': 'potholder'}, {'frequency': 'c', 'synset': 'pottery.n.01', 'synonyms': ['pottery', 'clayware'], 'id': 840, 'def': 'ceramic ware made from clay and baked in a kiln', 'name': 'pottery'}, {'frequency': 'c', 'synset': 'pouch.n.01', 'synonyms': ['pouch'], 'id': 841, 'def': 'a small or medium size container for holding or carrying things', 'name': 'pouch'}, {'frequency': 'c', 'synset': 'power_shovel.n.01', 'synonyms': ['power_shovel', 'excavator', 'digger'], 'id': 842, 'def': 'a machine for excavating', 'name': 'power_shovel'}, {'frequency': 'c', 'synset': 'prawn.n.01', 'synonyms': ['prawn', 'shrimp'], 'id': 843, 'def': 'any of various edible decapod crustaceans', 'name': 'prawn'}, {'frequency': 'c', 'synset': 'pretzel.n.01', 'synonyms': ['pretzel'], 'id': 844, 'def': 'glazed and salted cracker typically in the shape of a loose knot', 'name': 'pretzel'}, {'frequency': 'f', 'synset': 'printer.n.03', 'synonyms': ['printer', 'printing_machine'], 'id': 845, 'def': 'a machine that prints', 'name': 'printer'}, {'frequency': 'c', 'synset': 'projectile.n.01', 'synonyms': ['projectile_(weapon)', 'missile'], 'id': 846, 'def': 'a weapon that is forcibly thrown or projected at a targets', 'name': 'projectile_(weapon)'}, {'frequency': 'c', 'synset': 'projector.n.02', 'synonyms': ['projector'], 'id': 847, 'def': 'an optical instrument that projects an enlarged image onto a screen', 'name': 'projector'}, {'frequency': 'f', 'synset': 'propeller.n.01', 'synonyms': ['propeller', 'propellor'], 'id': 848, 'def': 'a mechanical device that rotates to push against air or water', 'name': 'propeller'}, {'frequency': 'r', 'synset': 'prune.n.01', 'synonyms': ['prune'], 'id': 849, 'def': 'dried plum', 'name': 'prune'}, {'frequency': 'r', 'synset': 'pudding.n.01', 'synonyms': ['pudding'], 'id': 850, 'def': 'any of various soft thick unsweetened baked dishes', 'name': 'pudding'}, {'frequency': 'r', 'synset': 'puffer.n.02', 'synonyms': ['puffer_(fish)', 'pufferfish', 'blowfish', 'globefish'], 'id': 851, 'def': 'fishes whose elongated spiny body can inflate itself with water or air to form a globe', 'name': 'puffer_(fish)'}, {'frequency': 'r', 'synset': 'puffin.n.01', 'synonyms': ['puffin'], 'id': 852, 'def': 'seabirds having short necks and brightly colored compressed bills', 'name': 'puffin'}, {'frequency': 'r', 'synset': 'pug.n.01', 'synonyms': ['pug-dog'], 'id': 853, 'def': 'small compact smooth-coated breed of Asiatic origin having a tightly curled tail and broad flat wrinkled muzzle', 'name': 'pug-dog'}, {'frequency': 'c', 'synset': 'pumpkin.n.02', 'synonyms': ['pumpkin'], 'id': 854, 'def': 'usually large pulpy deep-yellow round fruit of the squash family maturing in late summer or early autumn', 'name': 'pumpkin'}, {'frequency': 'r', 'synset': 'punch.n.03', 'synonyms': ['puncher'], 'id': 855, 'def': 'a tool for making holes or indentations', 'name': 'puncher'}, {'frequency': 'r', 'synset': 'puppet.n.01', 'synonyms': ['puppet', 'marionette'], 'id': 856, 'def': 'a small figure of a person operated from above with strings by a puppeteer', 'name': 'puppet'}, {'frequency': 'c', 'synset': 'puppy.n.01', 'synonyms': ['puppy'], 'id': 857, 'def': 'a young dog', 'name': 'puppy'}, {'frequency': 'r', 'synset': 'quesadilla.n.01', 'synonyms': ['quesadilla'], 'id': 858, 'def': 'a tortilla that is filled with cheese and heated', 'name': 'quesadilla'}, {'frequency': 'r', 'synset': 'quiche.n.02', 'synonyms': ['quiche'], 'id': 859, 'def': 'a tart filled with rich unsweetened custard; often contains other ingredients (as cheese or ham or seafood or vegetables)', 'name': 'quiche'}, {'frequency': 'f', 'synset': 'quilt.n.01', 'synonyms': ['quilt', 'comforter'], 'id': 860, 'def': 'bedding made of two layers of cloth filled with stuffing and stitched together', 'name': 'quilt'}, {'frequency': 'c', 'synset': 'rabbit.n.01', 'synonyms': ['rabbit'], 'id': 861, 'def': 'any of various burrowing animals of the family Leporidae having long ears and short tails', 'name': 'rabbit'}, {'frequency': 'r', 'synset': 'racer.n.02', 'synonyms': ['race_car', 'racing_car'], 'id': 862, 'def': 'a fast car that competes in races', 'name': 'race_car'}, {'frequency': 'c', 'synset': 'racket.n.04', 'synonyms': ['racket', 'racquet'], 'id': 863, 'def': 'a sports implement used to strike a ball in various games', 'name': 'racket'}, {'frequency': 'r', 'synset': 'radar.n.01', 'synonyms': ['radar'], 'id': 864, 'def': 'measuring instrument in which the echo of a pulse of microwave radiation is used to detect and locate distant objects', 'name': 'radar'}, {'frequency': 'f', 'synset': 'radiator.n.03', 'synonyms': ['radiator'], 'id': 865, 'def': 'a mechanism consisting of a metal honeycomb through which hot fluids circulate', 'name': 'radiator'}, {'frequency': 'c', 'synset': 'radio_receiver.n.01', 'synonyms': ['radio_receiver', 'radio_set', 'radio', 'tuner_(radio)'], 'id': 866, 'def': 'an electronic receiver that detects and demodulates and amplifies transmitted radio signals', 'name': 'radio_receiver'}, {'frequency': 'c', 'synset': 'radish.n.03', 'synonyms': ['radish', 'daikon'], 'id': 867, 'def': 'pungent edible root of any of various cultivated radish plants', 'name': 'radish'}, {'frequency': 'c', 'synset': 'raft.n.01', 'synonyms': ['raft'], 'id': 868, 'def': 'a flat float (usually made of logs or planks) that can be used for transport or as a platform for swimmers', 'name': 'raft'}, {'frequency': 'r', 'synset': 'rag_doll.n.01', 'synonyms': ['rag_doll'], 'id': 869, 'def': 'a cloth doll that is stuffed and (usually) painted', 'name': 'rag_doll'}, {'frequency': 'c', 'synset': 'raincoat.n.01', 'synonyms': ['raincoat', 'waterproof_jacket'], 'id': 870, 'def': 'a water-resistant coat', 'name': 'raincoat'}, {'frequency': 'c', 'synset': 'ram.n.05', 'synonyms': ['ram_(animal)'], 'id': 871, 'def': 'uncastrated adult male sheep', 'name': 'ram_(animal)'}, {'frequency': 'c', 'synset': 'raspberry.n.02', 'synonyms': ['raspberry'], 'id': 872, 'def': 'red or black edible aggregate berries usually smaller than the related blackberries', 'name': 'raspberry'}, {'frequency': 'r', 'synset': 'rat.n.01', 'synonyms': ['rat'], 'id': 873, 'def': 'any of various long-tailed rodents similar to but larger than a mouse', 'name': 'rat'}, {'frequency': 'c', 'synset': 'razorblade.n.01', 'synonyms': ['razorblade'], 'id': 874, 'def': 'a blade that has very sharp edge', 'name': 'razorblade'}, {'frequency': 'c', 'synset': 'reamer.n.01', 'synonyms': ['reamer_(juicer)', 'juicer', 'juice_reamer'], 'id': 875, 'def': 'a squeezer with a conical ridged center that is used for squeezing juice from citrus fruit', 'name': 'reamer_(juicer)'}, {'frequency': 'f', 'synset': 'rearview_mirror.n.01', 'synonyms': ['rearview_mirror'], 'id': 876, 'def': 'vehicle mirror (side or rearview)', 'name': 'rearview_mirror'}, {'frequency': 'c', 'synset': 'receipt.n.02', 'synonyms': ['receipt'], 'id': 877, 'def': 'an acknowledgment (usually tangible) that payment has been made', 'name': 'receipt'}, {'frequency': 'c', 'synset': 'recliner.n.01', 'synonyms': ['recliner', 'reclining_chair', 'lounger_(chair)'], 'id': 878, 'def': 'an armchair whose back can be lowered and foot can be raised to allow the sitter to recline in it', 'name': 'recliner'}, {'frequency': 'c', 'synset': 'record_player.n.01', 'synonyms': ['record_player', 'phonograph_(record_player)', 'turntable'], 'id': 879, 'def': 'machine in which rotating records cause a stylus to vibrate and the vibrations are amplified acoustically or electronically', 'name': 'record_player'}, {'frequency': 'f', 'synset': 'reflector.n.01', 'synonyms': ['reflector'], 'id': 880, 'def': 'device that reflects light, radiation, etc.', 'name': 'reflector'}, {'frequency': 'f', 'synset': 'remote_control.n.01', 'synonyms': ['remote_control'], 'id': 881, 'def': 'a device that can be used to control a machine or apparatus from a distance', 'name': 'remote_control'}, {'frequency': 'c', 'synset': 'rhinoceros.n.01', 'synonyms': ['rhinoceros'], 'id': 882, 'def': 'massive powerful herbivorous odd-toed ungulate of southeast Asia and Africa having very thick skin and one or two horns on the snout', 'name': 'rhinoceros'}, {'frequency': 'r', 'synset': 'rib.n.03', 'synonyms': ['rib_(food)'], 'id': 883, 'def': 'cut of meat including one or more ribs', 'name': 'rib_(food)'}, {'frequency': 'c', 'synset': 'rifle.n.01', 'synonyms': ['rifle'], 'id': 884, 'def': 'a shoulder firearm with a long barrel', 'name': 'rifle'}, {'frequency': 'f', 'synset': 'ring.n.08', 'synonyms': ['ring'], 'id': 885, 'def': 'jewelry consisting of a circlet of precious metal (often set with jewels) worn on the finger', 'name': 'ring'}, {'frequency': 'r', 'synset': 'river_boat.n.01', 'synonyms': ['river_boat'], 'id': 886, 'def': 'a boat used on rivers or to ply a river', 'name': 'river_boat'}, {'frequency': 'r', 'synset': 'road_map.n.02', 'synonyms': ['road_map'], 'id': 887, 'def': '(NOT A ROAD) a MAP showing roads (for automobile travel)', 'name': 'road_map'}, {'frequency': 'c', 'synset': 'robe.n.01', 'synonyms': ['robe'], 'id': 888, 'def': 'any loose flowing garment', 'name': 'robe'}, {'frequency': 'c', 'synset': 'rocking_chair.n.01', 'synonyms': ['rocking_chair'], 'id': 889, 'def': 'a chair mounted on rockers', 'name': 'rocking_chair'}, {'frequency': 'r', 'synset': 'rodent.n.01', 'synonyms': ['rodent'], 'id': 890, 'def': 'relatively small placental mammals having a single pair of constantly growing incisor teeth specialized for gnawing', 'name': 'rodent'}, {'frequency': 'r', 'synset': 'roller_skate.n.01', 'synonyms': ['roller_skate'], 'id': 891, 'def': 'a shoe with pairs of rollers (small hard wheels) fixed to the sole', 'name': 'roller_skate'}, {'frequency': 'r', 'synset': 'rollerblade.n.01', 'synonyms': ['Rollerblade'], 'id': 892, 'def': 'an in-line variant of a roller skate', 'name': 'Rollerblade'}, {'frequency': 'c', 'synset': 'rolling_pin.n.01', 'synonyms': ['rolling_pin'], 'id': 893, 'def': 'utensil consisting of a cylinder (usually of wood) with a handle at each end; used to roll out dough', 'name': 'rolling_pin'}, {'frequency': 'r', 'synset': 'root_beer.n.01', 'synonyms': ['root_beer'], 'id': 894, 'def': 'carbonated drink containing extracts of roots and herbs', 'name': 'root_beer'}, {'frequency': 'c', 'synset': 'router.n.02', 'synonyms': ['router_(computer_equipment)'], 'id': 895, 'def': 'a device that forwards data packets between computer networks', 'name': 'router_(computer_equipment)'}, {'frequency': 'f', 'synset': 'rubber_band.n.01', 'synonyms': ['rubber_band', 'elastic_band'], 'id': 896, 'def': 'a narrow band of elastic rubber used to hold things (such as papers) together', 'name': 'rubber_band'}, {'frequency': 'c', 'synset': 'runner.n.08', 'synonyms': ['runner_(carpet)'], 'id': 897, 'def': 'a long narrow carpet', 'name': 'runner_(carpet)'}, {'frequency': 'f', 'synset': 'sack.n.01', 'synonyms': ['plastic_bag', 'paper_bag'], 'id': 898, 'def': "a bag made of paper or plastic for holding customer's purchases", 'name': 'plastic_bag'}, {'frequency': 'f', 'synset': 'saddle.n.01', 'synonyms': ['saddle_(on_an_animal)'], 'id': 899, 'def': 'a seat for the rider of a horse or camel', 'name': 'saddle_(on_an_animal)'}, {'frequency': 'f', 'synset': 'saddle_blanket.n.01', 'synonyms': ['saddle_blanket', 'saddlecloth', 'horse_blanket'], 'id': 900, 'def': 'stable gear consisting of a blanket placed under the saddle', 'name': 'saddle_blanket'}, {'frequency': 'c', 'synset': 'saddlebag.n.01', 'synonyms': ['saddlebag'], 'id': 901, 'def': 'a large bag (or pair of bags) hung over a saddle', 'name': 'saddlebag'}, {'frequency': 'r', 'synset': 'safety_pin.n.01', 'synonyms': ['safety_pin'], 'id': 902, 'def': 'a pin in the form of a clasp; has a guard so the point of the pin will not stick the user', 'name': 'safety_pin'}, {'frequency': 'f', 'synset': 'sail.n.01', 'synonyms': ['sail'], 'id': 903, 'def': 'a large piece of fabric by means of which wind is used to propel a sailing vessel', 'name': 'sail'}, {'frequency': 'f', 'synset': 'salad.n.01', 'synonyms': ['salad'], 'id': 904, 'def': 'food mixtures either arranged on a plate or tossed and served with a moist dressing; usually consisting of or including greens', 'name': 'salad'}, {'frequency': 'r', 'synset': 'salad_plate.n.01', 'synonyms': ['salad_plate', 'salad_bowl'], 'id': 905, 'def': 'a plate or bowl for individual servings of salad', 'name': 'salad_plate'}, {'frequency': 'c', 'synset': 'salami.n.01', 'synonyms': ['salami'], 'id': 906, 'def': 'highly seasoned fatty sausage of pork and beef usually dried', 'name': 'salami'}, {'frequency': 'c', 'synset': 'salmon.n.01', 'synonyms': ['salmon_(fish)'], 'id': 907, 'def': 'any of various large food and game fishes of northern waters', 'name': 'salmon_(fish)'}, {'frequency': 'r', 'synset': 'salmon.n.03', 'synonyms': ['salmon_(food)'], 'id': 908, 'def': 'flesh of any of various marine or freshwater fish of the family Salmonidae', 'name': 'salmon_(food)'}, {'frequency': 'c', 'synset': 'salsa.n.01', 'synonyms': ['salsa'], 'id': 909, 'def': 'spicy sauce of tomatoes and onions and chili peppers to accompany Mexican foods', 'name': 'salsa'}, {'frequency': 'f', 'synset': 'saltshaker.n.01', 'synonyms': ['saltshaker'], 'id': 910, 'def': 'a shaker with a perforated top for sprinkling salt', 'name': 'saltshaker'}, {'frequency': 'f', 'synset': 'sandal.n.01', 'synonyms': ['sandal_(type_of_shoe)'], 'id': 911, 'def': 'a shoe consisting of a sole fastened by straps to the foot', 'name': 'sandal_(type_of_shoe)'}, {'frequency': 'f', 'synset': 'sandwich.n.01', 'synonyms': ['sandwich'], 'id': 912, 'def': 'two (or more) slices of bread with a filling between them', 'name': 'sandwich'}, {'frequency': 'r', 'synset': 'satchel.n.01', 'synonyms': ['satchel'], 'id': 913, 'def': 'luggage consisting of a small case with a flat bottom and (usually) a shoulder strap', 'name': 'satchel'}, {'frequency': 'r', 'synset': 'saucepan.n.01', 'synonyms': ['saucepan'], 'id': 914, 'def': 'a deep pan with a handle; used for stewing or boiling', 'name': 'saucepan'}, {'frequency': 'f', 'synset': 'saucer.n.02', 'synonyms': ['saucer'], 'id': 915, 'def': 'a small shallow dish for holding a cup at the table', 'name': 'saucer'}, {'frequency': 'f', 'synset': 'sausage.n.01', 'synonyms': ['sausage'], 'id': 916, 'def': 'highly seasoned minced meat stuffed in casings', 'name': 'sausage'}, {'frequency': 'r', 'synset': 'sawhorse.n.01', 'synonyms': ['sawhorse', 'sawbuck'], 'id': 917, 'def': 'a framework for holding wood that is being sawed', 'name': 'sawhorse'}, {'frequency': 'r', 'synset': 'sax.n.02', 'synonyms': ['saxophone'], 'id': 918, 'def': "a wind instrument with a `J'-shaped form typically made of brass", 'name': 'saxophone'}, {'frequency': 'f', 'synset': 'scale.n.07', 'synonyms': ['scale_(measuring_instrument)'], 'id': 919, 'def': 'a measuring instrument for weighing; shows amount of mass', 'name': 'scale_(measuring_instrument)'}, {'frequency': 'r', 'synset': 'scarecrow.n.01', 'synonyms': ['scarecrow', 'strawman'], 'id': 920, 'def': 'an effigy in the shape of a man to frighten birds away from seeds', 'name': 'scarecrow'}, {'frequency': 'f', 'synset': 'scarf.n.01', 'synonyms': ['scarf'], 'id': 921, 'def': 'a garment worn around the head or neck or shoulders for warmth or decoration', 'name': 'scarf'}, {'frequency': 'c', 'synset': 'school_bus.n.01', 'synonyms': ['school_bus'], 'id': 922, 'def': 'a bus used to transport children to or from school', 'name': 'school_bus'}, {'frequency': 'f', 'synset': 'scissors.n.01', 'synonyms': ['scissors'], 'id': 923, 'def': 'a tool having two crossed pivoting blades with looped handles', 'name': 'scissors'}, {'frequency': 'f', 'synset': 'scoreboard.n.01', 'synonyms': ['scoreboard'], 'id': 924, 'def': 'a large board for displaying the score of a contest (and some other information)', 'name': 'scoreboard'}, {'frequency': 'r', 'synset': 'scraper.n.01', 'synonyms': ['scraper'], 'id': 925, 'def': 'any of various hand tools for scraping', 'name': 'scraper'}, {'frequency': 'c', 'synset': 'screwdriver.n.01', 'synonyms': ['screwdriver'], 'id': 926, 'def': 'a hand tool for driving screws; has a tip that fits into the head of a screw', 'name': 'screwdriver'}, {'frequency': 'f', 'synset': 'scrub_brush.n.01', 'synonyms': ['scrubbing_brush'], 'id': 927, 'def': 'a brush with short stiff bristles for heavy cleaning', 'name': 'scrubbing_brush'}, {'frequency': 'c', 'synset': 'sculpture.n.01', 'synonyms': ['sculpture'], 'id': 928, 'def': 'a three-dimensional work of art', 'name': 'sculpture'}, {'frequency': 'c', 'synset': 'seabird.n.01', 'synonyms': ['seabird', 'seafowl'], 'id': 929, 'def': 'a bird that frequents coastal waters and the open ocean: gulls; pelicans; gannets; cormorants; albatrosses; petrels; etc.', 'name': 'seabird'}, {'frequency': 'c', 'synset': 'seahorse.n.02', 'synonyms': ['seahorse'], 'id': 930, 'def': 'small fish with horse-like heads bent sharply downward and curled tails', 'name': 'seahorse'}, {'frequency': 'r', 'synset': 'seaplane.n.01', 'synonyms': ['seaplane', 'hydroplane'], 'id': 931, 'def': 'an airplane that can land on or take off from water', 'name': 'seaplane'}, {'frequency': 'c', 'synset': 'seashell.n.01', 'synonyms': ['seashell'], 'id': 932, 'def': 'the shell of a marine organism', 'name': 'seashell'}, {'frequency': 'c', 'synset': 'sewing_machine.n.01', 'synonyms': ['sewing_machine'], 'id': 933, 'def': 'a textile machine used as a home appliance for sewing', 'name': 'sewing_machine'}, {'frequency': 'c', 'synset': 'shaker.n.03', 'synonyms': ['shaker'], 'id': 934, 'def': 'a container in which something can be shaken', 'name': 'shaker'}, {'frequency': 'c', 'synset': 'shampoo.n.01', 'synonyms': ['shampoo'], 'id': 935, 'def': 'cleansing agent consisting of soaps or detergents used for washing the hair', 'name': 'shampoo'}, {'frequency': 'c', 'synset': 'shark.n.01', 'synonyms': ['shark'], 'id': 936, 'def': 'typically large carnivorous fishes with sharpe teeth', 'name': 'shark'}, {'frequency': 'r', 'synset': 'sharpener.n.01', 'synonyms': ['sharpener'], 'id': 937, 'def': 'any implement that is used to make something (an edge or a point) sharper', 'name': 'sharpener'}, {'frequency': 'r', 'synset': 'sharpie.n.03', 'synonyms': ['Sharpie'], 'id': 938, 'def': 'a pen with indelible ink that will write on any surface', 'name': 'Sharpie'}, {'frequency': 'r', 'synset': 'shaver.n.03', 'synonyms': ['shaver_(electric)', 'electric_shaver', 'electric_razor'], 'id': 939, 'def': 'a razor powered by an electric motor', 'name': 'shaver_(electric)'}, {'frequency': 'c', 'synset': 'shaving_cream.n.01', 'synonyms': ['shaving_cream', 'shaving_soap'], 'id': 940, 'def': 'toiletry consisting that forms a rich lather for softening the beard before shaving', 'name': 'shaving_cream'}, {'frequency': 'r', 'synset': 'shawl.n.01', 'synonyms': ['shawl'], 'id': 941, 'def': 'cloak consisting of an oblong piece of cloth used to cover the head and shoulders', 'name': 'shawl'}, {'frequency': 'r', 'synset': 'shears.n.01', 'synonyms': ['shears'], 'id': 942, 'def': 'large scissors with strong blades', 'name': 'shears'}, {'frequency': 'f', 'synset': 'sheep.n.01', 'synonyms': ['sheep'], 'id': 943, 'def': 'woolly usually horned ruminant mammal related to the goat', 'name': 'sheep'}, {'frequency': 'r', 'synset': 'shepherd_dog.n.01', 'synonyms': ['shepherd_dog', 'sheepdog'], 'id': 944, 'def': 'any of various usually long-haired breeds of dog reared to herd and guard sheep', 'name': 'shepherd_dog'}, {'frequency': 'r', 'synset': 'sherbert.n.01', 'synonyms': ['sherbert', 'sherbet'], 'id': 945, 'def': 'a frozen dessert made primarily of fruit juice and sugar', 'name': 'sherbert'}, {'frequency': 'c', 'synset': 'shield.n.02', 'synonyms': ['shield'], 'id': 946, 'def': 'armor carried on the arm to intercept blows', 'name': 'shield'}, {'frequency': 'f', 'synset': 'shirt.n.01', 'synonyms': ['shirt'], 'id': 947, 'def': 'a garment worn on the upper half of the body', 'name': 'shirt'}, {'frequency': 'f', 'synset': 'shoe.n.01', 'synonyms': ['shoe', 'sneaker_(type_of_shoe)', 'tennis_shoe'], 'id': 948, 'def': 'common footwear covering the foot', 'name': 'shoe'}, {'frequency': 'f', 'synset': 'shopping_bag.n.01', 'synonyms': ['shopping_bag'], 'id': 949, 'def': 'a bag made of plastic or strong paper (often with handles); used to transport goods after shopping', 'name': 'shopping_bag'}, {'frequency': 'c', 'synset': 'shopping_cart.n.01', 'synonyms': ['shopping_cart'], 'id': 950, 'def': 'a handcart that holds groceries or other goods while shopping', 'name': 'shopping_cart'}, {'frequency': 'f', 'synset': 'short_pants.n.01', 'synonyms': ['short_pants', 'shorts_(clothing)', 'trunks_(clothing)'], 'id': 951, 'def': 'trousers that end at or above the knee', 'name': 'short_pants'}, {'frequency': 'r', 'synset': 'shot_glass.n.01', 'synonyms': ['shot_glass'], 'id': 952, 'def': 'a small glass adequate to hold a single swallow of whiskey', 'name': 'shot_glass'}, {'frequency': 'f', 'synset': 'shoulder_bag.n.01', 'synonyms': ['shoulder_bag'], 'id': 953, 'def': 'a large handbag that can be carried by a strap looped over the shoulder', 'name': 'shoulder_bag'}, {'frequency': 'c', 'synset': 'shovel.n.01', 'synonyms': ['shovel'], 'id': 954, 'def': 'a hand tool for lifting loose material such as snow, dirt, etc.', 'name': 'shovel'}, {'frequency': 'f', 'synset': 'shower.n.01', 'synonyms': ['shower_head'], 'id': 955, 'def': 'a plumbing fixture that sprays water over you', 'name': 'shower_head'}, {'frequency': 'r', 'synset': 'shower_cap.n.01', 'synonyms': ['shower_cap'], 'id': 956, 'def': 'a tight cap worn to keep hair dry while showering', 'name': 'shower_cap'}, {'frequency': 'f', 'synset': 'shower_curtain.n.01', 'synonyms': ['shower_curtain'], 'id': 957, 'def': 'a curtain that keeps water from splashing out of the shower area', 'name': 'shower_curtain'}, {'frequency': 'r', 'synset': 'shredder.n.01', 'synonyms': ['shredder_(for_paper)'], 'id': 958, 'def': 'a device that shreds documents', 'name': 'shredder_(for_paper)'}, {'frequency': 'f', 'synset': 'signboard.n.01', 'synonyms': ['signboard'], 'id': 959, 'def': 'structure displaying a board on which advertisements can be posted', 'name': 'signboard'}, {'frequency': 'c', 'synset': 'silo.n.01', 'synonyms': ['silo'], 'id': 960, 'def': 'a cylindrical tower used for storing goods', 'name': 'silo'}, {'frequency': 'f', 'synset': 'sink.n.01', 'synonyms': ['sink'], 'id': 961, 'def': 'plumbing fixture consisting of a water basin fixed to a wall or floor and having a drainpipe', 'name': 'sink'}, {'frequency': 'f', 'synset': 'skateboard.n.01', 'synonyms': ['skateboard'], 'id': 962, 'def': 'a board with wheels that is ridden in a standing or crouching position and propelled by foot', 'name': 'skateboard'}, {'frequency': 'c', 'synset': 'skewer.n.01', 'synonyms': ['skewer'], 'id': 963, 'def': 'a long pin for holding meat in position while it is being roasted', 'name': 'skewer'}, {'frequency': 'f', 'synset': 'ski.n.01', 'synonyms': ['ski'], 'id': 964, 'def': 'sports equipment for skiing on snow', 'name': 'ski'}, {'frequency': 'f', 'synset': 'ski_boot.n.01', 'synonyms': ['ski_boot'], 'id': 965, 'def': 'a stiff boot that is fastened to a ski with a ski binding', 'name': 'ski_boot'}, {'frequency': 'f', 'synset': 'ski_parka.n.01', 'synonyms': ['ski_parka', 'ski_jacket'], 'id': 966, 'def': 'a parka to be worn while skiing', 'name': 'ski_parka'}, {'frequency': 'f', 'synset': 'ski_pole.n.01', 'synonyms': ['ski_pole'], 'id': 967, 'def': 'a pole with metal points used as an aid in skiing', 'name': 'ski_pole'}, {'frequency': 'f', 'synset': 'skirt.n.02', 'synonyms': ['skirt'], 'id': 968, 'def': 'a garment hanging from the waist; worn mainly by girls and women', 'name': 'skirt'}, {'frequency': 'r', 'synset': 'skullcap.n.01', 'synonyms': ['skullcap'], 'id': 969, 'def': 'rounded brimless cap fitting the crown of the head', 'name': 'skullcap'}, {'frequency': 'c', 'synset': 'sled.n.01', 'synonyms': ['sled', 'sledge', 'sleigh'], 'id': 970, 'def': 'a vehicle or flat object for transportation over snow by sliding or pulled by dogs, etc.', 'name': 'sled'}, {'frequency': 'c', 'synset': 'sleeping_bag.n.01', 'synonyms': ['sleeping_bag'], 'id': 971, 'def': 'large padded bag designed to be slept in outdoors', 'name': 'sleeping_bag'}, {'frequency': 'r', 'synset': 'sling.n.05', 'synonyms': ['sling_(bandage)', 'triangular_bandage'], 'id': 972, 'def': 'bandage to support an injured forearm; slung over the shoulder or neck', 'name': 'sling_(bandage)'}, {'frequency': 'c', 'synset': 'slipper.n.01', 'synonyms': ['slipper_(footwear)', 'carpet_slipper_(footwear)'], 'id': 973, 'def': 'low footwear that can be slipped on and off easily; usually worn indoors', 'name': 'slipper_(footwear)'}, {'frequency': 'r', 'synset': 'smoothie.n.02', 'synonyms': ['smoothie'], 'id': 974, 'def': 'a thick smooth drink consisting of fresh fruit pureed with ice cream or yoghurt or milk', 'name': 'smoothie'}, {'frequency': 'r', 'synset': 'snake.n.01', 'synonyms': ['snake', 'serpent'], 'id': 975, 'def': 'limbless scaly elongate reptile; some are venomous', 'name': 'snake'}, {'frequency': 'f', 'synset': 'snowboard.n.01', 'synonyms': ['snowboard'], 'id': 976, 'def': 'a board that resembles a broad ski or a small surfboard; used in a standing position to slide down snow-covered slopes', 'name': 'snowboard'}, {'frequency': 'c', 'synset': 'snowman.n.01', 'synonyms': ['snowman'], 'id': 977, 'def': 'a figure of a person made of packed snow', 'name': 'snowman'}, {'frequency': 'c', 'synset': 'snowmobile.n.01', 'synonyms': ['snowmobile'], 'id': 978, 'def': 'tracked vehicle for travel on snow having skis in front', 'name': 'snowmobile'}, {'frequency': 'f', 'synset': 'soap.n.01', 'synonyms': ['soap'], 'id': 979, 'def': 'a cleansing agent made from the salts of vegetable or animal fats', 'name': 'soap'}, {'frequency': 'f', 'synset': 'soccer_ball.n.01', 'synonyms': ['soccer_ball'], 'id': 980, 'def': "an inflated ball used in playing soccer (called `football' outside of the United States)", 'name': 'soccer_ball'}, {'frequency': 'f', 'synset': 'sock.n.01', 'synonyms': ['sock'], 'id': 981, 'def': 'cloth covering for the foot; worn inside the shoe; reaches to between the ankle and the knee', 'name': 'sock'}, {'frequency': 'f', 'synset': 'sofa.n.01', 'synonyms': ['sofa', 'couch', 'lounge'], 'id': 982, 'def': 'an upholstered seat for more than one person', 'name': 'sofa'}, {'frequency': 'r', 'synset': 'softball.n.01', 'synonyms': ['softball'], 'id': 983, 'def': 'ball used in playing softball', 'name': 'softball'}, {'frequency': 'c', 'synset': 'solar_array.n.01', 'synonyms': ['solar_array', 'solar_battery', 'solar_panel'], 'id': 984, 'def': 'electrical device consisting of a large array of connected solar cells', 'name': 'solar_array'}, {'frequency': 'r', 'synset': 'sombrero.n.02', 'synonyms': ['sombrero'], 'id': 985, 'def': 'a straw hat with a tall crown and broad brim; worn in American southwest and in Mexico', 'name': 'sombrero'}, {'frequency': 'f', 'synset': 'soup.n.01', 'synonyms': ['soup'], 'id': 986, 'def': 'liquid food especially of meat or fish or vegetable stock often containing pieces of solid food', 'name': 'soup'}, {'frequency': 'r', 'synset': 'soup_bowl.n.01', 'synonyms': ['soup_bowl'], 'id': 987, 'def': 'a bowl for serving soup', 'name': 'soup_bowl'}, {'frequency': 'c', 'synset': 'soupspoon.n.01', 'synonyms': ['soupspoon'], 'id': 988, 'def': 'a spoon with a rounded bowl for eating soup', 'name': 'soupspoon'}, {'frequency': 'c', 'synset': 'sour_cream.n.01', 'synonyms': ['sour_cream', 'soured_cream'], 'id': 989, 'def': 'soured light cream', 'name': 'sour_cream'}, {'frequency': 'r', 'synset': 'soya_milk.n.01', 'synonyms': ['soya_milk', 'soybean_milk', 'soymilk'], 'id': 990, 'def': 'a milk substitute containing soybean flour and water; used in some infant formulas and in making tofu', 'name': 'soya_milk'}, {'frequency': 'r', 'synset': 'space_shuttle.n.01', 'synonyms': ['space_shuttle'], 'id': 991, 'def': "a reusable spacecraft with wings for a controlled descent through the Earth's atmosphere", 'name': 'space_shuttle'}, {'frequency': 'r', 'synset': 'sparkler.n.02', 'synonyms': ['sparkler_(fireworks)'], 'id': 992, 'def': 'a firework that burns slowly and throws out a shower of sparks', 'name': 'sparkler_(fireworks)'}, {'frequency': 'f', 'synset': 'spatula.n.02', 'synonyms': ['spatula'], 'id': 993, 'def': 'a hand tool with a thin flexible blade used to mix or spread soft substances', 'name': 'spatula'}, {'frequency': 'r', 'synset': 'spear.n.01', 'synonyms': ['spear', 'lance'], 'id': 994, 'def': 'a long pointed rod used as a tool or weapon', 'name': 'spear'}, {'frequency': 'f', 'synset': 'spectacles.n.01', 'synonyms': ['spectacles', 'specs', 'eyeglasses', 'glasses'], 'id': 995, 'def': 'optical instrument consisting of a frame that holds a pair of lenses for correcting defective vision', 'name': 'spectacles'}, {'frequency': 'c', 'synset': 'spice_rack.n.01', 'synonyms': ['spice_rack'], 'id': 996, 'def': 'a rack for displaying containers filled with spices', 'name': 'spice_rack'}, {'frequency': 'c', 'synset': 'spider.n.01', 'synonyms': ['spider'], 'id': 997, 'def': 'predatory arachnid with eight legs, two poison fangs, two feelers, and usually two silk-spinning organs at the back end of the body', 'name': 'spider'}, {'frequency': 'r', 'synset': 'spiny_lobster.n.02', 'synonyms': ['crawfish', 'crayfish'], 'id': 998, 'def': 'large edible marine crustacean having a spiny carapace but lacking the large pincers of true lobsters', 'name': 'crawfish'}, {'frequency': 'c', 'synset': 'sponge.n.01', 'synonyms': ['sponge'], 'id': 999, 'def': 'a porous mass usable to absorb water typically used for cleaning', 'name': 'sponge'}, {'frequency': 'f', 'synset': 'spoon.n.01', 'synonyms': ['spoon'], 'id': 1000, 'def': 'a piece of cutlery with a shallow bowl-shaped container and a handle', 'name': 'spoon'}, {'frequency': 'c', 'synset': 'sportswear.n.01', 'synonyms': ['sportswear', 'athletic_wear', 'activewear'], 'id': 1001, 'def': 'attire worn for sport or for casual wear', 'name': 'sportswear'}, {'frequency': 'c', 'synset': 'spotlight.n.02', 'synonyms': ['spotlight'], 'id': 1002, 'def': 'a lamp that produces a strong beam of light to illuminate a restricted area; used to focus attention of a stage performer', 'name': 'spotlight'}, {'frequency': 'r', 'synset': 'squid.n.01', 'synonyms': ['squid_(food)', 'calamari', 'calamary'], 'id': 1003, 'def': '(Italian cuisine) squid prepared as food', 'name': 'squid_(food)'}, {'frequency': 'c', 'synset': 'squirrel.n.01', 'synonyms': ['squirrel'], 'id': 1004, 'def': 'a kind of arboreal rodent having a long bushy tail', 'name': 'squirrel'}, {'frequency': 'r', 'synset': 'stagecoach.n.01', 'synonyms': ['stagecoach'], 'id': 1005, 'def': 'a large coach-and-four formerly used to carry passengers and mail on regular routes between towns', 'name': 'stagecoach'}, {'frequency': 'c', 'synset': 'stapler.n.01', 'synonyms': ['stapler_(stapling_machine)'], 'id': 1006, 'def': 'a machine that inserts staples into sheets of paper in order to fasten them together', 'name': 'stapler_(stapling_machine)'}, {'frequency': 'c', 'synset': 'starfish.n.01', 'synonyms': ['starfish', 'sea_star'], 'id': 1007, 'def': 'echinoderms characterized by five arms extending from a central disk', 'name': 'starfish'}, {'frequency': 'f', 'synset': 'statue.n.01', 'synonyms': ['statue_(sculpture)'], 'id': 1008, 'def': 'a sculpture representing a human or animal', 'name': 'statue_(sculpture)'}, {'frequency': 'c', 'synset': 'steak.n.01', 'synonyms': ['steak_(food)'], 'id': 1009, 'def': 'a slice of meat cut from the fleshy part of an animal or large fish', 'name': 'steak_(food)'}, {'frequency': 'r', 'synset': 'steak_knife.n.01', 'synonyms': ['steak_knife'], 'id': 1010, 'def': 'a sharp table knife used in eating steak', 'name': 'steak_knife'}, {'frequency': 'f', 'synset': 'steering_wheel.n.01', 'synonyms': ['steering_wheel'], 'id': 1011, 'def': 'a handwheel that is used for steering', 'name': 'steering_wheel'}, {'frequency': 'r', 'synset': 'step_ladder.n.01', 'synonyms': ['stepladder'], 'id': 1012, 'def': 'a folding portable ladder hinged at the top', 'name': 'stepladder'}, {'frequency': 'c', 'synset': 'step_stool.n.01', 'synonyms': ['step_stool'], 'id': 1013, 'def': 'a stool that has one or two steps that fold under the seat', 'name': 'step_stool'}, {'frequency': 'c', 'synset': 'stereo.n.01', 'synonyms': ['stereo_(sound_system)'], 'id': 1014, 'def': 'electronic device for playing audio', 'name': 'stereo_(sound_system)'}, {'frequency': 'r', 'synset': 'stew.n.02', 'synonyms': ['stew'], 'id': 1015, 'def': 'food prepared by stewing especially meat or fish with vegetables', 'name': 'stew'}, {'frequency': 'r', 'synset': 'stirrer.n.02', 'synonyms': ['stirrer'], 'id': 1016, 'def': 'an implement used for stirring', 'name': 'stirrer'}, {'frequency': 'f', 'synset': 'stirrup.n.01', 'synonyms': ['stirrup'], 'id': 1017, 'def': "support consisting of metal loops into which rider's feet go", 'name': 'stirrup'}, {'frequency': 'f', 'synset': 'stool.n.01', 'synonyms': ['stool'], 'id': 1018, 'def': 'a simple seat without a back or arms', 'name': 'stool'}, {'frequency': 'f', 'synset': 'stop_sign.n.01', 'synonyms': ['stop_sign'], 'id': 1019, 'def': 'a traffic sign to notify drivers that they must come to a complete stop', 'name': 'stop_sign'}, {'frequency': 'f', 'synset': 'stoplight.n.01', 'synonyms': ['brake_light'], 'id': 1020, 'def': 'a red light on the rear of a motor vehicle that signals when the brakes are applied', 'name': 'brake_light'}, {'frequency': 'f', 'synset': 'stove.n.01', 'synonyms': ['stove', 'kitchen_stove', 'range_(kitchen_appliance)', 'kitchen_range', 'cooking_stove'], 'id': 1021, 'def': 'a kitchen appliance used for cooking food', 'name': 'stove'}, {'frequency': 'c', 'synset': 'strainer.n.01', 'synonyms': ['strainer'], 'id': 1022, 'def': 'a filter to retain larger pieces while smaller pieces and liquids pass through', 'name': 'strainer'}, {'frequency': 'f', 'synset': 'strap.n.01', 'synonyms': ['strap'], 'id': 1023, 'def': 'an elongated strip of material for binding things together or holding', 'name': 'strap'}, {'frequency': 'f', 'synset': 'straw.n.04', 'synonyms': ['straw_(for_drinking)', 'drinking_straw'], 'id': 1024, 'def': 'a thin paper or plastic tube used to suck liquids into the mouth', 'name': 'straw_(for_drinking)'}, {'frequency': 'f', 'synset': 'strawberry.n.01', 'synonyms': ['strawberry'], 'id': 1025, 'def': 'sweet fleshy red fruit', 'name': 'strawberry'}, {'frequency': 'f', 'synset': 'street_sign.n.01', 'synonyms': ['street_sign'], 'id': 1026, 'def': 'a sign visible from the street', 'name': 'street_sign'}, {'frequency': 'f', 'synset': 'streetlight.n.01', 'synonyms': ['streetlight', 'street_lamp'], 'id': 1027, 'def': 'a lamp supported on a lamppost; for illuminating a street', 'name': 'streetlight'}, {'frequency': 'r', 'synset': 'string_cheese.n.01', 'synonyms': ['string_cheese'], 'id': 1028, 'def': 'cheese formed in long strings twisted together', 'name': 'string_cheese'}, {'frequency': 'r', 'synset': 'stylus.n.02', 'synonyms': ['stylus'], 'id': 1029, 'def': 'a pointed tool for writing or drawing or engraving, including pens', 'name': 'stylus'}, {'frequency': 'r', 'synset': 'subwoofer.n.01', 'synonyms': ['subwoofer'], 'id': 1030, 'def': 'a loudspeaker that is designed to reproduce very low bass frequencies', 'name': 'subwoofer'}, {'frequency': 'r', 'synset': 'sugar_bowl.n.01', 'synonyms': ['sugar_bowl'], 'id': 1031, 'def': 'a dish in which sugar is served', 'name': 'sugar_bowl'}, {'frequency': 'r', 'synset': 'sugarcane.n.01', 'synonyms': ['sugarcane_(plant)'], 'id': 1032, 'def': 'juicy canes whose sap is a source of molasses and commercial sugar; fresh canes are sometimes chewed for the juice', 'name': 'sugarcane_(plant)'}, {'frequency': 'f', 'synset': 'suit.n.01', 'synonyms': ['suit_(clothing)'], 'id': 1033, 'def': 'a set of garments (usually including a jacket and trousers or skirt) for outerwear all of the same fabric and color', 'name': 'suit_(clothing)'}, {'frequency': 'c', 'synset': 'sunflower.n.01', 'synonyms': ['sunflower'], 'id': 1034, 'def': 'any plant of the genus Helianthus having large flower heads with dark disk florets and showy yellow rays', 'name': 'sunflower'}, {'frequency': 'f', 'synset': 'sunglasses.n.01', 'synonyms': ['sunglasses'], 'id': 1035, 'def': 'spectacles that are darkened or polarized to protect the eyes from the glare of the sun', 'name': 'sunglasses'}, {'frequency': 'c', 'synset': 'sunhat.n.01', 'synonyms': ['sunhat'], 'id': 1036, 'def': 'a hat with a broad brim that protects the face from direct exposure to the sun', 'name': 'sunhat'}, {'frequency': 'f', 'synset': 'surfboard.n.01', 'synonyms': ['surfboard'], 'id': 1037, 'def': 'a narrow buoyant board for riding surf', 'name': 'surfboard'}, {'frequency': 'c', 'synset': 'sushi.n.01', 'synonyms': ['sushi'], 'id': 1038, 'def': 'rice (with raw fish) wrapped in seaweed', 'name': 'sushi'}, {'frequency': 'c', 'synset': 'swab.n.02', 'synonyms': ['mop'], 'id': 1039, 'def': 'cleaning implement consisting of absorbent material fastened to a handle; for cleaning floors', 'name': 'mop'}, {'frequency': 'c', 'synset': 'sweat_pants.n.01', 'synonyms': ['sweat_pants'], 'id': 1040, 'def': 'loose-fitting trousers with elastic cuffs; worn by athletes', 'name': 'sweat_pants'}, {'frequency': 'c', 'synset': 'sweatband.n.02', 'synonyms': ['sweatband'], 'id': 1041, 'def': 'a band of material tied around the forehead or wrist to absorb sweat', 'name': 'sweatband'}, {'frequency': 'f', 'synset': 'sweater.n.01', 'synonyms': ['sweater'], 'id': 1042, 'def': 'a crocheted or knitted garment covering the upper part of the body', 'name': 'sweater'}, {'frequency': 'f', 'synset': 'sweatshirt.n.01', 'synonyms': ['sweatshirt'], 'id': 1043, 'def': 'cotton knit pullover with long sleeves worn during athletic activity', 'name': 'sweatshirt'}, {'frequency': 'c', 'synset': 'sweet_potato.n.02', 'synonyms': ['sweet_potato'], 'id': 1044, 'def': 'the edible tuberous root of the sweet potato vine', 'name': 'sweet_potato'}, {'frequency': 'f', 'synset': 'swimsuit.n.01', 'synonyms': ['swimsuit', 'swimwear', 'bathing_suit', 'swimming_costume', 'bathing_costume', 'swimming_trunks', 'bathing_trunks'], 'id': 1045, 'def': 'garment worn for swimming', 'name': 'swimsuit'}, {'frequency': 'c', 'synset': 'sword.n.01', 'synonyms': ['sword'], 'id': 1046, 'def': 'a cutting or thrusting weapon that has a long metal blade', 'name': 'sword'}, {'frequency': 'r', 'synset': 'syringe.n.01', 'synonyms': ['syringe'], 'id': 1047, 'def': 'a medical instrument used to inject or withdraw fluids', 'name': 'syringe'}, {'frequency': 'r', 'synset': 'tabasco.n.02', 'synonyms': ['Tabasco_sauce'], 'id': 1048, 'def': 'very spicy sauce (trade name Tabasco) made from fully-aged red peppers', 'name': 'Tabasco_sauce'}, {'frequency': 'r', 'synset': 'table-tennis_table.n.01', 'synonyms': ['table-tennis_table', 'ping-pong_table'], 'id': 1049, 'def': 'a table used for playing table tennis', 'name': 'table-tennis_table'}, {'frequency': 'f', 'synset': 'table.n.02', 'synonyms': ['table'], 'id': 1050, 'def': 'a piece of furniture having a smooth flat top that is usually supported by one or more vertical legs', 'name': 'table'}, {'frequency': 'c', 'synset': 'table_lamp.n.01', 'synonyms': ['table_lamp'], 'id': 1051, 'def': 'a lamp that sits on a table', 'name': 'table_lamp'}, {'frequency': 'f', 'synset': 'tablecloth.n.01', 'synonyms': ['tablecloth'], 'id': 1052, 'def': 'a covering spread over a dining table', 'name': 'tablecloth'}, {'frequency': 'r', 'synset': 'tachometer.n.01', 'synonyms': ['tachometer'], 'id': 1053, 'def': 'measuring instrument for indicating speed of rotation', 'name': 'tachometer'}, {'frequency': 'r', 'synset': 'taco.n.02', 'synonyms': ['taco'], 'id': 1054, 'def': 'a small tortilla cupped around a filling', 'name': 'taco'}, {'frequency': 'f', 'synset': 'tag.n.02', 'synonyms': ['tag'], 'id': 1055, 'def': 'a label associated with something for the purpose of identification or information', 'name': 'tag'}, {'frequency': 'f', 'synset': 'taillight.n.01', 'synonyms': ['taillight', 'rear_light'], 'id': 1056, 'def': 'lamp (usually red) mounted at the rear of a motor vehicle', 'name': 'taillight'}, {'frequency': 'r', 'synset': 'tambourine.n.01', 'synonyms': ['tambourine'], 'id': 1057, 'def': 'a shallow drum with a single drumhead and with metallic disks in the sides', 'name': 'tambourine'}, {'frequency': 'r', 'synset': 'tank.n.01', 'synonyms': ['army_tank', 'armored_combat_vehicle', 'armoured_combat_vehicle'], 'id': 1058, 'def': 'an enclosed armored military vehicle; has a cannon and moves on caterpillar treads', 'name': 'army_tank'}, {'frequency': 'f', 'synset': 'tank.n.02', 'synonyms': ['tank_(storage_vessel)', 'storage_tank'], 'id': 1059, 'def': 'a large (usually metallic) vessel for holding gases or liquids', 'name': 'tank_(storage_vessel)'}, {'frequency': 'f', 'synset': 'tank_top.n.01', 'synonyms': ['tank_top_(clothing)'], 'id': 1060, 'def': 'a tight-fitting sleeveless shirt with wide shoulder straps and low neck and no front opening', 'name': 'tank_top_(clothing)'}, {'frequency': 'f', 'synset': 'tape.n.01', 'synonyms': ['tape_(sticky_cloth_or_paper)'], 'id': 1061, 'def': 'a long thin piece of cloth or paper as used for binding or fastening', 'name': 'tape_(sticky_cloth_or_paper)'}, {'frequency': 'c', 'synset': 'tape.n.04', 'synonyms': ['tape_measure', 'measuring_tape'], 'id': 1062, 'def': 'measuring instrument consisting of a narrow strip (cloth or metal) marked in inches or centimeters and used for measuring lengths', 'name': 'tape_measure'}, {'frequency': 'c', 'synset': 'tapestry.n.02', 'synonyms': ['tapestry'], 'id': 1063, 'def': 'a heavy textile with a woven design; used for curtains and upholstery', 'name': 'tapestry'}, {'frequency': 'f', 'synset': 'tarpaulin.n.01', 'synonyms': ['tarp'], 'id': 1064, 'def': 'waterproofed canvas', 'name': 'tarp'}, {'frequency': 'c', 'synset': 'tartan.n.01', 'synonyms': ['tartan', 'plaid'], 'id': 1065, 'def': 'a cloth having a crisscross design', 'name': 'tartan'}, {'frequency': 'c', 'synset': 'tassel.n.01', 'synonyms': ['tassel'], 'id': 1066, 'def': 'adornment consisting of a bunch of cords fastened at one end', 'name': 'tassel'}, {'frequency': 'c', 'synset': 'tea_bag.n.01', 'synonyms': ['tea_bag'], 'id': 1067, 'def': 'a measured amount of tea in a bag for an individual serving of tea', 'name': 'tea_bag'}, {'frequency': 'c', 'synset': 'teacup.n.02', 'synonyms': ['teacup'], 'id': 1068, 'def': 'a cup from which tea is drunk', 'name': 'teacup'}, {'frequency': 'c', 'synset': 'teakettle.n.01', 'synonyms': ['teakettle'], 'id': 1069, 'def': 'kettle for boiling water to make tea', 'name': 'teakettle'}, {'frequency': 'f', 'synset': 'teapot.n.01', 'synonyms': ['teapot'], 'id': 1070, 'def': 'pot for brewing tea; usually has a spout and handle', 'name': 'teapot'}, {'frequency': 'f', 'synset': 'teddy.n.01', 'synonyms': ['teddy_bear'], 'id': 1071, 'def': "plaything consisting of a child's toy bear (usually plush and stuffed with soft materials)", 'name': 'teddy_bear'}, {'frequency': 'f', 'synset': 'telephone.n.01', 'synonyms': ['telephone', 'phone', 'telephone_set'], 'id': 1072, 'def': 'electronic device for communicating by voice over long distances (includes wired and wireless/cell phones)', 'name': 'telephone'}, {'frequency': 'c', 'synset': 'telephone_booth.n.01', 'synonyms': ['telephone_booth', 'phone_booth', 'call_box', 'telephone_box', 'telephone_kiosk'], 'id': 1073, 'def': 'booth for using a telephone', 'name': 'telephone_booth'}, {'frequency': 'f', 'synset': 'telephone_pole.n.01', 'synonyms': ['telephone_pole', 'telegraph_pole', 'telegraph_post'], 'id': 1074, 'def': 'tall pole supporting telephone wires', 'name': 'telephone_pole'}, {'frequency': 'r', 'synset': 'telephoto_lens.n.01', 'synonyms': ['telephoto_lens', 'zoom_lens'], 'id': 1075, 'def': 'a camera lens that magnifies the image', 'name': 'telephoto_lens'}, {'frequency': 'c', 'synset': 'television_camera.n.01', 'synonyms': ['television_camera', 'tv_camera'], 'id': 1076, 'def': 'television equipment for capturing and recording video', 'name': 'television_camera'}, {'frequency': 'f', 'synset': 'television_receiver.n.01', 'synonyms': ['television_set', 'tv', 'tv_set'], 'id': 1077, 'def': 'an electronic device that receives television signals and displays them on a screen', 'name': 'television_set'}, {'frequency': 'f', 'synset': 'tennis_ball.n.01', 'synonyms': ['tennis_ball'], 'id': 1078, 'def': 'ball about the size of a fist used in playing tennis', 'name': 'tennis_ball'}, {'frequency': 'f', 'synset': 'tennis_racket.n.01', 'synonyms': ['tennis_racket'], 'id': 1079, 'def': 'a racket used to play tennis', 'name': 'tennis_racket'}, {'frequency': 'r', 'synset': 'tequila.n.01', 'synonyms': ['tequila'], 'id': 1080, 'def': 'Mexican liquor made from fermented juices of an agave plant', 'name': 'tequila'}, {'frequency': 'c', 'synset': 'thermometer.n.01', 'synonyms': ['thermometer'], 'id': 1081, 'def': 'measuring instrument for measuring temperature', 'name': 'thermometer'}, {'frequency': 'c', 'synset': 'thermos.n.01', 'synonyms': ['thermos_bottle'], 'id': 1082, 'def': 'vacuum flask that preserves temperature of hot or cold drinks', 'name': 'thermos_bottle'}, {'frequency': 'f', 'synset': 'thermostat.n.01', 'synonyms': ['thermostat'], 'id': 1083, 'def': 'a regulator for automatically regulating temperature by starting or stopping the supply of heat', 'name': 'thermostat'}, {'frequency': 'r', 'synset': 'thimble.n.02', 'synonyms': ['thimble'], 'id': 1084, 'def': 'a small metal cap to protect the finger while sewing; can be used as a small container', 'name': 'thimble'}, {'frequency': 'c', 'synset': 'thread.n.01', 'synonyms': ['thread', 'yarn'], 'id': 1085, 'def': 'a fine cord of twisted fibers (of cotton or silk or wool or nylon etc.) used in sewing and weaving', 'name': 'thread'}, {'frequency': 'c', 'synset': 'thumbtack.n.01', 'synonyms': ['thumbtack', 'drawing_pin', 'pushpin'], 'id': 1086, 'def': 'a tack for attaching papers to a bulletin board or drawing board', 'name': 'thumbtack'}, {'frequency': 'c', 'synset': 'tiara.n.01', 'synonyms': ['tiara'], 'id': 1087, 'def': 'a jeweled headdress worn by women on formal occasions', 'name': 'tiara'}, {'frequency': 'c', 'synset': 'tiger.n.02', 'synonyms': ['tiger'], 'id': 1088, 'def': 'large feline of forests in most of Asia having a tawny coat with black stripes', 'name': 'tiger'}, {'frequency': 'c', 'synset': 'tights.n.01', 'synonyms': ['tights_(clothing)', 'leotards'], 'id': 1089, 'def': 'skintight knit hose covering the body from the waist to the feet worn by acrobats and dancers and as stockings by women and girls', 'name': 'tights_(clothing)'}, {'frequency': 'c', 'synset': 'timer.n.01', 'synonyms': ['timer', 'stopwatch'], 'id': 1090, 'def': 'a timepiece that measures a time interval and signals its end', 'name': 'timer'}, {'frequency': 'f', 'synset': 'tinfoil.n.01', 'synonyms': ['tinfoil'], 'id': 1091, 'def': 'foil made of tin or an alloy of tin and lead', 'name': 'tinfoil'}, {'frequency': 'c', 'synset': 'tinsel.n.01', 'synonyms': ['tinsel'], 'id': 1092, 'def': 'a showy decoration that is basically valueless', 'name': 'tinsel'}, {'frequency': 'f', 'synset': 'tissue.n.02', 'synonyms': ['tissue_paper'], 'id': 1093, 'def': 'a soft thin (usually translucent) paper', 'name': 'tissue_paper'}, {'frequency': 'c', 'synset': 'toast.n.01', 'synonyms': ['toast_(food)'], 'id': 1094, 'def': 'slice of bread that has been toasted', 'name': 'toast_(food)'}, {'frequency': 'f', 'synset': 'toaster.n.02', 'synonyms': ['toaster'], 'id': 1095, 'def': 'a kitchen appliance (usually electric) for toasting bread', 'name': 'toaster'}, {'frequency': 'f', 'synset': 'toaster_oven.n.01', 'synonyms': ['toaster_oven'], 'id': 1096, 'def': 'kitchen appliance consisting of a small electric oven for toasting or warming food', 'name': 'toaster_oven'}, {'frequency': 'f', 'synset': 'toilet.n.02', 'synonyms': ['toilet'], 'id': 1097, 'def': 'a plumbing fixture for defecation and urination', 'name': 'toilet'}, {'frequency': 'f', 'synset': 'toilet_tissue.n.01', 'synonyms': ['toilet_tissue', 'toilet_paper', 'bathroom_tissue'], 'id': 1098, 'def': 'a soft thin absorbent paper for use in toilets', 'name': 'toilet_tissue'}, {'frequency': 'f', 'synset': 'tomato.n.01', 'synonyms': ['tomato'], 'id': 1099, 'def': 'mildly acid red or yellow pulpy fruit eaten as a vegetable', 'name': 'tomato'}, {'frequency': 'f', 'synset': 'tongs.n.01', 'synonyms': ['tongs'], 'id': 1100, 'def': 'any of various devices for taking hold of objects; usually have two hinged legs with handles above and pointed hooks below', 'name': 'tongs'}, {'frequency': 'c', 'synset': 'toolbox.n.01', 'synonyms': ['toolbox'], 'id': 1101, 'def': 'a box or chest or cabinet for holding hand tools', 'name': 'toolbox'}, {'frequency': 'f', 'synset': 'toothbrush.n.01', 'synonyms': ['toothbrush'], 'id': 1102, 'def': 'small brush; has long handle; used to clean teeth', 'name': 'toothbrush'}, {'frequency': 'f', 'synset': 'toothpaste.n.01', 'synonyms': ['toothpaste'], 'id': 1103, 'def': 'a dentifrice in the form of a paste', 'name': 'toothpaste'}, {'frequency': 'f', 'synset': 'toothpick.n.01', 'synonyms': ['toothpick'], 'id': 1104, 'def': 'pick consisting of a small strip of wood or plastic; used to pick food from between the teeth', 'name': 'toothpick'}, {'frequency': 'f', 'synset': 'top.n.09', 'synonyms': ['cover'], 'id': 1105, 'def': 'covering for a hole (especially a hole in the top of a container)', 'name': 'cover'}, {'frequency': 'c', 'synset': 'tortilla.n.01', 'synonyms': ['tortilla'], 'id': 1106, 'def': 'thin unleavened pancake made from cornmeal or wheat flour', 'name': 'tortilla'}, {'frequency': 'c', 'synset': 'tow_truck.n.01', 'synonyms': ['tow_truck'], 'id': 1107, 'def': 'a truck equipped to hoist and pull wrecked cars (or to remove cars from no-parking zones)', 'name': 'tow_truck'}, {'frequency': 'f', 'synset': 'towel.n.01', 'synonyms': ['towel'], 'id': 1108, 'def': 'a rectangular piece of absorbent cloth (or paper) for drying or wiping', 'name': 'towel'}, {'frequency': 'f', 'synset': 'towel_rack.n.01', 'synonyms': ['towel_rack', 'towel_rail', 'towel_bar'], 'id': 1109, 'def': 'a rack consisting of one or more bars on which towels can be hung', 'name': 'towel_rack'}, {'frequency': 'f', 'synset': 'toy.n.03', 'synonyms': ['toy'], 'id': 1110, 'def': 'a device regarded as providing amusement', 'name': 'toy'}, {'frequency': 'c', 'synset': 'tractor.n.01', 'synonyms': ['tractor_(farm_equipment)'], 'id': 1111, 'def': 'a wheeled vehicle with large wheels; used in farming and other applications', 'name': 'tractor_(farm_equipment)'}, {'frequency': 'f', 'synset': 'traffic_light.n.01', 'synonyms': ['traffic_light'], 'id': 1112, 'def': 'a device to control vehicle traffic often consisting of three or more lights', 'name': 'traffic_light'}, {'frequency': 'c', 'synset': 'trail_bike.n.01', 'synonyms': ['dirt_bike'], 'id': 1113, 'def': 'a lightweight motorcycle equipped with rugged tires and suspension for off-road use', 'name': 'dirt_bike'}, {'frequency': 'f', 'synset': 'trailer_truck.n.01', 'synonyms': ['trailer_truck', 'tractor_trailer', 'trucking_rig', 'articulated_lorry', 'semi_truck'], 'id': 1114, 'def': 'a truck consisting of a tractor and trailer together', 'name': 'trailer_truck'}, {'frequency': 'f', 'synset': 'train.n.01', 'synonyms': ['train_(railroad_vehicle)', 'railroad_train'], 'id': 1115, 'def': 'public or private transport provided by a line of railway cars coupled together and drawn by a locomotive', 'name': 'train_(railroad_vehicle)'}, {'frequency': 'r', 'synset': 'trampoline.n.01', 'synonyms': ['trampoline'], 'id': 1116, 'def': 'gymnastic apparatus consisting of a strong canvas sheet attached with springs to a metal frame', 'name': 'trampoline'}, {'frequency': 'f', 'synset': 'tray.n.01', 'synonyms': ['tray'], 'id': 1117, 'def': 'an open receptacle for holding or displaying or serving articles or food', 'name': 'tray'}, {'frequency': 'r', 'synset': 'trench_coat.n.01', 'synonyms': ['trench_coat'], 'id': 1118, 'def': 'a military style raincoat; belted with deep pockets', 'name': 'trench_coat'}, {'frequency': 'r', 'synset': 'triangle.n.05', 'synonyms': ['triangle_(musical_instrument)'], 'id': 1119, 'def': 'a percussion instrument consisting of a metal bar bent in the shape of an open triangle', 'name': 'triangle_(musical_instrument)'}, {'frequency': 'c', 'synset': 'tricycle.n.01', 'synonyms': ['tricycle'], 'id': 1120, 'def': 'a vehicle with three wheels that is moved by foot pedals', 'name': 'tricycle'}, {'frequency': 'f', 'synset': 'tripod.n.01', 'synonyms': ['tripod'], 'id': 1121, 'def': 'a three-legged rack used for support', 'name': 'tripod'}, {'frequency': 'f', 'synset': 'trouser.n.01', 'synonyms': ['trousers', 'pants_(clothing)'], 'id': 1122, 'def': 'a garment extending from the waist to the knee or ankle, covering each leg separately', 'name': 'trousers'}, {'frequency': 'f', 'synset': 'truck.n.01', 'synonyms': ['truck'], 'id': 1123, 'def': 'an automotive vehicle suitable for hauling', 'name': 'truck'}, {'frequency': 'r', 'synset': 'truffle.n.03', 'synonyms': ['truffle_(chocolate)', 'chocolate_truffle'], 'id': 1124, 'def': 'creamy chocolate candy', 'name': 'truffle_(chocolate)'}, {'frequency': 'c', 'synset': 'trunk.n.02', 'synonyms': ['trunk'], 'id': 1125, 'def': 'luggage consisting of a large strong case used when traveling or for storage', 'name': 'trunk'}, {'frequency': 'r', 'synset': 'tub.n.02', 'synonyms': ['vat'], 'id': 1126, 'def': 'a large vessel for holding or storing liquids', 'name': 'vat'}, {'frequency': 'c', 'synset': 'turban.n.01', 'synonyms': ['turban'], 'id': 1127, 'def': 'a traditional headdress consisting of a long scarf wrapped around the head', 'name': 'turban'}, {'frequency': 'c', 'synset': 'turkey.n.04', 'synonyms': ['turkey_(food)'], 'id': 1128, 'def': 'flesh of large domesticated fowl usually roasted', 'name': 'turkey_(food)'}, {'frequency': 'r', 'synset': 'turnip.n.01', 'synonyms': ['turnip'], 'id': 1129, 'def': 'widely cultivated plant having a large fleshy edible white or yellow root', 'name': 'turnip'}, {'frequency': 'c', 'synset': 'turtle.n.02', 'synonyms': ['turtle'], 'id': 1130, 'def': 'any of various aquatic and land reptiles having a bony shell and flipper-like limbs for swimming', 'name': 'turtle'}, {'frequency': 'c', 'synset': 'turtleneck.n.01', 'synonyms': ['turtleneck_(clothing)', 'polo-neck'], 'id': 1131, 'def': 'a sweater or jersey with a high close-fitting collar', 'name': 'turtleneck_(clothing)'}, {'frequency': 'c', 'synset': 'typewriter.n.01', 'synonyms': ['typewriter'], 'id': 1132, 'def': 'hand-operated character printer for printing written messages one character at a time', 'name': 'typewriter'}, {'frequency': 'f', 'synset': 'umbrella.n.01', 'synonyms': ['umbrella'], 'id': 1133, 'def': 'a lightweight handheld collapsible canopy', 'name': 'umbrella'}, {'frequency': 'f', 'synset': 'underwear.n.01', 'synonyms': ['underwear', 'underclothes', 'underclothing', 'underpants'], 'id': 1134, 'def': 'undergarment worn next to the skin and under the outer garments', 'name': 'underwear'}, {'frequency': 'r', 'synset': 'unicycle.n.01', 'synonyms': ['unicycle'], 'id': 1135, 'def': 'a vehicle with a single wheel that is driven by pedals', 'name': 'unicycle'}, {'frequency': 'f', 'synset': 'urinal.n.01', 'synonyms': ['urinal'], 'id': 1136, 'def': 'a plumbing fixture (usually attached to the wall) used by men to urinate', 'name': 'urinal'}, {'frequency': 'c', 'synset': 'urn.n.01', 'synonyms': ['urn'], 'id': 1137, 'def': 'a large vase that usually has a pedestal or feet', 'name': 'urn'}, {'frequency': 'c', 'synset': 'vacuum.n.04', 'synonyms': ['vacuum_cleaner'], 'id': 1138, 'def': 'an electrical home appliance that cleans by suction', 'name': 'vacuum_cleaner'}, {'frequency': 'f', 'synset': 'vase.n.01', 'synonyms': ['vase'], 'id': 1139, 'def': 'an open jar of glass or porcelain used as an ornament or to hold flowers', 'name': 'vase'}, {'frequency': 'c', 'synset': 'vending_machine.n.01', 'synonyms': ['vending_machine'], 'id': 1140, 'def': 'a slot machine for selling goods', 'name': 'vending_machine'}, {'frequency': 'f', 'synset': 'vent.n.01', 'synonyms': ['vent', 'blowhole', 'air_vent'], 'id': 1141, 'def': 'a hole for the escape of gas or air', 'name': 'vent'}, {'frequency': 'f', 'synset': 'vest.n.01', 'synonyms': ['vest', 'waistcoat'], 'id': 1142, 'def': "a man's sleeveless garment worn underneath a coat", 'name': 'vest'}, {'frequency': 'c', 'synset': 'videotape.n.01', 'synonyms': ['videotape'], 'id': 1143, 'def': 'a video recording made on magnetic tape', 'name': 'videotape'}, {'frequency': 'r', 'synset': 'vinegar.n.01', 'synonyms': ['vinegar'], 'id': 1144, 'def': 'sour-tasting liquid produced usually by oxidation of the alcohol in wine or cider and used as a condiment or food preservative', 'name': 'vinegar'}, {'frequency': 'r', 'synset': 'violin.n.01', 'synonyms': ['violin', 'fiddle'], 'id': 1145, 'def': 'bowed stringed instrument that is the highest member of the violin family', 'name': 'violin'}, {'frequency': 'r', 'synset': 'vodka.n.01', 'synonyms': ['vodka'], 'id': 1146, 'def': 'unaged colorless liquor originating in Russia', 'name': 'vodka'}, {'frequency': 'c', 'synset': 'volleyball.n.02', 'synonyms': ['volleyball'], 'id': 1147, 'def': 'an inflated ball used in playing volleyball', 'name': 'volleyball'}, {'frequency': 'r', 'synset': 'vulture.n.01', 'synonyms': ['vulture'], 'id': 1148, 'def': 'any of various large birds of prey having naked heads and weak claws and feeding chiefly on carrion', 'name': 'vulture'}, {'frequency': 'c', 'synset': 'waffle.n.01', 'synonyms': ['waffle'], 'id': 1149, 'def': 'pancake batter baked in a waffle iron', 'name': 'waffle'}, {'frequency': 'r', 'synset': 'waffle_iron.n.01', 'synonyms': ['waffle_iron'], 'id': 1150, 'def': 'a kitchen appliance for baking waffles', 'name': 'waffle_iron'}, {'frequency': 'c', 'synset': 'wagon.n.01', 'synonyms': ['wagon'], 'id': 1151, 'def': 'any of various kinds of wheeled vehicles drawn by an animal or a tractor', 'name': 'wagon'}, {'frequency': 'c', 'synset': 'wagon_wheel.n.01', 'synonyms': ['wagon_wheel'], 'id': 1152, 'def': 'a wheel of a wagon', 'name': 'wagon_wheel'}, {'frequency': 'c', 'synset': 'walking_stick.n.01', 'synonyms': ['walking_stick'], 'id': 1153, 'def': 'a stick carried in the hand for support in walking', 'name': 'walking_stick'}, {'frequency': 'c', 'synset': 'wall_clock.n.01', 'synonyms': ['wall_clock'], 'id': 1154, 'def': 'a clock mounted on a wall', 'name': 'wall_clock'}, {'frequency': 'f', 'synset': 'wall_socket.n.01', 'synonyms': ['wall_socket', 'wall_plug', 'electric_outlet', 'electrical_outlet', 'outlet', 'electric_receptacle'], 'id': 1155, 'def': 'receptacle providing a place in a wiring system where current can be taken to run electrical devices', 'name': 'wall_socket'}, {'frequency': 'f', 'synset': 'wallet.n.01', 'synonyms': ['wallet', 'billfold'], 'id': 1156, 'def': 'a pocket-size case for holding papers and paper money', 'name': 'wallet'}, {'frequency': 'r', 'synset': 'walrus.n.01', 'synonyms': ['walrus'], 'id': 1157, 'def': 'either of two large northern marine mammals having ivory tusks and tough hide over thick blubber', 'name': 'walrus'}, {'frequency': 'r', 'synset': 'wardrobe.n.01', 'synonyms': ['wardrobe'], 'id': 1158, 'def': 'a tall piece of furniture that provides storage space for clothes; has a door and rails or hooks for hanging clothes', 'name': 'wardrobe'}, {'frequency': 'r', 'synset': 'washbasin.n.01', 'synonyms': ['washbasin', 'basin_(for_washing)', 'washbowl', 'washstand', 'handbasin'], 'id': 1159, 'def': 'a bathroom sink that is permanently installed and connected to a water supply and drainpipe; where you can wash your hands and face', 'name': 'washbasin'}, {'frequency': 'c', 'synset': 'washer.n.03', 'synonyms': ['automatic_washer', 'washing_machine'], 'id': 1160, 'def': 'a home appliance for washing clothes and linens automatically', 'name': 'automatic_washer'}, {'frequency': 'f', 'synset': 'watch.n.01', 'synonyms': ['watch', 'wristwatch'], 'id': 1161, 'def': 'a small, portable timepiece', 'name': 'watch'}, {'frequency': 'f', 'synset': 'water_bottle.n.01', 'synonyms': ['water_bottle'], 'id': 1162, 'def': 'a bottle for holding water', 'name': 'water_bottle'}, {'frequency': 'c', 'synset': 'water_cooler.n.01', 'synonyms': ['water_cooler'], 'id': 1163, 'def': 'a device for cooling and dispensing drinking water', 'name': 'water_cooler'}, {'frequency': 'c', 'synset': 'water_faucet.n.01', 'synonyms': ['water_faucet', 'water_tap', 'tap_(water_faucet)'], 'id': 1164, 'def': 'a faucet for drawing water from a pipe or cask', 'name': 'water_faucet'}, {'frequency': 'r', 'synset': 'water_heater.n.01', 'synonyms': ['water_heater', 'hot-water_heater'], 'id': 1165, 'def': 'a heater and storage tank to supply heated water', 'name': 'water_heater'}, {'frequency': 'c', 'synset': 'water_jug.n.01', 'synonyms': ['water_jug'], 'id': 1166, 'def': 'a jug that holds water', 'name': 'water_jug'}, {'frequency': 'r', 'synset': 'water_pistol.n.01', 'synonyms': ['water_gun', 'squirt_gun'], 'id': 1167, 'def': 'plaything consisting of a toy pistol that squirts water', 'name': 'water_gun'}, {'frequency': 'c', 'synset': 'water_scooter.n.01', 'synonyms': ['water_scooter', 'sea_scooter', 'jet_ski'], 'id': 1168, 'def': 'a motorboat resembling a motor scooter (NOT A SURFBOARD OR WATER SKI)', 'name': 'water_scooter'}, {'frequency': 'c', 'synset': 'water_ski.n.01', 'synonyms': ['water_ski'], 'id': 1169, 'def': 'broad ski for skimming over water towed by a speedboat (DO NOT MARK WATER)', 'name': 'water_ski'}, {'frequency': 'c', 'synset': 'water_tower.n.01', 'synonyms': ['water_tower'], 'id': 1170, 'def': 'a large reservoir for water', 'name': 'water_tower'}, {'frequency': 'c', 'synset': 'watering_can.n.01', 'synonyms': ['watering_can'], 'id': 1171, 'def': 'a container with a handle and a spout with a perforated nozzle; used to sprinkle water over plants', 'name': 'watering_can'}, {'frequency': 'f', 'synset': 'watermelon.n.02', 'synonyms': ['watermelon'], 'id': 1172, 'def': 'large oblong or roundish melon with a hard green rind and sweet watery red or occasionally yellowish pulp', 'name': 'watermelon'}, {'frequency': 'f', 'synset': 'weathervane.n.01', 'synonyms': ['weathervane', 'vane_(weathervane)', 'wind_vane'], 'id': 1173, 'def': 'mechanical device attached to an elevated structure; rotates freely to show the direction of the wind', 'name': 'weathervane'}, {'frequency': 'c', 'synset': 'webcam.n.01', 'synonyms': ['webcam'], 'id': 1174, 'def': 'a digital camera designed to take digital photographs and transmit them over the internet', 'name': 'webcam'}, {'frequency': 'c', 'synset': 'wedding_cake.n.01', 'synonyms': ['wedding_cake', 'bridecake'], 'id': 1175, 'def': 'a rich cake with two or more tiers and covered with frosting and decorations; served at a wedding reception', 'name': 'wedding_cake'}, {'frequency': 'c', 'synset': 'wedding_ring.n.01', 'synonyms': ['wedding_ring', 'wedding_band'], 'id': 1176, 'def': 'a ring given to the bride and/or groom at the wedding', 'name': 'wedding_ring'}, {'frequency': 'f', 'synset': 'wet_suit.n.01', 'synonyms': ['wet_suit'], 'id': 1177, 'def': 'a close-fitting garment made of a permeable material; worn in cold water to retain body heat', 'name': 'wet_suit'}, {'frequency': 'f', 'synset': 'wheel.n.01', 'synonyms': ['wheel'], 'id': 1178, 'def': 'a circular frame with spokes (or a solid disc) that can rotate on a shaft or axle', 'name': 'wheel'}, {'frequency': 'c', 'synset': 'wheelchair.n.01', 'synonyms': ['wheelchair'], 'id': 1179, 'def': 'a movable chair mounted on large wheels', 'name': 'wheelchair'}, {'frequency': 'c', 'synset': 'whipped_cream.n.01', 'synonyms': ['whipped_cream'], 'id': 1180, 'def': 'cream that has been beaten until light and fluffy', 'name': 'whipped_cream'}, {'frequency': 'c', 'synset': 'whistle.n.03', 'synonyms': ['whistle'], 'id': 1181, 'def': 'a small wind instrument that produces a whistling sound by blowing into it', 'name': 'whistle'}, {'frequency': 'c', 'synset': 'wig.n.01', 'synonyms': ['wig'], 'id': 1182, 'def': 'hairpiece covering the head and made of real or synthetic hair', 'name': 'wig'}, {'frequency': 'c', 'synset': 'wind_chime.n.01', 'synonyms': ['wind_chime'], 'id': 1183, 'def': 'a decorative arrangement of pieces of metal or glass or pottery that hang together loosely so the wind can cause them to tinkle', 'name': 'wind_chime'}, {'frequency': 'c', 'synset': 'windmill.n.01', 'synonyms': ['windmill'], 'id': 1184, 'def': 'A mill or turbine that is powered by wind', 'name': 'windmill'}, {'frequency': 'c', 'synset': 'window_box.n.01', 'synonyms': ['window_box_(for_plants)'], 'id': 1185, 'def': 'a container for growing plants on a windowsill', 'name': 'window_box_(for_plants)'}, {'frequency': 'f', 'synset': 'windshield_wiper.n.01', 'synonyms': ['windshield_wiper', 'windscreen_wiper', 'wiper_(for_windshield/screen)'], 'id': 1186, 'def': 'a mechanical device that cleans the windshield', 'name': 'windshield_wiper'}, {'frequency': 'c', 'synset': 'windsock.n.01', 'synonyms': ['windsock', 'air_sock', 'air-sleeve', 'wind_sleeve', 'wind_cone'], 'id': 1187, 'def': 'a truncated cloth cone mounted on a mast/pole; shows wind direction', 'name': 'windsock'}, {'frequency': 'f', 'synset': 'wine_bottle.n.01', 'synonyms': ['wine_bottle'], 'id': 1188, 'def': 'a bottle for holding wine', 'name': 'wine_bottle'}, {'frequency': 'c', 'synset': 'wine_bucket.n.01', 'synonyms': ['wine_bucket', 'wine_cooler'], 'id': 1189, 'def': 'a bucket of ice used to chill a bottle of wine', 'name': 'wine_bucket'}, {'frequency': 'f', 'synset': 'wineglass.n.01', 'synonyms': ['wineglass'], 'id': 1190, 'def': 'a glass that has a stem and in which wine is served', 'name': 'wineglass'}, {'frequency': 'f', 'synset': 'winker.n.02', 'synonyms': ['blinder_(for_horses)'], 'id': 1191, 'def': 'blinds that prevent a horse from seeing something on either side', 'name': 'blinder_(for_horses)'}, {'frequency': 'c', 'synset': 'wok.n.01', 'synonyms': ['wok'], 'id': 1192, 'def': 'pan with a convex bottom; used for frying in Chinese cooking', 'name': 'wok'}, {'frequency': 'r', 'synset': 'wolf.n.01', 'synonyms': ['wolf'], 'id': 1193, 'def': 'a wild carnivorous mammal of the dog family, living and hunting in packs', 'name': 'wolf'}, {'frequency': 'c', 'synset': 'wooden_spoon.n.02', 'synonyms': ['wooden_spoon'], 'id': 1194, 'def': 'a spoon made of wood', 'name': 'wooden_spoon'}, {'frequency': 'c', 'synset': 'wreath.n.01', 'synonyms': ['wreath'], 'id': 1195, 'def': 'an arrangement of flowers, leaves, or stems fastened in a ring', 'name': 'wreath'}, {'frequency': 'c', 'synset': 'wrench.n.03', 'synonyms': ['wrench', 'spanner'], 'id': 1196, 'def': 'a hand tool that is used to hold or twist a nut or bolt', 'name': 'wrench'}, {'frequency': 'f', 'synset': 'wristband.n.01', 'synonyms': ['wristband'], 'id': 1197, 'def': 'band consisting of a part of a sleeve that covers the wrist', 'name': 'wristband'}, {'frequency': 'f', 'synset': 'wristlet.n.01', 'synonyms': ['wristlet', 'wrist_band'], 'id': 1198, 'def': 'a band or bracelet worn around the wrist', 'name': 'wristlet'}, {'frequency': 'c', 'synset': 'yacht.n.01', 'synonyms': ['yacht'], 'id': 1199, 'def': 'an expensive vessel propelled by sail or power and used for cruising or racing', 'name': 'yacht'}, {'frequency': 'c', 'synset': 'yogurt.n.01', 'synonyms': ['yogurt', 'yoghurt', 'yoghourt'], 'id': 1200, 'def': 'a custard-like food made from curdled milk', 'name': 'yogurt'}, {'frequency': 'c', 'synset': 'yoke.n.07', 'synonyms': ['yoke_(animal_equipment)'], 'id': 1201, 'def': 'gear joining two animals at the neck; NOT egg yolk', 'name': 'yoke_(animal_equipment)'}, {'frequency': 'f', 'synset': 'zebra.n.01', 'synonyms': ['zebra'], 'id': 1202, 'def': 'any of several fleet black-and-white striped African equines', 'name': 'zebra'}, {'frequency': 'c', 'synset': 'zucchini.n.02', 'synonyms': ['zucchini', 'courgette'], 'id': 1203, 'def': 'small cucumber-shaped vegetable marrow; typically dark green', 'name': 'zucchini'}] # noqa -# fmt: on diff --git a/src/regionclip/loader.py b/src/regionclip/loader.py deleted file mode 100644 index 4a8865f6c29e71b484e62befc915cb68720ec678..0000000000000000000000000000000000000000 --- a/src/regionclip/loader.py +++ /dev/null @@ -1,232 +0,0 @@ -from .clip_backbone import CLIP, convert_weights -import yaml -import os -import torch - -""" -Use the method load_regionclip_from_checkpoint to load a RegionCLIP model from a checkpoint file. -This function will automatically handle the conversion of RegionCLIP-specific state_dict keys to the standard CLIP format. -It also allows you to specify a configuration file to set parameters like out_features and freeze_at. -""" - - -def load_regionclip_config(config_name): - """ - Load RegionCLIP configuration from YAML file. - - Args: - config_name (str): Name of the YAML configuration file (from the regionclip/configs directory) - - Returns: - dict: Configuration dictionary - """ - config_path = os.path.join(os.path.dirname(__file__), 'configs', config_name) - - if not os.path.exists(config_path): - raise FileNotFoundError(f"Configuration file not found: {config_path}") - - #print(f"Loading RegionCLIP config from: {config_path}") - - with open(config_path, 'r') as f: - config = yaml.safe_load(f) - - #print(f"Successfully loaded configuration") - #print(f" - Model architecture: {config.get('MODEL', {}).get('META_ARCHITECTURE', 'Unknown')}") - #print(f" - Backbone: {config.get('MODEL', {}).get('BACKBONE', {}).get('NAME', 'Unknown')}") - #print(f" - Freeze at: {config.get('MODEL', {}).get('BACKBONE', {}).get('FREEZE_AT', 'Unknown')}") - - return config - -def load_regionclip_from_checkpoint(checkpoint_path, device='cpu', config=None, override_config=None): - """ - Load CLIP model from a checkpoint file using build_model function. - - Args: - checkpoint_path (str): Path to the .pth checkpoint file - device (str): Device to load the model on ('cpu', 'cuda', etc.) - config (dict | str): RegionCLIP configuration dictionary from YAML file or name of the config file. - If a string is provided, it will be loaded using load_regionclip_config. - override_config (dict): Optional dictionary to override specific configuration parameters. - - Returns: - CLIP model with loaded weights - """ - - if isinstance(config, str): - # If config is a string, load it from the YAML file - config = load_regionclip_config(config) - - if override_config: - # Override specific configuration parameters if provided - if config is None: - config = {} - # handle case of nested dictionaries - def recursive_update(d, u): - for k, v in u.items(): - if isinstance(v, dict) and k in d: - recursive_update(d[k], v) - else: - d[k] = v - recursive_update(config, override_config) - - #print(f"Loading checkpoint from: {checkpoint_path}") - - if not os.path.exists(checkpoint_path): - raise FileNotFoundError(f"Checkpoint file not found: {checkpoint_path}") - - # Load the checkpoint - checkpoint = torch.load(checkpoint_path, map_location=device, weights_only=False) - - # Extract state_dict if it's wrapped in a checkpoint structure - if 'state_dict' in checkpoint: - state_dict = checkpoint['state_dict'] - #print("Found 'state_dict' in checkpoint") - elif 'model' in checkpoint: - state_dict = checkpoint['model'] - #print("Found 'model' in checkpoint") - else: - # Assume the checkpoint is directly the state_dict - state_dict = checkpoint - #print("Using checkpoint as state_dict directly") - - # Convert RegionCLIP format to standard CLIP format if needed - if any(k.startswith('lang_encoder.') or k.startswith('backbone.') for k in state_dict.keys()): - #print("Converting RegionCLIP format to standard CLIP format...") - converted_state_dict = {} - - for key, value in state_dict.items(): - if key.startswith('lang_encoder.'): - # Remove lang_encoder prefix for text encoder - new_key = key.replace('lang_encoder.', '') - converted_state_dict[new_key] = value - elif key.startswith('backbone.'): - # Convert backbone to visual - new_key = key.replace('backbone.', 'visual.') - converted_state_dict[new_key] = value - - clip_state_dict = converted_state_dict - #print(f"Extracted {len(clip_state_dict)} CLIP-specific parameters") - else: - # Filter to only CLIP-related keys - clip_keys = [k for k in state_dict.keys() if any(clip_prefix in k for clip_prefix in [ - 'visual.', 'transformer.', 'token_embedding', 'positional_embedding', - 'ln_final', 'text_projection', 'logit_scale' - ])] - - if clip_keys: - clip_state_dict = {k: state_dict[k] for k in clip_keys} - #print(f"Extracted {len(clip_keys)} CLIP-specific parameters") - else: - # This checkpoint doesn't contain standard CLIP weights - print("No CLIP weights found in this checkpoint") - raise ValueError("No CLIP weights found in checkpoint") - - # Add missing logit_scale if not present - if 'logit_scale' not in clip_state_dict: - import numpy as np - print("Adding missing logit_scale parameter") - clip_state_dict['logit_scale'] = torch.ones([]) * np.log(1 / 0.07) - #self.logit_scale = nn.Parameter(torch.ones([]) * np.log(1 / 0.07)) - - # Use build_model function directly with custom wrapper - try: - # Create a custom build_model that provides the missing parameters - model = build_model_with_defaults(clip_state_dict, config) - model.to(device) - #print("Successfully created model using build_model()") - return model - except Exception as e: - print(f"build_model() failed: {e}") - raise e - -def build_model_with_defaults(state_dict, config=None): - """ - Wrapper around build_model that provides the required out_features and freeze_at parameters - - Args: - state_dict: Model state dictionary - config: RegionCLIP configuration dictionary from YAML file - """ - - # Get configuration parameters - if config: - model_config = config.get('MODEL', {}) - resnets_config = model_config.get('RESNETS', {}) - backbone_config = model_config.get('BACKBONE', {}) - - # Extract configuration values - out_features = resnets_config.get('OUT_FEATURES', ['res4']) - freeze_at = backbone_config.get('FREEZE_AT', 0) - - depth = resnets_config.get('DEPTH', None) # Optional depth parameter - - image_resolution = resnets_config.get('IMAGE_RESOLUTION', None) - - #print(f"Using config values - out_features: {out_features}, freeze_at: {freeze_at}") - else: - # Default values if no config is provided - out_features = ['res4'] - freeze_at = 0 - - vit = "visual.proj" in state_dict - - if vit: - vision_width = state_dict["visual.conv1.weight"].shape[0] - vision_layers = len([k for k in state_dict.keys() if k.startswith("visual.") and k.endswith(".attn.in_proj_weight")]) - vision_patch_size = state_dict["visual.conv1.weight"].shape[-1] - grid_size = round((state_dict["visual.positional_embedding"].shape[0] - 1) ** 0.5) - image_resolution = vision_patch_size * grid_size - else: - counts = [len(set(k.split(".")[2] for k in state_dict if k.startswith(f"visual.layer{b}"))) for b in [1, 2, 3, 4]] - vision_layers = tuple(counts) - vision_width = state_dict["visual.layer1.0.conv1.weight"].shape[0] - output_width = round((state_dict["visual.attnpool.positional_embedding"].shape[0] - 1) ** 0.5) - vision_patch_size = None - assert output_width ** 2 + 1 == state_dict["visual.attnpool.positional_embedding"].shape[0] - if image_resolution is None: - image_resolution = output_width * 32 - else: - if image_resolution / 32 != output_width: - # The positional embedding is not compatible with the image resolution - # Remove it from state_dict and let the model create a new one - print(f"Warning: Removing incompatible positional embedding from checkpoint.") - print(f" Checkpoint spatial size: {output_width}x{output_width} (for image resolution {output_width * 32})") - print(f" Config image resolution: {image_resolution} (requires {image_resolution // 32}x{image_resolution // 32})") - if "visual.attnpool.positional_embedding" in state_dict: - del state_dict["visual.attnpool.positional_embedding"] - # Update output_width to match the config - output_width = image_resolution // 32 - - embed_dim = state_dict["text_projection"].shape[1] - context_length = state_dict["positional_embedding"].shape[0] - vocab_size = state_dict["token_embedding.weight"].shape[0] - transformer_width = state_dict["ln_final.weight"].shape[0] - transformer_heads = transformer_width // 64 - transformer_layers = len(set(k.split(".")[2] for k in state_dict if k.startswith(f"transformer.resblocks"))) - - # Create CLIP model with the required parameters - model = CLIP( - embed_dim, - image_resolution, vision_layers, vision_width, vision_patch_size, - context_length, vocab_size, transformer_width, transformer_heads, transformer_layers, - out_features=out_features, # Use configuration parameter - freeze_at=freeze_at, # Use configuration parameter - depth=depth # Use configuration parameter - ) - - # Clean up state_dict - for key in ["input_resolution", "context_length", "vocab_size"]: - if key in state_dict: - del state_dict[key] - - convert_weights(model) - - # Load state dict with flexibility for missing or incompatible keys - incompatible_keys = model.load_state_dict(state_dict, strict=False) - - if incompatible_keys.missing_keys: - print(f"Note: Missing keys in checkpoint (will use model defaults): {incompatible_keys.missing_keys}") - if incompatible_keys.unexpected_keys: - print(f"Note: Unexpected keys in checkpoint (ignored): {incompatible_keys.unexpected_keys}") - - return model.eval() \ No newline at end of file diff --git a/src/talk2dino/talk2dino.py b/src/talk2dino/talk2dino.py deleted file mode 100644 index 0a7052b980299f7f2c135215ee904b13aa2fc4af..0000000000000000000000000000000000000000 --- a/src/talk2dino/talk2dino.py +++ /dev/null @@ -1,97 +0,0 @@ -import torch -import torch.nn as nn -import yaml - - -class ProjectionLayer(nn.Module): - """ - Creates a projection layer on top of the CLIP-text encoder. - The forward method calculate the similarity between the DINO CLS token and the projected CLIP textual CLS token. - """ - def __init__(self, act=nn.Tanh(), hidden_layer=False, cosine=True, dino_embed_dim=1024, clip_embed_dim=512, num_attn_head=16, weight_attn_heads=None, - alignment_strategy='max_score', alpha=0.6, keep_cls=False, keep_end_seq=False): - # mlp_dims list of mlp dimensions - super().__init__() - self.num_attn_head = num_attn_head - - self.linear_layer = nn.Linear(clip_embed_dim, dino_embed_dim) - if hidden_layer: - hidden_layer = 1 if hidden_layer is True else hidden_layer # ensuring compatibility with old code - # self.linear_layer2 = nn.Linear(dino_embed_dim, dino_embed_dim) - self.hidden_layers = nn.ModuleList([nn.Linear(dino_embed_dim, dino_embed_dim) for _ in range(hidden_layer)]) - self.act = act - self.cosine = cosine - - self.weight_attn_heads = weight_attn_heads - if weight_attn_heads == 'static': - self.attn_weights = nn.Parameter(torch.rand(self.num_attn_head)) - elif weight_attn_heads == 'conditioned': - self.weight_layer1 = nn.Linear(dino_embed_dim, dino_embed_dim) - self.weight_layer2 = nn.Linear(dino_embed_dim, self.num_attn_head) - - self.alignment_strategy = alignment_strategy # relevant only if we use disentangled_self_attn - self.keep_cls = keep_cls # relevant only if we use clip_txt_tokens_out - self.keep_end_seq = keep_end_seq # relevant only if we use clip_txt_tokens_out - self.alpha = alpha - - @classmethod - def from_config(cls, config): - if type(config) is str: - # if the configuration is a string, we treat it as a file path - with open(config, 'r') as f: - config = yaml.safe_load(f)['model'] - - # loading the activation function - act = config.get('act', None) - if act == 'tanh': - act = nn.Tanh() - elif act == 'relu': - act = nn.ReLU() - elif act == 'sigmoid': - act = nn.Sigmoid() - elif act is not None: - raise Exception("Unknown activation function") - - model = cls( - act=act, - hidden_layer=config.get('hidden_layer', False), - cosine=config.get('cosine', True), - dino_embed_dim=config.get('dino_embed_dim', 1024), - num_attn_head=config.get('num_attn_head', 16), - clip_embed_dim=config.get('clip_embed_dim', 512), - weight_attn_heads=config.get('weight_attn_heads', None), - alignment_strategy=config.get('alignment_strategy', 'max_score'), - alpha=config.get('alpha', 0.6), - keep_cls=config.get('keep_cls', None), - keep_end_seq=config.get('keep_end_seq', None), - ) - if config.get('starting_checkpoint', None) is not None: - model.load_state_dict(torch.load(config['starting_checkpoint'], 'cpu')) - - return model - - def project_clip_txt(self, textual_embedding): - textual_embedding = textual_embedding.float() - x = self.linear_layer(textual_embedding) - - if hasattr(self, 'hidden_layers'): - for hidden_layer in self.hidden_layers: - if self.act: - x = self.act(x) - x = hidden_layer(x) - - return x - def load_state_dict(self, state_dict, strict=True): - # compatibility with old code - if 'linear_layer2.weight' in state_dict: - state_dict['hidden_layers.0.weight'] = state_dict.pop('linear_layer2.weight') - state_dict['hidden_layers.0.bias'] = state_dict.pop('linear_layer2.bias') - # Call the parent class's load_state_dict with the modified state_dict - super(ProjectionLayer, self).load_state_dict(state_dict, strict) - - def set_alignment_strategy(self, alignment_strategy): - self.alignment_strategy = alignment_strategy - return - - def __len__(self): - return sum(p.numel() for p in self.parameters()) \ No newline at end of file diff --git a/src/viecap/ClipCap.py b/src/viecap/ClipCap.py deleted file mode 100644 index 32b152021849be777a037defa32dbce9201a84b2..0000000000000000000000000000000000000000 --- a/src/viecap/ClipCap.py +++ /dev/null @@ -1,251 +0,0 @@ -import torch -import torch.nn as nn -import torch.nn.functional as nnf -from typing import Tuple, Optional, List -from transformers import GPT2LMHeadModel - -class MlpTransformer(nn.Module): - - def __init__( - self, - input_size: int, # the input size of mlp - hidden_size: int, # the hidden layer size of mlp - output_size: Optional[int] = None, # the output size of mlp - act = nnf.relu, - dropout: float = 0.0 - ) -> None: - super().__init__() - output_size = output_size if output_size is not None else input_size - self.fc1 = nn.Linear(input_size, hidden_size) - self.act = act - self.fc2 = nn.Linear(hidden_size, output_size) - self.dropout = nn.Dropout(dropout) - - def forward(self, x: torch.Tensor) -> torch.Tensor: - x = self.fc1(x) - x = self.act(x) - x = self.dropout(x) - x = self.fc2(x) - x = self.dropout(x) - return x - -class MultiHeadAttention(nn.Module): - - def __init__( - self, - query_size: int, - key_value_size: int, - num_heads: int, - bias = True, - dropout: float = 0.0 - ) -> None: - super(MultiHeadAttention, self).__init__() - self.num_heads = num_heads - self.head_size = query_size // num_heads # the size of each head - self.scale = self.head_size ** -0.5 # normalization factor for each head - self.to_queries = nn.Linear(query_size, query_size, bias = bias) - # projecting key and value together and spliting them for computing efficiently - self.to_keys_values = nn.Linear(key_value_size, 2 * query_size, bias = bias) - self.project = nn.Linear(query_size, query_size) - self.dropout = nn.Dropout(dropout) - - def forward(self, query: torch.Tensor, key_value: torch.Tensor = None, mask: torch.Tensor = None) -> torch.Tensor: - key_value = key_value if key_value is not None else query - b, n, d_query = query.shape - _, m, _ = key_value.shape - queries = self.to_queries(query).reshape(b, n, self.num_heads, self.head_size) # (batch_size, n_seq, num_heads, head_size) - keys_values = self.to_keys_values(key_value).reshape(b, m, 2, self.num_heads, self.head_size) # (batch_size, m_seq, 2, num_heads, head_size) - keys, values = keys_values[:, :, 0], keys_values[:, :, 1] # (batch_size, m_seq, num_heads, head_size), (batch_size, m_seq, num_heads, head_size) - attention = torch.einsum('bnhd,bmhd->bnmh', queries, keys) * self.scale # (batch_size, n_seq, m_seq, num_heads) - - if mask is not None: - if mask.dim() == 2: - mask = mask.unsqueeze(dim = 1) # expending dimension, shape: (batch_size, 1, m_seq) - attention = attention.masked_fill(mask.unsqueeze(dim = 3), float("-inf")) # expending dimension n_seq head and fill -inf according to mask - - attention = attention.softmax(dim = 2) # softmax alongside the dimension of key_value pairs - outputs = torch.einsum('bnmh,bmhd->bnhd', attention, values).reshape(b, n, d_query) # (batch_size, n_seq, d_query) - outputs = self.project(outputs) - return outputs, attention - -class TransformerLayer(nn.Module): - - def __init__( - self, - query_size: int, - key_value_size: int, - num_heads: int, - mlp_ratio = 4.0, - bias = False, - dropout: float = 0.0, - act = nnf.relu, - norm_layer: nn.Module = nn.LayerNorm - ) -> None: - super(TransformerLayer, self).__init__() - self.norm1 = norm_layer(query_size) - self.attn = MultiHeadAttention(query_size, key_value_size, num_heads, bias = bias, dropout = dropout) - self.norm2 = norm_layer(query_size) - self.mlp = MlpTransformer(query_size, int(query_size * mlp_ratio), act = act, dropout = dropout) - - def forward(self, query: torch.Tensor, key_value: torch.Tensor = None, mask: torch.Tensor = None) -> torch.Tensor: - query_, self.attention = self.attn(self.norm1(query), key_value, mask) - query = query + query_ - query = query + self.mlp(self.norm2(query)) - return query - -class Transformer(nn.Module): - - def __init__( - self, - query_size: int, # query size - num_layers: int, # number of layer - num_heads: int, # number of head - key_value_size: Optional[int] = None, # key/value size - mlp_ratio: float = 2.0, # ratio for hidden size in mlp - act = nnf.relu, # activation - norm_layer: nn.Module = nn.LayerNorm # normalization - ) -> None: - super(Transformer, self).__init__() - key_value_size = key_value_size if key_value_size is not None else query_size - layers = [] - for _ in range(num_layers): - layers.append(TransformerLayer(query_size, key_value_size, num_heads, mlp_ratio = mlp_ratio, act = act, norm_layer = norm_layer)) - self.layers = nn.Sequential(*layers) - - def forward(self, query: torch.Tensor, key_value: torch.Tensor = None, mask: torch.Tensor = None) -> torch.Tensor: - self.attentions = [] - for layer in self.layers: - query = layer(query, key_value, mask) - self.attentions.append(layer.attention) - return query - -class MappingNetwork(nn.Module): - - def __init__( - self, - clip_project_length: int, - clip_hidden_size: int, - prefix_length: int, - d_model: int, # the hidden size of language model - num_layers: int = 8, - num_heads: int = 8 - ) -> None: - super(MappingNetwork, self).__init__() - self.clip_project_length = clip_project_length - # projector for input - self.linear = nn.Linear(clip_hidden_size, clip_project_length * d_model) - # learnable prefix embeddings - self.prefix_const = nn.Parameter(torch.randn(prefix_length, d_model), requires_grad = True) - self.transformer = Transformer(d_model, num_layers, num_heads) - - def forward(self, x: torch.Tensor) -> torch.Tensor: - """ - Args: - x: clip cls feature with a shape of (batch_size, clip_hidden_size) - Return: - the embeddings of prefix with the shape of (batch_size, prefix_length, d_model) - """ - x = self.linear(x).view(x.shape[0], self.clip_project_length, -1) # (b, clip_project_length, d_model) - prefix = self.prefix_const.unsqueeze(dim = 0).expand(x.shape[0], *self.prefix_const.shape) # (b, prefix_length, d_model) - inputs = torch.cat((x, prefix), dim = 1) # (b, clip_project_length + prefix_length, d_model) - outputs = self.transformer(inputs)[:,self.clip_project_length:,:] # (b, prefix_length, d_model) - - return outputs - -def get_language_mode(lm_type): - if 'gpt' in lm_type: - model = GPT2LMHeadModel.from_pretrained(lm_type) - hidden_size = model.config.hidden_size - elif 'opt' in lm_type: - from modeling_opt import OPTForCausalLM - model = OPTForCausalLM.from_pretrained(lm_type, torch_dtype = torch.float16) - hidden_size = model.config.word_embed_proj_dim - return model, hidden_size - -class ClipCaptionModel(nn.Module): - - def __init__( - self, - continuous_length: int = 10, - clip_project_length: int = 10, - clip_hidden_size: int = 512, - num_layers: int = 8, - num_heads: int = 8, - gpt_type: str = 'gpt2', - soft_prompt_first: bool = False, - only_hard_prompt: bool = False - ) -> None: - """ - Args: - continuous_length: the length of soft prompts which will be fed into language model as continuous part - clip_project_length: clip cls features (b, 1, d) -> (b, n, d) - clip_hidden_size: the dimensions of CLIP features - num_layers: the number of layer in projector - num_heads: the number of heads each layer - gpt_type: the language model - soft_prompt_first: False -> hard prompt + soft prompt; True -> soft prompt + hard prompt - only_hard_prompt: using the hard prompts only - """ - super(ClipCaptionModel, self).__init__() - self.soft_prompt_first = soft_prompt_first - self.only_hard_prompt = only_hard_prompt - self.continuous_length = continuous_length - self.gpt, self.gpt_hidden_size = get_language_mode(gpt_type) - self.mapping_network = MappingNetwork(clip_project_length, clip_hidden_size, continuous_length, self.gpt_hidden_size, num_layers, num_heads) - self.gpt_type = gpt_type - - def word_embed(self, caption_tokens): - if 'gpt' in self.gpt_type: - caption_embeddings = self.gpt.transformer.wte(caption_tokens) # (b, caption_length, gpt_hidden_size) - elif 'opt' in self.gpt_type: - caption_embeddings = self.gpt.model.decoder.embed_tokens(caption_tokens) - return caption_embeddings - - def forward( - self, - continuous_prompt: torch.Tensor, - caption_tokens: torch.Tensor, - hard_prompts_length: Optional[List] = None, - mask: Optional[torch.Tensor] = None, - ) -> Tuple[torch.Tensor, ...]: - """ - Args: - continuous_prompt: tensor with a shape of (b, clip_hidden_size), in text-only training, the caption features are eaxtracted from CLIP and used as image features - caption_tokens: caption tokens with a shape of (b, max_length_per_caption) - hard_prompts_length: list with len = batch size, the length of hard prompts constructed for each caption - mask: tensor with a shape of (b, discrete_length + continuous_length + max_length_per_caption), valid texts for attention computing - Return: - the output of language model - """ - caption_embeddings = self.word_embed(caption_tokens) - continuous_embeddings = self.mapping_network(continuous_prompt).view(-1, self.continuous_length, self.gpt_hidden_size) # (b, continuous_length, gpt_hidden_size) - if hard_prompts_length is not None: # with hard prompts - if self.only_hard_prompt: - embeddings = caption_embeddings - elif self.soft_prompt_first: # soft prompts + hard prompts - embeddings = torch.cat((continuous_embeddings, caption_embeddings), dim = 1) - else: # hard prompts + soft prompts - embeddings = None - for i in range(len(hard_prompts_length)): - length = hard_prompts_length[i] - temp_embeddings = torch.cat((caption_embeddings[i][:length], continuous_embeddings[i], caption_embeddings[i][length:]), dim = 0).unsqueeze(dim = 0) - if embeddings is None: - embeddings = temp_embeddings - else: - embeddings = torch.cat((embeddings, temp_embeddings), dim = 0) - else: # without hard prompts - embeddings = torch.cat((continuous_embeddings, caption_embeddings), dim = 1) # (b, continuous_length + caption_length, gpt_hidden_size) - - out = self.gpt(inputs_embeds = embeddings.type(self.gpt.dtype), attention_mask = mask) - - return out - -class ClipCaptionPrefix(ClipCaptionModel): - - def parameters(self, recurse: bool = True): - return self.mapping_network.parameters() - - def train(self, mode: bool = True): - super(ClipCaptionPrefix, self).train(mode) - self.gpt.eval() - return self \ No newline at end of file diff --git a/src/viecap/README.md b/src/viecap/README.md deleted file mode 100644 index ad1123f78ace63a4d41a006edb0304ee786ba82a..0000000000000000000000000000000000000000 --- a/src/viecap/README.md +++ /dev/null @@ -1,15 +0,0 @@ -The code is from [VieCap Repo](https://github.com/FeiElysia/ViECap/blob/main/infer_by_instance.py). - - - -## Execution Example - - -python infer_by_instance.py --image_path /raid/datasets/coco/train2017/000000000064.jpg - -the generated caption: in front of a building with a clock on the front of the building. -![http://images.cocodataset.org/train2017/000000000064.jpg](http://images.cocodataset.org/train2017/000000000064.jpg) - - -the generated caption: Blue and Yellow trains are on the tracks. -![http://images.cocodataset.org/train2017/000000000071.jpg](http://images.cocodataset.org/train2017/000000000071.jpg) \ No newline at end of file diff --git a/src/viecap/entrypoint.py b/src/viecap/entrypoint.py deleted file mode 100644 index b8e5aa947769161301f949e02e93242678219eba..0000000000000000000000000000000000000000 --- a/src/viecap/entrypoint.py +++ /dev/null @@ -1,203 +0,0 @@ -import torch -from torch.nn.utils.rnn import pad_sequence -from .ClipCap import ClipCaptionModel -from transformers import AutoTokenizer -from .utils import compose_discrete_prompts -from .load_annotations import load_entities_text -from .search import greedy_search, beam_search, opt_search -from .retrieval_categories import clip_texts_embeddings, image_text_simiarlity, top_k_categories -import os -from typing import List -from argparse import Namespace - - -class VieCap(torch.nn.Module): - - def __init__(self, args, device, clip_name): - super(VieCap, self).__init__() - args_dict = args.copy() - self.args = args = self.load_config(args) - self.device = device - - if args_dict.get('clip_hidden_size', None) is not None: - print(f"Using provided clip_hidden_size: {args_dict['clip_hidden_size']}") - self.clip_hidden_size = args_dict['clip_hidden_size'] - else: - print(f"Using default clip_hidden_size: {640 if 'RN' in clip_name else 512}") - self.clip_hidden_size = 640 if 'RN' in clip_name else 512 - - if args_dict.get('suffix', None) is not None: - suffix = args_dict['suffix'] - print(f"Using provided suffix: {suffix}") - else: - suffix = clip_name - print("No suffix provided, using empty string.") - - self.entities_text, self.texts_embeddings = self.get_viecap_texts_embeddings(args, suffix) - - self.tokenizer = AutoTokenizer.from_pretrained(args.language_model) - self.model = ClipCaptionModel(args.continuous_prompt_length, args.clip_project_length, self.clip_hidden_size, gpt_type = args.language_model) - self.model.load_state_dict(torch.load(args.weight_path, map_location = device), strict = False) - self.model.to(device) - - self.eval() - - defaults = { - #"clip_model": "ViT-B/32", - "language_model": "gpt2", - "continuous_prompt_length": 10, - "clip_project_length": 10, - "temperature": 0.01, - "top_k": 3, - "threshold": 0.2, - "disable_all_entities": False, - "name_of_entities_text": 'vinvl_vgoi_entities', - 'prompt_ensemble' : False, - "weight_path" : '/raid/datasets/viecap_files/checkpoints/train_coco/coco_prefix-0014.pt', - 'files_path' : '/raid/datasets/viecap_files/', - "using_hard_prompt": False, - "soft_prompt_first": False, - "only_hard_prompt": False, - "using_greedy_search": False, - "beam_width": 5, - "text_prompt": None, - } - - def load_config(self, args_dict : dict) -> Namespace: - - def dict_to_namespace(d): - if isinstance(d, dict): - return Namespace(**{k: dict_to_namespace(v) for k, v in d.items()}) - return d - # namespace should be loaded recursively - for key, value in self.defaults.items(): - if isinstance(value, dict): - for sub_key, sub_value in value.items(): - args_dict.setdefault(key, {}).setdefault(sub_key, sub_value) - else: - args_dict.setdefault(key, value) - args = dict_to_namespace(args_dict) - return args - - def forward(self, image_features, compute_scores : bool = False) -> List[str]: - """ - Image Features: (batch_size, clip_hidden_size) - - returns: List[str] - """ - #args = self.args - #model = self.model - pad_id = self.tokenizer.pad_token_id if self.tokenizer.pad_token_id is not None else 0 - - - image_features /= image_features.norm(2, dim = -1, keepdim = True) - - continuous_embeddings = self.model.mapping_network(image_features).view(-1, self.args.continuous_prompt_length, self.model.gpt_hidden_size) - - if self.args.using_hard_prompt: - - #logits = image_text_simiarlity(self.texts_embeddings, temperature = self.args.temperature, images_features = image_features) - #detected_objects, _ = top_k_categories(self.entities_text, logits, self.args.top_k, self.args.threshold) # List[List[]], [[category1, category2, ...], [], ...] - #detected_objects = detected_objects[0] # infering single image -> List[category1, category2, ...] - #discrete_tokens = compose_discrete_prompts(self.tokenizer, detected_objects).unsqueeze(dim = 0).to(self.args.device) - logits = image_text_simiarlity(self.texts_embeddings, temperature=self.args.temperature, images_features=image_features) - all_discrete_tokens = [] - for i in range(image_features.shape[0]): - detected_objects, _ = top_k_categories(self.entities_text, logits[i:i+1], self.args.top_k, self.args.threshold) - discrete_tokens = compose_discrete_prompts(self.tokenizer, detected_objects[0]) - all_discrete_tokens.append(discrete_tokens) - - all_discrete_tokens = [t.to(self.device) for t in all_discrete_tokens] - discrete_tokens = pad_sequence(all_discrete_tokens, batch_first=True, padding_value=pad_id) - #discrete_tokens = torch.stack(all_discrete_tokens).to(self.device) - - discrete_embeddings = self.model.word_embed(discrete_tokens) - if self.args.only_hard_prompt: - embeddings = discrete_embeddings - elif self.args.soft_prompt_first: - embeddings = torch.cat((continuous_embeddings, discrete_embeddings), dim = 1) - else: - embeddings = torch.cat((discrete_embeddings, continuous_embeddings), dim = 1) - else: - embeddings = continuous_embeddings - - if 'gpt' in self.args.language_model: - if not self.args.using_greedy_search: - - #sentences = beam_search(embeddings = embeddings, tokenizer = self.tokenizer, beam_width = self.args.beam_width, model = self.model.gpt) # List[str] - # make one beam_search call for each element in the batch - sentences = [] - for i in range(embeddings.shape[0]): - sentence = beam_search(embeddings = embeddings[i:i+1], tokenizer = self.tokenizer, beam_width = self.args.beam_width, model = self.model.gpt) - sentences.append(sentence[0]) - else: - sentences = greedy_search(embeddings = embeddings, tokenizer = self.tokenizer, model = self.model.gpt) - else: - sentences = opt_search(prompts=self.args.text_prompt, embeddings = embeddings, tokenizer = self.tokenizer, beam_width = self.args.beam_width, model = self.model.gpt) - - if compute_scores: - perplexities = self.compute_perplexity( - sentences, - tokenizer=self.tokenizer, - model=self.model.gpt, - device=self.device, - ) - return sentences, perplexities - else: - return sentences - - def compute_perplexity(self, sentences, tokenizer, model, device): - perplexities = [] - model.eval() - with torch.no_grad(): - for sentence in sentences: - encodings = tokenizer(sentence, return_tensors="pt").to(device) - input_ids = encodings.input_ids - attention_mask = encodings.attention_mask - - outputs = model(input_ids=input_ids, attention_mask=attention_mask, labels=input_ids) - loss = outputs.loss - perplexity = torch.exp(loss).item() - perplexities.append(perplexity) - return perplexities - - - def get_viecap_texts_embeddings(self, args, suffix : str): - suffix = suffix.replace('/', '') - - - # loading categories vocabulary for objects - if args.name_of_entities_text == 'visual_genome_entities': - entities_text = load_entities_text(args.name_of_entities_text, os.path.join(args.files_path, 'annotations/vocabulary/all_objects_attributes_relationships.pickle'), not args.disable_all_entities) - if args.prompt_ensemble: # loading ensemble embeddings - texts_embeddings = clip_texts_embeddings(entities_text, os.path.join(args.files_path, f'annotations/vocabulary/visual_genome_embedding_{suffix}_with_ensemble.pickle')) - else: - texts_embeddings = clip_texts_embeddings(entities_text, os.path.join(args.files_path, f'annotations/vocabulary/visual_genome_embedding_{suffix}.pickle')) - elif args.name_of_entities_text == 'coco_entities': - entities_text = load_entities_text(args.name_of_entities_text, os.path.join(args.files_path, 'annotations/vocabulary/coco_categories.json'), not args.disable_all_entities) - if args.prompt_ensemble: - texts_embeddings = clip_texts_embeddings(entities_text, os.path.join(args.files_path, f'annotations/vocabulary/coco_embeddings_{suffix}_with_ensemble.pickle')) - else: - texts_embeddings = clip_texts_embeddings(entities_text, os.path.join(args.files_path, f'annotations/vocabulary/coco_embeddings_{suffix}.pickle')) - elif args.name_of_entities_text == 'open_image_entities': - entities_text = load_entities_text(args.name_of_entities_text, './annotations/vocabulary/oidv7-class-descriptions-boxable.csv', not args.disable_all_entities) - if args.prompt_ensemble: - texts_embeddings = clip_texts_embeddings(entities_text, f'./annotations/vocabulary/open_image_embeddings_{suffix}_with_ensemble.pickle') - else: - texts_embeddings = clip_texts_embeddings(entities_text, f'./annotations/vocabulary/open_image_embeddings_{suffix}.pickle') - elif args.name_of_entities_text == 'vinvl_vg_entities': - entities_text = load_entities_text(args.name_of_entities_text, './annotations/vocabulary/VG-SGG-dicts-vgoi6-clipped.json', not args.disable_all_entities) - if args.prompt_ensemble: - texts_embeddings = clip_texts_embeddings(entities_text, f'./annotations/vocabulary/vg_embeddings_{suffix}_with_ensemble.pickle') - else: - texts_embeddings = clip_texts_embeddings(entities_text, f'./annotations/vocabulary/vg_embeddings_{suffix}.pickle') - elif args.name_of_entities_text == 'vinvl_vgoi_entities': - entities_text = load_entities_text(args.name_of_entities_text, os.path.join(args.files_path, 'annotations/vocabulary/vgcocooiobjects_v1_class2ind.json'), not args.disable_all_entities) - if args.prompt_ensemble: - texts_embeddings = clip_texts_embeddings(entities_text, os.path.join(args.files_path, f'annotations/vocabulary/vgoi_embeddings_{suffix}_with_ensemble.pickle')) - else: - texts_embeddings = clip_texts_embeddings(entities_text, os.path.join(args.files_path, f'annotations/vocabulary/vgoi_embeddings_{suffix}.pickle')) - else: - print('The entities text should be input correctly!') - return None - return entities_text, texts_embeddings - diff --git a/src/viecap/infer_by_instance.py b/src/viecap/infer_by_instance.py deleted file mode 100644 index 7fbf3bcf0429e3f39626bdf18b3a02190e0545b3..0000000000000000000000000000000000000000 --- a/src/viecap/infer_by_instance.py +++ /dev/null @@ -1,121 +0,0 @@ -import clip -import torch -import argparse -from PIL import Image -from ClipCap import ClipCaptionModel -from transformers import AutoTokenizer -from utils import compose_discrete_prompts -from load_annotations import load_entities_text -from search import greedy_search, beam_search, opt_search -from retrieval_categories import clip_texts_embeddings, image_text_simiarlity, top_k_categories -import os - - -@torch.no_grad() -def main(args) -> None: - # initializing - device = args.device - clip_name = args.clip_model.replace('/', '') - clip_hidden_size = 640 if 'RN' in args.clip_model else 512 - - # loading categories vocabulary for objects - if args.name_of_entities_text == 'visual_genome_entities': - entities_text = load_entities_text(args.name_of_entities_text, os.path.join(args.files_path, 'annotations/vocabulary/all_objects_attributes_relationships.pickle'), not args.disable_all_entities) - if args.prompt_ensemble: # loading ensemble embeddings - texts_embeddings = clip_texts_embeddings(entities_text, f'./annotations/vocabulary/visual_genome_embedding_{clip_name}_with_ensemble.pickle') - else: - texts_embeddings = clip_texts_embeddings(entities_text, f'./annotations/vocabulary/visual_genome_embedding_{clip_name}.pickle') - elif args.name_of_entities_text == 'coco_entities': - entities_text = load_entities_text(args.name_of_entities_text, './annotations/vocabulary/coco_categories.json', not args.disable_all_entities) - if args.prompt_ensemble: - texts_embeddings = clip_texts_embeddings(entities_text, f'./annotations/vocabulary/coco_embeddings_{clip_name}_with_ensemble.pickle') - else: - texts_embeddings = clip_texts_embeddings(entities_text, f'./annotations/vocabulary/coco_embeddings_{clip_name}.pickle') - elif args.name_of_entities_text == 'open_image_entities': - entities_text = load_entities_text(args.name_of_entities_text, './annotations/vocabulary/oidv7-class-descriptions-boxable.csv', not args.disable_all_entities) - if args.prompt_ensemble: - texts_embeddings = clip_texts_embeddings(entities_text, f'./annotations/vocabulary/open_image_embeddings_{clip_name}_with_ensemble.pickle') - else: - texts_embeddings = clip_texts_embeddings(entities_text, f'./annotations/vocabulary/open_image_embeddings_{clip_name}.pickle') - elif args.name_of_entities_text == 'vinvl_vg_entities': - entities_text = load_entities_text(args.name_of_entities_text, './annotations/vocabulary/VG-SGG-dicts-vgoi6-clipped.json', not args.disable_all_entities) - if args.prompt_ensemble: - texts_embeddings = clip_texts_embeddings(entities_text, f'./annotations/vocabulary/vg_embeddings_{clip_name}_with_ensemble.pickle') - else: - texts_embeddings = clip_texts_embeddings(entities_text, f'./annotations/vocabulary/vg_embeddings_{clip_name}.pickle') - elif args.name_of_entities_text == 'vinvl_vgoi_entities': - entities_text = load_entities_text(args.name_of_entities_text, os.path.join(args.files_path, 'annotations/vocabulary/vgcocooiobjects_v1_class2ind.json'), not args.disable_all_entities) - if args.prompt_ensemble: - texts_embeddings = clip_texts_embeddings(entities_text, os.path.join(args.files_path, f'annotations/vocabulary/vgoi_embeddings_{clip_name}_with_ensemble.pickle')) - else: - texts_embeddings = clip_texts_embeddings(entities_text, os.path.join(args.files_path, f'annotations/vocabulary/vgoi_embeddings_{clip_name}.pickle')) - else: - print('The entities text should be input correctly!') - return - - # loading model - tokenizer = AutoTokenizer.from_pretrained(args.language_model) - model = ClipCaptionModel(args.continuous_prompt_length, args.clip_project_length, clip_hidden_size, gpt_type = args.language_model) - model.load_state_dict(torch.load(args.weight_path, map_location = device), strict = False) - model.to(device) - encoder, preprocess = clip.load(args.clip_model, device = device) - - image = preprocess(Image.open(args.image_path)).unsqueeze(dim = 0).to(device) - image_features = encoder.encode_image(image).float() - image_features /= image_features.norm(2, dim = -1, keepdim = True) - continuous_embeddings = model.mapping_network(image_features).view(-1, args.continuous_prompt_length, model.gpt_hidden_size) - if args.using_hard_prompt: - logits = image_text_simiarlity(texts_embeddings, temperature = args.temperature, images_features = image_features) - detected_objects, _ = top_k_categories(entities_text, logits, args.top_k, args.threshold) # List[List[]], [[category1, category2, ...], [], ...] - detected_objects = detected_objects[0] # infering single image -> List[category1, category2, ...] - discrete_tokens = compose_discrete_prompts(tokenizer, detected_objects).unsqueeze(dim = 0).to(args.device) - - discrete_embeddings = model.word_embed(discrete_tokens) - if args.only_hard_prompt: - embeddings = discrete_embeddings - elif args.soft_prompt_first: - embeddings = torch.cat((continuous_embeddings, discrete_embeddings), dim = 1) - else: - embeddings = torch.cat((discrete_embeddings, continuous_embeddings), dim = 1) - else: - embeddings = continuous_embeddings - - if 'gpt' in args.language_model: - if not args.using_greedy_search: - sentence = beam_search(embeddings = embeddings, tokenizer = tokenizer, beam_width = args.beam_width, model = model.gpt) # List[str] - sentence = sentence[0] # selected top 1 - else: - sentence = greedy_search(embeddings = embeddings, tokenizer = tokenizer, model = model.gpt) - else: - sentence = opt_search(prompts=args.text_prompt, embeddings = embeddings, tokenizer = tokenizer, beam_width = args.beam_width, model = model.gpt) - sentence=sentence[0] - - print(f'the generated caption: {sentence}') - -if __name__ == '__main__': - - parser = argparse.ArgumentParser() - parser.add_argument('--device', default = 'cuda:0') - parser.add_argument('--clip_model', default = 'ViT-B/32') - parser.add_argument('--language_model', default = 'gpt2') - parser.add_argument('--continuous_prompt_length', type = int, default = 10) - parser.add_argument('--clip_project_length', type = int, default = 10) - parser.add_argument('--temperature', type = float, default = 0.01) - parser.add_argument('--top_k', type = int, default = 3) - parser.add_argument('--threshold', type = float, default = 0.2) - parser.add_argument('--disable_all_entities', action = 'store_true', default = False, help = 'whether to use entities with a single word only') - parser.add_argument('--name_of_entities_text', default = 'vinvl_vgoi_entities', choices = ('visual_genome_entities', 'coco_entities', 'open_image_entities', 'vinvl_vg_entities', 'vinvl_vgoi_entities')) - parser.add_argument('--prompt_ensemble', action = 'store_true', default = False) - parser.add_argument('--weight_path', default = '/raid/datasets/viecap_files/checkpoints/train_coco/coco_prefix-0014.pt') - parser.add_argument('--files_path', default = '/raid/datasets/viecap_files/') - parser.add_argument('--image_path', default = './images/') - parser.add_argument('--using_hard_prompt', action = 'store_true', default = False) - parser.add_argument('--soft_prompt_first', action = 'store_true', default = False) - parser.add_argument('--only_hard_prompt', action = 'store_true', default = False) - parser.add_argument('--using_greedy_search', action = 'store_true', default = False, help = 'greedy search or beam search') - parser.add_argument('--beam_width', type = int, default = 5, help = 'width of beam') - parser.add_argument('--text_prompt', type = str, default = None) - args = parser.parse_args() - print('args: {}\n'.format(vars(args))) - - main(args) \ No newline at end of file diff --git a/src/viecap/load_annotations.py b/src/viecap/load_annotations.py deleted file mode 100644 index 7eb52967684607e4a12c5a66c6cd3c64b8a373a8..0000000000000000000000000000000000000000 --- a/src/viecap/load_annotations.py +++ /dev/null @@ -1,213 +0,0 @@ -import json -import pickle -import pandas as pd -from typing import List - -def load_coco_captions(path: str) -> List[str]: - - with open(path, 'r') as infile: - annotations = json.load(infile) # dictionary -> {image_path: List[caption1, caption2, ...]} - punctuations = ['!', '"', '#', '$', '%', '&', "'", '(', ')', '*', '+', ',', ' ', '-', '.', '/', ':', ';', '<', '=', '>', '?', '@', '[', '\\', ']', '^', '_', '`', '{', '|', '}', '~'] - - captions = [] - for image_path in annotations: - temp_captions = annotations[image_path] # List: [caption1, caption2, ...], captions for the ith image - for caption in temp_captions: # caption - caption = caption.strip() # removing space at the end of the caption - if caption.isupper(): # processing the special caption in the COCO Caption, e.g., 'A BOY IS PLAYING BASEBALL.' - caption = caption.lower() - caption = caption[0].upper() + caption[1:] # capitalizing the first letter in the caption - if caption[-1] not in punctuations: # adding a '.' at the end of the caption if there are no punctuations. - caption += '.' - captions.append(caption) # final versin: A boy is playing baseball. - - return captions - -def load_flickr30k_captions(path: str) -> List[str]: - - with open(path, 'r') as infile: - annotations = json.load(infile) # dictionary -> {image_path: List[caption1, caption2, ...]} - punctuations = ['!', '"', '#', '$', '%', '&', "'", '(', ')', '*', '+', ',', ' ', '-', '.', '/', ':', ';', '<', '=', '>', '?', '@', '[', '\\', ']', '^', '_', '`', '{', '|', '}', '~'] - - captions = [] - for image_path in annotations: - temp_captions = annotations[image_path] - for caption in temp_captions: - caption = caption.strip() - if caption.isupper(): - caption = caption.lower() - caption = caption[0].upper() + caption[1:] - if caption[-1] not in punctuations: - caption += '.' - captions.append(caption) - - return captions - -def load_captions(name_of_datasets: str, path_of_datasets: str) -> List[str]: - """ - Args: - name_of_datasets: specifying the name of datasets - path_of_datasets: specifying the path of datasets - Return: - [caption1, caption2, ...] - """ - if name_of_datasets == 'coco_captions': - return load_coco_captions(path_of_datasets) - - if name_of_datasets == 'flickr30k_captions': - return load_flickr30k_captions(path_of_datasets) - - print('The datasets for training fail to load!') - -def load_stopwords() -> List[str]: - # Return: stopwords and punctuations - - stopwords = {'per', 'โ€™ll', 'could', 'fifteen', 'been', "isn't", 'whoever', 'any', 'whole', 'front', "won't", 'upon', 'there', 's', 'am', 'via', 'the', 'as', "haven't", 'on', 'km', 'further', 'their', 'quite', 'have', 'twenty', 'during', 'full', 'it', 'thin', 'so', 'what', 'an', 't', 'less', 'if', 'sixty', 'everyone', 'us', 'were', 'side', 'she', 'cannot', 'thereby', 'โ€˜ve', 'amount', 'nโ€™t', 'be', 'nine', 'isn', 'wouldn', 'by', 'along', "'ll", 'themselves', 'forty', 'everywhere', "'d", 'thru', 'sometimes', 'hasnt', 'seeming', 'own', 'that', "'ve", 'least', 'with', 'inc', 'really', 'afterwards', 'due', 'for', 'sometime', 'last', 'find', 'therein', 'all', 'thick', 'detail', 'few', 'hundred', 'some', 'even', 'off', 'โ€™m', 'ain', 'โ€™re', 'hence', 'etc', 'into', 'rather', 'where', 'm', 'its', 'onto', 'โ€™s', 'get', 'other', 'moreover', 'noone', 'being', 'must', 'bill', "wasn't", 'system', 'neither', "you'll", 'third', 'whereby', 'nobody', 'among', 'throughout', 'except', 'beforehand', "didn't", 'was', 'without', 'whose', 'hasn', 'โ€˜d', 'or', 'theirs', 'various', 'name', 'twelve', 'myself', 'former', 'though', 'we', 'ours', 'many', 'sincere', 'regarding', 'had', 'before', 'mustn', 'either', 'doing', 'why', 'fill', 'eight', 'won', 'anything', 'hereupon', 'this', 'amoungst', 'โ€˜s', 'of', 'yourselves', 'beside', 'within', 'ourselves', 'โ€˜re', 'about', 'elsewhere', 'latter', 'through', 'll', 'i', 'wasn', 'anywhere', 'weren', 'just', 'itself', "you're", 'wherein', 'four', 'keep', 'whether', 'nothing', 'found', 'back', 'needn', "aren't", 'has', 'one', 'wherever', 'serious', 'everything', 'hadn', 'first', 'anyway', 'co', 'still', 'five', 'becomes', "don't", 'formerly', 'ever', 'part', 'nowhere', 'made', 'himself', "couldn't", 'none', 'others', 'now', 'doesn', 'at', 'another', 'does', 'kg', 'see', 'often', 'them', 'shan', 'fifty', 'ltd', 'namely', 'they', 'somewhere', 'haven', 'take', 'latterly', 'well', 'whatever', 'nor', 'whereafter', 'might', 'only', 'de', 'our', 'hers', "mustn't", 'aren', 'you', 'his', "wouldn't", 'please', 'empty', 'but', 'mightn', 'then', 'should', 'and', 'each', 'such', 'a', 'yet', 'y', 'enough', 'someone', 'would', 'since', 'however', 'make', 'alone', 'anyone', 'amongst', 'these', 'whereupon', 'fire', "hasn't", 'shouldn', 'didn', 'do', 'me', 'becoming', 'after', 'several', 'seem', 'her', 'three', 'out', 'ten', 'whence', 'eg', 'couldn', 'un', 'did', "she's", 'whither', 'toward', 'once', "should've", 'call', "weren't", 'again', 'more', 'show', 'seems', "needn't", 'thereupon', 'used', 'most', 'hereby', 'put', 'ie', 've', 'my', 'your', 'thence', 'already', 'always', 'having', 'much', 'move', 'eleven', "'re", 'here', 'yours', 'con', 'done', 'up', 'over', 'yourself', "it's", 'o', 'six', 'can', 'how', "hadn't", 'anyhow', 'below', 'also', 'say', 'together', 'down', 'using', 'while', 'almost', 'cry', "you've", 'โ€™ve', 'two', 'towards', 'meanwhile', 'perhaps', 'when', 'ma', "shouldn't", 'both', 'hereafter', 'he', 'describe', 'ca', 'which', 'every', 'between', 'give', 'go', 'very', 'โ€™d', 'nevertheless', 'is', 'nโ€˜t', 'therefore', 'โ€˜ll', 'unless', 'next', 'who', 'became', 'mill', 'him', 'don', 'same', "'s", 'seemed', 'mostly', 'will', 're', "you'd", 'no', 'in', 'too', "mightn't", 'besides', 'are', 'because', 'couldnt', 'd', 'against', "doesn't", 'cant', 'whenever', 'somehow', 'thereafter', 'although', 'beyond', 'from', 'whereas', 'thus', 'than', "shan't", 'to', 'top', 'until', 'those', 'whom', 'bottom', 'else', 'herein', 'something', 'โ€˜m', 'may', 'not', "that'll", "'m", 'indeed', 'never', 'herself', 'interest', "n't", 'become', 'mine', 'otherwise'} - punctuations = ['!', '"', '#', '$', '%', '&', "'", '(', ')', '*', '+', ',', ' ', '-', '.', '/', ':', ';', '<', '=', '>', '?', '@', '[', '\\', ']', '^', '_', '`', '{', '|', '}', '~'] - other_words = {'photo', 'image', 'picture', 'pic', 'side', 'part', 'background'} - stopwords_and_punctuations = stopwords.union(punctuations) - stopwords_and_punctuations = stopwords_and_punctuations.union(other_words) - stopwords_and_punctuations = [stopword.lower() for stopword in stopwords_and_punctuations] - stopwords_and_punctuations.sort() - - return stopwords_and_punctuations - -def load_visual_genome_entities(path: str, all_entities: bool = True) -> List[str]: - # Visual Genome Vocabulary - - with open(path, 'rb') as infile: - all_objects_attributes_relationships = pickle.load(infile) # dictionary {'relationships': dict, 'attributes': dict, 'objects': dict} - entities = all_objects_attributes_relationships['objects'] # dictionary {'gqa': set, 'vg': set, 'joint': set}, joint = gqa + vg - entities = entities['joint'] # set - - if all_entities: - entities = [entity.lower().strip() for entity in entities] - else: - entities = [entity.lower().strip() for entity in entities if len(entity.split()) == 1] - entities.sort() # sort - - return entities - -def load_coco_entities(path: str, all_entities: bool = True) -> List[str]: - # COCO Vocabulary - - with open(path, 'r') as infile: - entities = json.load(infile) # List [category1, category2, ...] - - if all_entities: - entities = [entity.lower().strip() for entity in entities] - else: - entities = [entity.lower().strip() for entity in entities if len(entity.split()) == 1] - entities.sort() # sort - - return entities - -def load_open_image_entities(path: str, all_entities: bool = True) -> List[str]: - # Open Image Vocabulary - - open_images = pd.read_csv(path) # 601x2, i.e., [LabelName, DisplayName] - open_image_entities = list(open_images.DisplayName) # list - - for i in range(len(open_image_entities)): - entity = open_image_entities[i].lower().strip() - if entity[-1] == ')': - entity = entity[:entity.find('(')].strip() - open_image_entities[i] = entity - - if all_entities: - entities = [entity for entity in open_image_entities] - else: - entities = [entity for entity in open_image_entities if len(entity.split()) == 1] - entities.sort() # sort - - return entities - -def load_vinvl_vg_entities(path: str, all_entities: bool = True) -> List[str]: - # VG Vocabulary - - with open(path, 'r') as infile: - annotations = json.load(infile) # dictionary = {'label_to_idx':dict,'idx_to_label':dict,'attribute_to_idx':dict,'idx_to_attribute':dict,'predicate_to_idx':dict,'idx_to_predicate':dict,'object_count':dict,'attribute_count':dict,'predicate_count':dict,} - vinvl_entities = annotations['object_count'] # dictionary = {str: int, str: int, ...} - - if all_entities: - entities = [entity.lower().strip() for entity in vinvl_entities] - else: - entities = [entity.lower().strip() for entity in vinvl_entities if len(entity.split()) == 1] - entities.sort() # sort - - return entities - -def load_vinvl_vgoi_entities(path: str, all_entities: bool = True) -> List[str]: - - with open(path, 'r') as infile: - vgoi_entities = json.load(infile) # dictionary = {str: int} - - if all_entities: - entities = [entity.lower().strip() for entity in vgoi_entities] - else: - entities = [entity.lower().strip() for entity in vgoi_entities if len(entity.split()) == 1] - entities.sort() # sort - - return entities - -def load_entities_text(name_of_entities: str, path_of_entities: str, all_entities: bool = True) -> List[str]: - """ - Args: - name_of_entities: specifying the name of entities text - path_of_entities: specifying the path of entities text - all_entities: whether to apply all entities text. True denotes using entities including len(entitites.split()) > 1 - Return: - [entity1, entity2, ...] - """ - if name_of_entities == 'visual_genome_entities': - return load_visual_genome_entities(path_of_entities, all_entities) - - if name_of_entities == 'coco_entities': - return load_coco_entities(path_of_entities, all_entities) - - if name_of_entities == 'open_image_entities': - return load_open_image_entities(path_of_entities, all_entities) - - if name_of_entities == 'vinvl_vg_entities': - return load_vinvl_vg_entities(path_of_entities, all_entities) - - if name_of_entities == 'vinvl_vgoi_entities': - return load_vinvl_vgoi_entities(path_of_entities, all_entities) - - print('The entities text fails to load!') - -if __name__ == '__main__': - - # loading captions - datasets = ['coco_captions', 'flickr30k_captions'] - captions_path = [ - './annotations/coco/train_captions.json', - './annotations/flickr30k/train_captions.json', - ] - captions_idx = 1 - captions = load_captions(datasets[captions_idx], captions_path[captions_idx]) - for caption in captions[:20]: - print(caption) - print(len(captions), type(captions)) - - # loading stopwords - stopwords = load_stopwords() - print('stopwords: ', stopwords[:10], type(stopwords), len(stopwords)) - - # loading entities text - entities_text = ['visual_genome_entities', 'coco_entities', 'open_image_entities', 'vinvl_vg_entities', 'vinvl_vgoi_entities'] - entities_path = [ - './annotations/vocabulary/all_objects_attributes_relationships.pickle', - './annotations/vocabulary/coco_categories.json', - './annotations/vocabulary/oidv7-class-descriptions-boxable.csv', - './annotations/vocabulary/VG-SGG-dicts-vgoi6-clipped.json', - './annotations/vocabulary/vgcocooiobjects_v1_class2ind.json' - ] - # using all entities text - entities_idx = 4 - entities = load_entities_text(entities_text[entities_idx], entities_path[entities_idx]) - print('entities text: ', entities[:10], type(entities), len(entities)) - # using entities text with a single word - entities_idx = 4 - entities = load_entities_text(entities_text[entities_idx], entities_path[entities_idx], all_entities = False) - print('entities text: ', entities[:10], type(entities), len(entities)) \ No newline at end of file diff --git a/src/viecap/retrieval_categories.py b/src/viecap/retrieval_categories.py deleted file mode 100644 index 98359eb57e72dd9280db9f3dfd3c84b9105befe2..0000000000000000000000000000000000000000 --- a/src/viecap/retrieval_categories.py +++ /dev/null @@ -1,117 +0,0 @@ -import os -import clip -import torch -import pickle -from PIL import Image -from typing import List, Optional, Tuple - -@torch.no_grad() -def clip_texts_embeddings( - texts: List[str], - outpath = '', - device: Optional[str] = None, - batch_size: Optional[int] = 32, - clip_type: Optional[str] = None -) -> torch.Tensor: - """ - Args: - texts: name of categories, i.e., ['category1', 'category2', ...] - outpath: saving embeddings of category texts to outpath. reading it directly if existing - device: specifying device used - batch_size: the number of categories that would be transformed to embeddings per epoch - clip_type: specifying clip backbone used - Return: - tensor with a shape of (num_categories, clip_hidden_size), float32 - """ - if os.path.exists(outpath): - with open(outpath, 'rb') as infile: - texts_embeddings = pickle.load(infile) # (num_categories, clip_hidden_size) - return texts_embeddings - - # adding prompt for each category text, i.e., Photo of an ariplane. / Photo of a bicycle. - vowel = ['a', 'e', 'i', 'o', 'u', 'A', 'E', 'I', 'O', 'U'] - prompt_texts = [] - for text in texts: - if text[0] in vowel: - prompt_texts.append(f'A photo of an {text}.') - else: - prompt_texts.append(f'A photo of a {text}.') - - clip_texts_tokens = clip.tokenize(prompt_texts) # (num_categories, 77) - model, _ = clip.load(clip_type, device = device) # loading clip encoder - model.eval() - num_categories = len(texts) - texts_embeddings = None - epochs = int(num_categories / batch_size) if num_categories % batch_size == 0 else 1 + int (num_categories // batch_size) - for epoch in range(epochs): - temp_texts_tokens = clip_texts_tokens[batch_size * epoch : batch_size * (epoch + 1)] # (batch_size/(num_categories % batch_size), 77) - temp_texts_tokens = temp_texts_tokens.to(device) - with torch.no_grad(): - temp_texts_embeddings = model.encode_text(temp_texts_tokens).float().to('cpu') # (batch_size/(num_categories % batch_size), clip_hidden_size) - if texts_embeddings is None: - texts_embeddings = temp_texts_embeddings - else: - texts_embeddings = torch.cat((texts_embeddings, temp_texts_embeddings), dim = 0) - - with open(outpath, 'wb') as outfile: - pickle.dump(texts_embeddings, outfile) - - return texts_embeddings - -def image_text_simiarlity( - texts_embeddings: torch.Tensor, - temperature: float = 0.01, - image_path: Optional[str] = None, - images_features: Optional[torch.Tensor] = None, - clip_type: Optional[str] = None, - device: Optional[str] = None -) -> torch.Tensor: - """ - Args: - texts_embeddings: (num_categories, clip_hidden_size), float32, the embeddings of categories - temperature: temperature hyperparameter for computing similarity - image_path: Optional, the path of a single image - images_feature: (num_images, clip_hidden_size), float32, Optional - clip_type: clip type, using when input is image path - device: device using when input is device - Return: - logits with a shape of (num_images, num_categories) - """ - if images_features is None: - encoder, preprocess = clip.load(clip_type, device) - assert image_path is not None, 'Either image path or images feature should be given!' - image = preprocess(Image.open(image_path)).unsqueeze(dim = 0).to(device) # (1, 3, 224, 224) - with torch.no_grad(): - images_features = encoder.encode_image(image) # (1, clip_hidden_size) - - # computing on cpu to avoid out of memory - images_features = images_features.float().to('cpu') # (num_images, clip_hidden_size) - texts_embeddings = texts_embeddings.float().to('cpu') # (num_categories, clip_hidden_size) - images_features /= images_features.norm(dim = -1, keepdim = True) # (num_images, clip_hidden_size) - texts_embeddings /= texts_embeddings.norm(dim = -1, keepdim = True) # (num_categories, clip_hidden_size) - - image_to_text_similarity = torch.matmul(images_features, texts_embeddings.transpose(1, 0)) / temperature # (num_imegs, num_categories) - image_to_text_logits = torch.nn.functional.softmax(image_to_text_similarity, dim = -1) # (num_imegs, num_categories) - - return image_to_text_logits - -def top_k_categories( - texts: List[str], # ['category1', 'category2', ...], len = num_categories - logits: torch.Tensor, # (num_images, num_categories) - top_k: Optional[int] = 5, # choosing top k categories as retrieved category - threshold: Optional[float] = 0.0 # probability which is less than threshold will be filtered -) -> Tuple: - - top_k_probs, top_k_indices = torch.topk(logits, k = top_k, dim = -1) # (num_images, top_k) - top_k_texts = [] - for i in range(len(top_k_probs)): - per_image_top_k_probs = top_k_probs[i] # the ith image top k probability - per_image_top_k_indices = top_k_indices[i] # the ith image top k indices - temp_texts = [] - for j in range(top_k): - if per_image_top_k_probs[j] < threshold: - break - temp_texts.append(texts[per_image_top_k_indices[j]]) - top_k_texts.append(temp_texts) - - return top_k_texts, top_k_probs \ No newline at end of file diff --git a/src/viecap/search.py b/src/viecap/search.py deleted file mode 100644 index 8f2199eadc208aa894b17912c2df089d2943b8ba..0000000000000000000000000000000000000000 --- a/src/viecap/search.py +++ /dev/null @@ -1,682 +0,0 @@ -import clip -import torch -import numpy as np -from PIL import Image -import torch.nn.functional as F -from typing import Optional, Tuple, List -from transformers import GPT2Tokenizer, GPT2LMHeadModel - - -@torch.no_grad() -def opt_search( - prompts: Optional[str] = None, - tokens: Optional[torch.Tensor] = None, - embeddings: Optional[torch.Tensor] = None, - max_len: int = 64, - beam_width: int = 5, - end_of_sentence: str = ".", - tokenizer: GPT2Tokenizer = None, - model: GPT2LMHeadModel = None, -) -> List[str]: - """ - Sentence generation through choosing token guided by model confidence. - Taking text input as prompts, tokens or embeddings, if more than one input a time, priority should follow: embeddings > tokens > prompts. - Args: - prompts: str, prompts for generated sentence - tokens: tensor with shape of (b, n_seq), device = model.device, dtype = int64 - embeddings: tensor with shape of (b, n_seq, lm_hidden_size), device = model.device, dtype = float16/float32 (from clip encoder/gpt2 encoder) - max_len: int, the maximum length of generated sentence (without considering the length of prompts/tokens/embeddings) - end_of_sentence: str, early stopping once generated word is equal to end_of_sentence - tokenizer: transforming word/sentence to indice/list and vice versa, i.e., str -> List[int64] or List[int64] -> str - model: language model (taking input as either tokens or embeddings) - Return: - list[str] for generated sentence when batch size is greater than 1 (i.e., len(list) = batch_size), and string when batch size is equal to 1 - """ - model.eval() - device = model.device - - # tokenizing end of sentence, when the length of eos tokens is greater than 1, setting the first token of eos tokens as eos token - eos = tokenizer.encode(end_of_sentence)[-1] - - # prefix should transform into word embeddings so that sentence generation is capable of processing input of prompts, tokens or embeddings unifiedly - # priority: embeddings > tokens > prompts - if embeddings is not None: - generating = embeddings # (b, n_seq, lm_hidden_size) - else: - if tokens is None: - tokens = torch.tensor(tokenizer.encode(prompts)) # (n_seq), tokenizing prompts - tokens = tokens.unsqueeze(dim = 0).to(device) # (b(=1), n_seq), adding batch dimension - generating = word_embed(model, tokens) - # generating = model.transformer.wte(tokens) # (b, n_seq, lm_hidden_size), transforming to word embeddings - generating = generating.float() # (b, n_seq, lm_hidden_size) - assert generating.dim() == 3, 'The dimension of prompts should equal to 3!' - - b = generating.shape[0] - # past_key_values = None - inputs_opt = generating - - use_nucleus_sampling = False - num_beams=beam_width - max_length=max_len - min_length=1 - top_p=0.9 - repetition_penalty=1.0 - length_penalty=1.0 - num_captions=1 - temperature=1 - - if use_nucleus_sampling: - query_embeds = inputs_opt.repeat_interleave(num_captions, dim=0) - num_beams = 1 - else: - query_embeds = inputs_opt.repeat_interleave(num_beams, dim=0) - - atts_opt = torch.ones(inputs_opt.size()[:-1], dtype=torch.long).to(inputs_opt.device) - - prompt = tokenizer.eos_token + prompts if prompts else tokenizer.eos_token - prompt = [prompt] * b - opt_tokens = tokenizer(prompt, add_special_tokens=False, return_tensors="pt").to(embeddings.device) - input_ids = opt_tokens.input_ids - attention_mask = torch.cat([atts_opt, opt_tokens.attention_mask], dim=1) - - # import pdb - # pdb.set_trace() - - outputs = model.generate( - input_ids=input_ids, - query_embeds=query_embeds.type(model.dtype), - attention_mask=attention_mask, - do_sample=use_nucleus_sampling, - top_p=top_p, - temperature=temperature, - num_beams=num_beams, - max_new_tokens=max_length, - min_length=min_length, - eos_token_id= eos, - repetition_penalty=repetition_penalty, - length_penalty=length_penalty, - num_return_sequences=num_captions, - ) - - output_text = tokenizer.batch_decode(outputs[:, :], skip_special_tokens=True) - output_text = [text.strip() for text in output_text] - print(output_text) - return output_text - - -@torch.no_grad() -def greedy_search( - prompts: Optional[str] = None, - tokens: Optional[torch.Tensor] = None, - embeddings: Optional[torch.Tensor] = None, - max_len: int = 64, - end_of_sentences: List = [".", " ."], - tokenizer: GPT2Tokenizer = None, - model: GPT2LMHeadModel = None -) -> List[str]: - """ - Sentence generation through choosing token guided by model confidence. - Taking text input as prompts, tokens or embeddings, if more than one input a time, priority should follow: embeddings > tokens > prompts. - Args: - prompts: str, prompts for generated sentence - tokens: tensor with shape of (b, n_seq), device = model.device, dtype = int64 - embeddings: tensor with shape of (b, n_seq, lm_hidden_size), device = model.device, dtype = float16/float32 (from clip encoder/gpt2 encoder) - max_len: int, the maximum length of generated sentence (without considering the length of prompts/tokens/embeddings) - end_of_sentence: str, early stopping once generated word is equal to end_of_sentence - tokenizer: transforming word/sentence to indice/list and vice versa, i.e., str -> List[int64] or List[int64] -> str - model: language model (taking input as either tokens or embeddings) - Return: - list[str] for generated sentence when batch size is greater than 1 (i.e., len(list) = batch_size), and string when batch size is equal to 1 - """ - model.eval() - device = model.device - - # tokenizing end of sentence, when the length of eos tokens is greater than 1, setting the first token of eos tokens as eos token - eos = [tokenizer.encode(end_of_sentence)[-1] for end_of_sentence in end_of_sentences] - - # prefix should transform into word embeddings so that sentence generation is capable of processing input of prompts, tokens or embeddings unifiedly - # priority: embeddings > tokens > prompts - if embeddings is not None: - generating = embeddings # (b, n_seq, lm_hidden_size) - else: - if tokens is None: - tokens = torch.tensor(tokenizer.encode(prompts)) # (n_seq), tokenizing prompts - tokens = tokens.unsqueeze(dim = 0).to(device) # (b(=1), n_seq), adding batch dimension - generating = word_embed(model, tokens) - # generating = model.transformer.wte(tokens) # (b, n_seq, lm_hidden_size), transforming to word embeddings - generating = generating.float() # (b, n_seq, lm_hidden_size) - assert generating.dim() == 3, 'The dimension of prompts should equal to 3!' - - b = generating.shape[0] - past_key_values = None - for step in range(max_len): - # generating initial states of language model - if step == 0: - outputs = model(inputs_embeds = generating.type(model.dtype), past_key_values = past_key_values, use_cache = True) - next_token_logits = outputs.logits[:, -1, :] # (b, n_seq, vocal_size) -> (b, vocal_size), logits of the last token - past_key_values = outputs.past_key_values # Tuple[Tuple[(b, h, n_seq, lm_hidden_size/h)]], layers -> (key, value) -> torch.tensor - - next_token = torch.argmax(next_token_logits, dim = -1, keepdim = True) # (b, 1) - next_embedding = word_embed(model, next_token) # (b, 1, lm_hidden_size) - # next_embedding = model.transformer.wte(next_token) # (b, 1, lm_hidden_size) - outputs = model(inputs_embeds = next_embedding.type(model.dtype), past_key_values = past_key_values, use_cache = True) - next_token_logits = outputs.logits[:, -1, :] # (b, 1, vocal_size) -> (b, vocal_size) - past_key_values = outputs.past_key_values # Tuple[Tuple[(b, h, n_seq + 1, lm_hidden_size/h)]] - - # updating tokens - if tokens is None: - tokens = next_token - else: - tokens = torch.cat((tokens, next_token), dim = 1) # (b, n_seq + 1) - - # whether to stop early according to the end of sentence, only working when batch size is equal to 1 - if b == 1 and next_token.item() in eos: - new_tokens = tokens.squeeze(dim = 0).tolist() - sentence = tokenizer.decode(new_tokens) - return sentence - - # tokens: (1/b, n_seq + max_len) where n_seq refers to the length of inputs tokens or prompts - # torch.tensor(1/b, n_seq + max_Len) -> str/list[str] - sentence = [] - if b == 1: - new_tokens = tokens.squeeze(dim = 0).tolist() - sentence = tokenizer.decode(new_tokens) - else: - for temp_tokens in tokens: - for i in range(len(temp_tokens)): - if temp_tokens[i].item() in eos: - break - new_tokens = temp_tokens[:i + 1].tolist() - sentence.append(tokenizer.decode(new_tokens)) - return sentence - -def beam_search( - prompts: Optional[str] = None, - tokens: Optional[torch.Tensor] = None, - embeddings: Optional[torch.Tensor] = None, - temperature = 1.0, - max_len: int = 64, - beam_width: int = 5, - end_of_sentences: List = [".", " ."], - tokenizer: GPT2Tokenizer = None, - model: GPT2LMHeadModel = None -) -> List[str]: - """ - Sentence generation through choosing token guided by model confidence. - Taking text input as prompts, tokens or embeddings, if more than one input a time, priority should follow: embeddings > tokens > prompts. - Args: - prompts: str, prompts for generated sentence - tokens: tensor with shape of (b, n_seq), device = model.device, dtype = int64 - embeddings: tensor with shape of (b, n_seq, lm_hidden_size), device = model.device, dtype = float16/float32 (from clip encoder/gpt2 encoder) - max_len: int, the maximum length of generated sentence (without considering the length of prompts/tokens/embeddings) - beam_width: the width of beam - end_of_sentence: str, early stopping once generated word is equal to end_of_sentence - tokenizer: transforming word/sentence to indice/list and vice versa, i.e., str -> List[int64] or List[int64] -> str - model: language model (taking input as either tokens or embeddings) - Return: - list[str] for generated sentence when batch size is greater than 1 (i.e., len(list) = batch_size), and string when batch size is equal to 1 - """ - model.eval() - device = model.device - - # tokenizing end of sentence, when the length of eos tokens is greater than 1, setting the first token of eos tokens as eos token - eos = [tokenizer.encode(end_of_sentence)[-1] for end_of_sentence in end_of_sentences] - scores = None - seq_lengths = torch.ones(beam_width, device = device) - is_stopped = torch.zeros(beam_width, device = device, dtype=torch.bool) - # prefix should transform into word embeddings so that sentence generation is capable of processing input of prompts, tokens or embeddings unifiedly - # priority: embeddings > tokens > prompts - if embeddings is not None: - generated = embeddings # (b, n_seq, lm_hidden_size) - else: - if tokens is None: - tokens = torch.tensor(tokenizer.encode(prompts)) # (n_seq), tokenizing prompts - tokens = tokens.unsqueeze(dim = 0).to(device) # (b(=1), n_seq), adding batch dimension - generated = word_embed(model, tokens) - # generated = model.transformer.wte(tokens) # (b, n_seq, lm_hidden_size), transforming to word embeddings - generated = generated.float() # (b, n_seq, lm_hidden_size) - assert generated.dim() == 3, 'The dimension of prompts should equal to 3!' - - - for i in range(max_len): - outputs = model(inputs_embeds=generated.type(model.dtype)) - logits = outputs.logits - logits = logits[:, -1, :] / (temperature if temperature > 0 else 1.0) - logits = logits.softmax(-1).log() - if scores is None: - scores, next_tokens = logits.topk(beam_width, -1) - generated = generated.expand(beam_width, *generated.shape[1:]) - next_tokens, scores = next_tokens.permute(1, 0), scores.squeeze(0) - if tokens is None: - tokens = next_tokens - else: - tokens = tokens.expand(beam_width, *tokens.shape[1:]) - tokens = torch.cat((tokens, next_tokens), dim=1) - else: - logits[is_stopped] = -float(np.inf) - logits[is_stopped, 0] = 0 - scores_sum = scores[:, None] + logits - seq_lengths[~is_stopped] += 1 - scores_sum_average = scores_sum / seq_lengths[:, None] - scores_sum_average, next_tokens = scores_sum_average.view(-1).topk(beam_width, -1) - # next_tokens_source = torch.floor(torch.div(next_tokens, scores_sum.shape[1])).long() - next_tokens_source = torch.div(next_tokens, scores_sum.shape[1], rounding_mode = 'trunc') - seq_lengths = seq_lengths[next_tokens_source] - next_tokens = next_tokens % scores_sum.shape[1] - next_tokens = next_tokens.unsqueeze(1) - tokens = tokens[next_tokens_source] - tokens = torch.cat((tokens, next_tokens), dim=1) - generated = generated[next_tokens_source] - scores = scores_sum_average * seq_lengths - is_stopped = is_stopped[next_tokens_source] - next_token_embed = word_embed(model, next_tokens.squeeze()).view(generated.shape[0], 1, -1) - # next_token_embed = model.gpt.transformer.wte(next_tokens.squeeze()).view(generated.shape[0], 1, -1) - generated = torch.cat((generated, next_token_embed), dim=1) - assert len(eos) == 2 # hack - is_stopped = is_stopped + (next_tokens.eq(eos[0]) | next_tokens.eq(eos[1])).squeeze() - if is_stopped.all(): - break - scores = scores / seq_lengths - output_list = tokens.cpu().numpy() - output_texts = [tokenizer.decode(output[:int(length)]) for output, length in zip(output_list, seq_lengths)] - order = scores.argsort(descending=True) - output_texts = [output_texts[i] for i in order] - - return output_texts - -def word_embed(gpt, caption_tokens): - if hasattr(gpt, 'transformer'): - embedding_text = gpt.transformer.wte(caption_tokens) - elif hasattr(gpt, 'model'): - embedding_text = gpt.model.decoder.embed_tokens(caption_tokens) - return embedding_text - -@torch.no_grad() -def contrastive_search( - prompts: Optional[str] = None, - tokens: Optional[torch.Tensor] = None, - embeddings: Optional[torch.Tensor] = None, - alpha: float = 0.1, - top_k: int = 48, - max_len: int = 64, - end_of_sentence: str = '.', - tokenizer: GPT2Tokenizer = None, - model: GPT2LMHeadModel = None -) -> List[str]: - """ - Sentence generation through choosing token guided by model confidence, degeneration penality. - Taking text input as prompts, tokens or embeddings, if more than one input a time, priority should follow: embeddings > tokens > prompts. - Args: - prompts: str, prompts for generated sentence - tokens: tensor with shape of (b, n_seq), device = model.device, dtype = int64 - embeddings: tensor with shape of (b, n_seq, lm_hidden_size), device = model.device, dtype = float16/float32 (from clip encoder/gpt2 encoder) - alpha: float from 0.0 to 1.0, controlling the strength of degenration penalty (i.e., avoiding repeat) - top_k: int, generating k candidate tokens each time step in next token predicition (i.e., next token will be selected from the top k candidates) - max_len: int, the maximum length of generated sentence (without considering the length of prompts/tokens/embeddings) - end_of_sentence: str, early stopping once generated word is equal to end_of_sentence - tokenizer: transforming word/sentence to indice/list and vice versa, i.e., str -> List[int64] or List[int64] -> str - model: language model (taking input as either tokens or embeddings) - Return: - list[str] for generated sentence when batch size is greater than 1 (i.e., len(list) = batch_size), and string when batch size is equal to 1 - """ - model.eval() - device = model.device - - # tokenizing end of sentence, when the length of eos tokens is greater than 1, setting the first token of eos tokens as eos token - eos = tokenizer.encode(end_of_sentence)[0] - - # prefix should transform into word embeddings so that sentence generation is capable of processing input of prompts, tokens or embeddings unifiedly - # priority: embeddings > tokens > prompts - if embeddings is not None: - generating = embeddings # (b, n_seq, lm_hidden_size) - else: - if tokens is None: - tokens = torch.tensor(tokenizer.encode(prompts)) # (n_seq), tokenizing prompts - tokens = tokens.unsqueeze(dim = 0).to(device) # (b(=1), n_seq), adding batch dimension - generated = word_embed(model, tokens) - # generating = model.transformer.wte(tokens) # (b, n_seq, lm_hidden_size), transforming to word embeddings - generating = generating.float() # (b, n_seq, lm_hidden_size) - assert generating.dim() == 3, 'The dimension of prompts should equal to 3!' - - past_key_values = None - for step in range(max_len): - # generating the initial states of model - if step == 0: - outputs = model(inputs_embeds = generating, past_key_values = past_key_values, use_cache = True, output_hidden_states = True) - next_token_logits = outputs.logits[:, -1, :] # (b, n_seq, vocal_size) -> (b, vocal_size), logits of the last token - past_key_values = outputs.past_key_values # Tuple[Tuple[(b, h, n_seq, lm_hidden_size/h)]], layers -> (key, value) -> torch.tensor - past_hidden_states = outputs.hidden_states[-1] # Tuple[(b, n_seq, lm_hidden_size)] -> (b, n_seq, lm_hidden_size) (i.e., hidden state of last layer) - - # selecting top k candidates and their probability from next_tokens_logits - b, n_seq, lm_hidden_size = past_hidden_states.size() - next_token_probs = F.softmax(next_token_logits, dim = -1) # (b, vocal_size) - _, top_k_indices = torch.topk(next_token_logits, dim = -1, k = top_k) # (b, k), the indices for top k candidates (i.e., tokens) - top_k_probs = torch.gather(next_token_probs, dim = 1, index = top_k_indices) # (b, k), the probability for top k candidates - - # transformering b*k tokens to embeddings and processing past_key_values to compute simultaneously for k tokens - top_k_embeddings = model.transformer.wte(top_k_indices.view(-1, 1)) # (b*k, 1, lm_hidden_size) - past_key_values = reshape_from_past_key_values(past_key_values, top_k) # Tuple[Tuple[(b*k, h, n_seq, lm_hidden_size/h)]] - # computing hidden state of next token (b * top_k in total) - outputs = model(inputs_embeds = top_k_embeddings, past_key_values = past_key_values, use_cache = True, output_hidden_states = True) - logits = outputs.logits[:, -1, :] # (b*k, 1, vocal_size) -> (b*k, vocal_size) - past_key_values = outputs.past_key_values # Tuple[Tuple[(b*k, h, n_seq + 1, lm_hidden_size/h)]] - next_hidden_state = outputs.hidden_states[-1] # Tuple[(b*k, 1, lm_hidden_size)] -> (b*k, 1, lm_hidden_size) - context_hidden_states = past_hidden_states.unsqueeze(dim = 1).expand(-1, top_k, -1, -1).reshape(b*top_k, n_seq, lm_hidden_size) # (b*k, n_seq, lm_hidden_size) - - # selecting next token within top k candidates for each sentence - selected_max_prob_indices = ranking_and_selecting(context_hidden_states, next_hidden_state, top_k_probs, alpha, top_k) # (b) - - # updating next_token_logits, past key-values and last hidden state - logits = torch.stack(torch.split(logits, top_k), dim = 0) # (b, k, vocal_size) - next_token_logits = logits[range(b), selected_max_prob_indices, :] # (b, vocal_size) - past_key_values = reshape_to_past_key_values(past_key_values, selected_max_prob_indices, top_k) # (b, h, n_seq + 1, lm_hidden_size/h) - next_hidden_state = torch.stack(torch.split(next_hidden_state.squeeze(dim = 1), top_k), dim = 0) # (b, k, lm_hidden_size) - next_hidden_state = next_hidden_state[range(b), selected_max_prob_indices, :] # (b, lm_hidden_size) - past_hidden_states = torch.cat([past_hidden_states, next_hidden_state.unsqueeze(dim = 1)], dim=1) # [b, n_seq + 1, lm_hidden_size] - - # computing next token and saving it - next_token = top_k_indices[range(b), selected_max_prob_indices].unsqueeze(dim = -1) # (b, 1) - if tokens is None: - tokens = next_token - else: - tokens = torch.cat((tokens, next_token), dim = 1) # (b, n_seq + 1) - - # whether to stop early according to the end of sentence, only working when batch size is equal to 1 - if b == 1 and next_token.item() == eos: - new_tokens = tokens.squeeze(dim = 0).tolist() - sentence = tokenizer.decode(new_tokens) - return sentence - - # tokens: (1/b, n_seq + max_len) where n_seq refers to the length of inputs tokens or prompts - # torch.tensor(1/b, n_seq + max_Len) -> str/list[str] - sentence = [] - if b == 1: - new_tokens = tokens.squeeze(dim = 0).tolist() - sentence = tokenizer.decode(new_tokens) - else: - for temp_tokens in tokens: - for i in range(len(temp_tokens)): - if temp_tokens[i].item() == eos: - break - new_tokens = temp_tokens[:i + 1].tolist() - sentence.append(tokenizer.decode(new_tokens)) - return sentence - -@torch.no_grad() -def magic_search( - prompts: Optional[str] = None, - tokens: Optional[torch.Tensor] = None, - embeddings: Optional[torch.Tensor] = None, - image_path: Optional[str] = None, - images_feature: Optional[torch.Tensor] = None, - alpha: float = 0.1, - beta: float = 2.0, - top_k: int = 48, - max_len: int = 64, - clip_text_max_len: int = 60, - end_of_sentence: str = '.', - tokenizer: GPT2Tokenizer = None, - model: GPT2LMHeadModel = None -) -> List[str]: - """ - Sentence generation through choosing token guided by model confidence, degeneration penality and image at each time step. - Taking text input as prompts, tokens or embeddings, if more than one input a time, priority should follow: embeddings > tokens > prompts. - Taking image input as images_path or images_feature, if more than one input a time, priority should follow images_feature > image_path. - Args: - prompts: str, prompts for generated sentence - tokens: tensor with shape of (b, n_seq), device = model.device, dtype = int64 - embeddings: tensor with shape of (b, n_seq, lm_hidden_size), device = model.device, dtype = float16/float32 (from clip encoder/gpt2 encoder) - image_path: str, the path of a single image - images_feature: tensor with shape of (b, clip_hidden_size), device = model.device, dtype = float32 - alpha: float from 0.0 to 1.0, controlling the strength of degenration penalty (i.e., avoiding repeat) - beta: float, controlling image-guided strength - top_k: int, generating k candidate tokens each time step in next token predicition (i.e., next token will be selected from the top k candidates) - max_len: int, the maximum length of generated sentence (without considering the length of prompts/tokens/embeddings) - clip_text_max_len: int, the maximum length of clip textual encoder - end_of_sentence: str, early stopping once generated word is equal to end_of_sentence - tokenizer: transforming word/sentence to indice/list and vice versa, i.e., str -> List[int64] or List[int64] -> str - model: language model (taking input as either tokens or embeddings) - Return: - list[str] for generated sentence when batch size is greater than 1 (i.e., len(list) = batch_size), and string when batch size is equal to 1 - """ - model.eval() - device = model.device - - # tokenizing end of sentence, when the length of eos tokens is greater than 1, setting the first token of eos tokens as eos token - eos = tokenizer.encode(end_of_sentence)[0] - - # prefix should transform into word embeddings so that sentence generation is capable of processing input of prompts, tokens or embeddings unifiedly - # priority: embeddings > tokens > prompts - if embeddings is not None: - generating = embeddings # (b, n_seq, lm_hidden_size) - else: - if tokens is None: - tokens = torch.tensor(tokenizer.encode(prompts)) # (n_seq), tokenizing prompts - tokens = tokens.unsqueeze(dim = 0).to(device) # (b(=1), n_seq), adding batch dimension - generating = model.transformer.wte(tokens) # (b, n_seq, lm_hidden_size), transforming to word embeddings - generating = generating.float() # (b, n_seq, lm_hidden_size) - assert generating.dim() == 3, 'The dimension of prompts should equal to 3!' - - # generating image feature using clip visual encoder - # note that the dtype of feature from clip visual encoder is equal to float16, transforming it into float32 - # priority: images_feature > image_path - clip_model, preprocess = clip.load('ViT-B/32', device = device) - clip_model.eval() - if images_feature is None: - image = preprocess(Image.open(image_path)).unsqueeze(dim = 0).to(device) # (b(=1), 3, 224, 224) - images_feature = clip_model.encode_image(image) # (b, clip_hidden_size) - images_feature = images_feature.float() # (b, clip_hidden_size) - assert images_feature.dim() == 2, 'The dimension of images feature should equal to 2!' - assert images_feature.shape[0] == generating.shape[0], 'The number of images should be equal to the number of prompts/tokens/embeddings!' - - past_key_values = None - tokens_generated = None - for step in range(max_len): - # generating the initial states of model - if step == 0: - outputs = model(inputs_embeds = generating, past_key_values = past_key_values, use_cache = True, output_hidden_states = True) - next_token_logits = outputs.logits[:, -1, :] # (b, n_seq, vocal_size) -> (b, vocal_size), logits of the last token - past_key_values = outputs.past_key_values # Tuple[Tuple[(b, h, n_seq, lm_hidden_size/h)]], layers -> (key, value) -> torch.tensor - past_hidden_states = outputs.hidden_states[-1] # Tuple[(b, n_seq, lm_hidden_size)] -> (b, n_seq, lm_hidden_size) (i.e., hidden state of last layer) - - # selecting top k candidates and their probability from next_tokens_logits - b, n_seq, lm_hidden_size = past_hidden_states.size() - next_token_probs = F.softmax(next_token_logits, dim = -1) # (b, vocal_size) - _, top_k_indices = torch.topk(next_token_logits, dim = -1, k = top_k) # (b, k), the indices for top k candidates (i.e., tokens) - top_k_probs = torch.gather(next_token_probs, dim = 1, index = top_k_indices) # (b, k), the probability for top k candidates - - # computing similarity between image and sentence (b * k in total) - image_sentence_score = image_sentence_similarity(tokens_generated, top_k_indices, images_feature, top_k, clip_text_max_len, tokenizer, clip_model) # (b, k) - - # transformering b*k tokens to embeddings and processing past_key_values to compute simultaneously for k tokens - top_k_embeddings = model.transformer.wte(top_k_indices.view(-1, 1)) # (b*k, 1, lm_hidden_size) - past_key_values = reshape_from_past_key_values(past_key_values, top_k) # Tuple[Tuple[(b*k, h, n_seq, lm_hidden_size/h)]] - # computing hidden state of next token (b * top_k in total) - outputs = model(inputs_embeds = top_k_embeddings, past_key_values = past_key_values, use_cache = True, output_hidden_states = True) - logits = outputs.logits[:, -1, :] # (b*k, 1, vocal_size) -> (b*k, vocal_size) - past_key_values = outputs.past_key_values # Tuple[Tuple[(b*k, h, n_seq + 1, lm_hidden_size/h)]] - next_hidden_state = outputs.hidden_states[-1] # Tuple[(b*k, 1, lm_hidden_size)] -> (b*k, 1, lm_hidden_size) - context_hidden_states = past_hidden_states.unsqueeze(dim = 1).expand(-1, top_k, -1, -1).reshape(b*top_k, n_seq, lm_hidden_size) # (b*k, n_seq, lm_hidden_size) - - # selecting next token within top k candidates for each sentence - selected_max_prob_indices = ranking_and_selecting(context_hidden_states, next_hidden_state, top_k_probs, alpha, top_k, beta, image_sentence_score) # (b) - - # updating next_token_logits, past key-values and last hidden state - logits = torch.stack(torch.split(logits, top_k), dim = 0) # (b, k, vocal_size) - next_token_logits = logits[range(b), selected_max_prob_indices, :] # (b, vocal_size) - past_key_values = reshape_to_past_key_values(past_key_values, selected_max_prob_indices, top_k) # (b, h, n_seq + 1, lm_hidden_size/h) - next_hidden_state = torch.stack(torch.split(next_hidden_state.squeeze(dim = 1), top_k), dim = 0) # (b, k, lm_hidden_size) - next_hidden_state = next_hidden_state[range(b), selected_max_prob_indices, :] # (b, lm_hidden_size) - past_hidden_states = torch.cat([past_hidden_states, next_hidden_state.unsqueeze(dim = 1)], dim=1) # [b, n_seq + 1, lm_hidden_size] - - # computing next token and saving it - next_token = top_k_indices[range(b), selected_max_prob_indices].unsqueeze(dim = -1) # (b, 1) - if tokens is None: - tokens = next_token - tokens_generated = next_token - else: - if tokens_generated is None: - tokens_generated = next_token - else: - tokens_generated = torch.cat((tokens_generated, next_token), dim = 1) - tokens = torch.cat((tokens, next_token), dim = 1) # (b, n_seq + 1) - - # whether to stop early according to the end of sentence, only working when batch size is equal to 1 - if b == 1 and next_token.item() == eos: - new_tokens = tokens.squeeze(dim = 0).tolist() - sentence = tokenizer.decode(new_tokens) - return sentence - - # tokens: (1/b, n_seq + max_len) where n_seq refers to the length of inputs tokens or prompts - # torch.tensor(1/b, n_seq + max_Len) -> str/list[str] - sentence = [] - if b == 1: - new_tokens = tokens.squeeze(dim = 0).tolist() - sentence = tokenizer.decode(new_tokens) - else: - for temp_tokens in tokens: - for i in range(len(temp_tokens)): - if temp_tokens[i].item() == eos: - break - new_tokens = temp_tokens[:i + 1].tolist() - sentence.append(tokenizer.decode(new_tokens)) - return sentence - -def image_sentence_similarity( - tokens_generated: torch.Tensor, - top_k_indices: torch.Tensor, - images_feature: torch.Tensor, - top_k: int, - clip_text_max_len: int, - tokenizer: GPT2Tokenizer, - clip_model: clip -) -> torch.Tensor: - """ - Args: - tokens_generated: tensor with shape of (b, n_seq), the sentence generated (without considering the prompts) - top_k_indices: tensor with shape of (b, top_k), the top k candidates for each sentence - images_feature: tensor with shape of (b, clip_hidden_size), image feature encoded by clip - top_k: int, k candidates - clip_text_max_len: int, the maximum length of clip textual encoder - tokenizer: transforming word/sentence to indice/list and vice versa - clip_model: pre-trained clip model which encodes image or image to embeddings with dtype of float16 (transforming to float32) - - Return: - image-sentence similarity score with shape of (b, k), i.e., for each sentence (b in total), returning top k tokens similarity with image - """ - device = top_k_indices.device - - # obtaining tokens of generated (b sentences and k tokens for each sentence, i.e., b * k sentences in total) - if tokens_generated is None: - temp_tokens = top_k_indices.view(-1).unsqueeze(dim = 1) # (b*k, n_seq + 1), where n_seq = 0 - else: - b, n = tokens_generated.size() - tokens_generated = tokens_generated.unsqueeze(dim = 1).expand(-1, top_k, -1).reshape(b*top_k, n) # (b*k, n_seq) - top_k_indices = top_k_indices.view(-1).unsqueeze(dim = 1) # (b*k, 1) - temp_tokens = torch.cat([tokens_generated, top_k_indices], dim = 1) # (b*k, n_seq + 1) - - # converting to sentence - sentences = [] - for temp_token in temp_tokens: - # taking the latest clip_text_max_len tokens when tokens length is greater than clip_text_max_len - sentence = tokenizer.decode(temp_token[-clip_text_max_len:].to('cpu').tolist()) - sentences.append(sentence) # len(sentences) = b*k - - # converting to text tokens and embeddings of clip - clip_tokens = clip.tokenize(sentences).to(device) # (b*k, n_seq) - clip_embeddings = clip_model.encode_text(clip_tokens) # (b*k, clip_hidden_size) - clip_embeddings = torch.stack(torch.split(clip_embeddings, top_k), dim = 0).float() # (b, k, clip_hidden_size) - - # computing similarity score - images_feature = images_feature.unsqueeze(dim = 1) # (b, 1, clip_hidden_size) - clip_embeddings = clip_embeddings / clip_embeddings.norm(dim = -1, keepdim = True) # (b, k, clip_hidden_size) - images_feature = images_feature / images_feature.norm(dim = -1, keepdim = True) # (b, 1, clip_hidden_size) - scaling = clip_model.logit_scale.exp() - score = torch.matmul(clip_embeddings, images_feature.transpose(1, 2)).squeeze(dim = 2) * scaling # (b, k) - - return F.softmax(score, dim = -1) - -def reshape_from_past_key_values(past_key_values: Tuple[Tuple[torch.Tensor]], top_k: int) -> Tuple[Tuple[torch.Tensor]]: - """ - To compute top k candidates simultaneously for each sentence in a batch, duplicating k times for each sentence. - Args: - past_key_values: Tuple[Tuple[(b, h, n_seq, lm_hidden_size/h)]], the first tuple refers to layers and the second tuple refers to key-value pair - top_k: int, k candidates - Return: - Tuple[Tuple[(b*k, h, n_seq, lm_hidden_size/h)]] - """ - new_key_values = [] - for layer in past_key_values: - items = [] - for item in layer: - b, h, n, d = item.size() # d = lm_hidden_size/h - # duplicating k times for each sentence in a batch, the only difference between each k repeated sample is the candidate waiting to concatenate - item = item.unsqueeze(dim = 1).expand(-1, top_k, -1, -1, -1).reshape(b*top_k, h, n, d) # (b*k, h, n_seq, lm_hidden_size/h) - items.append(item) - new_key_values.append(items) - return new_key_values - -def reshape_to_past_key_values(past_key_values: Tuple[Tuple[torch.Tensor]], selected_max_prob_indices: torch.Tensor, top_k: int) -> Tuple[Tuple[torch.Tensor]]: - """ - Args: - past_key_values: Tuple[Tuple[(b*k, h, n_seq + 1, lm_hidden_size/h)]] - selected_max_prob_indices: tensor with shape of (b), indices of maximum probability in k candidates - top_k: int, k candidates - Return: - Tuple[Tuple[(b, h, n_seq + 1, lm_hidden_size/h)]] - """ - new_key_values = [] - for layer in past_key_values: - items = [] - for item in layer: - bk = item.shape[0] - b = int(bk//top_k) - item = torch.stack(torch.split(item, top_k), dim = 0) # (b, k, h, n_seq + 1, lm_hidden_size/h) - item = item[range(b), selected_max_prob_indices, :, :, :] # (b, h, n_seq + 1, lm_hidden_size/h) - items.append(item) - new_key_values.append(items) - return new_key_values - -def ranking_and_selecting( - context_hidden_states: torch.Tensor, - next_hidden_state: torch.Tensor, - top_k_probs: torch.Tensor, - alpha: float, - top_k: int, - beta: Optional[float] = None, - image_sentence_score: Optional[torch.Tensor] = None -) -> torch.Tensor: - """ - Args: - context_hidden_states: tensor with shape of (b*k, n_seq, lm_hidden_size), the hidden state of each token in sentence before candidates (i.e.