Mustafa Acikgoz commited on
Commit
2422360
·
1 Parent(s): 8818841

Fix model argument name and update logic

Browse files
Files changed (2) hide show
  1. app.py +7 -8
  2. inference_model.py +5 -5
app.py CHANGED
@@ -24,7 +24,9 @@ model = CLIPModel(
24
  projection_dim=config.PROJECTION_DIM
25
  ).to(device)
26
 
27
- # Load your trained model weights (.pth file)
 
 
28
  try:
29
  model.load_state_dict(torch.load(config.MODEL_PATH, map_location=device))
30
  model.eval()
@@ -47,15 +49,14 @@ snapshot_download(
47
  repo_id=DATASET_REPO_ID,
48
  repo_type="dataset",
49
  local_dir=IMAGE_STORAGE_PATH,
50
- local_dir_use_symlinks=False # Set to False for Spaces compatibility
51
  )
52
  print("Image dataset download complete.")
53
 
54
  # Get a list of all image file paths
55
  all_image_paths = glob.glob(os.path.join(IMAGE_STORAGE_PATH, "Flicker8k_Dataset", "*.jpg"))
56
 
57
- # **CRITICAL FIX FOR TIMEOUT**: Use a smaller subset of images for the demo.
58
- # Processing all 8000+ images on startup will cause a timeout on Hugging Face Spaces.
59
  NUM_IMAGES_TO_PROCESS = 1000
60
  all_image_paths = all_image_paths[:NUM_IMAGES_TO_PROCESS]
61
  print(f"Found {len(all_image_paths)} total images. Using a subset of {NUM_IMAGES_TO_PROCESS} to prevent timeout.")
@@ -77,8 +78,6 @@ def precompute_image_embeddings(image_paths, model, transform, device):
77
  image = Image.open(path).convert("RGB")
78
  image_tensor = transform(image).unsqueeze(0).to(device)
79
 
80
- # **CORRECTION**: Use the full model's forward pass to get projected embeddings.
81
- # This returns (image_embedding, text_embedding), so we take the first element.
82
  embedding, _ = model(image_features=image_tensor)
83
 
84
  all_embeddings.append(embedding)
@@ -111,8 +110,8 @@ def find_image_from_text(text_query):
111
  # 1. Process the text query
112
  text_inputs = tokenizer([text_query], padding=True, truncation=True, return_tensors="pt").to(device)
113
 
114
- # 2. **CORRECTION**: Use the full model's forward pass to get projected text embedding.
115
- # This returns (image_embedding, text_embedding), so we take the second element.
116
  _, text_embedding = model(
117
  text_input_ids=text_inputs['input_ids'],
118
  attention_mask=text_inputs['attention_mask']
 
24
  projection_dim=config.PROJECTION_DIM
25
  ).to(device)
26
 
27
+ # --- CRITICAL STEP ---
28
+ # The application will fail if it cannot find the file specified in config.MODEL_PATH.
29
+ # Make sure "clip_book_model.pth" is in the same directory as this script.
30
  try:
31
  model.load_state_dict(torch.load(config.MODEL_PATH, map_location=device))
32
  model.eval()
 
49
  repo_id=DATASET_REPO_ID,
50
  repo_type="dataset",
51
  local_dir=IMAGE_STORAGE_PATH,
52
+ local_dir_use_symlinks=False
53
  )
54
  print("Image dataset download complete.")
55
 
56
  # Get a list of all image file paths
57
  all_image_paths = glob.glob(os.path.join(IMAGE_STORAGE_PATH, "Flicker8k_Dataset", "*.jpg"))
58
 
59
+ # Use a smaller subset of images to prevent timeouts on public platforms.
 
60
  NUM_IMAGES_TO_PROCESS = 1000
61
  all_image_paths = all_image_paths[:NUM_IMAGES_TO_PROCESS]
62
  print(f"Found {len(all_image_paths)} total images. Using a subset of {NUM_IMAGES_TO_PROCESS} to prevent timeout.")
 
78
  image = Image.open(path).convert("RGB")
79
  image_tensor = transform(image).unsqueeze(0).to(device)
80
 
 
 
81
  embedding, _ = model(image_features=image_tensor)
82
 
83
  all_embeddings.append(embedding)
 
110
  # 1. Process the text query
111
  text_inputs = tokenizer([text_query], padding=True, truncation=True, return_tensors="pt").to(device)
112
 
113
+ # 2. Get the projected text embedding from the model.
114
+ # No change is needed here because inference_model.py was updated to expect 'attention_mask'.
115
  _, text_embedding = model(
116
  text_input_ids=text_inputs['input_ids'],
117
  attention_mask=text_inputs['attention_mask']
inference_model.py CHANGED
@@ -61,18 +61,18 @@ class CLIPModel(nn.Module):
61
  def __init__(self, image_embedding_dim, text_embedding_dim, projection_dim):
62
  super().__init__()
63
 
64
- # **CORRECTION**: Renamed 'vision_encoder' to 'image_encoder'
65
- # This attribute MUST be named 'image_encoder' to match the call in app.py
66
  self.image_encoder = VisionEncoder()
67
-
68
  self.text_encoder = TextEncoder()
69
  self.image_projection = ProjectionHead(embedding_dim=image_embedding_dim, projection_dim=projection_dim)
70
  self.text_projection = ProjectionHead(embedding_dim=text_embedding_dim, projection_dim=projection_dim)
71
 
72
- def forward(self, image_features=None, text_input_ids=None, text_attention_mask=None):
73
  """
74
  This forward pass handles both image and text inputs.
75
  app.py will call this to get the final, projected embeddings.
 
 
 
76
  """
77
  image_embedding = None
78
  if image_features is not None:
@@ -86,7 +86,7 @@ class CLIPModel(nn.Module):
86
  # Get raw features from the text backbone
87
  text_features_raw = self.text_encoder(
88
  input_ids=text_input_ids,
89
- attention_mask=text_attention_mask
90
  )
91
  # Project them into the shared embedding space
92
  text_embedding = self.text_projection(text_features_raw)
 
61
  def __init__(self, image_embedding_dim, text_embedding_dim, projection_dim):
62
  super().__init__()
63
 
 
 
64
  self.image_encoder = VisionEncoder()
 
65
  self.text_encoder = TextEncoder()
66
  self.image_projection = ProjectionHead(embedding_dim=image_embedding_dim, projection_dim=projection_dim)
67
  self.text_projection = ProjectionHead(embedding_dim=text_embedding_dim, projection_dim=projection_dim)
68
 
69
+ def forward(self, image_features=None, text_input_ids=None, attention_mask=None):
70
  """
71
  This forward pass handles both image and text inputs.
72
  app.py will call this to get the final, projected embeddings.
73
+
74
+ **MODIFICATION**: Renamed 'text_attention_mask' to 'attention_mask' for
75
+ compatibility with the standard Hugging Face tokenizer output.
76
  """
77
  image_embedding = None
78
  if image_features is not None:
 
86
  # Get raw features from the text backbone
87
  text_features_raw = self.text_encoder(
88
  input_ids=text_input_ids,
89
+ attention_mask=attention_mask
90
  )
91
  # Project them into the shared embedding space
92
  text_embedding = self.text_projection(text_features_raw)