Mustafa Acikgoz commited on
Commit
2e51bae
·
0 Parent(s):

Initial clean commit for Gradio app

Browse files
Files changed (11) hide show
  1. .gitattributes +1 -0
  2. .gitignore +0 -0
  3. README.md +77 -0
  4. app.py +95 -0
  5. clip_book_model.pth +3 -0
  6. config.py +20 -0
  7. dataset.py +51 -0
  8. inference_model.py +73 -0
  9. requirements.txt +7 -0
  10. train.py +54 -0
  11. training_model.py +73 -0
.gitattributes ADDED
@@ -0,0 +1 @@
 
 
1
+ *.pth filter=lfs diff=lfs merge=lfs -text
.gitignore ADDED
File without changes
README.md ADDED
@@ -0,0 +1,77 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # CLIP-Style Image Search Engine (Textbook Implementation)
2
+
3
+ This project provides a complete, modular, and end-to-end implementation of a CLIP-style model for text-to-image search. The architecture and training methodology are a faithful reproduction of the approach described in Chapter 14 of the textbook, "Building an Image Search Engine Using CLIP: a Multimodal Approach".
4
+
5
+ The project is structured for clarity and maintainability, making it an ideal portfolio piece to showcase skills in PyTorch, model implementation, and MLOps practices like deployment with Streamlit and Hugging Face.
6
+
7
+ ## Key Features
8
+
9
+ - **Faithful "Book Version" Architecture:** Implements the specific design choices from the textbook:
10
+ - **Frozen Vision Encoder:** Uses a pre-trained `ResNet50` as a fixed feature extractor.
11
+ - **Frozen Text Encoder:** Uses a pre-trained `DistilBERT` as a fixed feature extractor.
12
+ - **Projection Heads:** Maps both image and text features into a shared 256-dimensional space.
13
+ - **Custom Contrastive Loss:** Implements the unique loss function described in the book.
14
+ - **Modular & Professional Code Structure:** The code is separated into logical files (`config.py`, `dataset.py`, `model.py`, `train.py`, `app.py`) for better organization and scalability.
15
+ - **End-to-End MLOps Pipeline:**
16
+ - **Training:** A dedicated script to train the model and save the weights.
17
+ - **Inference:** A standalone Streamlit web application for interactive text-to-image search.
18
+ - **Hub Integration:** Detailed instructions for uploading the trained model and hosting the app on the Hugging Face Hub.
19
+
20
+ ## Project Structure
21
+ your-clip-project/
22
+
23
+ ├── data/
24
+ │ ├── images/
25
+ │ └── captions.txt
26
+
27
+ ├── app.py
28
+ ├── config.py
29
+ ├── dataset.py
30
+ ├── model.py
31
+ ├── train.py
32
+
33
+ ├── requirements.txt
34
+ └── README.md
35
+
36
+
37
+ ## Setup and Installation
38
+
39
+ **1. Clone the Repository:**
40
+ ```bash
41
+ git clone <your-repo-url>
42
+ cd your-clip-project
43
+ 2. Create a Python Virtual Environment:
44
+
45
+ Bash
46
+
47
+ python -m venv venv
48
+ source venv/bin/activate # On Windows, use `venv\Scripts\activate`
49
+ 3. Install Dependencies:
50
+
51
+ Bash
52
+
53
+ pip install -r requirements.txt
54
+ 4. Download the Flickr8k Dataset:
55
+
56
+ Request the dataset from the official source: https://illinois.edu/fb/sec/1713398.
57
+
58
+ Download and extract Flickr8k_Dataset.zip into the data/images/ folder.
59
+
60
+ Find a captions.txt file (commonly available on Kaggle versions of the dataset) and place it at data/captions.txt.
61
+
62
+ How to Run
63
+ Step 1: Train the Model
64
+ First, you must train the model. This will create a clip_book_model.pth file containing the learned weights of the projection heads.
65
+
66
+ Run the training script from your terminal:
67
+
68
+ Bash
69
+
70
+ python train.py
71
+ Step 2: Launch the Web Application
72
+ Once the model is trained, launch the interactive search engine with Streamlit:
73
+
74
+ Bash
75
+
76
+ streamlit run app.py
77
+ This will open a new tab in your browser with the application running.
app.py ADDED
@@ -0,0 +1,95 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import torch
3
+ import torch.nn.functional as F
4
+ from PIL import Image
5
+ import requests
6
+ from io import BytesIO
7
+ from torchvision import transforms
8
+ from transformers import DistilBertTokenizer
9
+ import config # Your config file
10
+ from inference_model import CLIPModel # Your model class file
11
+
12
+ # --- 1. Load Model and Tokenizer (runs only once) ---
13
+ # This section loads your trained model and tokenizer when the app starts.
14
+ device = config.DEVICE
15
+
16
+ # Load model with dimensions from config
17
+ model = CLIPModel(
18
+ image_embedding_dim=config.IMAGE_EMBEDDING_DIM,
19
+ text_embedding_dim=config.TEXT_EMBEDDING_DIM,
20
+ projection_dim=config.PROJECTION_DIM
21
+ ).to(device)
22
+
23
+ # Load the trained model weights from your .pth file
24
+ model.load_state_dict(torch.load(config.MODEL_PATH, map_location=device))
25
+ model.eval()
26
+
27
+ # Load tokenizer
28
+ tokenizer = DistilBertTokenizer.from_pretrained('distilbert-base-uncased')
29
+ print("Model and Tokenizer loaded successfully.")
30
+
31
+ # --- 2. Image Preprocessing Function (reused from your code) ---
32
+ def preprocess_image(image):
33
+ """Preprocess the image for the model."""
34
+ transform = transforms.Compose([
35
+ transforms.Resize((224, 224)),
36
+ transforms.ToTensor(),
37
+ transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
38
+ ])
39
+ return transform(image).unsqueeze(0)
40
+
41
+ # --- 3. The Main Gradio Function ---
42
+ # This is the core function that Gradio will build a UI around.
43
+ # It takes the inputs from the UI and returns the outputs to the UI.
44
+ def find_best_match(image_input, text_queries_input):
45
+ """
46
+ Takes an image and a block of text queries, and returns a dictionary
47
+ of queries and their similarity scores.
48
+ """
49
+ if image_input is None:
50
+ return "Please provide an image."
51
+ if not text_queries_input:
52
+ return "Please provide text descriptions."
53
+
54
+ # Process the image to get a tensor
55
+ image_tensor = preprocess_image(image_input).to(device)
56
+
57
+ # Process the text queries into a clean list
58
+ queries = [q.strip() for q in text_queries_input.split('\n') if q.strip()]
59
+ if not queries:
60
+ return "Please provide valid text descriptions."
61
+
62
+ # Process the text queries to get tokens
63
+ text_inputs = tokenizer(queries, padding=True, truncation=True, return_tensors="pt").to(device)
64
+
65
+ # Get model embeddings
66
+ with torch.no_grad():
67
+ image_embedding, text_embeddings = model(
68
+ image_features=image_tensor,
69
+ text_input_ids=text_inputs['input_ids'],
70
+ text_attention_mask=text_inputs['attention_mask']
71
+ )
72
+
73
+ # Calculate cosine similarity and format for Gradio's Label component
74
+ image_embedding_norm = F.normalize(image_embedding, p=2, dim=-1)
75
+ text_embeddings_norm = F.normalize(text_embeddings, p=2, dim=-1)
76
+ similarity_scores = (image_embedding_norm @ text_embeddings_norm.T).squeeze(0)
77
+
78
+ # Create a results dictionary: { "query text": score, ... }
79
+ results = {query: score.item() for query, score in zip(queries, similarity_scores)}
80
+
81
+ return results
82
+
83
+ # --- 4. Create and Launch the Gradio Interface ---
84
+ iface = gr.Interface(
85
+ fn=find_best_match,
86
+ inputs=[
87
+ gr.Image(type="pil", label="Upload or Drag an Image"),
88
+ gr.Textbox(lines=5, label="Text Descriptions (one per line)", placeholder="a person on a beach\na black cat\na city skyline at night")
89
+ ],
90
+ outputs=gr.Label(num_top_classes=5, label="Results"),
91
+ title="🖼️ CLIP Image-Text Search",
92
+ description="Provide an image and several text descriptions. The app will use a trained CLIP model to find the best textual match for the image."
93
+ )
94
+
95
+ iface.launch()
clip_book_model.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:e372639e3a5bfa25a166d5825293e682c684628e1d95eccb6da17fbbac82b522
3
+ size 363273830
config.py ADDED
@@ -0,0 +1,20 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # config.py
2
+ import torch
3
+
4
+ # --- Project Paths ---
5
+ IMAGE_DIR = "data/Flicker8k_Dataset"
6
+ CAPTION_FILE = "data/captions.txt"
7
+ MODEL_PATH = "clip_book_model.pth"
8
+
9
+ # --- Model Dimensions ---
10
+ IMAGE_EMBEDDING_DIM = 2048 # ResNet50 output dimension
11
+ TEXT_EMBEDDING_DIM = 768 # DistilBERT output dimension
12
+ PROJECTION_DIM = 256 # Shared embedding space dimension
13
+
14
+ # --- Training Parameters ---
15
+ BATCH_SIZE = 32
16
+ NUM_EPOCHS = 3
17
+ LEARNING_RATE = 1e-3
18
+
19
+ # --- System ---
20
+ DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
dataset.py ADDED
@@ -0,0 +1,51 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # dataset.py
2
+ import os
3
+ import torch
4
+ from torch.utils.data import Dataset
5
+ from PIL import Image
6
+ import pandas as pd
7
+ import torchvision.transforms as transforms
8
+
9
+ class Flickr8kDataset(Dataset):
10
+ """
11
+ Custom PyTorch Dataset for the Flickr8k data.
12
+ It loads images and their corresponding captions, tokenizing the text
13
+ on initialization for efficiency.
14
+ """
15
+ def __init__(self, image_dir, caption_file, tokenizer):
16
+ self.image_dir = image_dir
17
+ self.tokenizer = tokenizer
18
+
19
+ self.transform = transforms.Compose([
20
+ transforms.Resize((224, 224)),
21
+ transforms.ToTensor(),
22
+ transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
23
+ ])
24
+
25
+ df = pd.read_csv(caption_file)
26
+ self.image_paths = [os.path.join(self.image_dir, fname) for fname in df['image']]
27
+ self.captions = df['caption'].tolist()
28
+
29
+ print("Tokenizing all captions... (This may take a moment)")
30
+ self.caption_encodings = self.tokenizer(
31
+ self.captions,
32
+ truncation=True,
33
+ padding='max_length',
34
+ max_length=200,
35
+ return_tensors="pt"
36
+ )
37
+ print("Tokenization complete.")
38
+
39
+ def __len__(self):
40
+ return len(self.captions)
41
+
42
+ def __getitem__(self, idx):
43
+ item = {key: val[idx] for key, val in self.caption_encodings.items()}
44
+ try:
45
+ img = Image.open(self.image_paths[idx]).convert("RGB")
46
+ item['image'] = self.transform(img)
47
+ except (FileNotFoundError):
48
+ print(f"Warning: Could not load image at {self.image_paths[idx]}. Returning a black image.")
49
+ item['image'] = torch.zeros((3, 224, 224))
50
+ item["caption_text"] = self.captions[idx]
51
+ return item
inference_model.py ADDED
@@ -0,0 +1,73 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # inference_model.py
2
+ import torch
3
+ import torch.nn as nn
4
+ from torchvision.models import resnet50
5
+ from transformers import DistilBertModel
6
+
7
+ # --- Copy these classes from your original file ---
8
+ class VisionEncoder(nn.Module):
9
+ def __init__(self):
10
+ super().__init__()
11
+ # Note: Using the newer 'weights' parameter is recommended
12
+ pretrained_resnet50 = resnet50(weights='IMAGENET1K_V1')
13
+ self.model = nn.Sequential(*list(pretrained_resnet50.children())[:-1])
14
+ for param in self.model.parameters():
15
+ param.requires_grad = False
16
+
17
+ def forward(self, x):
18
+ x = self.model(x)
19
+ return x.view(x.size(0), -1)
20
+
21
+ class TextEncoder(nn.Module):
22
+ def __init__(self):
23
+ super().__init__()
24
+ self.model = DistilBertModel.from_pretrained('distilbert-base-uncased')
25
+ for param in self.model.parameters():
26
+ param.requires_grad = False
27
+
28
+ def forward(self, input_ids, attention_mask=None):
29
+ outputs = self.model(input_ids=input_ids, attention_mask=attention_mask)
30
+ return outputs.last_hidden_state[:, 0, :]
31
+
32
+ class ProjectionHead(nn.Module):
33
+ def __init__(self, embedding_dim, projection_dim=256, dropout=0.1):
34
+ super().__init__()
35
+ self.projection = nn.Linear(embedding_dim, projection_dim)
36
+ self.gelu = nn.GELU()
37
+ self.fc = nn.Linear(projection_dim, projection_dim)
38
+ self.dropout = nn.Dropout(dropout)
39
+ self.layer_norm = nn.LayerNorm(projection_dim)
40
+
41
+ def forward(self, x):
42
+ projected = self.projection(x)
43
+ x = self.gelu(projected)
44
+ x = self.fc(x)
45
+ x = self.dropout(x)
46
+ x = x + projected
47
+ x = self.layer_norm(x)
48
+ return x
49
+
50
+ # --- This is the MODIFIED CLIPModel for inference ---
51
+ class CLIPModel(nn.Module):
52
+ def __init__(self, image_embedding_dim, text_embedding_dim, projection_dim):
53
+ super().__init__()
54
+ self.vision_encoder = VisionEncoder()
55
+ self.text_encoder = TextEncoder()
56
+ self.image_projection = ProjectionHead(embedding_dim=image_embedding_dim, projection_dim=projection_dim)
57
+ self.text_projection = ProjectionHead(embedding_dim=text_embedding_dim, projection_dim=projection_dim)
58
+
59
+ def forward(self, image_features=None, text_input_ids=None, text_attention_mask=None):
60
+ image_embedding = None
61
+ if image_features is not None:
62
+ image_features = self.vision_encoder(image_features)
63
+ image_embedding = self.image_projection(image_features)
64
+
65
+ text_embedding = None
66
+ if text_input_ids is not None:
67
+ text_features = self.text_encoder(
68
+ input_ids=text_input_ids,
69
+ attention_mask=text_attention_mask
70
+ )
71
+ text_embedding = self.text_projection(text_features)
72
+
73
+ return image_embedding, text_embedding
requirements.txt ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ gradio
2
+ torch
3
+ torchvision
4
+ transformers
5
+ Pillow
6
+ requests
7
+ huggingface-hub
train.py ADDED
@@ -0,0 +1,54 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # train.py
2
+ import torch
3
+ from torch.utils.data import DataLoader
4
+ from transformers import DistilBertTokenizer
5
+ from tqdm import tqdm
6
+
7
+ import config
8
+ from dataset import Flickr8kDataset
9
+ from model import CLIPModel
10
+
11
+ def main():
12
+ print(f"--- Starting Training ---")
13
+ print(f"Using device: {config.DEVICE}")
14
+
15
+ tokenizer = DistilBertTokenizer.from_pretrained('distilbert-base-uncased')
16
+ dataset = Flickr8kDataset(config.IMAGE_DIR, config.CAPTION_FILE, tokenizer)
17
+ dataloader = DataLoader(dataset, batch_size=config.BATCH_SIZE, shuffle=True)
18
+
19
+ model = CLIPModel(
20
+ image_embedding_dim=config.IMAGE_EMBEDDING_DIM,
21
+ text_embedding_dim=config.TEXT_EMBEDDING_DIM,
22
+ projection_dim=config.PROJECTION_DIM
23
+ ).to(config.DEVICE)
24
+
25
+ optimizer = torch.optim.Adam(model.parameters(), lr=config.LEARNING_RATE)
26
+
27
+ for epoch in range(config.NUM_EPOCHS):
28
+ print(f"\nEpoch {epoch+1}/{config.NUM_EPOCHS}")
29
+ model.train()
30
+ total_loss = 0
31
+ progress_bar = tqdm(dataloader, desc=f"Training Epoch {epoch+1}")
32
+
33
+ for batch in progress_bar:
34
+ batch = {k: v.to(config.DEVICE) for k, v in batch.items() if isinstance(v, torch.Tensor)}
35
+
36
+ loss = model(batch)
37
+
38
+ optimizer.zero_grad()
39
+ loss.backward()
40
+ optimizer.step()
41
+
42
+ total_loss += loss.item()
43
+ progress_bar.set_postfix({"loss": f"{loss.item():.4f}"})
44
+
45
+ avg_loss = total_loss / len(dataloader)
46
+ print(f"Epoch {epoch+1} - Average Loss: {avg_loss:.4f}")
47
+
48
+ print("\nTraining complete.")
49
+ torch.save(model.state_dict(), config.MODEL_PATH)
50
+ print(f"Model saved to {config.MODEL_PATH}")
51
+ print("\nTo upload to Hugging Face Hub, run the upload_to_hub.py script.")
52
+
53
+ if __name__ == '__main__':
54
+ main()
training_model.py ADDED
@@ -0,0 +1,73 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # model.py
2
+ import torch
3
+ import torch.nn as nn
4
+ import torch.nn.functional as F
5
+ from torchvision.models import resnet50
6
+ from transformers import DistilBertModel
7
+
8
+ class VisionEncoder(nn.Module):
9
+ def __init__(self):
10
+ super().__init__()
11
+ pretrained_resnet50 = resnet50(weights='IMAGENET1K_V1')
12
+ self.model = nn.Sequential(*list(pretrained_resnet50.children())[:-1])
13
+ for param in self.model.parameters():
14
+ param.requires_grad = False
15
+
16
+ def forward(self, x):
17
+ x = self.model(x)
18
+ return x.view(x.size(0), -1)
19
+
20
+ class TextEncoder(nn.Module):
21
+ def __init__(self):
22
+ super().__init__()
23
+ self.model = DistilBertModel.from_pretrained('distilbert-base-uncased')
24
+ for param in self.model.parameters():
25
+ param.requires_grad = False
26
+
27
+ def forward(self, input_ids, attention_mask=None):
28
+ outputs = self.model(input_ids=input_ids, attention_mask=attention_mask)
29
+ return outputs.last_hidden_state[:, 0, :]
30
+
31
+ class ProjectionHead(nn.Module):
32
+ def __init__(self, embedding_dim, projection_dim=256, dropout=0.1):
33
+ super().__init__()
34
+ self.projection = nn.Linear(embedding_dim, projection_dim)
35
+ self.gelu = nn.GELU()
36
+ self.fc = nn.Linear(projection_dim, projection_dim)
37
+ self.dropout = nn.Dropout(dropout)
38
+ self.layer_norm = nn.LayerNorm(projection_dim)
39
+
40
+ def forward(self, x):
41
+ projected = self.projection(x)
42
+ x = self.gelu(projected)
43
+ x = self.fc(x)
44
+ x = self.dropout(x)
45
+ x = x + projected
46
+ x = self.layer_norm(x)
47
+ return x
48
+
49
+ class CLIPModel(nn.Module):
50
+ def __init__(self, image_embedding_dim, text_embedding_dim, projection_dim):
51
+ super().__init__()
52
+ self.vision_encoder = VisionEncoder()
53
+ self.text_encoder = TextEncoder()
54
+ self.image_projection = ProjectionHead(embedding_dim=image_embedding_dim, projection_dim=projection_dim)
55
+ self.text_projection = ProjectionHead(embedding_dim=text_embedding_dim, projection_dim=projection_dim)
56
+
57
+ def forward(self, batch):
58
+ image_features = self.vision_encoder(batch["image"])
59
+ text_features = self.text_encoder(
60
+ input_ids=batch["input_ids"],
61
+ attention_mask=batch["attention_mask"]
62
+ )
63
+ image_embeddings = self.image_projection(image_features)
64
+ text_embeddings = self.text_projection(text_features)
65
+
66
+ # Textbook's specific loss calculation
67
+ logits = text_embeddings @ image_embeddings.T
68
+ images_similarity = image_embeddings @ image_embeddings.T
69
+ texts_similarity = text_embeddings @ text_embeddings.T
70
+ targets = F.softmax((images_similarity + texts_similarity) / 2, dim=-1)
71
+ texts_loss = F.cross_entropy(logits, targets)
72
+ images_loss = F.cross_entropy(logits.T, targets.T)
73
+ return (images_loss + texts_loss) / 2.0