# convert_pth_to_safetensors.py import sys, torch from safetensors.torch import save_file src = sys.argv[1] dst = sys.argv[2] if len(sys.argv) > 2 else src.rsplit(".", 1)[0] + ".safetensors" obj = torch.load(src, map_location="cpu") # weights_only=True if your torch>=2.0 supports it sd = obj.get("state_dict", obj) # many .pth have nested {"state_dict": ...} # If someone saved a whole nn.Module: if hasattr(sd, "state_dict"): sd = sd.state_dict() # Ensure tensors are contiguous (defensive; usually not required) sd = {k: (v.contiguous() if isinstance(v, torch.Tensor) else v) for k, v in sd.items()} # Optional: add a tiny breadcrumb meta = {"converted_from": src, "format": "pt->safetensors"} save_file(sd, dst, metadata=meta) print(f"Wrote: {dst}")