File size: 6,227 Bytes
8c44aec
a7d0488
c4fb99d
9846481
8c44aec
95198c9
c4fb99d
2a8af70
5f557e8
9846481
8c44aec
4c33821
5f557e8
 
 
8c44aec
c4fb99d
4c33821
95198c9
5f557e8
95198c9
5f557e8
 
95198c9
 
9846481
 
5f557e8
9846481
95198c9
 
5f557e8
9846481
95198c9
 
 
 
 
5f557e8
95198c9
 
 
9846481
 
95198c9
 
9846481
 
95198c9
9846481
95198c9
 
 
 
 
9846481
5f557e8
95198c9
9846481
95198c9
9846481
 
4fe707c
9846481
 
5f557e8
 
 
 
 
 
 
 
 
 
 
9846481
95198c9
 
9846481
c9679e9
5f557e8
a7d0488
5f557e8
 
9846481
c9679e9
9846481
 
 
4fe707c
9846481
76d51af
 
9846481
 
 
 
c4fb99d
9846481
 
 
 
 
 
 
 
 
ac695f8
9846481
0dcce63
 
76d51af
9846481
ac695f8
9846481
 
0dcce63
9846481
95198c9
 
 
76d51af
 
9846481
95198c9
1e45caa
9846481
 
1647f01
9846481
76d51af
9846481
 
 
 
 
76d51af
9846481
76d51af
 
 
9846481
 
 
 
 
 
 
 
1647f01
9846481
0dcce63
 
 
 
 
95198c9
0dcce63
 
9846481
76d51af
 
0dcce63
95198c9
 
9846481
 
95198c9
 
0dcce63
95198c9
 
9846481
 
 
 
1e45caa
 
5f557e8
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
import os
import torch
from fastapi import FastAPI
from pydantic import BaseModel, Field
from comet import load_from_checkpoint
from huggingface_hub import snapshot_download, HfApi

# ==========================================================
# 🚀 Configuração da API 
# ==========================================================
app = FastAPI(
    title="XCOMET-XXL API",
    version="2.0.0",
    description="API para avaliação de traduções usando Unbabel/XCOMET-XXL API "
               
)

MODEL_NAME = "Unbabel/XCOMET-XXL"
HF_TOKEN = os.environ.get("HF_TOKEN")  # defina nas Secrets do Space
SPACE_REPO_ID = os.environ.get("SPACE_REPO_ID", "nairut/xcomet-xxl")

# Diretório de cache local (dentro do Space ou ambiente local)
MODEL_DIR = os.path.join(os.path.dirname(__file__), "model")
MODEL_CKPT = os.path.join(MODEL_DIR, "checkpoints", "model.ckpt")


# ==========================================================
# ⚙️ Função auxiliar: baixa e persiste o modelo
# ==========================================================
def ensure_model_persisted_once():
    """
    Faz o download do modelo COMETKiwi-DA-XXL para ./model (caso ainda não exista)
    e tenta commitar essa pasta no próprio Space, para persistência.
    """
    if os.path.exists(MODEL_CKPT):
        print(f"✅ Modelo já existe em {MODEL_CKPT}. Pulando download.")
        return

    print("🔽 Baixando snapshot do modelo para ./model ...")
    snapshot_download(
        repo_id=MODEL_NAME,
        token=HF_TOKEN,
        local_dir=MODEL_DIR,
        local_dir_use_symlinks=False
    )

    assert os.path.exists(MODEL_CKPT), f"Checkpoint não encontrado: {MODEL_CKPT}"

    try:
        print("⬆️  Enviando pasta 'model/' para o repositório do Space ...")
        api = HfApi(token=HF_TOKEN)
        api.upload_folder(
            repo_id=SPACE_REPO_ID,
            repo_type="space",
            folder_path=MODEL_DIR,
            path_in_repo="model",
            commit_message="Persistência automática do modelo COMETKiwi-DA-XXL"
        )
        print("✅ Modelo persistido no Space.")
    except Exception as e:
        print(f"⚠️ Falha ao persistir modelo no Space: {e}")
        print("   (prosseguindo com o modelo local para esta sessão)")


# ==========================================================
# ♻️  Inicialização limpa
# ==========================================================
# Remove da memória qualquer modelo carregado anteriormente
if "model" in globals():
    del model
    torch.cuda.empty_cache()
    print("🧹 Modelo anterior removido da memória.")


# ==========================================================
# 📦 Inicialização do modelo
# ==========================================================
ensure_model_persisted_once()

print(f"📂 Carregando modelo de {MODEL_CKPT} ...")
model = load_from_checkpoint(MODEL_CKPT)
print("✅ Modelo COMETKiwi-DA-XXL carregado com sucesso!")

USE_GPU = 1 if torch.cuda.is_available() else 0
print(f"⚙️ GPU detectada: {'sim' if USE_GPU else 'não'}")


# ==========================================================
# 🧠 Estrutura dos dados de entrada
# ==========================================================
class TranslationPair(BaseModel):
    source: str = Field(alias="source", description="Texto original")
    target: str = Field(alias="target", description="Tradução humana")
    machine_translation: str = Field(alias="machine_translation", description="Tradução automática")

    class Config:
        allow_population_by_field_name = True


# ==========================================================
# 🔧 Função utilitária
# ==========================================================
def prepare_data(pairs: list[TranslationPair]):
    """
    Converte lista de TranslationPair no formato esperado pelo COMET:
    [{"src": ..., "mt": ..., "ref": ...}, ...]
    """
    data = []
    for p in pairs:
        src = str(p.source).strip()
        mt = str(p.machine_translation).strip()
        ref = str(p.target).strip()
        data.append({"src": src, "mt": mt, "ref": ref})
    return data


# ==========================================================
# 🌐 Endpoints 
# ==========================================================
@app.get("/")
def root():
    return {
        "message": "🚀 XCOMET-XXL API ativa e pronta para uso!",
        "gpu_enabled": torch.cuda.is_available(),
        "available_endpoints": ["/score", "/score_batch"]
    }


@app.post("/score")
def score_single(pair: TranslationPair):
    """
    Avalia um único par de tradução (source → target) com COMET-XXL.
    """
    try:
        data = [{
            "src": str(pair.source),
            "mt": str(pair.target),
            "ref": str(pair.machine_translation)
        }]
        real_model = model.module if hasattr(model, "module") else model
        output = real_model.predict(data, batch_size=8, gpus=1)

        return {
            "system_score": getattr(output, "system_score", None),
            "segment_scores": getattr(output, "scores", None),
            "metadata": getattr(output, "metadata", None)
        }
    except Exception as e:
        print(f"❌ Erro em /score: {e}")
        return {"error": str(e)}


@app.post("/score_batch")
def score_batch(pairs: list[TranslationPair]):
    """
    Avalia múltiplos pares de tradução em lote (batch).
    """
    try:
        data = prepare_data(pairs)
        print(f"📊 Lote recebido: {len(data)} pares válidos")

        real_model = model.module if hasattr(model, "module") else model
        output = real_model.predict(data, batch_size=8, gpus=1)

        return {
            "system_score": getattr(output, "system_score", None),
            "segment_scores": getattr(output, "scores", None),
            "metadata": getattr(output, "metadata", None)
        }
    except Exception as e:
        print(f"❌ Erro no batch: {e}")
        return {"error": str(e)}


# ==========================================================
# ▶️ Execução local (para debug)
# ==========================================================
if __name__ == "__main__":
    import uvicorn
    uvicorn.run(app, host="0.0.0.0", port=7860)