import io import pickle import numpy as np import torch from fastapi import FastAPI, UploadFile, File, Form from fastapi.middleware.cors import CORSMiddleware from PIL import Image from transformers import AutoTokenizer, AutoModel import open_clip device = "cuda" if torch.cuda.is_available() else "cpu" TEXT_MODEL_NAME = "indobenchmark/indobert-large-p1" tokenizer = AutoTokenizer.from_pretrained(TEXT_MODEL_NAME) text_model = AutoModel.from_pretrained(TEXT_MODEL_NAME).to(device) text_model.eval() clip_model, _, clip_preprocess = open_clip.create_model_and_transforms( "EVA01-g-14-plus", pretrained="merged2b_s11b_b114k" ) clip_model.to(device) clip_model.eval() with open("xgb_full.pkl", "rb") as f: xgb_model = pickle.load(f) def preprocess_text(text: str) -> str: # nanti ditambahin preprocessingnya return text.strip() app = FastAPI( title="Multimodal Water Pollution Risk API", description=( "Input: text + image + geospatial + time\n" "Model: IndoBERT + EVA-CLIP (HF Hub) + XGBoost (xgb.pkl)\n" ), version="1.0.0", ) app.add_middleware( CORSMiddleware, allow_origins=["*"], allow_methods=["*"], allow_headers=["*"], ) @app.get("/") def root(): return { "status": "OK", "message": "Multimodal Water Pollution Risk API is running.", "info": "Use POST /predict with text, image, and features.", } @app.post("/predict") async def predict( text: str = Form(...), longitude: float = Form(...), latitude: float = Form(...), location_cluster: int = Form(...), hour: int = Form(...), dayofweek: int = Form(...), month: int = Form(...), image: UploadFile = File(...), ): # 1. Preprocess text cleaned_text = preprocess_text(text) # 2. Encode text -> IndoBERT CLS embedding (shape: [1, 1024]) text_inputs = tokenizer( cleaned_text, return_tensors="pt", padding="max_length", truncation=True, max_length=128, ) text_inputs = {k: v.to(device) for k, v in text_inputs.items()} with torch.no_grad(): text_emb = text_model(**text_inputs).last_hidden_state[:, 0, :] text_emb = text_emb.cpu().numpy() # 3. Encode image -> EVA-CLIP image embedding (shape: [1, 1024] / sesuai model) img_bytes = await image.read() pil_img = Image.open(io.BytesIO(img_bytes)).convert("RGB") img_tensor = clip_preprocess(pil_img).unsqueeze(0).to(device) with torch.no_grad(): img_emb = clip_model.encode_image(img_tensor) img_emb = img_emb.cpu().numpy() # 4. Additional numeric features (same order as training) add_feats = np.array( [[longitude, latitude, location_cluster, hour, dayofweek, month]], dtype=np.float32, ) # 5. Concatenate: [image_emb, text_emb, add_feats] # pastikan bentuk-nya [1, dim_image + dim_text + 6] fused = np.concatenate([img_emb, text_emb, add_feats], axis=1) # 6. XGBoost prediction proba = xgb_model.predict_proba(fused)[0] # shape: [2] pred_idx = int(np.argmax(proba)) label = "KRITIS" if pred_idx == 1 else "WASPADA" return { "prediction": label, "probabilities": { "WASPADA": float(proba[0]), "KRITIS": float(proba[1]), }, } if __name__ == "__main__": import uvicorn uvicorn.run(app, host="0.0.0.0", port=7860)