mahimairaja commited on
Commit
13720cb
·
1 Parent(s): 1044a0a

utils: embedding class using colpali

Browse files
Files changed (1) hide show
  1. utils/embedding_utils.py +53 -0
utils/embedding_utils.py ADDED
@@ -0,0 +1,53 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from colpali_engine.models import ColIdefics3, ColIdefics3Processor
3
+
4
+
5
+ class ColPaliEmbeddingGenerator:
6
+ def __init__(self, model_name="vidore/colSmol-500M"):
7
+ """
8
+ Initializes the ColPali embedding generator.
9
+ """
10
+ print(f"Initializing ColPali Model (Smol): {model_name}...")
11
+
12
+ self.device = "cuda" if torch.cuda.is_available() else "cpu"
13
+ if torch.backends.mps.is_available():
14
+ self.device = "mps"
15
+
16
+ print(f"Using device: {self.device}")
17
+
18
+ self.model = ColIdefics3.from_pretrained(
19
+ model_name,
20
+ dtype=torch.bfloat16 if self.device != "cpu" else torch.float32,
21
+ device_map=self.device,
22
+ ).eval()
23
+
24
+ self.processor = ColIdefics3Processor.from_pretrained(model_name)
25
+
26
+ def generate_image_embeddings(self, images):
27
+ """
28
+ Generates embeddings for a list of PIL Images.
29
+ Returns a list of list of vectors (one list of vectors per image).
30
+ """
31
+ if not isinstance(images, list):
32
+ images = [images]
33
+
34
+ batch_images = self.processor.process_images(images).to(self.device)
35
+
36
+ with torch.no_grad():
37
+ image_embeddings = self.model(**batch_images)
38
+
39
+ return [emb.cpu().float().numpy().tolist() for emb in image_embeddings]
40
+
41
+ def generate_query_embeddings(self, queries):
42
+ """
43
+ Generates embeddings for a list of text queries.
44
+ Returns a list of list of vectors (one list of vectors per query).
45
+ """
46
+ if not isinstance(queries, list):
47
+ queries = [queries]
48
+ batch_queries = self.processor.process_queries(queries).to(self.device)
49
+
50
+ with torch.no_grad():
51
+ query_embeddings = self.model(**batch_queries)
52
+
53
+ return [emb.cpu().float().numpy().tolist() for emb in query_embeddings]