File size: 775 Bytes
0873e19
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
# convert_pth_to_safetensors.py
import sys, torch
from safetensors.torch import save_file

src = sys.argv[1]
dst = sys.argv[2] if len(sys.argv) > 2 else src.rsplit(".", 1)[0] + ".safetensors"

obj = torch.load(src, map_location="cpu")  # weights_only=True if your torch>=2.0 supports it
sd = obj.get("state_dict", obj)            # many .pth have nested {"state_dict": ...}

# If someone saved a whole nn.Module:
if hasattr(sd, "state_dict"):
    sd = sd.state_dict()

# Ensure tensors are contiguous (defensive; usually not required)
sd = {k: (v.contiguous() if isinstance(v, torch.Tensor) else v) for k, v in sd.items()}

# Optional: add a tiny breadcrumb
meta = {"converted_from": src, "format": "pt->safetensors"}

save_file(sd, dst, metadata=meta)
print(f"Wrote: {dst}")