import datetime import io import os import random import numpy as np import rasterio from datasets import Array2D, Features, Image, Value, load_dataset from dotenv import load_dotenv from pyproj import Transformer load_dotenv() ds = load_dataset("ibm-nasa-geospatial/hls_burn_scars", trust_remote_code=True) ds_casted = ds.cast_column("image", Image(decode=False)) ds_annoted = ds_casted.cast_column("annotation", Image(decode=False)) def process_geotiff_data(geotiff_bytes): """ Processes raw GeoTIFF bytes to extract spectral bands, calculate NDVI, and derive geospatial information. """ processed_data = {} try: with rasterio.open(io.BytesIO(geotiff_bytes)) as src: band_blue = src.read(2).astype(np.float32) band_green = src.read(3).astype(np.float32) band_red = src.read(4).astype(np.float32) band_nir = src.read(5).astype(np.float32) band_swir = src.read(6).astype(np.float32) processed_data["blue"] = band_blue processed_data["green"] = band_green processed_data["red"] = band_red processed_data["nir"] = band_nir processed_data["swir"] = band_swir # Calculate NDVI denominator = band_nir + band_red denominator[denominator == 0] = 1e-6 ndvi = (band_nir - band_red) / denominator ndvi = np.nan_to_num(ndvi, nan=0.0, posinf=0.0, neginf=0.0) processed_data["ndvi"] = ndvi # Extract geospatial metadata crs = src.crs transform = src.transform # Calculate centroid center_pixel_x = src.width / 2 center_pixel_y = src.height / 2 center_x_crs, center_y_crs = transform * (center_pixel_x, center_pixel_y) if crs.is_projected: transformer = Transformer.from_crs(crs, "EPSG:4326", always_xy=True) centroid_lon, centroid_lat = transformer.transform( center_x_crs, center_y_crs ) else: centroid_lon, centroid_lat = center_x_crs, center_y_crs processed_data["centroid_lat"] = centroid_lat processed_data["centroid_lon"] = centroid_lon return processed_data except Exception as e: print(f"Error processing GeoTIFF: {e}") return None def calculate_burn_area(annotation_bytes): """ Calculates the burn area in hectares from the annotation mask. Assumes HLS pixel size is 30m x 30m. Value 1 = Burn Scar. """ try: with rasterio.open(io.BytesIO(annotation_bytes)) as src: mask = src.read(1) burn_pixel_count = np.count_nonzero(mask == 1) # 1 pixel = 900 m^2, 1 ha = 10,000 m^2 area_hectares = (burn_pixel_count * 900) / 10000 return float(area_hectares) except Exception as e: print(f"Error calculating burn area: {e}") return 0.0 def synthesize_temporal_metadata(): """Generates a random date between 2018 and 2021.""" start_date = datetime.datetime(2018, 1, 1) end_date = datetime.datetime(2021, 12, 31) days_between = (end_date - start_date).days random_days = random.randrange(days_between) random_date = start_date + datetime.timedelta(days=random_days) random_date = random_date.replace( hour=random.randint(0, 23), minute=random.randint(0, 59), second=random.randint(0, 59), ) return random_date.isoformat(timespec="seconds") + "Z" def process_sample_data(data_point): """ Extracts raw features and metadata. Does NOT compute embeddings (students will do this). """ try: img_bytes = data_point["image"]["bytes"] annot_bytes = data_point["annotation"]["bytes"] processed_geotiff = process_geotiff_data(img_bytes) if processed_geotiff is None: return {} burn_area = calculate_burn_area(annot_bytes) acquisition_date = synthesize_temporal_metadata() return { "annotation": data_point["annotation"], "red": processed_geotiff["red"], "green": processed_geotiff["green"], "blue": processed_geotiff["blue"], "nir": processed_geotiff["nir"], "swir": processed_geotiff["swir"], "ndvi": processed_geotiff["ndvi"], "latitude": processed_geotiff["centroid_lat"], "longitude": processed_geotiff["centroid_lon"], "date": acquisition_date, "burn_area": burn_area, } except Exception as e: print(f"Error processing sample: {e}") return {} features = Features( { "annotation": Image(), "red": Array2D(shape=(512, 512), dtype="float32"), "green": Array2D(shape=(512, 512), dtype="float32"), "blue": Array2D(shape=(512, 512), dtype="float32"), "nir": Array2D(shape=(512, 512), dtype="float32"), "swir": Array2D(shape=(512, 512), dtype="float32"), "ndvi": Array2D(shape=(512, 512), dtype="float32"), "latitude": Value("float64"), "longitude": Value("float64"), "date": Value("string"), "burn_area": Value("float32"), } ) print("Processing dataset...") ds_processed = ds_annoted.map( process_sample_data, remove_columns=["image"], features=features, writer_batch_size=100, ) print("Processing complete. Sample keys:", ds_processed["train"][0].keys()) repo_id = "mahimairaja/ibm-hls-burn-original" print(f"Pushing to {repo_id}...") ds_processed.push_to_hub(repo_id, private=False, token=os.getenv("HF_TOKEN")) print("Done!")