|
|
from fastapi import FastAPI, HTTPException |
|
|
from typing import Union, List |
|
|
from contextlib import asynccontextmanager |
|
|
import os |
|
|
import uvicorn |
|
|
|
|
|
from predict_urgency_model import UrgencyPredictor |
|
|
from response_schema import TextInput, UrgencyClassificationOutput |
|
|
from huggingface_hub import HfApi |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
model_repo = os.getenv("MODEL_REPO", "sambodhan/sambodhan_urgency_classifier") |
|
|
|
|
|
|
|
|
hf_api = HfApi() |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@asynccontextmanager |
|
|
async def lifespan(app: FastAPI): |
|
|
global predictor |
|
|
predictor = UrgencyPredictor(model_repo=model_repo) |
|
|
yield |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
app = FastAPI( |
|
|
title="Sambodhan Urgency Classifier API", |
|
|
description="AI model that classifies citizen grievances by urgency with confidence scores.", |
|
|
version="1.0.0", |
|
|
lifespan=lifespan |
|
|
) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@app.post("/predict_urgency", response_model=Union[UrgencyClassificationOutput, List[UrgencyClassificationOutput]]) |
|
|
def predict_urgency(input_data: TextInput): |
|
|
try: |
|
|
prediction = predictor.predict(input_data.text) |
|
|
return prediction |
|
|
except Exception as e: |
|
|
raise HTTPException(status_code=500, detail=f"Prediction failed: {str(e)}") |
|
|
|
|
|
@app.get("/") |
|
|
def root(): |
|
|
latest_tag = None |
|
|
try: |
|
|
latest_tag = hf_api.list_repo_refs(repo_id=model_repo, repo_type="model").tags[0].name |
|
|
except Exception: |
|
|
latest_tag = "unknown" |
|
|
|
|
|
return { |
|
|
"message": "Sambodhan Urgency Classifier API is running.", |
|
|
"status": "Active" if predictor else "Inactive", |
|
|
"model_version": latest_tag |
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|