|
|
import os |
|
|
import torch |
|
|
from fastapi import FastAPI |
|
|
from pydantic import BaseModel, Field |
|
|
from comet import load_from_checkpoint |
|
|
from huggingface_hub import snapshot_download, HfApi |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
app = FastAPI( |
|
|
title="XCOMET-XXL API", |
|
|
version="2.0.0", |
|
|
description="API para avaliação de traduções usando Unbabel/XCOMET-XXL API " |
|
|
|
|
|
) |
|
|
|
|
|
MODEL_NAME = "Unbabel/XCOMET-XXL" |
|
|
HF_TOKEN = os.environ.get("HF_TOKEN") |
|
|
SPACE_REPO_ID = os.environ.get("SPACE_REPO_ID", "nairut/xcomet-xxl") |
|
|
|
|
|
|
|
|
MODEL_DIR = os.path.join(os.path.dirname(__file__), "model") |
|
|
MODEL_CKPT = os.path.join(MODEL_DIR, "checkpoints", "model.ckpt") |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def ensure_model_persisted_once(): |
|
|
""" |
|
|
Faz o download do modelo COMETKiwi-DA-XXL para ./model (caso ainda não exista) |
|
|
e tenta commitar essa pasta no próprio Space, para persistência. |
|
|
""" |
|
|
if os.path.exists(MODEL_CKPT): |
|
|
print(f"✅ Modelo já existe em {MODEL_CKPT}. Pulando download.") |
|
|
return |
|
|
|
|
|
print("🔽 Baixando snapshot do modelo para ./model ...") |
|
|
snapshot_download( |
|
|
repo_id=MODEL_NAME, |
|
|
token=HF_TOKEN, |
|
|
local_dir=MODEL_DIR, |
|
|
local_dir_use_symlinks=False |
|
|
) |
|
|
|
|
|
assert os.path.exists(MODEL_CKPT), f"Checkpoint não encontrado: {MODEL_CKPT}" |
|
|
|
|
|
try: |
|
|
print("⬆️ Enviando pasta 'model/' para o repositório do Space ...") |
|
|
api = HfApi(token=HF_TOKEN) |
|
|
api.upload_folder( |
|
|
repo_id=SPACE_REPO_ID, |
|
|
repo_type="space", |
|
|
folder_path=MODEL_DIR, |
|
|
path_in_repo="model", |
|
|
commit_message="Persistência automática do modelo COMETKiwi-DA-XXL" |
|
|
) |
|
|
print("✅ Modelo persistido no Space.") |
|
|
except Exception as e: |
|
|
print(f"⚠️ Falha ao persistir modelo no Space: {e}") |
|
|
print(" (prosseguindo com o modelo local para esta sessão)") |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if "model" in globals(): |
|
|
del model |
|
|
torch.cuda.empty_cache() |
|
|
print("🧹 Modelo anterior removido da memória.") |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
ensure_model_persisted_once() |
|
|
|
|
|
print(f"📂 Carregando modelo de {MODEL_CKPT} ...") |
|
|
model = load_from_checkpoint(MODEL_CKPT) |
|
|
print("✅ Modelo COMETKiwi-DA-XXL carregado com sucesso!") |
|
|
|
|
|
USE_GPU = 1 if torch.cuda.is_available() else 0 |
|
|
print(f"⚙️ GPU detectada: {'sim' if USE_GPU else 'não'}") |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class TranslationPair(BaseModel): |
|
|
source: str = Field(alias="source", description="Texto original") |
|
|
target: str = Field(alias="target", description="Tradução humana") |
|
|
machine_translation: str = Field(alias="machine_translation", description="Tradução automática") |
|
|
|
|
|
class Config: |
|
|
allow_population_by_field_name = True |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def prepare_data(pairs: list[TranslationPair]): |
|
|
""" |
|
|
Converte lista de TranslationPair no formato esperado pelo COMET: |
|
|
[{"src": ..., "mt": ..., "ref": ...}, ...] |
|
|
""" |
|
|
data = [] |
|
|
for p in pairs: |
|
|
src = str(p.source).strip() |
|
|
mt = str(p.machine_translation).strip() |
|
|
ref = str(p.target).strip() |
|
|
data.append({"src": src, "mt": mt, "ref": ref}) |
|
|
return data |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@app.get("/") |
|
|
def root(): |
|
|
return { |
|
|
"message": "🚀 XCOMET-XXL API ativa e pronta para uso!", |
|
|
"gpu_enabled": torch.cuda.is_available(), |
|
|
"available_endpoints": ["/score", "/score_batch"] |
|
|
} |
|
|
|
|
|
|
|
|
@app.post("/score") |
|
|
def score_single(pair: TranslationPair): |
|
|
""" |
|
|
Avalia um único par de tradução (source → target) com COMET-XXL. |
|
|
""" |
|
|
try: |
|
|
data = [{ |
|
|
"src": str(pair.source), |
|
|
"mt": str(pair.target), |
|
|
"ref": str(pair.machine_translation) |
|
|
}] |
|
|
real_model = model.module if hasattr(model, "module") else model |
|
|
output = real_model.predict(data, batch_size=8, gpus=1) |
|
|
|
|
|
return { |
|
|
"system_score": getattr(output, "system_score", None), |
|
|
"segment_scores": getattr(output, "scores", None), |
|
|
"metadata": getattr(output, "metadata", None) |
|
|
} |
|
|
except Exception as e: |
|
|
print(f"❌ Erro em /score: {e}") |
|
|
return {"error": str(e)} |
|
|
|
|
|
|
|
|
@app.post("/score_batch") |
|
|
def score_batch(pairs: list[TranslationPair]): |
|
|
""" |
|
|
Avalia múltiplos pares de tradução em lote (batch). |
|
|
""" |
|
|
try: |
|
|
data = prepare_data(pairs) |
|
|
print(f"📊 Lote recebido: {len(data)} pares válidos") |
|
|
|
|
|
real_model = model.module if hasattr(model, "module") else model |
|
|
output = real_model.predict(data, batch_size=8, gpus=1) |
|
|
|
|
|
return { |
|
|
"system_score": getattr(output, "system_score", None), |
|
|
"segment_scores": getattr(output, "scores", None), |
|
|
"metadata": getattr(output, "metadata", None) |
|
|
} |
|
|
except Exception as e: |
|
|
print(f"❌ Erro no batch: {e}") |
|
|
return {"error": str(e)} |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
import uvicorn |
|
|
uvicorn.run(app, host="0.0.0.0", port=7860) |