mahimairaja commited on
Commit
63dd91e
·
1 Parent(s): 0251dd6

fix: torch dtype while loading the model

Browse files
Files changed (1) hide show
  1. utils/embedding_utils.py +1 -1
utils/embedding_utils.py CHANGED
@@ -17,7 +17,7 @@ class ColPaliEmbeddingGenerator:
17
 
18
  self.model = ColIdefics3.from_pretrained(
19
  model_name,
20
- dtype=torch.bfloat16 if self.device != "cpu" else torch.float32,
21
  device_map=self.device,
22
  ).eval()
23
 
 
17
 
18
  self.model = ColIdefics3.from_pretrained(
19
  model_name,
20
+ torch_dtype=torch.bfloat16 if self.device != "cpu" else torch.float32,
21
  device_map=self.device,
22
  ).eval()
23