File size: 1,856 Bytes
13720cb
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
63dd91e
13720cb
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
import torch
from colpali_engine.models import ColIdefics3, ColIdefics3Processor


class ColPaliEmbeddingGenerator:
    def __init__(self, model_name="vidore/colSmol-500M"):
        """
        Initializes the ColPali embedding generator.
        """
        print(f"Initializing ColPali Model (Smol): {model_name}...")

        self.device = "cuda" if torch.cuda.is_available() else "cpu"
        if torch.backends.mps.is_available():
            self.device = "mps"

        print(f"Using device: {self.device}")

        self.model = ColIdefics3.from_pretrained(
            model_name,
            torch_dtype=torch.bfloat16 if self.device != "cpu" else torch.float32,
            device_map=self.device,
        ).eval()

        self.processor = ColIdefics3Processor.from_pretrained(model_name)

    def generate_image_embeddings(self, images):
        """
        Generates embeddings for a list of PIL Images.
        Returns a list of list of vectors (one list of vectors per image).
        """
        if not isinstance(images, list):
            images = [images]

        batch_images = self.processor.process_images(images).to(self.device)

        with torch.no_grad():
            image_embeddings = self.model(**batch_images)

        return [emb.cpu().float().numpy().tolist() for emb in image_embeddings]

    def generate_query_embeddings(self, queries):
        """
        Generates embeddings for a list of text queries.
        Returns a list of list of vectors (one list of vectors per query).
        """
        if not isinstance(queries, list):
            queries = [queries]
        batch_queries = self.processor.process_queries(queries).to(self.device)

        with torch.no_grad():
            query_embeddings = self.model(**batch_queries)

        return [emb.cpu().float().numpy().tolist() for emb in query_embeddings]