|
|
from transformers import AutoTokenizer, AutoModelForSequenceClassification, pipeline |
|
|
import torch |
|
|
import os |
|
|
|
|
|
class DepartmentPredictor: |
|
|
def __init__(self, model_repo="mr-kush/sambodhan-department-classification-model", |
|
|
cache_dir="/app/hf_cache"): |
|
|
"""Load model and tokenizer once at startup.""" |
|
|
|
|
|
self.model_repo = model_repo |
|
|
self.cache_dir = cache_dir |
|
|
|
|
|
|
|
|
|
|
|
os.makedirs(self.cache_dir, exist_ok=True) |
|
|
if cache_dir is None: |
|
|
cache_dir = os.getenv("HF_HOME", "./hf_cache") |
|
|
self.cache_dir = cache_dir |
|
|
os.makedirs(self.cache_dir, exist_ok=True) |
|
|
|
|
|
|
|
|
self.device = 0 if torch.cuda.is_available() else -1 |
|
|
|
|
|
print(" Loading tokenizer and model...") |
|
|
|
|
|
self.tokenizer = AutoTokenizer.from_pretrained(self.model_repo, cache_dir=self.cache_dir, force_download=True) |
|
|
self.model = AutoModelForSequenceClassification.from_pretrained(self.model_repo, cache_dir=self.cache_dir, force_download=True) |
|
|
|
|
|
|
|
|
self.classifier = pipeline( |
|
|
"text-classification", |
|
|
model=self.model, |
|
|
tokenizer=self.tokenizer, |
|
|
device=self.device, |
|
|
top_k = None |
|
|
) |
|
|
print(" Model and tokenizer loaded successfully.") |
|
|
|
|
|
def predict(self, texts): |
|
|
"""Predict departments with scores for a single text or a batch.""" |
|
|
if isinstance(texts, str): |
|
|
texts = [texts] |
|
|
|
|
|
results = self.classifier(texts) |
|
|
formatted_results = [] |
|
|
|
|
|
for preds in results: |
|
|
|
|
|
preds = sorted(preds, key=lambda x: x["score"], reverse=True) |
|
|
top_pred = preds[0] |
|
|
label = top_pred["label"] |
|
|
confidence = round(top_pred["score"], 4) |
|
|
scores_dict = {p["label"]: round(p["score"], 4) for p in preds} |
|
|
|
|
|
formatted_results.append({ |
|
|
"label": label, |
|
|
"confidence": confidence, |
|
|
"scores": scores_dict |
|
|
}) |
|
|
|
|
|
|
|
|
return formatted_results[0] if len(formatted_results) == 1 else formatted_results |
|
|
|
|
|
@staticmethod |
|
|
def load_model(): |
|
|
"""Helper to preload the model during Docker build.""" |
|
|
_ = DepartmentPredictor() |
|
|
|