clip_ / dataset.py
Mustafa Acikgoz
Initial clean commit for Gradio app
2e51bae
raw
history blame
1.78 kB
# dataset.py
import os
import torch
from torch.utils.data import Dataset
from PIL import Image
import pandas as pd
import torchvision.transforms as transforms
class Flickr8kDataset(Dataset):
"""
Custom PyTorch Dataset for the Flickr8k data.
It loads images and their corresponding captions, tokenizing the text
on initialization for efficiency.
"""
def __init__(self, image_dir, caption_file, tokenizer):
self.image_dir = image_dir
self.tokenizer = tokenizer
self.transform = transforms.Compose([
transforms.Resize((224, 224)),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])
df = pd.read_csv(caption_file)
self.image_paths = [os.path.join(self.image_dir, fname) for fname in df['image']]
self.captions = df['caption'].tolist()
print("Tokenizing all captions... (This may take a moment)")
self.caption_encodings = self.tokenizer(
self.captions,
truncation=True,
padding='max_length',
max_length=200,
return_tensors="pt"
)
print("Tokenization complete.")
def __len__(self):
return len(self.captions)
def __getitem__(self, idx):
item = {key: val[idx] for key, val in self.caption_encodings.items()}
try:
img = Image.open(self.image_paths[idx]).convert("RGB")
item['image'] = self.transform(img)
except (FileNotFoundError):
print(f"Warning: Could not load image at {self.image_paths[idx]}. Returning a black image.")
item['image'] = torch.zeros((3, 224, 224))
item["caption_text"] = self.captions[idx]
return item