File size: 1,783 Bytes
2e51bae
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
# dataset.py
import os
import torch
from torch.utils.data import Dataset
from PIL import Image
import pandas as pd
import torchvision.transforms as transforms

class Flickr8kDataset(Dataset):
    """
    Custom PyTorch Dataset for the Flickr8k data.
    It loads images and their corresponding captions, tokenizing the text
    on initialization for efficiency.
    """
    def __init__(self, image_dir, caption_file, tokenizer):
        self.image_dir = image_dir
        self.tokenizer = tokenizer

        self.transform = transforms.Compose([
            transforms.Resize((224, 224)),
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
        ])

        df = pd.read_csv(caption_file)
        self.image_paths = [os.path.join(self.image_dir, fname) for fname in df['image']]
        self.captions = df['caption'].tolist()

        print("Tokenizing all captions... (This may take a moment)")
        self.caption_encodings = self.tokenizer(
            self.captions,
            truncation=True,
            padding='max_length',
            max_length=200,
            return_tensors="pt"
        )
        print("Tokenization complete.")

    def __len__(self):
        return len(self.captions)

    def __getitem__(self, idx):
        item = {key: val[idx] for key, val in self.caption_encodings.items()}
        try:
            img = Image.open(self.image_paths[idx]).convert("RGB")
            item['image'] = self.transform(img)
        except (FileNotFoundError):
            print(f"Warning: Could not load image at {self.image_paths[idx]}. Returning a black image.")
            item['image'] = torch.zeros((3, 224, 224))
        item["caption_text"] = self.captions[idx]
        return item