HunyuanVideo-Foley / scripts /convert_pth_to_safetensors.py
phazei's picture
Move scripts
7a94e98
raw
history blame contribute delete
775 Bytes
# convert_pth_to_safetensors.py
import sys, torch
from safetensors.torch import save_file
src = sys.argv[1]
dst = sys.argv[2] if len(sys.argv) > 2 else src.rsplit(".", 1)[0] + ".safetensors"
obj = torch.load(src, map_location="cpu") # weights_only=True if your torch>=2.0 supports it
sd = obj.get("state_dict", obj) # many .pth have nested {"state_dict": ...}
# If someone saved a whole nn.Module:
if hasattr(sd, "state_dict"):
sd = sd.state_dict()
# Ensure tensors are contiguous (defensive; usually not required)
sd = {k: (v.contiguous() if isinstance(v, torch.Tensor) else v) for k, v in sd.items()}
# Optional: add a tiny breadcrumb
meta = {"converted_from": src, "format": "pt->safetensors"}
save_file(sd, dst, metadata=meta)
print(f"Wrote: {dst}")