onullusoy's picture
Upload 12 files
5e94db5 verified
raw
history blame
643 Bytes
import os, csv, torch
def l2norm_rows(x: torch.Tensor, eps: float = 1e-12) -> torch.Tensor:
return x / (x.norm(dim=1, keepdim=True) + eps)
def load_tag_names(T: int, csv_name: str) -> list[str]:
p = os.path.join(os.path.dirname(__file__), csv_name)
names: list[str] = []
if os.path.isfile(p):
with open(p, "r", encoding="utf-8", newline="") as f:
for row in csv.reader(f):
if len(row) > 1 and row[1].strip():
names.append(row[1].strip())
if len(names) >= T:
return names[:T]
return names + [f"tag_{i:04d}" for i in range(len(names), T)]