Spaces:
Running
Running
| from fastapi import FastAPI, Request, Form | |
| from fastapi.responses import HTMLResponse, JSONResponse, StreamingResponse | |
| from fastapi.templating import Jinja2Templates | |
| from langchain_pinecone import PineconeVectorStore | |
| from src.config import Config | |
| from src.helper import download_embeddings | |
| from src.utility import QueryClassifier, StreamingHandler | |
| from langchain_google_genai import ChatGoogleGenerativeAI | |
| from langchain_classic.chains import create_retrieval_chain | |
| from langchain_classic.chains.combine_documents import create_stuff_documents_chain | |
| from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder | |
| from langchain_core.chat_history import BaseChatMessageHistory | |
| from langchain_community.chat_message_histories import ChatMessageHistory | |
| from src.prompt import system_prompt | |
| import uuid | |
| Config.validate() | |
| PINECONE_API_KEY = Config.PINECONE_API_KEY | |
| GEMINI_API_KEY = Config.GEMINI_API_KEY | |
| templates = Jinja2Templates(directory="templates") | |
| # Intialize FastAPI app | |
| app = FastAPI(title="Medical Chatbot", version="0.0.0") | |
| # Store for session-based chat histories (resets on server restart) | |
| chat_histories = {} | |
| # Intialize embedding model | |
| print("Loading the Embedding model...") | |
| embeddings = download_embeddings() | |
| # Connect to existing Pinecone index | |
| index_name = Config.PINECONE_INDEX_NAME | |
| print(f"Connecting to PineCone index: {index_name}") | |
| docsearch = PineconeVectorStore.from_existing_index( | |
| index_name=index_name, embedding=embeddings | |
| ) | |
| # Creating retriever from vector store | |
| retriever = docsearch.as_retriever( | |
| search_type=Config.SEARCH_TYPE, search_kwargs={"k": Config.RETRIEVAL_K} | |
| ) | |
| # Initialize Google Gemini chat model | |
| print("Initializing Gemini model...") | |
| llm = ChatGoogleGenerativeAI( | |
| model=Config.GEMINI_MODEL, | |
| google_api_key=GEMINI_API_KEY, | |
| temperature=Config.LLM_TEMPERATURE, | |
| convert_system_message_to_human=True, | |
| ) | |
| # Create chat prompt template with memory | |
| prompt = ChatPromptTemplate.from_messages( | |
| [ | |
| ("system", system_prompt), | |
| MessagesPlaceholder(variable_name="chat_history"), | |
| ("human", "{input}"), | |
| ] | |
| ) | |
| # Create the question-answer chain | |
| question_answer_chain = create_stuff_documents_chain(llm, prompt) | |
| # Create the RAG chain | |
| rag_chain = create_retrieval_chain(retriever, question_answer_chain) | |
| # Function to get chat history for a session | |
| def get_chat_history(session_id: str) -> BaseChatMessageHistory: | |
| if session_id not in chat_histories: | |
| chat_histories[session_id] = ChatMessageHistory() | |
| return chat_histories[session_id] | |
| # Function to maintain conversation window buffer (keep last 5 messages) | |
| def manage_memory_window(session_id: str, max_messages: int = 10): | |
| """Keep only the last max_messages (5 pairs = 10 messages)""" | |
| if session_id in chat_histories: | |
| history = chat_histories[session_id] | |
| if len(history.messages) > max_messages: | |
| # Keep only the last max_messages | |
| history.messages = history.messages[-max_messages:] | |
| print("Intialized Medical Chabot successfuly!") | |
| print("Vector Store connected") | |
| async def index(request: Request): | |
| """Render the chatbot interface""" | |
| # Clear all old sessions to prevent memory overflow | |
| chat_histories.clear() | |
| # Generate a new session ID for each page load | |
| session_id = str(uuid.uuid4()) | |
| return templates.TemplateResponse( | |
| "index.html", {"request": request, "session_id": session_id} | |
| ) | |
| async def chat(msg: str = Form(...), session_id: str = Form(...)): | |
| """Handle chat messages and return streaming AI responses with conversation memory""" | |
| # Get chat history for this session | |
| history = get_chat_history(session_id) | |
| # Classify query to determine if retrieval is needed | |
| needs_retrieval, reason = QueryClassifier.needs_retrieval(msg) | |
| async def generate_response(): | |
| """Generator for streaming response""" | |
| full_answer = "" | |
| try: | |
| if needs_retrieval: | |
| # Stream RAG chain response for medical queries | |
| print(f"✓ [RETRIEVAL STREAM] Reason: {reason} | Query: {msg[:50]}...") | |
| async for chunk in StreamingHandler.stream_rag_response( | |
| rag_chain, {"input": msg, "chat_history": history.messages} | |
| ): | |
| yield chunk | |
| # Extract full answer from the last chunk | |
| if b'"done": true' in chunk.encode(): | |
| import json | |
| data = json.loads(chunk.replace("data: ", "").strip()) | |
| if "full_answer" in data: | |
| full_answer = data["full_answer"] | |
| else: | |
| # Stream simple response for greetings/acknowledgments | |
| print(f"[NO RETRIEVAL STREAM] Reason: {reason} | Query: {msg[:50]}...") | |
| simple_resp = QueryClassifier.get_simple_response(msg) | |
| full_answer = simple_resp | |
| async for chunk in StreamingHandler.stream_simple_response(simple_resp): | |
| yield chunk | |
| # Add the conversation to history after streaming completes | |
| history.add_user_message(msg) | |
| history.add_ai_message(full_answer) | |
| # Manage memory window | |
| manage_memory_window(session_id, max_messages=10) | |
| except Exception as e: | |
| print(f"Error during streaming: {str(e)}") | |
| import json | |
| yield f"data: {json.dumps({'error': 'An error occurred', 'done': True})}\n\n" | |
| return StreamingResponse( | |
| generate_response(), | |
| media_type="text/event-stream", | |
| headers={ | |
| "Cache-Control": "no-cache", | |
| "Connection": "keep-alive", | |
| "X-Accel-Buffering": "no" | |
| } | |
| ) | |
| if __name__ == "__main__": | |
| import uvicorn | |
| import os | |
| # Use PORT from environment (7860 for HF Spaces, 8080 for Render) | |
| port = int(os.getenv("PORT", 7860)) | |
| uvicorn.run(app, host="0.0.0.0", port=port) | |