LLaDA Distilled 24L Draft Model (v3)

A 24-layer structurally pruned + distilled version of LLaDA-8B-Instruct (32 layers), designed as a draft model for speculative decoding in masked diffusion language models.

Last updated: 2026-04-12 01:59 UTC β€” Step 13000, Epoch 3 Best argmax agreement: 0.6502

Current Status

Metric Value
Training step 13000
Epoch 3
Best agreement (avg) 0.6502
Agreement @ mask 0.3 0.7503
Agreement @ mask 0.5 0.6413
Agreement @ mask 0.7 0.5415
Confident-correct (avg) 0.8440

Architecture

Teacher Student (this model)
Layers 32 24 (removed layers 12-19)
Attention heads 32 32
FFN intermediate dim 12,288 9,216 (25% pruned per layer)
Hidden dim (d_model) 4,096 4,096
Vocab size 126,464 126,464
Total params 8.0B ~5.36B
Parameter reduction β€” ~33%

Layers kept (original indices): [0–11, 20–31] Layers removed: [12, 13, 14, 15, 16, 17, 18, 19]

How It Was Made

Step 1: Structured Pruning (Data-Driven)

12 ablation experiments identified redundancies:

  • Layer removal: Layers 12–19 (contiguous middle block) have ~89% leave-one-out agreement β€” most expendable
  • FFN pruning: Bottom 25% of neurons (by gated activation magnitude) removed per layer β€” different neurons per layer
  • Pre-distillation agreement: ~60% argmax agreement with full model

Step 2: Knowledge Distillation (Spec-Dec Optimized)

Loss function β€” designed to maximize speculative decoding acceptance rate:

  • 0.5Γ— Top-1 CE: CE(student_logits, teacher_argmax) β€” directly optimizes the CtV/VtC acceptance criterion
  • 0.3Γ— Top-K KL: KL(student_topK βˆ₯ teacher_topK) over teacher's top-32 tokens only β€” sharpens decision boundary, ignores 126K irrelevant vocab tokens
  • 0.2Γ— Confidence MSE: MSE(max(softmax(student)), max(softmax(teacher))) β€” calibrates confidence for threshold-based drafting

Training config:

  • Data: StarCoderData Python (40%) + OpenMathInstruct-2 (40%) + C4 (20%)
  • Mask ratio: Beta(2,2) ∈ [0.15, 0.85]
  • Optimizer: AdamW (lr=5e-05, cosine schedule, 5% warmup)
  • Effective batch size: 16
  • Sequence length: 512
  • Precision: bf16

Training History

Step Epoch Avg Agree Agree@0.3 Agree@0.5 Agree@0.7 Loss
4500 1 0.6403 0.7364 0.6563 0.5282 0.5767
5000 1 0.6482 0.7535 0.6647 0.5264 0.5651
5500 1 0.6273 0.7359 0.6353 0.5106 0.5549
6000 1 0.6287 0.7413 0.6341 0.5108 0.5476
6250 1 0.6276 0.7390 0.6339 0.5098 0.5444
6500 2 0.6329 0.7391 0.6421 0.5176 0.4692
7000 2 0.6368 0.7565 0.6415 0.5125 0.4699
7500 2 0.6356 0.7483 0.6413 0.5173 0.4700
8000 2 0.6502 0.7658 0.6463 0.5385 0.4707
8500 2 0.6372 0.7573 0.6347 0.5195 0.4728
9000 2 0.6168 0.7482 0.6381 0.4642 0.4731
9500 2 0.6269 0.7472 0.6294 0.5043 0.4738
10000 2 0.6323 0.7430 0.6342 0.5198 0.4734
10500 2 0.6416 0.7512 0.6462 0.5273 0.4735
11000 2 0.6269 0.7343 0.6366 0.5097 0.4741
11500 2 0.6359 0.7475 0.6421 0.5180 0.4728
12000 2 0.6372 0.7309 0.6525 0.5282 0.4718
12500 2 0.6498 0.7609 0.6592 0.5293 0.4715
12500 2 0.6498 0.7609 0.6592 0.5293 0.4715
13000 3 0.6444 0.7503 0.6413 0.5415 0.4549

Intended Use

This is a draft model for speculative decoding with LLaDA-8B-Instruct as the target. NOT for standalone generation.

from transformers import AutoModel, AutoTokenizer
import torch

draft_model = AutoModel.from_pretrained(
    "jaygala223/llada-distilled-24L-v3",
    trust_remote_code=True,
    torch_dtype=torch.bfloat16
).to("cuda").eval()

target_model = AutoModel.from_pretrained(
    "GSAI-ML/LLaDA-8B-Instruct",
    trust_remote_code=True,
    torch_dtype=torch.bfloat16
).to("cuda").eval()

Limitations

  • Not standalone β€” designed for speculative decoding only
  • Same tokenizer as LLaDA-8B-Instruct (mask_id=126336)
  • Training ongoing β€” later checkpoints may improve
Downloads last month
99
Safetensors
Model size
5B params
Tensor type
BF16
Β·
Inference Providers NEW
This model isn't deployed by any Inference Provider. πŸ™‹ Ask for provider support