mahimairaja's picture
util: ingest data to qdrant
0b41e06
import base64
import os
import uuid
from datetime import datetime
from io import BytesIO
import numpy as np
from datasets import load_dataset
from dotenv import load_dotenv
from PIL import Image
from qdrant_client import QdrantClient, models
from .embedding_utils import ColPaliEmbeddingGenerator
load_dotenv()
repo_id = "mahimairaja/ibm-hls-burn-original"
ds_from_hub = load_dataset(repo_id)
collection_name = "hls_burn_scars_data_colpali"
generator = ColPaliEmbeddingGenerator()
def image_to_base64(image):
"""Converts a PIL Image to a base64 string."""
buffered = BytesIO()
image.save(buffered, format="PNG")
return base64.b64encode(buffered.getvalue()).decode("utf-8")
def generate_qdrant_points(dataset_split):
points = []
images_to_process = []
items_to_process = []
for i, item in enumerate(dataset_split):
try:
red = item["red"]
green = item["green"]
blue = item["blue"]
def normalize_band(band):
band = np.array(band, dtype=np.float32)
if band.max() <= 1.0:
band = band * 255.0
return band.astype(np.uint8)
r_img = normalize_band(red)
g_img = normalize_band(green)
b_img = normalize_band(blue)
rgb_array = np.stack([r_img, g_img, b_img], axis=2)
image = Image.fromarray(rgb_array)
images_to_process.append(image)
items_to_process.append(item)
except Exception as e:
print(f"Skipping item {i} due to error preparing image: {e}")
continue
if not images_to_process:
return []
try:
embeddings_batch = generator.generate_image_embeddings(images_to_process)
except Exception as e:
print(f"Error generating embeddings: {e}")
return []
for item, image, embedding in zip(
items_to_process, images_to_process, embeddings_batch
):
try:
burn_area = item.get("burn_area", 0.0)
annotation = item["annotation"]
if isinstance(annotation, np.ndarray):
annotation = Image.fromarray(annotation)
image_b64 = image_to_base64(annotation)
rgb_image_b64 = image_to_base64(image)
point_id = str(uuid.uuid4())
point = models.PointStruct(
id=point_id,
vector={
"colpali": embedding,
},
payload={
"centroid_lat": float(item["latitude"]),
"centroid_lon": float(item["longitude"]),
"acquisition_date": datetime.strptime(
item["date"], "%Y-%m-%dT%H:%M:%SZ"
),
"burn_area": float(burn_area),
"image_base64": image_b64,
"rgb_image_base64": rgb_image_b64,
},
)
points.append(point)
except Exception as e:
print(f"Error creating point: {e}")
continue
return points
def main():
client = QdrantClient(
url=os.getenv("QDRANT_URL"),
api_key=os.getenv("QDRANT_API_KEY"),
)
if client.collection_exists(collection_name=collection_name):
client.delete_collection(collection_name=collection_name)
client.create_collection(
collection_name=collection_name,
vectors_config={
"colpali": models.VectorParams(
size=128, # ColPali dim
distance=models.Distance.COSINE,
multivector_config=models.MultiVectorConfig(
comparator=models.MultiVectorComparator.MAX_SIM
),
),
},
optimizers_config=models.OptimizersConfigDiff(default_segment_number=2),
)
client.create_payload_index(
collection_name=collection_name,
field_name="centroid_lat",
field_schema=models.PayloadSchemaType.FLOAT,
)
client.create_payload_index(
collection_name=collection_name,
field_name="centroid_lon",
field_schema=models.PayloadSchemaType.FLOAT,
)
client.create_payload_index(
collection_name=collection_name,
field_name="acquisition_date",
field_schema=models.PayloadSchemaType.DATETIME,
)
client.create_payload_index(
collection_name=collection_name,
field_name="burn_area",
field_schema=models.PayloadSchemaType.FLOAT,
)
for split_name, dataset_split in ds_from_hub.items():
print(f"Ingesting {len(dataset_split)} points from '{split_name}' split...")
dataset_split = dataset_split.with_format("numpy")
batch_size = 4
total_points = len(dataset_split)
for start_idx in range(0, total_points, batch_size):
end_idx = min(start_idx + batch_size, total_points)
batch = dataset_split.select(range(start_idx, end_idx))
print(f"Processing batch {start_idx} to {end_idx}...")
qdrant_points = generate_qdrant_points(batch)
if qdrant_points:
client.upsert(
collection_name=collection_name, points=qdrant_points, wait=True
)
print(f"Upserted batch {start_idx} to {end_idx}.")
print(f"Finished ingesting '{split_name}' split.")
print("Data ingestion complete for all splits.")
count_result = client.count(collection_name=collection_name, exact=True)
print(f"Total points in Qdrant collection: {count_result.count}")
if __name__ == "__main__":
main()