Upload modeling_hunyuan.py
Browse files- modeling_hunyuan.py +71 -62
modeling_hunyuan.py
CHANGED
|
@@ -1605,68 +1605,15 @@ class HunYuanMoEV1ForCausalLM(HunYuanPreTrainedModel):
|
|
| 1605 |
)
|
| 1606 |
return reordered_past
|
| 1607 |
|
| 1608 |
-
|
| 1609 |
-
|
| 1610 |
-
|
| 1611 |
-
|
| 1612 |
-
"""Replace a (possibly nested) sub‑module.
|
| 1613 |
-
|
| 1614 |
-
``target`` is the dotted path returned by ``model.named_modules()``.
|
| 1615 |
-
"""
|
| 1616 |
-
parts = target.split('.')
|
| 1617 |
-
parent = root
|
| 1618 |
-
for p in parts[:-1]:
|
| 1619 |
-
parent = getattr(parent, p)
|
| 1620 |
-
setattr(parent, parts[-1], new_module)
|
| 1621 |
-
|
| 1622 |
-
# -----------------------------------------------------------------------
|
| 1623 |
-
# Public APIs
|
| 1624 |
-
# -----------------------------------------------------------------------
|
| 1625 |
-
def densify(model: nn.Module):
|
| 1626 |
-
"""Convert all :class:`HunYuanMoE` modules under *model* to
|
| 1627 |
-
:class:`HunYuanDenseMoE`. Operates **in‑place**."""
|
| 1628 |
-
replacements = []
|
| 1629 |
-
for name, module in model.named_modules():
|
| 1630 |
-
if isinstance(module, HunYuanMoE):
|
| 1631 |
-
replacements.append((name, module))
|
| 1632 |
-
for name, sparse_moe in replacements:
|
| 1633 |
-
dense_moe = HunYuanDenseMoE(sparse_moe).to(next(sparse_moe.parameters()).device)
|
| 1634 |
-
_replace_submodule(model, name, dense_moe)
|
| 1635 |
-
return model
|
| 1636 |
-
|
| 1637 |
-
|
| 1638 |
-
def sparsify(model: nn.Module):
|
| 1639 |
-
"""Rebuild standard sparse :class:`HunYuanMoE` modules from their
|
| 1640 |
-
fused :class:`HunYuanDenseMoE` form. Operates **in‑place**."""
|
| 1641 |
-
replacements = []
|
| 1642 |
-
for name, module in model.named_modules():
|
| 1643 |
-
if isinstance(module, HunYuanDenseMoE):
|
| 1644 |
-
replacements.append((name, module))
|
| 1645 |
-
for name, dense_moe in replacements:
|
| 1646 |
-
cfg = dense_moe.config
|
| 1647 |
-
sparse_moe = HunYuanMoE(cfg, layer_idx=dense_moe.layer_idx).to(next(dense_moe.parameters()).device)
|
| 1648 |
-
|
| 1649 |
-
# Copy router
|
| 1650 |
-
sparse_moe.gate.load_state_dict(dense_moe.gate.state_dict())
|
| 1651 |
-
|
| 1652 |
-
# Slice fused weights back to per‑expert
|
| 1653 |
-
for idx, expert in enumerate(sparse_moe.experts):
|
| 1654 |
-
start = idx * dense_moe.intermediate_size
|
| 1655 |
-
end = (idx + 1) * dense_moe.intermediate_size
|
| 1656 |
-
|
| 1657 |
-
expert.gate_proj.weight.data.copy_(
|
| 1658 |
-
dense_moe.fused_gate_proj.weight.data[start:end]
|
| 1659 |
-
)
|
| 1660 |
-
expert.up_proj.weight.data.copy_(
|
| 1661 |
-
dense_moe.fused_up_proj.weight.data[start:end]
|
| 1662 |
-
)
|
| 1663 |
-
expert.down_proj.weight.data.copy_(
|
| 1664 |
-
dense_moe.fused_down_proj.weight.data[:, start:end]
|
| 1665 |
-
)
|
| 1666 |
-
|
| 1667 |
-
_replace_submodule(model, name, sparse_moe)
|
| 1668 |
-
return model
|
| 1669 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1670 |
|
| 1671 |
@add_start_docstrings(
|
| 1672 |
"""
|
|
@@ -1907,4 +1854,66 @@ class HunYuanDenseMoE(nn.Module):
|
|
| 1907 |
sparse_out = self._sparse_path(x, probs)
|
| 1908 |
|
| 1909 |
out = dense_out + (sparse_out - dense_out).detach() # STE
|
| 1910 |
-
return out.view(bsz, seq_len, self.hidden_size)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1605 |
)
|
| 1606 |
return reordered_past
|
| 1607 |
|
| 1608 |
+
def densify(self):
|
| 1609 |
+
"""In-place fusion of every HunYuanMoE block under this model."""
|
| 1610 |
+
densify(self) # just call the standalone helper
|
| 1611 |
+
return self # for chaining
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1612 |
|
| 1613 |
+
def sparsify(self):
|
| 1614 |
+
"""Restore the original sparse experts (inverse of densify)."""
|
| 1615 |
+
sparsify(self)
|
| 1616 |
+
return self
|
| 1617 |
|
| 1618 |
@add_start_docstrings(
|
| 1619 |
"""
|
|
|
|
| 1854 |
sparse_out = self._sparse_path(x, probs)
|
| 1855 |
|
| 1856 |
out = dense_out + (sparse_out - dense_out).detach() # STE
|
| 1857 |
+
return out.view(bsz, seq_len, self.hidden_size)
|
| 1858 |
+
|
| 1859 |
+
# -----------------------------------------------------------------------
|
| 1860 |
+
# Helper for module replacement
|
| 1861 |
+
# -----------------------------------------------------------------------
|
| 1862 |
+
def _replace_submodule(root: nn.Module, target: str, new_module: nn.Module):
|
| 1863 |
+
"""Replace a (possibly nested) sub‑module.
|
| 1864 |
+
|
| 1865 |
+
``target`` is the dotted path returned by ``model.named_modules()``.
|
| 1866 |
+
"""
|
| 1867 |
+
parts = target.split('.')
|
| 1868 |
+
parent = root
|
| 1869 |
+
for p in parts[:-1]:
|
| 1870 |
+
parent = getattr(parent, p)
|
| 1871 |
+
setattr(parent, parts[-1], new_module)
|
| 1872 |
+
|
| 1873 |
+
# -----------------------------------------------------------------------
|
| 1874 |
+
# Public APIs
|
| 1875 |
+
# -----------------------------------------------------------------------
|
| 1876 |
+
def densify(model: nn.Module):
|
| 1877 |
+
"""Convert all :class:`HunYuanMoE` modules under *model* to
|
| 1878 |
+
:class:`HunYuanDenseMoE`. Operates **in‑place**."""
|
| 1879 |
+
replacements = []
|
| 1880 |
+
for name, module in model.named_modules():
|
| 1881 |
+
if isinstance(module, HunYuanMoE):
|
| 1882 |
+
replacements.append((name, module))
|
| 1883 |
+
for name, sparse_moe in replacements:
|
| 1884 |
+
dense_moe = HunYuanDenseMoE(sparse_moe).to(next(sparse_moe.parameters()).device)
|
| 1885 |
+
_replace_submodule(model, name, dense_moe)
|
| 1886 |
+
return model
|
| 1887 |
+
|
| 1888 |
+
|
| 1889 |
+
def sparsify(model: nn.Module):
|
| 1890 |
+
"""Rebuild standard sparse :class:`HunYuanMoE` modules from their
|
| 1891 |
+
fused :class:`HunYuanDenseMoE` form. Operates **in‑place**."""
|
| 1892 |
+
replacements = []
|
| 1893 |
+
for name, module in model.named_modules():
|
| 1894 |
+
if isinstance(module, HunYuanDenseMoE):
|
| 1895 |
+
replacements.append((name, module))
|
| 1896 |
+
for name, dense_moe in replacements:
|
| 1897 |
+
cfg = dense_moe.config
|
| 1898 |
+
sparse_moe = HunYuanMoE(cfg, layer_idx=dense_moe.layer_idx).to(next(dense_moe.parameters()).device)
|
| 1899 |
+
|
| 1900 |
+
# Copy router
|
| 1901 |
+
sparse_moe.gate.load_state_dict(dense_moe.gate.state_dict())
|
| 1902 |
+
|
| 1903 |
+
# Slice fused weights back to per‑expert
|
| 1904 |
+
for idx, expert in enumerate(sparse_moe.experts):
|
| 1905 |
+
start = idx * dense_moe.intermediate_size
|
| 1906 |
+
end = (idx + 1) * dense_moe.intermediate_size
|
| 1907 |
+
|
| 1908 |
+
expert.gate_proj.weight.data.copy_(
|
| 1909 |
+
dense_moe.fused_gate_proj.weight.data[start:end]
|
| 1910 |
+
)
|
| 1911 |
+
expert.up_proj.weight.data.copy_(
|
| 1912 |
+
dense_moe.fused_up_proj.weight.data[start:end]
|
| 1913 |
+
)
|
| 1914 |
+
expert.down_proj.weight.data.copy_(
|
| 1915 |
+
dense_moe.fused_down_proj.weight.data[:, start:end]
|
| 1916 |
+
)
|
| 1917 |
+
|
| 1918 |
+
_replace_submodule(model, name, sparse_moe)
|
| 1919 |
+
return model
|