File size: 643 Bytes
5e94db5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
import os, csv, torch

def l2norm_rows(x: torch.Tensor, eps: float = 1e-12) -> torch.Tensor:
    return x / (x.norm(dim=1, keepdim=True) + eps)

def load_tag_names(T: int, csv_name: str) -> list[str]:
    p = os.path.join(os.path.dirname(__file__), csv_name)
    names: list[str] = []
    if os.path.isfile(p):
        with open(p, "r", encoding="utf-8", newline="") as f:
            for row in csv.reader(f):
                if len(row) > 1 and row[1].strip():
                    names.append(row[1].strip())
    if len(names) >= T:
        return names[:T]
    return names + [f"tag_{i:04d}" for i in range(len(names), T)]