servejj / app.py
froidhj's picture
Update app.py
b037393 verified
# app.py — TrashTrack Turbo (compatível com ESP32 + multipart/form-data)
import os
# Evita erro do libgomp e excesso de threads no CPU
os.environ["OMP_NUM_THREADS"] = "1"
os.environ["TOKENIZERS_PARALLELISM"] = "false"
from fastapi import FastAPI, File, UploadFile
from fastapi.responses import JSONResponse
from PIL import Image, ImageOps
import io, time, torch
import numpy as np
from transformers import AutoProcessor, AutoModel
# ==============================================
# ⚙️ Configurações
# ==============================================
MODEL_ID = "google/siglip-so400m-patch14-384" # modelo de maior assertividade
device = "cuda" if torch.cuda.is_available() else "cpu"
# Detecta automaticamente se há torchvision
try:
import torchvision # noqa: F401
USE_FAST = True
except Exception:
USE_FAST = False
print(f"🚀 Carregando modelo {MODEL_ID} (use_fast={USE_FAST})...")
# Usa caminho "slow" se não tiver torchvision
processor = AutoProcessor.from_pretrained(MODEL_ID, use_fast=USE_FAST)
model = AutoModel.from_pretrained(
MODEL_ID,
dtype=torch.float16 if device == "cuda" else None
).to(device).eval()
print("✅ Modelo carregado com sucesso.")
# ==============================================
# 📋 Classes (PT + EN) — **SEM VIDRO**
# ==============================================
labels = {
"plastico": [
"plástico", "garrafa PET", "tampinha plástica",
"sacola plástica", "plastic bottle"
],
"papel": [
"papel", "folha", "envelope de papel",
"paper sheet", "paper wrapper"
],
"metal": [
"lata", "alumínio", "tampinha metálica",
"metal cap", "can"
],
}
def _promptize(term: str) -> str:
return f"centered {term} on a white background; ignore the background; classify only the object"
texts = [_promptize(t) for group in labels.values() for t in group]
# ==============================================
# 🔧 Util — recorte do foreground ignorando fundo branco
# ==============================================
def crop_foreground_ignore_white(pil: Image.Image) -> Image.Image:
img = pil.convert("RGB")
arr = np.array(img)
r, g, b = arr[..., 0], arr[..., 1], arr[..., 2]
whiteish = (r > 230) & (g > 230) & (b > 230)
fg = ~whiteish
if fg.sum() < 500:
w, h = img.size
cw, ch = int(w * 0.8), int(h * 0.8)
left, top = (w - cw) // 2, (h - ch) // 2
return img.crop((left, top, left + cw, top + ch))
ys, xs = np.where(fg)
y0, y1 = ys.min(), ys.max()
x0, x1 = xs.min(), xs.max()
py, px = int(0.03 * img.height), int(0.03 * img.width)
y0 = max(0, y0 - py); y1 = min(img.height - 1, y1 + py)
x0 = max(0, x0 - px); x1 = min(img.width - 1, x1 + px)
return img.crop((x0, y0, x1 + 1, y1 + 1))
# ==============================================
# 🌐 App FastAPI
# ==============================================
app = FastAPI(title="TrashTrack Turbo — ESP32 Compatible")
@app.get("/")
def root():
return {"ok": True, "model": MODEL_ID, "mode": "multipart/files[]", "classes": list(labels.keys())}
@app.post("/predict")
async def predict(files_: list[UploadFile] = File(..., alias="files[]")):
try:
t0 = time.time()
results = []
for f in files_:
data = await f.read()
image = Image.open(io.BytesIO(data))
image = ImageOps.exif_transpose(image).convert("RGB")
image = crop_foreground_ignore_white(image)
text_inputs = processor(text=texts, return_tensors="pt", padding=True).to(device)
image_inputs = processor(images=image, return_tensors="pt").to(device)
with torch.inference_mode():
txt_emb, img_emb = None, None
if hasattr(model, "get_text_features"):
txt_emb = model.get_text_features(**text_inputs)
if hasattr(model, "get_image_features"):
img_emb = model.get_image_features(**image_inputs)
if txt_emb is None or img_emb is None:
joint = processor(text=texts, images=image, return_tensors="pt", padding=True).to(device)
out = model(**joint)
if txt_emb is None:
txt_emb = getattr(out, "text_embeds", getattr(out, "text_embeds_projected", None))
if img_emb is None:
img_emb = getattr(out, "image_embeds", getattr(out, "image_embeds_projected", None))
img_emb = torch.nn.functional.normalize(img_emb, dim=-1)
txt_emb = torch.nn.functional.normalize(txt_emb, dim=-1)
logits = (img_emb @ txt_emb.t()).squeeze(0)
probs = torch.softmax(logits.float().cpu(), dim=-1).tolist()
idx, scores = 0, {}
for key, group in labels.items():
g = probs[idx: idx + len(group)]
s = sum(g) / len(g)
scores[key] = s
idx += len(group)
best = max(scores, key=scores.get)
conf = round(float(scores[best]), 3)
results.append((best, conf))
votos = {}
for r, c in results:
votos.setdefault(r, []).append(c)
final = max(votos, key=lambda k: sum(votos[k]) / len(votos[k]))
conf = round(sum(votos[final]) / len(votos[final]), 3)
latency = round(time.time() - t0, 2)
print(f"[OK] {final} ({conf}) em {latency}s")
return JSONResponse({"label": final, "conf": conf, "latency_s": latency})
except Exception as e:
print(f"[ERRO] {e}")
return JSONResponse({"error": str(e)}, status_code=500)