mr-kush's picture
Refactor Dockerfile and implement UrgencyPredictor class with response schema for urgency classification API
63c461f
from fastapi import FastAPI, HTTPException
from typing import Union, List
from contextlib import asynccontextmanager
import os
import uvicorn
from predict_urgency_model import UrgencyPredictor
from response_schema import TextInput, UrgencyClassificationOutput
from huggingface_hub import HfApi
# Model repository setup
model_repo = os.getenv("MODEL_REPO", "sambodhan/sambodhan_urgency_classifier")
# Hugging Face API for version info
hf_api = HfApi()
# Startup and shutdown
@asynccontextmanager
async def lifespan(app: FastAPI):
global predictor
predictor = UrgencyPredictor(model_repo=model_repo)
yield
# FastAPI app
app = FastAPI(
title="Sambodhan Urgency Classifier API",
description="AI model that classifies citizen grievances by urgency with confidence scores.",
version="1.0.0",
lifespan=lifespan
)
# Routes
@app.post("/predict_urgency", response_model=Union[UrgencyClassificationOutput, List[UrgencyClassificationOutput]])
def predict_urgency(input_data: TextInput):
try:
prediction = predictor.predict(input_data.text)
return prediction
except Exception as e:
raise HTTPException(status_code=500, detail=f"Prediction failed: {str(e)}")
@app.get("/")
def root():
latest_tag = None
try:
latest_tag = hf_api.list_repo_refs(repo_id=model_repo, repo_type="model").tags[0].name
except Exception:
latest_tag = "unknown"
return {
"message": "Sambodhan Urgency Classifier API is running.",
"status": "Active" if predictor else "Inactive",
"model_version": latest_tag
}
# For local testing (optional)
# if __name__ == "__main__":
# port = int(os.getenv("PORT", 7860))
# uvicorn.run("app:app", host="0.0.0.0", port=port)