mtyrrell commited on
Commit
335202a
·
1 Parent(s): c245449

routing changes

Browse files
Files changed (1) hide show
  1. app/main.py +142 -91
app/main.py CHANGED
@@ -16,6 +16,7 @@ from contextlib import asynccontextmanager
16
  import threading
17
  from langchain_core.runnables import RunnableLambda
18
  import tempfile
 
19
 
20
  from utils import getconfig
21
 
@@ -23,11 +24,53 @@ config = getconfig("params.cfg")
23
  RETRIEVER = config.get("retriever", "RETRIEVER", fallback="https://giz-chatfed-retriever.hf.space")
24
  GENERATOR = config.get("generator", "GENERATOR", fallback="https://giz-chatfed-generator.hf.space")
25
  INGESTOR = config.get("ingestor", "INGESTOR", fallback="https://mtyrrell-chatfed-ingestor.hf.space")
 
26
  MAX_CONTEXT_CHARS = config.get("general", "MAX_CONTEXT_CHARS")
27
 
28
  logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
29
  logger = logging.getLogger(__name__)
30
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
31
 
32
  # Models
33
  class GraphState(TypedDict):
@@ -42,6 +85,8 @@ class GraphState(TypedDict):
42
  file_content: Optional[bytes]
43
  filename: Optional[str]
44
  metadata: Optional[Dict[str, Any]]
 
 
45
 
46
  class ChatFedInput(TypedDict):
47
  query: str
@@ -61,9 +106,38 @@ class ChatFedOutput(TypedDict):
61
  class ChatUIInput(BaseModel):
62
  text: str
63
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
64
  # Module functions
65
  def ingest_node(state: GraphState) -> GraphState:
66
- """Process file through ingestor if file is provided"""
67
  start_time = datetime.now()
68
 
69
  # If no file provided, skip this step
@@ -71,10 +145,19 @@ def ingest_node(state: GraphState) -> GraphState:
71
  logger.info("No file provided, skipping ingestion")
72
  return {"ingestor_context": "", "metadata": state.get("metadata", {})}
73
 
74
- logger.info(f"Ingesting file: {state['filename']}")
 
75
 
76
  try:
77
- client = Client(INGESTOR)
 
 
 
 
 
 
 
 
78
 
79
  # Create a temporary file to upload
80
  with tempfile.NamedTemporaryFile(delete=False, suffix=os.path.splitext(state["filename"])[1]) as tmp_file:
@@ -82,9 +165,9 @@ def ingest_node(state: GraphState) -> GraphState:
82
  tmp_file_path = tmp_file.name
83
 
84
  try:
85
- # Call the ingestor's ingest endpoint - use gradio_client.file() for proper formatting
86
  ingestor_context = client.predict(
87
- file(tmp_file_path), # Use gradio_client.file() to properly format
88
  api_name="/ingest"
89
  )
90
 
@@ -103,7 +186,8 @@ def ingest_node(state: GraphState) -> GraphState:
103
  metadata.update({
104
  "ingestion_duration": duration,
105
  "ingestor_context_length": len(ingestor_context) if ingestor_context else 0,
106
- "ingestion_success": True
 
107
  })
108
 
109
  return {
@@ -122,52 +206,26 @@ def ingest_node(state: GraphState) -> GraphState:
122
  "ingestion_error": str(e)
123
  })
124
  return {"ingestor_context": "", "metadata": metadata}
 
 
 
 
125
 
126
- try:
127
- client = Client(INGESTOR)
128
-
129
- # Create a temporary file to upload
130
- with tempfile.NamedTemporaryFile(delete=False, suffix=os.path.splitext(state["filename"])[1]) as tmp_file:
131
- tmp_file.write(state["file_content"])
132
- tmp_file_path = tmp_file.name
133
-
134
- try:
135
- # Call the ingestor's ingest endpoint - returns context directly
136
- ingestor_context = client.predict(
137
- file=tmp_file_path,
138
- api_name="/ingest"
139
- )
140
-
141
- logger.info(f"Ingest result length: {len(ingestor_context) if ingestor_context else 0}")
142
-
143
- finally:
144
- # Clean up temporary file
145
- os.unlink(tmp_file_path)
146
-
147
- duration = (datetime.now() - start_time).total_seconds()
148
- metadata = state.get("metadata", {})
149
- metadata.update({
150
- "ingestion_duration": duration,
151
- "ingestor_context_length": len(ingestor_context) if ingestor_context else 0,
152
- "ingestion_success": True
153
- })
154
-
155
- return {
156
- "ingestor_context": ingestor_context,
157
- "metadata": metadata
158
- }
159
-
160
- except Exception as e:
161
- duration = (datetime.now() - start_time).total_seconds()
162
- logger.error(f"Ingestion failed: {str(e)}")
163
-
164
- metadata = state.get("metadata", {})
165
- metadata.update({
166
- "ingestion_duration": duration,
167
- "ingestion_success": False,
168
- "ingestion_error": str(e)
169
- })
170
- return {"ingestor_context": "", "metadata": metadata}
171
 
172
  def retrieve_node(state: GraphState) -> GraphState:
173
  start_time = datetime.now()
@@ -260,15 +318,41 @@ def generate_node(state: GraphState) -> GraphState:
260
  })
261
  return {"result": f"Error: {str(e)}", "metadata": metadata}
262
 
263
- # Updated graph with ingest node
 
 
 
 
 
 
264
  workflow = StateGraph(GraphState)
 
265
  workflow.add_node("ingest", ingest_node)
 
266
  workflow.add_node("retrieve", retrieve_node)
267
  workflow.add_node("generate", generate_node)
268
- workflow.add_edge(START, "ingest")
269
- workflow.add_edge("ingest", "retrieve")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
270
  workflow.add_edge("retrieve", "generate")
271
  workflow.add_edge("generate", END)
 
 
 
 
272
  compiled_graph = workflow.compile()
273
 
274
  def process_query_core(
@@ -299,6 +383,8 @@ def process_query_core(
299
  "year_filter": year_filter or "",
300
  "file_content": file_content,
301
  "filename": filename,
 
 
302
  "metadata": {
303
  "session_id": session_id,
304
  "user_id": user_id,
@@ -404,12 +490,12 @@ def process_query_langserve(input_data: ChatFedInput) -> ChatFedOutput:
404
  def create_gradio_interface():
405
  with gr.Blocks(title="ChatFed Orchestrator") as demo:
406
  gr.Markdown("# ChatFed Orchestrator")
407
- gr.Markdown("Upload documents (PDF/DOCX) alongside your queries for enhanced context. MCP endpoints available at `/gradio_api/mcp/sse`")
408
 
409
  with gr.Row():
410
  with gr.Column():
411
  query_input = gr.Textbox(label="Query", lines=2, placeholder="Enter your question...")
412
- file_input = gr.File(label="Upload Document (PDF/DOCX)", file_types=[".pdf", ".docx"])
413
 
414
  with gr.Accordion("Filters (Optional)", open=False):
415
  reports_filter_input = gr.Textbox(label="Reports Filter", placeholder="e.g., annual_reports")
@@ -496,41 +582,6 @@ async def chatfed_with_file(
496
 
497
  return ChatFedOutput(result=result["result"], metadata=result["metadata"])
498
 
499
- # Additional endpoint for file uploads via API
500
- @app.post("/chatfed-with-file")
501
- async def chatfed_with_file(
502
- query: str = Form(...),
503
- file: Optional[UploadFile] = File(None),
504
- reports_filter: Optional[str] = Form(""),
505
- sources_filter: Optional[str] = Form(""),
506
- subtype_filter: Optional[str] = Form(""),
507
- year_filter: Optional[str] = Form(""),
508
- session_id: Optional[str] = Form(None),
509
- user_id: Optional[str] = Form(None)
510
- ):
511
- """Endpoint for queries with optional file attachments"""
512
- file_content = None
513
- filename = None
514
-
515
- if file:
516
- file_content = await file.read()
517
- filename = file.filename
518
-
519
- result = process_query_core(
520
- query=query,
521
- reports_filter=reports_filter,
522
- sources_filter=sources_filter,
523
- subtype_filter=subtype_filter,
524
- year_filter=year_filter,
525
- file_content=file_content,
526
- filename=filename,
527
- session_id=session_id,
528
- user_id=user_id,
529
- return_metadata=True
530
- )
531
-
532
- return ChatFedOutput(result=result["result"], metadata=result["metadata"])
533
-
534
  # LangServe routes (these are the main endpoints)
535
  add_routes(
536
  app,
 
16
  import threading
17
  from langchain_core.runnables import RunnableLambda
18
  import tempfile
19
+ import mimetypes
20
 
21
  from utils import getconfig
22
 
 
24
  RETRIEVER = config.get("retriever", "RETRIEVER", fallback="https://giz-chatfed-retriever.hf.space")
25
  GENERATOR = config.get("generator", "GENERATOR", fallback="https://giz-chatfed-generator.hf.space")
26
  INGESTOR = config.get("ingestor", "INGESTOR", fallback="https://mtyrrell-chatfed-ingestor.hf.space")
27
+ GEOJSON_INGESTOR = config.get("ingestor", "GEOJSON_INGESTOR", fallback="https://giz-eudr-chatfed-ingestor.hf.space")
28
  MAX_CONTEXT_CHARS = config.get("general", "MAX_CONTEXT_CHARS")
29
 
30
  logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
31
  logger = logging.getLogger(__name__)
32
 
33
+ # File type detection
34
+ def detect_file_type(filename: str, file_content: bytes = None) -> str:
35
+ """Detect file type based on extension and content"""
36
+ if not filename:
37
+ return "unknown"
38
+
39
+ # Get file extension
40
+ _, ext = os.path.splitext(filename.lower())
41
+
42
+ # Define file type mappings
43
+ file_type_mappings = {
44
+ '.geojson': 'geojson',
45
+ '.json': 'json', # Could be geojson, will check content
46
+ '.pdf': 'text',
47
+ '.docx': 'text',
48
+ '.doc': 'text',
49
+ '.txt': 'text',
50
+ '.md': 'text',
51
+ '.csv': 'text',
52
+ '.xlsx': 'text',
53
+ '.xls': 'text'
54
+ }
55
+
56
+ detected_type = file_type_mappings.get(ext, 'unknown')
57
+
58
+ # For JSON files, check if it's actually GeoJSON
59
+ if detected_type == 'json' and file_content:
60
+ try:
61
+ import json
62
+ content_str = file_content.decode('utf-8')
63
+ data = json.loads(content_str)
64
+ # Check if it has GeoJSON structure
65
+ if isinstance(data, dict) and ('type' in data and data.get('type') == 'FeatureCollection'):
66
+ detected_type = 'geojson'
67
+ elif isinstance(data, dict) and ('type' in data and data.get('type') in ['Feature', 'Point', 'LineString', 'Polygon', 'MultiPoint', 'MultiLineString', 'MultiPolygon', 'GeometryCollection']):
68
+ detected_type = 'geojson'
69
+ except:
70
+ pass # Keep as json if parsing fails
71
+
72
+ logger.info(f"Detected file type: {detected_type} for file: {filename}")
73
+ return detected_type
74
 
75
  # Models
76
  class GraphState(TypedDict):
 
85
  file_content: Optional[bytes]
86
  filename: Optional[str]
87
  metadata: Optional[Dict[str, Any]]
88
+ file_type: Optional[str]
89
+ workflow_type: Optional[str] # 'standard' or 'geojson_direct'
90
 
91
  class ChatFedInput(TypedDict):
92
  query: str
 
106
  class ChatUIInput(BaseModel):
107
  text: str
108
 
109
+ # File type detection node
110
+ def detect_file_type_node(state: GraphState) -> GraphState:
111
+ """Detect file type and determine workflow"""
112
+ file_type = "unknown"
113
+ workflow_type = "standard"
114
+
115
+ if state.get("file_content") and state.get("filename"):
116
+ file_type = detect_file_type(state["filename"], state["file_content"])
117
+
118
+ # Determine workflow based on file type
119
+ if file_type == "geojson":
120
+ workflow_type = "geojson_direct"
121
+ else:
122
+ workflow_type = "standard"
123
+
124
+ logger.info(f"File type: {file_type}, Workflow: {workflow_type}")
125
+
126
+ metadata = state.get("metadata", {})
127
+ metadata.update({
128
+ "file_type": file_type,
129
+ "workflow_type": workflow_type
130
+ })
131
+
132
+ return {
133
+ "file_type": file_type,
134
+ "workflow_type": workflow_type,
135
+ "metadata": metadata
136
+ }
137
+
138
  # Module functions
139
  def ingest_node(state: GraphState) -> GraphState:
140
+ """Process file through appropriate ingestor based on file type"""
141
  start_time = datetime.now()
142
 
143
  # If no file provided, skip this step
 
145
  logger.info("No file provided, skipping ingestion")
146
  return {"ingestor_context": "", "metadata": state.get("metadata", {})}
147
 
148
+ file_type = state.get("file_type", "unknown")
149
+ logger.info(f"Ingesting {file_type} file: {state['filename']}")
150
 
151
  try:
152
+ # Choose ingestor based on file type
153
+ if file_type == "geojson":
154
+ ingestor_url = GEOJSON_INGESTOR
155
+ logger.info(f"Using GeoJSON ingestor: {ingestor_url}")
156
+ else:
157
+ ingestor_url = INGESTOR
158
+ logger.info(f"Using standard ingestor: {ingestor_url}")
159
+
160
+ client = Client(ingestor_url)
161
 
162
  # Create a temporary file to upload
163
  with tempfile.NamedTemporaryFile(delete=False, suffix=os.path.splitext(state["filename"])[1]) as tmp_file:
 
165
  tmp_file_path = tmp_file.name
166
 
167
  try:
168
+ # Call the ingestor's ingest endpoint
169
  ingestor_context = client.predict(
170
+ file(tmp_file_path),
171
  api_name="/ingest"
172
  )
173
 
 
186
  metadata.update({
187
  "ingestion_duration": duration,
188
  "ingestor_context_length": len(ingestor_context) if ingestor_context else 0,
189
+ "ingestion_success": True,
190
+ "ingestor_used": ingestor_url
191
  })
192
 
193
  return {
 
206
  "ingestion_error": str(e)
207
  })
208
  return {"ingestor_context": "", "metadata": metadata}
209
+
210
+ def geojson_direct_result_node(state: GraphState) -> GraphState:
211
+ """For GeoJSON files, return ingestor results directly without retrieval/generation"""
212
+ logger.info("Processing GeoJSON file - returning direct results")
213
 
214
+ ingestor_context = state.get("ingestor_context", "")
215
+
216
+ # For GeoJSON files, the ingestor result is the final result
217
+ result = ingestor_context if ingestor_context else "No results from GeoJSON processing."
218
+
219
+ metadata = state.get("metadata", {})
220
+ metadata.update({
221
+ "processing_type": "geojson_direct",
222
+ "result_length": len(result)
223
+ })
224
+
225
+ return {
226
+ "result": result,
227
+ "metadata": metadata
228
+ }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
229
 
230
  def retrieve_node(state: GraphState) -> GraphState:
231
  start_time = datetime.now()
 
318
  })
319
  return {"result": f"Error: {str(e)}", "metadata": metadata}
320
 
321
+ # Conditional routing function
322
+ def route_workflow(state: GraphState) -> str:
323
+ """Route to appropriate workflow based on file type"""
324
+ workflow_type = state.get("workflow_type", "standard")
325
+ return workflow_type
326
+
327
+ # Updated graph with conditional routing
328
  workflow = StateGraph(GraphState)
329
+ workflow.add_node("detect_file_type", detect_file_type_node)
330
  workflow.add_node("ingest", ingest_node)
331
+ workflow.add_node("geojson_direct", geojson_direct_result_node)
332
  workflow.add_node("retrieve", retrieve_node)
333
  workflow.add_node("generate", generate_node)
334
+
335
+ # Add edges
336
+ workflow.add_edge(START, "detect_file_type")
337
+ workflow.add_edge("detect_file_type", "ingest")
338
+
339
+ # Conditional routing after ingestion
340
+ workflow.add_conditional_edges(
341
+ "ingest",
342
+ route_workflow,
343
+ {
344
+ "geojson_direct": "geojson_direct",
345
+ "standard": "retrieve"
346
+ }
347
+ )
348
+
349
+ # Standard workflow
350
  workflow.add_edge("retrieve", "generate")
351
  workflow.add_edge("generate", END)
352
+
353
+ # GeoJSON direct workflow
354
+ workflow.add_edge("geojson_direct", END)
355
+
356
  compiled_graph = workflow.compile()
357
 
358
  def process_query_core(
 
383
  "year_filter": year_filter or "",
384
  "file_content": file_content,
385
  "filename": filename,
386
+ "file_type": "unknown",
387
+ "workflow_type": "standard",
388
  "metadata": {
389
  "session_id": session_id,
390
  "user_id": user_id,
 
490
  def create_gradio_interface():
491
  with gr.Blocks(title="ChatFed Orchestrator") as demo:
492
  gr.Markdown("# ChatFed Orchestrator")
493
+ gr.Markdown("Upload documents (PDF/DOCX/GeoJSON) alongside your queries for enhanced context. MCP endpoints available at `/gradio_api/mcp/sse`")
494
 
495
  with gr.Row():
496
  with gr.Column():
497
  query_input = gr.Textbox(label="Query", lines=2, placeholder="Enter your question...")
498
+ file_input = gr.File(label="Upload Document (PDF/DOCX/GeoJSON)", file_types=[".pdf", ".docx", ".geojson", ".json"])
499
 
500
  with gr.Accordion("Filters (Optional)", open=False):
501
  reports_filter_input = gr.Textbox(label="Reports Filter", placeholder="e.g., annual_reports")
 
582
 
583
  return ChatFedOutput(result=result["result"], metadata=result["metadata"])
584
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
585
  # LangServe routes (these are the main endpoints)
586
  add_routes(
587
  app,