clipspace / handler.py
borso271's picture
Add case-insensitive name deduplication (cat = Cat = CAT)
4b44c42
raw
history blame
19.5 kB
import contextlib, io, base64, torch, json, os, threading
from PIL import Image
import open_clip
from huggingface_hub import hf_hub_download, create_commit, CommitOperationAdd
from safetensors.torch import save_file, load_file
from reparam import reparameterize_model
ADMIN_TOKEN = os.getenv("ADMIN_TOKEN", "")
HF_LABEL_REPO = os.getenv("HF_LABEL_REPO", "") # e.g. "org/mobileclip-labels"
HF_WRITE_TOKEN = os.getenv("HF_WRITE_TOKEN", "")
HF_READ_TOKEN = os.getenv("HF_READ_TOKEN", HF_WRITE_TOKEN)
def _fingerprint(device: str, dtype: torch.dtype) -> dict:
return {
"model_id": "MobileCLIP-B",
"pretrained": "datacompdr",
"open_clip": getattr(open_clip, "__version__", "unknown"),
"torch": torch.__version__,
"cuda": torch.version.cuda if torch.cuda.is_available() else None,
"dtype_runtime": str(dtype),
"text_norm": "L2",
"logit_scale": 100.0,
}
class EndpointHandler:
def __init__(self, path: str = ""):
self.device = "cuda" if torch.cuda.is_available() else "cpu"
self.dtype = torch.float16 if self.device == "cuda" else torch.float32
# 1) Load model + transforms
model, _, self.preprocess = open_clip.create_model_and_transforms(
"MobileCLIP-B", pretrained="datacompdr"
)
model.eval()
model = reparameterize_model(model)
model.to(self.device)
if self.device == "cuda":
model = model.to(torch.float16)
self.model = model
self.tokenizer = open_clip.get_tokenizer("MobileCLIP-B")
self.fingerprint = _fingerprint(self.device, self.dtype)
self._lock = threading.Lock()
# 2) Try to load snapshot from Hub; else seed from items.json
loaded = False
if HF_LABEL_REPO:
with contextlib.suppress(Exception):
loaded = self._load_snapshot_from_hub_latest()
if not loaded:
items_path = "items.json" if not path else f"{path}/items.json"
with open(items_path, "r", encoding="utf-8") as f:
items = json.load(f)
prompts = [it["prompt"] for it in items]
self.class_ids = [int(it["id"]) for it in items]
self.class_names = [it["name"] for it in items]
with torch.no_grad():
toks = self.tokenizer(prompts).to(self.device)
feats = self.model.encode_text(toks)
feats = feats / feats.norm(dim=-1, keepdim=True)
self.text_features_cpu = feats.detach().cpu().to(torch.float32).contiguous()
self._to_device()
self.labels_version = 1
def __call__(self, data):
payload = data.get("inputs", data)
# Admin op: upsert_labels
op = payload.get("op")
if op == "upsert_labels":
if payload.get("token") != ADMIN_TOKEN:
return {"error": "unauthorized"}
items = payload.get("items", []) or []
added = self._upsert_items(items)
if added > 0:
new_ver = int(getattr(self, "labels_version", 1)) + 1
try:
self._persist_snapshot_to_hub(new_ver)
self.labels_version = new_ver
except Exception as e:
return {"status": "error", "added": added, "detail": str(e)}
return {"status": "ok", "added": added, "labels_version": getattr(self, "labels_version", 1)}
# Admin op: reload_labels
if op == "reload_labels":
if payload.get("token") != ADMIN_TOKEN:
return {"error": "unauthorized"}
try:
ver = int(payload.get("version"))
except Exception:
return {"error": "invalid_version"}
ok = self._load_snapshot_from_hub_version(ver)
return {"status": "ok" if ok else "nochange", "labels_version": getattr(self, "labels_version", 0)}
# Admin op: remove_labels
if op == "remove_labels":
if payload.get("token") != ADMIN_TOKEN:
return {"error": "unauthorized"}
ids_to_remove = set(payload.get("ids", []))
if not ids_to_remove:
return {"error": "no_ids_provided"}
removed = self._remove_items(ids_to_remove)
if removed > 0:
new_ver = int(getattr(self, "labels_version", 1)) + 1
try:
self._persist_snapshot_to_hub(new_ver)
self.labels_version = new_ver
except Exception as e:
return {"status": "error", "removed": removed, "detail": str(e)}
return {"status": "ok", "removed": removed, "labels_version": getattr(self, "labels_version", 1)}
# Freshness guard (optional)
min_ver = payload.get("min_labels_version")
if isinstance(min_ver, int) and min_ver > getattr(self, "labels_version", 0):
with contextlib.suppress(Exception):
self._load_snapshot_from_hub_version(min_ver)
# Classification path (unchanged contract)
img_b64 = payload["image"]
image = Image.open(io.BytesIO(base64.b64decode(img_b64))).convert("RGB")
img_tensor = self.preprocess(image).unsqueeze(0).to(self.device)
if self.device == "cuda":
img_tensor = img_tensor.to(torch.float16)
with torch.no_grad():
img_feat = self.model.encode_image(img_tensor)
img_feat /= img_feat.norm(dim=-1, keepdim=True)
probs = (100.0 * img_feat @ self.text_features.T).softmax(dim=-1)[0]
results = zip(self.class_ids, self.class_names, probs.detach().cpu().tolist())
top_k = int(payload.get("top_k", len(self.class_ids)))
return sorted(
[{"id": i, "label": name, "score": float(p)} for i, name, p in results],
key=lambda x: x["score"],
reverse=True,
)[:top_k]
# ------------- helpers -------------
def _encode_text(self, prompts):
with torch.no_grad():
toks = self.tokenizer(prompts).to(self.device)
feats = self.model.encode_text(toks)
feats = feats / feats.norm(dim=-1, keepdim=True)
return feats
def _to_device(self):
self.text_features = self.text_features_cpu.to(
self.device, dtype=(torch.float16 if self.device == "cuda" else torch.float32)
)
def _upsert_items(self, new_items):
if not new_items:
return 0
with self._lock:
# Get ALL existing IDs and names from current state
known_ids = set(getattr(self, "class_ids", []))
# Create lowercase set for case-insensitive comparison
known_names_lower = set(name.lower() for name in getattr(self, "class_names", []))
# Filter items, checking against both ID and name (case-insensitive)
batch = []
for it in new_items:
item_id = int(it.get("id"))
item_name = it.get("name")
# Skip if either ID or name already exists (case-insensitive for names)
if item_id in known_ids:
continue # Skip duplicate ID
elif item_name.lower() in known_names_lower:
continue # Skip duplicate name (case-insensitive)
else:
batch.append(it)
if not batch:
return 0
# Process the filtered batch
prompts = [it["prompt"] for it in batch]
feats = self._encode_text(prompts).detach().cpu().to(torch.float32)
# Update the persistent state
if not hasattr(self, "text_features_cpu"):
self.text_features_cpu = feats.contiguous()
self.class_ids = [int(it["id"]) for it in batch]
self.class_names = [it["name"] for it in batch]
else:
self.text_features_cpu = torch.cat([self.text_features_cpu, feats], dim=0).contiguous()
self.class_ids.extend([int(it["id"]) for it in batch])
self.class_names.extend([it["name"] for it in batch])
self._to_device()
return len(batch)
def _remove_items(self, ids_to_remove):
if not ids_to_remove or not hasattr(self, "class_ids"):
return 0
with self._lock:
ids_to_remove = set(int(id) for id in ids_to_remove)
# Find indices to keep
indices_to_keep = []
removed_count = 0
for i, class_id in enumerate(self.class_ids):
if class_id not in ids_to_remove:
indices_to_keep.append(i)
else:
removed_count += 1
if removed_count == 0:
return 0
# Filter the tensors and lists
if indices_to_keep:
self.text_features_cpu = self.text_features_cpu[indices_to_keep].contiguous()
self.class_ids = [self.class_ids[i] for i in indices_to_keep]
self.class_names = [self.class_names[i] for i in indices_to_keep]
else:
# All items removed, reset to empty
self.text_features_cpu = torch.empty(0, self.text_features_cpu.shape[1])
self.class_ids = []
self.class_names = []
self._to_device()
return removed_count
def _persist_snapshot_to_hub(self, version: int):
if not HF_LABEL_REPO:
raise RuntimeError("HF_LABEL_REPO not set")
if not HF_WRITE_TOKEN:
raise RuntimeError("HF_WRITE_TOKEN not set for publishing")
emb_path = "/tmp/embeddings.safetensors"
meta_path = "/tmp/meta.json"
latest_bytes = io.BytesIO(json.dumps({"version": int(version)}).encode("utf-8"))
save_file({"embeddings": self.text_features_cpu.to(torch.float32)}, emb_path)
meta = {
"items": [{"id": int(i), "name": n} for i, n in zip(self.class_ids, self.class_names)],
"fingerprint": self.fingerprint,
"dims": int(self.text_features_cpu.shape[1]),
"count": int(self.text_features_cpu.shape[0]),
"version": int(version),
}
with open(meta_path, "w", encoding="utf-8") as f:
json.dump(meta, f)
ops = [
CommitOperationAdd(
path_in_repo=f"snapshots/v{version}/embeddings.safetensors",
path_or_fileobj=emb_path
),
CommitOperationAdd(
path_in_repo=f"snapshots/v{version}/meta.json",
path_or_fileobj=meta_path
),
CommitOperationAdd(
path_in_repo="snapshots/latest.json",
path_or_fileobj=latest_bytes
),
]
create_commit(
repo_id=HF_LABEL_REPO,
repo_type="dataset",
operations=ops,
token=HF_WRITE_TOKEN,
commit_message=f"labels v{version}",
)
def _load_snapshot_from_hub_version(self, version: int) -> bool:
if not HF_LABEL_REPO:
return False
with self._lock:
emb_p = hf_hub_download(
HF_LABEL_REPO,
f"snapshots/v{version}/embeddings.safetensors",
repo_type="dataset",
token=HF_READ_TOKEN,
force_download=True,
)
meta_p = hf_hub_download(
HF_LABEL_REPO,
f"snapshots/v{version}/meta.json",
repo_type="dataset",
token=HF_READ_TOKEN,
force_download=True,
)
meta = json.load(open(meta_p, "r", encoding="utf-8"))
if meta.get("fingerprint") != self.fingerprint:
raise RuntimeError("Embedding/model fingerprint mismatch")
feats = load_file(emb_p)["embeddings"] # float32 CPU
self.text_features_cpu = feats.contiguous()
self.class_ids = [int(x["id"]) for x in meta.get("items", [])]
self.class_names = [x["name"] for x in meta.get("items", [])]
self.labels_version = int(meta.get("version", version))
self._to_device()
return True
def _load_snapshot_from_hub_latest(self) -> bool:
if not HF_LABEL_REPO:
return False
try:
latest_p = hf_hub_download(
HF_LABEL_REPO,
"snapshots/latest.json",
repo_type="dataset",
token=HF_READ_TOKEN,
)
except Exception:
return False
latest = json.load(open(latest_p, "r", encoding="utf-8"))
ver = int(latest.get("version", 0))
if ver <= 0:
return False
return self._load_snapshot_from_hub_version(ver)
# """
# MobileCLIP‑B Zero‑Shot Image Classifier (Hugging Face Inference Endpoint)
# ===========================================================================
# * One container instance is created per replica; the `EndpointHandler`
# object below is instantiated exactly **once** at start‑up.
# * At request time (`__call__`) we receive a base‑64‑encoded image, run a
# **single forward pass**, and return class probabilities.
# Design choices
# --------------
# 1. **Model & transform come from OpenCLIP**
# This guarantees we apply **identical preprocessing** to what the model
# was trained with (224 × 224 crop + mean/std normalisation).
# 2. **Re‑parameterisation for inference**
# MobileCLIP uses MobileOne blocks that have extra convolution branches
# for training; `reparameterize_model` fuses them so inference is fast
# and deterministic.
# 3. **Text embeddings are cached**
# The class “prompts” (e.g. `"a photo of a cat"`) are encoded **once at
# start‑up**. Each request therefore encodes *only* the image and
# performs a single matrix multiplication.
# 4. **Mixed precision on GPU**
# If the container has CUDA, we cast the model **and** inputs to
# `float16`. That halves memory and roughly doubles throughput on most
# modern GPUs. On CPU we stay in `float32` for numerical stability.
# """
# import contextlib, io, base64, json
# from pathlib import Path
# from typing import Any, Dict, List
# import torch
# from PIL import Image
# import open_clip
# from reparam import reparameterize_model # local copy (~60 LoC) of Apple’s helper
# class EndpointHandler:
# """
# Hugging Face entry‑point. The toolkit will instantiate this class
# once and call it for every HTTP request.
# Parameters
# ----------
# path : str, optional
# Root directory of the repository. HF mounts the code under
# `/repository`; we use this path to locate `items.json`.
# """
# # ------------------------------------------------------------------ #
# # INITIALISATION (runs **once**) #
# # ------------------------------------------------------------------ #
# def __init__(self, path: str = "") -> None:
# self.device = "cuda" if torch.cuda.is_available() else "cpu"
# # 1️⃣ Load MobileCLIP‑B weights & transforms -------------------
# # `pretrained="datacompdr"` makes OpenCLIP download the
# # official checkpoint from the Hub (cached in the image layer).
# model, _, self.preprocess = open_clip.create_model_and_transforms(
# "MobileCLIP-B", pretrained="datacompdr"
# )
# model.eval() # disable dropout / BN updates
# model = reparameterize_model(model) # fuse MobileOne branches
# model.to(self.device)
# if self.device == "cuda":
# model = model.to(torch.float16) # FP16 for throughput
# self.model = model # hold a reference
# # 2️⃣ Build the tokenizer once --------------------------------
# tokenizer = open_clip.get_tokenizer("MobileCLIP-B")
# # 3️⃣ Load class metadata -------------------------------------
# # Expect JSON file: [{"id": 3, "name": "cat", "prompt": "cat"}, …]
# items_path = Path(path) / "items.json"
# with items_path.open("r", encoding="utf-8") as f:
# class_defs: List[Dict[str, Any]] = json.load(f)
# # Extract the bits we need later
# prompts = [item["prompt"] for item in class_defs]
# self.class_ids: List[int] = [item["id"] for item in class_defs]
# self.class_names: List[str] = [item["name"] for item in class_defs]
# # 4️⃣ Encode all prompts once ---------------------------------
# with torch.no_grad():
# text_tokens = tokenizer(prompts).to(self.device)
# text_feats = self.model.encode_text(text_tokens)
# text_feats = text_feats / text_feats.norm(dim=-1, keepdim=True)
# self.text_features = text_feats # [num_classes, 512]
# # ------------------------------------------------------------------ #
# # INFERENCE CALL #
# # ------------------------------------------------------------------ #
# def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]:
# """
# Parameters
# ----------
# data : dict
# Either the raw payload `{"image": "<base64>"}` **or** the
# Hugging Face convention `{"inputs": {...}}`.
# Returns
# -------
# list of dict
# Sorted list of `{"id": int, "label": str, "score": float}`.
# Scores are the softmax probabilities over the *provided*
# class list (they sum to 1.0).
# """
# # 1️⃣ Unpack the request payload ------------------------------
# payload: Dict[str, Any] = data.get("inputs", data)
# img_b64: str = payload["image"]
# # 2️⃣ Decode + preprocess -------------------------------------
# image = Image.open(io.BytesIO(base64.b64decode(img_b64))).convert("RGB")
# img_tensor = self.preprocess(image).unsqueeze(0).to(self.device) # [1, 3, 224, 224]
# if self.device == "cuda":
# img_tensor = img_tensor.to(torch.float16)
# # 3️⃣ Forward pass (image only) -------------------------------
# with torch.no_grad(): # no autograd graph
# img_feat = self.model.encode_image(img_tensor) # [1, 512]
# img_feat = img_feat / img_feat.norm(dim=-1, keepdim=True) # L2‑normalise
# # cosine similarity → logits → softmax probabilities
# probs = (100 * img_feat @ self.text_features.T).softmax(dim=-1)[0] # [num_classes]
# # 4️⃣ Assemble JSON‑serialisable response ---------------------
# results = zip(self.class_ids, self.class_names, probs.cpu().tolist())
# return sorted(
# [{"id": cid, "label": name, "score": float(p)} for cid, name, p in results],
# key=lambda x: x["score"],
# reverse=True,
# )