Qwen2.5-32B Circuit-Level Transcoders (CLT)

High-quality Circuit-Level Transcoders for Qwen2.5-32B, trained with TopK sparsity for optimal interpretability and circuit discovery.

🎯 Key Features

  • βœ… 63 layers (L0 β†’ L62)
  • βœ… Fixed 12% L0 sparsity: Consistent activation patterns across all layers
  • βœ… TopK activation: Deterministic feature selection for reproducibility
  • βœ… Large feature space: 12,288 features per layer (2.4x expansion)
  • βœ… Excellent reconstruction: Low validation loss across all layers

πŸ“Š Model Scale Comparison

Model Layers Hidden Dim Feature Dim Expansion Total Size
Qwen2.5-VL-7B CLT 27 3,584 8,192 2.29x ~3.2 GB
Qwen2.5-32B CLT 63 5,120 12,288 2.40x ~15 GB

πŸ“Š Training Quality

Layer Range L0 Sparsity Dead Features Status
L0-L15 12.0% 0% βœ… Excellent
L16-L31 12.0% 0% βœ… Excellent
L32-L47 12.0% 0% βœ… Excellent
L48-L62 12.0% 0% βœ… Excellent

All layers maintain consistent 12% L0 sparsity (approximately 1,475 active features per token) with strong reconstruction quality and no dead features.

πŸš€ Quick Start

Installation

pip install torch huggingface-hub

Loading Transcoders

import torch
from huggingface_hub import hf_hub_download

# Download a specific layer
layer_idx = 30
transcoder_path = hf_hub_download(
    repo_id="KokosDev/qwen25-32b-clt",
    filename=f"transcoder_L{layer_idx}.pt"
)

# Load the transcoder
transcoder = torch.load(transcoder_path, map_location="cpu")

# Access the state dict
state_dict = transcoder['state_dict']
print(f"Hidden dim: {transcoder['hidden_dim']}")
print(f"Feature dim: {transcoder['feature_dim']}")
print(f"Layer: {transcoder['layer']}")

Using for Circuit Discovery

import torch
import torch.nn.functional as F

# Load transcoder
transcoder = torch.load("transcoder_L30.pt")
state_dict = transcoder['state_dict']

# Extract encoder and decoder weights
W_enc = state_dict['_orig_mod.enc.1.weight']  # [12288, 5120]
b_enc = state_dict['_orig_mod.enc.1.bias']    # [12288]
W_dec = state_dict['_orig_mod.dec.weight']    # [5120, 12288]
b_dec = state_dict['_orig_mod.dec.bias']      # [5120]
pre_enc_bias = state_dict['_orig_mod.enc.0.bias']  # [5120]

# Encode activations to sparse features
activations = torch.randn(1, 128, 5120)  # [batch, seq, hidden_dim]

# Apply pre-encoder bias (centering)
activations_centered = activations + pre_enc_bias

# Encode: features = ReLU(W_enc @ hidden + b_enc)
features = F.linear(activations_centered, W_enc, b_enc)
features = F.relu(features)  # [batch, seq, 12288]

# TopK sparsification (12% = ~1,475 features)
k = int(0.12 * features.shape[-1])
topk_values, topk_indices = torch.topk(features, k, dim=-1)
sparse_features = torch.zeros_like(features)
sparse_features.scatter_(-1, topk_indices, topk_values)

# Reconstruct
reconstructed = F.linear(sparse_features, W_dec, b_dec)  # [batch, seq, 5120]

# Calculate reconstruction error
rec_loss = F.mse_loss(reconstructed, activations)
print(f"Reconstruction loss: {rec_loss.item():.4f}")

πŸ“ Model Architecture

Input (5120) β†’ Pre-bias β†’ Encoder β†’ ReLU β†’ TopK(12%) β†’ Features (12288) β†’ Decoder β†’ Output (5120)
  • Hidden dim: 5,120 (Qwen2.5-32B residual stream)
  • Feature dim: 12,288 (sparse features, 2.4x expansion)
  • Activation: ReLU + TopK
  • Sparsity: Fixed 12% L0 (~1,475 active features per token)
  • Architecture: Linear encoder/decoder with pre-encoder bias centering

πŸ”¬ Training Details

Dataset

  • Source: Large-scale text corpus
  • Preprocessing: Cached activations from Qwen2.5-32B
  • Validation: Held-out samples for quality monitoring
  • Scale: 5,000 training steps per layer

Hyperparameters

  • Steps: 5,000 per layer
  • Learning rate: Adaptive with cosine schedule
  • Optimizer: AdamW
  • Sparsity: TopK with k = 12% of features (~1,475 features)
  • Validation interval: Regular monitoring
  • Batch size: Optimized for GPU memory

Training Infrastructure

  • GPU: NVIDIA A100-SXM4-40GB
  • Framework: PyTorch 2.0+ with mixed precision
  • Total layers: 63 (L0-L62)
  • Training time: ~5 days for all layers

🎯 CLT vs Traditional SAEs

Circuit-Level Transcoders (CLTs) offer several advantages:

  1. Deterministic sparsity: TopK ensures exactly 12% features active
  2. Reproducible: Same input always activates same features
  3. Interpretable: Fixed sparsity makes feature analysis consistent
  4. Efficient: TopK is faster than L1 regularization during inference
  5. No dead features: All 12,288 features per layer are active across the dataset
  6. Scalable: Successfully applied to large 32B model

πŸ“– Use Cases

  • Circuit discovery: Identify which features activate for specific inputs
  • Mechanistic interpretability: Understand large language model internals
  • Feature analysis: Study what concepts are encoded at each layer depth
  • Ablation studies: Remove specific features to test causal relationships
  • Activation steering: Modify feature activations to control model behavior
  • Prompt engineering: Understand which features drive specific responses
  • Safety research: Identify and analyze potentially harmful feature activations

πŸ”— Related Resources

πŸ“Š File Structure

qwen25-32b-clt/
β”œβ”€β”€ README.md
β”œβ”€β”€ .gitattributes
β”œβ”€β”€ transcoder_L0.pt    (241 MB)
β”œβ”€β”€ transcoder_L1.pt    (241 MB)
β”œβ”€β”€ transcoder_L2.pt    (241 MB)
β”œβ”€β”€ ...
β”œβ”€β”€ transcoder_L61.pt   (241 MB)
└── transcoder_L62.pt   (241 MB)

Each .pt file contains:

{
    'layer': int,                    # Layer index (0-62)
    'hidden_dim': 5120,             # Qwen2.5-32B hidden dimension
    'feature_dim': 12288,           # Sparse feature dimension
    'state_dict': OrderedDict({
        '_orig_mod.enc.0.weight': Tensor[5120],      # Pre-encoder normalization
        '_orig_mod.enc.0.bias': Tensor[5120],        # Pre-encoder bias
        '_orig_mod.enc.1.weight': Tensor[12288, 5120],  # Encoder weights
        '_orig_mod.enc.1.bias': Tensor[12288],       # Encoder bias
        '_orig_mod.dec.weight': Tensor[5120, 12288], # Decoder weights
        '_orig_mod.dec.bias': Tensor[5120]           # Decoder bias
    }),
    'training_metadata': {
        'steps': 5000,
        'final_l0_pct': float,      # Final L0 sparsity percentage
        'dead_features': 0,          # Number of dead features
        'dead_pct': 0.0,            # Percentage of dead features
        'final_rec_loss': float      # Final reconstruction loss
    }
}

πŸ”’ Technical Specifications

Memory Requirements

  • Inference: ~32 GB GPU memory for full 32B model + transcoders
  • Per-layer transcoder: ~241 MB
  • Recommended: A100 40GB or H100 80GB for comfortable inference

Computational Costs

  • Encoding: O(hidden_dim Γ— feature_dim) = O(5120 Γ— 12288)
  • TopK selection: O(feature_dim Γ— log(k)) = O(12288 Γ— log(1475))
  • Decoding: O(feature_dim Γ— hidden_dim) = O(12288 Γ— 5120)
  • Total per token: ~315M FLOPs per layer

Performance Benchmarks

Operation Time (A100) Memory
Load single transcoder ~0.5s 241 MB
Encode batch (32 tokens) ~2ms ~1 GB
TopK selection ~0.5ms negligible
Decode batch ~2ms ~1 GB

πŸ“„ License

Apache 2.0 - Same as Qwen2.5-32B base model

πŸ™ Acknowledgments

  • Qwen team for the excellent Qwen2.5-32B language model
  • Anthropic for pioneering sparse autoencoder research
  • The mechanistic interpretability community for insights and tools
  • OpenAI for foundational work in transformer interpretability

πŸ“Š Citation

If you use these transcoders in your research, please cite:

@misc{qwen25-32b-clt,
  title={Qwen2.5-32B Circuit-Level Transcoders},
  author={KokosDev},
  year={2024},
  howpublished={\url{https://huggingface.co/KokosDev/qwen25-32b-clt}},
}

πŸ“§ Contact

For questions, issues, or collaboration opportunities:

  • Open an issue in this repository
  • Contact: [Your contact method]

Model Version: v1.0
Last Updated: October 2024
Total Size: ~15 GB (63 layers Γ— 241 MB)
Training Date: October 2024
Base Model: Qwen2.5-32B


πŸš€ Quick Start Examples

Example 1: Feature Activation Analysis

import torch
from transformers import AutoTokenizer, AutoModelForCausalLM

# Load model and tokenizer
model = AutoModelForCausalLM.from_pretrained("Qwen/Qwen2.5-32B", torch_dtype=torch.float16, device_map="auto")
tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2.5-32B")

# Load transcoder for layer 30
transcoder = torch.load("transcoder_L30.pt")
state_dict = transcoder['state_dict']

# Encode text
text = "The capital of France is"
inputs = tokenizer(text, return_tensors="pt").to(model.device)

# Get activations at layer 30
with torch.no_grad():
    outputs = model(**inputs, output_hidden_states=True)
    layer_30_acts = outputs.hidden_states[30]  # [1, seq_len, 5120]

# Encode to features
W_enc = state_dict['_orig_mod.enc.1.weight'].to(model.device)
b_enc = state_dict['_orig_mod.enc.1.bias'].to(model.device)
pre_bias = state_dict['_orig_mod.enc.0.bias'].to(model.device)

centered = layer_30_acts + pre_bias
features = torch.relu(torch.nn.functional.linear(centered, W_enc, b_enc))

# Find top active features
top_features = torch.topk(features[0, -1, :], k=10)
print(f"Top 10 active features: {top_features.indices.tolist()}")
print(f"Activation values: {top_features.values.tolist()}")

Example 2: Batch Processing

import torch
from tqdm import tqdm

def process_dataset_with_transcoders(texts, transcoder_layers=[0, 15, 30, 45, 62]):
    """Process a dataset and collect feature statistics."""
    
    feature_stats = {layer: {} for layer in transcoder_layers}
    
    for text in tqdm(texts):
        inputs = tokenizer(text, return_tensors="pt", truncation=True, max_length=512).to(device)
        
        with torch.no_grad():
            outputs = model(**inputs, output_hidden_states=True)
            
            for layer_idx in transcoder_layers:
                acts = outputs.hidden_states[layer_idx]
                
                # Encode with transcoder
                features = encode_with_transcoder(acts, layer_idx)
                
                # Collect statistics
                active_features = (features > 0).any(dim=1)[0]  # [feature_dim]
                for feat_idx in active_features.nonzero():
                    feat_idx = feat_idx.item()
                    if feat_idx not in feature_stats[layer_idx]:
                        feature_stats[layer_idx][feat_idx] = 0
                    feature_stats[layer_idx][feat_idx] += 1
    
    return feature_stats

Happy Circuit Discovering! πŸ”πŸ§ 

Downloads last month

-

Downloads are not tracked for this model. How to track
Inference Providers NEW
This model isn't deployed by any Inference Provider. πŸ™‹ Ask for provider support