Upload 8 files
Browse files- inference.py +241 -0
- moe_config.py +119 -0
- moe_layers.py +323 -0
- moe_model.py +459 -0
- moe_trainer.py +168 -0
- requirements.txt +96 -0
- sample_generation_callback.py +148 -0
- train_moe_v8_clean.py +429 -0
inference.py
ADDED
|
@@ -0,0 +1,241 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Inference Script für trainiertes MoE Modell
|
| 3 |
+
Lädt automatisch den neuesten Checkpoint und testet verschiedene Sampling Strategien
|
| 4 |
+
"""
|
| 5 |
+
|
| 6 |
+
import os
|
| 7 |
+
import sys
|
| 8 |
+
import torch
|
| 9 |
+
from transformers import AutoTokenizer
|
| 10 |
+
from moe_config import MoEGPTConfig
|
| 11 |
+
from moe_model import MoEGPTForCausalLM
|
| 12 |
+
|
| 13 |
+
# Force UTF-8 encoding for Windows console
|
| 14 |
+
if sys.platform == 'win32':
|
| 15 |
+
sys.stdout.reconfigure(encoding='utf-8')
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
def find_latest_checkpoint(checkpoint_dir="./moe_checkpoints_v8_clean"):
|
| 19 |
+
"""
|
| 20 |
+
Findet den neuesten Checkpoint automatisch (v8 OPUS Edition!)
|
| 21 |
+
|
| 22 |
+
Returns:
|
| 23 |
+
str: Pfad zum neuesten Checkpoint oder None
|
| 24 |
+
"""
|
| 25 |
+
if not os.path.exists(checkpoint_dir):
|
| 26 |
+
return None
|
| 27 |
+
|
| 28 |
+
checkpoints = [
|
| 29 |
+
os.path.join(checkpoint_dir, d)
|
| 30 |
+
for d in os.listdir(checkpoint_dir)
|
| 31 |
+
if d.startswith("checkpoint-")
|
| 32 |
+
]
|
| 33 |
+
|
| 34 |
+
if not checkpoints:
|
| 35 |
+
return None
|
| 36 |
+
|
| 37 |
+
# Neuesten Checkpoint finden (nach creation time)
|
| 38 |
+
latest = max(checkpoints, key=os.path.getctime)
|
| 39 |
+
|
| 40 |
+
# Step Number extrahieren
|
| 41 |
+
step = latest.split("checkpoint-")[-1]
|
| 42 |
+
print(f"\n🔍 Neuester Checkpoint gefunden: Step {step}")
|
| 43 |
+
|
| 44 |
+
return latest
|
| 45 |
+
|
| 46 |
+
|
| 47 |
+
def load_model(model_path=None, device="cuda"):
|
| 48 |
+
"""
|
| 49 |
+
Lädt trainiertes MoE Modell
|
| 50 |
+
Wenn model_path=None, wird automatisch der neueste Checkpoint geladen
|
| 51 |
+
|
| 52 |
+
Args:
|
| 53 |
+
model_path: Pfad zum gespeicherten Modell (None = auto-find)
|
| 54 |
+
device: Device für Inference (cuda/cpu)
|
| 55 |
+
|
| 56 |
+
Returns:
|
| 57 |
+
model: Geladenes Modell
|
| 58 |
+
config: Model Config
|
| 59 |
+
"""
|
| 60 |
+
# Auto-find neuesten Checkpoint
|
| 61 |
+
if model_path is None:
|
| 62 |
+
model_path = find_latest_checkpoint()
|
| 63 |
+
if model_path is None:
|
| 64 |
+
# Fallback: Versuche finales Modell (v8)
|
| 65 |
+
model_path = "./moe_final_v8_clean"
|
| 66 |
+
if not os.path.exists(model_path):
|
| 67 |
+
raise ValueError("Kein Checkpoint gefunden! Trainiere zuerst ein Modell.")
|
| 68 |
+
|
| 69 |
+
print(f"\n📥 Lade Modell von: {model_path}")
|
| 70 |
+
|
| 71 |
+
config = MoEGPTConfig.from_pretrained(model_path)
|
| 72 |
+
model = MoEGPTForCausalLM.from_pretrained(model_path)
|
| 73 |
+
|
| 74 |
+
# Auf Device verschieben
|
| 75 |
+
if device == "cuda" and torch.cuda.is_available():
|
| 76 |
+
model = model.cuda()
|
| 77 |
+
print(f"✅ Modell geladen auf GPU")
|
| 78 |
+
else:
|
| 79 |
+
model = model.cpu()
|
| 80 |
+
print(f"✅ Modell geladen auf CPU")
|
| 81 |
+
|
| 82 |
+
model.eval()
|
| 83 |
+
|
| 84 |
+
total_params = sum(p.numel() for p in model.parameters())
|
| 85 |
+
print(f" 📊 Parameter: {total_params:,} ({total_params/1e6:.1f}M)")
|
| 86 |
+
print(f" 🧠 Experten: {config.total_experts}")
|
| 87 |
+
print(f" ⚡ Aktive Params: {config.active_parameters_ratio:.1%}")
|
| 88 |
+
|
| 89 |
+
return model, config
|
| 90 |
+
|
| 91 |
+
|
| 92 |
+
def generate_text(
|
| 93 |
+
model,
|
| 94 |
+
tokenizer,
|
| 95 |
+
prompt,
|
| 96 |
+
max_new_tokens=400,
|
| 97 |
+
temperature=0.8,
|
| 98 |
+
top_k=50,
|
| 99 |
+
top_p=0.95,
|
| 100 |
+
repetition_penalty=1.0,
|
| 101 |
+
device="cuda",
|
| 102 |
+
):
|
| 103 |
+
"""
|
| 104 |
+
Generiert Text mit dem MoE Modell
|
| 105 |
+
|
| 106 |
+
Args:
|
| 107 |
+
model: MoE Modell
|
| 108 |
+
tokenizer: Tokenizer
|
| 109 |
+
prompt: Input Prompt (String)
|
| 110 |
+
max_new_tokens: Maximale neue Tokens (400!)
|
| 111 |
+
temperature: Sampling Temperature
|
| 112 |
+
top_k: Top-k Sampling
|
| 113 |
+
top_p: Nucleus Sampling
|
| 114 |
+
repetition_penalty: Penalty für Wiederholungen
|
| 115 |
+
device: Device
|
| 116 |
+
|
| 117 |
+
Returns:
|
| 118 |
+
generated_text: Generierter Text
|
| 119 |
+
"""
|
| 120 |
+
# Tokenize prompt
|
| 121 |
+
input_ids = tokenizer.encode(prompt, return_tensors="pt")
|
| 122 |
+
|
| 123 |
+
if device == "cuda":
|
| 124 |
+
input_ids = input_ids.cuda()
|
| 125 |
+
|
| 126 |
+
# Generieren
|
| 127 |
+
with torch.no_grad():
|
| 128 |
+
output_ids = model.generate(
|
| 129 |
+
input_ids,
|
| 130 |
+
max_new_tokens=max_new_tokens,
|
| 131 |
+
temperature=temperature,
|
| 132 |
+
top_k=top_k,
|
| 133 |
+
top_p=top_p,
|
| 134 |
+
repetition_penalty=repetition_penalty,
|
| 135 |
+
do_sample=True,
|
| 136 |
+
pad_token_id=tokenizer.eos_token_id,
|
| 137 |
+
)
|
| 138 |
+
|
| 139 |
+
# Decode
|
| 140 |
+
generated_text = tokenizer.decode(output_ids[0], skip_special_tokens=True)
|
| 141 |
+
|
| 142 |
+
return generated_text
|
| 143 |
+
|
| 144 |
+
|
| 145 |
+
def test_sampling_strategies(model, tokenizer, prompts, device="cuda"):
|
| 146 |
+
"""
|
| 147 |
+
Testet verschiedene Sampling Strategien
|
| 148 |
+
|
| 149 |
+
Args:
|
| 150 |
+
model: MoE Modell
|
| 151 |
+
tokenizer: Tokenizer
|
| 152 |
+
prompts: Liste von Test-Prompts
|
| 153 |
+
device: Device
|
| 154 |
+
"""
|
| 155 |
+
# Optimale Strategien (basierend auf umfangreichen Tests)
|
| 156 |
+
strategies = {
|
| 157 |
+
"Standard (temp=0.7, rep=1.2, top_k=50, top_p=0.8)": {
|
| 158 |
+
"temperature": 0.7,
|
| 159 |
+
"top_k": 50,
|
| 160 |
+
"top_p": 0.7,
|
| 161 |
+
"repetition_penalty": 1.2,
|
| 162 |
+
},
|
| 163 |
+
"Focused (temp=0.7, rep=1.4, #top_k=30, top_p=0.7)": {
|
| 164 |
+
"temperature": 0.7,
|
| 165 |
+
"top_k": 20,
|
| 166 |
+
"top_p": 0.7,
|
| 167 |
+
"repetition_penalty": 1.4,
|
| 168 |
+
},
|
| 169 |
+
}
|
| 170 |
+
|
| 171 |
+
print("\n" + "=" * 80)
|
| 172 |
+
print("🧪 TESTING SAMPLING STRATEGIES")
|
| 173 |
+
print("=" * 80)
|
| 174 |
+
|
| 175 |
+
for prompt in prompts:
|
| 176 |
+
print(f"\n{'='*80}")
|
| 177 |
+
print(f"PROMPT: '{prompt}'")
|
| 178 |
+
print(f"{'='*80}\n")
|
| 179 |
+
|
| 180 |
+
for strategy_name, params in strategies.items():
|
| 181 |
+
print(f"\n🎯 Strategy: {strategy_name}")
|
| 182 |
+
print("-" * 80)
|
| 183 |
+
|
| 184 |
+
try:
|
| 185 |
+
generated = generate_text(
|
| 186 |
+
model=model,
|
| 187 |
+
tokenizer=tokenizer,
|
| 188 |
+
prompt=prompt,
|
| 189 |
+
max_new_tokens=400, # 400 Tokens!
|
| 190 |
+
device=device,
|
| 191 |
+
**params
|
| 192 |
+
)
|
| 193 |
+
|
| 194 |
+
print(f"{generated}")
|
| 195 |
+
print()
|
| 196 |
+
|
| 197 |
+
except Exception as e:
|
| 198 |
+
print(f"❌ Error: {str(e)}\n")
|
| 199 |
+
|
| 200 |
+
print("\n" + "=" * 80)
|
| 201 |
+
print("💡 EMPFEHLUNG")
|
| 202 |
+
print("=" * 80)
|
| 203 |
+
print("""
|
| 204 |
+
|
| 205 |
+
""")
|
| 206 |
+
|
| 207 |
+
|
| 208 |
+
def main():
|
| 209 |
+
# Device
|
| 210 |
+
device = "cuda" if torch.cuda.is_available() else "cpu"
|
| 211 |
+
print(f"\n🖥️ Device: {device}")
|
| 212 |
+
|
| 213 |
+
# Modell laden (automatisch neuester Checkpoint!)
|
| 214 |
+
model, config = load_model(model_path=None, device=device)
|
| 215 |
+
|
| 216 |
+
# Tokenizer laden
|
| 217 |
+
print("\n📚 Lade Tokenizer...")
|
| 218 |
+
tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-3.2-1B")
|
| 219 |
+
tokenizer.pad_token = tokenizer.eos_token
|
| 220 |
+
print("✅ Llama 3.2 Tokenizer geladen")
|
| 221 |
+
print(f" - Vocab Size: {tokenizer.vocab_size:,}")
|
| 222 |
+
print(f" - EOS Token: {tokenizer.eos_token}")
|
| 223 |
+
|
| 224 |
+
# ==================== SAMPLING STRATEGY TESTS ====================
|
| 225 |
+
|
| 226 |
+
# Test Prompts (diverse!)
|
| 227 |
+
test_prompts = [
|
| 228 |
+
"Gestern bin ich ", # Narrativ
|
| 229 |
+
"Der Mond ", # Poetisch
|
| 230 |
+
"Im Labor ", # Wissenschaftlich
|
| 231 |
+
"Hast du auch das Gefühl, dass", # Persönlich/Forum
|
| 232 |
+
"Die Zeit",
|
| 233 |
+
"Was ist die Definition von Philosophie?"
|
| 234 |
+
]
|
| 235 |
+
|
| 236 |
+
# Teste verschiedene Sampling Strategien
|
| 237 |
+
test_sampling_strategies(model, tokenizer, test_prompts, device)
|
| 238 |
+
|
| 239 |
+
|
| 240 |
+
if __name__ == "__main__":
|
| 241 |
+
main()
|
moe_config.py
ADDED
|
@@ -0,0 +1,119 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
HuggingFace-compatible MoE Configuration
|
| 3 |
+
Basierend auf dem nanoMoE Blog Post
|
| 4 |
+
"""
|
| 5 |
+
|
| 6 |
+
from transformers import PretrainedConfig
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
class MoEGPTConfig(PretrainedConfig):
|
| 10 |
+
"""
|
| 11 |
+
Konfiguration für MoE-basiertes GPT Modell.
|
| 12 |
+
|
| 13 |
+
Args:
|
| 14 |
+
vocab_size (int): Größe des Vokabulars
|
| 15 |
+
n_positions (int): Maximale Sequenzlänge
|
| 16 |
+
n_embd (int): Dimensionalität der Embeddings (d im Blog)
|
| 17 |
+
n_layer (int): Anzahl der Transformer Blocks
|
| 18 |
+
n_head (int): Anzahl der Attention Heads
|
| 19 |
+
n_experts (int): Anzahl der Experten pro MoE Layer
|
| 20 |
+
n_experts_active (int): Anzahl aktiver Experten (top-k)
|
| 21 |
+
moe_layer_frequency (int): Jede n-te Layer wird zu MoE (P im Blog)
|
| 22 |
+
capacity_factor (float): Expert Capacity Factor für Training
|
| 23 |
+
eval_capacity_factor (float): Expert Capacity Factor für Evaluation
|
| 24 |
+
use_noisy_gating (bool): Ob Noisy Top-k Gating verwendet werden soll
|
| 25 |
+
aux_loss_alpha (float): Skalierung für Load Balancing Loss
|
| 26 |
+
router_z_loss_alpha (float): Skalierung für Router Z-Loss
|
| 27 |
+
bias (bool): Ob Bias in Linear Layers verwendet werden soll
|
| 28 |
+
dropout (float): Dropout Probability
|
| 29 |
+
activation_function (str): Aktivierungsfunktion (gelu, relu, swiglu)
|
| 30 |
+
initializer_range (float): Standard Deviation für Weight Initialization
|
| 31 |
+
layer_norm_epsilon (float): Epsilon für Layer Normalization
|
| 32 |
+
"""
|
| 33 |
+
|
| 34 |
+
model_type = "moe_gpt"
|
| 35 |
+
|
| 36 |
+
def __init__(
|
| 37 |
+
self,
|
| 38 |
+
vocab_size=128256, # Llama 3.2 tokenizer (inkl. special tokens)
|
| 39 |
+
n_positions=2048, # Default 2048 für RoPE
|
| 40 |
+
n_embd=768,
|
| 41 |
+
n_layer=12,
|
| 42 |
+
n_head=12,
|
| 43 |
+
n_experts=8,
|
| 44 |
+
n_experts_active=2,
|
| 45 |
+
moe_layer_frequency=2,
|
| 46 |
+
capacity_factor=1.25,
|
| 47 |
+
eval_capacity_factor=2.0,
|
| 48 |
+
use_noisy_gating=True,
|
| 49 |
+
aux_loss_alpha=0.01,
|
| 50 |
+
router_z_loss_alpha=0.001,
|
| 51 |
+
bias=False,
|
| 52 |
+
dropout=0.1,
|
| 53 |
+
activation_function="gelu",
|
| 54 |
+
initializer_range=0.1,
|
| 55 |
+
layer_norm_epsilon=1e-5,
|
| 56 |
+
use_cache=True,
|
| 57 |
+
rope_theta=10000.0, # RoPE base theta
|
| 58 |
+
**kwargs,
|
| 59 |
+
):
|
| 60 |
+
super().__init__(**kwargs)
|
| 61 |
+
|
| 62 |
+
self.vocab_size = vocab_size
|
| 63 |
+
self.n_positions = n_positions
|
| 64 |
+
self.n_embd = n_embd
|
| 65 |
+
self.n_layer = n_layer
|
| 66 |
+
self.n_head = n_head
|
| 67 |
+
self.n_experts = n_experts
|
| 68 |
+
self.n_experts_active = n_experts_active
|
| 69 |
+
self.moe_layer_frequency = moe_layer_frequency
|
| 70 |
+
self.capacity_factor = capacity_factor
|
| 71 |
+
self.eval_capacity_factor = eval_capacity_factor
|
| 72 |
+
self.use_noisy_gating = use_noisy_gating
|
| 73 |
+
self.aux_loss_alpha = aux_loss_alpha
|
| 74 |
+
self.router_z_loss_alpha = router_z_loss_alpha
|
| 75 |
+
self.bias = bias
|
| 76 |
+
self.dropout = dropout
|
| 77 |
+
self.activation_function = activation_function
|
| 78 |
+
self.initializer_range = initializer_range
|
| 79 |
+
self.layer_norm_epsilon = layer_norm_epsilon
|
| 80 |
+
self.use_cache = use_cache
|
| 81 |
+
self.rope_theta = rope_theta
|
| 82 |
+
|
| 83 |
+
# HuggingFace Standard Attribute (für .generate())
|
| 84 |
+
self.num_hidden_layers = n_layer
|
| 85 |
+
self.hidden_size = n_embd
|
| 86 |
+
self.num_attention_heads = n_head
|
| 87 |
+
self.max_position_embeddings = n_positions
|
| 88 |
+
|
| 89 |
+
# Validierung
|
| 90 |
+
assert n_embd % n_head == 0, "n_embd muss durch n_head teilbar sein"
|
| 91 |
+
assert n_experts_active <= n_experts, "n_experts_active darf nicht größer als n_experts sein"
|
| 92 |
+
assert moe_layer_frequency >= 1, "moe_layer_frequency muss mindestens 1 sein"
|
| 93 |
+
|
| 94 |
+
@property
|
| 95 |
+
def head_dim(self):
|
| 96 |
+
"""Dimension pro Attention Head"""
|
| 97 |
+
return self.n_embd // self.n_head
|
| 98 |
+
|
| 99 |
+
@property
|
| 100 |
+
def total_experts(self):
|
| 101 |
+
"""Gesamtanzahl der Experten im Modell"""
|
| 102 |
+
num_moe_layers = sum(1 for i in range(self.n_layer) if i % self.moe_layer_frequency == 0)
|
| 103 |
+
return num_moe_layers * self.n_experts
|
| 104 |
+
|
| 105 |
+
@property
|
| 106 |
+
def active_parameters_ratio(self):
|
| 107 |
+
"""Ratio der aktiven Parameter (ungefähr)"""
|
| 108 |
+
num_moe_layers = sum(1 for i in range(self.n_layer) if i % self.moe_layer_frequency == 0)
|
| 109 |
+
num_dense_layers = self.n_layer - num_moe_layers
|
| 110 |
+
|
| 111 |
+
# Vereinfachte Schätzung (ignoriert Attention)
|
| 112 |
+
dense_params = num_dense_layers * (8 * self.n_embd**2) # FFN params
|
| 113 |
+
moe_total_params = num_moe_layers * self.n_experts * (8 * self.n_embd**2)
|
| 114 |
+
moe_active_params = num_moe_layers * self.n_experts_active * (8 * self.n_embd**2)
|
| 115 |
+
|
| 116 |
+
total = dense_params + moe_total_params
|
| 117 |
+
active = dense_params + moe_active_params
|
| 118 |
+
|
| 119 |
+
return active / total if total > 0 else 1.0
|
moe_layers.py
ADDED
|
@@ -0,0 +1,323 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
MoE Layer Komponenten
|
| 3 |
+
Basierend auf dem nanoMoE Blog Post und HuggingFace Best Practices
|
| 4 |
+
"""
|
| 5 |
+
|
| 6 |
+
import math
|
| 7 |
+
import torch
|
| 8 |
+
import torch.nn as nn
|
| 9 |
+
import torch.nn.functional as F
|
| 10 |
+
from typing import Tuple, Optional
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
class MoERouter(nn.Module):
|
| 14 |
+
"""
|
| 15 |
+
Noisy Top-k Router für MoE.
|
| 16 |
+
Routet Tokens zu den Top-k Experten basierend auf gelernten Wahrscheinlichkeiten.
|
| 17 |
+
"""
|
| 18 |
+
|
| 19 |
+
def __init__(
|
| 20 |
+
self,
|
| 21 |
+
d_model: int,
|
| 22 |
+
n_experts: int,
|
| 23 |
+
n_experts_active: int,
|
| 24 |
+
use_noisy_gating: bool = True,
|
| 25 |
+
capacity_factor: float = 1.25,
|
| 26 |
+
):
|
| 27 |
+
super().__init__()
|
| 28 |
+
|
| 29 |
+
self.d_model = d_model
|
| 30 |
+
self.n_experts = n_experts
|
| 31 |
+
self.n_experts_active = n_experts_active
|
| 32 |
+
self.use_noisy_gating = use_noisy_gating
|
| 33 |
+
self.capacity_factor = capacity_factor
|
| 34 |
+
|
| 35 |
+
# Linear projections für Router (kein Bias, siehe Shazeer et al. 2017)
|
| 36 |
+
self.w_gate = nn.Linear(d_model, n_experts, bias=False)
|
| 37 |
+
self.w_noise = nn.Linear(d_model, n_experts, bias=False) if use_noisy_gating else None
|
| 38 |
+
|
| 39 |
+
def forward(
|
| 40 |
+
self, x: torch.Tensor
|
| 41 |
+
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
|
| 42 |
+
"""
|
| 43 |
+
Args:
|
| 44 |
+
x: Input tensor [batch_size, seq_len, d_model]
|
| 45 |
+
|
| 46 |
+
Returns:
|
| 47 |
+
expert_weights: Gewichte für jeden Experten [batch_size * seq_len, n_experts, capacity]
|
| 48 |
+
expert_mask: Maske für verwendete Experten [batch_size * seq_len, n_experts, capacity]
|
| 49 |
+
expert_batches: Batches für jeden Experten [n_experts, capacity, d_model]
|
| 50 |
+
router_logits: Router Logits für z-loss [batch_size, seq_len, n_experts]
|
| 51 |
+
"""
|
| 52 |
+
batch_size, seq_len, d_model = x.shape
|
| 53 |
+
num_tokens = batch_size * seq_len
|
| 54 |
+
|
| 55 |
+
# Router läuft IMMER in FP32 für numerische Stabilität!
|
| 56 |
+
device_type = "cuda" if x.is_cuda else "cpu"
|
| 57 |
+
with torch.amp.autocast(device_type=device_type, enabled=False):
|
| 58 |
+
x_fp32 = x.float()
|
| 59 |
+
|
| 60 |
+
# Router Logits berechnen
|
| 61 |
+
router_logits = self.w_gate(x_fp32) # [B, T, n_experts]
|
| 62 |
+
|
| 63 |
+
# Noisy Top-k Gating (optional)
|
| 64 |
+
if self.use_noisy_gating and self.training:
|
| 65 |
+
noise = F.softplus(self.w_noise(x_fp32))
|
| 66 |
+
noise = noise * torch.randn_like(noise)
|
| 67 |
+
router_logits = router_logits + noise
|
| 68 |
+
|
| 69 |
+
# Top-k Experten auswählen
|
| 70 |
+
top_k_logits, top_k_indices = router_logits.topk(
|
| 71 |
+
self.n_experts_active, dim=-1
|
| 72 |
+
) # [B, T, K]
|
| 73 |
+
|
| 74 |
+
# Softmax über alle Experten (nicht nur Top-k)
|
| 75 |
+
router_probs = torch.full_like(router_logits, float("-inf"))
|
| 76 |
+
router_probs.scatter_(-1, top_k_indices, top_k_logits)
|
| 77 |
+
router_probs = F.softmax(router_probs, dim=-1) # [B, T, n_experts]
|
| 78 |
+
|
| 79 |
+
# Expert Capacity berechnen
|
| 80 |
+
capacity = self._compute_capacity(num_tokens)
|
| 81 |
+
|
| 82 |
+
# Multi-hot Maske der gewählten Experten
|
| 83 |
+
expert_mask = F.one_hot(
|
| 84 |
+
top_k_indices, num_classes=self.n_experts
|
| 85 |
+
) # [B, T, K, n_experts]
|
| 86 |
+
expert_mask = expert_mask.view(num_tokens, self.n_experts_active, self.n_experts)
|
| 87 |
+
expert_mask = expert_mask.permute(1, 0, 2) # [K, num_tokens, n_experts]
|
| 88 |
+
|
| 89 |
+
# Position jedes Tokens im Expert Batch (cumsum für Top-1 first prioritization)
|
| 90 |
+
expert_rank = expert_mask.reshape(
|
| 91 |
+
self.n_experts_active * num_tokens, self.n_experts
|
| 92 |
+
)
|
| 93 |
+
expert_rank = torch.cumsum(expert_rank, dim=0) - 1
|
| 94 |
+
expert_rank = expert_rank.reshape(
|
| 95 |
+
self.n_experts_active, num_tokens, self.n_experts
|
| 96 |
+
)
|
| 97 |
+
|
| 98 |
+
# Tokens über Kapazität hinaus maskieren
|
| 99 |
+
expert_mask = expert_mask * torch.lt(expert_rank, capacity)
|
| 100 |
+
|
| 101 |
+
# Position im Expert Batch
|
| 102 |
+
expert_rank = torch.sum(expert_mask * expert_rank, dim=-1) # [K, num_tokens]
|
| 103 |
+
|
| 104 |
+
# Wahrscheinlichkeiten mit Maske multiplizieren
|
| 105 |
+
router_probs = router_probs.view(num_tokens, self.n_experts)[
|
| 106 |
+
None, :
|
| 107 |
+
] # [1, num_tokens, n_experts]
|
| 108 |
+
expert_weights = expert_mask * router_probs # [K, num_tokens, n_experts]
|
| 109 |
+
|
| 110 |
+
# One-hot für Position in Expert Batch
|
| 111 |
+
expert_rank_one_hot = F.one_hot(
|
| 112 |
+
expert_rank, num_classes=capacity
|
| 113 |
+
) # [K, num_tokens, capacity]
|
| 114 |
+
|
| 115 |
+
# Gewichte an Expert Batch Position
|
| 116 |
+
expert_weights = torch.sum(
|
| 117 |
+
expert_weights.unsqueeze(3) * expert_rank_one_hot.unsqueeze(2), dim=0
|
| 118 |
+
) # [num_tokens, n_experts, capacity]
|
| 119 |
+
expert_mask = expert_weights.bool()
|
| 120 |
+
|
| 121 |
+
# Expert Batches erstellen
|
| 122 |
+
x_flat = x.view(num_tokens, d_model)
|
| 123 |
+
expert_batches = (
|
| 124 |
+
expert_mask.permute(1, 2, 0).type_as(x) @ x_flat
|
| 125 |
+
) # [n_experts, capacity, d_model]
|
| 126 |
+
|
| 127 |
+
return expert_weights, expert_mask, expert_batches, router_logits
|
| 128 |
+
|
| 129 |
+
def _compute_capacity(self, num_tokens: int) -> int:
|
| 130 |
+
"""Berechnet Expert Capacity"""
|
| 131 |
+
capacity = math.floor(
|
| 132 |
+
self.n_experts_active * self.capacity_factor * num_tokens / self.n_experts
|
| 133 |
+
)
|
| 134 |
+
capacity += capacity % 2 # Gerade Zahl für bessere Hardware-Nutzung
|
| 135 |
+
return max(int(capacity), 2) # Minimum 2 für kleine Batches
|
| 136 |
+
|
| 137 |
+
|
| 138 |
+
class ExpertMLP(nn.Module):
|
| 139 |
+
"""
|
| 140 |
+
Batch von MLP Experten.
|
| 141 |
+
Alle Experten haben die gleiche Architektur, aber unabhängige Gewichte.
|
| 142 |
+
"""
|
| 143 |
+
|
| 144 |
+
def __init__(
|
| 145 |
+
self,
|
| 146 |
+
d_model: int,
|
| 147 |
+
n_experts: int,
|
| 148 |
+
bias: bool = False,
|
| 149 |
+
dropout: float = 0.1,
|
| 150 |
+
activation: str = "gelu",
|
| 151 |
+
):
|
| 152 |
+
super().__init__()
|
| 153 |
+
|
| 154 |
+
self.d_model = d_model
|
| 155 |
+
self.n_experts = n_experts
|
| 156 |
+
self.bias = bias
|
| 157 |
+
|
| 158 |
+
# 4x hidden dimension (Standard für GPT)
|
| 159 |
+
hidden_dim = 4 * d_model
|
| 160 |
+
|
| 161 |
+
# Gewichte für alle Experten (batch matmul)
|
| 162 |
+
self.w_fc = nn.Parameter(torch.empty(n_experts, d_model, hidden_dim))
|
| 163 |
+
self.w_proj = nn.Parameter(torch.empty(n_experts, hidden_dim, d_model))
|
| 164 |
+
|
| 165 |
+
if bias:
|
| 166 |
+
self.fc_bias = nn.Parameter(torch.empty(n_experts, 1, hidden_dim))
|
| 167 |
+
self.proj_bias = nn.Parameter(torch.empty(n_experts, 1, d_model))
|
| 168 |
+
else:
|
| 169 |
+
self.register_parameter("fc_bias", None)
|
| 170 |
+
self.register_parameter("proj_bias", None)
|
| 171 |
+
|
| 172 |
+
# Aktivierungsfunktion
|
| 173 |
+
if activation == "gelu":
|
| 174 |
+
self.activation = nn.GELU()
|
| 175 |
+
elif activation == "relu":
|
| 176 |
+
self.activation = nn.ReLU()
|
| 177 |
+
elif activation == "swiglu":
|
| 178 |
+
# SwiGLU braucht extra Gewichte
|
| 179 |
+
self.w_gate = nn.Parameter(torch.empty(n_experts, d_model, hidden_dim))
|
| 180 |
+
self.activation = nn.SiLU()
|
| 181 |
+
else:
|
| 182 |
+
raise ValueError(f"Unbekannte Aktivierung: {activation}")
|
| 183 |
+
|
| 184 |
+
self.dropout = nn.Dropout(dropout)
|
| 185 |
+
self.activation_type = activation
|
| 186 |
+
|
| 187 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| 188 |
+
"""
|
| 189 |
+
Args:
|
| 190 |
+
x: [n_experts, capacity, d_model]
|
| 191 |
+
|
| 192 |
+
Returns:
|
| 193 |
+
output: [n_experts, capacity, d_model]
|
| 194 |
+
"""
|
| 195 |
+
# Erste Linear Layer mit batch matmul
|
| 196 |
+
h = torch.bmm(x, self.w_fc)
|
| 197 |
+
if self.bias:
|
| 198 |
+
h = h + self.fc_bias
|
| 199 |
+
|
| 200 |
+
# Aktivierung
|
| 201 |
+
if self.activation_type == "swiglu":
|
| 202 |
+
# SwiGLU: silu(x @ W_gate) * (x @ W_fc)
|
| 203 |
+
gate = torch.bmm(x, self.w_gate)
|
| 204 |
+
h = self.activation(gate) * h
|
| 205 |
+
else:
|
| 206 |
+
h = self.activation(h)
|
| 207 |
+
|
| 208 |
+
# Zweite Linear Layer
|
| 209 |
+
output = torch.bmm(h, self.w_proj)
|
| 210 |
+
if self.bias:
|
| 211 |
+
output = output + self.proj_bias
|
| 212 |
+
|
| 213 |
+
output = self.dropout(output)
|
| 214 |
+
|
| 215 |
+
return output
|
| 216 |
+
|
| 217 |
+
|
| 218 |
+
class MoELayer(nn.Module):
|
| 219 |
+
"""
|
| 220 |
+
Vollständige Mixture-of-Experts Layer.
|
| 221 |
+
Kombiniert Router und Experten.
|
| 222 |
+
"""
|
| 223 |
+
|
| 224 |
+
def __init__(
|
| 225 |
+
self,
|
| 226 |
+
d_model: int,
|
| 227 |
+
n_experts: int = 8,
|
| 228 |
+
n_experts_active: int = 2,
|
| 229 |
+
use_noisy_gating: bool = True,
|
| 230 |
+
capacity_factor: float = 1.25,
|
| 231 |
+
bias: bool = False,
|
| 232 |
+
dropout: float = 0.1,
|
| 233 |
+
activation: str = "gelu",
|
| 234 |
+
):
|
| 235 |
+
super().__init__()
|
| 236 |
+
|
| 237 |
+
self.router = MoERouter(
|
| 238 |
+
d_model=d_model,
|
| 239 |
+
n_experts=n_experts,
|
| 240 |
+
n_experts_active=n_experts_active,
|
| 241 |
+
use_noisy_gating=use_noisy_gating,
|
| 242 |
+
capacity_factor=capacity_factor,
|
| 243 |
+
)
|
| 244 |
+
|
| 245 |
+
self.experts = ExpertMLP(
|
| 246 |
+
d_model=d_model,
|
| 247 |
+
n_experts=n_experts,
|
| 248 |
+
bias=bias,
|
| 249 |
+
dropout=dropout,
|
| 250 |
+
activation=activation,
|
| 251 |
+
)
|
| 252 |
+
|
| 253 |
+
self.n_experts = n_experts
|
| 254 |
+
self.n_experts_active = n_experts_active
|
| 255 |
+
|
| 256 |
+
def forward(
|
| 257 |
+
self, x: torch.Tensor
|
| 258 |
+
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
| 259 |
+
"""
|
| 260 |
+
Args:
|
| 261 |
+
x: [batch_size, seq_len, d_model]
|
| 262 |
+
|
| 263 |
+
Returns:
|
| 264 |
+
output: [batch_size, seq_len, d_model]
|
| 265 |
+
load_balance_loss: Skalarer Load Balancing Loss
|
| 266 |
+
router_z_loss: Skalarer Router Z-Loss
|
| 267 |
+
"""
|
| 268 |
+
batch_size, seq_len, d_model = x.shape
|
| 269 |
+
num_tokens = batch_size * seq_len
|
| 270 |
+
|
| 271 |
+
# Routing
|
| 272 |
+
expert_weights, expert_mask, expert_batches, router_logits = self.router(x)
|
| 273 |
+
|
| 274 |
+
# Expert Forward Pass
|
| 275 |
+
expert_outputs = self.experts(expert_batches) # [n_experts, capacity, d_model]
|
| 276 |
+
|
| 277 |
+
# Outputs kombinieren (gewichteter Durchschnitt)
|
| 278 |
+
expert_weights_flat = expert_weights.view(num_tokens, -1) # [num_tokens, n_experts * capacity]
|
| 279 |
+
expert_outputs_flat = expert_outputs.view(-1, d_model) # [n_experts * capacity, d_model]
|
| 280 |
+
output = expert_weights_flat @ expert_outputs_flat # [num_tokens, d_model]
|
| 281 |
+
output = output.view(batch_size, seq_len, d_model)
|
| 282 |
+
|
| 283 |
+
# Auxiliary Losses berechnen
|
| 284 |
+
load_balance_loss = self._compute_load_balance_loss(router_logits, expert_mask)
|
| 285 |
+
router_z_loss = self._compute_router_z_loss(router_logits)
|
| 286 |
+
|
| 287 |
+
return output, load_balance_loss, router_z_loss
|
| 288 |
+
|
| 289 |
+
def _compute_load_balance_loss(
|
| 290 |
+
self, router_logits: torch.Tensor, expert_mask: torch.Tensor
|
| 291 |
+
) -> torch.Tensor:
|
| 292 |
+
"""
|
| 293 |
+
Load Balancing Loss (Switch Transformer, Fedus et al. 2022)
|
| 294 |
+
Encourages uniform distribution of tokens across experts.
|
| 295 |
+
"""
|
| 296 |
+
batch_size, seq_len, n_experts = router_logits.shape
|
| 297 |
+
num_tokens = batch_size * seq_len
|
| 298 |
+
|
| 299 |
+
# Probability pro Expert
|
| 300 |
+
router_probs = F.softmax(router_logits, dim=-1) # [B, T, n_experts]
|
| 301 |
+
prob_per_expert = torch.mean(router_probs, dim=(0, 1)) # [n_experts]
|
| 302 |
+
|
| 303 |
+
# Token Ratio pro Expert
|
| 304 |
+
with torch.no_grad():
|
| 305 |
+
# expert_mask ist [num_tokens, n_experts, capacity]
|
| 306 |
+
tokens_per_expert = torch.sum(expert_mask.float(), dim=(0, 2)) # [n_experts]
|
| 307 |
+
tokens_per_expert = tokens_per_expert / (num_tokens * self.n_experts_active)
|
| 308 |
+
|
| 309 |
+
# Dot product (scaled by n_experts)
|
| 310 |
+
loss = self.n_experts * torch.sum(prob_per_expert * tokens_per_expert)
|
| 311 |
+
|
| 312 |
+
return loss
|
| 313 |
+
|
| 314 |
+
def _compute_router_z_loss(self, router_logits: torch.Tensor) -> torch.Tensor:
|
| 315 |
+
"""
|
| 316 |
+
Router Z-Loss (ST-MoE, Zoph et al. 2022)
|
| 317 |
+
Penalisiert große Router Logits für numerische Stabilität.
|
| 318 |
+
"""
|
| 319 |
+
# Squared logsumexp über Experten
|
| 320 |
+
z_loss = torch.logsumexp(router_logits, dim=-1) ** 2.0 # [B, T]
|
| 321 |
+
z_loss = torch.mean(z_loss)
|
| 322 |
+
|
| 323 |
+
return z_loss
|
moe_model.py
ADDED
|
@@ -0,0 +1,459 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
MoE GPT Model - HuggingFace kompatibel
|
| 3 |
+
Basiert auf nanoMoE und dem Blog Post
|
| 4 |
+
"""
|
| 5 |
+
|
| 6 |
+
import math
|
| 7 |
+
import torch
|
| 8 |
+
import torch.nn as nn
|
| 9 |
+
import torch.nn.functional as F
|
| 10 |
+
from typing import Optional, Tuple, Union
|
| 11 |
+
from dataclasses import dataclass
|
| 12 |
+
|
| 13 |
+
from transformers import PreTrainedModel
|
| 14 |
+
from transformers.generation import GenerationMixin
|
| 15 |
+
from transformers.modeling_outputs import CausalLMOutputWithPast
|
| 16 |
+
|
| 17 |
+
from moe_config import MoEGPTConfig
|
| 18 |
+
from moe_layers import MoELayer
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
@dataclass
|
| 22 |
+
class MoECausalLMOutput(CausalLMOutputWithPast):
|
| 23 |
+
"""
|
| 24 |
+
Erweiterte Output Klasse mit MoE-spezifischen Losses
|
| 25 |
+
"""
|
| 26 |
+
|
| 27 |
+
aux_loss: Optional[torch.FloatTensor] = None
|
| 28 |
+
router_z_loss: Optional[torch.FloatTensor] = None
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
def apply_rotary_emb(x: torch.Tensor, freqs_cos: torch.Tensor, freqs_sin: torch.Tensor) -> torch.Tensor:
|
| 32 |
+
"""
|
| 33 |
+
Applies Rotary Position Embeddings (RoPE) to input tensor.
|
| 34 |
+
|
| 35 |
+
Args:
|
| 36 |
+
x: Input tensor of shape [B, H, T, D]
|
| 37 |
+
freqs_cos: Cosine frequencies of shape [T, D//2]
|
| 38 |
+
freqs_sin: Sine frequencies of shape [T, D//2]
|
| 39 |
+
|
| 40 |
+
Returns:
|
| 41 |
+
Tensor with RoPE applied
|
| 42 |
+
"""
|
| 43 |
+
# Reshape x to separate real and imaginary parts for rotation
|
| 44 |
+
# x: [B, H, T, D] -> [B, H, T, D//2, 2]
|
| 45 |
+
x_complex = x.float().reshape(*x.shape[:-1], -1, 2)
|
| 46 |
+
|
| 47 |
+
# Apply rotation: (a + bi) * (cos + i*sin) = (a*cos - b*sin) + i(a*sin + b*cos)
|
| 48 |
+
x_rot_real = x_complex[..., 0] * freqs_cos - x_complex[..., 1] * freqs_sin
|
| 49 |
+
x_rot_imag = x_complex[..., 0] * freqs_sin + x_complex[..., 1] * freqs_cos
|
| 50 |
+
|
| 51 |
+
# Stack back together and flatten
|
| 52 |
+
x_out = torch.stack([x_rot_real, x_rot_imag], dim=-1)
|
| 53 |
+
x_out = x_out.flatten(-2)
|
| 54 |
+
|
| 55 |
+
return x_out.type_as(x)
|
| 56 |
+
|
| 57 |
+
|
| 58 |
+
def precompute_freqs_rope(dim: int, max_seq_len: int, theta: float = 10000.0) -> Tuple[torch.Tensor, torch.Tensor]:
|
| 59 |
+
"""
|
| 60 |
+
Precomputes RoPE frequencies.
|
| 61 |
+
|
| 62 |
+
Args:
|
| 63 |
+
dim: Head dimension
|
| 64 |
+
max_seq_len: Maximum sequence length
|
| 65 |
+
theta: RoPE theta parameter (base for frequency calculation)
|
| 66 |
+
|
| 67 |
+
Returns:
|
| 68 |
+
Tuple of (freqs_cos, freqs_sin) tensors of shape [max_seq_len, dim//2]
|
| 69 |
+
"""
|
| 70 |
+
# Compute frequencies for each dimension pair
|
| 71 |
+
freqs = 1.0 / (theta ** (torch.arange(0, dim, 2).float() / dim))
|
| 72 |
+
|
| 73 |
+
# Create position indices
|
| 74 |
+
t = torch.arange(max_seq_len, dtype=torch.float32)
|
| 75 |
+
|
| 76 |
+
# Compute outer product: [max_seq_len, dim//2]
|
| 77 |
+
freqs = torch.outer(t, freqs)
|
| 78 |
+
|
| 79 |
+
# Compute cos and sin
|
| 80 |
+
freqs_cos = torch.cos(freqs)
|
| 81 |
+
freqs_sin = torch.sin(freqs)
|
| 82 |
+
|
| 83 |
+
return freqs_cos, freqs_sin
|
| 84 |
+
|
| 85 |
+
|
| 86 |
+
class CausalSelfAttention(nn.Module):
|
| 87 |
+
"""
|
| 88 |
+
Multi-Head Causal Self-Attention with Rotary Position Embeddings (RoPE).
|
| 89 |
+
Uses PyTorch SDPA for optimized performance.
|
| 90 |
+
"""
|
| 91 |
+
|
| 92 |
+
def __init__(self, config: MoEGPTConfig):
|
| 93 |
+
super().__init__()
|
| 94 |
+
assert config.n_embd % config.n_head == 0
|
| 95 |
+
|
| 96 |
+
# Key, Query, Value für alle Heads gleichzeitig
|
| 97 |
+
self.c_attn = nn.Linear(config.n_embd, 3 * config.n_embd, bias=config.bias)
|
| 98 |
+
# Output Projektion
|
| 99 |
+
self.c_proj = nn.Linear(config.n_embd, config.n_embd, bias=config.bias)
|
| 100 |
+
|
| 101 |
+
# Regularization
|
| 102 |
+
self.attn_dropout = nn.Dropout(config.dropout)
|
| 103 |
+
self.resid_dropout = nn.Dropout(config.dropout)
|
| 104 |
+
|
| 105 |
+
self.n_head = config.n_head
|
| 106 |
+
self.n_embd = config.n_embd
|
| 107 |
+
self.dropout = config.dropout
|
| 108 |
+
self.head_dim = config.n_embd // config.n_head
|
| 109 |
+
|
| 110 |
+
# Precompute RoPE frequencies
|
| 111 |
+
freqs_cos, freqs_sin = precompute_freqs_rope(
|
| 112 |
+
dim=self.head_dim,
|
| 113 |
+
max_seq_len=config.n_positions,
|
| 114 |
+
theta=config.rope_theta
|
| 115 |
+
)
|
| 116 |
+
self.register_buffer("freqs_cos", freqs_cos, persistent=False)
|
| 117 |
+
self.register_buffer("freqs_sin", freqs_sin, persistent=False)
|
| 118 |
+
|
| 119 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| 120 |
+
B, T, C = x.size() # batch, sequence length, embedding dim
|
| 121 |
+
|
| 122 |
+
# Q, K, V berechnen
|
| 123 |
+
q, k, v = self.c_attn(x).split(self.n_embd, dim=2)
|
| 124 |
+
|
| 125 |
+
# Reshape für Multi-Head
|
| 126 |
+
k = k.view(B, T, self.n_head, self.head_dim).transpose(1, 2) # [B, H, T, d]
|
| 127 |
+
q = q.view(B, T, self.n_head, self.head_dim).transpose(1, 2)
|
| 128 |
+
v = v.view(B, T, self.n_head, self.head_dim).transpose(1, 2)
|
| 129 |
+
|
| 130 |
+
# Apply RoPE to Q and K
|
| 131 |
+
q = apply_rotary_emb(q, self.freqs_cos[:T], self.freqs_sin[:T])
|
| 132 |
+
k = apply_rotary_emb(k, self.freqs_cos[:T], self.freqs_sin[:T])
|
| 133 |
+
|
| 134 |
+
# Use PyTorch SDPA (Scaled Dot Product Attention) - optimized!
|
| 135 |
+
# SDPA handles causal masking, dropout, and is memory efficient
|
| 136 |
+
y = F.scaled_dot_product_attention(
|
| 137 |
+
q, k, v,
|
| 138 |
+
attn_mask=None, # Causal mask handled by is_causal
|
| 139 |
+
dropout_p=self.dropout if self.training else 0.0,
|
| 140 |
+
is_causal=True # Efficient causal masking
|
| 141 |
+
) # [B, H, T, d]
|
| 142 |
+
|
| 143 |
+
# Reshape back
|
| 144 |
+
y = y.transpose(1, 2).contiguous().view(B, T, C)
|
| 145 |
+
|
| 146 |
+
# Output Projektion
|
| 147 |
+
y = self.resid_dropout(self.c_proj(y))
|
| 148 |
+
|
| 149 |
+
return y
|
| 150 |
+
|
| 151 |
+
|
| 152 |
+
class MLP(nn.Module):
|
| 153 |
+
"""
|
| 154 |
+
Standard Feed-Forward Network (für nicht-MoE Layers)
|
| 155 |
+
"""
|
| 156 |
+
|
| 157 |
+
def __init__(self, config: MoEGPTConfig):
|
| 158 |
+
super().__init__()
|
| 159 |
+
self.c_fc = nn.Linear(config.n_embd, 4 * config.n_embd, bias=config.bias)
|
| 160 |
+
self.c_proj = nn.Linear(4 * config.n_embd, config.n_embd, bias=config.bias)
|
| 161 |
+
self.dropout = nn.Dropout(config.dropout)
|
| 162 |
+
|
| 163 |
+
if config.activation_function == "gelu":
|
| 164 |
+
self.activation = nn.GELU()
|
| 165 |
+
elif config.activation_function == "relu":
|
| 166 |
+
self.activation = nn.ReLU()
|
| 167 |
+
else:
|
| 168 |
+
raise ValueError(f"Unbekannte Aktivierung: {config.activation_function}")
|
| 169 |
+
|
| 170 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| 171 |
+
x = self.c_fc(x)
|
| 172 |
+
x = self.activation(x)
|
| 173 |
+
x = self.c_proj(x)
|
| 174 |
+
x = self.dropout(x)
|
| 175 |
+
return x
|
| 176 |
+
|
| 177 |
+
|
| 178 |
+
class TransformerBlock(nn.Module):
|
| 179 |
+
"""
|
| 180 |
+
Standard Transformer Block (Attention + MLP)
|
| 181 |
+
"""
|
| 182 |
+
|
| 183 |
+
def __init__(self, config: MoEGPTConfig):
|
| 184 |
+
super().__init__()
|
| 185 |
+
self.ln_1 = nn.LayerNorm(config.n_embd, eps=config.layer_norm_epsilon)
|
| 186 |
+
self.attn = CausalSelfAttention(config)
|
| 187 |
+
self.ln_2 = nn.LayerNorm(config.n_embd, eps=config.layer_norm_epsilon)
|
| 188 |
+
self.mlp = MLP(config)
|
| 189 |
+
|
| 190 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| 191 |
+
x = x + self.attn(self.ln_1(x))
|
| 192 |
+
x = x + self.mlp(self.ln_2(x))
|
| 193 |
+
return x
|
| 194 |
+
|
| 195 |
+
|
| 196 |
+
class MoETransformerBlock(nn.Module):
|
| 197 |
+
"""
|
| 198 |
+
MoE Transformer Block (Attention + MoE Layer)
|
| 199 |
+
"""
|
| 200 |
+
|
| 201 |
+
def __init__(self, config: MoEGPTConfig):
|
| 202 |
+
super().__init__()
|
| 203 |
+
self.ln_1 = nn.LayerNorm(config.n_embd, eps=config.layer_norm_epsilon)
|
| 204 |
+
self.attn = CausalSelfAttention(config)
|
| 205 |
+
self.ln_2 = nn.LayerNorm(config.n_embd, eps=config.layer_norm_epsilon)
|
| 206 |
+
|
| 207 |
+
# Capacity Factor abhängig von Training/Eval
|
| 208 |
+
self.moe = MoELayer(
|
| 209 |
+
d_model=config.n_embd,
|
| 210 |
+
n_experts=config.n_experts,
|
| 211 |
+
n_experts_active=config.n_experts_active,
|
| 212 |
+
use_noisy_gating=config.use_noisy_gating,
|
| 213 |
+
capacity_factor=config.capacity_factor,
|
| 214 |
+
bias=config.bias,
|
| 215 |
+
dropout=config.dropout,
|
| 216 |
+
activation=config.activation_function,
|
| 217 |
+
)
|
| 218 |
+
|
| 219 |
+
def forward(
|
| 220 |
+
self, x: torch.Tensor
|
| 221 |
+
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
| 222 |
+
# Attention
|
| 223 |
+
x = x + self.attn(self.ln_1(x))
|
| 224 |
+
|
| 225 |
+
# MoE Layer
|
| 226 |
+
moe_out, aux_loss, router_z_loss = self.moe(self.ln_2(x))
|
| 227 |
+
x = x + moe_out
|
| 228 |
+
|
| 229 |
+
return x, aux_loss, router_z_loss
|
| 230 |
+
|
| 231 |
+
|
| 232 |
+
class MoEGPTPreTrainedModel(PreTrainedModel):
|
| 233 |
+
"""
|
| 234 |
+
Base Klasse für MoE GPT mit HuggingFace PreTrainedModel
|
| 235 |
+
"""
|
| 236 |
+
|
| 237 |
+
config_class = MoEGPTConfig
|
| 238 |
+
base_model_prefix = "transformer"
|
| 239 |
+
supports_gradient_checkpointing = True
|
| 240 |
+
|
| 241 |
+
def _init_weights(self, module):
|
| 242 |
+
"""
|
| 243 |
+
Weight Initialization nach ST-MoE (Zoph et al. 2022)
|
| 244 |
+
Truncated Normal mit reduzierter Std für MoE Stabilität
|
| 245 |
+
"""
|
| 246 |
+
if isinstance(module, nn.Linear):
|
| 247 |
+
# Fan-in Initialization
|
| 248 |
+
fan_in = module.weight.shape[-1]
|
| 249 |
+
std = (self.config.initializer_range / fan_in) ** 0.5
|
| 250 |
+
|
| 251 |
+
torch.nn.init.trunc_normal_(
|
| 252 |
+
module.weight,
|
| 253 |
+
mean=0.0,
|
| 254 |
+
std=std,
|
| 255 |
+
a=-2 * std,
|
| 256 |
+
b=2 * std,
|
| 257 |
+
)
|
| 258 |
+
if module.bias is not None:
|
| 259 |
+
torch.nn.init.zeros_(module.bias)
|
| 260 |
+
|
| 261 |
+
elif isinstance(module, nn.Embedding):
|
| 262 |
+
torch.nn.init.normal_(module.weight, mean=0.0, std=self.config.initializer_range)
|
| 263 |
+
|
| 264 |
+
elif isinstance(module, nn.Parameter):
|
| 265 |
+
# Für Expert Parameter
|
| 266 |
+
fan_in = module.shape[-1] if len(module.shape) >= 2 else module.shape[0]
|
| 267 |
+
std = (self.config.initializer_range / fan_in) ** 0.5
|
| 268 |
+
|
| 269 |
+
torch.nn.init.trunc_normal_(
|
| 270 |
+
module,
|
| 271 |
+
mean=0.0,
|
| 272 |
+
std=std,
|
| 273 |
+
a=-2 * std,
|
| 274 |
+
b=2 * std,
|
| 275 |
+
)
|
| 276 |
+
|
| 277 |
+
|
| 278 |
+
class MoEGPTModel(MoEGPTPreTrainedModel):
|
| 279 |
+
"""
|
| 280 |
+
MoE GPT Model (ohne LM Head)
|
| 281 |
+
"""
|
| 282 |
+
|
| 283 |
+
def __init__(self, config: MoEGPTConfig):
|
| 284 |
+
super().__init__(config)
|
| 285 |
+
self.config = config
|
| 286 |
+
self.gradient_checkpointing = False # Für HF Gradient Checkpointing Support
|
| 287 |
+
|
| 288 |
+
# Token Embeddings only (RoPE handles positions)
|
| 289 |
+
self.wte = nn.Embedding(config.vocab_size, config.n_embd)
|
| 290 |
+
self.drop = nn.Dropout(config.dropout)
|
| 291 |
+
|
| 292 |
+
# Transformer Blocks (gemischt: Standard + MoE)
|
| 293 |
+
self.h = nn.ModuleList()
|
| 294 |
+
for i in range(config.n_layer):
|
| 295 |
+
if i % config.moe_layer_frequency == 0:
|
| 296 |
+
# MoE Block
|
| 297 |
+
self.h.append(MoETransformerBlock(config))
|
| 298 |
+
else:
|
| 299 |
+
# Standard Block
|
| 300 |
+
self.h.append(TransformerBlock(config))
|
| 301 |
+
|
| 302 |
+
# Final Layer Norm
|
| 303 |
+
self.ln_f = nn.LayerNorm(config.n_embd, eps=config.layer_norm_epsilon)
|
| 304 |
+
|
| 305 |
+
# Initialize weights
|
| 306 |
+
self.post_init()
|
| 307 |
+
|
| 308 |
+
def forward(
|
| 309 |
+
self,
|
| 310 |
+
input_ids: torch.LongTensor,
|
| 311 |
+
attention_mask: Optional[torch.Tensor] = None,
|
| 312 |
+
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
| 313 |
+
device = input_ids.device
|
| 314 |
+
b, t = input_ids.size()
|
| 315 |
+
|
| 316 |
+
assert t <= self.config.n_positions, f"Sequenz zu lang: {t} > {self.config.n_positions}"
|
| 317 |
+
|
| 318 |
+
# Token Embeddings only (RoPE in attention layers)
|
| 319 |
+
tok_emb = self.wte(input_ids) # [B, T, n_embd]
|
| 320 |
+
x = self.drop(tok_emb)
|
| 321 |
+
|
| 322 |
+
# Sammle Auxiliary Losses
|
| 323 |
+
total_aux_loss = 0.0
|
| 324 |
+
total_router_z_loss = 0.0
|
| 325 |
+
|
| 326 |
+
# Durch alle Blocks
|
| 327 |
+
for block in self.h:
|
| 328 |
+
if isinstance(block, MoETransformerBlock):
|
| 329 |
+
if self.gradient_checkpointing and self.training:
|
| 330 |
+
# Gradient Checkpointing für MoE Blocks
|
| 331 |
+
def create_custom_forward(module):
|
| 332 |
+
def custom_forward(*inputs):
|
| 333 |
+
return module(*inputs)
|
| 334 |
+
return custom_forward
|
| 335 |
+
|
| 336 |
+
x, aux_loss, router_z_loss = torch.utils.checkpoint.checkpoint(
|
| 337 |
+
create_custom_forward(block),
|
| 338 |
+
x,
|
| 339 |
+
use_reentrant=False
|
| 340 |
+
)
|
| 341 |
+
else:
|
| 342 |
+
x, aux_loss, router_z_loss = block(x)
|
| 343 |
+
total_aux_loss = total_aux_loss + aux_loss
|
| 344 |
+
total_router_z_loss = total_router_z_loss + router_z_loss
|
| 345 |
+
else:
|
| 346 |
+
if self.gradient_checkpointing and self.training:
|
| 347 |
+
x = torch.utils.checkpoint.checkpoint(
|
| 348 |
+
block,
|
| 349 |
+
x,
|
| 350 |
+
use_reentrant=False
|
| 351 |
+
)
|
| 352 |
+
else:
|
| 353 |
+
x = block(x)
|
| 354 |
+
|
| 355 |
+
x = self.ln_f(x)
|
| 356 |
+
|
| 357 |
+
return x, total_aux_loss, total_router_z_loss
|
| 358 |
+
|
| 359 |
+
|
| 360 |
+
class MoEGPTForCausalLM(MoEGPTPreTrainedModel, GenerationMixin):
|
| 361 |
+
"""
|
| 362 |
+
MoE GPT mit Language Modeling Head (für Pretraining)
|
| 363 |
+
Erbt von GenerationMixin für .generate() Support
|
| 364 |
+
"""
|
| 365 |
+
|
| 366 |
+
# Teile HuggingFace mit, welche Weights geteilt sind
|
| 367 |
+
_tied_weights_keys = ["lm_head.weight"]
|
| 368 |
+
|
| 369 |
+
def __init__(self, config: MoEGPTConfig):
|
| 370 |
+
super().__init__(config)
|
| 371 |
+
self.transformer = MoEGPTModel(config)
|
| 372 |
+
self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False)
|
| 373 |
+
|
| 374 |
+
# Weight Tying (LM Head teilt Gewichte mit Token Embedding)
|
| 375 |
+
self.lm_head.weight = self.transformer.wte.weight
|
| 376 |
+
|
| 377 |
+
# Initialize weights
|
| 378 |
+
self.post_init()
|
| 379 |
+
|
| 380 |
+
def get_output_embeddings(self):
|
| 381 |
+
"""Für HuggingFace Weight Tying"""
|
| 382 |
+
return self.lm_head
|
| 383 |
+
|
| 384 |
+
def set_output_embeddings(self, new_embeddings):
|
| 385 |
+
"""Für HuggingFace Weight Tying"""
|
| 386 |
+
self.lm_head = new_embeddings
|
| 387 |
+
|
| 388 |
+
def get_input_embeddings(self):
|
| 389 |
+
"""Für HuggingFace Weight Tying"""
|
| 390 |
+
return self.transformer.wte
|
| 391 |
+
|
| 392 |
+
def set_input_embeddings(self, new_embeddings):
|
| 393 |
+
"""Für HuggingFace Weight Tying"""
|
| 394 |
+
self.transformer.wte = new_embeddings
|
| 395 |
+
|
| 396 |
+
def tie_weights(self):
|
| 397 |
+
"""
|
| 398 |
+
Tie lm_head weights to input embeddings (weight tying)
|
| 399 |
+
Called after loading checkpoint to fix missing lm_head.weight
|
| 400 |
+
"""
|
| 401 |
+
self.lm_head.weight = self.transformer.wte.weight
|
| 402 |
+
|
| 403 |
+
def forward(
|
| 404 |
+
self,
|
| 405 |
+
input_ids: torch.LongTensor,
|
| 406 |
+
attention_mask: Optional[torch.Tensor] = None,
|
| 407 |
+
labels: Optional[torch.LongTensor] = None,
|
| 408 |
+
return_dict: Optional[bool] = None,
|
| 409 |
+
) -> Union[Tuple, MoECausalLMOutput]:
|
| 410 |
+
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
| 411 |
+
|
| 412 |
+
# Forward durch Transformer
|
| 413 |
+
hidden_states, aux_loss, router_z_loss = self.transformer(
|
| 414 |
+
input_ids=input_ids,
|
| 415 |
+
attention_mask=attention_mask,
|
| 416 |
+
)
|
| 417 |
+
|
| 418 |
+
# LM Head
|
| 419 |
+
if labels is not None:
|
| 420 |
+
# Training: nur letzte Position für jede Sequenz
|
| 421 |
+
logits = self.lm_head(hidden_states)
|
| 422 |
+
else:
|
| 423 |
+
# Inference: nur letzte Position
|
| 424 |
+
logits = self.lm_head(hidden_states[:, [-1], :])
|
| 425 |
+
|
| 426 |
+
# Loss berechnen
|
| 427 |
+
loss = None
|
| 428 |
+
if labels is not None:
|
| 429 |
+
# Shift für next token prediction
|
| 430 |
+
shift_logits = logits[..., :-1, :].contiguous()
|
| 431 |
+
shift_labels = labels[..., 1:].contiguous()
|
| 432 |
+
|
| 433 |
+
# Cross Entropy Loss
|
| 434 |
+
loss_fct = nn.CrossEntropyLoss()
|
| 435 |
+
lm_loss = loss_fct(
|
| 436 |
+
shift_logits.view(-1, shift_logits.size(-1)),
|
| 437 |
+
shift_labels.view(-1),
|
| 438 |
+
)
|
| 439 |
+
|
| 440 |
+
# Auxiliary Losses hinzufügen
|
| 441 |
+
loss = lm_loss
|
| 442 |
+
if self.training:
|
| 443 |
+
loss = loss + self.config.aux_loss_alpha * aux_loss
|
| 444 |
+
loss = loss + self.config.router_z_loss_alpha * router_z_loss
|
| 445 |
+
|
| 446 |
+
if not return_dict:
|
| 447 |
+
output = (logits,)
|
| 448 |
+
return ((loss,) + output) if loss is not None else output
|
| 449 |
+
|
| 450 |
+
return MoECausalLMOutput(
|
| 451 |
+
loss=loss,
|
| 452 |
+
logits=logits,
|
| 453 |
+
aux_loss=aux_loss if self.training else None,
|
| 454 |
+
router_z_loss=router_z_loss if self.training else None,
|
| 455 |
+
)
|
| 456 |
+
|
| 457 |
+
def prepare_inputs_for_generation(self, input_ids, **kwargs):
|
| 458 |
+
"""Für HuggingFace generate() Funktion"""
|
| 459 |
+
return {"input_ids": input_ids}
|
moe_trainer.py
ADDED
|
@@ -0,0 +1,168 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Custom MoE Trainer mit erweiterten Logging-Funktionen
|
| 3 |
+
"""
|
| 4 |
+
|
| 5 |
+
import torch
|
| 6 |
+
from typing import Dict, Optional, Any
|
| 7 |
+
from transformers import Trainer
|
| 8 |
+
from transformers.trainer_callback import TrainerCallback
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
class MoETrainer(Trainer):
|
| 12 |
+
"""
|
| 13 |
+
Erweiterter Trainer für MoE Modelle mit speziellem Logging für:
|
| 14 |
+
- Auxiliary Losses (Load Balancing, Router Z-Loss)
|
| 15 |
+
- Expert Utilization
|
| 16 |
+
- Capacity Factor Anpassung
|
| 17 |
+
"""
|
| 18 |
+
|
| 19 |
+
def compute_loss(self, model, inputs, return_outputs=False, num_items_in_batch=None):
|
| 20 |
+
"""
|
| 21 |
+
Überschreibt compute_loss um MoE-spezifische Losses zu berücksichtigen.
|
| 22 |
+
Diese sind bereits im model.forward() eingerechnet, aber wir loggen sie separat.
|
| 23 |
+
"""
|
| 24 |
+
# Labels für next token prediction
|
| 25 |
+
if "labels" not in inputs:
|
| 26 |
+
inputs["labels"] = inputs["input_ids"].clone()
|
| 27 |
+
|
| 28 |
+
# Forward pass
|
| 29 |
+
outputs = model(**inputs)
|
| 30 |
+
|
| 31 |
+
# Loss ist bereits total loss (LM + aux losses)
|
| 32 |
+
loss = outputs.loss
|
| 33 |
+
|
| 34 |
+
# Logging der Auxiliary Losses (wenn im Training)
|
| 35 |
+
if self.state.global_step % self.args.logging_steps == 0:
|
| 36 |
+
if hasattr(outputs, "aux_loss") and outputs.aux_loss is not None:
|
| 37 |
+
self.log({"train/aux_loss": outputs.aux_loss.item()})
|
| 38 |
+
|
| 39 |
+
if hasattr(outputs, "router_z_loss") and outputs.router_z_loss is not None:
|
| 40 |
+
self.log({"train/router_z_loss": outputs.router_z_loss.item()})
|
| 41 |
+
|
| 42 |
+
# Gesamter Loss breakdown
|
| 43 |
+
if hasattr(outputs, "aux_loss") and outputs.aux_loss is not None:
|
| 44 |
+
lm_loss = (
|
| 45 |
+
loss.item()
|
| 46 |
+
- self.model.config.aux_loss_alpha * outputs.aux_loss.item()
|
| 47 |
+
- self.model.config.router_z_loss_alpha * outputs.router_z_loss.item()
|
| 48 |
+
)
|
| 49 |
+
self.log({"train/lm_loss": lm_loss})
|
| 50 |
+
|
| 51 |
+
return (loss, outputs) if return_outputs else loss
|
| 52 |
+
|
| 53 |
+
def prediction_step(self, model, inputs, prediction_loss_only, ignore_keys=None):
|
| 54 |
+
"""
|
| 55 |
+
Überschreibt prediction_step um eval_loss korrekt zurückzugeben
|
| 56 |
+
"""
|
| 57 |
+
# Labels sicherstellen
|
| 58 |
+
if "labels" not in inputs:
|
| 59 |
+
inputs["labels"] = inputs["input_ids"].clone()
|
| 60 |
+
|
| 61 |
+
# Standard prediction_step aufrufen
|
| 62 |
+
loss, logits, labels = super().prediction_step(
|
| 63 |
+
model, inputs, prediction_loss_only, ignore_keys
|
| 64 |
+
)
|
| 65 |
+
|
| 66 |
+
return loss, logits, labels
|
| 67 |
+
|
| 68 |
+
def log(self, logs: Dict[str, float], start_time=None) -> None:
|
| 69 |
+
"""
|
| 70 |
+
Erweitert das Standard-Logging um MoE-spezifische Metriken
|
| 71 |
+
"""
|
| 72 |
+
# GPU Memory Tracking
|
| 73 |
+
if torch.cuda.is_available():
|
| 74 |
+
logs["gpu_memory_allocated_gb"] = (
|
| 75 |
+
torch.cuda.memory_allocated() / 1024**3
|
| 76 |
+
)
|
| 77 |
+
logs["gpu_memory_reserved_gb"] = (
|
| 78 |
+
torch.cuda.memory_reserved() / 1024**3
|
| 79 |
+
)
|
| 80 |
+
|
| 81 |
+
if start_time is not None:
|
| 82 |
+
super().log(logs, start_time)
|
| 83 |
+
else:
|
| 84 |
+
super().log(logs)
|
| 85 |
+
|
| 86 |
+
|
| 87 |
+
class MoEEvalCallback(TrainerCallback):
|
| 88 |
+
"""
|
| 89 |
+
Callback für erweiterte MoE-spezifische Evaluation
|
| 90 |
+
"""
|
| 91 |
+
|
| 92 |
+
def on_evaluate(self, args, state, control, model, metrics=None, **kwargs):
|
| 93 |
+
"""
|
| 94 |
+
Nach jeder Evaluation loggen wir zusätzliche MoE Metriken
|
| 95 |
+
"""
|
| 96 |
+
if metrics is not None and model is not None:
|
| 97 |
+
# Model Statistiken
|
| 98 |
+
total_params = sum(p.numel() for p in model.parameters())
|
| 99 |
+
trainable_params = sum(
|
| 100 |
+
p.numel() for p in model.parameters() if p.requires_grad
|
| 101 |
+
)
|
| 102 |
+
|
| 103 |
+
metrics["model/total_params_M"] = total_params / 1e6
|
| 104 |
+
metrics["model/trainable_params_M"] = trainable_params / 1e6
|
| 105 |
+
|
| 106 |
+
# MoE Spezifisch
|
| 107 |
+
if hasattr(model.config, "n_experts"):
|
| 108 |
+
metrics["model/total_experts"] = model.config.total_experts
|
| 109 |
+
metrics["model/active_params_ratio"] = (
|
| 110 |
+
model.config.active_parameters_ratio
|
| 111 |
+
)
|
| 112 |
+
|
| 113 |
+
|
| 114 |
+
class DataCollatorForLanguageModeling:
|
| 115 |
+
"""
|
| 116 |
+
Einfacher Data Collator für Causal Language Modeling.
|
| 117 |
+
Geht davon aus, dass Daten bereits tokenisiert sind.
|
| 118 |
+
"""
|
| 119 |
+
|
| 120 |
+
def __init__(self, pad_token_id: int = 0):
|
| 121 |
+
self.pad_token_id = pad_token_id
|
| 122 |
+
|
| 123 |
+
def __call__(self, examples):
|
| 124 |
+
"""
|
| 125 |
+
Args:
|
| 126 |
+
examples: Liste von Dicts mit 'input_ids' und 'attention_mask'
|
| 127 |
+
|
| 128 |
+
Returns:
|
| 129 |
+
Batch dict mit gepaddetem input_ids und attention_mask
|
| 130 |
+
"""
|
| 131 |
+
# Maximale Länge in diesem Batch
|
| 132 |
+
max_length = max(len(ex["input_ids"]) for ex in examples)
|
| 133 |
+
|
| 134 |
+
input_ids = []
|
| 135 |
+
attention_mask = []
|
| 136 |
+
|
| 137 |
+
for ex in examples:
|
| 138 |
+
seq_len = len(ex["input_ids"])
|
| 139 |
+
padding_length = max_length - seq_len
|
| 140 |
+
|
| 141 |
+
# Padding rechts
|
| 142 |
+
padded_input_ids = ex["input_ids"] + [self.pad_token_id] * padding_length
|
| 143 |
+
padded_attention_mask = ex["attention_mask"] + [0] * padding_length
|
| 144 |
+
|
| 145 |
+
input_ids.append(padded_input_ids)
|
| 146 |
+
attention_mask.append(padded_attention_mask)
|
| 147 |
+
|
| 148 |
+
# Als Tensoren
|
| 149 |
+
batch = {
|
| 150 |
+
"input_ids": torch.tensor(input_ids, dtype=torch.long),
|
| 151 |
+
"attention_mask": torch.tensor(attention_mask, dtype=torch.long),
|
| 152 |
+
}
|
| 153 |
+
|
| 154 |
+
return batch
|
| 155 |
+
|
| 156 |
+
|
| 157 |
+
def compute_metrics(eval_preds):
|
| 158 |
+
"""
|
| 159 |
+
Compute Perplexity für Evaluation
|
| 160 |
+
"""
|
| 161 |
+
predictions, labels = eval_preds
|
| 162 |
+
|
| 163 |
+
# Für Language Modeling sind predictions die Logits
|
| 164 |
+
# Labels sind die tatsächlichen Token IDs
|
| 165 |
+
# Wir berechnen nur Perplexity hier (Loss wird automatisch geloggt)
|
| 166 |
+
|
| 167 |
+
# Diese Funktion ist optional - Loss wird bereits vom Trainer berechnet
|
| 168 |
+
return {}
|
requirements.txt
ADDED
|
@@ -0,0 +1,96 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# German MoE GPT v6 - Requirements
|
| 2 |
+
# Environment: nano_moe (Conda)
|
| 3 |
+
# Python: 3.10+
|
| 4 |
+
# CUDA: 12.4
|
| 5 |
+
|
| 6 |
+
# ============================================================================
|
| 7 |
+
# CRITICAL: PyTorch Installation
|
| 8 |
+
# ============================================================================
|
| 9 |
+
# IMPORTANT: Install PyTorch FIRST with CUDA support!
|
| 10 |
+
# DO NOT use pip for PyTorch on Windows - use conda instead:
|
| 11 |
+
#
|
| 12 |
+
# conda install pytorch torchvision torchaudio pytorch-cuda=12.4 -c pytorch -c nvidia
|
| 13 |
+
#
|
| 14 |
+
# Or from PyTorch website (pip with CUDA):
|
| 15 |
+
# pip install torch==2.6.0+cu124 torchvision==0.21.0+cu124 torchaudio==2.6.0+cu124 --index-url https://download.pytorch.org/whl/cu124
|
| 16 |
+
#
|
| 17 |
+
# Current installed versions:
|
| 18 |
+
# torch==2.6.0+cu124
|
| 19 |
+
# torchvision==0.21.0+cu124
|
| 20 |
+
# torchaudio==2.6.0+cu124
|
| 21 |
+
# ============================================================================
|
| 22 |
+
|
| 23 |
+
# Core ML Libraries (install AFTER PyTorch!)
|
| 24 |
+
transformers==4.56.1
|
| 25 |
+
datasets==4.0.0
|
| 26 |
+
accelerate==1.10.1
|
| 27 |
+
|
| 28 |
+
# Training & Monitoring
|
| 29 |
+
tensorboard==2.20.0
|
| 30 |
+
tensorboard-data-server==0.7.2
|
| 31 |
+
|
| 32 |
+
# Tokenization
|
| 33 |
+
tokenizers==0.22.0
|
| 34 |
+
tiktoken==0.11.0
|
| 35 |
+
|
| 36 |
+
# Data Processing
|
| 37 |
+
numpy==1.26.4
|
| 38 |
+
pandas==2.3.2
|
| 39 |
+
pyarrow==21.0.0
|
| 40 |
+
|
| 41 |
+
# Utilities
|
| 42 |
+
tqdm==4.67.1
|
| 43 |
+
safetensors==0.6.2
|
| 44 |
+
huggingface-hub==0.34.4
|
| 45 |
+
regex==2025.9.1
|
| 46 |
+
fsspec==2025.3.0
|
| 47 |
+
dill==0.3.8
|
| 48 |
+
multiprocess==0.70.16
|
| 49 |
+
xxhash==3.5.0
|
| 50 |
+
|
| 51 |
+
# Performance (Windows CUDA)
|
| 52 |
+
triton-windows==3.2.0.post19 # Optimized kernels for CUDA
|
| 53 |
+
|
| 54 |
+
# Configuration & Logging
|
| 55 |
+
PyYAML==6.0.2
|
| 56 |
+
python-dotenv==1.0.1
|
| 57 |
+
requests==2.32.5
|
| 58 |
+
httpx[http2]==0.27.0
|
| 59 |
+
|
| 60 |
+
# Optional: Weights & Biases (uncomment if needed)
|
| 61 |
+
# wandb>=0.15.0
|
| 62 |
+
|
| 63 |
+
# ============================================================================
|
| 64 |
+
# Installation Instructions
|
| 65 |
+
# ============================================================================
|
| 66 |
+
#
|
| 67 |
+
# STEP 1: Create conda environment
|
| 68 |
+
# conda create -n nano_moe python=3.10
|
| 69 |
+
# conda activate nano_moe
|
| 70 |
+
#
|
| 71 |
+
# STEP 2: Install PyTorch with CUDA 12.4
|
| 72 |
+
# conda install pytorch torchvision torchaudio pytorch-cuda=12.4 -c pytorch -c nvidia
|
| 73 |
+
#
|
| 74 |
+
# STEP 3: Install remaining dependencies
|
| 75 |
+
# pip install -r requirements.txt --no-deps
|
| 76 |
+
# (--no-deps prevents pip from reinstalling PyTorch!)
|
| 77 |
+
#
|
| 78 |
+
# STEP 4: Verify installation
|
| 79 |
+
# python -c "import torch; print(f'CUDA available: {torch.cuda.is_available()}')"
|
| 80 |
+
#
|
| 81 |
+
# ============================================================================
|
| 82 |
+
# Notes
|
| 83 |
+
# ============================================================================
|
| 84 |
+
#
|
| 85 |
+
# - DO NOT install PyTorch via pip requirements.txt on Windows!
|
| 86 |
+
# It will install CPU version or wrong CUDA version
|
| 87 |
+
#
|
| 88 |
+
# - triton-windows only works on Windows with CUDA
|
| 89 |
+
# On Linux, use: triton>=2.0.0
|
| 90 |
+
#
|
| 91 |
+
# - datasets 4.0.0 has breaking changes from 2.x
|
| 92 |
+
# Use load_from_disk() / save_to_disk() for eval dataset
|
| 93 |
+
#
|
| 94 |
+
# - transformers 4.56.1 is compatible with our custom MoE implementation
|
| 95 |
+
#
|
| 96 |
+
# ============================================================================
|
sample_generation_callback.py
ADDED
|
@@ -0,0 +1,148 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Sample Generation Callback für MoE Training
|
| 3 |
+
Generiert Texte während des Trainings um Fortschritt zu beobachten
|
| 4 |
+
"""
|
| 5 |
+
|
| 6 |
+
import torch
|
| 7 |
+
from transformers import TrainerCallback, AutoTokenizer
|
| 8 |
+
from typing import Optional
|
| 9 |
+
import os
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
class SampleGenerationCallback(TrainerCallback):
|
| 13 |
+
"""
|
| 14 |
+
Generiert Sample-Texte alle N Steps während des Trainings
|
| 15 |
+
"""
|
| 16 |
+
|
| 17 |
+
def __init__(
|
| 18 |
+
self,
|
| 19 |
+
tokenizer,
|
| 20 |
+
prompts: list[str],
|
| 21 |
+
generate_every_n_steps: int = 100,
|
| 22 |
+
max_new_tokens: int = 50,
|
| 23 |
+
temperature: float = 0.8,
|
| 24 |
+
top_k: int = 50,
|
| 25 |
+
top_p: float = 0.95,
|
| 26 |
+
output_dir: str = "./samples",
|
| 27 |
+
):
|
| 28 |
+
"""
|
| 29 |
+
Args:
|
| 30 |
+
tokenizer: HuggingFace Tokenizer
|
| 31 |
+
prompts: Liste von Prompts für Generierung
|
| 32 |
+
generate_every_n_steps: Generiere alle N Steps
|
| 33 |
+
max_new_tokens: Max neue Tokens
|
| 34 |
+
temperature: Sampling Temperature
|
| 35 |
+
top_k: Top-k Sampling
|
| 36 |
+
top_p: Nucleus Sampling
|
| 37 |
+
output_dir: Ordner für Sample Outputs
|
| 38 |
+
"""
|
| 39 |
+
self.tokenizer = tokenizer
|
| 40 |
+
self.prompts = prompts
|
| 41 |
+
self.generate_every_n_steps = generate_every_n_steps
|
| 42 |
+
self.max_new_tokens = max_new_tokens
|
| 43 |
+
self.temperature = temperature
|
| 44 |
+
self.top_k = top_k
|
| 45 |
+
self.top_p = top_p
|
| 46 |
+
self.output_dir = output_dir
|
| 47 |
+
|
| 48 |
+
# Output Ordner erstellen
|
| 49 |
+
os.makedirs(output_dir, exist_ok=True)
|
| 50 |
+
|
| 51 |
+
# Samples Log Datei
|
| 52 |
+
self.log_file = os.path.join(output_dir, "generation_log.txt")
|
| 53 |
+
|
| 54 |
+
# Header schreiben
|
| 55 |
+
with open(self.log_file, "w", encoding="utf-8") as f:
|
| 56 |
+
f.write("=" * 80 + "\n")
|
| 57 |
+
f.write("MoE Training - Sample Generation Log\n")
|
| 58 |
+
f.write("=" * 80 + "\n\n")
|
| 59 |
+
|
| 60 |
+
def on_step_end(self, args, state, control, model=None, **kwargs):
|
| 61 |
+
"""
|
| 62 |
+
Wird nach jedem Training Step aufgerufen
|
| 63 |
+
"""
|
| 64 |
+
# Nur alle N Steps generieren
|
| 65 |
+
if state.global_step % self.generate_every_n_steps != 0:
|
| 66 |
+
return
|
| 67 |
+
|
| 68 |
+
# Skip wenn kein Model
|
| 69 |
+
if model is None:
|
| 70 |
+
return
|
| 71 |
+
|
| 72 |
+
print(f"\n{'='*80}")
|
| 73 |
+
print(f"🎨 GENERATING SAMPLES @ STEP {state.global_step}")
|
| 74 |
+
print(f"{'='*80}\n")
|
| 75 |
+
|
| 76 |
+
# Model in Eval Mode
|
| 77 |
+
model.eval()
|
| 78 |
+
|
| 79 |
+
samples = []
|
| 80 |
+
samples.append(f"\n{'='*80}\n")
|
| 81 |
+
samples.append(f"Step: {state.global_step}\n")
|
| 82 |
+
samples.append(f"{'='*80}\n\n")
|
| 83 |
+
|
| 84 |
+
with torch.no_grad():
|
| 85 |
+
for i, prompt in enumerate(self.prompts, 1):
|
| 86 |
+
print(f"[{i}/{len(self.prompts)}] Prompt: '{prompt}'")
|
| 87 |
+
|
| 88 |
+
# Tokenize
|
| 89 |
+
input_ids = self.tokenizer.encode(prompt, return_tensors="pt")
|
| 90 |
+
input_ids = input_ids.to(model.device)
|
| 91 |
+
|
| 92 |
+
try:
|
| 93 |
+
# Generieren
|
| 94 |
+
# NOTE: repetition_penalty is REQUIRED for longer generations!
|
| 95 |
+
# For 300 tokens, 1.3-1.5 is better than 1.2
|
| 96 |
+
output_ids = model.generate(
|
| 97 |
+
input_ids,
|
| 98 |
+
max_new_tokens=self.max_new_tokens,
|
| 99 |
+
temperature=self.temperature,
|
| 100 |
+
top_k=self.top_k,
|
| 101 |
+
top_p=self.top_p,
|
| 102 |
+
repetition_penalty=1.4, # ← Higher for 300 tokens!
|
| 103 |
+
do_sample=True,
|
| 104 |
+
pad_token_id=self.tokenizer.eos_token_id,
|
| 105 |
+
)
|
| 106 |
+
|
| 107 |
+
# Decode
|
| 108 |
+
generated_text = self.tokenizer.decode(
|
| 109 |
+
output_ids[0], skip_special_tokens=True
|
| 110 |
+
)
|
| 111 |
+
|
| 112 |
+
# Ausgabe
|
| 113 |
+
print(f" → {generated_text}\n")
|
| 114 |
+
|
| 115 |
+
# Log speichern
|
| 116 |
+
samples.append(f"Prompt {i}: {prompt}\n")
|
| 117 |
+
samples.append(f"Output: {generated_text}\n\n")
|
| 118 |
+
|
| 119 |
+
except Exception as e:
|
| 120 |
+
error_msg = f" ❌ Error: {str(e)}\n"
|
| 121 |
+
print(error_msg)
|
| 122 |
+
samples.append(f"Prompt {i}: {prompt}\n")
|
| 123 |
+
samples.append(f"Error: {str(e)}\n\n")
|
| 124 |
+
|
| 125 |
+
# Samples in Datei schreiben
|
| 126 |
+
with open(self.log_file, "a", encoding="utf-8") as f:
|
| 127 |
+
f.writelines(samples)
|
| 128 |
+
|
| 129 |
+
print(f"{'='*80}\n")
|
| 130 |
+
|
| 131 |
+
# Model zurück in Training Mode
|
| 132 |
+
model.train()
|
| 133 |
+
|
| 134 |
+
|
| 135 |
+
def get_german_sample_prompts():
|
| 136 |
+
"""
|
| 137 |
+
Gibt eine Liste deutscher Sample-Prompts zurück
|
| 138 |
+
"""
|
| 139 |
+
return [
|
| 140 |
+
"Die Künstliche Intelligenz",
|
| 141 |
+
"Im finsteren Wald",
|
| 142 |
+
"In der Zukunft werden wir",
|
| 143 |
+
"Machine Learning bedeutet",
|
| 144 |
+
"Das Wetter heute ist",
|
| 145 |
+
"Ein wichtiger Aspekt der",
|
| 146 |
+
"Die Geschichte von",
|
| 147 |
+
"Wissenschaftler haben herausgefunden",
|
| 148 |
+
]
|
train_moe_v8_clean.py
ADDED
|
@@ -0,0 +1,429 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
German MoE GPT v8 - CLEAN DATA + OPUS EDITION
|
| 3 |
+
Training mit Wikipedia + OpenSubtitles + Belletristik
|
| 4 |
+
|
| 5 |
+
Datasets (v8 - CLEAN + DIALOGUES! 🎉):
|
| 6 |
+
- Clean Wikipedia (local) - 11 GB (64%)
|
| 7 |
+
- OpenSubtitles OPUS (local) - 4.2 GB (24%)
|
| 8 |
+
- Belletristik (arnomatic/merged_all) - 2.2 GB (12%)
|
| 9 |
+
|
| 10 |
+
Total: ~17.4 GB of 100% CLEAN German text!
|
| 11 |
+
NO spam, NO ads, NO SEO garbage! ✅
|
| 12 |
+
PLUS natural dialogues from movie subtitles! 🎬
|
| 13 |
+
"""
|
| 14 |
+
|
| 15 |
+
import os
|
| 16 |
+
import sys
|
| 17 |
+
|
| 18 |
+
# Disable HF transfer (can cause issues on Windows)
|
| 19 |
+
os.environ["HF_HUB_ENABLE_HF_TRANSFER"] = "0"
|
| 20 |
+
|
| 21 |
+
# Force UTF-8 encoding for Windows console
|
| 22 |
+
if sys.platform == 'win32':
|
| 23 |
+
sys.stdout.reconfigure(encoding='utf-8')
|
| 24 |
+
|
| 25 |
+
import torch
|
| 26 |
+
from datasets import load_dataset, interleave_datasets
|
| 27 |
+
from transformers import TrainingArguments, set_seed, AutoTokenizer
|
| 28 |
+
|
| 29 |
+
from moe_config import MoEGPTConfig
|
| 30 |
+
from moe_model import MoEGPTForCausalLM
|
| 31 |
+
from moe_trainer import MoETrainer, MoEEvalCallback, DataCollatorForLanguageModeling
|
| 32 |
+
from sample_generation_callback import SampleGenerationCallback, get_german_sample_prompts
|
| 33 |
+
|
| 34 |
+
|
| 35 |
+
def load_clean_datasets(tokenizer, max_length=2048, seed=42, resume_step=0):
|
| 36 |
+
"""
|
| 37 |
+
Lädt 3 clean datasets (v8 - INTERLEAVED!):
|
| 38 |
+
- Wikipedia (WITH EOS) - 64%
|
| 39 |
+
- OpenSubtitles OPUS (NO EOS) - 24%
|
| 40 |
+
- Belletristik (NO EOS) - 12%
|
| 41 |
+
|
| 42 |
+
Args:
|
| 43 |
+
resume_step: If > 0, adjusts seed to continue from checkpoint
|
| 44 |
+
"""
|
| 45 |
+
# Adjust seed based on resume step (für reproducibility beim Resume)
|
| 46 |
+
effective_seed = seed + (resume_step // 1000)
|
| 47 |
+
print(f"📚 Lade CLEAN Datasets (v8 - OPUS Edition)...")
|
| 48 |
+
if resume_step > 0:
|
| 49 |
+
print(f" 🔄 Resume from step {resume_step} → Effective seed: {effective_seed}\n")
|
| 50 |
+
else:
|
| 51 |
+
print()
|
| 52 |
+
|
| 53 |
+
# ========================================================================
|
| 54 |
+
# 1. WIKIPEDIA (WITH EOS between articles)
|
| 55 |
+
# ========================================================================
|
| 56 |
+
print("1️⃣ Wikipedia (WITH EOS)...")
|
| 57 |
+
try:
|
| 58 |
+
wiki_ds = load_dataset(
|
| 59 |
+
"jonas-is-coding/german-wikipedia-articles",
|
| 60 |
+
split="train",
|
| 61 |
+
streaming=True
|
| 62 |
+
)
|
| 63 |
+
print(" ✅ Dataset loaded (streaming mode)")
|
| 64 |
+
|
| 65 |
+
# Shuffle
|
| 66 |
+
print(" 🔀 Shuffling with buffer_size=10,000...")
|
| 67 |
+
wiki_ds = wiki_ds.shuffle(seed=effective_seed, buffer_size=10000)
|
| 68 |
+
print(" ✅ Shuffle applied")
|
| 69 |
+
|
| 70 |
+
except Exception as e:
|
| 71 |
+
print(f" ❌ Wikipedia Error: {e}")
|
| 72 |
+
raise ValueError(f"Failed to load Wikipedia: {e}")
|
| 73 |
+
|
| 74 |
+
# ========================================================================
|
| 75 |
+
# 2. OPENSUBTITLES OPUS (NO EOS - continuous dialogues)
|
| 76 |
+
# ========================================================================
|
| 77 |
+
print("\n2️⃣ OpenSubtitles OPUS (NO EOS - continuous dialogues)...")
|
| 78 |
+
try:
|
| 79 |
+
opus_ds = load_dataset(
|
| 80 |
+
"arnomatic/german-opus-subtitles",
|
| 81 |
+
split="train",
|
| 82 |
+
streaming=True
|
| 83 |
+
)
|
| 84 |
+
print(" ✅ Dataset loaded (streaming mode)")
|
| 85 |
+
|
| 86 |
+
# Shuffle
|
| 87 |
+
print(" 🔀 Shuffling with buffer_size=10,000...")
|
| 88 |
+
opus_ds = opus_ds.shuffle(seed=effective_seed, buffer_size=10000)
|
| 89 |
+
print(" ✅ Shuffle applied")
|
| 90 |
+
|
| 91 |
+
except Exception as e:
|
| 92 |
+
print(f" ❌ OpenSubtitles Error: {e}")
|
| 93 |
+
raise ValueError(f"Failed to load OpenSubtitles: {e}")
|
| 94 |
+
|
| 95 |
+
# ========================================================================
|
| 96 |
+
# 3. BELLETRISTIK (NO EOS - continuous)
|
| 97 |
+
# ========================================================================
|
| 98 |
+
print("\n3️⃣ Belletristik (NO EOS - continuous)...")
|
| 99 |
+
try:
|
| 100 |
+
belle_ds = load_dataset(
|
| 101 |
+
"arnomatic/merged_all",
|
| 102 |
+
split="train",
|
| 103 |
+
streaming=True
|
| 104 |
+
)
|
| 105 |
+
print(" ✅ Dataset loaded (streaming mode)")
|
| 106 |
+
|
| 107 |
+
# Shuffle
|
| 108 |
+
print(" 🔀 Shuffling with buffer_size=10,000...")
|
| 109 |
+
belle_ds = belle_ds.shuffle(seed=effective_seed, buffer_size=10000)
|
| 110 |
+
print(" ✅ Shuffle applied")
|
| 111 |
+
|
| 112 |
+
except Exception as e:
|
| 113 |
+
print(f" ❌ Belletristik Error: {e}")
|
| 114 |
+
raise ValueError(f"Failed to load Belletristik: {e}")
|
| 115 |
+
|
| 116 |
+
print("\n✅ All datasets loaded!")
|
| 117 |
+
print(" Wikipedia: 4 GB (WITH EOS)")
|
| 118 |
+
print(" OpenSubtitles: 4.2 GB (NO EOS)")
|
| 119 |
+
print(" Belletristik: 2.2 GB (NO EOS)")
|
| 120 |
+
print(" Total: ~10.4 GB clean German!")
|
| 121 |
+
|
| 122 |
+
# ========================================================================
|
| 123 |
+
# DIRECT PACKING (no intermediate tokenization)
|
| 124 |
+
# ========================================================================
|
| 125 |
+
print("\n🔤 Tokenizing & Packing datasets...")
|
| 126 |
+
|
| 127 |
+
from datasets import IterableDataset as HFIterableDataset
|
| 128 |
+
|
| 129 |
+
def pack_dataset_with_eos(dataset, text_field='text'):
|
| 130 |
+
"""Pack dataset WITH EOS directly into 2048-token batches"""
|
| 131 |
+
def gen():
|
| 132 |
+
buffer = []
|
| 133 |
+
for example in dataset:
|
| 134 |
+
text = example.get(text_field, '')
|
| 135 |
+
if not text or not text.strip():
|
| 136 |
+
continue
|
| 137 |
+
|
| 138 |
+
# Tokenize
|
| 139 |
+
tokens = tokenizer.encode(text, add_special_tokens=False)
|
| 140 |
+
|
| 141 |
+
# Add tokens + EOS
|
| 142 |
+
buffer.extend(tokens)
|
| 143 |
+
buffer.append(tokenizer.eos_token_id)
|
| 144 |
+
|
| 145 |
+
# Yield complete chunks
|
| 146 |
+
while len(buffer) >= max_length:
|
| 147 |
+
yield {
|
| 148 |
+
"input_ids": buffer[:max_length],
|
| 149 |
+
"attention_mask": [1] * max_length,
|
| 150 |
+
"labels": buffer[:max_length],
|
| 151 |
+
}
|
| 152 |
+
buffer = buffer[max_length:]
|
| 153 |
+
|
| 154 |
+
return HFIterableDataset.from_generator(gen)
|
| 155 |
+
|
| 156 |
+
def pack_dataset_no_eos(dataset, text_field='text'):
|
| 157 |
+
"""Pack dataset WITHOUT EOS directly into 2048-token batches"""
|
| 158 |
+
def gen():
|
| 159 |
+
buffer = []
|
| 160 |
+
for example in dataset:
|
| 161 |
+
text = example.get(text_field, '')
|
| 162 |
+
if not text or not text.strip():
|
| 163 |
+
continue
|
| 164 |
+
|
| 165 |
+
# Tokenize
|
| 166 |
+
tokens = tokenizer.encode(text, add_special_tokens=False)
|
| 167 |
+
|
| 168 |
+
# Add tokens (NO EOS)
|
| 169 |
+
buffer.extend(tokens)
|
| 170 |
+
|
| 171 |
+
# Yield complete chunks
|
| 172 |
+
while len(buffer) >= max_length:
|
| 173 |
+
yield {
|
| 174 |
+
"input_ids": buffer[:max_length],
|
| 175 |
+
"attention_mask": [1] * max_length,
|
| 176 |
+
"labels": buffer[:max_length],
|
| 177 |
+
}
|
| 178 |
+
buffer = buffer[max_length:]
|
| 179 |
+
|
| 180 |
+
return HFIterableDataset.from_generator(gen)
|
| 181 |
+
|
| 182 |
+
print(" Wikipedia (WITH EOS)...")
|
| 183 |
+
wiki_batched = pack_dataset_with_eos(wiki_ds, text_field='content')
|
| 184 |
+
|
| 185 |
+
print(" OpenSubtitles (NO EOS)...")
|
| 186 |
+
opus_batched = pack_dataset_no_eos(opus_ds, text_field='text')
|
| 187 |
+
|
| 188 |
+
print(" Belletristik (NO EOS)...")
|
| 189 |
+
belle_batched = pack_dataset_no_eos(belle_ds, text_field='text')
|
| 190 |
+
|
| 191 |
+
print("✅ Batching complete!")
|
| 192 |
+
|
| 193 |
+
# ========================================================================
|
| 194 |
+
# INTERLEAVE DATASETS (64% Wiki, 24% OPUS, 12% Belle)
|
| 195 |
+
# ========================================================================
|
| 196 |
+
print("\n🔀 Interleaving datasets (64/24/12)...")
|
| 197 |
+
|
| 198 |
+
train_dataset = interleave_datasets(
|
| 199 |
+
[wiki_batched, opus_batched, belle_batched],
|
| 200 |
+
probabilities=[0.64, 0.24, 0.12],
|
| 201 |
+
seed=effective_seed,
|
| 202 |
+
stopping_strategy="all_exhausted"
|
| 203 |
+
)
|
| 204 |
+
|
| 205 |
+
print("✅ Datasets interleaved! (v8 strategy)")
|
| 206 |
+
print(" Wikipedia: 64%")
|
| 207 |
+
print(" OpenSubtitles: 24%")
|
| 208 |
+
print(" Belletristik: 12%")
|
| 209 |
+
|
| 210 |
+
# ========================================================================
|
| 211 |
+
# EVAL DATASET (fixed 500 samples from Wikipedia)
|
| 212 |
+
# ========================================================================
|
| 213 |
+
eval_dataset_path = "./eval_dataset_v8_clean"
|
| 214 |
+
|
| 215 |
+
if os.path.exists(eval_dataset_path):
|
| 216 |
+
print(f"\n📊 Loading existing eval dataset from {eval_dataset_path}...")
|
| 217 |
+
from datasets import load_from_disk
|
| 218 |
+
eval_dataset = load_from_disk(eval_dataset_path)
|
| 219 |
+
print(f"✅ Eval dataset loaded: {len(eval_dataset)} samples (from disk)")
|
| 220 |
+
else:
|
| 221 |
+
print("\n📊 Creating fixed eval set (500 samples from Wikipedia)...")
|
| 222 |
+
|
| 223 |
+
eval_samples = []
|
| 224 |
+
eval_iter = iter(wiki_batched)
|
| 225 |
+
for i in range(500):
|
| 226 |
+
try:
|
| 227 |
+
sample = next(eval_iter)
|
| 228 |
+
eval_samples.append(sample)
|
| 229 |
+
if (i + 1) % 100 == 0:
|
| 230 |
+
print(f" Collected {i+1}/500 samples...")
|
| 231 |
+
except StopIteration:
|
| 232 |
+
print(f" ⚠️ Only {i} eval samples available (dataset exhausted)")
|
| 233 |
+
break
|
| 234 |
+
|
| 235 |
+
if len(eval_samples) == 0:
|
| 236 |
+
raise ValueError("No eval samples collected! Dataset exhausted immediately.")
|
| 237 |
+
|
| 238 |
+
print(f" Collected {len(eval_samples)} samples total")
|
| 239 |
+
|
| 240 |
+
# Convert to regular Dataset (not streaming!)
|
| 241 |
+
from datasets import Dataset
|
| 242 |
+
eval_dataset = Dataset.from_dict({
|
| 243 |
+
key: [sample[key] for sample in eval_samples]
|
| 244 |
+
for key in eval_samples[0].keys()
|
| 245 |
+
})
|
| 246 |
+
|
| 247 |
+
# Save to disk
|
| 248 |
+
print(f"💾 Saving eval dataset to {eval_dataset_path}...")
|
| 249 |
+
eval_dataset.save_to_disk(eval_dataset_path)
|
| 250 |
+
print(f"✅ Eval dataset saved to disk!")
|
| 251 |
+
|
| 252 |
+
print(f" → No more fsspec cache leak!")
|
| 253 |
+
print(f" Training: Clean Mix (streaming)")
|
| 254 |
+
print(f" Eval: {len(eval_dataset)} samples (fixed, from disk)\n")
|
| 255 |
+
|
| 256 |
+
return train_dataset, eval_dataset
|
| 257 |
+
|
| 258 |
+
|
| 259 |
+
def main():
|
| 260 |
+
SEED = 42
|
| 261 |
+
set_seed(SEED)
|
| 262 |
+
|
| 263 |
+
# Config
|
| 264 |
+
config = MoEGPTConfig(
|
| 265 |
+
vocab_size=128256,
|
| 266 |
+
n_positions=2048,
|
| 267 |
+
n_embd=512,
|
| 268 |
+
n_layer=8,
|
| 269 |
+
n_head=8,
|
| 270 |
+
n_experts=8,
|
| 271 |
+
n_experts_active=2,
|
| 272 |
+
moe_layer_frequency=2,
|
| 273 |
+
capacity_factor=1.25,
|
| 274 |
+
eval_capacity_factor=2.0,
|
| 275 |
+
use_noisy_gating=True,
|
| 276 |
+
aux_loss_alpha=0.01,
|
| 277 |
+
router_z_loss_alpha=0.001,
|
| 278 |
+
bias=False,
|
| 279 |
+
dropout=0.1,
|
| 280 |
+
activation_function="gelu",
|
| 281 |
+
initializer_range=0.1,
|
| 282 |
+
rope_theta=10000.0,
|
| 283 |
+
)
|
| 284 |
+
|
| 285 |
+
print("\n🔧 Model Config:")
|
| 286 |
+
print(f" - Experten: {config.n_experts} (Top-{config.n_experts_active})")
|
| 287 |
+
print(f" - Parameter: {config.total_experts} MoE experts")
|
| 288 |
+
|
| 289 |
+
# Training Args
|
| 290 |
+
# Dataset: ~10.4 GB ≈ 2.5B tokens ≈ 1.2M batches (2048 tokens each)
|
| 291 |
+
# With batch size 32: ~38K steps per epoch
|
| 292 |
+
# ~1.3 epochs = ~50K steps (interleaved = more efficient)
|
| 293 |
+
training_args = TrainingArguments(
|
| 294 |
+
output_dir="./moe_checkpoints_v8_clean",
|
| 295 |
+
run_name="german_moe_v8_clean",
|
| 296 |
+
max_steps=200000,
|
| 297 |
+
per_device_train_batch_size=2,
|
| 298 |
+
per_device_eval_batch_size=2,
|
| 299 |
+
gradient_accumulation_steps=16,
|
| 300 |
+
learning_rate=6e-4,
|
| 301 |
+
warmup_steps=2000,
|
| 302 |
+
lr_scheduler_type="cosine",
|
| 303 |
+
weight_decay=0.1,
|
| 304 |
+
bf16=torch.cuda.is_bf16_supported() if torch.cuda.is_available() else False,
|
| 305 |
+
fp16=not torch.cuda.is_bf16_supported() if torch.cuda.is_available() else False,
|
| 306 |
+
logging_dir="./logs_v8_clean",
|
| 307 |
+
logging_steps=100,
|
| 308 |
+
logging_first_step=True,
|
| 309 |
+
report_to=["tensorboard"],
|
| 310 |
+
eval_strategy="steps",
|
| 311 |
+
eval_steps=1000, # Every 1K steps (more frequent than v7)
|
| 312 |
+
save_strategy="steps",
|
| 313 |
+
save_steps=1000,
|
| 314 |
+
save_total_limit=10,
|
| 315 |
+
dataloader_num_workers=0,
|
| 316 |
+
dataloader_pin_memory=True,
|
| 317 |
+
gradient_checkpointing=True,
|
| 318 |
+
seed=SEED,
|
| 319 |
+
load_best_model_at_end=False,
|
| 320 |
+
metric_for_best_model="eval_loss",
|
| 321 |
+
greater_is_better=False,
|
| 322 |
+
ignore_data_skip=True, # CRITICAL: Don't skip batches, use fresh shuffled data!
|
| 323 |
+
)
|
| 324 |
+
|
| 325 |
+
# Check for existing checkpoints (auto-resume) - DO THIS EARLY!
|
| 326 |
+
import glob
|
| 327 |
+
checkpoints = glob.glob(os.path.join(training_args.output_dir, "checkpoint-*"))
|
| 328 |
+
resume_from_checkpoint = None
|
| 329 |
+
resume_step = 0
|
| 330 |
+
|
| 331 |
+
if checkpoints:
|
| 332 |
+
latest_checkpoint = max(checkpoints, key=lambda x: int(x.split("-")[-1]))
|
| 333 |
+
resume_from_checkpoint = latest_checkpoint
|
| 334 |
+
resume_step = int(latest_checkpoint.split("-")[-1])
|
| 335 |
+
print(f"\n🔄 RESUME Training from: {latest_checkpoint} (Step {resume_step})")
|
| 336 |
+
else:
|
| 337 |
+
print("\n🆕 Starting fresh training (no checkpoints found)")
|
| 338 |
+
|
| 339 |
+
# Tokenizer
|
| 340 |
+
print("\n📚 Lade Tokenizer...")
|
| 341 |
+
tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-3.2-1B")
|
| 342 |
+
tokenizer.pad_token = tokenizer.eos_token
|
| 343 |
+
print("✅ Llama 3.2 Tokenizer geladen")
|
| 344 |
+
|
| 345 |
+
# Load Clean Datasets (with resume_step for reproducibility!)
|
| 346 |
+
train_dataset, eval_dataset = load_clean_datasets(
|
| 347 |
+
tokenizer=tokenizer,
|
| 348 |
+
max_length=2048,
|
| 349 |
+
seed=SEED,
|
| 350 |
+
resume_step=resume_step,
|
| 351 |
+
)
|
| 352 |
+
|
| 353 |
+
# Data Collator
|
| 354 |
+
data_collator = DataCollatorForLanguageModeling(pad_token_id=tokenizer.pad_token_id)
|
| 355 |
+
|
| 356 |
+
# Model
|
| 357 |
+
print("\n🏗️ Erstelle MoE Modell...")
|
| 358 |
+
model = MoEGPTForCausalLM(config)
|
| 359 |
+
|
| 360 |
+
# Ensure weight tying (especially after checkpoint load)
|
| 361 |
+
model.tie_weights()
|
| 362 |
+
|
| 363 |
+
total_params = sum(p.numel() for p in model.parameters())
|
| 364 |
+
print(f"✅ Modell erstellt! ({total_params/1e6:.1f}M params)")
|
| 365 |
+
|
| 366 |
+
# Callbacks
|
| 367 |
+
sample_callback = SampleGenerationCallback(
|
| 368 |
+
tokenizer=tokenizer,
|
| 369 |
+
prompts=get_german_sample_prompts(),
|
| 370 |
+
generate_every_n_steps=1000, # Every 1K steps - fast feedback!
|
| 371 |
+
max_new_tokens=500,
|
| 372 |
+
temperature=0.7,
|
| 373 |
+
top_p=0.7,
|
| 374 |
+
output_dir="./samples_v8_clean",
|
| 375 |
+
)
|
| 376 |
+
|
| 377 |
+
# Trainer
|
| 378 |
+
print("\n🚀 Initialisiere Trainer...")
|
| 379 |
+
trainer = MoETrainer(
|
| 380 |
+
model=model,
|
| 381 |
+
args=training_args,
|
| 382 |
+
train_dataset=train_dataset,
|
| 383 |
+
eval_dataset=eval_dataset,
|
| 384 |
+
data_collator=data_collator,
|
| 385 |
+
callbacks=[MoEEvalCallback(), sample_callback],
|
| 386 |
+
)
|
| 387 |
+
|
| 388 |
+
print("✅ Trainer bereit!")
|
| 389 |
+
|
| 390 |
+
print("\n" + "=" * 60)
|
| 391 |
+
print("🎯 STARTE TRAINING v8 - OPUS EDITION!")
|
| 392 |
+
print("=" * 60)
|
| 393 |
+
print("\nDataset Composition (INTERLEAVED!):")
|
| 394 |
+
print(" Wikipedia (WITH EOS): 64%")
|
| 395 |
+
print(" OpenSubtitles OPUS (NO EOS): 24%")
|
| 396 |
+
print(" Belletristik (NO EOS): 12%")
|
| 397 |
+
print("\nTotal: ~10.4 GB CLEAN German!")
|
| 398 |
+
print("NO spam, NO ads, NO SEO garbage! 🎉")
|
| 399 |
+
print("PLUS natural dialogues from movie subtitles! 🎬")
|
| 400 |
+
print("=" * 60 + "\n")
|
| 401 |
+
|
| 402 |
+
# Train with resume support
|
| 403 |
+
trainer.train(resume_from_checkpoint=resume_from_checkpoint)
|
| 404 |
+
|
| 405 |
+
# Save
|
| 406 |
+
print("\n💾 Speichere finales Modell...")
|
| 407 |
+
final_model_path = "./moe_final_v8_clean"
|
| 408 |
+
trainer.save_model(final_model_path)
|
| 409 |
+
config.save_pretrained(final_model_path)
|
| 410 |
+
print(f"✅ Modell gespeichert in: {final_model_path}")
|
| 411 |
+
|
| 412 |
+
# Eval
|
| 413 |
+
print("\n📊 Finale Evaluation...")
|
| 414 |
+
eval_results = trainer.evaluate()
|
| 415 |
+
|
| 416 |
+
for key, value in eval_results.items():
|
| 417 |
+
print(f" - {key}: {value:.4f}")
|
| 418 |
+
|
| 419 |
+
if "eval_loss" in eval_results:
|
| 420 |
+
perplexity = torch.exp(torch.tensor(eval_results["eval_loss"]))
|
| 421 |
+
print(f"\n🎯 Finale Perplexity: {perplexity:.2f}")
|
| 422 |
+
|
| 423 |
+
print("\n" + "=" * 60)
|
| 424 |
+
print("✅ TRAINING ABGESCHLOSSEN!")
|
| 425 |
+
print("=" * 60)
|
| 426 |
+
|
| 427 |
+
|
| 428 |
+
if __name__ == "__main__":
|
| 429 |
+
main()
|