File size: 5,879 Bytes
8cd0952
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
"""

Custom MoE Trainer mit erweiterten Logging-Funktionen

"""

import torch
from typing import Dict, Optional, Any
from transformers import Trainer
from transformers.trainer_callback import TrainerCallback


class MoETrainer(Trainer):
    """

    Erweiterter Trainer für MoE Modelle mit speziellem Logging für:

    - Auxiliary Losses (Load Balancing, Router Z-Loss)

    - Expert Utilization

    - Capacity Factor Anpassung

    """

    def compute_loss(self, model, inputs, return_outputs=False, num_items_in_batch=None):
        """

        Überschreibt compute_loss um MoE-spezifische Losses zu berücksichtigen.

        Diese sind bereits im model.forward() eingerechnet, aber wir loggen sie separat.

        """
        # Labels für next token prediction
        if "labels" not in inputs:
            inputs["labels"] = inputs["input_ids"].clone()

        # Forward pass
        outputs = model(**inputs)

        # Loss ist bereits total loss (LM + aux losses)
        loss = outputs.loss

        # Logging der Auxiliary Losses (wenn im Training)
        if self.state.global_step % self.args.logging_steps == 0:
            if hasattr(outputs, "aux_loss") and outputs.aux_loss is not None:
                self.log({"train/aux_loss": outputs.aux_loss.item()})

            if hasattr(outputs, "router_z_loss") and outputs.router_z_loss is not None:
                self.log({"train/router_z_loss": outputs.router_z_loss.item()})

            # Gesamter Loss breakdown
            if hasattr(outputs, "aux_loss") and outputs.aux_loss is not None:
                lm_loss = (
                    loss.item()
                    - self.model.config.aux_loss_alpha * outputs.aux_loss.item()
                    - self.model.config.router_z_loss_alpha * outputs.router_z_loss.item()
                )
                self.log({"train/lm_loss": lm_loss})

        return (loss, outputs) if return_outputs else loss

    def prediction_step(self, model, inputs, prediction_loss_only, ignore_keys=None):
        """

        Überschreibt prediction_step um eval_loss korrekt zurückzugeben

        """
        # Labels sicherstellen
        if "labels" not in inputs:
            inputs["labels"] = inputs["input_ids"].clone()

        # Standard prediction_step aufrufen
        loss, logits, labels = super().prediction_step(
            model, inputs, prediction_loss_only, ignore_keys
        )

        return loss, logits, labels

    def log(self, logs: Dict[str, float], start_time=None) -> None:
        """

        Erweitert das Standard-Logging um MoE-spezifische Metriken

        """
        # GPU Memory Tracking
        if torch.cuda.is_available():
            logs["gpu_memory_allocated_gb"] = (
                torch.cuda.memory_allocated() / 1024**3
            )
            logs["gpu_memory_reserved_gb"] = (
                torch.cuda.memory_reserved() / 1024**3
            )

        if start_time is not None:
            super().log(logs, start_time)
        else:
            super().log(logs)


class MoEEvalCallback(TrainerCallback):
    """

    Callback für erweiterte MoE-spezifische Evaluation

    """

    def on_evaluate(self, args, state, control, model, metrics=None, **kwargs):
        """

        Nach jeder Evaluation loggen wir zusätzliche MoE Metriken

        """
        if metrics is not None and model is not None:
            # Model Statistiken
            total_params = sum(p.numel() for p in model.parameters())
            trainable_params = sum(
                p.numel() for p in model.parameters() if p.requires_grad
            )

            metrics["model/total_params_M"] = total_params / 1e6
            metrics["model/trainable_params_M"] = trainable_params / 1e6

            # MoE Spezifisch
            if hasattr(model.config, "n_experts"):
                metrics["model/total_experts"] = model.config.total_experts
                metrics["model/active_params_ratio"] = (
                    model.config.active_parameters_ratio
                )


class DataCollatorForLanguageModeling:
    """

    Einfacher Data Collator für Causal Language Modeling.

    Geht davon aus, dass Daten bereits tokenisiert sind.

    """

    def __init__(self, pad_token_id: int = 0):
        self.pad_token_id = pad_token_id

    def __call__(self, examples):
        """

        Args:

            examples: Liste von Dicts mit 'input_ids' und 'attention_mask'



        Returns:

            Batch dict mit gepaddetem input_ids und attention_mask

        """
        # Maximale Länge in diesem Batch
        max_length = max(len(ex["input_ids"]) for ex in examples)

        input_ids = []
        attention_mask = []

        for ex in examples:
            seq_len = len(ex["input_ids"])
            padding_length = max_length - seq_len

            # Padding rechts
            padded_input_ids = ex["input_ids"] + [self.pad_token_id] * padding_length
            padded_attention_mask = ex["attention_mask"] + [0] * padding_length

            input_ids.append(padded_input_ids)
            attention_mask.append(padded_attention_mask)

        # Als Tensoren
        batch = {
            "input_ids": torch.tensor(input_ids, dtype=torch.long),
            "attention_mask": torch.tensor(attention_mask, dtype=torch.long),
        }

        return batch


def compute_metrics(eval_preds):
    """

    Compute Perplexity für Evaluation

    """
    predictions, labels = eval_preds

    # Für Language Modeling sind predictions die Logits
    # Labels sind die tatsächlichen Token IDs
    # Wir berechnen nur Perplexity hier (Loss wird automatisch geloggt)

    # Diese Funktion ist optional - Loss wird bereits vom Trainer berechnet
    return {}