mr-kush commited on
Commit
63c461f
·
1 Parent(s): 9a4dd71

Refactor Dockerfile and implement UrgencyPredictor class with response schema for urgency classification API

Browse files
Files changed (4) hide show
  1. Dockerfile +11 -15
  2. app.py +58 -43
  3. predict_urgency_model.py +62 -0
  4. response_schema.py +49 -0
Dockerfile CHANGED
@@ -1,25 +1,21 @@
1
-
2
- # Lightweight Python base
3
  FROM python:3.12-slim
4
 
5
  WORKDIR /app
6
  COPY . /app
7
 
8
- # Create writable cache folder
9
- RUN mkdir -p /app/model_cache && chmod -R 777 /app/model_cache
10
-
11
- # Environment variables for Hugging Face cache
12
- ENV HF_HOME=/app/model_cache
13
- ENV TRANSFORMERS_CACHE=/app/model_cache
14
- ENV HF_DATASETS_CACHE=/app/model_cache
15
- ENV HF_METRICS_CACHE=/app/model_cache
16
 
17
- # Install dependencies
18
- RUN apt-get update && apt-get install -y git
19
  RUN pip install --upgrade pip
20
  RUN pip install --no-cache-dir -r requirements.txt
21
 
22
- EXPOSE 7860
 
23
 
24
- # Run FastAPI server
25
- CMD ["uvicorn", "app:app", "--host", "0.0.0.0", "--port", "7860"]
 
 
 
 
1
  FROM python:3.12-slim
2
 
3
  WORKDIR /app
4
  COPY . /app
5
 
6
+ # use dedicated cache dir
7
+ ENV HF_HOME=/app/hf_cache
8
+ ENV HF_DATASETS_CACHE=/app/hf_cache
9
+ ENV HF_METRICS_CACHE=/app/hf_cache
10
+ ENV MODEL_REPO=sambodhan/sambodhan_urgency_classifier
 
 
 
11
 
12
+ RUN apt-get update && apt-get install -y git curl && rm -rf /var/lib/apt/lists/*
 
13
  RUN pip install --upgrade pip
14
  RUN pip install --no-cache-dir -r requirements.txt
15
 
16
+ # make sure cache dir is writable
17
+ RUN mkdir -p /app/hf_cache && chmod -R 777 /app/hf_cache
18
 
19
+
20
+ EXPOSE 7860
21
+ CMD ["sh", "-c", "uvicorn app:app --host 0.0.0.0 --port ${PORT:-7860}"]
app.py CHANGED
@@ -1,53 +1,68 @@
1
-
 
 
2
  import os
3
- from fastapi import FastAPI
4
- from pydantic import BaseModel
5
- from transformers import AutoTokenizer, AutoModelForSequenceClassification, pipeline
6
- import torch
7
-
8
- # ✅ Create writable model cache directory
9
- os.makedirs("/app/model_cache", exist_ok=True)
10
- os.environ["HF_HOME"] = "/app/model_cache"
11
- os.environ["TRANSFORMERS_CACHE"] = "/app/model_cache"
12
- os.environ["HF_DATASETS_CACHE"] = "/app/model_cache"
13
- os.environ["HF_METRICS_CACHE"] = "/app/model_cache"
14
-
15
- MODEL_REPO = "sambodhan/sambodhan_urgency_classifier"
16
- device = 0 if torch.cuda.is_available() else -1
17
-
18
- # Load model and tokenizer safely
19
- tokenizer = AutoTokenizer.from_pretrained(MODEL_REPO, cache_dir="/app/model_cache")
20
- model = AutoModelForSequenceClassification.from_pretrained(MODEL_REPO, cache_dir="/app/model_cache")
21
-
22
- classifier = pipeline(
23
- "text-classification",
24
- model=model,
25
- tokenizer=tokenizer,
26
- device=device,
27
- return_all_scores=True
 
 
 
 
 
 
28
  )
29
 
30
- LABELS = ["NORMAL", "URGENT", "HIGHLY URGENT"]
31
 
32
- app = FastAPI(title="Sambodhan Urgency Classifier API", version="2.0.3")
33
 
34
- class TextInput(BaseModel):
35
- text: str
 
 
 
 
 
36
 
37
- @app.post("/predict_urgency")
38
- async def predict(input_data: TextInput):
39
- text = input_data.text.strip()
40
- if not text:
41
- return {"error": "Empty input"}
 
 
42
 
43
- results = classifier(text)[0]
44
- top = max(results, key=lambda x: x["score"])
45
  return {
46
- "label": top["label"],
47
- "confidence": round(top["score"], 4),
48
- "scores": {r["label"]: round(r["score"], 4) for r in results},
49
  }
50
 
51
- @app.get("/")
52
- def root():
53
- return {"message": "✅ Sambodhan Urgency Classifier API running successfully!"}
 
 
 
 
1
+ from fastapi import FastAPI, HTTPException
2
+ from typing import Union, List
3
+ from contextlib import asynccontextmanager
4
  import os
5
+ import uvicorn
6
+
7
+ from predict_urgency_model import UrgencyPredictor
8
+ from response_schema import TextInput, UrgencyClassificationOutput
9
+ from huggingface_hub import HfApi
10
+
11
+
12
+ # Model repository setup
13
+
14
+ model_repo = os.getenv("MODEL_REPO", "sambodhan/sambodhan_urgency_classifier")
15
+
16
+ # Hugging Face API for version info
17
+ hf_api = HfApi()
18
+
19
+
20
+ # Startup and shutdown
21
+
22
+ @asynccontextmanager
23
+ async def lifespan(app: FastAPI):
24
+ global predictor
25
+ predictor = UrgencyPredictor(model_repo=model_repo)
26
+ yield
27
+
28
+
29
+ # FastAPI app
30
+
31
+ app = FastAPI(
32
+ title="Sambodhan Urgency Classifier API",
33
+ description="AI model that classifies citizen grievances by urgency with confidence scores.",
34
+ version="1.0.0",
35
+ lifespan=lifespan
36
  )
37
 
 
38
 
39
+ # Routes
40
 
41
+ @app.post("/predict_urgency", response_model=Union[UrgencyClassificationOutput, List[UrgencyClassificationOutput]])
42
+ def predict_urgency(input_data: TextInput):
43
+ try:
44
+ prediction = predictor.predict(input_data.text)
45
+ return prediction
46
+ except Exception as e:
47
+ raise HTTPException(status_code=500, detail=f"Prediction failed: {str(e)}")
48
 
49
+ @app.get("/")
50
+ def root():
51
+ latest_tag = None
52
+ try:
53
+ latest_tag = hf_api.list_repo_refs(repo_id=model_repo, repo_type="model").tags[0].name
54
+ except Exception:
55
+ latest_tag = "unknown"
56
 
 
 
57
  return {
58
+ "message": "Sambodhan Urgency Classifier API is running.",
59
+ "status": "Active" if predictor else "Inactive",
60
+ "model_version": latest_tag
61
  }
62
 
63
+
64
+ # For local testing (optional)
65
+
66
+ # if __name__ == "__main__":
67
+ # port = int(os.getenv("PORT", 7860))
68
+ # uvicorn.run("app:app", host="0.0.0.0", port=port)
predict_urgency_model.py ADDED
@@ -0,0 +1,62 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import AutoTokenizer, AutoModelForSequenceClassification, pipeline
2
+ import torch
3
+ import os
4
+
5
+ class UrgencyPredictor:
6
+ def __init__(self, model_repo="sambodhan/sambodhan_urgency_classifier",
7
+ cache_dir="/app/hf_cache"):
8
+ """Load model and tokenizer once at startup."""
9
+
10
+ self.model_repo = model_repo
11
+ self.cache_dir = cache_dir
12
+
13
+ # Ensure cache folder exists
14
+ os.makedirs(self.cache_dir, exist_ok=True)
15
+
16
+ # Device selection
17
+ self.device = 0 if torch.cuda.is_available() else -1
18
+
19
+ print("Loading tokenizer and model...")
20
+ # Load tokenizer and model
21
+ self.tokenizer = AutoTokenizer.from_pretrained(self.model_repo, cache_dir=self.cache_dir, force_download=True)
22
+ self.model = AutoModelForSequenceClassification.from_pretrained(self.model_repo, cache_dir=self.cache_dir, force_download=True)
23
+
24
+ # Create classification pipeline
25
+ self.classifier = pipeline(
26
+ "text-classification",
27
+ model=self.model,
28
+ tokenizer=self.tokenizer,
29
+ device=self.device,
30
+ return_all_scores=True
31
+ )
32
+ print("Model and tokenizer loaded successfully.")
33
+
34
+ def predict(self, texts):
35
+ """Predict urgency labels with scores for a single text or a batch."""
36
+ if isinstance(texts, str):
37
+ texts = [texts]
38
+
39
+ results = self.classifier(texts)
40
+ formatted_results = []
41
+
42
+ for preds in results:
43
+ # Sort by descending confidence
44
+ preds = sorted(preds, key=lambda x: x["score"], reverse=True)
45
+ top_pred = preds[0]
46
+ label = top_pred["label"]
47
+ confidence = round(top_pred["score"], 4)
48
+ scores_dict = {p["label"]: round(p["score"], 4) for p in preds}
49
+
50
+ formatted_results.append({
51
+ "label": label,
52
+ "confidence": confidence,
53
+ "scores": scores_dict
54
+ })
55
+
56
+ # Return single dict if only one input
57
+ return formatted_results[0] if len(formatted_results) == 1 else formatted_results
58
+
59
+ @staticmethod
60
+ def load_model():
61
+ """Helper to preload the model during Docker build."""
62
+ _ = UrgencyPredictor()
response_schema.py ADDED
@@ -0,0 +1,49 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Dict
2
+ from pydantic import BaseModel, Field, field_validator, model_validator
3
+ import re
4
+
5
+ # ---------------------------
6
+ # Text cleaning function
7
+ # ---------------------------
8
+ def clean_text(text: str) -> str:
9
+ """Clean grievance text by removing URLs, HTML tags, extra whitespace."""
10
+ text = re.sub(r'https?://\S+|www\.\S+', '', text) # Remove URLs
11
+ text = re.sub(r'<.*?>', '', text) # Remove HTML tags
12
+ text = re.sub(r'\n', ' ', text) # Replace newlines with space
13
+ text = re.sub(r'\s+', ' ', text).strip() # Reduce multiple spaces
14
+ return text
15
+
16
+ # ---------------------------
17
+ # Request schema
18
+ # ---------------------------
19
+ class TextInput(BaseModel):
20
+ text: str = Field(..., description="Grievance text to classify urgency")
21
+
22
+ @field_validator("text")
23
+ def validate_non_empty(cls, value: str) -> str:
24
+ value = value.strip()
25
+ if not value:
26
+ raise ValueError("Input text cannot be empty")
27
+ return value
28
+
29
+ @model_validator(mode="after")
30
+ def clean_text_after(cls):
31
+ self.text = clean_text(self.text)
32
+ return self
33
+
34
+ model_config = {
35
+ "json_schema_extra": {
36
+ "examples": [
37
+ {"text": "The water supply has been cut off for 3 days."},
38
+ {"text": "Streetlight on my street is not working, please fix urgently."}
39
+ ]
40
+ }
41
+ }
42
+
43
+ # ---------------------------
44
+ # Response schema
45
+ # ---------------------------
46
+ class UrgencyClassificationOutput(BaseModel):
47
+ label: str = Field(..., description="Top predicted urgency label")
48
+ confidence: float = Field(..., ge=0, le=1, description="Confidence score for top label")
49
+ scores: Dict[str, float] = Field(..., description="All label confidence scores")