Spaces:
Runtime error
Runtime error
Mustafa Acikgoz
commited on
Commit
·
2422360
1
Parent(s):
8818841
Fix model argument name and update logic
Browse files- app.py +7 -8
- inference_model.py +5 -5
app.py
CHANGED
|
@@ -24,7 +24,9 @@ model = CLIPModel(
|
|
| 24 |
projection_dim=config.PROJECTION_DIM
|
| 25 |
).to(device)
|
| 26 |
|
| 27 |
-
#
|
|
|
|
|
|
|
| 28 |
try:
|
| 29 |
model.load_state_dict(torch.load(config.MODEL_PATH, map_location=device))
|
| 30 |
model.eval()
|
|
@@ -47,15 +49,14 @@ snapshot_download(
|
|
| 47 |
repo_id=DATASET_REPO_ID,
|
| 48 |
repo_type="dataset",
|
| 49 |
local_dir=IMAGE_STORAGE_PATH,
|
| 50 |
-
local_dir_use_symlinks=False
|
| 51 |
)
|
| 52 |
print("Image dataset download complete.")
|
| 53 |
|
| 54 |
# Get a list of all image file paths
|
| 55 |
all_image_paths = glob.glob(os.path.join(IMAGE_STORAGE_PATH, "Flicker8k_Dataset", "*.jpg"))
|
| 56 |
|
| 57 |
-
#
|
| 58 |
-
# Processing all 8000+ images on startup will cause a timeout on Hugging Face Spaces.
|
| 59 |
NUM_IMAGES_TO_PROCESS = 1000
|
| 60 |
all_image_paths = all_image_paths[:NUM_IMAGES_TO_PROCESS]
|
| 61 |
print(f"Found {len(all_image_paths)} total images. Using a subset of {NUM_IMAGES_TO_PROCESS} to prevent timeout.")
|
|
@@ -77,8 +78,6 @@ def precompute_image_embeddings(image_paths, model, transform, device):
|
|
| 77 |
image = Image.open(path).convert("RGB")
|
| 78 |
image_tensor = transform(image).unsqueeze(0).to(device)
|
| 79 |
|
| 80 |
-
# **CORRECTION**: Use the full model's forward pass to get projected embeddings.
|
| 81 |
-
# This returns (image_embedding, text_embedding), so we take the first element.
|
| 82 |
embedding, _ = model(image_features=image_tensor)
|
| 83 |
|
| 84 |
all_embeddings.append(embedding)
|
|
@@ -111,8 +110,8 @@ def find_image_from_text(text_query):
|
|
| 111 |
# 1. Process the text query
|
| 112 |
text_inputs = tokenizer([text_query], padding=True, truncation=True, return_tensors="pt").to(device)
|
| 113 |
|
| 114 |
-
# 2.
|
| 115 |
-
#
|
| 116 |
_, text_embedding = model(
|
| 117 |
text_input_ids=text_inputs['input_ids'],
|
| 118 |
attention_mask=text_inputs['attention_mask']
|
|
|
|
| 24 |
projection_dim=config.PROJECTION_DIM
|
| 25 |
).to(device)
|
| 26 |
|
| 27 |
+
# --- CRITICAL STEP ---
|
| 28 |
+
# The application will fail if it cannot find the file specified in config.MODEL_PATH.
|
| 29 |
+
# Make sure "clip_book_model.pth" is in the same directory as this script.
|
| 30 |
try:
|
| 31 |
model.load_state_dict(torch.load(config.MODEL_PATH, map_location=device))
|
| 32 |
model.eval()
|
|
|
|
| 49 |
repo_id=DATASET_REPO_ID,
|
| 50 |
repo_type="dataset",
|
| 51 |
local_dir=IMAGE_STORAGE_PATH,
|
| 52 |
+
local_dir_use_symlinks=False
|
| 53 |
)
|
| 54 |
print("Image dataset download complete.")
|
| 55 |
|
| 56 |
# Get a list of all image file paths
|
| 57 |
all_image_paths = glob.glob(os.path.join(IMAGE_STORAGE_PATH, "Flicker8k_Dataset", "*.jpg"))
|
| 58 |
|
| 59 |
+
# Use a smaller subset of images to prevent timeouts on public platforms.
|
|
|
|
| 60 |
NUM_IMAGES_TO_PROCESS = 1000
|
| 61 |
all_image_paths = all_image_paths[:NUM_IMAGES_TO_PROCESS]
|
| 62 |
print(f"Found {len(all_image_paths)} total images. Using a subset of {NUM_IMAGES_TO_PROCESS} to prevent timeout.")
|
|
|
|
| 78 |
image = Image.open(path).convert("RGB")
|
| 79 |
image_tensor = transform(image).unsqueeze(0).to(device)
|
| 80 |
|
|
|
|
|
|
|
| 81 |
embedding, _ = model(image_features=image_tensor)
|
| 82 |
|
| 83 |
all_embeddings.append(embedding)
|
|
|
|
| 110 |
# 1. Process the text query
|
| 111 |
text_inputs = tokenizer([text_query], padding=True, truncation=True, return_tensors="pt").to(device)
|
| 112 |
|
| 113 |
+
# 2. Get the projected text embedding from the model.
|
| 114 |
+
# No change is needed here because inference_model.py was updated to expect 'attention_mask'.
|
| 115 |
_, text_embedding = model(
|
| 116 |
text_input_ids=text_inputs['input_ids'],
|
| 117 |
attention_mask=text_inputs['attention_mask']
|
inference_model.py
CHANGED
|
@@ -61,18 +61,18 @@ class CLIPModel(nn.Module):
|
|
| 61 |
def __init__(self, image_embedding_dim, text_embedding_dim, projection_dim):
|
| 62 |
super().__init__()
|
| 63 |
|
| 64 |
-
# **CORRECTION**: Renamed 'vision_encoder' to 'image_encoder'
|
| 65 |
-
# This attribute MUST be named 'image_encoder' to match the call in app.py
|
| 66 |
self.image_encoder = VisionEncoder()
|
| 67 |
-
|
| 68 |
self.text_encoder = TextEncoder()
|
| 69 |
self.image_projection = ProjectionHead(embedding_dim=image_embedding_dim, projection_dim=projection_dim)
|
| 70 |
self.text_projection = ProjectionHead(embedding_dim=text_embedding_dim, projection_dim=projection_dim)
|
| 71 |
|
| 72 |
-
def forward(self, image_features=None, text_input_ids=None,
|
| 73 |
"""
|
| 74 |
This forward pass handles both image and text inputs.
|
| 75 |
app.py will call this to get the final, projected embeddings.
|
|
|
|
|
|
|
|
|
|
| 76 |
"""
|
| 77 |
image_embedding = None
|
| 78 |
if image_features is not None:
|
|
@@ -86,7 +86,7 @@ class CLIPModel(nn.Module):
|
|
| 86 |
# Get raw features from the text backbone
|
| 87 |
text_features_raw = self.text_encoder(
|
| 88 |
input_ids=text_input_ids,
|
| 89 |
-
attention_mask=
|
| 90 |
)
|
| 91 |
# Project them into the shared embedding space
|
| 92 |
text_embedding = self.text_projection(text_features_raw)
|
|
|
|
| 61 |
def __init__(self, image_embedding_dim, text_embedding_dim, projection_dim):
|
| 62 |
super().__init__()
|
| 63 |
|
|
|
|
|
|
|
| 64 |
self.image_encoder = VisionEncoder()
|
|
|
|
| 65 |
self.text_encoder = TextEncoder()
|
| 66 |
self.image_projection = ProjectionHead(embedding_dim=image_embedding_dim, projection_dim=projection_dim)
|
| 67 |
self.text_projection = ProjectionHead(embedding_dim=text_embedding_dim, projection_dim=projection_dim)
|
| 68 |
|
| 69 |
+
def forward(self, image_features=None, text_input_ids=None, attention_mask=None):
|
| 70 |
"""
|
| 71 |
This forward pass handles both image and text inputs.
|
| 72 |
app.py will call this to get the final, projected embeddings.
|
| 73 |
+
|
| 74 |
+
**MODIFICATION**: Renamed 'text_attention_mask' to 'attention_mask' for
|
| 75 |
+
compatibility with the standard Hugging Face tokenizer output.
|
| 76 |
"""
|
| 77 |
image_embedding = None
|
| 78 |
if image_features is not None:
|
|
|
|
| 86 |
# Get raw features from the text backbone
|
| 87 |
text_features_raw = self.text_encoder(
|
| 88 |
input_ids=text_input_ids,
|
| 89 |
+
attention_mask=attention_mask
|
| 90 |
)
|
| 91 |
# Project them into the shared embedding space
|
| 92 |
text_embedding = self.text_projection(text_features_raw)
|