Spaces:
Running
on
CPU Upgrade
Running
on
CPU Upgrade
ts cleanup langserve
Browse files- app/main.py +42 -149
app/main.py
CHANGED
|
@@ -791,109 +791,8 @@ async def chatui_adapter(data):
|
|
| 791 |
logger.error(f"ChatUI error: {str(e)}")
|
| 792 |
yield f"Error: {str(e)}"
|
| 793 |
|
| 794 |
-
# # MAIN CHATUI STREAMING ENDPOINT - KEEP THIS (ChatUI likely uses this)
|
| 795 |
-
# async def chatui_stream_with_sources(request: Request):
|
| 796 |
-
# """ChatUI streaming endpoint that supports sources format"""
|
| 797 |
-
# try:
|
| 798 |
-
# body = await request.json()
|
| 799 |
-
# query = body.get("text", "")
|
| 800 |
-
|
| 801 |
-
# if not query:
|
| 802 |
-
# async def error_stream():
|
| 803 |
-
# yield f"event: error\ndata: {json.dumps({'error': 'No query provided'})}\n\n"
|
| 804 |
-
# return StreamingResponse(error_stream(), media_type="text/event-stream")
|
| 805 |
-
|
| 806 |
-
# async def event_stream():
|
| 807 |
-
# try:
|
| 808 |
-
# # Process the query through the orchestrator pipeline
|
| 809 |
-
# start_time = datetime.now()
|
| 810 |
-
# session_id = f"chatui_{start_time.strftime('%Y%m%d_%H%M%S')}"
|
| 811 |
-
|
| 812 |
-
# initial_state = {
|
| 813 |
-
# "query": query,
|
| 814 |
-
# "context": "",
|
| 815 |
-
# "ingestor_context": "",
|
| 816 |
-
# "result": "",
|
| 817 |
-
# "sources": [],
|
| 818 |
-
# "reports_filter": "",
|
| 819 |
-
# "sources_filter": "",
|
| 820 |
-
# "subtype_filter": "",
|
| 821 |
-
# "year_filter": "",
|
| 822 |
-
# "file_content": None,
|
| 823 |
-
# "filename": None,
|
| 824 |
-
# "file_type": "unknown",
|
| 825 |
-
# "workflow_type": "standard",
|
| 826 |
-
# "metadata": {
|
| 827 |
-
# "session_id": session_id,
|
| 828 |
-
# "start_time": start_time.isoformat(),
|
| 829 |
-
# "has_file_attachment": False
|
| 830 |
-
# }
|
| 831 |
-
# }
|
| 832 |
-
|
| 833 |
-
# # Process non-streaming steps
|
| 834 |
-
# state_after_detect = {**initial_state, **detect_file_type_node(initial_state)}
|
| 835 |
-
# state_after_ingest = {**state_after_detect, **ingest_node(state_after_detect)}
|
| 836 |
-
# workflow_type = route_workflow(state_after_ingest)
|
| 837 |
-
|
| 838 |
-
# if workflow_type == "geojson_direct":
|
| 839 |
-
# final_state = geojson_direct_result_node(state_after_ingest)
|
| 840 |
-
# yield f"event: data\ndata: {json.dumps(final_state['result'])}\n\n"
|
| 841 |
-
# yield f"event: end\ndata: {{}}\n\n"
|
| 842 |
-
# else:
|
| 843 |
-
# state_after_retrieve = {**state_after_ingest, **retrieve_node(state_after_ingest)}
|
| 844 |
-
|
| 845 |
-
# # Stream generation with sources
|
| 846 |
-
# sources_sent = False
|
| 847 |
-
# final_sources = None
|
| 848 |
-
# partial_count = 0
|
| 849 |
-
|
| 850 |
|
| 851 |
-
|
| 852 |
-
# async for partial_state in generate_node_streaming(state_after_retrieve):
|
| 853 |
-
# if "result" in partial_state:
|
| 854 |
-
# sse_message = f"event: data\ndata: {json.dumps(partial_state['result'])}\n\n"
|
| 855 |
-
# yield sse_message
|
| 856 |
-
|
| 857 |
-
# # Store sources but don't send immediately
|
| 858 |
-
# if "sources" in partial_state:
|
| 859 |
-
# final_sources = partial_state["sources"]
|
| 860 |
-
|
| 861 |
-
# # Send sources as the last event before end
|
| 862 |
-
# if final_sources and not sources_sent:
|
| 863 |
-
# sources_data = {"sources": final_sources}
|
| 864 |
-
# sse_sources_message = f"event: sources\ndata: {json.dumps(sources_data)}\n\n"
|
| 865 |
-
# yield sse_sources_message
|
| 866 |
-
# sources_sent = True
|
| 867 |
-
|
| 868 |
-
# end_message = f"event: end\ndata: {{}}\n\n"
|
| 869 |
-
# yield end_message
|
| 870 |
-
|
| 871 |
-
# except Exception as e:
|
| 872 |
-
# logger.error(f"ChatUI streaming error: {str(e)}")
|
| 873 |
-
# yield f"event: error\ndata: {json.dumps({'error': str(e)})}\n\n"
|
| 874 |
-
|
| 875 |
-
# return StreamingResponse(
|
| 876 |
-
# event_stream(),
|
| 877 |
-
# media_type="text/event-stream",
|
| 878 |
-
# headers={
|
| 879 |
-
# "Cache-Control": "no-cache",
|
| 880 |
-
# "Connection": "keep-alive",
|
| 881 |
-
# "Access-Control-Allow-Origin": "*",
|
| 882 |
-
# "Access-Control-Allow-Headers": "*",
|
| 883 |
-
# }
|
| 884 |
-
# )
|
| 885 |
-
|
| 886 |
-
# except Exception as e:
|
| 887 |
-
# logger.error(f"ChatUI request parsing failed: {str(e)}")
|
| 888 |
-
# async def error_stream():
|
| 889 |
-
# yield f"event: error\ndata: {json.dumps({'error': str(e)})}\n\n"
|
| 890 |
-
|
| 891 |
-
# return StreamingResponse(
|
| 892 |
-
# error_stream(),
|
| 893 |
-
# media_type="text/event-stream"
|
| 894 |
-
# )
|
| 895 |
-
|
| 896 |
-
# TODO: PROBABLY REMOVE - USED BY LANGSERVE ENDPOINT
|
| 897 |
def process_query_langserve(input_data: ChatFedInput) -> ChatFedOutput:
|
| 898 |
result = process_query_core(
|
| 899 |
query=input_data["query"],
|
|
@@ -909,37 +808,37 @@ def process_query_langserve(input_data: ChatFedInput) -> ChatFedOutput:
|
|
| 909 |
)
|
| 910 |
return ChatFedOutput(result=result["result"], metadata=result["metadata"])
|
| 911 |
|
| 912 |
-
#
|
| 913 |
-
|
| 914 |
-
|
| 915 |
-
|
| 916 |
-
|
| 917 |
|
| 918 |
-
|
| 919 |
-
|
| 920 |
-
|
| 921 |
-
|
| 922 |
|
| 923 |
-
|
| 924 |
-
|
| 925 |
-
|
| 926 |
-
|
| 927 |
-
|
| 928 |
|
| 929 |
-
|
| 930 |
|
| 931 |
-
|
| 932 |
-
|
| 933 |
|
| 934 |
-
#
|
| 935 |
-
|
| 936 |
-
|
| 937 |
-
|
| 938 |
-
|
| 939 |
-
|
| 940 |
-
|
| 941 |
|
| 942 |
-
|
| 943 |
|
| 944 |
@asynccontextmanager
|
| 945 |
async def lifespan(app: FastAPI):
|
|
@@ -969,15 +868,10 @@ async def root():
|
|
| 969 |
"chatfed-ui-stream": "/chatfed-ui-stream",
|
| 970 |
"chatfed-with-file": "/chatfed-with-file",
|
| 971 |
# "chatfed-with-file-stream": "/chatfed-with-file/stream",
|
| 972 |
-
# "chatui-stream": "/chatui-stream"
|
| 973 |
}
|
| 974 |
}
|
| 975 |
|
| 976 |
-
|
| 977 |
-
# @app.post("/chatui-stream")
|
| 978 |
-
# async def chatui_stream_endpoint(request: Request):
|
| 979 |
-
# """ChatUI compatible streaming endpoint with sources support"""
|
| 980 |
-
# return await chatui_stream_with_sources(request)
|
| 981 |
|
| 982 |
# # FILE UPLOAD ADAPTER - KEEP THIS
|
| 983 |
async def chatfed_with_file_adapter(
|
|
@@ -1113,7 +1007,7 @@ async def chatfed_with_file_stream(
|
|
| 1113 |
filename = file.filename
|
| 1114 |
|
| 1115 |
async def generate_sse_stream():
|
| 1116 |
-
"""Generate Server-Sent Events format
|
| 1117 |
try:
|
| 1118 |
# Small delay to ensure SSE connection is established
|
| 1119 |
await asyncio.sleep(0.1)
|
|
@@ -1148,7 +1042,7 @@ async def chatfed_with_file_stream(
|
|
| 1148 |
|
| 1149 |
return StreamingResponse(
|
| 1150 |
generate_sse_stream(),
|
| 1151 |
-
media_type="text/event-stream",
|
| 1152 |
headers={
|
| 1153 |
"Cache-Control": "no-cache",
|
| 1154 |
"Connection": "keep-alive",
|
|
@@ -1158,13 +1052,13 @@ async def chatfed_with_file_stream(
|
|
| 1158 |
|
| 1159 |
# TODO: TEST IF CHATUI NEEDS THESE LANGSERVE ENDPOINTS
|
| 1160 |
# If ChatUI works without these, they can be removed
|
| 1161 |
-
add_routes(
|
| 1162 |
-
|
| 1163 |
-
|
| 1164 |
-
|
| 1165 |
-
|
| 1166 |
-
|
| 1167 |
-
)
|
| 1168 |
|
| 1169 |
add_routes(
|
| 1170 |
app,
|
|
@@ -1177,17 +1071,16 @@ add_routes(
|
|
| 1177 |
)
|
| 1178 |
|
| 1179 |
if __name__ == "__main__":
|
| 1180 |
-
#
|
| 1181 |
-
|
| 1182 |
-
# demo = create_gradio_interface()
|
| 1183 |
|
| 1184 |
-
#
|
| 1185 |
-
|
| 1186 |
|
| 1187 |
-
host = os.getenv("HOST", "
|
| 1188 |
port = int(os.getenv("PORT", "7860"))
|
| 1189 |
|
| 1190 |
logger.info(f"Starting FastAPI server on {host}:{port}")
|
| 1191 |
-
|
| 1192 |
|
| 1193 |
uvicorn.run(app, host=host, port=port, log_level="info", access_log=True)
|
|
|
|
| 791 |
logger.error(f"ChatUI error: {str(e)}")
|
| 792 |
yield f"Error: {str(e)}"
|
| 793 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 794 |
|
| 795 |
+
# USED BY LANGSERVE ENDPOINT
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 796 |
def process_query_langserve(input_data: ChatFedInput) -> ChatFedOutput:
|
| 797 |
result = process_query_core(
|
| 798 |
query=input_data["query"],
|
|
|
|
| 808 |
)
|
| 809 |
return ChatFedOutput(result=result["result"], metadata=result["metadata"])
|
| 810 |
|
| 811 |
+
# GRADIO TEST UI
|
| 812 |
+
def create_gradio_interface():
|
| 813 |
+
with gr.Blocks(title="ChatFed Orchestrator") as demo:
|
| 814 |
+
gr.Markdown("# ChatFed Orchestrator")
|
| 815 |
+
gr.Markdown("Upload documents (PDF/DOCX/GeoJSON) alongside your queries for enhanced context. MCP endpoints available at `/gradio_api/mcp/sse`")
|
| 816 |
|
| 817 |
+
with gr.Row():
|
| 818 |
+
with gr.Column():
|
| 819 |
+
query_input = gr.Textbox(label="Query", lines=2, placeholder="Enter your question...")
|
| 820 |
+
file_input = gr.File(label="Upload Document (PDF/DOCX/GeoJSON)", file_types=[".pdf", ".docx", ".geojson", ".json"])
|
| 821 |
|
| 822 |
+
with gr.Accordion("Filters (Optional)", open=False):
|
| 823 |
+
reports_filter_input = gr.Textbox(label="Reports Filter", placeholder="e.g., annual_reports")
|
| 824 |
+
sources_filter_input = gr.Textbox(label="Sources Filter", placeholder="e.g., internal")
|
| 825 |
+
subtype_filter_input = gr.Textbox(label="Subtype Filter", placeholder="e.g., financial")
|
| 826 |
+
year_filter_input = gr.Textbox(label="Year Filter", placeholder="e.g., 2024")
|
| 827 |
|
| 828 |
+
submit_btn = gr.Button("Submit", variant="primary")
|
| 829 |
|
| 830 |
+
with gr.Column():
|
| 831 |
+
output = gr.Textbox(label="Response", lines=15, show_copy_button=True)
|
| 832 |
|
| 833 |
+
# Use streaming function
|
| 834 |
+
submit_btn.click(
|
| 835 |
+
fn=process_query_gradio_streaming,
|
| 836 |
+
inputs=[query_input, file_input, reports_filter_input, sources_filter_input,
|
| 837 |
+
subtype_filter_input, year_filter_input],
|
| 838 |
+
outputs=output
|
| 839 |
+
)
|
| 840 |
|
| 841 |
+
return demo
|
| 842 |
|
| 843 |
@asynccontextmanager
|
| 844 |
async def lifespan(app: FastAPI):
|
|
|
|
| 868 |
"chatfed-ui-stream": "/chatfed-ui-stream",
|
| 869 |
"chatfed-with-file": "/chatfed-with-file",
|
| 870 |
# "chatfed-with-file-stream": "/chatfed-with-file/stream",
|
|
|
|
| 871 |
}
|
| 872 |
}
|
| 873 |
|
| 874 |
+
|
|
|
|
|
|
|
|
|
|
|
|
|
| 875 |
|
| 876 |
# # FILE UPLOAD ADAPTER - KEEP THIS
|
| 877 |
async def chatfed_with_file_adapter(
|
|
|
|
| 1007 |
filename = file.filename
|
| 1008 |
|
| 1009 |
async def generate_sse_stream():
|
| 1010 |
+
"""Generate Server-Sent Events format for ChatUI"""
|
| 1011 |
try:
|
| 1012 |
# Small delay to ensure SSE connection is established
|
| 1013 |
await asyncio.sleep(0.1)
|
|
|
|
| 1042 |
|
| 1043 |
return StreamingResponse(
|
| 1044 |
generate_sse_stream(),
|
| 1045 |
+
media_type="text/event-stream", #ChatUI format
|
| 1046 |
headers={
|
| 1047 |
"Cache-Control": "no-cache",
|
| 1048 |
"Connection": "keep-alive",
|
|
|
|
| 1052 |
|
| 1053 |
# TODO: TEST IF CHATUI NEEDS THESE LANGSERVE ENDPOINTS
|
| 1054 |
# If ChatUI works without these, they can be removed
|
| 1055 |
+
# add_routes(
|
| 1056 |
+
# app,
|
| 1057 |
+
# RunnableLambda(process_query_langserve),
|
| 1058 |
+
# path="/chatfed",
|
| 1059 |
+
# input_type=ChatFedInput,
|
| 1060 |
+
# output_type=ChatFedOutput
|
| 1061 |
+
# )
|
| 1062 |
|
| 1063 |
add_routes(
|
| 1064 |
app,
|
|
|
|
| 1071 |
)
|
| 1072 |
|
| 1073 |
if __name__ == "__main__":
|
| 1074 |
+
# Create Gradio interface
|
| 1075 |
+
demo = create_gradio_interface()
|
|
|
|
| 1076 |
|
| 1077 |
+
# Mount Gradio app to FastAPI
|
| 1078 |
+
app = gr.mount_gradio_app(app, demo, path="/gradio")
|
| 1079 |
|
| 1080 |
+
host = os.getenv("HOST", "localhost")
|
| 1081 |
port = int(os.getenv("PORT", "7860"))
|
| 1082 |
|
| 1083 |
logger.info(f"Starting FastAPI server on {host}:{port}")
|
| 1084 |
+
logger.info(f"Gradio UI available at: http://{host}:{port}/gradio") # Remove if Gradio is removed
|
| 1085 |
|
| 1086 |
uvicorn.run(app, host=host, port=port, log_level="info", access_log=True)
|