|
|
import base64 |
|
|
import os |
|
|
import uuid |
|
|
from datetime import datetime |
|
|
from io import BytesIO |
|
|
|
|
|
import numpy as np |
|
|
from datasets import load_dataset |
|
|
from dotenv import load_dotenv |
|
|
from PIL import Image |
|
|
from qdrant_client import QdrantClient, models |
|
|
|
|
|
from .embedding_utils import ColPaliEmbeddingGenerator |
|
|
|
|
|
load_dotenv() |
|
|
|
|
|
repo_id = "mahimairaja/ibm-hls-burn-original" |
|
|
ds_from_hub = load_dataset(repo_id) |
|
|
|
|
|
collection_name = "hls_burn_scars_data_colpali" |
|
|
|
|
|
generator = ColPaliEmbeddingGenerator() |
|
|
|
|
|
|
|
|
def image_to_base64(image): |
|
|
"""Converts a PIL Image to a base64 string.""" |
|
|
buffered = BytesIO() |
|
|
image.save(buffered, format="PNG") |
|
|
return base64.b64encode(buffered.getvalue()).decode("utf-8") |
|
|
|
|
|
|
|
|
def generate_qdrant_points(dataset_split): |
|
|
points = [] |
|
|
|
|
|
images_to_process = [] |
|
|
items_to_process = [] |
|
|
|
|
|
for i, item in enumerate(dataset_split): |
|
|
try: |
|
|
red = item["red"] |
|
|
green = item["green"] |
|
|
blue = item["blue"] |
|
|
|
|
|
def normalize_band(band): |
|
|
band = np.array(band, dtype=np.float32) |
|
|
if band.max() <= 1.0: |
|
|
band = band * 255.0 |
|
|
return band.astype(np.uint8) |
|
|
|
|
|
r_img = normalize_band(red) |
|
|
g_img = normalize_band(green) |
|
|
b_img = normalize_band(blue) |
|
|
|
|
|
rgb_array = np.stack([r_img, g_img, b_img], axis=2) |
|
|
image = Image.fromarray(rgb_array) |
|
|
|
|
|
images_to_process.append(image) |
|
|
items_to_process.append(item) |
|
|
|
|
|
except Exception as e: |
|
|
print(f"Skipping item {i} due to error preparing image: {e}") |
|
|
continue |
|
|
|
|
|
if not images_to_process: |
|
|
return [] |
|
|
|
|
|
try: |
|
|
embeddings_batch = generator.generate_image_embeddings(images_to_process) |
|
|
except Exception as e: |
|
|
print(f"Error generating embeddings: {e}") |
|
|
return [] |
|
|
|
|
|
for item, image, embedding in zip( |
|
|
items_to_process, images_to_process, embeddings_batch |
|
|
): |
|
|
try: |
|
|
burn_area = item.get("burn_area", 0.0) |
|
|
|
|
|
annotation = item["annotation"] |
|
|
if isinstance(annotation, np.ndarray): |
|
|
annotation = Image.fromarray(annotation) |
|
|
|
|
|
image_b64 = image_to_base64(annotation) |
|
|
rgb_image_b64 = image_to_base64(image) |
|
|
|
|
|
point_id = str(uuid.uuid4()) |
|
|
|
|
|
point = models.PointStruct( |
|
|
id=point_id, |
|
|
vector={ |
|
|
"colpali": embedding, |
|
|
}, |
|
|
payload={ |
|
|
"centroid_lat": float(item["latitude"]), |
|
|
"centroid_lon": float(item["longitude"]), |
|
|
"acquisition_date": datetime.strptime( |
|
|
item["date"], "%Y-%m-%dT%H:%M:%SZ" |
|
|
), |
|
|
"burn_area": float(burn_area), |
|
|
"image_base64": image_b64, |
|
|
"rgb_image_base64": rgb_image_b64, |
|
|
}, |
|
|
) |
|
|
points.append(point) |
|
|
except Exception as e: |
|
|
print(f"Error creating point: {e}") |
|
|
continue |
|
|
|
|
|
return points |
|
|
|
|
|
|
|
|
def main(): |
|
|
client = QdrantClient( |
|
|
url=os.getenv("QDRANT_URL"), |
|
|
api_key=os.getenv("QDRANT_API_KEY"), |
|
|
) |
|
|
|
|
|
if client.collection_exists(collection_name=collection_name): |
|
|
client.delete_collection(collection_name=collection_name) |
|
|
|
|
|
client.create_collection( |
|
|
collection_name=collection_name, |
|
|
vectors_config={ |
|
|
"colpali": models.VectorParams( |
|
|
size=128, |
|
|
distance=models.Distance.COSINE, |
|
|
multivector_config=models.MultiVectorConfig( |
|
|
comparator=models.MultiVectorComparator.MAX_SIM |
|
|
), |
|
|
), |
|
|
}, |
|
|
optimizers_config=models.OptimizersConfigDiff(default_segment_number=2), |
|
|
) |
|
|
|
|
|
client.create_payload_index( |
|
|
collection_name=collection_name, |
|
|
field_name="centroid_lat", |
|
|
field_schema=models.PayloadSchemaType.FLOAT, |
|
|
) |
|
|
client.create_payload_index( |
|
|
collection_name=collection_name, |
|
|
field_name="centroid_lon", |
|
|
field_schema=models.PayloadSchemaType.FLOAT, |
|
|
) |
|
|
client.create_payload_index( |
|
|
collection_name=collection_name, |
|
|
field_name="acquisition_date", |
|
|
field_schema=models.PayloadSchemaType.DATETIME, |
|
|
) |
|
|
client.create_payload_index( |
|
|
collection_name=collection_name, |
|
|
field_name="burn_area", |
|
|
field_schema=models.PayloadSchemaType.FLOAT, |
|
|
) |
|
|
|
|
|
for split_name, dataset_split in ds_from_hub.items(): |
|
|
print(f"Ingesting {len(dataset_split)} points from '{split_name}' split...") |
|
|
|
|
|
dataset_split = dataset_split.with_format("numpy") |
|
|
|
|
|
batch_size = 4 |
|
|
total_points = len(dataset_split) |
|
|
|
|
|
for start_idx in range(0, total_points, batch_size): |
|
|
end_idx = min(start_idx + batch_size, total_points) |
|
|
batch = dataset_split.select(range(start_idx, end_idx)) |
|
|
|
|
|
print(f"Processing batch {start_idx} to {end_idx}...") |
|
|
qdrant_points = generate_qdrant_points(batch) |
|
|
|
|
|
if qdrant_points: |
|
|
client.upsert( |
|
|
collection_name=collection_name, points=qdrant_points, wait=True |
|
|
) |
|
|
print(f"Upserted batch {start_idx} to {end_idx}.") |
|
|
|
|
|
print(f"Finished ingesting '{split_name}' split.") |
|
|
|
|
|
print("Data ingestion complete for all splits.") |
|
|
|
|
|
count_result = client.count(collection_name=collection_name, exact=True) |
|
|
print(f"Total points in Qdrant collection: {count_result.count}") |
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
main() |
|
|
|