Spaces:
Runtime error
Runtime error
Mustafa Acikgoz
commited on
Commit
·
8818841
1
Parent(s):
296fb5d
Fix: Correct image_encoder attribute and prevent startup timeout
Browse files- app.py +33 -39
- inference_model.py +30 -9
app.py
CHANGED
|
@@ -10,11 +10,10 @@ import glob
|
|
| 10 |
from tqdm import tqdm
|
| 11 |
|
| 12 |
# --- Custom Modules ---
|
| 13 |
-
# These imports assume your config.py and model files are in the same directory
|
| 14 |
import config
|
| 15 |
from inference_model import CLIPModel
|
| 16 |
|
| 17 |
-
# --- 1. Initial Setup: Load Model and Tokenizer
|
| 18 |
print("Starting application setup...")
|
| 19 |
device = config.DEVICE
|
| 20 |
|
|
@@ -32,35 +31,34 @@ try:
|
|
| 32 |
print("CLIP Model loaded successfully.")
|
| 33 |
except Exception as e:
|
| 34 |
print(f"Error loading model: {e}")
|
| 35 |
-
model = None
|
| 36 |
|
| 37 |
# Load the text tokenizer
|
| 38 |
tokenizer = DistilBertTokenizer.from_pretrained('distilbert-base-uncased')
|
| 39 |
print("Tokenizer loaded successfully.")
|
| 40 |
|
| 41 |
-
# --- 2. Data Handling: Download and Pre-process Images (runs once on startup) ---
|
| 42 |
-
# This is the key section that connects your app to your image dataset.
|
| 43 |
|
| 44 |
-
#
|
| 45 |
DATASET_REPO_ID = "mustafa2ak/Flickr8k-Images"
|
| 46 |
-
# Define the local folder where the images will be stored inside the Space
|
| 47 |
IMAGE_STORAGE_PATH = "./flickr8k_images"
|
| 48 |
|
| 49 |
print(f"Downloading image dataset from {DATASET_REPO_ID}...")
|
| 50 |
-
# Use snapshot_download for a fast, server-to-server transfer
|
| 51 |
snapshot_download(
|
| 52 |
repo_id=DATASET_REPO_ID,
|
| 53 |
repo_type="dataset",
|
| 54 |
local_dir=IMAGE_STORAGE_PATH,
|
| 55 |
-
local_dir_use_symlinks=False #
|
| 56 |
)
|
| 57 |
print("Image dataset download complete.")
|
| 58 |
|
| 59 |
-
# Get a list of all image file paths
|
| 60 |
-
# **CORRECTION**: The dataset structure has images directly in 'Flicker8k_Dataset'
|
| 61 |
-
# The original code was looking for a subfolder named 'images', which doesn't exist.
|
| 62 |
all_image_paths = glob.glob(os.path.join(IMAGE_STORAGE_PATH, "Flicker8k_Dataset", "*.jpg"))
|
| 63 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 64 |
|
| 65 |
# Define the image preprocessing pipeline
|
| 66 |
image_transform = transforms.Compose([
|
|
@@ -70,43 +68,39 @@ image_transform = transforms.Compose([
|
|
| 70 |
])
|
| 71 |
|
| 72 |
def precompute_image_embeddings(image_paths, model, transform, device):
|
| 73 |
-
"""
|
| 74 |
-
|
| 75 |
-
This is a crucial optimization.
|
| 76 |
-
"""
|
| 77 |
-
print("Pre-computing image embeddings... This may take a few minutes.")
|
| 78 |
all_embeddings = []
|
| 79 |
-
# torch.no_grad() disables gradient calculation, making this much faster
|
| 80 |
with torch.no_grad():
|
| 81 |
-
# tqdm creates a progress bar in your logs
|
| 82 |
for path in tqdm(image_paths, desc="Processing Images"):
|
| 83 |
try:
|
| 84 |
image = Image.open(path).convert("RGB")
|
| 85 |
image_tensor = transform(image).unsqueeze(0).to(device)
|
| 86 |
-
|
| 87 |
-
|
|
|
|
|
|
|
|
|
|
| 88 |
all_embeddings.append(embedding)
|
| 89 |
except Exception as e:
|
| 90 |
print(f"Warning: Could not process image {path}. Error: {e}")
|
| 91 |
continue
|
| 92 |
-
# Combine the list of individual tensors into one large tensor
|
| 93 |
return torch.cat(all_embeddings, dim=0)
|
| 94 |
|
| 95 |
# Pre-compute all image embeddings and store them in memory
|
| 96 |
if model and all_image_paths:
|
| 97 |
image_embeddings_precomputed = precompute_image_embeddings(all_image_paths, model, image_transform, device)
|
| 98 |
-
# Normalize the embeddings once for faster similarity calculation
|
| 99 |
image_embeddings_precomputed = F.normalize(image_embeddings_precomputed, p=2, dim=-1)
|
| 100 |
print("Image embeddings pre-computed and stored.")
|
| 101 |
else:
|
| 102 |
image_embeddings_precomputed = None
|
| 103 |
print("Skipping embedding pre-computation due to missing model or images.")
|
| 104 |
|
|
|
|
| 105 |
# --- 3. The Main Gradio Function for Text-to-Image Search ---
|
| 106 |
def find_image_from_text(text_query):
|
| 107 |
-
"""
|
| 108 |
-
Takes a text query and finds the best matching image from the pre-computed embeddings.
|
| 109 |
-
"""
|
| 110 |
if not text_query:
|
| 111 |
return None, "Please enter a text query."
|
| 112 |
if image_embeddings_precomputed is None:
|
|
@@ -114,31 +108,32 @@ def find_image_from_text(text_query):
|
|
| 114 |
|
| 115 |
print(f"Searching for text: '{text_query}'")
|
| 116 |
with torch.no_grad():
|
| 117 |
-
# 1. Process the text query
|
| 118 |
text_inputs = tokenizer([text_query], padding=True, truncation=True, return_tensors="pt").to(device)
|
| 119 |
-
|
| 120 |
-
|
|
|
|
|
|
|
|
|
|
| 121 |
attention_mask=text_inputs['attention_mask']
|
| 122 |
)
|
| 123 |
-
|
|
|
|
| 124 |
text_embedding_norm = F.normalize(text_embedding, p=2, dim=-1)
|
| 125 |
|
| 126 |
-
#
|
| 127 |
-
# This is a fast matrix multiplication: (1, 512) @ (512, N_images) -> (1, N_images)
|
| 128 |
similarity_scores = (text_embedding_norm @ image_embeddings_precomputed.T).squeeze(0)
|
| 129 |
|
| 130 |
-
#
|
| 131 |
best_image_index = similarity_scores.argmax().item()
|
| 132 |
-
|
| 133 |
-
# 5. Get the file path of the best image
|
| 134 |
best_image_path = all_image_paths[best_image_index]
|
| 135 |
best_score = similarity_scores[best_image_index].item()
|
| 136 |
|
| 137 |
print(f"Found best match: {best_image_path} with score {best_score:.4f}")
|
| 138 |
|
| 139 |
-
# Return the path to the best image and a caption for the UI
|
| 140 |
return best_image_path, f"Best match with score: {best_score:.4f}"
|
| 141 |
|
|
|
|
| 142 |
# --- 4. Create and Launch the Gradio Interface ---
|
| 143 |
iface = gr.Interface(
|
| 144 |
fn=find_image_from_text,
|
|
@@ -148,9 +143,8 @@ iface = gr.Interface(
|
|
| 148 |
gr.Textbox(label="Result Details")
|
| 149 |
],
|
| 150 |
title="🖼️ Text-to-Image Search with CLIP",
|
| 151 |
-
description="Enter a text description to search for the most relevant image in the Flickr8k dataset. The app
|
| 152 |
allow_flagging="never"
|
| 153 |
)
|
| 154 |
|
| 155 |
-
# This starts the web server
|
| 156 |
iface.launch()
|
|
|
|
| 10 |
from tqdm import tqdm
|
| 11 |
|
| 12 |
# --- Custom Modules ---
|
|
|
|
| 13 |
import config
|
| 14 |
from inference_model import CLIPModel
|
| 15 |
|
| 16 |
+
# --- 1. Initial Setup: Load Model and Tokenizer ---
|
| 17 |
print("Starting application setup...")
|
| 18 |
device = config.DEVICE
|
| 19 |
|
|
|
|
| 31 |
print("CLIP Model loaded successfully.")
|
| 32 |
except Exception as e:
|
| 33 |
print(f"Error loading model: {e}")
|
| 34 |
+
model = None
|
| 35 |
|
| 36 |
# Load the text tokenizer
|
| 37 |
tokenizer = DistilBertTokenizer.from_pretrained('distilbert-base-uncased')
|
| 38 |
print("Tokenizer loaded successfully.")
|
| 39 |
|
|
|
|
|
|
|
| 40 |
|
| 41 |
+
# --- 2. Data Handling: Download and Pre-process Images ---
|
| 42 |
DATASET_REPO_ID = "mustafa2ak/Flickr8k-Images"
|
|
|
|
| 43 |
IMAGE_STORAGE_PATH = "./flickr8k_images"
|
| 44 |
|
| 45 |
print(f"Downloading image dataset from {DATASET_REPO_ID}...")
|
|
|
|
| 46 |
snapshot_download(
|
| 47 |
repo_id=DATASET_REPO_ID,
|
| 48 |
repo_type="dataset",
|
| 49 |
local_dir=IMAGE_STORAGE_PATH,
|
| 50 |
+
local_dir_use_symlinks=False # Set to False for Spaces compatibility
|
| 51 |
)
|
| 52 |
print("Image dataset download complete.")
|
| 53 |
|
| 54 |
+
# Get a list of all image file paths
|
|
|
|
|
|
|
| 55 |
all_image_paths = glob.glob(os.path.join(IMAGE_STORAGE_PATH, "Flicker8k_Dataset", "*.jpg"))
|
| 56 |
+
|
| 57 |
+
# **CRITICAL FIX FOR TIMEOUT**: Use a smaller subset of images for the demo.
|
| 58 |
+
# Processing all 8000+ images on startup will cause a timeout on Hugging Face Spaces.
|
| 59 |
+
NUM_IMAGES_TO_PROCESS = 1000
|
| 60 |
+
all_image_paths = all_image_paths[:NUM_IMAGES_TO_PROCESS]
|
| 61 |
+
print(f"Found {len(all_image_paths)} total images. Using a subset of {NUM_IMAGES_TO_PROCESS} to prevent timeout.")
|
| 62 |
|
| 63 |
# Define the image preprocessing pipeline
|
| 64 |
image_transform = transforms.Compose([
|
|
|
|
| 68 |
])
|
| 69 |
|
| 70 |
def precompute_image_embeddings(image_paths, model, transform, device):
|
| 71 |
+
"""Processes all images and computes their final embeddings for fast searching."""
|
| 72 |
+
print("Pre-computing image embeddings... This may take a minute.")
|
|
|
|
|
|
|
|
|
|
| 73 |
all_embeddings = []
|
|
|
|
| 74 |
with torch.no_grad():
|
|
|
|
| 75 |
for path in tqdm(image_paths, desc="Processing Images"):
|
| 76 |
try:
|
| 77 |
image = Image.open(path).convert("RGB")
|
| 78 |
image_tensor = transform(image).unsqueeze(0).to(device)
|
| 79 |
+
|
| 80 |
+
# **CORRECTION**: Use the full model's forward pass to get projected embeddings.
|
| 81 |
+
# This returns (image_embedding, text_embedding), so we take the first element.
|
| 82 |
+
embedding, _ = model(image_features=image_tensor)
|
| 83 |
+
|
| 84 |
all_embeddings.append(embedding)
|
| 85 |
except Exception as e:
|
| 86 |
print(f"Warning: Could not process image {path}. Error: {e}")
|
| 87 |
continue
|
|
|
|
| 88 |
return torch.cat(all_embeddings, dim=0)
|
| 89 |
|
| 90 |
# Pre-compute all image embeddings and store them in memory
|
| 91 |
if model and all_image_paths:
|
| 92 |
image_embeddings_precomputed = precompute_image_embeddings(all_image_paths, model, image_transform, device)
|
| 93 |
+
# Normalize the embeddings once for faster similarity calculation
|
| 94 |
image_embeddings_precomputed = F.normalize(image_embeddings_precomputed, p=2, dim=-1)
|
| 95 |
print("Image embeddings pre-computed and stored.")
|
| 96 |
else:
|
| 97 |
image_embeddings_precomputed = None
|
| 98 |
print("Skipping embedding pre-computation due to missing model or images.")
|
| 99 |
|
| 100 |
+
|
| 101 |
# --- 3. The Main Gradio Function for Text-to-Image Search ---
|
| 102 |
def find_image_from_text(text_query):
|
| 103 |
+
"""Takes a text query and finds the best matching image."""
|
|
|
|
|
|
|
| 104 |
if not text_query:
|
| 105 |
return None, "Please enter a text query."
|
| 106 |
if image_embeddings_precomputed is None:
|
|
|
|
| 108 |
|
| 109 |
print(f"Searching for text: '{text_query}'")
|
| 110 |
with torch.no_grad():
|
| 111 |
+
# 1. Process the text query
|
| 112 |
text_inputs = tokenizer([text_query], padding=True, truncation=True, return_tensors="pt").to(device)
|
| 113 |
+
|
| 114 |
+
# 2. **CORRECTION**: Use the full model's forward pass to get projected text embedding.
|
| 115 |
+
# This returns (image_embedding, text_embedding), so we take the second element.
|
| 116 |
+
_, text_embedding = model(
|
| 117 |
+
text_input_ids=text_inputs['input_ids'],
|
| 118 |
attention_mask=text_inputs['attention_mask']
|
| 119 |
)
|
| 120 |
+
|
| 121 |
+
# 3. Normalize the text embedding
|
| 122 |
text_embedding_norm = F.normalize(text_embedding, p=2, dim=-1)
|
| 123 |
|
| 124 |
+
# 4. Calculate similarity against all pre-computed image embeddings
|
|
|
|
| 125 |
similarity_scores = (text_embedding_norm @ image_embeddings_precomputed.T).squeeze(0)
|
| 126 |
|
| 127 |
+
# 5. Find the index of the image with the highest score
|
| 128 |
best_image_index = similarity_scores.argmax().item()
|
|
|
|
|
|
|
| 129 |
best_image_path = all_image_paths[best_image_index]
|
| 130 |
best_score = similarity_scores[best_image_index].item()
|
| 131 |
|
| 132 |
print(f"Found best match: {best_image_path} with score {best_score:.4f}")
|
| 133 |
|
|
|
|
| 134 |
return best_image_path, f"Best match with score: {best_score:.4f}"
|
| 135 |
|
| 136 |
+
|
| 137 |
# --- 4. Create and Launch the Gradio Interface ---
|
| 138 |
iface = gr.Interface(
|
| 139 |
fn=find_image_from_text,
|
|
|
|
| 143 |
gr.Textbox(label="Result Details")
|
| 144 |
],
|
| 145 |
title="🖼️ Text-to-Image Search with CLIP",
|
| 146 |
+
description="Enter a text description to search for the most relevant image in the Flickr8k dataset. The app uses a pre-trained CLIP-like model to find the best match from a subset of 1000 images.",
|
| 147 |
allow_flagging="never"
|
| 148 |
)
|
| 149 |
|
|
|
|
| 150 |
iface.launch()
|
inference_model.py
CHANGED
|
@@ -1,32 +1,38 @@
|
|
| 1 |
-
# inference_model.py
|
| 2 |
import torch
|
| 3 |
import torch.nn as nn
|
| 4 |
from torchvision.models import resnet50
|
| 5 |
from transformers import DistilBertModel
|
| 6 |
|
| 7 |
-
# ---
|
|
|
|
|
|
|
| 8 |
class VisionEncoder(nn.Module):
|
| 9 |
def __init__(self):
|
| 10 |
super().__init__()
|
| 11 |
-
#
|
| 12 |
pretrained_resnet50 = resnet50(weights='IMAGENET1K_V1')
|
|
|
|
| 13 |
self.model = nn.Sequential(*list(pretrained_resnet50.children())[:-1])
|
|
|
|
| 14 |
for param in self.model.parameters():
|
| 15 |
param.requires_grad = False
|
| 16 |
|
| 17 |
def forward(self, x):
|
| 18 |
x = self.model(x)
|
|
|
|
| 19 |
return x.view(x.size(0), -1)
|
| 20 |
|
| 21 |
class TextEncoder(nn.Module):
|
| 22 |
def __init__(self):
|
| 23 |
super().__init__()
|
| 24 |
self.model = DistilBertModel.from_pretrained('distilbert-base-uncased')
|
|
|
|
| 25 |
for param in self.model.parameters():
|
| 26 |
param.requires_grad = False
|
| 27 |
|
| 28 |
def forward(self, input_ids, attention_mask=None):
|
| 29 |
outputs = self.model(input_ids=input_ids, attention_mask=attention_mask)
|
|
|
|
| 30 |
return outputs.last_hidden_state[:, 0, :]
|
| 31 |
|
| 32 |
class ProjectionHead(nn.Module):
|
|
@@ -43,31 +49,46 @@ class ProjectionHead(nn.Module):
|
|
| 43 |
x = self.gelu(projected)
|
| 44 |
x = self.fc(x)
|
| 45 |
x = self.dropout(x)
|
|
|
|
| 46 |
x = x + projected
|
| 47 |
x = self.layer_norm(x)
|
| 48 |
return x
|
| 49 |
|
| 50 |
-
# ---
|
|
|
|
|
|
|
| 51 |
class CLIPModel(nn.Module):
|
| 52 |
def __init__(self, image_embedding_dim, text_embedding_dim, projection_dim):
|
| 53 |
super().__init__()
|
| 54 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 55 |
self.text_encoder = TextEncoder()
|
| 56 |
self.image_projection = ProjectionHead(embedding_dim=image_embedding_dim, projection_dim=projection_dim)
|
| 57 |
self.text_projection = ProjectionHead(embedding_dim=text_embedding_dim, projection_dim=projection_dim)
|
| 58 |
|
| 59 |
def forward(self, image_features=None, text_input_ids=None, text_attention_mask=None):
|
|
|
|
|
|
|
|
|
|
|
|
|
| 60 |
image_embedding = None
|
| 61 |
if image_features is not None:
|
| 62 |
-
|
| 63 |
-
|
|
|
|
|
|
|
| 64 |
|
| 65 |
text_embedding = None
|
| 66 |
if text_input_ids is not None:
|
| 67 |
-
|
|
|
|
| 68 |
input_ids=text_input_ids,
|
| 69 |
attention_mask=text_attention_mask
|
| 70 |
)
|
| 71 |
-
|
|
|
|
| 72 |
|
| 73 |
return image_embedding, text_embedding
|
|
|
|
|
|
|
| 1 |
import torch
|
| 2 |
import torch.nn as nn
|
| 3 |
from torchvision.models import resnet50
|
| 4 |
from transformers import DistilBertModel
|
| 5 |
|
| 6 |
+
# --- Helper Classes (VisionEncoder, TextEncoder, ProjectionHead) ---
|
| 7 |
+
# These define the components of the overall CLIP model.
|
| 8 |
+
|
| 9 |
class VisionEncoder(nn.Module):
|
| 10 |
def __init__(self):
|
| 11 |
super().__init__()
|
| 12 |
+
# Use the recommended 'weights' parameter for pre-trained models
|
| 13 |
pretrained_resnet50 = resnet50(weights='IMAGENET1K_V1')
|
| 14 |
+
# Use all layers of ResNet50 except for the final fully connected layer
|
| 15 |
self.model = nn.Sequential(*list(pretrained_resnet50.children())[:-1])
|
| 16 |
+
# Freeze the parameters of the vision encoder
|
| 17 |
for param in self.model.parameters():
|
| 18 |
param.requires_grad = False
|
| 19 |
|
| 20 |
def forward(self, x):
|
| 21 |
x = self.model(x)
|
| 22 |
+
# Flatten the output to a 1D tensor per image
|
| 23 |
return x.view(x.size(0), -1)
|
| 24 |
|
| 25 |
class TextEncoder(nn.Module):
|
| 26 |
def __init__(self):
|
| 27 |
super().__init__()
|
| 28 |
self.model = DistilBertModel.from_pretrained('distilbert-base-uncased')
|
| 29 |
+
# Freeze the parameters of the text encoder
|
| 30 |
for param in self.model.parameters():
|
| 31 |
param.requires_grad = False
|
| 32 |
|
| 33 |
def forward(self, input_ids, attention_mask=None):
|
| 34 |
outputs = self.model(input_ids=input_ids, attention_mask=attention_mask)
|
| 35 |
+
# Use the embedding of the [CLS] token as the sentence representation
|
| 36 |
return outputs.last_hidden_state[:, 0, :]
|
| 37 |
|
| 38 |
class ProjectionHead(nn.Module):
|
|
|
|
| 49 |
x = self.gelu(projected)
|
| 50 |
x = self.fc(x)
|
| 51 |
x = self.dropout(x)
|
| 52 |
+
# Add a residual connection
|
| 53 |
x = x + projected
|
| 54 |
x = self.layer_norm(x)
|
| 55 |
return x
|
| 56 |
|
| 57 |
+
# --- Main CLIPModel for Inference ---
|
| 58 |
+
# This class combines the encoders and projection heads.
|
| 59 |
+
|
| 60 |
class CLIPModel(nn.Module):
|
| 61 |
def __init__(self, image_embedding_dim, text_embedding_dim, projection_dim):
|
| 62 |
super().__init__()
|
| 63 |
+
|
| 64 |
+
# **CORRECTION**: Renamed 'vision_encoder' to 'image_encoder'
|
| 65 |
+
# This attribute MUST be named 'image_encoder' to match the call in app.py
|
| 66 |
+
self.image_encoder = VisionEncoder()
|
| 67 |
+
|
| 68 |
self.text_encoder = TextEncoder()
|
| 69 |
self.image_projection = ProjectionHead(embedding_dim=image_embedding_dim, projection_dim=projection_dim)
|
| 70 |
self.text_projection = ProjectionHead(embedding_dim=text_embedding_dim, projection_dim=projection_dim)
|
| 71 |
|
| 72 |
def forward(self, image_features=None, text_input_ids=None, text_attention_mask=None):
|
| 73 |
+
"""
|
| 74 |
+
This forward pass handles both image and text inputs.
|
| 75 |
+
app.py will call this to get the final, projected embeddings.
|
| 76 |
+
"""
|
| 77 |
image_embedding = None
|
| 78 |
if image_features is not None:
|
| 79 |
+
# Get raw features from the vision backbone
|
| 80 |
+
image_features_raw = self.image_encoder(image_features)
|
| 81 |
+
# Project them into the shared embedding space
|
| 82 |
+
image_embedding = self.image_projection(image_features_raw)
|
| 83 |
|
| 84 |
text_embedding = None
|
| 85 |
if text_input_ids is not None:
|
| 86 |
+
# Get raw features from the text backbone
|
| 87 |
+
text_features_raw = self.text_encoder(
|
| 88 |
input_ids=text_input_ids,
|
| 89 |
attention_mask=text_attention_mask
|
| 90 |
)
|
| 91 |
+
# Project them into the shared embedding space
|
| 92 |
+
text_embedding = self.text_projection(text_features_raw)
|
| 93 |
|
| 94 |
return image_embedding, text_embedding
|