|
|
import contextlib, io, base64, torch, json, os, threading |
|
|
from PIL import Image |
|
|
import open_clip |
|
|
from huggingface_hub import hf_hub_download, create_commit, CommitOperationAdd |
|
|
from safetensors.torch import save_file, load_file |
|
|
from reparam import reparameterize_model |
|
|
|
|
|
ADMIN_TOKEN = os.getenv("ADMIN_TOKEN", "") |
|
|
HF_LABEL_REPO = os.getenv("HF_LABEL_REPO", "") |
|
|
HF_WRITE_TOKEN = os.getenv("HF_WRITE_TOKEN", "") |
|
|
HF_READ_TOKEN = os.getenv("HF_READ_TOKEN", HF_WRITE_TOKEN) |
|
|
|
|
|
|
|
|
def _fingerprint(device: str, dtype: torch.dtype) -> dict: |
|
|
return { |
|
|
"model_id": "MobileCLIP-B", |
|
|
"pretrained": "datacompdr", |
|
|
"open_clip": getattr(open_clip, "__version__", "unknown"), |
|
|
"torch": torch.__version__, |
|
|
"cuda": torch.version.cuda if torch.cuda.is_available() else None, |
|
|
"dtype_runtime": str(dtype), |
|
|
"text_norm": "L2", |
|
|
"logit_scale": 100.0, |
|
|
} |
|
|
|
|
|
|
|
|
class EndpointHandler: |
|
|
def __init__(self, path: str = ""): |
|
|
self.device = "cuda" if torch.cuda.is_available() else "cpu" |
|
|
self.dtype = torch.float16 if self.device == "cuda" else torch.float32 |
|
|
|
|
|
|
|
|
model, _, self.preprocess = open_clip.create_model_and_transforms( |
|
|
"MobileCLIP-B", pretrained="datacompdr" |
|
|
) |
|
|
model.eval() |
|
|
model = reparameterize_model(model) |
|
|
model.to(self.device) |
|
|
if self.device == "cuda": |
|
|
model = model.to(torch.float16) |
|
|
self.model = model |
|
|
self.tokenizer = open_clip.get_tokenizer("MobileCLIP-B") |
|
|
self.fingerprint = _fingerprint(self.device, self.dtype) |
|
|
self._lock = threading.Lock() |
|
|
|
|
|
|
|
|
loaded = False |
|
|
if HF_LABEL_REPO: |
|
|
with contextlib.suppress(Exception): |
|
|
loaded = self._load_snapshot_from_hub_latest() |
|
|
if not loaded: |
|
|
items_path = "items.json" if not path else f"{path}/items.json" |
|
|
with open(items_path, "r", encoding="utf-8") as f: |
|
|
items = json.load(f) |
|
|
prompts = [it["prompt"] for it in items] |
|
|
self.class_ids = [int(it["id"]) for it in items] |
|
|
self.class_names = [it["name"] for it in items] |
|
|
with torch.no_grad(): |
|
|
toks = self.tokenizer(prompts).to(self.device) |
|
|
feats = self.model.encode_text(toks) |
|
|
feats = feats / feats.norm(dim=-1, keepdim=True) |
|
|
self.text_features_cpu = feats.detach().cpu().to(torch.float32).contiguous() |
|
|
self._to_device() |
|
|
self.labels_version = 1 |
|
|
|
|
|
def __call__(self, data): |
|
|
payload = data.get("inputs", data) |
|
|
|
|
|
|
|
|
op = payload.get("op") |
|
|
if op == "upsert_labels": |
|
|
if payload.get("token") != ADMIN_TOKEN: |
|
|
return {"error": "unauthorized"} |
|
|
items = payload.get("items", []) or [] |
|
|
added = self._upsert_items(items) |
|
|
if added > 0: |
|
|
new_ver = int(getattr(self, "labels_version", 1)) + 1 |
|
|
try: |
|
|
self._persist_snapshot_to_hub(new_ver) |
|
|
self.labels_version = new_ver |
|
|
except Exception as e: |
|
|
return {"status": "error", "added": added, "detail": str(e)} |
|
|
return {"status": "ok", "added": added, "labels_version": getattr(self, "labels_version", 1)} |
|
|
|
|
|
|
|
|
if op == "reload_labels": |
|
|
if payload.get("token") != ADMIN_TOKEN: |
|
|
return {"error": "unauthorized"} |
|
|
try: |
|
|
ver = int(payload.get("version")) |
|
|
except Exception: |
|
|
return {"error": "invalid_version"} |
|
|
ok = self._load_snapshot_from_hub_version(ver) |
|
|
return {"status": "ok" if ok else "nochange", "labels_version": getattr(self, "labels_version", 0)} |
|
|
|
|
|
|
|
|
if op == "remove_labels": |
|
|
if payload.get("token") != ADMIN_TOKEN: |
|
|
return {"error": "unauthorized"} |
|
|
ids_to_remove = set(payload.get("ids", [])) |
|
|
if not ids_to_remove: |
|
|
return {"error": "no_ids_provided"} |
|
|
|
|
|
removed = self._remove_items(ids_to_remove) |
|
|
if removed > 0: |
|
|
new_ver = int(getattr(self, "labels_version", 1)) + 1 |
|
|
try: |
|
|
self._persist_snapshot_to_hub(new_ver) |
|
|
self.labels_version = new_ver |
|
|
except Exception as e: |
|
|
return {"status": "error", "removed": removed, "detail": str(e)} |
|
|
return {"status": "ok", "removed": removed, "labels_version": getattr(self, "labels_version", 1)} |
|
|
|
|
|
|
|
|
min_ver = payload.get("min_labels_version") |
|
|
if isinstance(min_ver, int) and min_ver > getattr(self, "labels_version", 0): |
|
|
with contextlib.suppress(Exception): |
|
|
self._load_snapshot_from_hub_version(min_ver) |
|
|
|
|
|
|
|
|
img_b64 = payload["image"] |
|
|
image = Image.open(io.BytesIO(base64.b64decode(img_b64))).convert("RGB") |
|
|
img_tensor = self.preprocess(image).unsqueeze(0).to(self.device) |
|
|
if self.device == "cuda": |
|
|
img_tensor = img_tensor.to(torch.float16) |
|
|
with torch.no_grad(): |
|
|
img_feat = self.model.encode_image(img_tensor) |
|
|
img_feat /= img_feat.norm(dim=-1, keepdim=True) |
|
|
probs = (100.0 * img_feat @ self.text_features.T).softmax(dim=-1)[0] |
|
|
results = zip(self.class_ids, self.class_names, probs.detach().cpu().tolist()) |
|
|
top_k = int(payload.get("top_k", len(self.class_ids))) |
|
|
return sorted( |
|
|
[{"id": i, "label": name, "score": float(p)} for i, name, p in results], |
|
|
key=lambda x: x["score"], |
|
|
reverse=True, |
|
|
)[:top_k] |
|
|
|
|
|
|
|
|
def _encode_text(self, prompts): |
|
|
with torch.no_grad(): |
|
|
toks = self.tokenizer(prompts).to(self.device) |
|
|
feats = self.model.encode_text(toks) |
|
|
feats = feats / feats.norm(dim=-1, keepdim=True) |
|
|
return feats |
|
|
|
|
|
def _to_device(self): |
|
|
self.text_features = self.text_features_cpu.to( |
|
|
self.device, dtype=(torch.float16 if self.device == "cuda" else torch.float32) |
|
|
) |
|
|
|
|
|
def _upsert_items(self, new_items): |
|
|
if not new_items: |
|
|
return 0 |
|
|
with self._lock: |
|
|
|
|
|
known_ids = set(getattr(self, "class_ids", [])) |
|
|
|
|
|
known_names_lower = set(name.lower() for name in getattr(self, "class_names", [])) |
|
|
|
|
|
|
|
|
batch = [] |
|
|
for it in new_items: |
|
|
item_id = int(it.get("id")) |
|
|
item_name = it.get("name") |
|
|
|
|
|
|
|
|
if item_id in known_ids: |
|
|
continue |
|
|
elif item_name.lower() in known_names_lower: |
|
|
continue |
|
|
else: |
|
|
batch.append(it) |
|
|
|
|
|
if not batch: |
|
|
return 0 |
|
|
|
|
|
|
|
|
prompts = [it["prompt"] for it in batch] |
|
|
feats = self._encode_text(prompts).detach().cpu().to(torch.float32) |
|
|
|
|
|
|
|
|
if not hasattr(self, "text_features_cpu"): |
|
|
self.text_features_cpu = feats.contiguous() |
|
|
self.class_ids = [int(it["id"]) for it in batch] |
|
|
self.class_names = [it["name"] for it in batch] |
|
|
else: |
|
|
self.text_features_cpu = torch.cat([self.text_features_cpu, feats], dim=0).contiguous() |
|
|
self.class_ids.extend([int(it["id"]) for it in batch]) |
|
|
self.class_names.extend([it["name"] for it in batch]) |
|
|
|
|
|
self._to_device() |
|
|
return len(batch) |
|
|
|
|
|
def _remove_items(self, ids_to_remove): |
|
|
if not ids_to_remove or not hasattr(self, "class_ids"): |
|
|
return 0 |
|
|
with self._lock: |
|
|
ids_to_remove = set(int(id) for id in ids_to_remove) |
|
|
|
|
|
indices_to_keep = [] |
|
|
removed_count = 0 |
|
|
for i, class_id in enumerate(self.class_ids): |
|
|
if class_id not in ids_to_remove: |
|
|
indices_to_keep.append(i) |
|
|
else: |
|
|
removed_count += 1 |
|
|
|
|
|
if removed_count == 0: |
|
|
return 0 |
|
|
|
|
|
|
|
|
if indices_to_keep: |
|
|
self.text_features_cpu = self.text_features_cpu[indices_to_keep].contiguous() |
|
|
self.class_ids = [self.class_ids[i] for i in indices_to_keep] |
|
|
self.class_names = [self.class_names[i] for i in indices_to_keep] |
|
|
else: |
|
|
|
|
|
self.text_features_cpu = torch.empty(0, self.text_features_cpu.shape[1]) |
|
|
self.class_ids = [] |
|
|
self.class_names = [] |
|
|
|
|
|
self._to_device() |
|
|
return removed_count |
|
|
|
|
|
def _persist_snapshot_to_hub(self, version: int): |
|
|
if not HF_LABEL_REPO: |
|
|
raise RuntimeError("HF_LABEL_REPO not set") |
|
|
if not HF_WRITE_TOKEN: |
|
|
raise RuntimeError("HF_WRITE_TOKEN not set for publishing") |
|
|
|
|
|
emb_path = "/tmp/embeddings.safetensors" |
|
|
meta_path = "/tmp/meta.json" |
|
|
latest_bytes = io.BytesIO(json.dumps({"version": int(version)}).encode("utf-8")) |
|
|
|
|
|
save_file({"embeddings": self.text_features_cpu.to(torch.float32)}, emb_path) |
|
|
meta = { |
|
|
"items": [{"id": int(i), "name": n} for i, n in zip(self.class_ids, self.class_names)], |
|
|
"fingerprint": self.fingerprint, |
|
|
"dims": int(self.text_features_cpu.shape[1]), |
|
|
"count": int(self.text_features_cpu.shape[0]), |
|
|
"version": int(version), |
|
|
} |
|
|
with open(meta_path, "w", encoding="utf-8") as f: |
|
|
json.dump(meta, f) |
|
|
|
|
|
ops = [ |
|
|
CommitOperationAdd( |
|
|
path_in_repo=f"snapshots/v{version}/embeddings.safetensors", |
|
|
path_or_fileobj=emb_path |
|
|
), |
|
|
CommitOperationAdd( |
|
|
path_in_repo=f"snapshots/v{version}/meta.json", |
|
|
path_or_fileobj=meta_path |
|
|
), |
|
|
CommitOperationAdd( |
|
|
path_in_repo="snapshots/latest.json", |
|
|
path_or_fileobj=latest_bytes |
|
|
), |
|
|
] |
|
|
create_commit( |
|
|
repo_id=HF_LABEL_REPO, |
|
|
repo_type="dataset", |
|
|
operations=ops, |
|
|
token=HF_WRITE_TOKEN, |
|
|
commit_message=f"labels v{version}", |
|
|
) |
|
|
|
|
|
def _load_snapshot_from_hub_version(self, version: int) -> bool: |
|
|
if not HF_LABEL_REPO: |
|
|
return False |
|
|
with self._lock: |
|
|
emb_p = hf_hub_download( |
|
|
HF_LABEL_REPO, |
|
|
f"snapshots/v{version}/embeddings.safetensors", |
|
|
repo_type="dataset", |
|
|
token=HF_READ_TOKEN, |
|
|
force_download=True, |
|
|
) |
|
|
meta_p = hf_hub_download( |
|
|
HF_LABEL_REPO, |
|
|
f"snapshots/v{version}/meta.json", |
|
|
repo_type="dataset", |
|
|
token=HF_READ_TOKEN, |
|
|
force_download=True, |
|
|
) |
|
|
meta = json.load(open(meta_p, "r", encoding="utf-8")) |
|
|
if meta.get("fingerprint") != self.fingerprint: |
|
|
raise RuntimeError("Embedding/model fingerprint mismatch") |
|
|
feats = load_file(emb_p)["embeddings"] |
|
|
self.text_features_cpu = feats.contiguous() |
|
|
self.class_ids = [int(x["id"]) for x in meta.get("items", [])] |
|
|
self.class_names = [x["name"] for x in meta.get("items", [])] |
|
|
self.labels_version = int(meta.get("version", version)) |
|
|
self._to_device() |
|
|
return True |
|
|
|
|
|
def _load_snapshot_from_hub_latest(self) -> bool: |
|
|
if not HF_LABEL_REPO: |
|
|
return False |
|
|
try: |
|
|
latest_p = hf_hub_download( |
|
|
HF_LABEL_REPO, |
|
|
"snapshots/latest.json", |
|
|
repo_type="dataset", |
|
|
token=HF_READ_TOKEN, |
|
|
) |
|
|
except Exception: |
|
|
return False |
|
|
latest = json.load(open(latest_p, "r", encoding="utf-8")) |
|
|
ver = int(latest.get("version", 0)) |
|
|
if ver <= 0: |
|
|
return False |
|
|
return self._load_snapshot_from_hub_version(ver) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|