mahimairaja commited on
Commit
0849f9e
Β·
1 Parent(s): 38b5248

feat: updated streamlit app to use structured filters

Browse files
Files changed (1) hide show
  1. app.py +125 -36
app.py CHANGED
@@ -13,6 +13,7 @@ from qdrant_client import QdrantClient, models
13
  from streamlit_folium import st_folium
14
 
15
  from utils.embedding_utils import ColPaliEmbeddingGenerator
 
16
 
17
  load_dotenv()
18
 
@@ -123,6 +124,14 @@ st.markdown(
123
 
124
  with st.sidebar:
125
  st.header("Filters")
 
 
 
 
 
 
 
 
126
 
127
  min_area = st.slider(
128
  "Minimum Burn Area (Hectares)",
@@ -131,13 +140,18 @@ with st.sidebar:
131
  value=0,
132
  step=10,
133
  help="Filter results to show only burn scars larger than this value.",
 
134
  )
135
 
136
  st.subheader("Date Range")
137
- min_date_input = st.date_input("Start Date", value=date(2018, 1, 1))
138
- max_date_input = st.date_input("End Date", value=date(2021, 12, 31))
 
 
 
 
139
  st.subheader("Spatial Filter")
140
- use_spatial = st.checkbox("Filter by Location")
141
 
142
  if "lat_input_widget" not in st.session_state:
143
  st.session_state.lat_input_widget = 37.0
@@ -168,14 +182,20 @@ with st.sidebar:
168
  col_lat, col_lon = st.columns(2)
169
  with col_lat:
170
  lat_input = st.number_input(
171
- "Latitude", format="%.4f", key="lat_input_widget"
 
 
 
172
  )
173
  with col_lon:
174
  lon_input = st.number_input(
175
- "Longitude", format="%.4f", key="lon_input_widget"
 
 
 
176
  )
177
 
178
- radius_km = st.slider("Radius (km)", 10, 5000, 230)
179
 
180
  user_query = st.text_input(
181
  "Enter your query:", "Help me find the burn scars that have more than 100 hectares"
@@ -190,46 +210,115 @@ if st.button("Search"):
190
 
191
  filter_conditions = []
192
 
193
- if min_area > 0:
194
- filter_conditions.append(
195
- models.FieldCondition(
196
- key="burn_area",
197
- range=models.Range(gt=min_area),
 
 
 
 
 
 
198
  )
199
- )
200
-
201
- start_dt = datetime.combine(min_date_input, datetime.min.time())
202
- end_dt = datetime.combine(max_date_input, datetime.max.time())
203
 
204
- filter_conditions.append(
205
- models.FieldCondition(
206
- key="acquisition_date",
207
- range=models.DatetimeRange(
208
- gte=start_dt.isoformat(), lte=end_dt.isoformat()
209
- ),
210
- )
211
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
212
 
213
- if use_spatial:
214
- lat_delta = radius_km / 111.0
215
- lon_delta = radius_km / (111.0 * np.cos(np.radians(lat_input)))
216
 
217
  filter_conditions.append(
218
  models.FieldCondition(
219
- key="centroid_lat",
220
- range=models.Range(
221
- gte=lat_input - lat_delta, lte=lat_input + lat_delta
222
  ),
223
  )
224
  )
225
- filter_conditions.append(
226
- models.FieldCondition(
227
- key="centroid_lon",
228
- range=models.Range(
229
- gte=lon_input - lon_delta, lte=lon_input + lon_delta
230
- ),
 
 
 
 
 
 
 
 
 
 
 
 
 
 
231
  )
232
- )
233
 
234
  query_filter = None
235
  if filter_conditions:
 
13
  from streamlit_folium import st_folium
14
 
15
  from utils.embedding_utils import ColPaliEmbeddingGenerator
16
+ from utils.llm_utils import extract_filters_from_query
17
 
18
  load_dotenv()
19
 
 
124
 
125
  with st.sidebar:
126
  st.header("Filters")
127
+ use_auto_filter = st.toggle(
128
+ "πŸ€– Enable Auto-Filter",
129
+ value=False,
130
+ help="Use AI to automatically extract filters from your query.",
131
+ )
132
+
133
+ if use_auto_filter:
134
+ st.info("Filters will be extracted from your query automatically.")
135
 
136
  min_area = st.slider(
137
  "Minimum Burn Area (Hectares)",
 
140
  value=0,
141
  step=10,
142
  help="Filter results to show only burn scars larger than this value.",
143
+ disabled=use_auto_filter,
144
  )
145
 
146
  st.subheader("Date Range")
147
+ min_date_input = st.date_input(
148
+ "Start Date", value=date(2018, 1, 1), disabled=use_auto_filter
149
+ )
150
+ max_date_input = st.date_input(
151
+ "End Date", value=date(2021, 12, 31), disabled=use_auto_filter
152
+ )
153
  st.subheader("Spatial Filter")
154
+ use_spatial = st.checkbox("Filter by Location", disabled=use_auto_filter)
155
 
156
  if "lat_input_widget" not in st.session_state:
157
  st.session_state.lat_input_widget = 37.0
 
182
  col_lat, col_lon = st.columns(2)
183
  with col_lat:
184
  lat_input = st.number_input(
185
+ "Latitude",
186
+ format="%.4f",
187
+ key="lat_input_widget",
188
+ disabled=use_auto_filter,
189
  )
190
  with col_lon:
191
  lon_input = st.number_input(
192
+ "Longitude",
193
+ format="%.4f",
194
+ key="lon_input_widget",
195
+ disabled=use_auto_filter,
196
  )
197
 
198
+ radius_km = st.slider("Radius (km)", 10, 5000, 230, disabled=use_auto_filter)
199
 
200
  user_query = st.text_input(
201
  "Enter your query:", "Help me find the burn scars that have more than 100 hectares"
 
210
 
211
  filter_conditions = []
212
 
213
+ if use_auto_filter:
214
+ extracted = extract_filters_from_query(user_query)
215
+ st.success(f"πŸ€– Extracted Filters: {extracted}")
216
+
217
+ # 1. Burn Area
218
+ if extracted.min_burn_area:
219
+ filter_conditions.append(
220
+ models.FieldCondition(
221
+ key="burn_area",
222
+ range=models.Range(gt=extracted.min_burn_area),
223
+ )
224
  )
 
 
 
 
225
 
226
+ # 2. Date Range
227
+ if extracted.date_range:
228
+ try:
229
+ ex_start = extracted.date_range.start
230
+ ex_end = extracted.date_range.end
231
+
232
+ # Defaults if only one is provided
233
+ if not ex_start:
234
+ ex_start = "2018-01-01"
235
+ if not ex_end:
236
+ ex_end = "2021-12-31"
237
+
238
+ filter_conditions.append(
239
+ models.FieldCondition(
240
+ key="acquisition_date",
241
+ range=models.DatetimeRange(
242
+ gte=f"{ex_start}T00:00:00", lte=f"{ex_end}T23:59:59"
243
+ ),
244
+ )
245
+ )
246
+ except Exception as e:
247
+ st.warning(f"Could not parse extracted dates: {e}")
248
+
249
+ # 3. Spatial
250
+ if extracted.location:
251
+ lat, lon = get_coordinates(extracted.location)
252
+ if lat is not None:
253
+ st.info(
254
+ f"πŸ“ Geocoded '{extracted.location}' to ({lat:.4f}, {lon:.4f})"
255
+ )
256
+
257
+ ex_radius = (
258
+ extracted.radius_km if extracted.radius_km else 100.0
259
+ )
260
+
261
+ lat_delta = ex_radius / 111.0
262
+ lon_delta = ex_radius / (111.0 * np.cos(np.radians(lat)))
263
+
264
+ filter_conditions.append(
265
+ models.FieldCondition(
266
+ key="centroid_lat",
267
+ range=models.Range(
268
+ gte=lat - lat_delta, lte=lat + lat_delta
269
+ ),
270
+ )
271
+ )
272
+ filter_conditions.append(
273
+ models.FieldCondition(
274
+ key="centroid_lon",
275
+ range=models.Range(
276
+ gte=lon - lon_delta, lte=lon + lon_delta
277
+ ),
278
+ )
279
+ )
280
+ else:
281
+ # Manual Filters
282
+ if min_area > 0:
283
+ filter_conditions.append(
284
+ models.FieldCondition(
285
+ key="burn_area",
286
+ range=models.Range(gt=min_area),
287
+ )
288
+ )
289
 
290
+ start_dt = datetime.combine(min_date_input, datetime.min.time())
291
+ end_dt = datetime.combine(max_date_input, datetime.max.time())
 
292
 
293
  filter_conditions.append(
294
  models.FieldCondition(
295
+ key="acquisition_date",
296
+ range=models.DatetimeRange(
297
+ gte=start_dt.isoformat(), lte=end_dt.isoformat()
298
  ),
299
  )
300
  )
301
+
302
+ if use_spatial:
303
+ lat_delta = radius_km / 111.0
304
+ lon_delta = radius_km / (111.0 * np.cos(np.radians(lat_input)))
305
+
306
+ filter_conditions.append(
307
+ models.FieldCondition(
308
+ key="centroid_lat",
309
+ range=models.Range(
310
+ gte=lat_input - lat_delta, lte=lat_input + lat_delta
311
+ ),
312
+ )
313
+ )
314
+ filter_conditions.append(
315
+ models.FieldCondition(
316
+ key="centroid_lon",
317
+ range=models.Range(
318
+ gte=lon_input - lon_delta, lte=lon_input + lon_delta
319
+ ),
320
+ )
321
  )
 
322
 
323
  query_filter = None
324
  if filter_conditions: