import json import os from typing import Optional from huggingface_hub import InferenceClient from pydantic import BaseModel, Field class DateRange(BaseModel): start: Optional[str] = Field( None, description="Start date in YYYY-MM-DD format. e.g. 2020-01-01" ) end: Optional[str] = Field( None, description="End date in YYYY-MM-DD format. e.g. 2020-12-31" ) class SearchFilters(BaseModel): min_burn_area: Optional[float] = Field( None, description="Minimum burn area in hectares. e.g. 100" ) date_range: Optional[DateRange] = Field( None, description="Date range for the search." ) location: Optional[str] = Field( None, description="Geographic location or region name. e.g. California, USA" ) radius_km: Optional[float] = Field( None, description="Search radius in kilometers around the location. Default to 100 if vague.", ) def extract_filters_from_query( query: str, model: str = "Qwen/Qwen2.5-72B-Instruct" ) -> SearchFilters: """ Extracts structured filters from a natural language query using Hugging Face Inference Client. """ api_key = os.getenv("HF_TOKEN") if not api_key: print("Warning: HF_TOKEN not found in environment variables.") return SearchFilters() client = InferenceClient(api_key=api_key) schema_definition = SearchFilters.model_json_schema() system_prompt = f""" You are a helpful assistant that extracts search filters from a natural language query about burn stars/fires. You must return a VALID JSON object that matches the following schema: {json.dumps(schema_definition, indent=2)} Do not add any markdown formatting (like ```json). Just return the raw JSON string. If a field is not mentioned or cannot be inferred, leave it as null. For 'location', try to catch specific state or region names. For 'date_range', if the user says '2020', infer start=2020-01-01 and end=2020-12-31. """ messages = [ {"role": "system", "content": system_prompt}, {"role": "user", "content": query}, ] try: response = client.chat_completion( model=model, messages=messages, max_tokens=500, temperature=0.1, # response_format={"type": "json_object"} ) content = response.choices[0].message.content.strip() if content.startswith("```json"): content = content[7:] if content.endswith("```"): content = content[:-3] content = content.strip() # print(f"DEBUG: LLM Response content: {content}") data = json.loads(content) return SearchFilters(**data) except Exception as e: print(f"Error extracting filters: {e}") return SearchFilters()