codys12 commited on
Commit
83ff72b
·
verified ·
1 Parent(s): bab2a77

Upload modeling_hunyuan.py

Browse files
Files changed (1) hide show
  1. modeling_hunyuan.py +71 -62
modeling_hunyuan.py CHANGED
@@ -1605,68 +1605,15 @@ class HunYuanMoEV1ForCausalLM(HunYuanPreTrainedModel):
1605
  )
1606
  return reordered_past
1607
 
1608
- # -----------------------------------------------------------------------
1609
- # Helper for module replacement
1610
- # -----------------------------------------------------------------------
1611
- def _replace_submodule(root: nn.Module, target: str, new_module: nn.Module):
1612
- """Replace a (possibly nested) sub‑module.
1613
-
1614
- ``target`` is the dotted path returned by ``model.named_modules()``.
1615
- """
1616
- parts = target.split('.')
1617
- parent = root
1618
- for p in parts[:-1]:
1619
- parent = getattr(parent, p)
1620
- setattr(parent, parts[-1], new_module)
1621
-
1622
- # -----------------------------------------------------------------------
1623
- # Public APIs
1624
- # -----------------------------------------------------------------------
1625
- def densify(model: nn.Module):
1626
- """Convert all :class:`HunYuanMoE` modules under *model* to
1627
- :class:`HunYuanDenseMoE`. Operates **in‑place**."""
1628
- replacements = []
1629
- for name, module in model.named_modules():
1630
- if isinstance(module, HunYuanMoE):
1631
- replacements.append((name, module))
1632
- for name, sparse_moe in replacements:
1633
- dense_moe = HunYuanDenseMoE(sparse_moe).to(next(sparse_moe.parameters()).device)
1634
- _replace_submodule(model, name, dense_moe)
1635
- return model
1636
-
1637
-
1638
- def sparsify(model: nn.Module):
1639
- """Rebuild standard sparse :class:`HunYuanMoE` modules from their
1640
- fused :class:`HunYuanDenseMoE` form. Operates **in‑place**."""
1641
- replacements = []
1642
- for name, module in model.named_modules():
1643
- if isinstance(module, HunYuanDenseMoE):
1644
- replacements.append((name, module))
1645
- for name, dense_moe in replacements:
1646
- cfg = dense_moe.config
1647
- sparse_moe = HunYuanMoE(cfg, layer_idx=dense_moe.layer_idx).to(next(dense_moe.parameters()).device)
1648
-
1649
- # Copy router
1650
- sparse_moe.gate.load_state_dict(dense_moe.gate.state_dict())
1651
-
1652
- # Slice fused weights back to per‑expert
1653
- for idx, expert in enumerate(sparse_moe.experts):
1654
- start = idx * dense_moe.intermediate_size
1655
- end = (idx + 1) * dense_moe.intermediate_size
1656
-
1657
- expert.gate_proj.weight.data.copy_(
1658
- dense_moe.fused_gate_proj.weight.data[start:end]
1659
- )
1660
- expert.up_proj.weight.data.copy_(
1661
- dense_moe.fused_up_proj.weight.data[start:end]
1662
- )
1663
- expert.down_proj.weight.data.copy_(
1664
- dense_moe.fused_down_proj.weight.data[:, start:end]
1665
- )
1666
-
1667
- _replace_submodule(model, name, sparse_moe)
1668
- return model
1669
 
 
 
 
 
1670
 
1671
  @add_start_docstrings(
1672
  """
@@ -1907,4 +1854,66 @@ class HunYuanDenseMoE(nn.Module):
1907
  sparse_out = self._sparse_path(x, probs)
1908
 
1909
  out = dense_out + (sparse_out - dense_out).detach() # STE
1910
- return out.view(bsz, seq_len, self.hidden_size)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1605
  )
1606
  return reordered_past
1607
 
1608
+ def densify(self):
1609
+ """In-place fusion of every HunYuanMoE block under this model."""
1610
+ densify(self) # just call the standalone helper
1611
+ return self # for chaining
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1612
 
1613
+ def sparsify(self):
1614
+ """Restore the original sparse experts (inverse of densify)."""
1615
+ sparsify(self)
1616
+ return self
1617
 
1618
  @add_start_docstrings(
1619
  """
 
1854
  sparse_out = self._sparse_path(x, probs)
1855
 
1856
  out = dense_out + (sparse_out - dense_out).detach() # STE
1857
+ return out.view(bsz, seq_len, self.hidden_size)
1858
+
1859
+ # -----------------------------------------------------------------------
1860
+ # Helper for module replacement
1861
+ # -----------------------------------------------------------------------
1862
+ def _replace_submodule(root: nn.Module, target: str, new_module: nn.Module):
1863
+ """Replace a (possibly nested) sub‑module.
1864
+
1865
+ ``target`` is the dotted path returned by ``model.named_modules()``.
1866
+ """
1867
+ parts = target.split('.')
1868
+ parent = root
1869
+ for p in parts[:-1]:
1870
+ parent = getattr(parent, p)
1871
+ setattr(parent, parts[-1], new_module)
1872
+
1873
+ # -----------------------------------------------------------------------
1874
+ # Public APIs
1875
+ # -----------------------------------------------------------------------
1876
+ def densify(model: nn.Module):
1877
+ """Convert all :class:`HunYuanMoE` modules under *model* to
1878
+ :class:`HunYuanDenseMoE`. Operates **in‑place**."""
1879
+ replacements = []
1880
+ for name, module in model.named_modules():
1881
+ if isinstance(module, HunYuanMoE):
1882
+ replacements.append((name, module))
1883
+ for name, sparse_moe in replacements:
1884
+ dense_moe = HunYuanDenseMoE(sparse_moe).to(next(sparse_moe.parameters()).device)
1885
+ _replace_submodule(model, name, dense_moe)
1886
+ return model
1887
+
1888
+
1889
+ def sparsify(model: nn.Module):
1890
+ """Rebuild standard sparse :class:`HunYuanMoE` modules from their
1891
+ fused :class:`HunYuanDenseMoE` form. Operates **in‑place**."""
1892
+ replacements = []
1893
+ for name, module in model.named_modules():
1894
+ if isinstance(module, HunYuanDenseMoE):
1895
+ replacements.append((name, module))
1896
+ for name, dense_moe in replacements:
1897
+ cfg = dense_moe.config
1898
+ sparse_moe = HunYuanMoE(cfg, layer_idx=dense_moe.layer_idx).to(next(dense_moe.parameters()).device)
1899
+
1900
+ # Copy router
1901
+ sparse_moe.gate.load_state_dict(dense_moe.gate.state_dict())
1902
+
1903
+ # Slice fused weights back to per‑expert
1904
+ for idx, expert in enumerate(sparse_moe.experts):
1905
+ start = idx * dense_moe.intermediate_size
1906
+ end = (idx + 1) * dense_moe.intermediate_size
1907
+
1908
+ expert.gate_proj.weight.data.copy_(
1909
+ dense_moe.fused_gate_proj.weight.data[start:end]
1910
+ )
1911
+ expert.up_proj.weight.data.copy_(
1912
+ dense_moe.fused_up_proj.weight.data[start:end]
1913
+ )
1914
+ expert.down_proj.weight.data.copy_(
1915
+ dense_moe.fused_down_proj.weight.data[:, start:end]
1916
+ )
1917
+
1918
+ _replace_submodule(model, name, sparse_moe)
1919
+ return model