import torch import numpy as np import faiss from data import PropertyEmbeddingDataset property_field_list = ["price", "average_rating", "lat", "lon", "type_enc"] from torch.utils.data import DataLoader from tqdm import tqdm def build_faiss_index( model, property_df, batch_size=128, index_path="property_faiss.index", id_map_path="property_id_map.npy", ): """ Builds a FAISS index for property embeddings using batched processing. Args: model: The model with a prop_tower(texts, features) method. dataset (RecommenderDataset): The dataset containing property features and text. property_ids (np.ndarray): An array of property IDs aligned with the dataset. batch_size (int): Batch size for processing. index_path (str): Path to save the FAISS index. id_map_path (str): Path to save the property ID map. """ model.eval() dim = None all_embs = [] dataset = PropertyEmbeddingDataset( property_df[property_field_list].values.astype(np.float32), property_df["text"].values, ) dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=False) for batch in tqdm(dataloader, desc="Building FAISS index"): device = torch.device("cuda" if torch.cuda.is_available() else "cpu") batch_texts = batch["text"] batch_feats = batch["data"].to(device) with torch.no_grad(): emb, _ = model.prop_tower(batch_texts, batch_feats) emb = emb.cpu().numpy() faiss.normalize_L2(emb) all_embs.append(emb) # Stack all embeddings all_embs = np.vstack(all_embs) # Create FAISS index dim = all_embs.shape[1] index = faiss.IndexFlatIP(dim) index.add(all_embs) # Save the index and property ID mapping faiss.write_index(index, index_path) np.save(id_map_path, len(property_df["id"]))