import torch from colpali_engine.models import ColIdefics3, ColIdefics3Processor class ColPaliEmbeddingGenerator: def __init__(self, model_name="vidore/colSmol-500M"): """ Initializes the ColPali embedding generator. """ print(f"Initializing ColPali Model (Smol): {model_name}...") self.device = "cuda" if torch.cuda.is_available() else "cpu" if torch.backends.mps.is_available(): self.device = "mps" print(f"Using device: {self.device}") self.model = ColIdefics3.from_pretrained( model_name, torch_dtype=torch.bfloat16 if self.device != "cpu" else torch.float32, device_map=self.device, ).eval() self.processor = ColIdefics3Processor.from_pretrained(model_name) def generate_image_embeddings(self, images): """ Generates embeddings for a list of PIL Images. Returns a list of list of vectors (one list of vectors per image). """ if not isinstance(images, list): images = [images] batch_images = self.processor.process_images(images).to(self.device) with torch.no_grad(): image_embeddings = self.model(**batch_images) return [emb.cpu().float().numpy().tolist() for emb in image_embeddings] def generate_query_embeddings(self, queries): """ Generates embeddings for a list of text queries. Returns a list of list of vectors (one list of vectors per query). """ if not isinstance(queries, list): queries = [queries] batch_queries = self.processor.process_queries(queries).to(self.device) with torch.no_grad(): query_embeddings = self.model(**batch_queries) return [emb.cpu().float().numpy().tolist() for emb in query_embeddings]