LLaDA Distilled 24L Draft Model (v3)
A 24-layer structurally pruned + distilled version of LLaDA-8B-Instruct (32 layers), designed as a draft model for speculative decoding in masked diffusion language models.
Last updated: 2026-04-12 01:59 UTC β Step 13000, Epoch 3 Best argmax agreement: 0.6502
Current Status
| Metric | Value |
|---|---|
| Training step | 13000 |
| Epoch | 3 |
| Best agreement (avg) | 0.6502 |
| Agreement @ mask 0.3 | 0.7503 |
| Agreement @ mask 0.5 | 0.6413 |
| Agreement @ mask 0.7 | 0.5415 |
| Confident-correct (avg) | 0.8440 |
Architecture
| Teacher | Student (this model) | |
|---|---|---|
| Layers | 32 | 24 (removed layers 12-19) |
| Attention heads | 32 | 32 |
| FFN intermediate dim | 12,288 | 9,216 (25% pruned per layer) |
| Hidden dim (d_model) | 4,096 | 4,096 |
| Vocab size | 126,464 | 126,464 |
| Total params | 8.0B | ~5.36B |
| Parameter reduction | β | ~33% |
Layers kept (original indices): [0β11, 20β31] Layers removed: [12, 13, 14, 15, 16, 17, 18, 19]
How It Was Made
Step 1: Structured Pruning (Data-Driven)
12 ablation experiments identified redundancies:
- Layer removal: Layers 12β19 (contiguous middle block) have ~89% leave-one-out agreement β most expendable
- FFN pruning: Bottom 25% of neurons (by gated activation magnitude) removed per layer β different neurons per layer
- Pre-distillation agreement: ~60% argmax agreement with full model
Step 2: Knowledge Distillation (Spec-Dec Optimized)
Loss function β designed to maximize speculative decoding acceptance rate:
- 0.5Γ Top-1 CE:
CE(student_logits, teacher_argmax)β directly optimizes the CtV/VtC acceptance criterion - 0.3Γ Top-K KL:
KL(student_topK β₯ teacher_topK)over teacher's top-32 tokens only β sharpens decision boundary, ignores 126K irrelevant vocab tokens - 0.2Γ Confidence MSE:
MSE(max(softmax(student)), max(softmax(teacher)))β calibrates confidence for threshold-based drafting
Training config:
- Data: StarCoderData Python (40%) + OpenMathInstruct-2 (40%) + C4 (20%)
- Mask ratio: Beta(2,2) β [0.15, 0.85]
- Optimizer: AdamW (lr=5e-05, cosine schedule, 5% warmup)
- Effective batch size: 16
- Sequence length: 512
- Precision: bf16
Training History
| Step | Epoch | Avg Agree | Agree@0.3 | Agree@0.5 | Agree@0.7 | Loss |
|---|---|---|---|---|---|---|
| 4500 | 1 | 0.6403 | 0.7364 | 0.6563 | 0.5282 | 0.5767 |
| 5000 | 1 | 0.6482 | 0.7535 | 0.6647 | 0.5264 | 0.5651 |
| 5500 | 1 | 0.6273 | 0.7359 | 0.6353 | 0.5106 | 0.5549 |
| 6000 | 1 | 0.6287 | 0.7413 | 0.6341 | 0.5108 | 0.5476 |
| 6250 | 1 | 0.6276 | 0.7390 | 0.6339 | 0.5098 | 0.5444 |
| 6500 | 2 | 0.6329 | 0.7391 | 0.6421 | 0.5176 | 0.4692 |
| 7000 | 2 | 0.6368 | 0.7565 | 0.6415 | 0.5125 | 0.4699 |
| 7500 | 2 | 0.6356 | 0.7483 | 0.6413 | 0.5173 | 0.4700 |
| 8000 | 2 | 0.6502 | 0.7658 | 0.6463 | 0.5385 | 0.4707 |
| 8500 | 2 | 0.6372 | 0.7573 | 0.6347 | 0.5195 | 0.4728 |
| 9000 | 2 | 0.6168 | 0.7482 | 0.6381 | 0.4642 | 0.4731 |
| 9500 | 2 | 0.6269 | 0.7472 | 0.6294 | 0.5043 | 0.4738 |
| 10000 | 2 | 0.6323 | 0.7430 | 0.6342 | 0.5198 | 0.4734 |
| 10500 | 2 | 0.6416 | 0.7512 | 0.6462 | 0.5273 | 0.4735 |
| 11000 | 2 | 0.6269 | 0.7343 | 0.6366 | 0.5097 | 0.4741 |
| 11500 | 2 | 0.6359 | 0.7475 | 0.6421 | 0.5180 | 0.4728 |
| 12000 | 2 | 0.6372 | 0.7309 | 0.6525 | 0.5282 | 0.4718 |
| 12500 | 2 | 0.6498 | 0.7609 | 0.6592 | 0.5293 | 0.4715 |
| 12500 | 2 | 0.6498 | 0.7609 | 0.6592 | 0.5293 | 0.4715 |
| 13000 | 3 | 0.6444 | 0.7503 | 0.6413 | 0.5415 | 0.4549 |
Intended Use
This is a draft model for speculative decoding with LLaDA-8B-Instruct as the target. NOT for standalone generation.
from transformers import AutoModel, AutoTokenizer
import torch
draft_model = AutoModel.from_pretrained(
"jaygala223/llada-distilled-24L-v3",
trust_remote_code=True,
torch_dtype=torch.bfloat16
).to("cuda").eval()
target_model = AutoModel.from_pretrained(
"GSAI-ML/LLaDA-8B-Instruct",
trust_remote_code=True,
torch_dtype=torch.bfloat16
).to("cuda").eval()
Limitations
- Not standalone β designed for speculative decoding only
- Same tokenizer as LLaDA-8B-Instruct (mask_id=126336)
- Training ongoing β later checkpoints may improve
- Downloads last month
- 99
Inference Providers NEW
This model isn't deployed by any Inference Provider. π Ask for provider support