|
|
import datetime |
|
|
import io |
|
|
import os |
|
|
import random |
|
|
|
|
|
import numpy as np |
|
|
import rasterio |
|
|
from datasets import Array2D, Features, Image, Value, load_dataset |
|
|
from dotenv import load_dotenv |
|
|
from pyproj import Transformer |
|
|
|
|
|
load_dotenv() |
|
|
|
|
|
ds = load_dataset("ibm-nasa-geospatial/hls_burn_scars", trust_remote_code=True) |
|
|
|
|
|
ds_casted = ds.cast_column("image", Image(decode=False)) |
|
|
ds_annoted = ds_casted.cast_column("annotation", Image(decode=False)) |
|
|
|
|
|
|
|
|
def process_geotiff_data(geotiff_bytes): |
|
|
""" |
|
|
Processes raw GeoTIFF bytes to extract spectral bands, calculate NDVI, |
|
|
and derive geospatial information. |
|
|
""" |
|
|
processed_data = {} |
|
|
try: |
|
|
with rasterio.open(io.BytesIO(geotiff_bytes)) as src: |
|
|
band_blue = src.read(2).astype(np.float32) |
|
|
band_green = src.read(3).astype(np.float32) |
|
|
band_red = src.read(4).astype(np.float32) |
|
|
band_nir = src.read(5).astype(np.float32) |
|
|
band_swir = src.read(6).astype(np.float32) |
|
|
|
|
|
processed_data["blue"] = band_blue |
|
|
processed_data["green"] = band_green |
|
|
processed_data["red"] = band_red |
|
|
processed_data["nir"] = band_nir |
|
|
processed_data["swir"] = band_swir |
|
|
|
|
|
|
|
|
denominator = band_nir + band_red |
|
|
denominator[denominator == 0] = 1e-6 |
|
|
ndvi = (band_nir - band_red) / denominator |
|
|
ndvi = np.nan_to_num(ndvi, nan=0.0, posinf=0.0, neginf=0.0) |
|
|
processed_data["ndvi"] = ndvi |
|
|
|
|
|
|
|
|
crs = src.crs |
|
|
transform = src.transform |
|
|
|
|
|
|
|
|
center_pixel_x = src.width / 2 |
|
|
center_pixel_y = src.height / 2 |
|
|
center_x_crs, center_y_crs = transform * (center_pixel_x, center_pixel_y) |
|
|
|
|
|
if crs.is_projected: |
|
|
transformer = Transformer.from_crs(crs, "EPSG:4326", always_xy=True) |
|
|
centroid_lon, centroid_lat = transformer.transform( |
|
|
center_x_crs, center_y_crs |
|
|
) |
|
|
else: |
|
|
centroid_lon, centroid_lat = center_x_crs, center_y_crs |
|
|
|
|
|
processed_data["centroid_lat"] = centroid_lat |
|
|
processed_data["centroid_lon"] = centroid_lon |
|
|
|
|
|
return processed_data |
|
|
|
|
|
except Exception as e: |
|
|
print(f"Error processing GeoTIFF: {e}") |
|
|
return None |
|
|
|
|
|
|
|
|
def calculate_burn_area(annotation_bytes): |
|
|
""" |
|
|
Calculates the burn area in hectares from the annotation mask. |
|
|
Assumes HLS pixel size is 30m x 30m. |
|
|
Value 1 = Burn Scar. |
|
|
""" |
|
|
try: |
|
|
with rasterio.open(io.BytesIO(annotation_bytes)) as src: |
|
|
mask = src.read(1) |
|
|
burn_pixel_count = np.count_nonzero(mask == 1) |
|
|
|
|
|
area_hectares = (burn_pixel_count * 900) / 10000 |
|
|
return float(area_hectares) |
|
|
except Exception as e: |
|
|
print(f"Error calculating burn area: {e}") |
|
|
return 0.0 |
|
|
|
|
|
|
|
|
def synthesize_temporal_metadata(): |
|
|
"""Generates a random date between 2018 and 2021.""" |
|
|
start_date = datetime.datetime(2018, 1, 1) |
|
|
end_date = datetime.datetime(2021, 12, 31) |
|
|
days_between = (end_date - start_date).days |
|
|
random_days = random.randrange(days_between) |
|
|
random_date = start_date + datetime.timedelta(days=random_days) |
|
|
|
|
|
random_date = random_date.replace( |
|
|
hour=random.randint(0, 23), |
|
|
minute=random.randint(0, 59), |
|
|
second=random.randint(0, 59), |
|
|
) |
|
|
return random_date.isoformat(timespec="seconds") + "Z" |
|
|
|
|
|
|
|
|
def process_sample_data(data_point): |
|
|
""" |
|
|
Extracts raw features and metadata. |
|
|
Does NOT compute embeddings (students will do this). |
|
|
""" |
|
|
try: |
|
|
img_bytes = data_point["image"]["bytes"] |
|
|
annot_bytes = data_point["annotation"]["bytes"] |
|
|
|
|
|
processed_geotiff = process_geotiff_data(img_bytes) |
|
|
if processed_geotiff is None: |
|
|
return {} |
|
|
|
|
|
burn_area = calculate_burn_area(annot_bytes) |
|
|
acquisition_date = synthesize_temporal_metadata() |
|
|
|
|
|
return { |
|
|
"annotation": data_point["annotation"], |
|
|
"red": processed_geotiff["red"], |
|
|
"green": processed_geotiff["green"], |
|
|
"blue": processed_geotiff["blue"], |
|
|
"nir": processed_geotiff["nir"], |
|
|
"swir": processed_geotiff["swir"], |
|
|
"ndvi": processed_geotiff["ndvi"], |
|
|
"latitude": processed_geotiff["centroid_lat"], |
|
|
"longitude": processed_geotiff["centroid_lon"], |
|
|
"date": acquisition_date, |
|
|
"burn_area": burn_area, |
|
|
} |
|
|
|
|
|
except Exception as e: |
|
|
print(f"Error processing sample: {e}") |
|
|
return {} |
|
|
|
|
|
|
|
|
features = Features( |
|
|
{ |
|
|
"annotation": Image(), |
|
|
"red": Array2D(shape=(512, 512), dtype="float32"), |
|
|
"green": Array2D(shape=(512, 512), dtype="float32"), |
|
|
"blue": Array2D(shape=(512, 512), dtype="float32"), |
|
|
"nir": Array2D(shape=(512, 512), dtype="float32"), |
|
|
"swir": Array2D(shape=(512, 512), dtype="float32"), |
|
|
"ndvi": Array2D(shape=(512, 512), dtype="float32"), |
|
|
"latitude": Value("float64"), |
|
|
"longitude": Value("float64"), |
|
|
"date": Value("string"), |
|
|
"burn_area": Value("float32"), |
|
|
} |
|
|
) |
|
|
|
|
|
print("Processing dataset...") |
|
|
ds_processed = ds_annoted.map( |
|
|
process_sample_data, |
|
|
remove_columns=["image"], |
|
|
features=features, |
|
|
writer_batch_size=100, |
|
|
) |
|
|
print("Processing complete. Sample keys:", ds_processed["train"][0].keys()) |
|
|
|
|
|
repo_id = "mahimairaja/ibm-hls-burn-original" |
|
|
print(f"Pushing to {repo_id}...") |
|
|
|
|
|
ds_processed.push_to_hub(repo_id, private=False, token=os.getenv("HF_TOKEN")) |
|
|
print("Done!") |
|
|
|