import base64 import os import uuid from datetime import datetime from io import BytesIO import numpy as np from datasets import load_dataset from dotenv import load_dotenv from PIL import Image from qdrant_client import QdrantClient, models from .embedding_utils import ColPaliEmbeddingGenerator load_dotenv() repo_id = "mahimairaja/ibm-hls-burn-original" ds_from_hub = load_dataset(repo_id) collection_name = "hls_burn_scars_data_colpali" generator = ColPaliEmbeddingGenerator() def image_to_base64(image): """Converts a PIL Image to a base64 string.""" buffered = BytesIO() image.save(buffered, format="PNG") return base64.b64encode(buffered.getvalue()).decode("utf-8") def generate_qdrant_points(dataset_split): points = [] images_to_process = [] items_to_process = [] for i, item in enumerate(dataset_split): try: red = item["red"] green = item["green"] blue = item["blue"] def normalize_band(band): band = np.array(band, dtype=np.float32) if band.max() <= 1.0: band = band * 255.0 return band.astype(np.uint8) r_img = normalize_band(red) g_img = normalize_band(green) b_img = normalize_band(blue) rgb_array = np.stack([r_img, g_img, b_img], axis=2) image = Image.fromarray(rgb_array) images_to_process.append(image) items_to_process.append(item) except Exception as e: print(f"Skipping item {i} due to error preparing image: {e}") continue if not images_to_process: return [] try: embeddings_batch = generator.generate_image_embeddings(images_to_process) except Exception as e: print(f"Error generating embeddings: {e}") return [] for item, image, embedding in zip( items_to_process, images_to_process, embeddings_batch ): try: burn_area = item.get("burn_area", 0.0) annotation = item["annotation"] if isinstance(annotation, np.ndarray): annotation = Image.fromarray(annotation) image_b64 = image_to_base64(annotation) rgb_image_b64 = image_to_base64(image) point_id = str(uuid.uuid4()) point = models.PointStruct( id=point_id, vector={ "colpali": embedding, }, payload={ "centroid_lat": float(item["latitude"]), "centroid_lon": float(item["longitude"]), "acquisition_date": datetime.strptime( item["date"], "%Y-%m-%dT%H:%M:%SZ" ), "burn_area": float(burn_area), "image_base64": image_b64, "rgb_image_base64": rgb_image_b64, }, ) points.append(point) except Exception as e: print(f"Error creating point: {e}") continue return points def main(): client = QdrantClient( url=os.getenv("QDRANT_URL"), api_key=os.getenv("QDRANT_API_KEY"), ) if client.collection_exists(collection_name=collection_name): client.delete_collection(collection_name=collection_name) client.create_collection( collection_name=collection_name, vectors_config={ "colpali": models.VectorParams( size=128, # ColPali dim distance=models.Distance.COSINE, multivector_config=models.MultiVectorConfig( comparator=models.MultiVectorComparator.MAX_SIM ), ), }, optimizers_config=models.OptimizersConfigDiff(default_segment_number=2), ) client.create_payload_index( collection_name=collection_name, field_name="centroid_lat", field_schema=models.PayloadSchemaType.FLOAT, ) client.create_payload_index( collection_name=collection_name, field_name="centroid_lon", field_schema=models.PayloadSchemaType.FLOAT, ) client.create_payload_index( collection_name=collection_name, field_name="acquisition_date", field_schema=models.PayloadSchemaType.DATETIME, ) client.create_payload_index( collection_name=collection_name, field_name="burn_area", field_schema=models.PayloadSchemaType.FLOAT, ) for split_name, dataset_split in ds_from_hub.items(): print(f"Ingesting {len(dataset_split)} points from '{split_name}' split...") dataset_split = dataset_split.with_format("numpy") batch_size = 4 total_points = len(dataset_split) for start_idx in range(0, total_points, batch_size): end_idx = min(start_idx + batch_size, total_points) batch = dataset_split.select(range(start_idx, end_idx)) print(f"Processing batch {start_idx} to {end_idx}...") qdrant_points = generate_qdrant_points(batch) if qdrant_points: client.upsert( collection_name=collection_name, points=qdrant_points, wait=True ) print(f"Upserted batch {start_idx} to {end_idx}.") print(f"Finished ingesting '{split_name}' split.") print("Data ingestion complete for all splits.") count_result = client.count(collection_name=collection_name, exact=True) print(f"Total points in Qdrant collection: {count_result.count}") if __name__ == "__main__": main()