mahimairaja's picture
util: script to build dataset
1044a0a
import datetime
import io
import os
import random
import numpy as np
import rasterio
from datasets import Array2D, Features, Image, Value, load_dataset
from dotenv import load_dotenv
from pyproj import Transformer
load_dotenv()
ds = load_dataset("ibm-nasa-geospatial/hls_burn_scars", trust_remote_code=True)
ds_casted = ds.cast_column("image", Image(decode=False))
ds_annoted = ds_casted.cast_column("annotation", Image(decode=False))
def process_geotiff_data(geotiff_bytes):
"""
Processes raw GeoTIFF bytes to extract spectral bands, calculate NDVI,
and derive geospatial information.
"""
processed_data = {}
try:
with rasterio.open(io.BytesIO(geotiff_bytes)) as src:
band_blue = src.read(2).astype(np.float32)
band_green = src.read(3).astype(np.float32)
band_red = src.read(4).astype(np.float32)
band_nir = src.read(5).astype(np.float32)
band_swir = src.read(6).astype(np.float32)
processed_data["blue"] = band_blue
processed_data["green"] = band_green
processed_data["red"] = band_red
processed_data["nir"] = band_nir
processed_data["swir"] = band_swir
# Calculate NDVI
denominator = band_nir + band_red
denominator[denominator == 0] = 1e-6
ndvi = (band_nir - band_red) / denominator
ndvi = np.nan_to_num(ndvi, nan=0.0, posinf=0.0, neginf=0.0)
processed_data["ndvi"] = ndvi
# Extract geospatial metadata
crs = src.crs
transform = src.transform
# Calculate centroid
center_pixel_x = src.width / 2
center_pixel_y = src.height / 2
center_x_crs, center_y_crs = transform * (center_pixel_x, center_pixel_y)
if crs.is_projected:
transformer = Transformer.from_crs(crs, "EPSG:4326", always_xy=True)
centroid_lon, centroid_lat = transformer.transform(
center_x_crs, center_y_crs
)
else:
centroid_lon, centroid_lat = center_x_crs, center_y_crs
processed_data["centroid_lat"] = centroid_lat
processed_data["centroid_lon"] = centroid_lon
return processed_data
except Exception as e:
print(f"Error processing GeoTIFF: {e}")
return None
def calculate_burn_area(annotation_bytes):
"""
Calculates the burn area in hectares from the annotation mask.
Assumes HLS pixel size is 30m x 30m.
Value 1 = Burn Scar.
"""
try:
with rasterio.open(io.BytesIO(annotation_bytes)) as src:
mask = src.read(1)
burn_pixel_count = np.count_nonzero(mask == 1)
# 1 pixel = 900 m^2, 1 ha = 10,000 m^2
area_hectares = (burn_pixel_count * 900) / 10000
return float(area_hectares)
except Exception as e:
print(f"Error calculating burn area: {e}")
return 0.0
def synthesize_temporal_metadata():
"""Generates a random date between 2018 and 2021."""
start_date = datetime.datetime(2018, 1, 1)
end_date = datetime.datetime(2021, 12, 31)
days_between = (end_date - start_date).days
random_days = random.randrange(days_between)
random_date = start_date + datetime.timedelta(days=random_days)
random_date = random_date.replace(
hour=random.randint(0, 23),
minute=random.randint(0, 59),
second=random.randint(0, 59),
)
return random_date.isoformat(timespec="seconds") + "Z"
def process_sample_data(data_point):
"""
Extracts raw features and metadata.
Does NOT compute embeddings (students will do this).
"""
try:
img_bytes = data_point["image"]["bytes"]
annot_bytes = data_point["annotation"]["bytes"]
processed_geotiff = process_geotiff_data(img_bytes)
if processed_geotiff is None:
return {}
burn_area = calculate_burn_area(annot_bytes)
acquisition_date = synthesize_temporal_metadata()
return {
"annotation": data_point["annotation"],
"red": processed_geotiff["red"],
"green": processed_geotiff["green"],
"blue": processed_geotiff["blue"],
"nir": processed_geotiff["nir"],
"swir": processed_geotiff["swir"],
"ndvi": processed_geotiff["ndvi"],
"latitude": processed_geotiff["centroid_lat"],
"longitude": processed_geotiff["centroid_lon"],
"date": acquisition_date,
"burn_area": burn_area,
}
except Exception as e:
print(f"Error processing sample: {e}")
return {}
features = Features(
{
"annotation": Image(),
"red": Array2D(shape=(512, 512), dtype="float32"),
"green": Array2D(shape=(512, 512), dtype="float32"),
"blue": Array2D(shape=(512, 512), dtype="float32"),
"nir": Array2D(shape=(512, 512), dtype="float32"),
"swir": Array2D(shape=(512, 512), dtype="float32"),
"ndvi": Array2D(shape=(512, 512), dtype="float32"),
"latitude": Value("float64"),
"longitude": Value("float64"),
"date": Value("string"),
"burn_area": Value("float32"),
}
)
print("Processing dataset...")
ds_processed = ds_annoted.map(
process_sample_data,
remove_columns=["image"],
features=features,
writer_batch_size=100,
)
print("Processing complete. Sample keys:", ds_processed["train"][0].keys())
repo_id = "mahimairaja/ibm-hls-burn-original"
print(f"Pushing to {repo_id}...")
ds_processed.push_to_hub(repo_id, private=False, token=os.getenv("HF_TOKEN"))
print("Done!")