Kareman commited on
Commit
14faba3
·
0 Parent(s):

feat(ContextAI)

Browse files
Files changed (10) hide show
  1. .DS_Store +0 -0
  2. .gitignore +29 -0
  3. README.md +127 -0
  4. app/auth.py +40 -0
  5. app/config.py +2 -0
  6. app/database.py +11 -0
  7. app/main.py +113 -0
  8. app/models.py +22 -0
  9. app/rag.py +76 -0
  10. requirements.txt +16 -0
.DS_Store ADDED
Binary file (6.15 kB). View file
 
.gitignore ADDED
@@ -0,0 +1,29 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Python cache / build files
2
+ __pycache__/
3
+ *.pyc
4
+ *.pyo
5
+ *.pyd
6
+ *.sqlite3
7
+ *.log
8
+
9
+ # Virtual environment
10
+ venv/
11
+ .env/
12
+
13
+ # Environment variables
14
+ .env
15
+
16
+ # Uploaded files
17
+ uploads/
18
+
19
+ # Chroma database
20
+ chroma_db/
21
+
22
+ # IDE/editor configs (optional, but useful to ignore)
23
+ .vscode/
24
+ .idea/
25
+ *.swp
26
+
27
+ # SQLite DB
28
+ *.db
29
+ *.sqlite3
README.md ADDED
@@ -0,0 +1,127 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # ContextAI
2
+
3
+ A **FastAPI-based RAG application** that lets users upload documents (PDF/TXT) and ask questions.
4
+ Powered by **LangChain**, **ChromaDB**, and **LLMs** for context-aware answers.
5
+
6
+ 📚 FastAPI RAG App with LangChain, ChromaDB & Authentication
7
+
8
+ This project is a Retrieval-Augmented Generation (RAG) web application built with FastAPI.
9
+ It allows users to:
10
+
11
+ 🔑 Sign up / Sign in (JWT-based authentication)
12
+
13
+ 📂 Upload PDF or text documents
14
+
15
+ 🧠 Store document embeddings in ChromaDB (vector database)
16
+
17
+ 💬 Ask questions about uploaded documents
18
+
19
+ ⚡ Get context-aware answers powered by LangChain + LLMs (via OpenRouter
20
+ )
21
+
22
+ 🚀 Features
23
+
24
+ User authentication with access & refresh tokens
25
+
26
+ Secure file uploads (.pdf, .txt)
27
+
28
+ Automatic text chunking & embedding with HuggingFace models
29
+
30
+ Persistent vector store using ChromaDB
31
+
32
+ RAG pipeline with LangChain’s RetrievalQA
33
+
34
+ OpenRouter integration for running LLM queries
35
+
36
+ CORS configured for frontend integration
37
+
38
+ 🛠️ Tech Stack
39
+
40
+ FastAPI
41
+
42
+ LangChain
43
+
44
+ ChromaDB
45
+
46
+ SQLModel
47
+ for user database
48
+
49
+ HuggingFace Embeddings
50
+
51
+ OpenRouter
52
+ (for LLM access)
53
+
54
+ 📂 Project Structure
55
+ app/
56
+ ├── main.py # FastAPI routes & entrypoint
57
+ ├── rag.py # RAG pipeline (embeddings, vector store, QA chain)
58
+ ├── models.py # User models & schemas
59
+ ├── auth.py # Auth logic (hashing, tokens, verification)
60
+ ├── database.py # SQLModel setup
61
+ └── config.py # Settings & constants
62
+ uploads/ # User uploaded files (ignored in Git)
63
+ chroma_db/ # Vector DB storage (ignored in Git)
64
+
65
+ ⚙️ Setup & Installation
66
+ 1️⃣ Clone the repo
67
+ git clone https://github.com/your-username/fastapi-rag-app.git
68
+ cd fastapi-rag-app
69
+
70
+ 2️⃣ Create & activate virtual environment
71
+ python -m venv venv
72
+ source venv/bin/activate # Linux/Mac
73
+ venv\Scripts\activate # Windows
74
+
75
+ 3️⃣ Install dependencies
76
+ pip install -r requirements.txt
77
+
78
+ 4️⃣ Configure environment variables
79
+
80
+ Create a .env file in the project root (or copy from .env.example):
81
+
82
+ # OpenRouter
83
+ OPENROUTER=your_openrouter_api_key_here
84
+
85
+ # JWT secret
86
+ SECRET_KEY=your_super_secret_key
87
+
88
+ ⚠️ Never commit your real .env file.
89
+
90
+ ▶️ Run the App
91
+
92
+ Start the FastAPI server:
93
+
94
+ uvicorn app.main:app --reload
95
+
96
+
97
+ The API will be available at:
98
+ 👉 http://127.0.0.1:8000
99
+
100
+ Interactive API docs:
101
+ 👉 http://127.0.0.1:8000/docs
102
+
103
+ 🔑 Authentication Flow
104
+
105
+ Signup → POST /signup with username & password
106
+
107
+ Signin → POST /signin to receive access_token & refresh_token
108
+
109
+ Use Authorization: Bearer <access_token> for protected endpoints
110
+
111
+ 📂 Document Workflow
112
+
113
+ User logs in
114
+
115
+ Upload document → POST /upload (PDF or TXT)
116
+
117
+ Ask a question → GET /ask?q=your+question
118
+
119
+ The system searches your embeddings in ChromaDB and queries the LLM with context
120
+
121
+ 📝 Notes
122
+
123
+ uploads/ and chroma_db/ are auto-created at runtime if they don’t exist.
124
+
125
+ Both folders are ignored by Git (runtime data only).
126
+
127
+ Contributions & pull requests are welcome 🚀
app/auth.py ADDED
@@ -0,0 +1,40 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from passlib.context import CryptContext
2
+ from datetime import datetime, timedelta
3
+ from jose import JWTError, jwt
4
+ from typing import Optional
5
+
6
+ # Password hashing
7
+ pwd_context = CryptContext(schemes=["bcrypt"], deprecated="auto")
8
+ from dotenv import load_dotenv
9
+ load_dotenv()
10
+
11
+ def hash_password(password: str):
12
+ return pwd_context.hash(password)
13
+
14
+ def verify_password(password: str, hashed: str):
15
+ return pwd_context.verify(password, hashed)
16
+
17
+ # JWT settings
18
+ SECRET_KEY=os.getenv("SECRET_KEY")
19
+ ALGORITHM = "HS256"
20
+ ACCESS_TOKEN_EXPIRE_MINUTES = 15
21
+ REFRESH_TOKEN_EXPIRE_DAYS = 7
22
+
23
+ def create_access_token(data: dict, expires_delta: Optional[timedelta] = None):
24
+ to_encode = data.copy()
25
+ expire = datetime.utcnow() + (expires_delta or timedelta(minutes=ACCESS_TOKEN_EXPIRE_MINUTES))
26
+ to_encode.update({"exp": expire})
27
+ return jwt.encode(to_encode, SECRET_KEY, algorithm=ALGORITHM)
28
+
29
+ def create_refresh_token(data: dict, expires_delta: Optional[timedelta] = None):
30
+ to_encode = data.copy()
31
+ expire = datetime.utcnow() + (expires_delta or timedelta(days=REFRESH_TOKEN_EXPIRE_DAYS))
32
+ to_encode.update({"exp": expire})
33
+ return jwt.encode(to_encode, SECRET_KEY, algorithm=ALGORITHM)
34
+
35
+ def decode_token(token: str):
36
+ try:
37
+ payload = jwt.decode(token, SECRET_KEY, algorithms=[ALGORITHM])
38
+ return payload
39
+ except JWTError:
40
+ return None
app/config.py ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ # Vector DB storage
2
+ CHROMA_DB_DIR = "./chroma_db"
app/database.py ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from sqlmodel import SQLModel, create_engine, Session
2
+
3
+ DATABASE_URL = "sqlite:///./users.db"
4
+ engine = create_engine(DATABASE_URL, echo=True)
5
+
6
+ def init_db():
7
+ SQLModel.metadata.create_all(engine)
8
+
9
+ def get_session():
10
+ with Session(engine) as session:
11
+ yield session
app/main.py ADDED
@@ -0,0 +1,113 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from app import rag
2
+ import shutil
3
+ import os
4
+ from fastapi import FastAPI, Depends, HTTPException, status, UploadFile, File
5
+ from fastapi.security import OAuth2PasswordBearer
6
+ from sqlmodel import select
7
+ from app.models import User, UserCreate, UserLogin, Token
8
+ from app.auth import hash_password, verify_password, create_access_token, create_refresh_token, decode_token
9
+ from app.database import init_db, get_session
10
+ from sqlmodel import Session
11
+ from fastapi.middleware.cors import CORSMiddleware
12
+
13
+
14
+ # Initialize DB
15
+ init_db()
16
+
17
+ app = FastAPI()
18
+
19
+ # Allow your frontend origin
20
+ origins = [
21
+ "http://localhost:5173", # React dev server
22
+ ]
23
+
24
+ app.add_middleware(
25
+ CORSMiddleware,
26
+ allow_origins=origins, # or ["*"] for all origins (not recommended for production)
27
+ allow_credentials=True,
28
+ allow_methods=["*"], # allow POST, GET, OPTIONS, etc.
29
+ allow_headers=["*"],
30
+ )
31
+
32
+ oauth2_scheme = OAuth2PasswordBearer(tokenUrl="token")
33
+
34
+
35
+ UPLOAD_DIR = "./uploads"
36
+ os.makedirs(UPLOAD_DIR, exist_ok=True)
37
+
38
+
39
+ # ------------------------
40
+ # Protected Route Example
41
+ # ------------------------
42
+ def get_current_user(token: str = Depends(oauth2_scheme), session: Session = Depends(get_session)):
43
+ payload = decode_token(token)
44
+ if not payload:
45
+ raise HTTPException(status_code=401, detail="Invalid token")
46
+ username = payload.get("sub")
47
+ user = session.exec(select(User).where(User.username == username)).first()
48
+ if not user:
49
+ raise HTTPException(status_code=404, detail="User not found")
50
+ return user
51
+
52
+ @app.get("/protected")
53
+ def protected_route(current_user: User = Depends(get_current_user)):
54
+ return {"message": f"Hello {current_user.username}, you are authenticated!"}
55
+
56
+ @app.post("/upload")
57
+ def upload_file(file: UploadFile = File(...), current_user: User = Depends(get_current_user)):
58
+ user_id = current_user.username
59
+ file_path = f"./uploads/{file.filename}"
60
+ with open(file_path, "wb") as f:
61
+ shutil.copyfileobj(file.file, f)
62
+ rag.add_document(file_path, user_id=user_id)
63
+ return {"message": "Document uploaded successfully."}
64
+
65
+ @app.get("/ask")
66
+ def ask(q: str, current_user: User = Depends(get_current_user)):
67
+ user_id = current_user.username
68
+ qa = rag.get_qa_chain(user_id=user_id)
69
+ answer = qa.run(q)
70
+ return {"question": q, "answer": answer}
71
+
72
+
73
+ # ------------------------
74
+ # Auth Endpoints
75
+ # ------------------------
76
+
77
+ @app.post("/signup", response_model=Token)
78
+ def signup(user: UserCreate, session: Session = Depends(get_session)):
79
+ existing_user = session.exec(select(User).where(User.username == user.username)).first()
80
+ if existing_user:
81
+ raise HTTPException(status_code=400, detail="Username already exists")
82
+ db_user = User(username=user.username, hashed_password=hash_password(user.password))
83
+ session.add(db_user)
84
+ session.commit()
85
+ session.refresh(db_user)
86
+ access_token = create_access_token({"sub": db_user.username})
87
+ refresh_token = create_refresh_token({"sub": db_user.username})
88
+ return {"access_token": access_token, "refresh_token": refresh_token}
89
+
90
+ @app.post("/signin", response_model=Token)
91
+ def signin(user: UserLogin, session: Session = Depends(get_session)):
92
+ db_user = session.exec(select(User).where(User.username == user.username)).first()
93
+ if not db_user or not verify_password(user.password, db_user.hashed_password):
94
+ raise HTTPException(status_code=401, detail="Invalid username or password")
95
+ access_token = create_access_token({"sub": db_user.username})
96
+ refresh_token = create_refresh_token({"sub": db_user.username})
97
+ return {"access_token": access_token, "refresh_token": refresh_token}
98
+
99
+ from fastapi import Body
100
+
101
+ @app.post("/refresh", response_model=Token)
102
+ def refresh_token(refresh_token: str = Body(..., embed=True)):
103
+ payload = decode_token(refresh_token)
104
+ if not payload:
105
+ raise HTTPException(status_code=401, detail="Invalid refresh token")
106
+ username = payload.get("sub")
107
+ # ✅ issue a new access token
108
+ new_access_token = create_access_token({"sub": username})
109
+ # we can either reuse the refresh_token or rotate it (issue a new one)
110
+ return {"access_token": new_access_token, "refresh_token": refresh_token}
111
+
112
+
113
+
app/models.py ADDED
@@ -0,0 +1,22 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from sqlmodel import SQLModel, Field
2
+ from pydantic import BaseModel
3
+
4
+ # Database model
5
+ class User(SQLModel, table=True):
6
+ id: int | None = Field(default=None, primary_key=True)
7
+ username: str
8
+ hashed_password: str
9
+
10
+ # Pydantic models for API
11
+ class UserCreate(BaseModel):
12
+ username: str
13
+ password: str
14
+
15
+ class UserLogin(BaseModel):
16
+ username: str
17
+ password: str
18
+
19
+ class Token(BaseModel):
20
+ access_token: str
21
+ refresh_token: str
22
+ token_type: str = "bearer"
app/rag.py ADDED
@@ -0,0 +1,76 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import os
3
+
4
+ os.environ['NUMPY_IMPORT'] = 'done' # This ensures numpy is loaded
5
+
6
+ from langchain_community.vectorstores import Chroma
7
+ from langchain_community.embeddings import HuggingFaceEmbeddings
8
+ from langchain_community.document_loaders import TextLoader
9
+ from langchain.document_loaders import PyPDFLoader
10
+ from langchain.text_splitter import CharacterTextSplitter
11
+ from app.config import CHROMA_DB_DIR
12
+ from langchain.chat_models import ChatOpenAI
13
+ from langchain.chains import RetrievalQA
14
+
15
+ from dotenv import load_dotenv
16
+ load_dotenv()
17
+ OPENAI_ROUTER_TOKEN=os.getenv("OPENROUTER")
18
+
19
+
20
+ # Embeddings
21
+ embeddings = HuggingFaceEmbeddings(model_name=EMBEDDING_MODEL)
22
+
23
+ # Chroma DB
24
+ db = Chroma(persist_directory=CHROMA_DB_DIR, embedding_function=embeddings)
25
+
26
+ from langchain.docstore.document import Document
27
+
28
+ def add_document(file_path: str, user_id: str):
29
+ # Load file
30
+ if file_path.lower().endswith(".pdf"):
31
+ loader = PyPDFLoader(file_path)
32
+ elif file_path.lower().endswith(".txt"):
33
+ loader = TextLoader(file_path, encoding="utf-8")
34
+ else:
35
+ raise RuntimeError(f"Unsupported file type: {file_path}")
36
+
37
+ documents = loader.load()
38
+
39
+ # Split into chunks
40
+ splitter = CharacterTextSplitter(chunk_size=500, chunk_overlap=50)
41
+ docs = splitter.split_documents(documents)
42
+
43
+ # Add metadata directly to Document objects
44
+ docs_with_metadata = [
45
+ Document(page_content=d.page_content, metadata={"user_id": user_id, "filename": os.path.basename(file_path)})
46
+ for d in docs
47
+ ]
48
+
49
+ # Add to vector store
50
+ db.add_documents(docs_with_metadata)
51
+
52
+
53
+ def get_qa_chain(user_id: str):
54
+ """
55
+ Return a RetrievalQA pipeline for a specific user using OpenRouter's Phi-3 Medium Instruct model.
56
+
57
+ Args:
58
+ user_id (str): Unique identifier for the user.
59
+ """
60
+ # Initialize LLM with OpenRouter
61
+ llm = ChatOpenAI(
62
+ openai_api_key=OPENAI_ROUTER_TOKEN, # your OpenRouter API key
63
+ model="meta-llama/llama-4-scout:free", # free OpenRouter model
64
+ temperature=0,
65
+ max_tokens=512,
66
+ openai_api_base="https://openrouter.ai/api/v1" # OpenRouter endpoint
67
+ )
68
+ # Create retriever filtered by user_id
69
+ retriever = db.as_retriever(search_kwargs={"filter": {"user_id": user_id}})
70
+
71
+ # Build RetrievalQA pipeline
72
+ qa = RetrievalQA.from_chain_type(llm=llm, retriever=retriever, chain_type="stuff")
73
+ return qa
74
+
75
+
76
+
requirements.txt ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ fastapi
2
+ uvicorn
3
+ langchain
4
+ langchain-community
5
+ langchain-core
6
+ chromadb
7
+ pydantic
8
+ sentence-transformers
9
+ transformers
10
+ huggingface-hub
11
+ passlib[bcrypt] # for hashing passwords
12
+ python-jose[cryptography] # for JWT tokens
13
+ python-multipart # already there, for file uploads
14
+ sqlmodel
15
+ pypdf
16
+ numpy