File size: 3,855 Bytes
2e51bae
 
 
 
 
8818841
 
 
2e51bae
 
 
8818841
2e51bae
8818841
2e51bae
8818841
2e51bae
 
 
 
 
8818841
2e51bae
 
 
 
 
 
8818841
2e51bae
 
 
 
 
8818841
2e51bae
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8818841
2e51bae
 
 
 
8818841
 
 
2e51bae
 
 
8818841
 
2e51bae
 
 
 
2422360
8818841
 
 
2422360
 
 
8818841
2e51bae
 
8818841
 
 
 
2e51bae
 
 
8818841
 
2e51bae
2422360
2e51bae
8818841
 
2e51bae
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
import torch
import torch.nn as nn
from torchvision.models import resnet50
from transformers import DistilBertModel

# --- Helper Classes (VisionEncoder, TextEncoder, ProjectionHead) ---
# These define the components of the overall CLIP model.

class VisionEncoder(nn.Module):
    def __init__(self):
        super().__init__()
        # Use the recommended 'weights' parameter for pre-trained models
        pretrained_resnet50 = resnet50(weights='IMAGENET1K_V1')
        # Use all layers of ResNet50 except for the final fully connected layer
        self.model = nn.Sequential(*list(pretrained_resnet50.children())[:-1])
        # Freeze the parameters of the vision encoder
        for param in self.model.parameters():
            param.requires_grad = False

    def forward(self, x):
        x = self.model(x)
        # Flatten the output to a 1D tensor per image
        return x.view(x.size(0), -1)

class TextEncoder(nn.Module):
    def __init__(self):
        super().__init__()
        self.model = DistilBertModel.from_pretrained('distilbert-base-uncased')
        # Freeze the parameters of the text encoder
        for param in self.model.parameters():
            param.requires_grad = False

    def forward(self, input_ids, attention_mask=None):
        outputs = self.model(input_ids=input_ids, attention_mask=attention_mask)
        # Use the embedding of the [CLS] token as the sentence representation
        return outputs.last_hidden_state[:, 0, :]

class ProjectionHead(nn.Module):
    def __init__(self, embedding_dim, projection_dim=256, dropout=0.1):
        super().__init__()
        self.projection = nn.Linear(embedding_dim, projection_dim)
        self.gelu = nn.GELU()
        self.fc = nn.Linear(projection_dim, projection_dim)
        self.dropout = nn.Dropout(dropout)
        self.layer_norm = nn.LayerNorm(projection_dim)

    def forward(self, x):
        projected = self.projection(x)
        x = self.gelu(projected)
        x = self.fc(x)
        x = self.dropout(x)
        # Add a residual connection
        x = x + projected
        x = self.layer_norm(x)
        return x

# --- Main CLIPModel for Inference ---
# This class combines the encoders and projection heads.

class CLIPModel(nn.Module):
    def __init__(self, image_embedding_dim, text_embedding_dim, projection_dim):
        super().__init__()
        
        self.image_encoder = VisionEncoder()
        self.text_encoder = TextEncoder()
        self.image_projection = ProjectionHead(embedding_dim=image_embedding_dim, projection_dim=projection_dim)
        self.text_projection = ProjectionHead(embedding_dim=text_embedding_dim, projection_dim=projection_dim)

    def forward(self, image_features=None, text_input_ids=None, attention_mask=None):
        """
        This forward pass handles both image and text inputs.
        app.py will call this to get the final, projected embeddings.
        
        **MODIFICATION**: Renamed 'text_attention_mask' to 'attention_mask' for
        compatibility with the standard Hugging Face tokenizer output.
        """
        image_embedding = None
        if image_features is not None:
            # Get raw features from the vision backbone
            image_features_raw = self.image_encoder(image_features)
            # Project them into the shared embedding space
            image_embedding = self.image_projection(image_features_raw)

        text_embedding = None
        if text_input_ids is not None:
            # Get raw features from the text backbone
            text_features_raw = self.text_encoder(
                input_ids=text_input_ids,
                attention_mask=attention_mask
            )
            # Project them into the shared embedding space
            text_embedding = self.text_projection(text_features_raw)

        return image_embedding, text_embedding