Spaces:
Sleeping
Sleeping
| import io | |
| import pickle | |
| import numpy as np | |
| import torch | |
| from fastapi import FastAPI, UploadFile, File, Form | |
| from fastapi.middleware.cors import CORSMiddleware | |
| from PIL import Image | |
| from transformers import AutoTokenizer, AutoModel | |
| import open_clip | |
| import re | |
| device = "cuda" if torch.cuda.is_available() else "cpu" | |
| TEXT_MODEL_NAME = "indobenchmark/indobert-large-p1" | |
| tokenizer = AutoTokenizer.from_pretrained(TEXT_MODEL_NAME) | |
| text_model = AutoModel.from_pretrained(TEXT_MODEL_NAME).to(device) | |
| text_model.eval() | |
| clip_model, _, clip_preprocess = open_clip.create_model_and_transforms( | |
| "EVA01-g-14-plus", | |
| pretrained="merged2b_s11b_b114k" | |
| ) | |
| clip_model.to(device) | |
| clip_model.eval() | |
| with open("xgb_full.pkl", "rb") as f: | |
| xgb_model = pickle.load(f) | |
| def preprocess_text(text: str) -> str: | |
| text = str(text).lower() | |
| text = re.sub(r'http\S+|www\.\S+', '', text) | |
| text = re.sub(r'@\w+|#\w+', '', text) | |
| text = re.sub(r'[^a-z\s]', ' ', text) | |
| text = re.sub(r'\s+', ' ', text).strip() | |
| return " ".join(text.split()) | |
| app = FastAPI( | |
| title="Multimodal Water Pollution Risk API", | |
| description=( | |
| "Input: text + image + geospatial + time\n" | |
| "Model: IndoBERT + EVA-CLIP (HF Hub) + XGBoost (xgb.pkl)\n" | |
| ), | |
| version="1.0.0", | |
| ) | |
| app.add_middleware( | |
| CORSMiddleware, | |
| allow_origins=["*"], | |
| allow_methods=["*"], | |
| allow_headers=["*"], | |
| ) | |
| def root(): | |
| return { | |
| "status": "OK", | |
| "message": "Multimodal Water Pollution Risk API is running.", | |
| "info": "Use POST /predict with text, image, and features.", | |
| } | |
| async def predict( | |
| text: str = Form(...), | |
| longitude: float = Form(...), | |
| latitude: float = Form(...), | |
| location_cluster: int = Form(...), | |
| hour: int = Form(...), | |
| dayofweek: int = Form(...), | |
| month: int = Form(...), | |
| image: UploadFile = File(...), | |
| ): | |
| # 1. preprocess text | |
| cleaned_text = preprocess_text(text) | |
| # 2. encode text (ambil CLS token-nya) | |
| text_inputs = tokenizer( | |
| cleaned_text, | |
| return_tensors="pt", | |
| padding="max_length", | |
| truncation=True, | |
| max_length=128, | |
| ) | |
| text_inputs = {k: v.to(device) for k, v in text_inputs.items()} | |
| with torch.no_grad(): | |
| text_emb = text_model(**text_inputs).last_hidden_state[:, 0, :] # take the CLS token only | |
| text_emb = text_emb.cpu().numpy() | |
| # 3. encode image (EVA-CLIP image embedding) | |
| img_bytes = await image.read() | |
| pil_img = Image.open(io.BytesIO(img_bytes)).convert("RGB") | |
| img_tensor = clip_preprocess(pil_img).unsqueeze(0).to(device) | |
| with torch.no_grad(): | |
| img_emb = clip_model.encode_image(img_tensor) | |
| img_emb = img_emb.cpu().numpy() | |
| # 4. additional numeric features (longitude, latitude, location_cluster, hour, dayofweek, month) | |
| add_feats = np.array( | |
| [[longitude, latitude, location_cluster, hour, dayofweek, month]], | |
| dtype=np.float32, | |
| ) | |
| # 5. concatenate (early fusion): [image_emb, text_emb, add_feats] | |
| fused = np.concatenate([img_emb, text_emb, add_feats], axis=1) | |
| # 6. predict | |
| proba = xgb_model.predict_proba(fused)[0] | |
| pred_idx = int(np.argmax(proba)) | |
| label = "KRITIS" if pred_idx == 1 else "WASPADA" | |
| return { | |
| "prediction": label, | |
| "probabilities": { | |
| "WASPADA": float(proba[0]), | |
| "KRITIS": float(proba[1]), | |
| }, | |
| } | |
| if __name__ == "__main__": | |
| import uvicorn | |
| uvicorn.run(app, host="0.0.0.0", port=7860) |