comet-xxl / app.py
nairut's picture
Update app.py
5f557e8 verified
import os
import torch
from fastapi import FastAPI
from pydantic import BaseModel, Field
from comet import load_from_checkpoint
from huggingface_hub import snapshot_download, HfApi
# ==========================================================
# 🚀 Configuração da API
# ==========================================================
app = FastAPI(
title="XCOMET-XXL API",
version="2.0.0",
description="API para avaliação de traduções usando Unbabel/XCOMET-XXL API "
)
MODEL_NAME = "Unbabel/XCOMET-XXL"
HF_TOKEN = os.environ.get("HF_TOKEN") # defina nas Secrets do Space
SPACE_REPO_ID = os.environ.get("SPACE_REPO_ID", "nairut/xcomet-xxl")
# Diretório de cache local (dentro do Space ou ambiente local)
MODEL_DIR = os.path.join(os.path.dirname(__file__), "model")
MODEL_CKPT = os.path.join(MODEL_DIR, "checkpoints", "model.ckpt")
# ==========================================================
# ⚙️ Função auxiliar: baixa e persiste o modelo
# ==========================================================
def ensure_model_persisted_once():
"""
Faz o download do modelo COMETKiwi-DA-XXL para ./model (caso ainda não exista)
e tenta commitar essa pasta no próprio Space, para persistência.
"""
if os.path.exists(MODEL_CKPT):
print(f"✅ Modelo já existe em {MODEL_CKPT}. Pulando download.")
return
print("🔽 Baixando snapshot do modelo para ./model ...")
snapshot_download(
repo_id=MODEL_NAME,
token=HF_TOKEN,
local_dir=MODEL_DIR,
local_dir_use_symlinks=False
)
assert os.path.exists(MODEL_CKPT), f"Checkpoint não encontrado: {MODEL_CKPT}"
try:
print("⬆️ Enviando pasta 'model/' para o repositório do Space ...")
api = HfApi(token=HF_TOKEN)
api.upload_folder(
repo_id=SPACE_REPO_ID,
repo_type="space",
folder_path=MODEL_DIR,
path_in_repo="model",
commit_message="Persistência automática do modelo COMETKiwi-DA-XXL"
)
print("✅ Modelo persistido no Space.")
except Exception as e:
print(f"⚠️ Falha ao persistir modelo no Space: {e}")
print(" (prosseguindo com o modelo local para esta sessão)")
# ==========================================================
# ♻️ Inicialização limpa
# ==========================================================
# Remove da memória qualquer modelo carregado anteriormente
if "model" in globals():
del model
torch.cuda.empty_cache()
print("🧹 Modelo anterior removido da memória.")
# ==========================================================
# 📦 Inicialização do modelo
# ==========================================================
ensure_model_persisted_once()
print(f"📂 Carregando modelo de {MODEL_CKPT} ...")
model = load_from_checkpoint(MODEL_CKPT)
print("✅ Modelo COMETKiwi-DA-XXL carregado com sucesso!")
USE_GPU = 1 if torch.cuda.is_available() else 0
print(f"⚙️ GPU detectada: {'sim' if USE_GPU else 'não'}")
# ==========================================================
# 🧠 Estrutura dos dados de entrada
# ==========================================================
class TranslationPair(BaseModel):
source: str = Field(alias="source", description="Texto original")
target: str = Field(alias="target", description="Tradução humana")
machine_translation: str = Field(alias="machine_translation", description="Tradução automática")
class Config:
allow_population_by_field_name = True
# ==========================================================
# 🔧 Função utilitária
# ==========================================================
def prepare_data(pairs: list[TranslationPair]):
"""
Converte lista de TranslationPair no formato esperado pelo COMET:
[{"src": ..., "mt": ..., "ref": ...}, ...]
"""
data = []
for p in pairs:
src = str(p.source).strip()
mt = str(p.machine_translation).strip()
ref = str(p.target).strip()
data.append({"src": src, "mt": mt, "ref": ref})
return data
# ==========================================================
# 🌐 Endpoints
# ==========================================================
@app.get("/")
def root():
return {
"message": "🚀 XCOMET-XXL API ativa e pronta para uso!",
"gpu_enabled": torch.cuda.is_available(),
"available_endpoints": ["/score", "/score_batch"]
}
@app.post("/score")
def score_single(pair: TranslationPair):
"""
Avalia um único par de tradução (source → target) com COMET-XXL.
"""
try:
data = [{
"src": str(pair.source),
"mt": str(pair.target),
"ref": str(pair.machine_translation)
}]
real_model = model.module if hasattr(model, "module") else model
output = real_model.predict(data, batch_size=8, gpus=1)
return {
"system_score": getattr(output, "system_score", None),
"segment_scores": getattr(output, "scores", None),
"metadata": getattr(output, "metadata", None)
}
except Exception as e:
print(f"❌ Erro em /score: {e}")
return {"error": str(e)}
@app.post("/score_batch")
def score_batch(pairs: list[TranslationPair]):
"""
Avalia múltiplos pares de tradução em lote (batch).
"""
try:
data = prepare_data(pairs)
print(f"📊 Lote recebido: {len(data)} pares válidos")
real_model = model.module if hasattr(model, "module") else model
output = real_model.predict(data, batch_size=8, gpus=1)
return {
"system_score": getattr(output, "system_score", None),
"segment_scores": getattr(output, "scores", None),
"metadata": getattr(output, "metadata", None)
}
except Exception as e:
print(f"❌ Erro no batch: {e}")
return {"error": str(e)}
# ==========================================================
# ▶️ Execução local (para debug)
# ==========================================================
if __name__ == "__main__":
import uvicorn
uvicorn.run(app, host="0.0.0.0", port=7860)