mahimairaja's picture
feat: added streamlit app
b97f4ae
raw
history blame
10.5 kB
import base64
import os
from datetime import date, datetime
from io import BytesIO
import folium
import numpy as np
import streamlit as st
from dotenv import load_dotenv
from geopy.geocoders import Nominatim
from PIL import Image
from qdrant_client import QdrantClient, models
from streamlit_folium import st_folium
from utils.embedding_utils import ColPaliEmbeddingGenerator
load_dotenv()
st.set_page_config(
page_title="Geo-Spatial Vector Search with Qdrant", page_icon="🌍", layout="wide"
)
COLLECTION_NAME = "hls_burn_scars_data_colpali_rgb"
REGIONS = [
"Select a Region",
"Alabama, USA",
"Alaska, USA",
"Arizona, USA",
"Arkansas, USA",
"California, USA",
"Colorado, USA",
"Connecticut, USA",
"Delaware, USA",
"Florida, USA",
"Georgia, USA",
"Hawaii, USA",
"Idaho, USA",
"Illinois, USA",
"Indiana, USA",
"Iowa, USA",
"Kansas, USA",
"Kentucky, USA",
"Louisiana, USA",
"Maine, USA",
"Maryland, USA",
"Massachusetts, USA",
"Michigan, USA",
"Minnesota, USA",
"Mississippi, USA",
"Missouri, USA",
"Montana, USA",
"Nebraska, USA",
"Nevada, USA",
"New Hampshire, USA",
"New Jersey, USA",
"New Mexico, USA",
"New York, USA",
"North Carolina, USA",
"North Dakota, USA",
"Ohio, USA",
"Oklahoma, USA",
"Oregon, USA",
"Pennsylvania, USA",
"Rhode Island, USA",
"South Carolina, USA",
"South Dakota, USA",
"Tennessee, USA",
"Texas, USA",
"Utah, USA",
"Vermont, USA",
"Virginia, USA",
"Washington, USA",
"West Virginia, USA",
"Wisconsin, USA",
"Wyoming, USA",
"Alberta, Canada",
"British Columbia, Canada",
"Manitoba, Canada",
"New Brunswick, Canada",
"Newfoundland and Labrador, Canada",
"Nova Scotia, Canada",
"Ontario, Canada",
"Prince Edward Island, Canada",
"Quebec, Canada",
"Saskatchewan, Canada",
]
@st.cache_resource
def get_qdrant_client():
return QdrantClient(
url=os.getenv("QDRANT_URL"),
api_key=os.getenv("QDRANT_API_KEY"),
)
@st.cache_resource
def get_embedding_model():
with st.spinner("Loading ColPali model... (this may take a minute)"):
generator = ColPaliEmbeddingGenerator()
return generator
@st.cache_data
def get_coordinates(place_name):
"""Geocode a place name to (lat, lon)."""
try:
geolocator = Nominatim(user_agent="geo_spatial_chat_app")
location = geolocator.geocode(place_name)
if location:
return location.latitude, location.longitude
except Exception as e:
st.error(f"Error geocoding {place_name}: {e}")
return None, None
client = get_qdrant_client()
generator = get_embedding_model()
st.title("Geo-Spatial Vector Search with Qdrant")
st.markdown(
"Find burn scars using natural language queries with Multi-Vector Retrieval."
)
with st.sidebar:
st.header("Filters")
min_area = st.slider(
"Minimum Burn Area (Hectares)",
min_value=0,
max_value=1000,
value=0,
step=10,
help="Filter results to show only burn scars larger than this value.",
)
st.subheader("Date Range")
min_date_input = st.date_input("Start Date", value=date(2018, 1, 1))
max_date_input = st.date_input("End Date", value=date(2021, 12, 31))
st.subheader("Spatial Filter")
use_spatial = st.checkbox("Filter by Location")
if "lat_input_widget" not in st.session_state:
st.session_state.lat_input_widget = 37.0
if "lon_input_widget" not in st.session_state:
st.session_state.lon_input_widget = -120.0
lat_input = st.session_state.lat_input_widget
lon_input = st.session_state.lon_input_widget
radius_km = 100.0
if use_spatial:
def on_region_change():
region = st.session_state.selected_region
if region and region != "Select a Region":
lat, lon = get_coordinates(region)
if lat is not None:
st.session_state.lat_input_widget = lat
st.session_state.lon_input_widget = lon
st.selectbox(
"Select State/Region (US/Canada)",
REGIONS,
key="selected_region",
on_change=on_region_change,
)
col_lat, col_lon = st.columns(2)
with col_lat:
lat_input = st.number_input(
"Latitude", format="%.4f", key="lat_input_widget"
)
with col_lon:
lon_input = st.number_input(
"Longitude", format="%.4f", key="lon_input_widget"
)
radius_km = st.slider("Radius (km)", 10, 5000, 230)
user_query = st.text_input(
"Enter your query:", "Help me find the burn scars that have more than 100 hectares"
)
if st.button("Search"):
if not user_query:
st.warning("Please enter a query.")
else:
with st.spinner("Searching..."):
query_embeddings = generator.generate_query_embeddings([user_query])[0]
filter_conditions = []
if min_area > 0:
filter_conditions.append(
models.FieldCondition(
key="burn_area",
range=models.Range(gt=min_area),
)
)
start_dt = datetime.combine(min_date_input, datetime.min.time())
end_dt = datetime.combine(max_date_input, datetime.max.time())
filter_conditions.append(
models.FieldCondition(
key="acquisition_date",
range=models.DatetimeRange(
gte=start_dt.isoformat(), lte=end_dt.isoformat()
),
)
)
if use_spatial:
lat_delta = radius_km / 111.0
lon_delta = radius_km / (111.0 * np.cos(np.radians(lat_input)))
filter_conditions.append(
models.FieldCondition(
key="centroid_lat",
range=models.Range(
gte=lat_input - lat_delta, lte=lat_input + lat_delta
),
)
)
filter_conditions.append(
models.FieldCondition(
key="centroid_lon",
range=models.Range(
gte=lon_input - lon_delta, lte=lon_input + lon_delta
),
)
)
query_filter = None
if filter_conditions:
query_filter = models.Filter(must=filter_conditions)
results = client.query_points(
collection_name=COLLECTION_NAME,
query=query_embeddings, # List of vectors (tokens)
using="colpali",
limit=5,
query_filter=query_filter,
with_payload=True,
)
st.session_state.search_results = results
if "search_results" in st.session_state:
results = st.session_state.search_results
if not results.points:
st.info("No results found.")
else:
col1, col2 = st.columns([1, 2])
with col1:
st.subheader("Top Results")
for i, point in enumerate(results.points):
score = point.score
date_str = point.payload.get("acquisition_date")
lat = point.payload.get("centroid_lat")
lon = point.payload.get("centroid_lon")
area = point.payload.get("burn_area", "N/A")
image_b64 = point.payload.get("image_base64")
rgb_image_b64 = point.payload.get("rgb_image_base64")
with st.expander(f"Result {i + 1} (Score: {score:.4f})", expanded=True):
st.write(f"**Date:** {date_str}")
st.write(f"**Area:** {area} ha")
st.write(f"**Location:** {lat:.4f}, {lon:.4f}")
display_b64 = rgb_image_b64 if rgb_image_b64 else image_b64
caption = "RGB Image" if rgb_image_b64 else "Burn Scar Mask"
if display_b64:
try:
image_data = base64.b64decode(display_b64)
image = Image.open(BytesIO(image_data))
if image.mode in ("I", "I;16"):
arr = np.array(image)
arr = (arr * 255).astype(np.uint8)
image = Image.fromarray(arr).convert("RGB")
st.image(
image,
caption=caption,
width="stretch",
)
except Exception as e:
st.error(f"Error loading image: {e}")
with col2:
st.subheader("Map Visualization")
first_lat = results.points[0].payload.get("centroid_lat")
first_lon = results.points[0].payload.get("centroid_lon")
m = folium.Map(location=[first_lat, first_lon], zoom_start=10)
for point in results.points:
lat = point.payload.get("centroid_lat")
lon = point.payload.get("centroid_lon")
date_str = point.payload.get("acquisition_date")
score = point.score
area = point.payload.get("burn_area", "N/A")
image_b64 = point.payload.get("image_base64")
rgb_image_b64 = point.payload.get("rgb_image_base64")
popup_html = f"""
<div style="width:200px">
<b>Date:</b> {date_str}<br>
<b>Score:</b> {score:.4f}<br>
<b>Area:</b> {area} ha<br>
"""
display_b64 = rgb_image_b64 if rgb_image_b64 else image_b64
if display_b64:
popup_html += f'<img src="data:image/png;base64,{display_b64}" style="width:100%;margin-top:10px;">'
popup_html += "</div>"
folium.Marker(
[lat, lon],
popup=folium.Popup(popup_html, max_width=250),
tooltip=f"Result (Score: {score:.4f})",
).add_to(m)
st_folium(m, width="100%", height=500)