| # convert_pth_to_safetensors.py | |
| import sys, torch | |
| from safetensors.torch import save_file | |
| src = sys.argv[1] | |
| dst = sys.argv[2] if len(sys.argv) > 2 else src.rsplit(".", 1)[0] + ".safetensors" | |
| obj = torch.load(src, map_location="cpu") # weights_only=True if your torch>=2.0 supports it | |
| sd = obj.get("state_dict", obj) # many .pth have nested {"state_dict": ...} | |
| # If someone saved a whole nn.Module: | |
| if hasattr(sd, "state_dict"): | |
| sd = sd.state_dict() | |
| # Ensure tensors are contiguous (defensive; usually not required) | |
| sd = {k: (v.contiguous() if isinstance(v, torch.Tensor) else v) for k, v in sd.items()} | |
| # Optional: add a tiny breadcrumb | |
| meta = {"converted_from": src, "format": "pt->safetensors"} | |
| save_file(sd, dst, metadata=meta) | |
| print(f"Wrote: {dst}") |