from fastapi import FastAPI, HTTPException, Header from pydantic import BaseModel import pandas as pd, os, joblib, time, requests from huggingface_hub import hf_hub_download from prometheus_client import Counter, Histogram, Gauge, generate_latest app = FastAPI() API_KEY = os.getenv("API_KEY") HF_MODEL_REPO = os.getenv("HF_MODEL_REPO") PROM_PUSHGATEWAY = os.getenv("PROM_PUSHGATEWAY") # Prometheus metrics REQS = Counter("pred_requests_total", "Total prediction requests") LAT = Histogram("pred_request_latency_seconds", "Request latency") LATEST = Gauge("latest_prediction", "Last predicted value") # Load model and preprocessing artifacts try: m = hf_hub_download(repo_id=HF_MODEL_REPO, filename="best_model.joblib") e = hf_hub_download(repo_id=HF_MODEL_REPO, filename="models/encoders.joblib") s = hf_hub_download(repo_id=HF_MODEL_REPO, filename="models/scaler.joblib") f = hf_hub_download(repo_id=HF_MODEL_REPO, filename="models/feature_columns.joblib") model = joblib.load(m) encoders = joblib.load(e) scaler = joblib.load(s) feature_columns = joblib.load(f) loaded = True except Exception as ex: print("Model load error:", ex) loaded = False @app.get("/") def health(): return {"status": "ok", "model_loaded": loaded, "features": feature_columns} @app.post("/predict") def predict(payload: dict, x_api_key: str = Header(None)): if x_api_key != API_KEY: raise HTTPException(status_code=401, detail="Invalid API key") if not loaded: raise HTTPException(status_code=503, detail="Model not loaded") # Convert input dict to DataFrame and reindex to match training columns df = pd.DataFrame([payload]).reindex(columns=feature_columns, fill_value=0) # Scale numeric values df_scaled = scaler.transform(df) start = time.time() pred = model.predict(df_scaled)[0] LAT.observe(time.time() - start) REQS.inc() LATEST.set(pred) if PROM_PUSHGATEWAY: try: requests.post(f"{PROM_PUSHGATEWAY}/metrics/job/loan_model", data=generate_latest()) except Exception: pass return {"prediction": int(pred), "used_features": feature_columns} @app.get("/metrics") def metrics(): return generate_latest()