from typing import Dict, List, Tuple from PIL import Image ALL_CATEGORIES = [ "alcohol","drugs","weapons","gambling", "nudity","sexy","smoking","violence" ] DEFAULT_THRESHOLD = 0.5 NUDENET_ONLY = {"clip-nudenet-lp", "siglip-nudenet-lp"} class BaseModel: name = "base" supports_selected_tags = False categories = ALL_CATEGORIES def load(self): raise NotImplementedError def predict_image(self, pil_image: Image.Image, requested_categories: List[str]) -> Dict[str, float]: raise NotImplementedError def extra_selected_tags(self, pil_image: Image.Image, top_k: int = 10) -> List[Tuple[str, float]]: return [] class Clip_MultiLabel(BaseModel): name = "clip-multilabel" categories = ALL_CATEGORIES def __init__(self, head_path="weights/clip_multilabel.pt"): self._cfg = dict(head_path=head_path, categories=self.categories) self._m = None def load(self): from src.models import CLIPMultiLabel if self._m is None: self._m = CLIPMultiLabel(**self._cfg) def predict_image(self, pil_image, requested_categories: List[str]) -> Dict[str, float]: p = self._m.prob([pil_image])[0].tolist() return {c: float(p[i]) for i, c in enumerate(self.categories) if c in requested_categories} class _EVABaseAdapter(BaseModel): supports_selected_tags = True REPO_ID = "" TAG_CSV = "" def __init__(self, head_path: str): self._cfg = dict(head_path=head_path, categories=self.categories) self._m = None def load(self): from src.models.eva_headpreserving import EVAHeadPreserving if self._m is None: self._m = EVAHeadPreserving(repo_id=self.REPO_ID, head_path=self._cfg["head_path"], categories=self.categories, tag_csv=self.TAG_CSV) def predict_image(self, pil_image, requested_categories: List[str]) -> Dict[str, float]: p = self._m.prob([pil_image])[0].tolist() return {c: float(p[i]) for i, c in enumerate(self.categories) if c in requested_categories} def extra_selected_tags(self, pil_image: Image.Image, top_k: int = 50) -> List[Tuple[str, float]]: return self._m.top_tags(pil_image, top_k=top_k) class WDEva02_Multitask(_EVABaseAdapter): name = "wdeva02-multitask" REPO_ID = "SmilingWolf/wd-eva02-large-tagger-v3" TAG_CSV = "wdeva02_tags.csv" def __init__(self, head_path="weights/wdeva02.pt"): super().__init__(head_path=head_path) class Animetimm_Multitask(_EVABaseAdapter): name = "animetimm-multitask" REPO_ID = "animetimm/eva02_large_patch14_448.dbv4-full" TAG_CSV = "animetimm_tags.csv" def __init__(self, head_path="weights/animetimm.pt"): super().__init__(head_path=head_path) class Clip_NudeNet_LP(BaseModel): name = "clip-nudenet-lp"; categories = ["sexual"] def __init__(self, head_path: str = "weights/clip_nudenet_lp.npz"): self._cfg = dict(head_path=head_path); self._lp = None def load(self): if self._lp is None: from src.models import CLIPLinearProbe self._lp = CLIPLinearProbe(**self._cfg) def predict_image(self, pil_image: Image.Image, requested_categories: List[str]) -> Dict[str, float]: return {"sexual": float(self._lp.prob([pil_image])[0])} class Siglip_NudeNet_LP(BaseModel): name = "siglip-nudenet-lp"; categories = ["sexual"] def __init__(self, head_path: str = "weights/siglip_nudenet_lp.npz"): self._cfg = dict(head_path=head_path); self._lp = None def load(self): if self._lp is None: from src.models import SigLIPLinearProbe self._lp = SigLIPLinearProbe(**self._cfg) def predict_image(self, pil_image: Image.Image, requested_categories: List[str]) -> Dict[str, float]: return {"sexual": float(self._lp.prob([pil_image])[0])} REGISTRY = { "clip-multilabel": Clip_MultiLabel(), "wdeva02-multilabel": WDEva02_Multitask(), "animetimm-multilabel": Animetimm_Multitask(), "clip-nudenet-lp": Clip_NudeNet_LP(), "siglip-nudenet-lp": Siglip_NudeNet_LP(), } def get_model(name: str) -> BaseModel: m = REGISTRY[name] if not hasattr(m, "_loaded"): m.load() m._loaded = True return m