Mustafa Acikgoz commited on
Commit
8818841
·
1 Parent(s): 296fb5d

Fix: Correct image_encoder attribute and prevent startup timeout

Browse files
Files changed (2) hide show
  1. app.py +33 -39
  2. inference_model.py +30 -9
app.py CHANGED
@@ -10,11 +10,10 @@ import glob
10
  from tqdm import tqdm
11
 
12
  # --- Custom Modules ---
13
- # These imports assume your config.py and model files are in the same directory
14
  import config
15
  from inference_model import CLIPModel
16
 
17
- # --- 1. Initial Setup: Load Model and Tokenizer (runs once on startup) ---
18
  print("Starting application setup...")
19
  device = config.DEVICE
20
 
@@ -32,35 +31,34 @@ try:
32
  print("CLIP Model loaded successfully.")
33
  except Exception as e:
34
  print(f"Error loading model: {e}")
35
- model = None # Set model to None if loading fails
36
 
37
  # Load the text tokenizer
38
  tokenizer = DistilBertTokenizer.from_pretrained('distilbert-base-uncased')
39
  print("Tokenizer loaded successfully.")
40
 
41
- # --- 2. Data Handling: Download and Pre-process Images (runs once on startup) ---
42
- # This is the key section that connects your app to your image dataset.
43
 
44
- # Define the dataset repository on the Hugging Face Hub
45
  DATASET_REPO_ID = "mustafa2ak/Flickr8k-Images"
46
- # Define the local folder where the images will be stored inside the Space
47
  IMAGE_STORAGE_PATH = "./flickr8k_images"
48
 
49
  print(f"Downloading image dataset from {DATASET_REPO_ID}...")
50
- # Use snapshot_download for a fast, server-to-server transfer
51
  snapshot_download(
52
  repo_id=DATASET_REPO_ID,
53
  repo_type="dataset",
54
  local_dir=IMAGE_STORAGE_PATH,
55
- local_dir_use_symlinks=False # Important for compatibility
56
  )
57
  print("Image dataset download complete.")
58
 
59
- # Get a list of all image file paths from the downloaded folder
60
- # **CORRECTION**: The dataset structure has images directly in 'Flicker8k_Dataset'
61
- # The original code was looking for a subfolder named 'images', which doesn't exist.
62
  all_image_paths = glob.glob(os.path.join(IMAGE_STORAGE_PATH, "Flicker8k_Dataset", "*.jpg"))
63
- print(f"Found {len(all_image_paths)} images.")
 
 
 
 
 
64
 
65
  # Define the image preprocessing pipeline
66
  image_transform = transforms.Compose([
@@ -70,43 +68,39 @@ image_transform = transforms.Compose([
70
  ])
71
 
72
  def precompute_image_embeddings(image_paths, model, transform, device):
73
- """
74
- Processes all images and computes their embeddings for fast searching.
75
- This is a crucial optimization.
76
- """
77
- print("Pre-computing image embeddings... This may take a few minutes.")
78
  all_embeddings = []
79
- # torch.no_grad() disables gradient calculation, making this much faster
80
  with torch.no_grad():
81
- # tqdm creates a progress bar in your logs
82
  for path in tqdm(image_paths, desc="Processing Images"):
83
  try:
84
  image = Image.open(path).convert("RGB")
85
  image_tensor = transform(image).unsqueeze(0).to(device)
86
- # Pass the image through the model's image encoder part
87
- embedding = model.image_encoder(image_tensor)
 
 
 
88
  all_embeddings.append(embedding)
89
  except Exception as e:
90
  print(f"Warning: Could not process image {path}. Error: {e}")
91
  continue
92
- # Combine the list of individual tensors into one large tensor
93
  return torch.cat(all_embeddings, dim=0)
94
 
95
  # Pre-compute all image embeddings and store them in memory
96
  if model and all_image_paths:
97
  image_embeddings_precomputed = precompute_image_embeddings(all_image_paths, model, image_transform, device)
98
- # Normalize the embeddings once for faster similarity calculation later
99
  image_embeddings_precomputed = F.normalize(image_embeddings_precomputed, p=2, dim=-1)
100
  print("Image embeddings pre-computed and stored.")
101
  else:
102
  image_embeddings_precomputed = None
103
  print("Skipping embedding pre-computation due to missing model or images.")
104
 
 
105
  # --- 3. The Main Gradio Function for Text-to-Image Search ---
106
  def find_image_from_text(text_query):
107
- """
108
- Takes a text query and finds the best matching image from the pre-computed embeddings.
109
- """
110
  if not text_query:
111
  return None, "Please enter a text query."
112
  if image_embeddings_precomputed is None:
@@ -114,31 +108,32 @@ def find_image_from_text(text_query):
114
 
115
  print(f"Searching for text: '{text_query}'")
116
  with torch.no_grad():
117
- # 1. Process the text query into tokens and get its embedding
118
  text_inputs = tokenizer([text_query], padding=True, truncation=True, return_tensors="pt").to(device)
119
- text_embedding = model.text_encoder(
120
- input_ids=text_inputs['input_ids'],
 
 
 
121
  attention_mask=text_inputs['attention_mask']
122
  )
123
- # 2. Normalize the text embedding
 
124
  text_embedding_norm = F.normalize(text_embedding, p=2, dim=-1)
125
 
126
- # 3. Calculate similarity against all pre-computed image embeddings
127
- # This is a fast matrix multiplication: (1, 512) @ (512, N_images) -> (1, N_images)
128
  similarity_scores = (text_embedding_norm @ image_embeddings_precomputed.T).squeeze(0)
129
 
130
- # 4. Find the index of the image with the highest score
131
  best_image_index = similarity_scores.argmax().item()
132
-
133
- # 5. Get the file path of the best image
134
  best_image_path = all_image_paths[best_image_index]
135
  best_score = similarity_scores[best_image_index].item()
136
 
137
  print(f"Found best match: {best_image_path} with score {best_score:.4f}")
138
 
139
- # Return the path to the best image and a caption for the UI
140
  return best_image_path, f"Best match with score: {best_score:.4f}"
141
 
 
142
  # --- 4. Create and Launch the Gradio Interface ---
143
  iface = gr.Interface(
144
  fn=find_image_from_text,
@@ -148,9 +143,8 @@ iface = gr.Interface(
148
  gr.Textbox(label="Result Details")
149
  ],
150
  title="🖼️ Text-to-Image Search with CLIP",
151
- description="Enter a text description to search for the most relevant image in the Flickr8k dataset. The app will download the dataset and pre-process images on startup.",
152
  allow_flagging="never"
153
  )
154
 
155
- # This starts the web server
156
  iface.launch()
 
10
  from tqdm import tqdm
11
 
12
  # --- Custom Modules ---
 
13
  import config
14
  from inference_model import CLIPModel
15
 
16
+ # --- 1. Initial Setup: Load Model and Tokenizer ---
17
  print("Starting application setup...")
18
  device = config.DEVICE
19
 
 
31
  print("CLIP Model loaded successfully.")
32
  except Exception as e:
33
  print(f"Error loading model: {e}")
34
+ model = None
35
 
36
  # Load the text tokenizer
37
  tokenizer = DistilBertTokenizer.from_pretrained('distilbert-base-uncased')
38
  print("Tokenizer loaded successfully.")
39
 
 
 
40
 
41
+ # --- 2. Data Handling: Download and Pre-process Images ---
42
  DATASET_REPO_ID = "mustafa2ak/Flickr8k-Images"
 
43
  IMAGE_STORAGE_PATH = "./flickr8k_images"
44
 
45
  print(f"Downloading image dataset from {DATASET_REPO_ID}...")
 
46
  snapshot_download(
47
  repo_id=DATASET_REPO_ID,
48
  repo_type="dataset",
49
  local_dir=IMAGE_STORAGE_PATH,
50
+ local_dir_use_symlinks=False # Set to False for Spaces compatibility
51
  )
52
  print("Image dataset download complete.")
53
 
54
+ # Get a list of all image file paths
 
 
55
  all_image_paths = glob.glob(os.path.join(IMAGE_STORAGE_PATH, "Flicker8k_Dataset", "*.jpg"))
56
+
57
+ # **CRITICAL FIX FOR TIMEOUT**: Use a smaller subset of images for the demo.
58
+ # Processing all 8000+ images on startup will cause a timeout on Hugging Face Spaces.
59
+ NUM_IMAGES_TO_PROCESS = 1000
60
+ all_image_paths = all_image_paths[:NUM_IMAGES_TO_PROCESS]
61
+ print(f"Found {len(all_image_paths)} total images. Using a subset of {NUM_IMAGES_TO_PROCESS} to prevent timeout.")
62
 
63
  # Define the image preprocessing pipeline
64
  image_transform = transforms.Compose([
 
68
  ])
69
 
70
  def precompute_image_embeddings(image_paths, model, transform, device):
71
+ """Processes all images and computes their final embeddings for fast searching."""
72
+ print("Pre-computing image embeddings... This may take a minute.")
 
 
 
73
  all_embeddings = []
 
74
  with torch.no_grad():
 
75
  for path in tqdm(image_paths, desc="Processing Images"):
76
  try:
77
  image = Image.open(path).convert("RGB")
78
  image_tensor = transform(image).unsqueeze(0).to(device)
79
+
80
+ # **CORRECTION**: Use the full model's forward pass to get projected embeddings.
81
+ # This returns (image_embedding, text_embedding), so we take the first element.
82
+ embedding, _ = model(image_features=image_tensor)
83
+
84
  all_embeddings.append(embedding)
85
  except Exception as e:
86
  print(f"Warning: Could not process image {path}. Error: {e}")
87
  continue
 
88
  return torch.cat(all_embeddings, dim=0)
89
 
90
  # Pre-compute all image embeddings and store them in memory
91
  if model and all_image_paths:
92
  image_embeddings_precomputed = precompute_image_embeddings(all_image_paths, model, image_transform, device)
93
+ # Normalize the embeddings once for faster similarity calculation
94
  image_embeddings_precomputed = F.normalize(image_embeddings_precomputed, p=2, dim=-1)
95
  print("Image embeddings pre-computed and stored.")
96
  else:
97
  image_embeddings_precomputed = None
98
  print("Skipping embedding pre-computation due to missing model or images.")
99
 
100
+
101
  # --- 3. The Main Gradio Function for Text-to-Image Search ---
102
  def find_image_from_text(text_query):
103
+ """Takes a text query and finds the best matching image."""
 
 
104
  if not text_query:
105
  return None, "Please enter a text query."
106
  if image_embeddings_precomputed is None:
 
108
 
109
  print(f"Searching for text: '{text_query}'")
110
  with torch.no_grad():
111
+ # 1. Process the text query
112
  text_inputs = tokenizer([text_query], padding=True, truncation=True, return_tensors="pt").to(device)
113
+
114
+ # 2. **CORRECTION**: Use the full model's forward pass to get projected text embedding.
115
+ # This returns (image_embedding, text_embedding), so we take the second element.
116
+ _, text_embedding = model(
117
+ text_input_ids=text_inputs['input_ids'],
118
  attention_mask=text_inputs['attention_mask']
119
  )
120
+
121
+ # 3. Normalize the text embedding
122
  text_embedding_norm = F.normalize(text_embedding, p=2, dim=-1)
123
 
124
+ # 4. Calculate similarity against all pre-computed image embeddings
 
125
  similarity_scores = (text_embedding_norm @ image_embeddings_precomputed.T).squeeze(0)
126
 
127
+ # 5. Find the index of the image with the highest score
128
  best_image_index = similarity_scores.argmax().item()
 
 
129
  best_image_path = all_image_paths[best_image_index]
130
  best_score = similarity_scores[best_image_index].item()
131
 
132
  print(f"Found best match: {best_image_path} with score {best_score:.4f}")
133
 
 
134
  return best_image_path, f"Best match with score: {best_score:.4f}"
135
 
136
+
137
  # --- 4. Create and Launch the Gradio Interface ---
138
  iface = gr.Interface(
139
  fn=find_image_from_text,
 
143
  gr.Textbox(label="Result Details")
144
  ],
145
  title="🖼️ Text-to-Image Search with CLIP",
146
+ description="Enter a text description to search for the most relevant image in the Flickr8k dataset. The app uses a pre-trained CLIP-like model to find the best match from a subset of 1000 images.",
147
  allow_flagging="never"
148
  )
149
 
 
150
  iface.launch()
inference_model.py CHANGED
@@ -1,32 +1,38 @@
1
- # inference_model.py
2
  import torch
3
  import torch.nn as nn
4
  from torchvision.models import resnet50
5
  from transformers import DistilBertModel
6
 
7
- # --- Copy these classes from your original file ---
 
 
8
  class VisionEncoder(nn.Module):
9
  def __init__(self):
10
  super().__init__()
11
- # Note: Using the newer 'weights' parameter is recommended
12
  pretrained_resnet50 = resnet50(weights='IMAGENET1K_V1')
 
13
  self.model = nn.Sequential(*list(pretrained_resnet50.children())[:-1])
 
14
  for param in self.model.parameters():
15
  param.requires_grad = False
16
 
17
  def forward(self, x):
18
  x = self.model(x)
 
19
  return x.view(x.size(0), -1)
20
 
21
  class TextEncoder(nn.Module):
22
  def __init__(self):
23
  super().__init__()
24
  self.model = DistilBertModel.from_pretrained('distilbert-base-uncased')
 
25
  for param in self.model.parameters():
26
  param.requires_grad = False
27
 
28
  def forward(self, input_ids, attention_mask=None):
29
  outputs = self.model(input_ids=input_ids, attention_mask=attention_mask)
 
30
  return outputs.last_hidden_state[:, 0, :]
31
 
32
  class ProjectionHead(nn.Module):
@@ -43,31 +49,46 @@ class ProjectionHead(nn.Module):
43
  x = self.gelu(projected)
44
  x = self.fc(x)
45
  x = self.dropout(x)
 
46
  x = x + projected
47
  x = self.layer_norm(x)
48
  return x
49
 
50
- # --- This is the MODIFIED CLIPModel for inference ---
 
 
51
  class CLIPModel(nn.Module):
52
  def __init__(self, image_embedding_dim, text_embedding_dim, projection_dim):
53
  super().__init__()
54
- self.vision_encoder = VisionEncoder()
 
 
 
 
55
  self.text_encoder = TextEncoder()
56
  self.image_projection = ProjectionHead(embedding_dim=image_embedding_dim, projection_dim=projection_dim)
57
  self.text_projection = ProjectionHead(embedding_dim=text_embedding_dim, projection_dim=projection_dim)
58
 
59
  def forward(self, image_features=None, text_input_ids=None, text_attention_mask=None):
 
 
 
 
60
  image_embedding = None
61
  if image_features is not None:
62
- image_features = self.vision_encoder(image_features)
63
- image_embedding = self.image_projection(image_features)
 
 
64
 
65
  text_embedding = None
66
  if text_input_ids is not None:
67
- text_features = self.text_encoder(
 
68
  input_ids=text_input_ids,
69
  attention_mask=text_attention_mask
70
  )
71
- text_embedding = self.text_projection(text_features)
 
72
 
73
  return image_embedding, text_embedding
 
 
1
  import torch
2
  import torch.nn as nn
3
  from torchvision.models import resnet50
4
  from transformers import DistilBertModel
5
 
6
+ # --- Helper Classes (VisionEncoder, TextEncoder, ProjectionHead) ---
7
+ # These define the components of the overall CLIP model.
8
+
9
  class VisionEncoder(nn.Module):
10
  def __init__(self):
11
  super().__init__()
12
+ # Use the recommended 'weights' parameter for pre-trained models
13
  pretrained_resnet50 = resnet50(weights='IMAGENET1K_V1')
14
+ # Use all layers of ResNet50 except for the final fully connected layer
15
  self.model = nn.Sequential(*list(pretrained_resnet50.children())[:-1])
16
+ # Freeze the parameters of the vision encoder
17
  for param in self.model.parameters():
18
  param.requires_grad = False
19
 
20
  def forward(self, x):
21
  x = self.model(x)
22
+ # Flatten the output to a 1D tensor per image
23
  return x.view(x.size(0), -1)
24
 
25
  class TextEncoder(nn.Module):
26
  def __init__(self):
27
  super().__init__()
28
  self.model = DistilBertModel.from_pretrained('distilbert-base-uncased')
29
+ # Freeze the parameters of the text encoder
30
  for param in self.model.parameters():
31
  param.requires_grad = False
32
 
33
  def forward(self, input_ids, attention_mask=None):
34
  outputs = self.model(input_ids=input_ids, attention_mask=attention_mask)
35
+ # Use the embedding of the [CLS] token as the sentence representation
36
  return outputs.last_hidden_state[:, 0, :]
37
 
38
  class ProjectionHead(nn.Module):
 
49
  x = self.gelu(projected)
50
  x = self.fc(x)
51
  x = self.dropout(x)
52
+ # Add a residual connection
53
  x = x + projected
54
  x = self.layer_norm(x)
55
  return x
56
 
57
+ # --- Main CLIPModel for Inference ---
58
+ # This class combines the encoders and projection heads.
59
+
60
  class CLIPModel(nn.Module):
61
  def __init__(self, image_embedding_dim, text_embedding_dim, projection_dim):
62
  super().__init__()
63
+
64
+ # **CORRECTION**: Renamed 'vision_encoder' to 'image_encoder'
65
+ # This attribute MUST be named 'image_encoder' to match the call in app.py
66
+ self.image_encoder = VisionEncoder()
67
+
68
  self.text_encoder = TextEncoder()
69
  self.image_projection = ProjectionHead(embedding_dim=image_embedding_dim, projection_dim=projection_dim)
70
  self.text_projection = ProjectionHead(embedding_dim=text_embedding_dim, projection_dim=projection_dim)
71
 
72
  def forward(self, image_features=None, text_input_ids=None, text_attention_mask=None):
73
+ """
74
+ This forward pass handles both image and text inputs.
75
+ app.py will call this to get the final, projected embeddings.
76
+ """
77
  image_embedding = None
78
  if image_features is not None:
79
+ # Get raw features from the vision backbone
80
+ image_features_raw = self.image_encoder(image_features)
81
+ # Project them into the shared embedding space
82
+ image_embedding = self.image_projection(image_features_raw)
83
 
84
  text_embedding = None
85
  if text_input_ids is not None:
86
+ # Get raw features from the text backbone
87
+ text_features_raw = self.text_encoder(
88
  input_ids=text_input_ids,
89
  attention_mask=text_attention_mask
90
  )
91
+ # Project them into the shared embedding space
92
+ text_embedding = self.text_projection(text_features_raw)
93
 
94
  return image_embedding, text_embedding