arnomatic commited on
Commit
8cd0952
·
verified ·
1 Parent(s): 0e27037

Upload 8 files

Browse files
Files changed (8) hide show
  1. inference.py +241 -0
  2. moe_config.py +119 -0
  3. moe_layers.py +323 -0
  4. moe_model.py +459 -0
  5. moe_trainer.py +168 -0
  6. requirements.txt +96 -0
  7. sample_generation_callback.py +148 -0
  8. train_moe_v8_clean.py +429 -0
inference.py ADDED
@@ -0,0 +1,241 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Inference Script für trainiertes MoE Modell
3
+ Lädt automatisch den neuesten Checkpoint und testet verschiedene Sampling Strategien
4
+ """
5
+
6
+ import os
7
+ import sys
8
+ import torch
9
+ from transformers import AutoTokenizer
10
+ from moe_config import MoEGPTConfig
11
+ from moe_model import MoEGPTForCausalLM
12
+
13
+ # Force UTF-8 encoding for Windows console
14
+ if sys.platform == 'win32':
15
+ sys.stdout.reconfigure(encoding='utf-8')
16
+
17
+
18
+ def find_latest_checkpoint(checkpoint_dir="./moe_checkpoints_v8_clean"):
19
+ """
20
+ Findet den neuesten Checkpoint automatisch (v8 OPUS Edition!)
21
+
22
+ Returns:
23
+ str: Pfad zum neuesten Checkpoint oder None
24
+ """
25
+ if not os.path.exists(checkpoint_dir):
26
+ return None
27
+
28
+ checkpoints = [
29
+ os.path.join(checkpoint_dir, d)
30
+ for d in os.listdir(checkpoint_dir)
31
+ if d.startswith("checkpoint-")
32
+ ]
33
+
34
+ if not checkpoints:
35
+ return None
36
+
37
+ # Neuesten Checkpoint finden (nach creation time)
38
+ latest = max(checkpoints, key=os.path.getctime)
39
+
40
+ # Step Number extrahieren
41
+ step = latest.split("checkpoint-")[-1]
42
+ print(f"\n🔍 Neuester Checkpoint gefunden: Step {step}")
43
+
44
+ return latest
45
+
46
+
47
+ def load_model(model_path=None, device="cuda"):
48
+ """
49
+ Lädt trainiertes MoE Modell
50
+ Wenn model_path=None, wird automatisch der neueste Checkpoint geladen
51
+
52
+ Args:
53
+ model_path: Pfad zum gespeicherten Modell (None = auto-find)
54
+ device: Device für Inference (cuda/cpu)
55
+
56
+ Returns:
57
+ model: Geladenes Modell
58
+ config: Model Config
59
+ """
60
+ # Auto-find neuesten Checkpoint
61
+ if model_path is None:
62
+ model_path = find_latest_checkpoint()
63
+ if model_path is None:
64
+ # Fallback: Versuche finales Modell (v8)
65
+ model_path = "./moe_final_v8_clean"
66
+ if not os.path.exists(model_path):
67
+ raise ValueError("Kein Checkpoint gefunden! Trainiere zuerst ein Modell.")
68
+
69
+ print(f"\n📥 Lade Modell von: {model_path}")
70
+
71
+ config = MoEGPTConfig.from_pretrained(model_path)
72
+ model = MoEGPTForCausalLM.from_pretrained(model_path)
73
+
74
+ # Auf Device verschieben
75
+ if device == "cuda" and torch.cuda.is_available():
76
+ model = model.cuda()
77
+ print(f"✅ Modell geladen auf GPU")
78
+ else:
79
+ model = model.cpu()
80
+ print(f"✅ Modell geladen auf CPU")
81
+
82
+ model.eval()
83
+
84
+ total_params = sum(p.numel() for p in model.parameters())
85
+ print(f" 📊 Parameter: {total_params:,} ({total_params/1e6:.1f}M)")
86
+ print(f" 🧠 Experten: {config.total_experts}")
87
+ print(f" ⚡ Aktive Params: {config.active_parameters_ratio:.1%}")
88
+
89
+ return model, config
90
+
91
+
92
+ def generate_text(
93
+ model,
94
+ tokenizer,
95
+ prompt,
96
+ max_new_tokens=400,
97
+ temperature=0.8,
98
+ top_k=50,
99
+ top_p=0.95,
100
+ repetition_penalty=1.0,
101
+ device="cuda",
102
+ ):
103
+ """
104
+ Generiert Text mit dem MoE Modell
105
+
106
+ Args:
107
+ model: MoE Modell
108
+ tokenizer: Tokenizer
109
+ prompt: Input Prompt (String)
110
+ max_new_tokens: Maximale neue Tokens (400!)
111
+ temperature: Sampling Temperature
112
+ top_k: Top-k Sampling
113
+ top_p: Nucleus Sampling
114
+ repetition_penalty: Penalty für Wiederholungen
115
+ device: Device
116
+
117
+ Returns:
118
+ generated_text: Generierter Text
119
+ """
120
+ # Tokenize prompt
121
+ input_ids = tokenizer.encode(prompt, return_tensors="pt")
122
+
123
+ if device == "cuda":
124
+ input_ids = input_ids.cuda()
125
+
126
+ # Generieren
127
+ with torch.no_grad():
128
+ output_ids = model.generate(
129
+ input_ids,
130
+ max_new_tokens=max_new_tokens,
131
+ temperature=temperature,
132
+ top_k=top_k,
133
+ top_p=top_p,
134
+ repetition_penalty=repetition_penalty,
135
+ do_sample=True,
136
+ pad_token_id=tokenizer.eos_token_id,
137
+ )
138
+
139
+ # Decode
140
+ generated_text = tokenizer.decode(output_ids[0], skip_special_tokens=True)
141
+
142
+ return generated_text
143
+
144
+
145
+ def test_sampling_strategies(model, tokenizer, prompts, device="cuda"):
146
+ """
147
+ Testet verschiedene Sampling Strategien
148
+
149
+ Args:
150
+ model: MoE Modell
151
+ tokenizer: Tokenizer
152
+ prompts: Liste von Test-Prompts
153
+ device: Device
154
+ """
155
+ # Optimale Strategien (basierend auf umfangreichen Tests)
156
+ strategies = {
157
+ "Standard (temp=0.7, rep=1.2, top_k=50, top_p=0.8)": {
158
+ "temperature": 0.7,
159
+ "top_k": 50,
160
+ "top_p": 0.7,
161
+ "repetition_penalty": 1.2,
162
+ },
163
+ "Focused (temp=0.7, rep=1.4, #top_k=30, top_p=0.7)": {
164
+ "temperature": 0.7,
165
+ "top_k": 20,
166
+ "top_p": 0.7,
167
+ "repetition_penalty": 1.4,
168
+ },
169
+ }
170
+
171
+ print("\n" + "=" * 80)
172
+ print("🧪 TESTING SAMPLING STRATEGIES")
173
+ print("=" * 80)
174
+
175
+ for prompt in prompts:
176
+ print(f"\n{'='*80}")
177
+ print(f"PROMPT: '{prompt}'")
178
+ print(f"{'='*80}\n")
179
+
180
+ for strategy_name, params in strategies.items():
181
+ print(f"\n🎯 Strategy: {strategy_name}")
182
+ print("-" * 80)
183
+
184
+ try:
185
+ generated = generate_text(
186
+ model=model,
187
+ tokenizer=tokenizer,
188
+ prompt=prompt,
189
+ max_new_tokens=400, # 400 Tokens!
190
+ device=device,
191
+ **params
192
+ )
193
+
194
+ print(f"{generated}")
195
+ print()
196
+
197
+ except Exception as e:
198
+ print(f"❌ Error: {str(e)}\n")
199
+
200
+ print("\n" + "=" * 80)
201
+ print("💡 EMPFEHLUNG")
202
+ print("=" * 80)
203
+ print("""
204
+
205
+ """)
206
+
207
+
208
+ def main():
209
+ # Device
210
+ device = "cuda" if torch.cuda.is_available() else "cpu"
211
+ print(f"\n🖥️ Device: {device}")
212
+
213
+ # Modell laden (automatisch neuester Checkpoint!)
214
+ model, config = load_model(model_path=None, device=device)
215
+
216
+ # Tokenizer laden
217
+ print("\n📚 Lade Tokenizer...")
218
+ tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-3.2-1B")
219
+ tokenizer.pad_token = tokenizer.eos_token
220
+ print("✅ Llama 3.2 Tokenizer geladen")
221
+ print(f" - Vocab Size: {tokenizer.vocab_size:,}")
222
+ print(f" - EOS Token: {tokenizer.eos_token}")
223
+
224
+ # ==================== SAMPLING STRATEGY TESTS ====================
225
+
226
+ # Test Prompts (diverse!)
227
+ test_prompts = [
228
+ "Gestern bin ich ", # Narrativ
229
+ "Der Mond ", # Poetisch
230
+ "Im Labor ", # Wissenschaftlich
231
+ "Hast du auch das Gefühl, dass", # Persönlich/Forum
232
+ "Die Zeit",
233
+ "Was ist die Definition von Philosophie?"
234
+ ]
235
+
236
+ # Teste verschiedene Sampling Strategien
237
+ test_sampling_strategies(model, tokenizer, test_prompts, device)
238
+
239
+
240
+ if __name__ == "__main__":
241
+ main()
moe_config.py ADDED
@@ -0,0 +1,119 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ HuggingFace-compatible MoE Configuration
3
+ Basierend auf dem nanoMoE Blog Post
4
+ """
5
+
6
+ from transformers import PretrainedConfig
7
+
8
+
9
+ class MoEGPTConfig(PretrainedConfig):
10
+ """
11
+ Konfiguration für MoE-basiertes GPT Modell.
12
+
13
+ Args:
14
+ vocab_size (int): Größe des Vokabulars
15
+ n_positions (int): Maximale Sequenzlänge
16
+ n_embd (int): Dimensionalität der Embeddings (d im Blog)
17
+ n_layer (int): Anzahl der Transformer Blocks
18
+ n_head (int): Anzahl der Attention Heads
19
+ n_experts (int): Anzahl der Experten pro MoE Layer
20
+ n_experts_active (int): Anzahl aktiver Experten (top-k)
21
+ moe_layer_frequency (int): Jede n-te Layer wird zu MoE (P im Blog)
22
+ capacity_factor (float): Expert Capacity Factor für Training
23
+ eval_capacity_factor (float): Expert Capacity Factor für Evaluation
24
+ use_noisy_gating (bool): Ob Noisy Top-k Gating verwendet werden soll
25
+ aux_loss_alpha (float): Skalierung für Load Balancing Loss
26
+ router_z_loss_alpha (float): Skalierung für Router Z-Loss
27
+ bias (bool): Ob Bias in Linear Layers verwendet werden soll
28
+ dropout (float): Dropout Probability
29
+ activation_function (str): Aktivierungsfunktion (gelu, relu, swiglu)
30
+ initializer_range (float): Standard Deviation für Weight Initialization
31
+ layer_norm_epsilon (float): Epsilon für Layer Normalization
32
+ """
33
+
34
+ model_type = "moe_gpt"
35
+
36
+ def __init__(
37
+ self,
38
+ vocab_size=128256, # Llama 3.2 tokenizer (inkl. special tokens)
39
+ n_positions=2048, # Default 2048 für RoPE
40
+ n_embd=768,
41
+ n_layer=12,
42
+ n_head=12,
43
+ n_experts=8,
44
+ n_experts_active=2,
45
+ moe_layer_frequency=2,
46
+ capacity_factor=1.25,
47
+ eval_capacity_factor=2.0,
48
+ use_noisy_gating=True,
49
+ aux_loss_alpha=0.01,
50
+ router_z_loss_alpha=0.001,
51
+ bias=False,
52
+ dropout=0.1,
53
+ activation_function="gelu",
54
+ initializer_range=0.1,
55
+ layer_norm_epsilon=1e-5,
56
+ use_cache=True,
57
+ rope_theta=10000.0, # RoPE base theta
58
+ **kwargs,
59
+ ):
60
+ super().__init__(**kwargs)
61
+
62
+ self.vocab_size = vocab_size
63
+ self.n_positions = n_positions
64
+ self.n_embd = n_embd
65
+ self.n_layer = n_layer
66
+ self.n_head = n_head
67
+ self.n_experts = n_experts
68
+ self.n_experts_active = n_experts_active
69
+ self.moe_layer_frequency = moe_layer_frequency
70
+ self.capacity_factor = capacity_factor
71
+ self.eval_capacity_factor = eval_capacity_factor
72
+ self.use_noisy_gating = use_noisy_gating
73
+ self.aux_loss_alpha = aux_loss_alpha
74
+ self.router_z_loss_alpha = router_z_loss_alpha
75
+ self.bias = bias
76
+ self.dropout = dropout
77
+ self.activation_function = activation_function
78
+ self.initializer_range = initializer_range
79
+ self.layer_norm_epsilon = layer_norm_epsilon
80
+ self.use_cache = use_cache
81
+ self.rope_theta = rope_theta
82
+
83
+ # HuggingFace Standard Attribute (für .generate())
84
+ self.num_hidden_layers = n_layer
85
+ self.hidden_size = n_embd
86
+ self.num_attention_heads = n_head
87
+ self.max_position_embeddings = n_positions
88
+
89
+ # Validierung
90
+ assert n_embd % n_head == 0, "n_embd muss durch n_head teilbar sein"
91
+ assert n_experts_active <= n_experts, "n_experts_active darf nicht größer als n_experts sein"
92
+ assert moe_layer_frequency >= 1, "moe_layer_frequency muss mindestens 1 sein"
93
+
94
+ @property
95
+ def head_dim(self):
96
+ """Dimension pro Attention Head"""
97
+ return self.n_embd // self.n_head
98
+
99
+ @property
100
+ def total_experts(self):
101
+ """Gesamtanzahl der Experten im Modell"""
102
+ num_moe_layers = sum(1 for i in range(self.n_layer) if i % self.moe_layer_frequency == 0)
103
+ return num_moe_layers * self.n_experts
104
+
105
+ @property
106
+ def active_parameters_ratio(self):
107
+ """Ratio der aktiven Parameter (ungefähr)"""
108
+ num_moe_layers = sum(1 for i in range(self.n_layer) if i % self.moe_layer_frequency == 0)
109
+ num_dense_layers = self.n_layer - num_moe_layers
110
+
111
+ # Vereinfachte Schätzung (ignoriert Attention)
112
+ dense_params = num_dense_layers * (8 * self.n_embd**2) # FFN params
113
+ moe_total_params = num_moe_layers * self.n_experts * (8 * self.n_embd**2)
114
+ moe_active_params = num_moe_layers * self.n_experts_active * (8 * self.n_embd**2)
115
+
116
+ total = dense_params + moe_total_params
117
+ active = dense_params + moe_active_params
118
+
119
+ return active / total if total > 0 else 1.0
moe_layers.py ADDED
@@ -0,0 +1,323 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ MoE Layer Komponenten
3
+ Basierend auf dem nanoMoE Blog Post und HuggingFace Best Practices
4
+ """
5
+
6
+ import math
7
+ import torch
8
+ import torch.nn as nn
9
+ import torch.nn.functional as F
10
+ from typing import Tuple, Optional
11
+
12
+
13
+ class MoERouter(nn.Module):
14
+ """
15
+ Noisy Top-k Router für MoE.
16
+ Routet Tokens zu den Top-k Experten basierend auf gelernten Wahrscheinlichkeiten.
17
+ """
18
+
19
+ def __init__(
20
+ self,
21
+ d_model: int,
22
+ n_experts: int,
23
+ n_experts_active: int,
24
+ use_noisy_gating: bool = True,
25
+ capacity_factor: float = 1.25,
26
+ ):
27
+ super().__init__()
28
+
29
+ self.d_model = d_model
30
+ self.n_experts = n_experts
31
+ self.n_experts_active = n_experts_active
32
+ self.use_noisy_gating = use_noisy_gating
33
+ self.capacity_factor = capacity_factor
34
+
35
+ # Linear projections für Router (kein Bias, siehe Shazeer et al. 2017)
36
+ self.w_gate = nn.Linear(d_model, n_experts, bias=False)
37
+ self.w_noise = nn.Linear(d_model, n_experts, bias=False) if use_noisy_gating else None
38
+
39
+ def forward(
40
+ self, x: torch.Tensor
41
+ ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
42
+ """
43
+ Args:
44
+ x: Input tensor [batch_size, seq_len, d_model]
45
+
46
+ Returns:
47
+ expert_weights: Gewichte für jeden Experten [batch_size * seq_len, n_experts, capacity]
48
+ expert_mask: Maske für verwendete Experten [batch_size * seq_len, n_experts, capacity]
49
+ expert_batches: Batches für jeden Experten [n_experts, capacity, d_model]
50
+ router_logits: Router Logits für z-loss [batch_size, seq_len, n_experts]
51
+ """
52
+ batch_size, seq_len, d_model = x.shape
53
+ num_tokens = batch_size * seq_len
54
+
55
+ # Router läuft IMMER in FP32 für numerische Stabilität!
56
+ device_type = "cuda" if x.is_cuda else "cpu"
57
+ with torch.amp.autocast(device_type=device_type, enabled=False):
58
+ x_fp32 = x.float()
59
+
60
+ # Router Logits berechnen
61
+ router_logits = self.w_gate(x_fp32) # [B, T, n_experts]
62
+
63
+ # Noisy Top-k Gating (optional)
64
+ if self.use_noisy_gating and self.training:
65
+ noise = F.softplus(self.w_noise(x_fp32))
66
+ noise = noise * torch.randn_like(noise)
67
+ router_logits = router_logits + noise
68
+
69
+ # Top-k Experten auswählen
70
+ top_k_logits, top_k_indices = router_logits.topk(
71
+ self.n_experts_active, dim=-1
72
+ ) # [B, T, K]
73
+
74
+ # Softmax über alle Experten (nicht nur Top-k)
75
+ router_probs = torch.full_like(router_logits, float("-inf"))
76
+ router_probs.scatter_(-1, top_k_indices, top_k_logits)
77
+ router_probs = F.softmax(router_probs, dim=-1) # [B, T, n_experts]
78
+
79
+ # Expert Capacity berechnen
80
+ capacity = self._compute_capacity(num_tokens)
81
+
82
+ # Multi-hot Maske der gewählten Experten
83
+ expert_mask = F.one_hot(
84
+ top_k_indices, num_classes=self.n_experts
85
+ ) # [B, T, K, n_experts]
86
+ expert_mask = expert_mask.view(num_tokens, self.n_experts_active, self.n_experts)
87
+ expert_mask = expert_mask.permute(1, 0, 2) # [K, num_tokens, n_experts]
88
+
89
+ # Position jedes Tokens im Expert Batch (cumsum für Top-1 first prioritization)
90
+ expert_rank = expert_mask.reshape(
91
+ self.n_experts_active * num_tokens, self.n_experts
92
+ )
93
+ expert_rank = torch.cumsum(expert_rank, dim=0) - 1
94
+ expert_rank = expert_rank.reshape(
95
+ self.n_experts_active, num_tokens, self.n_experts
96
+ )
97
+
98
+ # Tokens über Kapazität hinaus maskieren
99
+ expert_mask = expert_mask * torch.lt(expert_rank, capacity)
100
+
101
+ # Position im Expert Batch
102
+ expert_rank = torch.sum(expert_mask * expert_rank, dim=-1) # [K, num_tokens]
103
+
104
+ # Wahrscheinlichkeiten mit Maske multiplizieren
105
+ router_probs = router_probs.view(num_tokens, self.n_experts)[
106
+ None, :
107
+ ] # [1, num_tokens, n_experts]
108
+ expert_weights = expert_mask * router_probs # [K, num_tokens, n_experts]
109
+
110
+ # One-hot für Position in Expert Batch
111
+ expert_rank_one_hot = F.one_hot(
112
+ expert_rank, num_classes=capacity
113
+ ) # [K, num_tokens, capacity]
114
+
115
+ # Gewichte an Expert Batch Position
116
+ expert_weights = torch.sum(
117
+ expert_weights.unsqueeze(3) * expert_rank_one_hot.unsqueeze(2), dim=0
118
+ ) # [num_tokens, n_experts, capacity]
119
+ expert_mask = expert_weights.bool()
120
+
121
+ # Expert Batches erstellen
122
+ x_flat = x.view(num_tokens, d_model)
123
+ expert_batches = (
124
+ expert_mask.permute(1, 2, 0).type_as(x) @ x_flat
125
+ ) # [n_experts, capacity, d_model]
126
+
127
+ return expert_weights, expert_mask, expert_batches, router_logits
128
+
129
+ def _compute_capacity(self, num_tokens: int) -> int:
130
+ """Berechnet Expert Capacity"""
131
+ capacity = math.floor(
132
+ self.n_experts_active * self.capacity_factor * num_tokens / self.n_experts
133
+ )
134
+ capacity += capacity % 2 # Gerade Zahl für bessere Hardware-Nutzung
135
+ return max(int(capacity), 2) # Minimum 2 für kleine Batches
136
+
137
+
138
+ class ExpertMLP(nn.Module):
139
+ """
140
+ Batch von MLP Experten.
141
+ Alle Experten haben die gleiche Architektur, aber unabhängige Gewichte.
142
+ """
143
+
144
+ def __init__(
145
+ self,
146
+ d_model: int,
147
+ n_experts: int,
148
+ bias: bool = False,
149
+ dropout: float = 0.1,
150
+ activation: str = "gelu",
151
+ ):
152
+ super().__init__()
153
+
154
+ self.d_model = d_model
155
+ self.n_experts = n_experts
156
+ self.bias = bias
157
+
158
+ # 4x hidden dimension (Standard für GPT)
159
+ hidden_dim = 4 * d_model
160
+
161
+ # Gewichte für alle Experten (batch matmul)
162
+ self.w_fc = nn.Parameter(torch.empty(n_experts, d_model, hidden_dim))
163
+ self.w_proj = nn.Parameter(torch.empty(n_experts, hidden_dim, d_model))
164
+
165
+ if bias:
166
+ self.fc_bias = nn.Parameter(torch.empty(n_experts, 1, hidden_dim))
167
+ self.proj_bias = nn.Parameter(torch.empty(n_experts, 1, d_model))
168
+ else:
169
+ self.register_parameter("fc_bias", None)
170
+ self.register_parameter("proj_bias", None)
171
+
172
+ # Aktivierungsfunktion
173
+ if activation == "gelu":
174
+ self.activation = nn.GELU()
175
+ elif activation == "relu":
176
+ self.activation = nn.ReLU()
177
+ elif activation == "swiglu":
178
+ # SwiGLU braucht extra Gewichte
179
+ self.w_gate = nn.Parameter(torch.empty(n_experts, d_model, hidden_dim))
180
+ self.activation = nn.SiLU()
181
+ else:
182
+ raise ValueError(f"Unbekannte Aktivierung: {activation}")
183
+
184
+ self.dropout = nn.Dropout(dropout)
185
+ self.activation_type = activation
186
+
187
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
188
+ """
189
+ Args:
190
+ x: [n_experts, capacity, d_model]
191
+
192
+ Returns:
193
+ output: [n_experts, capacity, d_model]
194
+ """
195
+ # Erste Linear Layer mit batch matmul
196
+ h = torch.bmm(x, self.w_fc)
197
+ if self.bias:
198
+ h = h + self.fc_bias
199
+
200
+ # Aktivierung
201
+ if self.activation_type == "swiglu":
202
+ # SwiGLU: silu(x @ W_gate) * (x @ W_fc)
203
+ gate = torch.bmm(x, self.w_gate)
204
+ h = self.activation(gate) * h
205
+ else:
206
+ h = self.activation(h)
207
+
208
+ # Zweite Linear Layer
209
+ output = torch.bmm(h, self.w_proj)
210
+ if self.bias:
211
+ output = output + self.proj_bias
212
+
213
+ output = self.dropout(output)
214
+
215
+ return output
216
+
217
+
218
+ class MoELayer(nn.Module):
219
+ """
220
+ Vollständige Mixture-of-Experts Layer.
221
+ Kombiniert Router und Experten.
222
+ """
223
+
224
+ def __init__(
225
+ self,
226
+ d_model: int,
227
+ n_experts: int = 8,
228
+ n_experts_active: int = 2,
229
+ use_noisy_gating: bool = True,
230
+ capacity_factor: float = 1.25,
231
+ bias: bool = False,
232
+ dropout: float = 0.1,
233
+ activation: str = "gelu",
234
+ ):
235
+ super().__init__()
236
+
237
+ self.router = MoERouter(
238
+ d_model=d_model,
239
+ n_experts=n_experts,
240
+ n_experts_active=n_experts_active,
241
+ use_noisy_gating=use_noisy_gating,
242
+ capacity_factor=capacity_factor,
243
+ )
244
+
245
+ self.experts = ExpertMLP(
246
+ d_model=d_model,
247
+ n_experts=n_experts,
248
+ bias=bias,
249
+ dropout=dropout,
250
+ activation=activation,
251
+ )
252
+
253
+ self.n_experts = n_experts
254
+ self.n_experts_active = n_experts_active
255
+
256
+ def forward(
257
+ self, x: torch.Tensor
258
+ ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
259
+ """
260
+ Args:
261
+ x: [batch_size, seq_len, d_model]
262
+
263
+ Returns:
264
+ output: [batch_size, seq_len, d_model]
265
+ load_balance_loss: Skalarer Load Balancing Loss
266
+ router_z_loss: Skalarer Router Z-Loss
267
+ """
268
+ batch_size, seq_len, d_model = x.shape
269
+ num_tokens = batch_size * seq_len
270
+
271
+ # Routing
272
+ expert_weights, expert_mask, expert_batches, router_logits = self.router(x)
273
+
274
+ # Expert Forward Pass
275
+ expert_outputs = self.experts(expert_batches) # [n_experts, capacity, d_model]
276
+
277
+ # Outputs kombinieren (gewichteter Durchschnitt)
278
+ expert_weights_flat = expert_weights.view(num_tokens, -1) # [num_tokens, n_experts * capacity]
279
+ expert_outputs_flat = expert_outputs.view(-1, d_model) # [n_experts * capacity, d_model]
280
+ output = expert_weights_flat @ expert_outputs_flat # [num_tokens, d_model]
281
+ output = output.view(batch_size, seq_len, d_model)
282
+
283
+ # Auxiliary Losses berechnen
284
+ load_balance_loss = self._compute_load_balance_loss(router_logits, expert_mask)
285
+ router_z_loss = self._compute_router_z_loss(router_logits)
286
+
287
+ return output, load_balance_loss, router_z_loss
288
+
289
+ def _compute_load_balance_loss(
290
+ self, router_logits: torch.Tensor, expert_mask: torch.Tensor
291
+ ) -> torch.Tensor:
292
+ """
293
+ Load Balancing Loss (Switch Transformer, Fedus et al. 2022)
294
+ Encourages uniform distribution of tokens across experts.
295
+ """
296
+ batch_size, seq_len, n_experts = router_logits.shape
297
+ num_tokens = batch_size * seq_len
298
+
299
+ # Probability pro Expert
300
+ router_probs = F.softmax(router_logits, dim=-1) # [B, T, n_experts]
301
+ prob_per_expert = torch.mean(router_probs, dim=(0, 1)) # [n_experts]
302
+
303
+ # Token Ratio pro Expert
304
+ with torch.no_grad():
305
+ # expert_mask ist [num_tokens, n_experts, capacity]
306
+ tokens_per_expert = torch.sum(expert_mask.float(), dim=(0, 2)) # [n_experts]
307
+ tokens_per_expert = tokens_per_expert / (num_tokens * self.n_experts_active)
308
+
309
+ # Dot product (scaled by n_experts)
310
+ loss = self.n_experts * torch.sum(prob_per_expert * tokens_per_expert)
311
+
312
+ return loss
313
+
314
+ def _compute_router_z_loss(self, router_logits: torch.Tensor) -> torch.Tensor:
315
+ """
316
+ Router Z-Loss (ST-MoE, Zoph et al. 2022)
317
+ Penalisiert große Router Logits für numerische Stabilität.
318
+ """
319
+ # Squared logsumexp über Experten
320
+ z_loss = torch.logsumexp(router_logits, dim=-1) ** 2.0 # [B, T]
321
+ z_loss = torch.mean(z_loss)
322
+
323
+ return z_loss
moe_model.py ADDED
@@ -0,0 +1,459 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ MoE GPT Model - HuggingFace kompatibel
3
+ Basiert auf nanoMoE und dem Blog Post
4
+ """
5
+
6
+ import math
7
+ import torch
8
+ import torch.nn as nn
9
+ import torch.nn.functional as F
10
+ from typing import Optional, Tuple, Union
11
+ from dataclasses import dataclass
12
+
13
+ from transformers import PreTrainedModel
14
+ from transformers.generation import GenerationMixin
15
+ from transformers.modeling_outputs import CausalLMOutputWithPast
16
+
17
+ from moe_config import MoEGPTConfig
18
+ from moe_layers import MoELayer
19
+
20
+
21
+ @dataclass
22
+ class MoECausalLMOutput(CausalLMOutputWithPast):
23
+ """
24
+ Erweiterte Output Klasse mit MoE-spezifischen Losses
25
+ """
26
+
27
+ aux_loss: Optional[torch.FloatTensor] = None
28
+ router_z_loss: Optional[torch.FloatTensor] = None
29
+
30
+
31
+ def apply_rotary_emb(x: torch.Tensor, freqs_cos: torch.Tensor, freqs_sin: torch.Tensor) -> torch.Tensor:
32
+ """
33
+ Applies Rotary Position Embeddings (RoPE) to input tensor.
34
+
35
+ Args:
36
+ x: Input tensor of shape [B, H, T, D]
37
+ freqs_cos: Cosine frequencies of shape [T, D//2]
38
+ freqs_sin: Sine frequencies of shape [T, D//2]
39
+
40
+ Returns:
41
+ Tensor with RoPE applied
42
+ """
43
+ # Reshape x to separate real and imaginary parts for rotation
44
+ # x: [B, H, T, D] -> [B, H, T, D//2, 2]
45
+ x_complex = x.float().reshape(*x.shape[:-1], -1, 2)
46
+
47
+ # Apply rotation: (a + bi) * (cos + i*sin) = (a*cos - b*sin) + i(a*sin + b*cos)
48
+ x_rot_real = x_complex[..., 0] * freqs_cos - x_complex[..., 1] * freqs_sin
49
+ x_rot_imag = x_complex[..., 0] * freqs_sin + x_complex[..., 1] * freqs_cos
50
+
51
+ # Stack back together and flatten
52
+ x_out = torch.stack([x_rot_real, x_rot_imag], dim=-1)
53
+ x_out = x_out.flatten(-2)
54
+
55
+ return x_out.type_as(x)
56
+
57
+
58
+ def precompute_freqs_rope(dim: int, max_seq_len: int, theta: float = 10000.0) -> Tuple[torch.Tensor, torch.Tensor]:
59
+ """
60
+ Precomputes RoPE frequencies.
61
+
62
+ Args:
63
+ dim: Head dimension
64
+ max_seq_len: Maximum sequence length
65
+ theta: RoPE theta parameter (base for frequency calculation)
66
+
67
+ Returns:
68
+ Tuple of (freqs_cos, freqs_sin) tensors of shape [max_seq_len, dim//2]
69
+ """
70
+ # Compute frequencies for each dimension pair
71
+ freqs = 1.0 / (theta ** (torch.arange(0, dim, 2).float() / dim))
72
+
73
+ # Create position indices
74
+ t = torch.arange(max_seq_len, dtype=torch.float32)
75
+
76
+ # Compute outer product: [max_seq_len, dim//2]
77
+ freqs = torch.outer(t, freqs)
78
+
79
+ # Compute cos and sin
80
+ freqs_cos = torch.cos(freqs)
81
+ freqs_sin = torch.sin(freqs)
82
+
83
+ return freqs_cos, freqs_sin
84
+
85
+
86
+ class CausalSelfAttention(nn.Module):
87
+ """
88
+ Multi-Head Causal Self-Attention with Rotary Position Embeddings (RoPE).
89
+ Uses PyTorch SDPA for optimized performance.
90
+ """
91
+
92
+ def __init__(self, config: MoEGPTConfig):
93
+ super().__init__()
94
+ assert config.n_embd % config.n_head == 0
95
+
96
+ # Key, Query, Value für alle Heads gleichzeitig
97
+ self.c_attn = nn.Linear(config.n_embd, 3 * config.n_embd, bias=config.bias)
98
+ # Output Projektion
99
+ self.c_proj = nn.Linear(config.n_embd, config.n_embd, bias=config.bias)
100
+
101
+ # Regularization
102
+ self.attn_dropout = nn.Dropout(config.dropout)
103
+ self.resid_dropout = nn.Dropout(config.dropout)
104
+
105
+ self.n_head = config.n_head
106
+ self.n_embd = config.n_embd
107
+ self.dropout = config.dropout
108
+ self.head_dim = config.n_embd // config.n_head
109
+
110
+ # Precompute RoPE frequencies
111
+ freqs_cos, freqs_sin = precompute_freqs_rope(
112
+ dim=self.head_dim,
113
+ max_seq_len=config.n_positions,
114
+ theta=config.rope_theta
115
+ )
116
+ self.register_buffer("freqs_cos", freqs_cos, persistent=False)
117
+ self.register_buffer("freqs_sin", freqs_sin, persistent=False)
118
+
119
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
120
+ B, T, C = x.size() # batch, sequence length, embedding dim
121
+
122
+ # Q, K, V berechnen
123
+ q, k, v = self.c_attn(x).split(self.n_embd, dim=2)
124
+
125
+ # Reshape für Multi-Head
126
+ k = k.view(B, T, self.n_head, self.head_dim).transpose(1, 2) # [B, H, T, d]
127
+ q = q.view(B, T, self.n_head, self.head_dim).transpose(1, 2)
128
+ v = v.view(B, T, self.n_head, self.head_dim).transpose(1, 2)
129
+
130
+ # Apply RoPE to Q and K
131
+ q = apply_rotary_emb(q, self.freqs_cos[:T], self.freqs_sin[:T])
132
+ k = apply_rotary_emb(k, self.freqs_cos[:T], self.freqs_sin[:T])
133
+
134
+ # Use PyTorch SDPA (Scaled Dot Product Attention) - optimized!
135
+ # SDPA handles causal masking, dropout, and is memory efficient
136
+ y = F.scaled_dot_product_attention(
137
+ q, k, v,
138
+ attn_mask=None, # Causal mask handled by is_causal
139
+ dropout_p=self.dropout if self.training else 0.0,
140
+ is_causal=True # Efficient causal masking
141
+ ) # [B, H, T, d]
142
+
143
+ # Reshape back
144
+ y = y.transpose(1, 2).contiguous().view(B, T, C)
145
+
146
+ # Output Projektion
147
+ y = self.resid_dropout(self.c_proj(y))
148
+
149
+ return y
150
+
151
+
152
+ class MLP(nn.Module):
153
+ """
154
+ Standard Feed-Forward Network (für nicht-MoE Layers)
155
+ """
156
+
157
+ def __init__(self, config: MoEGPTConfig):
158
+ super().__init__()
159
+ self.c_fc = nn.Linear(config.n_embd, 4 * config.n_embd, bias=config.bias)
160
+ self.c_proj = nn.Linear(4 * config.n_embd, config.n_embd, bias=config.bias)
161
+ self.dropout = nn.Dropout(config.dropout)
162
+
163
+ if config.activation_function == "gelu":
164
+ self.activation = nn.GELU()
165
+ elif config.activation_function == "relu":
166
+ self.activation = nn.ReLU()
167
+ else:
168
+ raise ValueError(f"Unbekannte Aktivierung: {config.activation_function}")
169
+
170
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
171
+ x = self.c_fc(x)
172
+ x = self.activation(x)
173
+ x = self.c_proj(x)
174
+ x = self.dropout(x)
175
+ return x
176
+
177
+
178
+ class TransformerBlock(nn.Module):
179
+ """
180
+ Standard Transformer Block (Attention + MLP)
181
+ """
182
+
183
+ def __init__(self, config: MoEGPTConfig):
184
+ super().__init__()
185
+ self.ln_1 = nn.LayerNorm(config.n_embd, eps=config.layer_norm_epsilon)
186
+ self.attn = CausalSelfAttention(config)
187
+ self.ln_2 = nn.LayerNorm(config.n_embd, eps=config.layer_norm_epsilon)
188
+ self.mlp = MLP(config)
189
+
190
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
191
+ x = x + self.attn(self.ln_1(x))
192
+ x = x + self.mlp(self.ln_2(x))
193
+ return x
194
+
195
+
196
+ class MoETransformerBlock(nn.Module):
197
+ """
198
+ MoE Transformer Block (Attention + MoE Layer)
199
+ """
200
+
201
+ def __init__(self, config: MoEGPTConfig):
202
+ super().__init__()
203
+ self.ln_1 = nn.LayerNorm(config.n_embd, eps=config.layer_norm_epsilon)
204
+ self.attn = CausalSelfAttention(config)
205
+ self.ln_2 = nn.LayerNorm(config.n_embd, eps=config.layer_norm_epsilon)
206
+
207
+ # Capacity Factor abhängig von Training/Eval
208
+ self.moe = MoELayer(
209
+ d_model=config.n_embd,
210
+ n_experts=config.n_experts,
211
+ n_experts_active=config.n_experts_active,
212
+ use_noisy_gating=config.use_noisy_gating,
213
+ capacity_factor=config.capacity_factor,
214
+ bias=config.bias,
215
+ dropout=config.dropout,
216
+ activation=config.activation_function,
217
+ )
218
+
219
+ def forward(
220
+ self, x: torch.Tensor
221
+ ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
222
+ # Attention
223
+ x = x + self.attn(self.ln_1(x))
224
+
225
+ # MoE Layer
226
+ moe_out, aux_loss, router_z_loss = self.moe(self.ln_2(x))
227
+ x = x + moe_out
228
+
229
+ return x, aux_loss, router_z_loss
230
+
231
+
232
+ class MoEGPTPreTrainedModel(PreTrainedModel):
233
+ """
234
+ Base Klasse für MoE GPT mit HuggingFace PreTrainedModel
235
+ """
236
+
237
+ config_class = MoEGPTConfig
238
+ base_model_prefix = "transformer"
239
+ supports_gradient_checkpointing = True
240
+
241
+ def _init_weights(self, module):
242
+ """
243
+ Weight Initialization nach ST-MoE (Zoph et al. 2022)
244
+ Truncated Normal mit reduzierter Std für MoE Stabilität
245
+ """
246
+ if isinstance(module, nn.Linear):
247
+ # Fan-in Initialization
248
+ fan_in = module.weight.shape[-1]
249
+ std = (self.config.initializer_range / fan_in) ** 0.5
250
+
251
+ torch.nn.init.trunc_normal_(
252
+ module.weight,
253
+ mean=0.0,
254
+ std=std,
255
+ a=-2 * std,
256
+ b=2 * std,
257
+ )
258
+ if module.bias is not None:
259
+ torch.nn.init.zeros_(module.bias)
260
+
261
+ elif isinstance(module, nn.Embedding):
262
+ torch.nn.init.normal_(module.weight, mean=0.0, std=self.config.initializer_range)
263
+
264
+ elif isinstance(module, nn.Parameter):
265
+ # Für Expert Parameter
266
+ fan_in = module.shape[-1] if len(module.shape) >= 2 else module.shape[0]
267
+ std = (self.config.initializer_range / fan_in) ** 0.5
268
+
269
+ torch.nn.init.trunc_normal_(
270
+ module,
271
+ mean=0.0,
272
+ std=std,
273
+ a=-2 * std,
274
+ b=2 * std,
275
+ )
276
+
277
+
278
+ class MoEGPTModel(MoEGPTPreTrainedModel):
279
+ """
280
+ MoE GPT Model (ohne LM Head)
281
+ """
282
+
283
+ def __init__(self, config: MoEGPTConfig):
284
+ super().__init__(config)
285
+ self.config = config
286
+ self.gradient_checkpointing = False # Für HF Gradient Checkpointing Support
287
+
288
+ # Token Embeddings only (RoPE handles positions)
289
+ self.wte = nn.Embedding(config.vocab_size, config.n_embd)
290
+ self.drop = nn.Dropout(config.dropout)
291
+
292
+ # Transformer Blocks (gemischt: Standard + MoE)
293
+ self.h = nn.ModuleList()
294
+ for i in range(config.n_layer):
295
+ if i % config.moe_layer_frequency == 0:
296
+ # MoE Block
297
+ self.h.append(MoETransformerBlock(config))
298
+ else:
299
+ # Standard Block
300
+ self.h.append(TransformerBlock(config))
301
+
302
+ # Final Layer Norm
303
+ self.ln_f = nn.LayerNorm(config.n_embd, eps=config.layer_norm_epsilon)
304
+
305
+ # Initialize weights
306
+ self.post_init()
307
+
308
+ def forward(
309
+ self,
310
+ input_ids: torch.LongTensor,
311
+ attention_mask: Optional[torch.Tensor] = None,
312
+ ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
313
+ device = input_ids.device
314
+ b, t = input_ids.size()
315
+
316
+ assert t <= self.config.n_positions, f"Sequenz zu lang: {t} > {self.config.n_positions}"
317
+
318
+ # Token Embeddings only (RoPE in attention layers)
319
+ tok_emb = self.wte(input_ids) # [B, T, n_embd]
320
+ x = self.drop(tok_emb)
321
+
322
+ # Sammle Auxiliary Losses
323
+ total_aux_loss = 0.0
324
+ total_router_z_loss = 0.0
325
+
326
+ # Durch alle Blocks
327
+ for block in self.h:
328
+ if isinstance(block, MoETransformerBlock):
329
+ if self.gradient_checkpointing and self.training:
330
+ # Gradient Checkpointing für MoE Blocks
331
+ def create_custom_forward(module):
332
+ def custom_forward(*inputs):
333
+ return module(*inputs)
334
+ return custom_forward
335
+
336
+ x, aux_loss, router_z_loss = torch.utils.checkpoint.checkpoint(
337
+ create_custom_forward(block),
338
+ x,
339
+ use_reentrant=False
340
+ )
341
+ else:
342
+ x, aux_loss, router_z_loss = block(x)
343
+ total_aux_loss = total_aux_loss + aux_loss
344
+ total_router_z_loss = total_router_z_loss + router_z_loss
345
+ else:
346
+ if self.gradient_checkpointing and self.training:
347
+ x = torch.utils.checkpoint.checkpoint(
348
+ block,
349
+ x,
350
+ use_reentrant=False
351
+ )
352
+ else:
353
+ x = block(x)
354
+
355
+ x = self.ln_f(x)
356
+
357
+ return x, total_aux_loss, total_router_z_loss
358
+
359
+
360
+ class MoEGPTForCausalLM(MoEGPTPreTrainedModel, GenerationMixin):
361
+ """
362
+ MoE GPT mit Language Modeling Head (für Pretraining)
363
+ Erbt von GenerationMixin für .generate() Support
364
+ """
365
+
366
+ # Teile HuggingFace mit, welche Weights geteilt sind
367
+ _tied_weights_keys = ["lm_head.weight"]
368
+
369
+ def __init__(self, config: MoEGPTConfig):
370
+ super().__init__(config)
371
+ self.transformer = MoEGPTModel(config)
372
+ self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False)
373
+
374
+ # Weight Tying (LM Head teilt Gewichte mit Token Embedding)
375
+ self.lm_head.weight = self.transformer.wte.weight
376
+
377
+ # Initialize weights
378
+ self.post_init()
379
+
380
+ def get_output_embeddings(self):
381
+ """Für HuggingFace Weight Tying"""
382
+ return self.lm_head
383
+
384
+ def set_output_embeddings(self, new_embeddings):
385
+ """Für HuggingFace Weight Tying"""
386
+ self.lm_head = new_embeddings
387
+
388
+ def get_input_embeddings(self):
389
+ """Für HuggingFace Weight Tying"""
390
+ return self.transformer.wte
391
+
392
+ def set_input_embeddings(self, new_embeddings):
393
+ """Für HuggingFace Weight Tying"""
394
+ self.transformer.wte = new_embeddings
395
+
396
+ def tie_weights(self):
397
+ """
398
+ Tie lm_head weights to input embeddings (weight tying)
399
+ Called after loading checkpoint to fix missing lm_head.weight
400
+ """
401
+ self.lm_head.weight = self.transformer.wte.weight
402
+
403
+ def forward(
404
+ self,
405
+ input_ids: torch.LongTensor,
406
+ attention_mask: Optional[torch.Tensor] = None,
407
+ labels: Optional[torch.LongTensor] = None,
408
+ return_dict: Optional[bool] = None,
409
+ ) -> Union[Tuple, MoECausalLMOutput]:
410
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
411
+
412
+ # Forward durch Transformer
413
+ hidden_states, aux_loss, router_z_loss = self.transformer(
414
+ input_ids=input_ids,
415
+ attention_mask=attention_mask,
416
+ )
417
+
418
+ # LM Head
419
+ if labels is not None:
420
+ # Training: nur letzte Position für jede Sequenz
421
+ logits = self.lm_head(hidden_states)
422
+ else:
423
+ # Inference: nur letzte Position
424
+ logits = self.lm_head(hidden_states[:, [-1], :])
425
+
426
+ # Loss berechnen
427
+ loss = None
428
+ if labels is not None:
429
+ # Shift für next token prediction
430
+ shift_logits = logits[..., :-1, :].contiguous()
431
+ shift_labels = labels[..., 1:].contiguous()
432
+
433
+ # Cross Entropy Loss
434
+ loss_fct = nn.CrossEntropyLoss()
435
+ lm_loss = loss_fct(
436
+ shift_logits.view(-1, shift_logits.size(-1)),
437
+ shift_labels.view(-1),
438
+ )
439
+
440
+ # Auxiliary Losses hinzufügen
441
+ loss = lm_loss
442
+ if self.training:
443
+ loss = loss + self.config.aux_loss_alpha * aux_loss
444
+ loss = loss + self.config.router_z_loss_alpha * router_z_loss
445
+
446
+ if not return_dict:
447
+ output = (logits,)
448
+ return ((loss,) + output) if loss is not None else output
449
+
450
+ return MoECausalLMOutput(
451
+ loss=loss,
452
+ logits=logits,
453
+ aux_loss=aux_loss if self.training else None,
454
+ router_z_loss=router_z_loss if self.training else None,
455
+ )
456
+
457
+ def prepare_inputs_for_generation(self, input_ids, **kwargs):
458
+ """Für HuggingFace generate() Funktion"""
459
+ return {"input_ids": input_ids}
moe_trainer.py ADDED
@@ -0,0 +1,168 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Custom MoE Trainer mit erweiterten Logging-Funktionen
3
+ """
4
+
5
+ import torch
6
+ from typing import Dict, Optional, Any
7
+ from transformers import Trainer
8
+ from transformers.trainer_callback import TrainerCallback
9
+
10
+
11
+ class MoETrainer(Trainer):
12
+ """
13
+ Erweiterter Trainer für MoE Modelle mit speziellem Logging für:
14
+ - Auxiliary Losses (Load Balancing, Router Z-Loss)
15
+ - Expert Utilization
16
+ - Capacity Factor Anpassung
17
+ """
18
+
19
+ def compute_loss(self, model, inputs, return_outputs=False, num_items_in_batch=None):
20
+ """
21
+ Überschreibt compute_loss um MoE-spezifische Losses zu berücksichtigen.
22
+ Diese sind bereits im model.forward() eingerechnet, aber wir loggen sie separat.
23
+ """
24
+ # Labels für next token prediction
25
+ if "labels" not in inputs:
26
+ inputs["labels"] = inputs["input_ids"].clone()
27
+
28
+ # Forward pass
29
+ outputs = model(**inputs)
30
+
31
+ # Loss ist bereits total loss (LM + aux losses)
32
+ loss = outputs.loss
33
+
34
+ # Logging der Auxiliary Losses (wenn im Training)
35
+ if self.state.global_step % self.args.logging_steps == 0:
36
+ if hasattr(outputs, "aux_loss") and outputs.aux_loss is not None:
37
+ self.log({"train/aux_loss": outputs.aux_loss.item()})
38
+
39
+ if hasattr(outputs, "router_z_loss") and outputs.router_z_loss is not None:
40
+ self.log({"train/router_z_loss": outputs.router_z_loss.item()})
41
+
42
+ # Gesamter Loss breakdown
43
+ if hasattr(outputs, "aux_loss") and outputs.aux_loss is not None:
44
+ lm_loss = (
45
+ loss.item()
46
+ - self.model.config.aux_loss_alpha * outputs.aux_loss.item()
47
+ - self.model.config.router_z_loss_alpha * outputs.router_z_loss.item()
48
+ )
49
+ self.log({"train/lm_loss": lm_loss})
50
+
51
+ return (loss, outputs) if return_outputs else loss
52
+
53
+ def prediction_step(self, model, inputs, prediction_loss_only, ignore_keys=None):
54
+ """
55
+ Überschreibt prediction_step um eval_loss korrekt zurückzugeben
56
+ """
57
+ # Labels sicherstellen
58
+ if "labels" not in inputs:
59
+ inputs["labels"] = inputs["input_ids"].clone()
60
+
61
+ # Standard prediction_step aufrufen
62
+ loss, logits, labels = super().prediction_step(
63
+ model, inputs, prediction_loss_only, ignore_keys
64
+ )
65
+
66
+ return loss, logits, labels
67
+
68
+ def log(self, logs: Dict[str, float], start_time=None) -> None:
69
+ """
70
+ Erweitert das Standard-Logging um MoE-spezifische Metriken
71
+ """
72
+ # GPU Memory Tracking
73
+ if torch.cuda.is_available():
74
+ logs["gpu_memory_allocated_gb"] = (
75
+ torch.cuda.memory_allocated() / 1024**3
76
+ )
77
+ logs["gpu_memory_reserved_gb"] = (
78
+ torch.cuda.memory_reserved() / 1024**3
79
+ )
80
+
81
+ if start_time is not None:
82
+ super().log(logs, start_time)
83
+ else:
84
+ super().log(logs)
85
+
86
+
87
+ class MoEEvalCallback(TrainerCallback):
88
+ """
89
+ Callback für erweiterte MoE-spezifische Evaluation
90
+ """
91
+
92
+ def on_evaluate(self, args, state, control, model, metrics=None, **kwargs):
93
+ """
94
+ Nach jeder Evaluation loggen wir zusätzliche MoE Metriken
95
+ """
96
+ if metrics is not None and model is not None:
97
+ # Model Statistiken
98
+ total_params = sum(p.numel() for p in model.parameters())
99
+ trainable_params = sum(
100
+ p.numel() for p in model.parameters() if p.requires_grad
101
+ )
102
+
103
+ metrics["model/total_params_M"] = total_params / 1e6
104
+ metrics["model/trainable_params_M"] = trainable_params / 1e6
105
+
106
+ # MoE Spezifisch
107
+ if hasattr(model.config, "n_experts"):
108
+ metrics["model/total_experts"] = model.config.total_experts
109
+ metrics["model/active_params_ratio"] = (
110
+ model.config.active_parameters_ratio
111
+ )
112
+
113
+
114
+ class DataCollatorForLanguageModeling:
115
+ """
116
+ Einfacher Data Collator für Causal Language Modeling.
117
+ Geht davon aus, dass Daten bereits tokenisiert sind.
118
+ """
119
+
120
+ def __init__(self, pad_token_id: int = 0):
121
+ self.pad_token_id = pad_token_id
122
+
123
+ def __call__(self, examples):
124
+ """
125
+ Args:
126
+ examples: Liste von Dicts mit 'input_ids' und 'attention_mask'
127
+
128
+ Returns:
129
+ Batch dict mit gepaddetem input_ids und attention_mask
130
+ """
131
+ # Maximale Länge in diesem Batch
132
+ max_length = max(len(ex["input_ids"]) for ex in examples)
133
+
134
+ input_ids = []
135
+ attention_mask = []
136
+
137
+ for ex in examples:
138
+ seq_len = len(ex["input_ids"])
139
+ padding_length = max_length - seq_len
140
+
141
+ # Padding rechts
142
+ padded_input_ids = ex["input_ids"] + [self.pad_token_id] * padding_length
143
+ padded_attention_mask = ex["attention_mask"] + [0] * padding_length
144
+
145
+ input_ids.append(padded_input_ids)
146
+ attention_mask.append(padded_attention_mask)
147
+
148
+ # Als Tensoren
149
+ batch = {
150
+ "input_ids": torch.tensor(input_ids, dtype=torch.long),
151
+ "attention_mask": torch.tensor(attention_mask, dtype=torch.long),
152
+ }
153
+
154
+ return batch
155
+
156
+
157
+ def compute_metrics(eval_preds):
158
+ """
159
+ Compute Perplexity für Evaluation
160
+ """
161
+ predictions, labels = eval_preds
162
+
163
+ # Für Language Modeling sind predictions die Logits
164
+ # Labels sind die tatsächlichen Token IDs
165
+ # Wir berechnen nur Perplexity hier (Loss wird automatisch geloggt)
166
+
167
+ # Diese Funktion ist optional - Loss wird bereits vom Trainer berechnet
168
+ return {}
requirements.txt ADDED
@@ -0,0 +1,96 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # German MoE GPT v6 - Requirements
2
+ # Environment: nano_moe (Conda)
3
+ # Python: 3.10+
4
+ # CUDA: 12.4
5
+
6
+ # ============================================================================
7
+ # CRITICAL: PyTorch Installation
8
+ # ============================================================================
9
+ # IMPORTANT: Install PyTorch FIRST with CUDA support!
10
+ # DO NOT use pip for PyTorch on Windows - use conda instead:
11
+ #
12
+ # conda install pytorch torchvision torchaudio pytorch-cuda=12.4 -c pytorch -c nvidia
13
+ #
14
+ # Or from PyTorch website (pip with CUDA):
15
+ # pip install torch==2.6.0+cu124 torchvision==0.21.0+cu124 torchaudio==2.6.0+cu124 --index-url https://download.pytorch.org/whl/cu124
16
+ #
17
+ # Current installed versions:
18
+ # torch==2.6.0+cu124
19
+ # torchvision==0.21.0+cu124
20
+ # torchaudio==2.6.0+cu124
21
+ # ============================================================================
22
+
23
+ # Core ML Libraries (install AFTER PyTorch!)
24
+ transformers==4.56.1
25
+ datasets==4.0.0
26
+ accelerate==1.10.1
27
+
28
+ # Training & Monitoring
29
+ tensorboard==2.20.0
30
+ tensorboard-data-server==0.7.2
31
+
32
+ # Tokenization
33
+ tokenizers==0.22.0
34
+ tiktoken==0.11.0
35
+
36
+ # Data Processing
37
+ numpy==1.26.4
38
+ pandas==2.3.2
39
+ pyarrow==21.0.0
40
+
41
+ # Utilities
42
+ tqdm==4.67.1
43
+ safetensors==0.6.2
44
+ huggingface-hub==0.34.4
45
+ regex==2025.9.1
46
+ fsspec==2025.3.0
47
+ dill==0.3.8
48
+ multiprocess==0.70.16
49
+ xxhash==3.5.0
50
+
51
+ # Performance (Windows CUDA)
52
+ triton-windows==3.2.0.post19 # Optimized kernels for CUDA
53
+
54
+ # Configuration & Logging
55
+ PyYAML==6.0.2
56
+ python-dotenv==1.0.1
57
+ requests==2.32.5
58
+ httpx[http2]==0.27.0
59
+
60
+ # Optional: Weights & Biases (uncomment if needed)
61
+ # wandb>=0.15.0
62
+
63
+ # ============================================================================
64
+ # Installation Instructions
65
+ # ============================================================================
66
+ #
67
+ # STEP 1: Create conda environment
68
+ # conda create -n nano_moe python=3.10
69
+ # conda activate nano_moe
70
+ #
71
+ # STEP 2: Install PyTorch with CUDA 12.4
72
+ # conda install pytorch torchvision torchaudio pytorch-cuda=12.4 -c pytorch -c nvidia
73
+ #
74
+ # STEP 3: Install remaining dependencies
75
+ # pip install -r requirements.txt --no-deps
76
+ # (--no-deps prevents pip from reinstalling PyTorch!)
77
+ #
78
+ # STEP 4: Verify installation
79
+ # python -c "import torch; print(f'CUDA available: {torch.cuda.is_available()}')"
80
+ #
81
+ # ============================================================================
82
+ # Notes
83
+ # ============================================================================
84
+ #
85
+ # - DO NOT install PyTorch via pip requirements.txt on Windows!
86
+ # It will install CPU version or wrong CUDA version
87
+ #
88
+ # - triton-windows only works on Windows with CUDA
89
+ # On Linux, use: triton>=2.0.0
90
+ #
91
+ # - datasets 4.0.0 has breaking changes from 2.x
92
+ # Use load_from_disk() / save_to_disk() for eval dataset
93
+ #
94
+ # - transformers 4.56.1 is compatible with our custom MoE implementation
95
+ #
96
+ # ============================================================================
sample_generation_callback.py ADDED
@@ -0,0 +1,148 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Sample Generation Callback für MoE Training
3
+ Generiert Texte während des Trainings um Fortschritt zu beobachten
4
+ """
5
+
6
+ import torch
7
+ from transformers import TrainerCallback, AutoTokenizer
8
+ from typing import Optional
9
+ import os
10
+
11
+
12
+ class SampleGenerationCallback(TrainerCallback):
13
+ """
14
+ Generiert Sample-Texte alle N Steps während des Trainings
15
+ """
16
+
17
+ def __init__(
18
+ self,
19
+ tokenizer,
20
+ prompts: list[str],
21
+ generate_every_n_steps: int = 100,
22
+ max_new_tokens: int = 50,
23
+ temperature: float = 0.8,
24
+ top_k: int = 50,
25
+ top_p: float = 0.95,
26
+ output_dir: str = "./samples",
27
+ ):
28
+ """
29
+ Args:
30
+ tokenizer: HuggingFace Tokenizer
31
+ prompts: Liste von Prompts für Generierung
32
+ generate_every_n_steps: Generiere alle N Steps
33
+ max_new_tokens: Max neue Tokens
34
+ temperature: Sampling Temperature
35
+ top_k: Top-k Sampling
36
+ top_p: Nucleus Sampling
37
+ output_dir: Ordner für Sample Outputs
38
+ """
39
+ self.tokenizer = tokenizer
40
+ self.prompts = prompts
41
+ self.generate_every_n_steps = generate_every_n_steps
42
+ self.max_new_tokens = max_new_tokens
43
+ self.temperature = temperature
44
+ self.top_k = top_k
45
+ self.top_p = top_p
46
+ self.output_dir = output_dir
47
+
48
+ # Output Ordner erstellen
49
+ os.makedirs(output_dir, exist_ok=True)
50
+
51
+ # Samples Log Datei
52
+ self.log_file = os.path.join(output_dir, "generation_log.txt")
53
+
54
+ # Header schreiben
55
+ with open(self.log_file, "w", encoding="utf-8") as f:
56
+ f.write("=" * 80 + "\n")
57
+ f.write("MoE Training - Sample Generation Log\n")
58
+ f.write("=" * 80 + "\n\n")
59
+
60
+ def on_step_end(self, args, state, control, model=None, **kwargs):
61
+ """
62
+ Wird nach jedem Training Step aufgerufen
63
+ """
64
+ # Nur alle N Steps generieren
65
+ if state.global_step % self.generate_every_n_steps != 0:
66
+ return
67
+
68
+ # Skip wenn kein Model
69
+ if model is None:
70
+ return
71
+
72
+ print(f"\n{'='*80}")
73
+ print(f"🎨 GENERATING SAMPLES @ STEP {state.global_step}")
74
+ print(f"{'='*80}\n")
75
+
76
+ # Model in Eval Mode
77
+ model.eval()
78
+
79
+ samples = []
80
+ samples.append(f"\n{'='*80}\n")
81
+ samples.append(f"Step: {state.global_step}\n")
82
+ samples.append(f"{'='*80}\n\n")
83
+
84
+ with torch.no_grad():
85
+ for i, prompt in enumerate(self.prompts, 1):
86
+ print(f"[{i}/{len(self.prompts)}] Prompt: '{prompt}'")
87
+
88
+ # Tokenize
89
+ input_ids = self.tokenizer.encode(prompt, return_tensors="pt")
90
+ input_ids = input_ids.to(model.device)
91
+
92
+ try:
93
+ # Generieren
94
+ # NOTE: repetition_penalty is REQUIRED for longer generations!
95
+ # For 300 tokens, 1.3-1.5 is better than 1.2
96
+ output_ids = model.generate(
97
+ input_ids,
98
+ max_new_tokens=self.max_new_tokens,
99
+ temperature=self.temperature,
100
+ top_k=self.top_k,
101
+ top_p=self.top_p,
102
+ repetition_penalty=1.4, # ← Higher for 300 tokens!
103
+ do_sample=True,
104
+ pad_token_id=self.tokenizer.eos_token_id,
105
+ )
106
+
107
+ # Decode
108
+ generated_text = self.tokenizer.decode(
109
+ output_ids[0], skip_special_tokens=True
110
+ )
111
+
112
+ # Ausgabe
113
+ print(f" → {generated_text}\n")
114
+
115
+ # Log speichern
116
+ samples.append(f"Prompt {i}: {prompt}\n")
117
+ samples.append(f"Output: {generated_text}\n\n")
118
+
119
+ except Exception as e:
120
+ error_msg = f" ❌ Error: {str(e)}\n"
121
+ print(error_msg)
122
+ samples.append(f"Prompt {i}: {prompt}\n")
123
+ samples.append(f"Error: {str(e)}\n\n")
124
+
125
+ # Samples in Datei schreiben
126
+ with open(self.log_file, "a", encoding="utf-8") as f:
127
+ f.writelines(samples)
128
+
129
+ print(f"{'='*80}\n")
130
+
131
+ # Model zurück in Training Mode
132
+ model.train()
133
+
134
+
135
+ def get_german_sample_prompts():
136
+ """
137
+ Gibt eine Liste deutscher Sample-Prompts zurück
138
+ """
139
+ return [
140
+ "Die Künstliche Intelligenz",
141
+ "Im finsteren Wald",
142
+ "In der Zukunft werden wir",
143
+ "Machine Learning bedeutet",
144
+ "Das Wetter heute ist",
145
+ "Ein wichtiger Aspekt der",
146
+ "Die Geschichte von",
147
+ "Wissenschaftler haben herausgefunden",
148
+ ]
train_moe_v8_clean.py ADDED
@@ -0,0 +1,429 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ German MoE GPT v8 - CLEAN DATA + OPUS EDITION
3
+ Training mit Wikipedia + OpenSubtitles + Belletristik
4
+
5
+ Datasets (v8 - CLEAN + DIALOGUES! 🎉):
6
+ - Clean Wikipedia (local) - 11 GB (64%)
7
+ - OpenSubtitles OPUS (local) - 4.2 GB (24%)
8
+ - Belletristik (arnomatic/merged_all) - 2.2 GB (12%)
9
+
10
+ Total: ~17.4 GB of 100% CLEAN German text!
11
+ NO spam, NO ads, NO SEO garbage! ✅
12
+ PLUS natural dialogues from movie subtitles! 🎬
13
+ """
14
+
15
+ import os
16
+ import sys
17
+
18
+ # Disable HF transfer (can cause issues on Windows)
19
+ os.environ["HF_HUB_ENABLE_HF_TRANSFER"] = "0"
20
+
21
+ # Force UTF-8 encoding for Windows console
22
+ if sys.platform == 'win32':
23
+ sys.stdout.reconfigure(encoding='utf-8')
24
+
25
+ import torch
26
+ from datasets import load_dataset, interleave_datasets
27
+ from transformers import TrainingArguments, set_seed, AutoTokenizer
28
+
29
+ from moe_config import MoEGPTConfig
30
+ from moe_model import MoEGPTForCausalLM
31
+ from moe_trainer import MoETrainer, MoEEvalCallback, DataCollatorForLanguageModeling
32
+ from sample_generation_callback import SampleGenerationCallback, get_german_sample_prompts
33
+
34
+
35
+ def load_clean_datasets(tokenizer, max_length=2048, seed=42, resume_step=0):
36
+ """
37
+ Lädt 3 clean datasets (v8 - INTERLEAVED!):
38
+ - Wikipedia (WITH EOS) - 64%
39
+ - OpenSubtitles OPUS (NO EOS) - 24%
40
+ - Belletristik (NO EOS) - 12%
41
+
42
+ Args:
43
+ resume_step: If > 0, adjusts seed to continue from checkpoint
44
+ """
45
+ # Adjust seed based on resume step (für reproducibility beim Resume)
46
+ effective_seed = seed + (resume_step // 1000)
47
+ print(f"📚 Lade CLEAN Datasets (v8 - OPUS Edition)...")
48
+ if resume_step > 0:
49
+ print(f" 🔄 Resume from step {resume_step} → Effective seed: {effective_seed}\n")
50
+ else:
51
+ print()
52
+
53
+ # ========================================================================
54
+ # 1. WIKIPEDIA (WITH EOS between articles)
55
+ # ========================================================================
56
+ print("1️⃣ Wikipedia (WITH EOS)...")
57
+ try:
58
+ wiki_ds = load_dataset(
59
+ "jonas-is-coding/german-wikipedia-articles",
60
+ split="train",
61
+ streaming=True
62
+ )
63
+ print(" ✅ Dataset loaded (streaming mode)")
64
+
65
+ # Shuffle
66
+ print(" 🔀 Shuffling with buffer_size=10,000...")
67
+ wiki_ds = wiki_ds.shuffle(seed=effective_seed, buffer_size=10000)
68
+ print(" ✅ Shuffle applied")
69
+
70
+ except Exception as e:
71
+ print(f" ❌ Wikipedia Error: {e}")
72
+ raise ValueError(f"Failed to load Wikipedia: {e}")
73
+
74
+ # ========================================================================
75
+ # 2. OPENSUBTITLES OPUS (NO EOS - continuous dialogues)
76
+ # ========================================================================
77
+ print("\n2️⃣ OpenSubtitles OPUS (NO EOS - continuous dialogues)...")
78
+ try:
79
+ opus_ds = load_dataset(
80
+ "arnomatic/german-opus-subtitles",
81
+ split="train",
82
+ streaming=True
83
+ )
84
+ print(" ✅ Dataset loaded (streaming mode)")
85
+
86
+ # Shuffle
87
+ print(" 🔀 Shuffling with buffer_size=10,000...")
88
+ opus_ds = opus_ds.shuffle(seed=effective_seed, buffer_size=10000)
89
+ print(" ✅ Shuffle applied")
90
+
91
+ except Exception as e:
92
+ print(f" ❌ OpenSubtitles Error: {e}")
93
+ raise ValueError(f"Failed to load OpenSubtitles: {e}")
94
+
95
+ # ========================================================================
96
+ # 3. BELLETRISTIK (NO EOS - continuous)
97
+ # ========================================================================
98
+ print("\n3️⃣ Belletristik (NO EOS - continuous)...")
99
+ try:
100
+ belle_ds = load_dataset(
101
+ "arnomatic/merged_all",
102
+ split="train",
103
+ streaming=True
104
+ )
105
+ print(" ✅ Dataset loaded (streaming mode)")
106
+
107
+ # Shuffle
108
+ print(" 🔀 Shuffling with buffer_size=10,000...")
109
+ belle_ds = belle_ds.shuffle(seed=effective_seed, buffer_size=10000)
110
+ print(" ✅ Shuffle applied")
111
+
112
+ except Exception as e:
113
+ print(f" ❌ Belletristik Error: {e}")
114
+ raise ValueError(f"Failed to load Belletristik: {e}")
115
+
116
+ print("\n✅ All datasets loaded!")
117
+ print(" Wikipedia: 4 GB (WITH EOS)")
118
+ print(" OpenSubtitles: 4.2 GB (NO EOS)")
119
+ print(" Belletristik: 2.2 GB (NO EOS)")
120
+ print(" Total: ~10.4 GB clean German!")
121
+
122
+ # ========================================================================
123
+ # DIRECT PACKING (no intermediate tokenization)
124
+ # ========================================================================
125
+ print("\n🔤 Tokenizing & Packing datasets...")
126
+
127
+ from datasets import IterableDataset as HFIterableDataset
128
+
129
+ def pack_dataset_with_eos(dataset, text_field='text'):
130
+ """Pack dataset WITH EOS directly into 2048-token batches"""
131
+ def gen():
132
+ buffer = []
133
+ for example in dataset:
134
+ text = example.get(text_field, '')
135
+ if not text or not text.strip():
136
+ continue
137
+
138
+ # Tokenize
139
+ tokens = tokenizer.encode(text, add_special_tokens=False)
140
+
141
+ # Add tokens + EOS
142
+ buffer.extend(tokens)
143
+ buffer.append(tokenizer.eos_token_id)
144
+
145
+ # Yield complete chunks
146
+ while len(buffer) >= max_length:
147
+ yield {
148
+ "input_ids": buffer[:max_length],
149
+ "attention_mask": [1] * max_length,
150
+ "labels": buffer[:max_length],
151
+ }
152
+ buffer = buffer[max_length:]
153
+
154
+ return HFIterableDataset.from_generator(gen)
155
+
156
+ def pack_dataset_no_eos(dataset, text_field='text'):
157
+ """Pack dataset WITHOUT EOS directly into 2048-token batches"""
158
+ def gen():
159
+ buffer = []
160
+ for example in dataset:
161
+ text = example.get(text_field, '')
162
+ if not text or not text.strip():
163
+ continue
164
+
165
+ # Tokenize
166
+ tokens = tokenizer.encode(text, add_special_tokens=False)
167
+
168
+ # Add tokens (NO EOS)
169
+ buffer.extend(tokens)
170
+
171
+ # Yield complete chunks
172
+ while len(buffer) >= max_length:
173
+ yield {
174
+ "input_ids": buffer[:max_length],
175
+ "attention_mask": [1] * max_length,
176
+ "labels": buffer[:max_length],
177
+ }
178
+ buffer = buffer[max_length:]
179
+
180
+ return HFIterableDataset.from_generator(gen)
181
+
182
+ print(" Wikipedia (WITH EOS)...")
183
+ wiki_batched = pack_dataset_with_eos(wiki_ds, text_field='content')
184
+
185
+ print(" OpenSubtitles (NO EOS)...")
186
+ opus_batched = pack_dataset_no_eos(opus_ds, text_field='text')
187
+
188
+ print(" Belletristik (NO EOS)...")
189
+ belle_batched = pack_dataset_no_eos(belle_ds, text_field='text')
190
+
191
+ print("✅ Batching complete!")
192
+
193
+ # ========================================================================
194
+ # INTERLEAVE DATASETS (64% Wiki, 24% OPUS, 12% Belle)
195
+ # ========================================================================
196
+ print("\n🔀 Interleaving datasets (64/24/12)...")
197
+
198
+ train_dataset = interleave_datasets(
199
+ [wiki_batched, opus_batched, belle_batched],
200
+ probabilities=[0.64, 0.24, 0.12],
201
+ seed=effective_seed,
202
+ stopping_strategy="all_exhausted"
203
+ )
204
+
205
+ print("✅ Datasets interleaved! (v8 strategy)")
206
+ print(" Wikipedia: 64%")
207
+ print(" OpenSubtitles: 24%")
208
+ print(" Belletristik: 12%")
209
+
210
+ # ========================================================================
211
+ # EVAL DATASET (fixed 500 samples from Wikipedia)
212
+ # ========================================================================
213
+ eval_dataset_path = "./eval_dataset_v8_clean"
214
+
215
+ if os.path.exists(eval_dataset_path):
216
+ print(f"\n📊 Loading existing eval dataset from {eval_dataset_path}...")
217
+ from datasets import load_from_disk
218
+ eval_dataset = load_from_disk(eval_dataset_path)
219
+ print(f"✅ Eval dataset loaded: {len(eval_dataset)} samples (from disk)")
220
+ else:
221
+ print("\n📊 Creating fixed eval set (500 samples from Wikipedia)...")
222
+
223
+ eval_samples = []
224
+ eval_iter = iter(wiki_batched)
225
+ for i in range(500):
226
+ try:
227
+ sample = next(eval_iter)
228
+ eval_samples.append(sample)
229
+ if (i + 1) % 100 == 0:
230
+ print(f" Collected {i+1}/500 samples...")
231
+ except StopIteration:
232
+ print(f" ⚠️ Only {i} eval samples available (dataset exhausted)")
233
+ break
234
+
235
+ if len(eval_samples) == 0:
236
+ raise ValueError("No eval samples collected! Dataset exhausted immediately.")
237
+
238
+ print(f" Collected {len(eval_samples)} samples total")
239
+
240
+ # Convert to regular Dataset (not streaming!)
241
+ from datasets import Dataset
242
+ eval_dataset = Dataset.from_dict({
243
+ key: [sample[key] for sample in eval_samples]
244
+ for key in eval_samples[0].keys()
245
+ })
246
+
247
+ # Save to disk
248
+ print(f"💾 Saving eval dataset to {eval_dataset_path}...")
249
+ eval_dataset.save_to_disk(eval_dataset_path)
250
+ print(f"✅ Eval dataset saved to disk!")
251
+
252
+ print(f" → No more fsspec cache leak!")
253
+ print(f" Training: Clean Mix (streaming)")
254
+ print(f" Eval: {len(eval_dataset)} samples (fixed, from disk)\n")
255
+
256
+ return train_dataset, eval_dataset
257
+
258
+
259
+ def main():
260
+ SEED = 42
261
+ set_seed(SEED)
262
+
263
+ # Config
264
+ config = MoEGPTConfig(
265
+ vocab_size=128256,
266
+ n_positions=2048,
267
+ n_embd=512,
268
+ n_layer=8,
269
+ n_head=8,
270
+ n_experts=8,
271
+ n_experts_active=2,
272
+ moe_layer_frequency=2,
273
+ capacity_factor=1.25,
274
+ eval_capacity_factor=2.0,
275
+ use_noisy_gating=True,
276
+ aux_loss_alpha=0.01,
277
+ router_z_loss_alpha=0.001,
278
+ bias=False,
279
+ dropout=0.1,
280
+ activation_function="gelu",
281
+ initializer_range=0.1,
282
+ rope_theta=10000.0,
283
+ )
284
+
285
+ print("\n🔧 Model Config:")
286
+ print(f" - Experten: {config.n_experts} (Top-{config.n_experts_active})")
287
+ print(f" - Parameter: {config.total_experts} MoE experts")
288
+
289
+ # Training Args
290
+ # Dataset: ~10.4 GB ≈ 2.5B tokens ≈ 1.2M batches (2048 tokens each)
291
+ # With batch size 32: ~38K steps per epoch
292
+ # ~1.3 epochs = ~50K steps (interleaved = more efficient)
293
+ training_args = TrainingArguments(
294
+ output_dir="./moe_checkpoints_v8_clean",
295
+ run_name="german_moe_v8_clean",
296
+ max_steps=200000,
297
+ per_device_train_batch_size=2,
298
+ per_device_eval_batch_size=2,
299
+ gradient_accumulation_steps=16,
300
+ learning_rate=6e-4,
301
+ warmup_steps=2000,
302
+ lr_scheduler_type="cosine",
303
+ weight_decay=0.1,
304
+ bf16=torch.cuda.is_bf16_supported() if torch.cuda.is_available() else False,
305
+ fp16=not torch.cuda.is_bf16_supported() if torch.cuda.is_available() else False,
306
+ logging_dir="./logs_v8_clean",
307
+ logging_steps=100,
308
+ logging_first_step=True,
309
+ report_to=["tensorboard"],
310
+ eval_strategy="steps",
311
+ eval_steps=1000, # Every 1K steps (more frequent than v7)
312
+ save_strategy="steps",
313
+ save_steps=1000,
314
+ save_total_limit=10,
315
+ dataloader_num_workers=0,
316
+ dataloader_pin_memory=True,
317
+ gradient_checkpointing=True,
318
+ seed=SEED,
319
+ load_best_model_at_end=False,
320
+ metric_for_best_model="eval_loss",
321
+ greater_is_better=False,
322
+ ignore_data_skip=True, # CRITICAL: Don't skip batches, use fresh shuffled data!
323
+ )
324
+
325
+ # Check for existing checkpoints (auto-resume) - DO THIS EARLY!
326
+ import glob
327
+ checkpoints = glob.glob(os.path.join(training_args.output_dir, "checkpoint-*"))
328
+ resume_from_checkpoint = None
329
+ resume_step = 0
330
+
331
+ if checkpoints:
332
+ latest_checkpoint = max(checkpoints, key=lambda x: int(x.split("-")[-1]))
333
+ resume_from_checkpoint = latest_checkpoint
334
+ resume_step = int(latest_checkpoint.split("-")[-1])
335
+ print(f"\n🔄 RESUME Training from: {latest_checkpoint} (Step {resume_step})")
336
+ else:
337
+ print("\n🆕 Starting fresh training (no checkpoints found)")
338
+
339
+ # Tokenizer
340
+ print("\n📚 Lade Tokenizer...")
341
+ tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-3.2-1B")
342
+ tokenizer.pad_token = tokenizer.eos_token
343
+ print("✅ Llama 3.2 Tokenizer geladen")
344
+
345
+ # Load Clean Datasets (with resume_step for reproducibility!)
346
+ train_dataset, eval_dataset = load_clean_datasets(
347
+ tokenizer=tokenizer,
348
+ max_length=2048,
349
+ seed=SEED,
350
+ resume_step=resume_step,
351
+ )
352
+
353
+ # Data Collator
354
+ data_collator = DataCollatorForLanguageModeling(pad_token_id=tokenizer.pad_token_id)
355
+
356
+ # Model
357
+ print("\n🏗️ Erstelle MoE Modell...")
358
+ model = MoEGPTForCausalLM(config)
359
+
360
+ # Ensure weight tying (especially after checkpoint load)
361
+ model.tie_weights()
362
+
363
+ total_params = sum(p.numel() for p in model.parameters())
364
+ print(f"✅ Modell erstellt! ({total_params/1e6:.1f}M params)")
365
+
366
+ # Callbacks
367
+ sample_callback = SampleGenerationCallback(
368
+ tokenizer=tokenizer,
369
+ prompts=get_german_sample_prompts(),
370
+ generate_every_n_steps=1000, # Every 1K steps - fast feedback!
371
+ max_new_tokens=500,
372
+ temperature=0.7,
373
+ top_p=0.7,
374
+ output_dir="./samples_v8_clean",
375
+ )
376
+
377
+ # Trainer
378
+ print("\n🚀 Initialisiere Trainer...")
379
+ trainer = MoETrainer(
380
+ model=model,
381
+ args=training_args,
382
+ train_dataset=train_dataset,
383
+ eval_dataset=eval_dataset,
384
+ data_collator=data_collator,
385
+ callbacks=[MoEEvalCallback(), sample_callback],
386
+ )
387
+
388
+ print("✅ Trainer bereit!")
389
+
390
+ print("\n" + "=" * 60)
391
+ print("🎯 STARTE TRAINING v8 - OPUS EDITION!")
392
+ print("=" * 60)
393
+ print("\nDataset Composition (INTERLEAVED!):")
394
+ print(" Wikipedia (WITH EOS): 64%")
395
+ print(" OpenSubtitles OPUS (NO EOS): 24%")
396
+ print(" Belletristik (NO EOS): 12%")
397
+ print("\nTotal: ~10.4 GB CLEAN German!")
398
+ print("NO spam, NO ads, NO SEO garbage! 🎉")
399
+ print("PLUS natural dialogues from movie subtitles! 🎬")
400
+ print("=" * 60 + "\n")
401
+
402
+ # Train with resume support
403
+ trainer.train(resume_from_checkpoint=resume_from_checkpoint)
404
+
405
+ # Save
406
+ print("\n💾 Speichere finales Modell...")
407
+ final_model_path = "./moe_final_v8_clean"
408
+ trainer.save_model(final_model_path)
409
+ config.save_pretrained(final_model_path)
410
+ print(f"✅ Modell gespeichert in: {final_model_path}")
411
+
412
+ # Eval
413
+ print("\n📊 Finale Evaluation...")
414
+ eval_results = trainer.evaluate()
415
+
416
+ for key, value in eval_results.items():
417
+ print(f" - {key}: {value:.4f}")
418
+
419
+ if "eval_loss" in eval_results:
420
+ perplexity = torch.exp(torch.tensor(eval_results["eval_loss"]))
421
+ print(f"\n🎯 Finale Perplexity: {perplexity:.2f}")
422
+
423
+ print("\n" + "=" * 60)
424
+ print("✅ TRAINING ABGESCHLOSSEN!")
425
+ print("=" * 60)
426
+
427
+
428
+ if __name__ == "__main__":
429
+ main()