|
|
import base64 |
|
|
import os |
|
|
from datetime import date, datetime |
|
|
from io import BytesIO |
|
|
|
|
|
import folium |
|
|
import numpy as np |
|
|
import streamlit as st |
|
|
from dotenv import load_dotenv |
|
|
from geopy.geocoders import Nominatim |
|
|
from PIL import Image |
|
|
from qdrant_client import QdrantClient, models |
|
|
from streamlit_folium import st_folium |
|
|
|
|
|
from utils.embedding_utils import ColPaliEmbeddingGenerator |
|
|
|
|
|
load_dotenv() |
|
|
|
|
|
st.set_page_config( |
|
|
page_title="Geo-Spatial Vector Search with Qdrant", page_icon="π", layout="wide" |
|
|
) |
|
|
COLLECTION_NAME = "hls_burn_scars_data_colpali_rgb" |
|
|
REGIONS = [ |
|
|
"Select a Region", |
|
|
"Alabama, USA", |
|
|
"Alaska, USA", |
|
|
"Arizona, USA", |
|
|
"Arkansas, USA", |
|
|
"California, USA", |
|
|
"Colorado, USA", |
|
|
"Connecticut, USA", |
|
|
"Delaware, USA", |
|
|
"Florida, USA", |
|
|
"Georgia, USA", |
|
|
"Hawaii, USA", |
|
|
"Idaho, USA", |
|
|
"Illinois, USA", |
|
|
"Indiana, USA", |
|
|
"Iowa, USA", |
|
|
"Kansas, USA", |
|
|
"Kentucky, USA", |
|
|
"Louisiana, USA", |
|
|
"Maine, USA", |
|
|
"Maryland, USA", |
|
|
"Massachusetts, USA", |
|
|
"Michigan, USA", |
|
|
"Minnesota, USA", |
|
|
"Mississippi, USA", |
|
|
"Missouri, USA", |
|
|
"Montana, USA", |
|
|
"Nebraska, USA", |
|
|
"Nevada, USA", |
|
|
"New Hampshire, USA", |
|
|
"New Jersey, USA", |
|
|
"New Mexico, USA", |
|
|
"New York, USA", |
|
|
"North Carolina, USA", |
|
|
"North Dakota, USA", |
|
|
"Ohio, USA", |
|
|
"Oklahoma, USA", |
|
|
"Oregon, USA", |
|
|
"Pennsylvania, USA", |
|
|
"Rhode Island, USA", |
|
|
"South Carolina, USA", |
|
|
"South Dakota, USA", |
|
|
"Tennessee, USA", |
|
|
"Texas, USA", |
|
|
"Utah, USA", |
|
|
"Vermont, USA", |
|
|
"Virginia, USA", |
|
|
"Washington, USA", |
|
|
"West Virginia, USA", |
|
|
"Wisconsin, USA", |
|
|
"Wyoming, USA", |
|
|
"Alberta, Canada", |
|
|
"British Columbia, Canada", |
|
|
"Manitoba, Canada", |
|
|
"New Brunswick, Canada", |
|
|
"Newfoundland and Labrador, Canada", |
|
|
"Nova Scotia, Canada", |
|
|
"Ontario, Canada", |
|
|
"Prince Edward Island, Canada", |
|
|
"Quebec, Canada", |
|
|
"Saskatchewan, Canada", |
|
|
] |
|
|
|
|
|
|
|
|
@st.cache_resource |
|
|
def get_qdrant_client(): |
|
|
return QdrantClient( |
|
|
url=os.getenv("QDRANT_URL"), |
|
|
api_key=os.getenv("QDRANT_API_KEY"), |
|
|
) |
|
|
|
|
|
|
|
|
@st.cache_resource |
|
|
def get_embedding_model(): |
|
|
with st.spinner("Loading ColPali model... (this may take a minute)"): |
|
|
generator = ColPaliEmbeddingGenerator() |
|
|
return generator |
|
|
|
|
|
|
|
|
@st.cache_data |
|
|
def get_coordinates(place_name): |
|
|
"""Geocode a place name to (lat, lon).""" |
|
|
try: |
|
|
geolocator = Nominatim(user_agent="geo_spatial_chat_app") |
|
|
location = geolocator.geocode(place_name) |
|
|
if location: |
|
|
return location.latitude, location.longitude |
|
|
except Exception as e: |
|
|
st.error(f"Error geocoding {place_name}: {e}") |
|
|
return None, None |
|
|
|
|
|
|
|
|
client = get_qdrant_client() |
|
|
generator = get_embedding_model() |
|
|
|
|
|
st.title("Geo-Spatial Vector Search with Qdrant") |
|
|
st.markdown( |
|
|
"Find burn scars using natural language queries with Multi-Vector Retrieval." |
|
|
) |
|
|
|
|
|
with st.sidebar: |
|
|
st.header("Filters") |
|
|
|
|
|
min_area = st.slider( |
|
|
"Minimum Burn Area (Hectares)", |
|
|
min_value=0, |
|
|
max_value=1000, |
|
|
value=0, |
|
|
step=10, |
|
|
help="Filter results to show only burn scars larger than this value.", |
|
|
) |
|
|
|
|
|
st.subheader("Date Range") |
|
|
min_date_input = st.date_input("Start Date", value=date(2018, 1, 1)) |
|
|
max_date_input = st.date_input("End Date", value=date(2021, 12, 31)) |
|
|
st.subheader("Spatial Filter") |
|
|
use_spatial = st.checkbox("Filter by Location") |
|
|
|
|
|
if "lat_input_widget" not in st.session_state: |
|
|
st.session_state.lat_input_widget = 37.0 |
|
|
if "lon_input_widget" not in st.session_state: |
|
|
st.session_state.lon_input_widget = -120.0 |
|
|
|
|
|
lat_input = st.session_state.lat_input_widget |
|
|
lon_input = st.session_state.lon_input_widget |
|
|
radius_km = 100.0 |
|
|
|
|
|
if use_spatial: |
|
|
|
|
|
def on_region_change(): |
|
|
region = st.session_state.selected_region |
|
|
if region and region != "Select a Region": |
|
|
lat, lon = get_coordinates(region) |
|
|
if lat is not None: |
|
|
st.session_state.lat_input_widget = lat |
|
|
st.session_state.lon_input_widget = lon |
|
|
|
|
|
st.selectbox( |
|
|
"Select State/Region (US/Canada)", |
|
|
REGIONS, |
|
|
key="selected_region", |
|
|
on_change=on_region_change, |
|
|
) |
|
|
|
|
|
col_lat, col_lon = st.columns(2) |
|
|
with col_lat: |
|
|
lat_input = st.number_input( |
|
|
"Latitude", format="%.4f", key="lat_input_widget" |
|
|
) |
|
|
with col_lon: |
|
|
lon_input = st.number_input( |
|
|
"Longitude", format="%.4f", key="lon_input_widget" |
|
|
) |
|
|
|
|
|
radius_km = st.slider("Radius (km)", 10, 5000, 230) |
|
|
|
|
|
user_query = st.text_input( |
|
|
"Enter your query:", "Help me find the burn scars that have more than 100 hectares" |
|
|
) |
|
|
|
|
|
if st.button("Search"): |
|
|
if not user_query: |
|
|
st.warning("Please enter a query.") |
|
|
else: |
|
|
with st.spinner("Searching..."): |
|
|
query_embeddings = generator.generate_query_embeddings([user_query])[0] |
|
|
|
|
|
filter_conditions = [] |
|
|
|
|
|
if min_area > 0: |
|
|
filter_conditions.append( |
|
|
models.FieldCondition( |
|
|
key="burn_area", |
|
|
range=models.Range(gt=min_area), |
|
|
) |
|
|
) |
|
|
|
|
|
start_dt = datetime.combine(min_date_input, datetime.min.time()) |
|
|
end_dt = datetime.combine(max_date_input, datetime.max.time()) |
|
|
|
|
|
filter_conditions.append( |
|
|
models.FieldCondition( |
|
|
key="acquisition_date", |
|
|
range=models.DatetimeRange( |
|
|
gte=start_dt.isoformat(), lte=end_dt.isoformat() |
|
|
), |
|
|
) |
|
|
) |
|
|
|
|
|
if use_spatial: |
|
|
lat_delta = radius_km / 111.0 |
|
|
lon_delta = radius_km / (111.0 * np.cos(np.radians(lat_input))) |
|
|
|
|
|
filter_conditions.append( |
|
|
models.FieldCondition( |
|
|
key="centroid_lat", |
|
|
range=models.Range( |
|
|
gte=lat_input - lat_delta, lte=lat_input + lat_delta |
|
|
), |
|
|
) |
|
|
) |
|
|
filter_conditions.append( |
|
|
models.FieldCondition( |
|
|
key="centroid_lon", |
|
|
range=models.Range( |
|
|
gte=lon_input - lon_delta, lte=lon_input + lon_delta |
|
|
), |
|
|
) |
|
|
) |
|
|
|
|
|
query_filter = None |
|
|
if filter_conditions: |
|
|
query_filter = models.Filter(must=filter_conditions) |
|
|
results = client.query_points( |
|
|
collection_name=COLLECTION_NAME, |
|
|
query=query_embeddings, |
|
|
using="colpali", |
|
|
limit=5, |
|
|
query_filter=query_filter, |
|
|
with_payload=True, |
|
|
) |
|
|
st.session_state.search_results = results |
|
|
|
|
|
if "search_results" in st.session_state: |
|
|
results = st.session_state.search_results |
|
|
if not results.points: |
|
|
st.info("No results found.") |
|
|
else: |
|
|
col1, col2 = st.columns([1, 2]) |
|
|
|
|
|
with col1: |
|
|
st.subheader("Top Results") |
|
|
for i, point in enumerate(results.points): |
|
|
score = point.score |
|
|
date_str = point.payload.get("acquisition_date") |
|
|
lat = point.payload.get("centroid_lat") |
|
|
lon = point.payload.get("centroid_lon") |
|
|
area = point.payload.get("burn_area", "N/A") |
|
|
image_b64 = point.payload.get("image_base64") |
|
|
rgb_image_b64 = point.payload.get("rgb_image_base64") |
|
|
|
|
|
with st.expander(f"Result {i + 1} (Score: {score:.4f})", expanded=True): |
|
|
st.write(f"**Date:** {date_str}") |
|
|
st.write(f"**Area:** {area} ha") |
|
|
st.write(f"**Location:** {lat:.4f}, {lon:.4f}") |
|
|
|
|
|
display_b64 = rgb_image_b64 if rgb_image_b64 else image_b64 |
|
|
caption = "RGB Image" if rgb_image_b64 else "Burn Scar Mask" |
|
|
|
|
|
if display_b64: |
|
|
try: |
|
|
image_data = base64.b64decode(display_b64) |
|
|
image = Image.open(BytesIO(image_data)) |
|
|
|
|
|
if image.mode in ("I", "I;16"): |
|
|
arr = np.array(image) |
|
|
arr = (arr * 255).astype(np.uint8) |
|
|
image = Image.fromarray(arr).convert("RGB") |
|
|
|
|
|
st.image( |
|
|
image, |
|
|
caption=caption, |
|
|
width="stretch", |
|
|
) |
|
|
except Exception as e: |
|
|
st.error(f"Error loading image: {e}") |
|
|
|
|
|
with col2: |
|
|
st.subheader("Map Visualization") |
|
|
|
|
|
first_lat = results.points[0].payload.get("centroid_lat") |
|
|
first_lon = results.points[0].payload.get("centroid_lon") |
|
|
m = folium.Map(location=[first_lat, first_lon], zoom_start=10) |
|
|
for point in results.points: |
|
|
lat = point.payload.get("centroid_lat") |
|
|
lon = point.payload.get("centroid_lon") |
|
|
date_str = point.payload.get("acquisition_date") |
|
|
score = point.score |
|
|
area = point.payload.get("burn_area", "N/A") |
|
|
image_b64 = point.payload.get("image_base64") |
|
|
rgb_image_b64 = point.payload.get("rgb_image_base64") |
|
|
|
|
|
popup_html = f""" |
|
|
<div style="width:200px"> |
|
|
<b>Date:</b> {date_str}<br> |
|
|
<b>Score:</b> {score:.4f}<br> |
|
|
<b>Area:</b> {area} ha<br> |
|
|
""" |
|
|
|
|
|
display_b64 = rgb_image_b64 if rgb_image_b64 else image_b64 |
|
|
|
|
|
if display_b64: |
|
|
popup_html += f'<img src="data:image/png;base64,{display_b64}" style="width:100%;margin-top:10px;">' |
|
|
|
|
|
popup_html += "</div>" |
|
|
|
|
|
folium.Marker( |
|
|
[lat, lon], |
|
|
popup=folium.Popup(popup_html, max_width=250), |
|
|
tooltip=f"Result (Score: {score:.4f})", |
|
|
).add_to(m) |
|
|
|
|
|
st_folium(m, width="100%", height=500) |
|
|
|