File size: 6,227 Bytes
8c44aec a7d0488 c4fb99d 9846481 8c44aec 95198c9 c4fb99d 2a8af70 5f557e8 9846481 8c44aec 4c33821 5f557e8 8c44aec c4fb99d 4c33821 95198c9 5f557e8 95198c9 5f557e8 95198c9 9846481 5f557e8 9846481 95198c9 5f557e8 9846481 95198c9 5f557e8 95198c9 9846481 95198c9 9846481 95198c9 9846481 95198c9 9846481 5f557e8 95198c9 9846481 95198c9 9846481 4fe707c 9846481 5f557e8 9846481 95198c9 9846481 c9679e9 5f557e8 a7d0488 5f557e8 9846481 c9679e9 9846481 4fe707c 9846481 76d51af 9846481 c4fb99d 9846481 ac695f8 9846481 0dcce63 76d51af 9846481 ac695f8 9846481 0dcce63 9846481 95198c9 76d51af 9846481 95198c9 1e45caa 9846481 1647f01 9846481 76d51af 9846481 76d51af 9846481 76d51af 9846481 1647f01 9846481 0dcce63 95198c9 0dcce63 9846481 76d51af 0dcce63 95198c9 9846481 95198c9 0dcce63 95198c9 9846481 1e45caa 5f557e8 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 |
import os
import torch
from fastapi import FastAPI
from pydantic import BaseModel, Field
from comet import load_from_checkpoint
from huggingface_hub import snapshot_download, HfApi
# ==========================================================
# 🚀 Configuração da API
# ==========================================================
app = FastAPI(
title="XCOMET-XXL API",
version="2.0.0",
description="API para avaliação de traduções usando Unbabel/XCOMET-XXL API "
)
MODEL_NAME = "Unbabel/XCOMET-XXL"
HF_TOKEN = os.environ.get("HF_TOKEN") # defina nas Secrets do Space
SPACE_REPO_ID = os.environ.get("SPACE_REPO_ID", "nairut/xcomet-xxl")
# Diretório de cache local (dentro do Space ou ambiente local)
MODEL_DIR = os.path.join(os.path.dirname(__file__), "model")
MODEL_CKPT = os.path.join(MODEL_DIR, "checkpoints", "model.ckpt")
# ==========================================================
# ⚙️ Função auxiliar: baixa e persiste o modelo
# ==========================================================
def ensure_model_persisted_once():
"""
Faz o download do modelo COMETKiwi-DA-XXL para ./model (caso ainda não exista)
e tenta commitar essa pasta no próprio Space, para persistência.
"""
if os.path.exists(MODEL_CKPT):
print(f"✅ Modelo já existe em {MODEL_CKPT}. Pulando download.")
return
print("🔽 Baixando snapshot do modelo para ./model ...")
snapshot_download(
repo_id=MODEL_NAME,
token=HF_TOKEN,
local_dir=MODEL_DIR,
local_dir_use_symlinks=False
)
assert os.path.exists(MODEL_CKPT), f"Checkpoint não encontrado: {MODEL_CKPT}"
try:
print("⬆️ Enviando pasta 'model/' para o repositório do Space ...")
api = HfApi(token=HF_TOKEN)
api.upload_folder(
repo_id=SPACE_REPO_ID,
repo_type="space",
folder_path=MODEL_DIR,
path_in_repo="model",
commit_message="Persistência automática do modelo COMETKiwi-DA-XXL"
)
print("✅ Modelo persistido no Space.")
except Exception as e:
print(f"⚠️ Falha ao persistir modelo no Space: {e}")
print(" (prosseguindo com o modelo local para esta sessão)")
# ==========================================================
# ♻️ Inicialização limpa
# ==========================================================
# Remove da memória qualquer modelo carregado anteriormente
if "model" in globals():
del model
torch.cuda.empty_cache()
print("🧹 Modelo anterior removido da memória.")
# ==========================================================
# 📦 Inicialização do modelo
# ==========================================================
ensure_model_persisted_once()
print(f"📂 Carregando modelo de {MODEL_CKPT} ...")
model = load_from_checkpoint(MODEL_CKPT)
print("✅ Modelo COMETKiwi-DA-XXL carregado com sucesso!")
USE_GPU = 1 if torch.cuda.is_available() else 0
print(f"⚙️ GPU detectada: {'sim' if USE_GPU else 'não'}")
# ==========================================================
# 🧠 Estrutura dos dados de entrada
# ==========================================================
class TranslationPair(BaseModel):
source: str = Field(alias="source", description="Texto original")
target: str = Field(alias="target", description="Tradução humana")
machine_translation: str = Field(alias="machine_translation", description="Tradução automática")
class Config:
allow_population_by_field_name = True
# ==========================================================
# 🔧 Função utilitária
# ==========================================================
def prepare_data(pairs: list[TranslationPair]):
"""
Converte lista de TranslationPair no formato esperado pelo COMET:
[{"src": ..., "mt": ..., "ref": ...}, ...]
"""
data = []
for p in pairs:
src = str(p.source).strip()
mt = str(p.machine_translation).strip()
ref = str(p.target).strip()
data.append({"src": src, "mt": mt, "ref": ref})
return data
# ==========================================================
# 🌐 Endpoints
# ==========================================================
@app.get("/")
def root():
return {
"message": "🚀 XCOMET-XXL API ativa e pronta para uso!",
"gpu_enabled": torch.cuda.is_available(),
"available_endpoints": ["/score", "/score_batch"]
}
@app.post("/score")
def score_single(pair: TranslationPair):
"""
Avalia um único par de tradução (source → target) com COMET-XXL.
"""
try:
data = [{
"src": str(pair.source),
"mt": str(pair.target),
"ref": str(pair.machine_translation)
}]
real_model = model.module if hasattr(model, "module") else model
output = real_model.predict(data, batch_size=8, gpus=1)
return {
"system_score": getattr(output, "system_score", None),
"segment_scores": getattr(output, "scores", None),
"metadata": getattr(output, "metadata", None)
}
except Exception as e:
print(f"❌ Erro em /score: {e}")
return {"error": str(e)}
@app.post("/score_batch")
def score_batch(pairs: list[TranslationPair]):
"""
Avalia múltiplos pares de tradução em lote (batch).
"""
try:
data = prepare_data(pairs)
print(f"📊 Lote recebido: {len(data)} pares válidos")
real_model = model.module if hasattr(model, "module") else model
output = real_model.predict(data, batch_size=8, gpus=1)
return {
"system_score": getattr(output, "system_score", None),
"segment_scores": getattr(output, "scores", None),
"metadata": getattr(output, "metadata", None)
}
except Exception as e:
print(f"❌ Erro no batch: {e}")
return {"error": str(e)}
# ==========================================================
# ▶️ Execução local (para debug)
# ==========================================================
if __name__ == "__main__":
import uvicorn
uvicorn.run(app, host="0.0.0.0", port=7860) |