clip_ / app.py
mustafa2ak's picture
Update app.py
03334ff verified
import gradio as gr
import torch
import torch.nn.functional as F
from PIL import Image
from torchvision import transforms
from transformers import DistilBertTokenizer
from huggingface_hub import snapshot_download
import os
import glob
from tqdm import tqdm
# --- Custom Modules ---
import config
from inference_model import CLIPModel
# --- 1. Initial Setup: Load Model and Tokenizer ---
print("Starting application setup...")
device = config.DEVICE
# Load the CLIP model's structure
model = CLIPModel(
image_embedding_dim=config.IMAGE_EMBEDDING_DIM,
text_embedding_dim=config.TEXT_EMBEDDING_DIM,
projection_dim=config.PROJECTION_DIM
).to(device)
# --- CRITICAL STEP (Corrected) ---
# Load the state dictionary with `strict=False`.
# This allows the model to load only the weights present in the file (e.g., your trained
# projection heads) and ignore the missing ones (e.g., the base ResNet and DistilBERT weights,
# which are already pre-loaded by the model class itself).
try:
model.load_state_dict(torch.load(config.MODEL_PATH, map_location=device), strict=False)
model.eval()
print("CLIP Model loaded successfully.")
except Exception as e:
print(f"Error loading model: {e}")
model = None
# Load the text tokenizer
tokenizer = DistilBertTokenizer.from_pretrained('distilbert-base-uncased')
print("Tokenizer loaded successfully.")
# --- 2. Data Handling: Download and Pre-process Images ---
DATASET_REPO_ID = "mustafa2ak/Flickr8k-Images"
IMAGE_STORAGE_PATH = "./flickr8k_images"
print(f"Downloading image dataset from {DATASET_REPO_ID}...")
snapshot_download(
repo_id=DATASET_REPO_ID,
repo_type="dataset",
local_dir=IMAGE_STORAGE_PATH,
local_dir_use_symlinks=False
)
print("Image dataset download complete.")
# Get a list of all image file paths
all_image_paths = glob.glob(os.path.join(IMAGE_STORAGE_PATH, "Flicker8k_Dataset", "*.jpg"))
# Use a smaller subset of images to prevent timeouts and for faster testing.
# You can increase this value after confirming the app works.
NUM_IMAGES_TO_PROCESS = 100
all_image_paths = all_image_paths[:NUM_IMAGES_TO_PROCESS]
print(f"Found {len(all_image_paths)} total images. Using a subset of {NUM_IMAGES_TO_PROCESS} to prevent timeout.")
# Define the image preprocessing pipeline
image_transform = transforms.Compose([
transforms.Resize((224, 224)),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])
def precompute_image_embeddings(image_paths, model, transform, device):
"""Processes all images and computes their final embeddings for fast searching."""
print("Pre-computing image embeddings... This may take a minute.")
all_embeddings = []
with torch.no_grad():
for path in tqdm(image_paths, desc="Processing Images"):
try:
image = Image.open(path).convert("RGB")
image_tensor = transform(image).unsqueeze(0).to(device)
# Pass image_features to the model to get the embedding
embedding, _ = model(image_features=image_tensor)
all_embeddings.append(embedding)
except Exception as e:
print(f"Warning: Could not process image {path}. Error: {e}")
continue
return torch.cat(all_embeddings, dim=0)
# Pre-compute all image embeddings and store them in memory
if model and all_image_paths:
image_embeddings_precomputed = precompute_image_embeddings(all_image_paths, model, image_transform, device)
# Normalize the embeddings once for faster similarity calculation
image_embeddings_precomputed = F.normalize(image_embeddings_precomputed, p=2, dim=-1)
print("Image embeddings pre-computed and stored.")
else:
image_embeddings_precomputed = None
print("Skipping embedding pre-computation due to missing model or images.")
# --- 3. The Main Gradio Function for Text-to-Image Search ---
def find_image_from_text(text_query):
"""Takes a text query and finds the best matching image."""
if not text_query:
return None, "Please enter a text query."
if image_embeddings_precomputed is None:
return None, "Error: Image embeddings are not available. Check logs for errors."
print(f"Searching for text: '{text_query}'")
with torch.no_grad():
# 1. Process the text query
text_inputs = tokenizer([text_query], padding=True, truncation=True, return_tensors="pt").to(device)
# 2. Get the projected text embedding from the model.
_, text_embedding = model(
text_input_ids=text_inputs['input_ids'],
attention_mask=text_inputs['attention_mask']
)
# 3. Normalize the text embedding
text_embedding_norm = F.normalize(text_embedding, p=2, dim=-1)
# 4. Calculate similarity against all pre-computed image embeddings
similarity_scores = (text_embedding_norm @ image_embeddings_precomputed.T).squeeze(0)
# 5. Find the index of the image with the highest score
best_image_index = similarity_scores.argmax().item()
best_image_path = all_image_paths[best_image_index]
best_score = similarity_scores[best_image_index].item()
print(f"Found best match: {best_image_path} with score {best_score:.4f}")
return best_image_path, f"Best match with score: {best_score:.4f}"
# --- 4. Create and Launch the Gradio Interface ---
iface = gr.Interface(
fn=find_image_from_text,
inputs=gr.Textbox(lines=2, label="Text Query", placeholder="Enter text to find a matching image..."),
outputs=[
gr.Image(type="filepath", label="Best Matching Image"),
gr.Textbox(label="Result Details")
],
title="🖼️ Text-to-Image Search with CLIP",
description="Enter a text description to search for the most relevant image in the Flickr8k dataset. The app uses a pre-trained CLIP-like model to find the best match.",
allow_flagging="never"
)
iface.launch()