Spaces:
Runtime error
Runtime error
Mustafa Acikgoz
commited on
Commit
·
2e51bae
0
Parent(s):
Initial clean commit for Gradio app
Browse files- .gitattributes +1 -0
- .gitignore +0 -0
- README.md +77 -0
- app.py +95 -0
- clip_book_model.pth +3 -0
- config.py +20 -0
- dataset.py +51 -0
- inference_model.py +73 -0
- requirements.txt +7 -0
- train.py +54 -0
- training_model.py +73 -0
.gitattributes
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
*.pth filter=lfs diff=lfs merge=lfs -text
|
.gitignore
ADDED
|
File without changes
|
README.md
ADDED
|
@@ -0,0 +1,77 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# CLIP-Style Image Search Engine (Textbook Implementation)
|
| 2 |
+
|
| 3 |
+
This project provides a complete, modular, and end-to-end implementation of a CLIP-style model for text-to-image search. The architecture and training methodology are a faithful reproduction of the approach described in Chapter 14 of the textbook, "Building an Image Search Engine Using CLIP: a Multimodal Approach".
|
| 4 |
+
|
| 5 |
+
The project is structured for clarity and maintainability, making it an ideal portfolio piece to showcase skills in PyTorch, model implementation, and MLOps practices like deployment with Streamlit and Hugging Face.
|
| 6 |
+
|
| 7 |
+
## Key Features
|
| 8 |
+
|
| 9 |
+
- **Faithful "Book Version" Architecture:** Implements the specific design choices from the textbook:
|
| 10 |
+
- **Frozen Vision Encoder:** Uses a pre-trained `ResNet50` as a fixed feature extractor.
|
| 11 |
+
- **Frozen Text Encoder:** Uses a pre-trained `DistilBERT` as a fixed feature extractor.
|
| 12 |
+
- **Projection Heads:** Maps both image and text features into a shared 256-dimensional space.
|
| 13 |
+
- **Custom Contrastive Loss:** Implements the unique loss function described in the book.
|
| 14 |
+
- **Modular & Professional Code Structure:** The code is separated into logical files (`config.py`, `dataset.py`, `model.py`, `train.py`, `app.py`) for better organization and scalability.
|
| 15 |
+
- **End-to-End MLOps Pipeline:**
|
| 16 |
+
- **Training:** A dedicated script to train the model and save the weights.
|
| 17 |
+
- **Inference:** A standalone Streamlit web application for interactive text-to-image search.
|
| 18 |
+
- **Hub Integration:** Detailed instructions for uploading the trained model and hosting the app on the Hugging Face Hub.
|
| 19 |
+
|
| 20 |
+
## Project Structure
|
| 21 |
+
your-clip-project/
|
| 22 |
+
│
|
| 23 |
+
├── data/
|
| 24 |
+
│ ├── images/
|
| 25 |
+
│ └── captions.txt
|
| 26 |
+
│
|
| 27 |
+
├── app.py
|
| 28 |
+
├── config.py
|
| 29 |
+
├── dataset.py
|
| 30 |
+
├── model.py
|
| 31 |
+
├── train.py
|
| 32 |
+
│
|
| 33 |
+
├── requirements.txt
|
| 34 |
+
└── README.md
|
| 35 |
+
|
| 36 |
+
|
| 37 |
+
## Setup and Installation
|
| 38 |
+
|
| 39 |
+
**1. Clone the Repository:**
|
| 40 |
+
```bash
|
| 41 |
+
git clone <your-repo-url>
|
| 42 |
+
cd your-clip-project
|
| 43 |
+
2. Create a Python Virtual Environment:
|
| 44 |
+
|
| 45 |
+
Bash
|
| 46 |
+
|
| 47 |
+
python -m venv venv
|
| 48 |
+
source venv/bin/activate # On Windows, use `venv\Scripts\activate`
|
| 49 |
+
3. Install Dependencies:
|
| 50 |
+
|
| 51 |
+
Bash
|
| 52 |
+
|
| 53 |
+
pip install -r requirements.txt
|
| 54 |
+
4. Download the Flickr8k Dataset:
|
| 55 |
+
|
| 56 |
+
Request the dataset from the official source: https://illinois.edu/fb/sec/1713398.
|
| 57 |
+
|
| 58 |
+
Download and extract Flickr8k_Dataset.zip into the data/images/ folder.
|
| 59 |
+
|
| 60 |
+
Find a captions.txt file (commonly available on Kaggle versions of the dataset) and place it at data/captions.txt.
|
| 61 |
+
|
| 62 |
+
How to Run
|
| 63 |
+
Step 1: Train the Model
|
| 64 |
+
First, you must train the model. This will create a clip_book_model.pth file containing the learned weights of the projection heads.
|
| 65 |
+
|
| 66 |
+
Run the training script from your terminal:
|
| 67 |
+
|
| 68 |
+
Bash
|
| 69 |
+
|
| 70 |
+
python train.py
|
| 71 |
+
Step 2: Launch the Web Application
|
| 72 |
+
Once the model is trained, launch the interactive search engine with Streamlit:
|
| 73 |
+
|
| 74 |
+
Bash
|
| 75 |
+
|
| 76 |
+
streamlit run app.py
|
| 77 |
+
This will open a new tab in your browser with the application running.
|
app.py
ADDED
|
@@ -0,0 +1,95 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import gradio as gr
|
| 2 |
+
import torch
|
| 3 |
+
import torch.nn.functional as F
|
| 4 |
+
from PIL import Image
|
| 5 |
+
import requests
|
| 6 |
+
from io import BytesIO
|
| 7 |
+
from torchvision import transforms
|
| 8 |
+
from transformers import DistilBertTokenizer
|
| 9 |
+
import config # Your config file
|
| 10 |
+
from inference_model import CLIPModel # Your model class file
|
| 11 |
+
|
| 12 |
+
# --- 1. Load Model and Tokenizer (runs only once) ---
|
| 13 |
+
# This section loads your trained model and tokenizer when the app starts.
|
| 14 |
+
device = config.DEVICE
|
| 15 |
+
|
| 16 |
+
# Load model with dimensions from config
|
| 17 |
+
model = CLIPModel(
|
| 18 |
+
image_embedding_dim=config.IMAGE_EMBEDDING_DIM,
|
| 19 |
+
text_embedding_dim=config.TEXT_EMBEDDING_DIM,
|
| 20 |
+
projection_dim=config.PROJECTION_DIM
|
| 21 |
+
).to(device)
|
| 22 |
+
|
| 23 |
+
# Load the trained model weights from your .pth file
|
| 24 |
+
model.load_state_dict(torch.load(config.MODEL_PATH, map_location=device))
|
| 25 |
+
model.eval()
|
| 26 |
+
|
| 27 |
+
# Load tokenizer
|
| 28 |
+
tokenizer = DistilBertTokenizer.from_pretrained('distilbert-base-uncased')
|
| 29 |
+
print("Model and Tokenizer loaded successfully.")
|
| 30 |
+
|
| 31 |
+
# --- 2. Image Preprocessing Function (reused from your code) ---
|
| 32 |
+
def preprocess_image(image):
|
| 33 |
+
"""Preprocess the image for the model."""
|
| 34 |
+
transform = transforms.Compose([
|
| 35 |
+
transforms.Resize((224, 224)),
|
| 36 |
+
transforms.ToTensor(),
|
| 37 |
+
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
|
| 38 |
+
])
|
| 39 |
+
return transform(image).unsqueeze(0)
|
| 40 |
+
|
| 41 |
+
# --- 3. The Main Gradio Function ---
|
| 42 |
+
# This is the core function that Gradio will build a UI around.
|
| 43 |
+
# It takes the inputs from the UI and returns the outputs to the UI.
|
| 44 |
+
def find_best_match(image_input, text_queries_input):
|
| 45 |
+
"""
|
| 46 |
+
Takes an image and a block of text queries, and returns a dictionary
|
| 47 |
+
of queries and their similarity scores.
|
| 48 |
+
"""
|
| 49 |
+
if image_input is None:
|
| 50 |
+
return "Please provide an image."
|
| 51 |
+
if not text_queries_input:
|
| 52 |
+
return "Please provide text descriptions."
|
| 53 |
+
|
| 54 |
+
# Process the image to get a tensor
|
| 55 |
+
image_tensor = preprocess_image(image_input).to(device)
|
| 56 |
+
|
| 57 |
+
# Process the text queries into a clean list
|
| 58 |
+
queries = [q.strip() for q in text_queries_input.split('\n') if q.strip()]
|
| 59 |
+
if not queries:
|
| 60 |
+
return "Please provide valid text descriptions."
|
| 61 |
+
|
| 62 |
+
# Process the text queries to get tokens
|
| 63 |
+
text_inputs = tokenizer(queries, padding=True, truncation=True, return_tensors="pt").to(device)
|
| 64 |
+
|
| 65 |
+
# Get model embeddings
|
| 66 |
+
with torch.no_grad():
|
| 67 |
+
image_embedding, text_embeddings = model(
|
| 68 |
+
image_features=image_tensor,
|
| 69 |
+
text_input_ids=text_inputs['input_ids'],
|
| 70 |
+
text_attention_mask=text_inputs['attention_mask']
|
| 71 |
+
)
|
| 72 |
+
|
| 73 |
+
# Calculate cosine similarity and format for Gradio's Label component
|
| 74 |
+
image_embedding_norm = F.normalize(image_embedding, p=2, dim=-1)
|
| 75 |
+
text_embeddings_norm = F.normalize(text_embeddings, p=2, dim=-1)
|
| 76 |
+
similarity_scores = (image_embedding_norm @ text_embeddings_norm.T).squeeze(0)
|
| 77 |
+
|
| 78 |
+
# Create a results dictionary: { "query text": score, ... }
|
| 79 |
+
results = {query: score.item() for query, score in zip(queries, similarity_scores)}
|
| 80 |
+
|
| 81 |
+
return results
|
| 82 |
+
|
| 83 |
+
# --- 4. Create and Launch the Gradio Interface ---
|
| 84 |
+
iface = gr.Interface(
|
| 85 |
+
fn=find_best_match,
|
| 86 |
+
inputs=[
|
| 87 |
+
gr.Image(type="pil", label="Upload or Drag an Image"),
|
| 88 |
+
gr.Textbox(lines=5, label="Text Descriptions (one per line)", placeholder="a person on a beach\na black cat\na city skyline at night")
|
| 89 |
+
],
|
| 90 |
+
outputs=gr.Label(num_top_classes=5, label="Results"),
|
| 91 |
+
title="🖼️ CLIP Image-Text Search",
|
| 92 |
+
description="Provide an image and several text descriptions. The app will use a trained CLIP model to find the best textual match for the image."
|
| 93 |
+
)
|
| 94 |
+
|
| 95 |
+
iface.launch()
|
clip_book_model.pth
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:e372639e3a5bfa25a166d5825293e682c684628e1d95eccb6da17fbbac82b522
|
| 3 |
+
size 363273830
|
config.py
ADDED
|
@@ -0,0 +1,20 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# config.py
|
| 2 |
+
import torch
|
| 3 |
+
|
| 4 |
+
# --- Project Paths ---
|
| 5 |
+
IMAGE_DIR = "data/Flicker8k_Dataset"
|
| 6 |
+
CAPTION_FILE = "data/captions.txt"
|
| 7 |
+
MODEL_PATH = "clip_book_model.pth"
|
| 8 |
+
|
| 9 |
+
# --- Model Dimensions ---
|
| 10 |
+
IMAGE_EMBEDDING_DIM = 2048 # ResNet50 output dimension
|
| 11 |
+
TEXT_EMBEDDING_DIM = 768 # DistilBERT output dimension
|
| 12 |
+
PROJECTION_DIM = 256 # Shared embedding space dimension
|
| 13 |
+
|
| 14 |
+
# --- Training Parameters ---
|
| 15 |
+
BATCH_SIZE = 32
|
| 16 |
+
NUM_EPOCHS = 3
|
| 17 |
+
LEARNING_RATE = 1e-3
|
| 18 |
+
|
| 19 |
+
# --- System ---
|
| 20 |
+
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
dataset.py
ADDED
|
@@ -0,0 +1,51 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# dataset.py
|
| 2 |
+
import os
|
| 3 |
+
import torch
|
| 4 |
+
from torch.utils.data import Dataset
|
| 5 |
+
from PIL import Image
|
| 6 |
+
import pandas as pd
|
| 7 |
+
import torchvision.transforms as transforms
|
| 8 |
+
|
| 9 |
+
class Flickr8kDataset(Dataset):
|
| 10 |
+
"""
|
| 11 |
+
Custom PyTorch Dataset for the Flickr8k data.
|
| 12 |
+
It loads images and their corresponding captions, tokenizing the text
|
| 13 |
+
on initialization for efficiency.
|
| 14 |
+
"""
|
| 15 |
+
def __init__(self, image_dir, caption_file, tokenizer):
|
| 16 |
+
self.image_dir = image_dir
|
| 17 |
+
self.tokenizer = tokenizer
|
| 18 |
+
|
| 19 |
+
self.transform = transforms.Compose([
|
| 20 |
+
transforms.Resize((224, 224)),
|
| 21 |
+
transforms.ToTensor(),
|
| 22 |
+
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
|
| 23 |
+
])
|
| 24 |
+
|
| 25 |
+
df = pd.read_csv(caption_file)
|
| 26 |
+
self.image_paths = [os.path.join(self.image_dir, fname) for fname in df['image']]
|
| 27 |
+
self.captions = df['caption'].tolist()
|
| 28 |
+
|
| 29 |
+
print("Tokenizing all captions... (This may take a moment)")
|
| 30 |
+
self.caption_encodings = self.tokenizer(
|
| 31 |
+
self.captions,
|
| 32 |
+
truncation=True,
|
| 33 |
+
padding='max_length',
|
| 34 |
+
max_length=200,
|
| 35 |
+
return_tensors="pt"
|
| 36 |
+
)
|
| 37 |
+
print("Tokenization complete.")
|
| 38 |
+
|
| 39 |
+
def __len__(self):
|
| 40 |
+
return len(self.captions)
|
| 41 |
+
|
| 42 |
+
def __getitem__(self, idx):
|
| 43 |
+
item = {key: val[idx] for key, val in self.caption_encodings.items()}
|
| 44 |
+
try:
|
| 45 |
+
img = Image.open(self.image_paths[idx]).convert("RGB")
|
| 46 |
+
item['image'] = self.transform(img)
|
| 47 |
+
except (FileNotFoundError):
|
| 48 |
+
print(f"Warning: Could not load image at {self.image_paths[idx]}. Returning a black image.")
|
| 49 |
+
item['image'] = torch.zeros((3, 224, 224))
|
| 50 |
+
item["caption_text"] = self.captions[idx]
|
| 51 |
+
return item
|
inference_model.py
ADDED
|
@@ -0,0 +1,73 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# inference_model.py
|
| 2 |
+
import torch
|
| 3 |
+
import torch.nn as nn
|
| 4 |
+
from torchvision.models import resnet50
|
| 5 |
+
from transformers import DistilBertModel
|
| 6 |
+
|
| 7 |
+
# --- Copy these classes from your original file ---
|
| 8 |
+
class VisionEncoder(nn.Module):
|
| 9 |
+
def __init__(self):
|
| 10 |
+
super().__init__()
|
| 11 |
+
# Note: Using the newer 'weights' parameter is recommended
|
| 12 |
+
pretrained_resnet50 = resnet50(weights='IMAGENET1K_V1')
|
| 13 |
+
self.model = nn.Sequential(*list(pretrained_resnet50.children())[:-1])
|
| 14 |
+
for param in self.model.parameters():
|
| 15 |
+
param.requires_grad = False
|
| 16 |
+
|
| 17 |
+
def forward(self, x):
|
| 18 |
+
x = self.model(x)
|
| 19 |
+
return x.view(x.size(0), -1)
|
| 20 |
+
|
| 21 |
+
class TextEncoder(nn.Module):
|
| 22 |
+
def __init__(self):
|
| 23 |
+
super().__init__()
|
| 24 |
+
self.model = DistilBertModel.from_pretrained('distilbert-base-uncased')
|
| 25 |
+
for param in self.model.parameters():
|
| 26 |
+
param.requires_grad = False
|
| 27 |
+
|
| 28 |
+
def forward(self, input_ids, attention_mask=None):
|
| 29 |
+
outputs = self.model(input_ids=input_ids, attention_mask=attention_mask)
|
| 30 |
+
return outputs.last_hidden_state[:, 0, :]
|
| 31 |
+
|
| 32 |
+
class ProjectionHead(nn.Module):
|
| 33 |
+
def __init__(self, embedding_dim, projection_dim=256, dropout=0.1):
|
| 34 |
+
super().__init__()
|
| 35 |
+
self.projection = nn.Linear(embedding_dim, projection_dim)
|
| 36 |
+
self.gelu = nn.GELU()
|
| 37 |
+
self.fc = nn.Linear(projection_dim, projection_dim)
|
| 38 |
+
self.dropout = nn.Dropout(dropout)
|
| 39 |
+
self.layer_norm = nn.LayerNorm(projection_dim)
|
| 40 |
+
|
| 41 |
+
def forward(self, x):
|
| 42 |
+
projected = self.projection(x)
|
| 43 |
+
x = self.gelu(projected)
|
| 44 |
+
x = self.fc(x)
|
| 45 |
+
x = self.dropout(x)
|
| 46 |
+
x = x + projected
|
| 47 |
+
x = self.layer_norm(x)
|
| 48 |
+
return x
|
| 49 |
+
|
| 50 |
+
# --- This is the MODIFIED CLIPModel for inference ---
|
| 51 |
+
class CLIPModel(nn.Module):
|
| 52 |
+
def __init__(self, image_embedding_dim, text_embedding_dim, projection_dim):
|
| 53 |
+
super().__init__()
|
| 54 |
+
self.vision_encoder = VisionEncoder()
|
| 55 |
+
self.text_encoder = TextEncoder()
|
| 56 |
+
self.image_projection = ProjectionHead(embedding_dim=image_embedding_dim, projection_dim=projection_dim)
|
| 57 |
+
self.text_projection = ProjectionHead(embedding_dim=text_embedding_dim, projection_dim=projection_dim)
|
| 58 |
+
|
| 59 |
+
def forward(self, image_features=None, text_input_ids=None, text_attention_mask=None):
|
| 60 |
+
image_embedding = None
|
| 61 |
+
if image_features is not None:
|
| 62 |
+
image_features = self.vision_encoder(image_features)
|
| 63 |
+
image_embedding = self.image_projection(image_features)
|
| 64 |
+
|
| 65 |
+
text_embedding = None
|
| 66 |
+
if text_input_ids is not None:
|
| 67 |
+
text_features = self.text_encoder(
|
| 68 |
+
input_ids=text_input_ids,
|
| 69 |
+
attention_mask=text_attention_mask
|
| 70 |
+
)
|
| 71 |
+
text_embedding = self.text_projection(text_features)
|
| 72 |
+
|
| 73 |
+
return image_embedding, text_embedding
|
requirements.txt
ADDED
|
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
gradio
|
| 2 |
+
torch
|
| 3 |
+
torchvision
|
| 4 |
+
transformers
|
| 5 |
+
Pillow
|
| 6 |
+
requests
|
| 7 |
+
huggingface-hub
|
train.py
ADDED
|
@@ -0,0 +1,54 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# train.py
|
| 2 |
+
import torch
|
| 3 |
+
from torch.utils.data import DataLoader
|
| 4 |
+
from transformers import DistilBertTokenizer
|
| 5 |
+
from tqdm import tqdm
|
| 6 |
+
|
| 7 |
+
import config
|
| 8 |
+
from dataset import Flickr8kDataset
|
| 9 |
+
from model import CLIPModel
|
| 10 |
+
|
| 11 |
+
def main():
|
| 12 |
+
print(f"--- Starting Training ---")
|
| 13 |
+
print(f"Using device: {config.DEVICE}")
|
| 14 |
+
|
| 15 |
+
tokenizer = DistilBertTokenizer.from_pretrained('distilbert-base-uncased')
|
| 16 |
+
dataset = Flickr8kDataset(config.IMAGE_DIR, config.CAPTION_FILE, tokenizer)
|
| 17 |
+
dataloader = DataLoader(dataset, batch_size=config.BATCH_SIZE, shuffle=True)
|
| 18 |
+
|
| 19 |
+
model = CLIPModel(
|
| 20 |
+
image_embedding_dim=config.IMAGE_EMBEDDING_DIM,
|
| 21 |
+
text_embedding_dim=config.TEXT_EMBEDDING_DIM,
|
| 22 |
+
projection_dim=config.PROJECTION_DIM
|
| 23 |
+
).to(config.DEVICE)
|
| 24 |
+
|
| 25 |
+
optimizer = torch.optim.Adam(model.parameters(), lr=config.LEARNING_RATE)
|
| 26 |
+
|
| 27 |
+
for epoch in range(config.NUM_EPOCHS):
|
| 28 |
+
print(f"\nEpoch {epoch+1}/{config.NUM_EPOCHS}")
|
| 29 |
+
model.train()
|
| 30 |
+
total_loss = 0
|
| 31 |
+
progress_bar = tqdm(dataloader, desc=f"Training Epoch {epoch+1}")
|
| 32 |
+
|
| 33 |
+
for batch in progress_bar:
|
| 34 |
+
batch = {k: v.to(config.DEVICE) for k, v in batch.items() if isinstance(v, torch.Tensor)}
|
| 35 |
+
|
| 36 |
+
loss = model(batch)
|
| 37 |
+
|
| 38 |
+
optimizer.zero_grad()
|
| 39 |
+
loss.backward()
|
| 40 |
+
optimizer.step()
|
| 41 |
+
|
| 42 |
+
total_loss += loss.item()
|
| 43 |
+
progress_bar.set_postfix({"loss": f"{loss.item():.4f}"})
|
| 44 |
+
|
| 45 |
+
avg_loss = total_loss / len(dataloader)
|
| 46 |
+
print(f"Epoch {epoch+1} - Average Loss: {avg_loss:.4f}")
|
| 47 |
+
|
| 48 |
+
print("\nTraining complete.")
|
| 49 |
+
torch.save(model.state_dict(), config.MODEL_PATH)
|
| 50 |
+
print(f"Model saved to {config.MODEL_PATH}")
|
| 51 |
+
print("\nTo upload to Hugging Face Hub, run the upload_to_hub.py script.")
|
| 52 |
+
|
| 53 |
+
if __name__ == '__main__':
|
| 54 |
+
main()
|
training_model.py
ADDED
|
@@ -0,0 +1,73 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# model.py
|
| 2 |
+
import torch
|
| 3 |
+
import torch.nn as nn
|
| 4 |
+
import torch.nn.functional as F
|
| 5 |
+
from torchvision.models import resnet50
|
| 6 |
+
from transformers import DistilBertModel
|
| 7 |
+
|
| 8 |
+
class VisionEncoder(nn.Module):
|
| 9 |
+
def __init__(self):
|
| 10 |
+
super().__init__()
|
| 11 |
+
pretrained_resnet50 = resnet50(weights='IMAGENET1K_V1')
|
| 12 |
+
self.model = nn.Sequential(*list(pretrained_resnet50.children())[:-1])
|
| 13 |
+
for param in self.model.parameters():
|
| 14 |
+
param.requires_grad = False
|
| 15 |
+
|
| 16 |
+
def forward(self, x):
|
| 17 |
+
x = self.model(x)
|
| 18 |
+
return x.view(x.size(0), -1)
|
| 19 |
+
|
| 20 |
+
class TextEncoder(nn.Module):
|
| 21 |
+
def __init__(self):
|
| 22 |
+
super().__init__()
|
| 23 |
+
self.model = DistilBertModel.from_pretrained('distilbert-base-uncased')
|
| 24 |
+
for param in self.model.parameters():
|
| 25 |
+
param.requires_grad = False
|
| 26 |
+
|
| 27 |
+
def forward(self, input_ids, attention_mask=None):
|
| 28 |
+
outputs = self.model(input_ids=input_ids, attention_mask=attention_mask)
|
| 29 |
+
return outputs.last_hidden_state[:, 0, :]
|
| 30 |
+
|
| 31 |
+
class ProjectionHead(nn.Module):
|
| 32 |
+
def __init__(self, embedding_dim, projection_dim=256, dropout=0.1):
|
| 33 |
+
super().__init__()
|
| 34 |
+
self.projection = nn.Linear(embedding_dim, projection_dim)
|
| 35 |
+
self.gelu = nn.GELU()
|
| 36 |
+
self.fc = nn.Linear(projection_dim, projection_dim)
|
| 37 |
+
self.dropout = nn.Dropout(dropout)
|
| 38 |
+
self.layer_norm = nn.LayerNorm(projection_dim)
|
| 39 |
+
|
| 40 |
+
def forward(self, x):
|
| 41 |
+
projected = self.projection(x)
|
| 42 |
+
x = self.gelu(projected)
|
| 43 |
+
x = self.fc(x)
|
| 44 |
+
x = self.dropout(x)
|
| 45 |
+
x = x + projected
|
| 46 |
+
x = self.layer_norm(x)
|
| 47 |
+
return x
|
| 48 |
+
|
| 49 |
+
class CLIPModel(nn.Module):
|
| 50 |
+
def __init__(self, image_embedding_dim, text_embedding_dim, projection_dim):
|
| 51 |
+
super().__init__()
|
| 52 |
+
self.vision_encoder = VisionEncoder()
|
| 53 |
+
self.text_encoder = TextEncoder()
|
| 54 |
+
self.image_projection = ProjectionHead(embedding_dim=image_embedding_dim, projection_dim=projection_dim)
|
| 55 |
+
self.text_projection = ProjectionHead(embedding_dim=text_embedding_dim, projection_dim=projection_dim)
|
| 56 |
+
|
| 57 |
+
def forward(self, batch):
|
| 58 |
+
image_features = self.vision_encoder(batch["image"])
|
| 59 |
+
text_features = self.text_encoder(
|
| 60 |
+
input_ids=batch["input_ids"],
|
| 61 |
+
attention_mask=batch["attention_mask"]
|
| 62 |
+
)
|
| 63 |
+
image_embeddings = self.image_projection(image_features)
|
| 64 |
+
text_embeddings = self.text_projection(text_features)
|
| 65 |
+
|
| 66 |
+
# Textbook's specific loss calculation
|
| 67 |
+
logits = text_embeddings @ image_embeddings.T
|
| 68 |
+
images_similarity = image_embeddings @ image_embeddings.T
|
| 69 |
+
texts_similarity = text_embeddings @ text_embeddings.T
|
| 70 |
+
targets = F.softmax((images_similarity + texts_similarity) / 2, dim=-1)
|
| 71 |
+
texts_loss = F.cross_entropy(logits, targets)
|
| 72 |
+
images_loss = F.cross_entropy(logits.T, targets.T)
|
| 73 |
+
return (images_loss + texts_loss) / 2.0
|