update model code
Browse files- config.json +1 -1
- modeling_plamo.py +152 -59
config.json
CHANGED
|
@@ -30,7 +30,7 @@
|
|
| 30 |
"mamba_num_heads": 32,
|
| 31 |
"mamba_step": 2,
|
| 32 |
"max_position_embeddings": 10485760,
|
| 33 |
-
"model_type": "
|
| 34 |
"n_expert": null,
|
| 35 |
"num_attention_heads": 16,
|
| 36 |
"num_hidden_layers": 16,
|
|
|
|
| 30 |
"mamba_num_heads": 32,
|
| 31 |
"mamba_step": 2,
|
| 32 |
"max_position_embeddings": 10485760,
|
| 33 |
+
"model_type": "plamo2",
|
| 34 |
"n_expert": null,
|
| 35 |
"num_attention_heads": 16,
|
| 36 |
"num_hidden_layers": 16,
|
modeling_plamo.py
CHANGED
|
@@ -551,6 +551,68 @@ def _ssd_chunk_scan_combined_naive(
|
|
| 551 |
return torch.cat(ys, dim=1), ssm_state
|
| 552 |
|
| 553 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 554 |
def ssd_chunk_scan_combined(
|
| 555 |
x: torch.Tensor,
|
| 556 |
dt: torch.Tensor,
|
|
@@ -587,19 +649,19 @@ def ssd_chunk_scan_combined(
|
|
| 587 |
To avoid updating state, we set dt to -inf and x to 0
|
| 588 |
because `softplus(-inf) = 0` and `exp(0) = 1`
|
| 589 |
"""
|
| 590 |
-
|
| 591 |
-
|
| 592 |
-
|
| 593 |
-
|
| 594 |
-
|
| 595 |
-
|
| 596 |
-
|
| 597 |
-
|
| 598 |
-
seq_idx = torch.nn.functional.pad(seq_idx, pad=[pad, 0], value=0)
|
| 599 |
-
|
| 600 |
-
length = x.shape[1]
|
| 601 |
-
assert length % chunk_size == 0, (length, chunk_size)
|
| 602 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 603 |
dtype = _get_trition_dtype(x.dtype)
|
| 604 |
out = mamba_ssm.ops.triton.ssd_combined.mamba_chunk_scan_combined( # type: ignore
|
| 605 |
x.to(dtype),
|
|
@@ -622,19 +684,75 @@ def ssd_chunk_scan_combined(
|
|
| 622 |
assert isinstance(out, torch.Tensor)
|
| 623 |
return out[:, pad:]
|
| 624 |
else:
|
| 625 |
-
if ssm_state is None:
|
| 626 |
-
|
| 627 |
-
|
| 628 |
-
|
| 629 |
-
|
| 630 |
-
|
| 631 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 632 |
if return_final_states:
|
| 633 |
return tmp
|
| 634 |
else:
|
| 635 |
return tmp[0]
|
| 636 |
|
| 637 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 638 |
def _causal_conv1d(
|
| 639 |
conv_state: torch.Tensor | None, weight: torch.Tensor, x: torch.Tensor, seq_idx: torch.Tensor | None
|
| 640 |
) -> tuple[torch.Tensor, torch.Tensor | None]:
|
|
@@ -670,52 +788,27 @@ def _causal_conv1d(
|
|
| 670 |
else:
|
| 671 |
x = tmp
|
| 672 |
else:
|
| 673 |
-
if
|
| 674 |
-
|
| 675 |
-
|
| 676 |
-
|
| 677 |
-
|
| 678 |
-
|
| 679 |
-
|
| 680 |
-
|
| 681 |
-
|
| 682 |
-
|
| 683 |
-
|
| 684 |
-
|
| 685 |
-
|
| 686 |
-
)
|
| 687 |
-
|
| 688 |
-
x = out
|
| 689 |
if return_final_states:
|
| 690 |
return x, conv_state
|
| 691 |
else:
|
| 692 |
return x, None
|
| 693 |
|
| 694 |
|
| 695 |
-
def _causal_conv1d_update(
|
| 696 |
-
conv_state: torch.Tensor, weight: torch.Tensor, xBC: torch.Tensor
|
| 697 |
-
) -> tuple[torch.Tensor, torch.Tensor]:
|
| 698 |
-
dtype = conv_state.dtype
|
| 699 |
-
xBC = xBC.to(dtype)
|
| 700 |
-
weight = weight.to(dtype)
|
| 701 |
-
if conv_state.is_cuda:
|
| 702 |
-
x = causal_conv1d.causal_conv1d_update(
|
| 703 |
-
x=xBC,
|
| 704 |
-
conv_state=conv_state,
|
| 705 |
-
weight=weight[:, 0, :],
|
| 706 |
-
activation="silu",
|
| 707 |
-
)
|
| 708 |
-
return x, conv_state
|
| 709 |
-
else:
|
| 710 |
-
x = causal_conv1d.causal_conv1d_update_ref(
|
| 711 |
-
x=xBC,
|
| 712 |
-
conv_state=conv_state,
|
| 713 |
-
weight=weight[:, 0, :],
|
| 714 |
-
activation="silu",
|
| 715 |
-
)
|
| 716 |
-
return x, conv_state
|
| 717 |
-
|
| 718 |
-
|
| 719 |
class Mamba(torch.nn.Module):
|
| 720 |
def __init__(self, config: PlamoConfig, layer_idx: int) -> None:
|
| 721 |
super().__init__()
|
|
|
|
| 551 |
return torch.cat(ys, dim=1), ssm_state
|
| 552 |
|
| 553 |
|
| 554 |
+
def _ssd_chunk_scan_combined_cpu(
|
| 555 |
+
x: torch.Tensor,
|
| 556 |
+
dt: torch.Tensor,
|
| 557 |
+
A: torch.Tensor,
|
| 558 |
+
B: torch.Tensor,
|
| 559 |
+
C: torch.Tensor,
|
| 560 |
+
chunk_size: int,
|
| 561 |
+
D: torch.Tensor,
|
| 562 |
+
z: torch.Tensor,
|
| 563 |
+
dt_bias: torch.Tensor,
|
| 564 |
+
dt_softplus: bool,
|
| 565 |
+
) -> tuple[torch.Tensor, torch.Tensor]:
|
| 566 |
+
# (bsize, nhead, nchunk, chunk_size)
|
| 567 |
+
dt = dt.float() # We want high precision for this before cumsum
|
| 568 |
+
dt = dt.permute(0, 2, 1).unflatten(2, (-1, chunk_size)) # type: ignore
|
| 569 |
+
if dt_bias is not None:
|
| 570 |
+
dt = dt + dt_bias[None, :, None, None]
|
| 571 |
+
if dt_softplus:
|
| 572 |
+
dt = F.softplus(dt)
|
| 573 |
+
dA = dt * A[None, :, None, None]
|
| 574 |
+
dA_cumsum = torch.cumsum(dA, dim=-1)
|
| 575 |
+
|
| 576 |
+
_, _, nheads, _ = x.shape
|
| 577 |
+
dstate = B.shape[-1]
|
| 578 |
+
_ = dt.shape[2]
|
| 579 |
+
|
| 580 |
+
with torch.profiler.record_function("ssd_chunk_scan_combined_cpu_chunk_state"):
|
| 581 |
+
# Following is equivalent to `mamba_ssm.ops.triton.ssd_combined.chunk_state_ref(B, x, dt, dA_cumsum)`
|
| 582 |
+
# But `einsum` in the above function is too slow in CPU.
|
| 583 |
+
x_ = torch.unflatten(x, 1, (-1, chunk_size))
|
| 584 |
+
assert B.shape[2] == nheads # B should be already expanded
|
| 585 |
+
B_ = torch.unflatten(B, 1, (-1, chunk_size)).to(x.dtype) # (bsize, nchunk, chunk_size, nheads, dstate)
|
| 586 |
+
decay_states = torch.exp((dA_cumsum[:, :, :, -1:] - dA_cumsum)).to(x.dtype)
|
| 587 |
+
dt_ = dt.to(x.dtype)
|
| 588 |
+
|
| 589 |
+
# einsum("bclhn,bhcl,bhcl,bclhp->bchpn", B_, decay_states, dt_, x_)
|
| 590 |
+
B_ = B_.permute(0, 1, 3, 4, 2) # bchnl
|
| 591 |
+
tmp = dt_ * decay_states # bhcl
|
| 592 |
+
tmp = tmp.permute(0, 2, 1, 3)[:, :, :, None] # bch1l
|
| 593 |
+
tmp = B_ * tmp # bchnl
|
| 594 |
+
x_ = x_.permute(0, 1, 3, 2, 4) # bchlp
|
| 595 |
+
tmp = tmp @ x_ # bchnp
|
| 596 |
+
states = tmp.permute(0, 1, 2, 4, 3) # bchpn
|
| 597 |
+
|
| 598 |
+
states_dtype = states.dtype
|
| 599 |
+
if states.dtype not in [torch.float32, torch.float64]:
|
| 600 |
+
states = states.to(torch.float32)
|
| 601 |
+
with torch.profiler.record_function("ssd_chunk_scan_combined_cpu_state_passing"):
|
| 602 |
+
out, last_state = mamba_ssm.ops.triton.ssd_combined.state_passing_ref(
|
| 603 |
+
states.flatten(start_dim=-2, end_dim=-1),
|
| 604 |
+
dA_cumsum[:, :, :, -1],
|
| 605 |
+
)
|
| 606 |
+
states = torch.unflatten(out, -1, (-1, dstate))
|
| 607 |
+
last_state = torch.unflatten(last_state, -1, (-1, dstate))
|
| 608 |
+
states = states.to(states_dtype)
|
| 609 |
+
with torch.profiler.record_function("ssd_chunk_scan_combined_cpu_chunk_scan"):
|
| 610 |
+
out = mamba_ssm.ops.triton.ssd_combined.chunk_scan_ref(B, C, x, dt, dA_cumsum, states, D=D, z=z)
|
| 611 |
+
|
| 612 |
+
return out, last_state
|
| 613 |
+
|
| 614 |
+
|
| 615 |
+
@torch.profiler.record_function("ssd_chunk_scan_combined")
|
| 616 |
def ssd_chunk_scan_combined(
|
| 617 |
x: torch.Tensor,
|
| 618 |
dt: torch.Tensor,
|
|
|
|
| 649 |
To avoid updating state, we set dt to -inf and x to 0
|
| 650 |
because `softplus(-inf) = 0` and `exp(0) = 1`
|
| 651 |
"""
|
| 652 |
+
pad = (chunk_size - length % chunk_size) % chunk_size
|
| 653 |
+
x = torch.nn.functional.pad(x, pad=[0, 0, 0, 0, pad, 0], value=0.0)
|
| 654 |
+
dt = torch.nn.functional.pad(dt, pad=[0, 0, pad, 0], value=float("-inf"))
|
| 655 |
+
B = torch.nn.functional.pad(B, pad=[0, 0, 0, 0, pad, 0], value=0.0)
|
| 656 |
+
C = torch.nn.functional.pad(C, pad=[0, 0, 0, 0, pad, 0], value=0.0)
|
| 657 |
+
z = torch.nn.functional.pad(z, pad=[0, 0, 0, 0, pad, 0], value=0.0)
|
| 658 |
+
if seq_idx is not None:
|
| 659 |
+
seq_idx = torch.nn.functional.pad(seq_idx, pad=[pad, 0], value=0)
|
|
|
|
|
|
|
|
|
|
|
|
|
| 660 |
|
| 661 |
+
length = x.shape[1]
|
| 662 |
+
assert length % chunk_size == 0, (length, chunk_size)
|
| 663 |
+
|
| 664 |
+
if dt.is_cuda:
|
| 665 |
dtype = _get_trition_dtype(x.dtype)
|
| 666 |
out = mamba_ssm.ops.triton.ssd_combined.mamba_chunk_scan_combined( # type: ignore
|
| 667 |
x.to(dtype),
|
|
|
|
| 684 |
assert isinstance(out, torch.Tensor)
|
| 685 |
return out[:, pad:]
|
| 686 |
else:
|
| 687 |
+
if ssm_state is None and seq_idx is None:
|
| 688 |
+
tmp = _ssd_chunk_scan_combined_cpu(
|
| 689 |
+
x,
|
| 690 |
+
dt,
|
| 691 |
+
A,
|
| 692 |
+
B,
|
| 693 |
+
C,
|
| 694 |
+
chunk_size,
|
| 695 |
+
D=D,
|
| 696 |
+
z=z,
|
| 697 |
+
dt_bias=dt_bias.float(),
|
| 698 |
+
dt_softplus=dt_softplus,
|
| 699 |
+
)
|
| 700 |
+
else:
|
| 701 |
+
if ssm_state is None:
|
| 702 |
+
bsize, _, num_heads, channel = x.shape
|
| 703 |
+
state = B.shape[-1]
|
| 704 |
+
ssm_state = torch.zeros(bsize, num_heads, channel, state, dtype=torch.float32, device=x.device)
|
| 705 |
+
tmp = _ssd_chunk_scan_combined_naive(
|
| 706 |
+
x, dt, A, B, C, D, z=z, dt_bias=dt_bias, dt_softplus=dt_softplus, seq_idx=seq_idx, ssm_state=ssm_state
|
| 707 |
+
)
|
| 708 |
+
tmp = (tmp[0][:, pad:], tmp[1])
|
| 709 |
if return_final_states:
|
| 710 |
return tmp
|
| 711 |
else:
|
| 712 |
return tmp[0]
|
| 713 |
|
| 714 |
|
| 715 |
+
def _causal_conv1d_update(
|
| 716 |
+
conv_state: torch.Tensor, weight: torch.Tensor, xBC: torch.Tensor
|
| 717 |
+
) -> tuple[torch.Tensor, torch.Tensor]:
|
| 718 |
+
dtype = conv_state.dtype
|
| 719 |
+
xBC = xBC.to(dtype)
|
| 720 |
+
weight = weight.to(dtype)
|
| 721 |
+
if conv_state.is_cuda:
|
| 722 |
+
x = causal_conv1d.causal_conv1d_update(
|
| 723 |
+
x=xBC,
|
| 724 |
+
conv_state=conv_state,
|
| 725 |
+
weight=weight[:, 0, :],
|
| 726 |
+
activation="silu",
|
| 727 |
+
)
|
| 728 |
+
return x, conv_state
|
| 729 |
+
else:
|
| 730 |
+
x = causal_conv1d.causal_conv1d_update_ref(
|
| 731 |
+
x=xBC,
|
| 732 |
+
conv_state=conv_state,
|
| 733 |
+
weight=weight[:, 0, :],
|
| 734 |
+
activation="silu",
|
| 735 |
+
)
|
| 736 |
+
return x, conv_state
|
| 737 |
+
|
| 738 |
+
|
| 739 |
+
def _causal_conv1d_naive(
|
| 740 |
+
conv_state: torch.Tensor, weight: torch.Tensor, x: torch.Tensor, seq_idx: torch.Tensor | None
|
| 741 |
+
) -> tuple[torch.Tensor, torch.Tensor]:
|
| 742 |
+
length = x.shape[-1]
|
| 743 |
+
out = torch.zeros_like(x)
|
| 744 |
+
for i in range(length):
|
| 745 |
+
if i != 0 and seq_idx is not None:
|
| 746 |
+
conv_state = torch.where(
|
| 747 |
+
(seq_idx[:, i - 1] != seq_idx[:, i])[:, None, None],
|
| 748 |
+
torch.zeros_like(conv_state),
|
| 749 |
+
conv_state,
|
| 750 |
+
)
|
| 751 |
+
out[:, :, i : i + 1], conv_state = _causal_conv1d_update(conv_state, weight, x[:, :, i : i + 1])
|
| 752 |
+
return out, conv_state
|
| 753 |
+
|
| 754 |
+
|
| 755 |
+
@torch.profiler.record_function("causal_conv1d")
|
| 756 |
def _causal_conv1d(
|
| 757 |
conv_state: torch.Tensor | None, weight: torch.Tensor, x: torch.Tensor, seq_idx: torch.Tensor | None
|
| 758 |
) -> tuple[torch.Tensor, torch.Tensor | None]:
|
|
|
|
| 788 |
else:
|
| 789 |
x = tmp
|
| 790 |
else:
|
| 791 |
+
if seq_idx is None:
|
| 792 |
+
x, conv_state = causal_conv1d.causal_conv1d_ref(
|
| 793 |
+
x=x,
|
| 794 |
+
initial_states=conv_state,
|
| 795 |
+
return_final_states=True,
|
| 796 |
+
weight=weight[:, 0, :],
|
| 797 |
+
activation="silu",
|
| 798 |
+
)
|
| 799 |
+
else:
|
| 800 |
+
if conv_state is None:
|
| 801 |
+
bsize = x.shape[0]
|
| 802 |
+
dim = weight.shape[0]
|
| 803 |
+
d_conv = weight.shape[-1]
|
| 804 |
+
conv_state = torch.zeros(bsize, dim, d_conv - 1, dtype=x.dtype, device=x.device)
|
| 805 |
+
x, conv_state = _causal_conv1d_naive(conv_state, weight, x, seq_idx)
|
|
|
|
| 806 |
if return_final_states:
|
| 807 |
return x, conv_state
|
| 808 |
else:
|
| 809 |
return x, None
|
| 810 |
|
| 811 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 812 |
class Mamba(torch.nn.Module):
|
| 813 |
def __init__(self, config: PlamoConfig, layer_idx: int) -> None:
|
| 814 |
super().__init__()
|