mtyrrell commited on
Commit
f852f01
·
1 Parent(s): caa8809
Files changed (2) hide show
  1. app.py +108 -9
  2. utils/generator.py +121 -23
app.py CHANGED
@@ -1,7 +1,10 @@
1
  import gradio as gr
2
  import asyncio
3
  import logging
4
- from utils.generator import generate, generate_streaming
 
 
 
5
 
6
  # Configure logging
7
  logging.basicConfig(
@@ -14,6 +17,100 @@ logging.basicConfig(
14
  )
15
  logger = logging.getLogger(__name__)
16
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
17
  # ---------------------------------------------------------------------
18
  # Wrapper function to handle async streaming for Gradio
19
  # ---------------------------------------------------------------------
@@ -22,7 +119,7 @@ def generate_streaming_wrapper(query: str, context: str):
22
  logger.info(f"Starting generation request - Query length: {len(query)}, Context length: {len(context)}")
23
 
24
  async def _async_generator():
25
- async for chunk in generate_streaming(query, context):
26
  yield chunk
27
 
28
  # Create a new event loop for this thread
@@ -81,13 +178,15 @@ ui = gr.Interface(
81
  api_name="generate"
82
  )
83
 
 
 
 
84
  # Launch with MCP server enabled
85
  if __name__ == "__main__":
 
86
  logger.info("Starting ChatFed Generation Module server")
87
- logger.info("Server will be available at http://0.0.0.0:7860")
88
- ui.launch(
89
- server_name="0.0.0.0",
90
- server_port=7860,
91
- # mcp_server=True,
92
- show_error=True
93
- )
 
1
  import gradio as gr
2
  import asyncio
3
  import logging
4
+ import json
5
+ from fastapi import FastAPI, Request
6
+ from fastapi.responses import StreamingResponse
7
+ from utils.generator import generate_streaming, generate
8
 
9
  # Configure logging
10
  logging.basicConfig(
 
17
  )
18
  logger = logging.getLogger(__name__)
19
 
20
+ # ---------------------------------------------------------------------
21
+ # FastAPI app for ChatUI endpoints
22
+ # ---------------------------------------------------------------------
23
+ app = FastAPI(title="ChatFed Generator", version="1.0.0")
24
+
25
+ @app.post("/generate")
26
+ async def generate_endpoint(request: Request):
27
+ """
28
+ Non-streaming generation endpoint for ChatUI format.
29
+
30
+ Expected request body:
31
+ {
32
+ "query": "user question",
33
+ "context": [...] // list of retrieval results
34
+ }
35
+
36
+ Returns ChatUI format:
37
+ {
38
+ "answer": "response with citations [1][2]",
39
+ "sources": [{"link": "doc://...", "title": "..."}]
40
+ }
41
+ """
42
+ try:
43
+ body = await request.json()
44
+ query = body.get("query", "")
45
+ context = body.get("context", [])
46
+
47
+ result = await generate(query, context, chatui_format=True)
48
+ return result
49
+
50
+ except Exception as e:
51
+ logger.exception("Generation endpoint failed")
52
+ return {"error": str(e)}
53
+
54
+ @app.post("/generate/stream")
55
+ async def generate_stream_endpoint(request: Request):
56
+ """
57
+ Streaming generation endpoint for ChatUI format.
58
+
59
+ Expected request body:
60
+ {
61
+ "query": "user question",
62
+ "context": [...] // list of retrieval results
63
+ }
64
+
65
+ Returns Server-Sent Events in ChatUI format:
66
+ event: data
67
+ data: "response chunk"
68
+
69
+ event: sources
70
+ data: {"sources": [...]}
71
+
72
+ event: end
73
+ """
74
+ try:
75
+ body = await request.json()
76
+ query = body.get("query", "")
77
+ context = body.get("context", [])
78
+
79
+ async def event_stream():
80
+ async for event in generate_streaming(query, context, chatui_format=True):
81
+ event_type = event["event"]
82
+ event_data = event["data"]
83
+
84
+ if event_type == "data":
85
+ yield f"event: data\ndata: {json.dumps(event_data)}\n\n"
86
+ elif event_type == "sources":
87
+ yield f"event: sources\ndata: {json.dumps(event_data)}\n\n"
88
+ elif event_type == "end":
89
+ yield f"event: end\ndata: {{}}\n\n"
90
+ elif event_type == "error":
91
+ yield f"event: error\ndata: {json.dumps(event_data)}\n\n"
92
+
93
+ return StreamingResponse(
94
+ event_stream(),
95
+ media_type="text/event-stream",
96
+ headers={
97
+ "Cache-Control": "no-cache",
98
+ "Connection": "keep-alive",
99
+ "Access-Control-Allow-Origin": "*",
100
+ "Access-Control-Allow-Headers": "*",
101
+ }
102
+ )
103
+
104
+ except Exception as e:
105
+ logger.exception("Streaming endpoint failed")
106
+ async def error_stream():
107
+ yield f"event: error\ndata: {json.dumps({'error': str(e)})}\n\n"
108
+
109
+ return StreamingResponse(
110
+ error_stream(),
111
+ media_type="text/event-stream"
112
+ )
113
+
114
  # ---------------------------------------------------------------------
115
  # Wrapper function to handle async streaming for Gradio
116
  # ---------------------------------------------------------------------
 
119
  logger.info(f"Starting generation request - Query length: {len(query)}, Context length: {len(context)}")
120
 
121
  async def _async_generator():
122
+ async for chunk in generate_streaming(query, context, chatui_format=False):
123
  yield chunk
124
 
125
  # Create a new event loop for this thread
 
178
  api_name="generate"
179
  )
180
 
181
+ # Mount Gradio app to FastAPI
182
+ app = gr.mount_gradio_app(app, ui, path="/gradio")
183
+
184
  # Launch with MCP server enabled
185
  if __name__ == "__main__":
186
+ import uvicorn
187
  logger.info("Starting ChatFed Generation Module server")
188
+ logger.info("FastAPI server will be available at http://0.0.0.0:7860")
189
+ logger.info("Gradio UI will be available at http://0.0.0.0:7860/gradio")
190
+ logger.info("ChatUI endpoints: /generate (non-streaming), /generate/stream (streaming)")
191
+
192
+ uvicorn.run(app, host="0.0.0.0", port=7860)
 
 
utils/generator.py CHANGED
@@ -86,8 +86,8 @@ def extract_relevant_fields(retrieval_results: List[Dict[str, Any]]) -> List[Dic
86
  Returns:
87
  List of processed objects with only relevant fields
88
  """
89
-
90
- retrieval_results = ast.literal_eval(retrieval_results)
91
 
92
  processed_results = []
93
 
@@ -191,7 +191,11 @@ def build_messages(question: str, context: str) -> list:
191
  """
192
  system_content = (
193
  "You are an expert assistant. Answer the USER question using only the "
194
- "CONTEXT provided. If the context is insufficient say 'I don't know.'"
 
 
 
 
195
  )
196
 
197
  user_content = f"### CONTEXT\n{context}\n\n### USER QUESTION\n{question}"
@@ -201,8 +205,7 @@ def build_messages(question: str, context: str) -> list:
201
  HumanMessage(content=user_content)
202
  ]
203
 
204
-
205
- async def generate(query: str, context: Union[str, List[Dict[str, Any]]]) -> str:
206
  """
207
  Generate an answer to a query using provided context through RAG.
208
 
@@ -211,42 +214,79 @@ async def generate(query: str, context: Union[str, List[Dict[str, Any]]]) -> str
211
 
212
  Args:
213
  query (str): User query
214
- context (list): List of retrieval result objects (dictionaries)
 
 
215
  Returns:
216
- str: The generated answer based on the query and context
217
  """
218
  if not query.strip():
219
- return "Error: Query cannot be empty"
 
 
220
 
221
  # Handle both string context (for Gradio UI) and list context (from retriever)
222
  if isinstance(context, list):
223
  if not context:
224
- return "Error: No retrieval results provided"
225
 
226
  # Process the retrieval results
227
  processed_results = extract_relevant_fields(context)
228
  formatted_context = format_context_from_results(processed_results)
229
 
230
  if not formatted_context.strip():
231
- return "Error: No valid content found in retrieval results"
232
 
233
  elif isinstance(context, str):
234
  if not context.strip():
235
- return "Error: Context cannot be empty"
236
  formatted_context = context
237
 
238
  else:
239
- return "Error: Context must be either a string or list of retrieval results"
240
 
241
  try:
242
  messages = build_messages(query, formatted_context)
243
  answer = await _call_llm(messages)
244
- return answer
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
245
  except Exception as e:
246
  logging.exception("Generation failed")
247
- return f"Error: {str(e)}"
248
 
249
- async def generate_streaming(query: str, context: Union[str, List[Dict[str, Any]]]) -> AsyncGenerator[str, None]:
250
  """
251
  Generate a streaming answer to a query using provided context through RAG.
252
 
@@ -256,18 +296,27 @@ async def generate_streaming(query: str, context: Union[str, List[Dict[str, Any]
256
  Args:
257
  query (str): User query
258
  context (Union[str, List[Dict[str, Any]]]): Context as string or list of retrieval results
 
259
 
260
  Yields:
261
- str: Streaming chunks of the generated answer
262
  """
263
  if not query.strip():
264
- yield "Error: Query cannot be empty"
 
 
 
265
  return
266
 
 
 
267
  # Handle both string context (for Gradio UI) and list context (from retriever)
268
  if isinstance(context, list):
269
  if not context:
270
- yield "Error: No retrieval results provided"
 
 
 
271
  return
272
 
273
  # Process the retrieval results
@@ -275,23 +324,72 @@ async def generate_streaming(query: str, context: Union[str, List[Dict[str, Any]
275
  formatted_context = format_context_from_results(processed_results)
276
 
277
  if not formatted_context.strip():
278
- yield "Error: No valid content found in retrieval results"
 
 
 
279
  return
280
 
281
  elif isinstance(context, str):
282
  if not context.strip():
283
- yield "Error: Context cannot be empty"
 
 
 
284
  return
285
  formatted_context = context
286
 
287
  else:
288
- yield "Error: Context must be either a string or list of retrieval results"
 
 
 
289
  return
290
 
291
  try:
292
  messages = build_messages(query, formatted_context)
 
 
293
  async for chunk in _call_llm_streaming(messages):
294
- yield chunk
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
295
  except Exception as e:
296
  logging.exception("Streaming generation failed")
297
- yield f"Error: {str(e)}"
 
 
 
 
86
  Returns:
87
  List of processed objects with only relevant fields
88
  """
89
+ if isinstance(retrieval_results, str):
90
+ retrieval_results = ast.literal_eval(retrieval_results)
91
 
92
  processed_results = []
93
 
 
191
  """
192
  system_content = (
193
  "You are an expert assistant. Answer the USER question using only the "
194
+ "CONTEXT provided. When referencing information from the context, use inline "
195
+ "citations in square brackets like [1], [2], etc. to reference the document "
196
+ "numbers shown in the context. Use multiple citations when information comes "
197
+ "from multiple documents, like [1][2]. If the context is insufficient, say "
198
+ "'I don't know.'"
199
  )
200
 
201
  user_content = f"### CONTEXT\n{context}\n\n### USER QUESTION\n{question}"
 
205
  HumanMessage(content=user_content)
206
  ]
207
 
208
+ async def generate(query: str, context: Union[str, List[Dict[str, Any]]], chatui_format: bool = False) -> Union[str, Dict[str, Any]]:
 
209
  """
210
  Generate an answer to a query using provided context through RAG.
211
 
 
214
 
215
  Args:
216
  query (str): User query
217
+ context (Union[str, List[Dict[str, Any]]]): Context as string or list of retrieval results
218
+ chatui_format (bool): If True, return ChatUI format with sources
219
+
220
  Returns:
221
+ Union[str, Dict]: The generated answer or ChatUI format response
222
  """
223
  if not query.strip():
224
+ return {"error": "Query cannot be empty"} if chatui_format else "Error: Query cannot be empty"
225
+
226
+ processed_results = []
227
 
228
  # Handle both string context (for Gradio UI) and list context (from retriever)
229
  if isinstance(context, list):
230
  if not context:
231
+ return {"error": "No retrieval results provided"} if chatui_format else "Error: No retrieval results provided"
232
 
233
  # Process the retrieval results
234
  processed_results = extract_relevant_fields(context)
235
  formatted_context = format_context_from_results(processed_results)
236
 
237
  if not formatted_context.strip():
238
+ return {"error": "No valid content found in retrieval results"} if chatui_format else "Error: No valid content found in retrieval results"
239
 
240
  elif isinstance(context, str):
241
  if not context.strip():
242
+ return {"error": "Context cannot be empty"} if chatui_format else "Error: Context cannot be empty"
243
  formatted_context = context
244
 
245
  else:
246
+ return {"error": "Context must be either a string or list of retrieval results"} if chatui_format else "Error: Context must be either a string or list of retrieval results"
247
 
248
  try:
249
  messages = build_messages(query, formatted_context)
250
  answer = await _call_llm(messages)
251
+
252
+ if chatui_format:
253
+ # Return ChatUI format
254
+ result = {"answer": answer}
255
+ if processed_results:
256
+ # Extract sources for ChatUI
257
+ sources = []
258
+ for result_item in processed_results:
259
+ filename = result_item.get('filename', 'Unknown')
260
+ page = result_item.get('page', 'Unknown')
261
+ year = result_item.get('year', 'Unknown')
262
+
263
+ # Create link using doc:// scheme
264
+ link = f"doc://{filename}"
265
+
266
+ # Create descriptive title
267
+ title_parts = [filename]
268
+ if page != 'Unknown':
269
+ title_parts.append(f"Page {page}")
270
+ if year != 'Unknown':
271
+ title_parts.append(f"({year})")
272
+
273
+ title = " - ".join(title_parts)
274
+
275
+ sources.append({
276
+ "link": link,
277
+ "title": title
278
+ })
279
+
280
+ result["sources"] = sources
281
+ return result
282
+ else:
283
+ return answer
284
+
285
  except Exception as e:
286
  logging.exception("Generation failed")
287
+ return {"error": str(e)} if chatui_format else f"Error: {str(e)}"
288
 
289
+ async def generate_streaming(query: str, context: Union[str, List[Dict[str, Any]]], chatui_format: bool = False) -> AsyncGenerator[Union[str, Dict[str, Any]], None]:
290
  """
291
  Generate a streaming answer to a query using provided context through RAG.
292
 
 
296
  Args:
297
  query (str): User query
298
  context (Union[str, List[Dict[str, Any]]]): Context as string or list of retrieval results
299
+ chatui_format (bool): If True, yield ChatUI format events
300
 
301
  Yields:
302
+ Union[str, Dict]: Streaming chunks or ChatUI format events
303
  """
304
  if not query.strip():
305
+ if chatui_format:
306
+ yield {"event": "error", "data": {"error": "Query cannot be empty"}}
307
+ else:
308
+ yield "Error: Query cannot be empty"
309
  return
310
 
311
+ processed_results = []
312
+
313
  # Handle both string context (for Gradio UI) and list context (from retriever)
314
  if isinstance(context, list):
315
  if not context:
316
+ if chatui_format:
317
+ yield {"event": "error", "data": {"error": "No retrieval results provided"}}
318
+ else:
319
+ yield "Error: No retrieval results provided"
320
  return
321
 
322
  # Process the retrieval results
 
324
  formatted_context = format_context_from_results(processed_results)
325
 
326
  if not formatted_context.strip():
327
+ if chatui_format:
328
+ yield {"event": "error", "data": {"error": "No valid content found in retrieval results"}}
329
+ else:
330
+ yield "Error: No valid content found in retrieval results"
331
  return
332
 
333
  elif isinstance(context, str):
334
  if not context.strip():
335
+ if chatui_format:
336
+ yield {"event": "error", "data": {"error": "Context cannot be empty"}}
337
+ else:
338
+ yield "Error: Context cannot be empty"
339
  return
340
  formatted_context = context
341
 
342
  else:
343
+ if chatui_format:
344
+ yield {"event": "error", "data": {"error": "Context must be either a string or list of retrieval results"}}
345
+ else:
346
+ yield "Error: Context must be either a string or list of retrieval results"
347
  return
348
 
349
  try:
350
  messages = build_messages(query, formatted_context)
351
+
352
+ # Stream the text response
353
  async for chunk in _call_llm_streaming(messages):
354
+ if chatui_format:
355
+ yield {"event": "data", "data": chunk}
356
+ else:
357
+ yield chunk
358
+
359
+ # Send sources at the end if available and in ChatUI format
360
+ if chatui_format and processed_results:
361
+ sources = []
362
+ for result in processed_results:
363
+ filename = result.get('filename', 'Unknown')
364
+ page = result.get('page', 'Unknown')
365
+ year = result.get('year', 'Unknown')
366
+
367
+ # Create link using doc:// scheme
368
+ link = f"doc://{filename}"
369
+
370
+ # Create descriptive title
371
+ title_parts = [filename]
372
+ if page != 'Unknown':
373
+ title_parts.append(f"Page {page}")
374
+ if year != 'Unknown':
375
+ title_parts.append(f"({year})")
376
+
377
+ title = " - ".join(title_parts)
378
+
379
+ sources.append({
380
+ "link": link,
381
+ "title": title
382
+ })
383
+
384
+ yield {"event": "sources", "data": {"sources": sources}}
385
+
386
+ # Send end event for ChatUI format
387
+ if chatui_format:
388
+ yield {"event": "end", "data": {}}
389
+
390
  except Exception as e:
391
  logging.exception("Streaming generation failed")
392
+ if chatui_format:
393
+ yield {"event": "error", "data": {"error": str(e)}}
394
+ else:
395
+ yield f"Error: {str(e)}"