import base64 import os from datetime import date, datetime from io import BytesIO import folium import numpy as np import streamlit as st from dotenv import load_dotenv from geopy.geocoders import Nominatim from PIL import Image from qdrant_client import QdrantClient, models from streamlit_folium import st_folium from utils.embedding_utils import ColPaliEmbeddingGenerator from utils.llm_utils import extract_filters_from_query load_dotenv() st.set_page_config( page_title="Geo-Spatial Vector Search with Qdrant", page_icon="🌍", layout="wide" ) COLLECTION_NAME = "hls_burn_scars_data_colpali_rgb" REGIONS = [ "Select a Region", "Alabama, USA", "Alaska, USA", "Arizona, USA", "Arkansas, USA", "California, USA", "Colorado, USA", "Connecticut, USA", "Delaware, USA", "Florida, USA", "Georgia, USA", "Hawaii, USA", "Idaho, USA", "Illinois, USA", "Indiana, USA", "Iowa, USA", "Kansas, USA", "Kentucky, USA", "Louisiana, USA", "Maine, USA", "Maryland, USA", "Massachusetts, USA", "Michigan, USA", "Minnesota, USA", "Mississippi, USA", "Missouri, USA", "Montana, USA", "Nebraska, USA", "Nevada, USA", "New Hampshire, USA", "New Jersey, USA", "New Mexico, USA", "New York, USA", "North Carolina, USA", "North Dakota, USA", "Ohio, USA", "Oklahoma, USA", "Oregon, USA", "Pennsylvania, USA", "Rhode Island, USA", "South Carolina, USA", "South Dakota, USA", "Tennessee, USA", "Texas, USA", "Utah, USA", "Vermont, USA", "Virginia, USA", "Washington, USA", "West Virginia, USA", "Wisconsin, USA", "Wyoming, USA", "Alberta, Canada", "British Columbia, Canada", "Manitoba, Canada", "New Brunswick, Canada", "Newfoundland and Labrador, Canada", "Nova Scotia, Canada", "Ontario, Canada", "Prince Edward Island, Canada", "Quebec, Canada", "Saskatchewan, Canada", ] @st.cache_resource def get_qdrant_client(): return QdrantClient( url=os.getenv("QDRANT_URL"), api_key=os.getenv("QDRANT_API_KEY"), ) @st.cache_resource def get_embedding_model(): with st.spinner("Loading ColPali model... (this may take a minute)"): generator = ColPaliEmbeddingGenerator() return generator @st.cache_data def get_coordinates(place_name): """Geocode a place name to (lat, lon).""" try: geolocator = Nominatim(user_agent="geo_spatial_chat_app") location = geolocator.geocode(place_name) if location: return location.latitude, location.longitude except Exception as e: st.error(f"Error geocoding {place_name}: {e}") return None, None client = get_qdrant_client() generator = get_embedding_model() st.title("Geo-Spatial Vector Search with Qdrant") st.markdown( "Find burn scars using natural language queries with Multi-Vector Retrieval." ) with st.sidebar: st.header("Filters") use_auto_filter = st.toggle( "🤖 Enable Auto-Filter", value=False, help="Use AI to automatically extract filters from your query.", ) if use_auto_filter: st.info("Filters will be extracted from your query automatically.") min_area = st.slider( "Minimum Burn Area (Hectares)", min_value=0, max_value=1000, value=0, step=10, help="Filter results to show only burn scars larger than this value.", disabled=use_auto_filter, ) st.subheader("Date Range") min_date_input = st.date_input( "Start Date", value=date(2018, 1, 1), disabled=use_auto_filter ) max_date_input = st.date_input( "End Date", value=date(2021, 12, 31), disabled=use_auto_filter ) st.subheader("Spatial Filter") use_spatial = st.checkbox("Filter by Location", disabled=use_auto_filter) if "lat_input_widget" not in st.session_state: st.session_state.lat_input_widget = 37.0 if "lon_input_widget" not in st.session_state: st.session_state.lon_input_widget = -120.0 lat_input = st.session_state.lat_input_widget lon_input = st.session_state.lon_input_widget radius_km = 100.0 if use_spatial: def on_region_change(): region = st.session_state.selected_region if region and region != "Select a Region": lat, lon = get_coordinates(region) if lat is not None: st.session_state.lat_input_widget = lat st.session_state.lon_input_widget = lon st.selectbox( "Select State/Region (US/Canada)", REGIONS, key="selected_region", on_change=on_region_change, ) col_lat, col_lon = st.columns(2) with col_lat: lat_input = st.number_input( "Latitude", format="%.4f", key="lat_input_widget", disabled=use_auto_filter, ) with col_lon: lon_input = st.number_input( "Longitude", format="%.4f", key="lon_input_widget", disabled=use_auto_filter, ) radius_km = st.slider("Radius (km)", 10, 5000, 230, disabled=use_auto_filter) user_query = st.text_input( "Enter your query:", "Help me find the burn scars that have more than 100 hectares" ) if st.button("Search"): if not user_query: st.warning("Please enter a query.") else: with st.spinner("Searching..."): query_embeddings = generator.generate_query_embeddings([user_query])[0] filter_conditions = [] if use_auto_filter: extracted = extract_filters_from_query(user_query) st.success(f"🤖 Extracted Filters: {extracted}") # 1. Burn Area if extracted.min_burn_area: filter_conditions.append( models.FieldCondition( key="burn_area", range=models.Range(gt=extracted.min_burn_area), ) ) # 2. Date Range if extracted.date_range: try: ex_start = extracted.date_range.start ex_end = extracted.date_range.end # Defaults if only one is provided if not ex_start: ex_start = "2018-01-01" if not ex_end: ex_end = "2021-12-31" filter_conditions.append( models.FieldCondition( key="acquisition_date", range=models.DatetimeRange( gte=f"{ex_start}T00:00:00", lte=f"{ex_end}T23:59:59" ), ) ) except Exception as e: st.warning(f"Could not parse extracted dates: {e}") # 3. Spatial if extracted.location: lat, lon = get_coordinates(extracted.location) if lat is not None: st.info( f"📍 Geocoded '{extracted.location}' to ({lat:.4f}, {lon:.4f})" ) ex_radius = ( extracted.radius_km if extracted.radius_km else 100.0 ) lat_delta = ex_radius / 111.0 lon_delta = ex_radius / (111.0 * np.cos(np.radians(lat))) filter_conditions.append( models.FieldCondition( key="centroid_lat", range=models.Range( gte=lat - lat_delta, lte=lat + lat_delta ), ) ) filter_conditions.append( models.FieldCondition( key="centroid_lon", range=models.Range( gte=lon - lon_delta, lte=lon + lon_delta ), ) ) else: # Manual Filters if min_area > 0: filter_conditions.append( models.FieldCondition( key="burn_area", range=models.Range(gt=min_area), ) ) start_dt = datetime.combine(min_date_input, datetime.min.time()) end_dt = datetime.combine(max_date_input, datetime.max.time()) filter_conditions.append( models.FieldCondition( key="acquisition_date", range=models.DatetimeRange( gte=start_dt.isoformat(), lte=end_dt.isoformat() ), ) ) if use_spatial: lat_delta = radius_km / 111.0 lon_delta = radius_km / (111.0 * np.cos(np.radians(lat_input))) filter_conditions.append( models.FieldCondition( key="centroid_lat", range=models.Range( gte=lat_input - lat_delta, lte=lat_input + lat_delta ), ) ) filter_conditions.append( models.FieldCondition( key="centroid_lon", range=models.Range( gte=lon_input - lon_delta, lte=lon_input + lon_delta ), ) ) query_filter = None if filter_conditions: query_filter = models.Filter(must=filter_conditions) results = client.query_points( collection_name=COLLECTION_NAME, query=query_embeddings, # List of vectors (tokens) using="colpali", limit=5, query_filter=query_filter, with_payload=True, ) st.session_state.search_results = results if "search_results" in st.session_state: results = st.session_state.search_results if not results.points: st.info("No results found.") else: col1, col2 = st.columns([1, 2]) with col1: st.subheader("Top Results") for i, point in enumerate(results.points): score = point.score date_str = point.payload.get("acquisition_date") lat = point.payload.get("centroid_lat") lon = point.payload.get("centroid_lon") area = point.payload.get("burn_area", "N/A") image_b64 = point.payload.get("image_base64") rgb_image_b64 = point.payload.get("rgb_image_base64") with st.expander(f"Result {i + 1} (Score: {score:.4f})", expanded=True): st.write(f"**Date:** {date_str}") st.write(f"**Area:** {area} ha") st.write(f"**Location:** {lat:.4f}, {lon:.4f}") display_b64 = rgb_image_b64 if rgb_image_b64 else image_b64 caption = "RGB Image" if rgb_image_b64 else "Burn Scar Mask" if display_b64: try: image_data = base64.b64decode(display_b64) image = Image.open(BytesIO(image_data)) if image.mode in ("I", "I;16"): arr = np.array(image) arr = (arr * 255).astype(np.uint8) image = Image.fromarray(arr).convert("RGB") st.image( image, caption=caption, width="stretch", ) except Exception as e: st.error(f"Error loading image: {e}") with col2: st.subheader("Map Visualization") first_lat = results.points[0].payload.get("centroid_lat") first_lon = results.points[0].payload.get("centroid_lon") m = folium.Map(location=[first_lat, first_lon], zoom_start=10) for point in results.points: lat = point.payload.get("centroid_lat") lon = point.payload.get("centroid_lon") date_str = point.payload.get("acquisition_date") score = point.score area = point.payload.get("burn_area", "N/A") image_b64 = point.payload.get("image_base64") rgb_image_b64 = point.payload.get("rgb_image_base64") popup_html = f"""
Date: {date_str}
Score: {score:.4f}
Area: {area} ha
""" display_b64 = rgb_image_b64 if rgb_image_b64 else image_b64 if display_b64: popup_html += f'' popup_html += "
" folium.Marker( [lat, lon], popup=folium.Popup(popup_html, max_width=250), tooltip=f"Result (Score: {score:.4f})", ).add_to(m) st_folium(m, width="100%", height=500)