import os, csv, torch def l2norm_rows(x: torch.Tensor, eps: float = 1e-12) -> torch.Tensor: return x / (x.norm(dim=1, keepdim=True) + eps) def load_tag_names(T: int, csv_name: str) -> list[str]: p = os.path.join(os.path.dirname(__file__), csv_name) names: list[str] = [] if os.path.isfile(p): with open(p, "r", encoding="utf-8", newline="") as f: for row in csv.reader(f): if len(row) > 1 and row[1].strip(): names.append(row[1].strip()) if len(names) >= T: return names[:T] return names + [f"tag_{i:04d}" for i in range(len(names), T)]