|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
"""Pretrain utilities.""" |
|
|
from datetime import datetime |
|
|
from functools import partial |
|
|
|
|
|
import math |
|
|
import sys |
|
|
|
|
|
import torch |
|
|
import deepspeed |
|
|
import numpy as np |
|
|
|
|
|
from megatron.utils import ( |
|
|
Timers, |
|
|
init_wandb, |
|
|
get_ltor_masks_and_position_ids, |
|
|
reduce_losses, |
|
|
) |
|
|
|
|
|
|
|
|
from megatron import print_rank_0, mpu |
|
|
from megatron.model import ( |
|
|
GPT2ModelPipe, |
|
|
SoftEmbedding, |
|
|
get_params_for_weight_decay_optimization, |
|
|
) |
|
|
from megatron.checkpointing import load_checkpoint, save_checkpoint |
|
|
from megatron.data.data_utils import build_train_valid_test_data_iterators |
|
|
from megatron.initialize import initialize_megatron |
|
|
from megatron.learning_rates import AnnealingLR |
|
|
from megatron.logging import tb_wandb_log, training_log |
|
|
from megatron.utils import ( |
|
|
OverflowMonitor, |
|
|
get_noise_scale_logger, |
|
|
get_total_params, |
|
|
CharCounter, |
|
|
) |
|
|
from megatron.model.gpt2_model import cross_entropy |
|
|
from eval_tasks import run_eval_harness |
|
|
|
|
|
|
|
|
def pretrain(neox_args): |
|
|
"""Main training program. |
|
|
|
|
|
This function will run the following in the order provided: |
|
|
1) initialize Megatron. |
|
|
2) setup model, optimizer and lr schedule |
|
|
3) call train_val_test_data_provider to get train/val/test datasets. |
|
|
4) train the model. |
|
|
|
|
|
Arguments: |
|
|
neox_args: an instance of NeoXArgs containing the configuration for pretrain |
|
|
|
|
|
""" |
|
|
|
|
|
init_wandb(neox_args=neox_args) |
|
|
timers = Timers( |
|
|
use_wandb=neox_args.use_wandb, tensorboard_writer=neox_args.tensorboard_writer |
|
|
) |
|
|
|
|
|
|
|
|
initialize_megatron(neox_args=neox_args) |
|
|
|
|
|
|
|
|
timers("model and optimizer").start() |
|
|
model, optimizer, lr_scheduler = setup_model_and_optimizer( |
|
|
neox_args=neox_args, use_cache=False |
|
|
) |
|
|
timers("model and optimizer").stop() |
|
|
|
|
|
|
|
|
timers("train/valid/test data iterators").start() |
|
|
( |
|
|
train_data_iterator, |
|
|
valid_data_iterator, |
|
|
test_data_iterator, |
|
|
) = build_train_valid_test_data_iterators(neox_args=neox_args) |
|
|
timers("train/valid/test data iterators").stop() |
|
|
|
|
|
|
|
|
print_rank_0("done with setups ...") |
|
|
timers.log(["model and optimizer", "train/valid/test data iterators"]) |
|
|
print_rank_0("training ...") |
|
|
|
|
|
iteration = neox_args.iteration |
|
|
if neox_args.do_train and neox_args.train_iters > 0: |
|
|
|
|
|
if neox_args.save and 0 in neox_args.save_iters and iteration == 0: |
|
|
save_checkpoint( |
|
|
neox_args=neox_args, |
|
|
iteration=iteration, |
|
|
model=model, |
|
|
optimizer=optimizer, |
|
|
lr_scheduler=lr_scheduler, |
|
|
) |
|
|
|
|
|
iteration = train( |
|
|
neox_args=neox_args, |
|
|
timers=timers, |
|
|
model=model, |
|
|
optimizer=optimizer, |
|
|
lr_scheduler=lr_scheduler, |
|
|
train_data_iterator=train_data_iterator, |
|
|
valid_data_iterator=valid_data_iterator, |
|
|
) |
|
|
|
|
|
if neox_args.do_valid: |
|
|
prefix = "the end of training for val data" |
|
|
evaluate_and_print_results( |
|
|
neox_args=neox_args, |
|
|
prefix=prefix, |
|
|
forward_step_func=forward_step, |
|
|
data_iterator=valid_data_iterator, |
|
|
model=model, |
|
|
iteration=iteration, |
|
|
verbose=False, |
|
|
timers=timers, |
|
|
) |
|
|
|
|
|
if neox_args.save and iteration != 0: |
|
|
save_checkpoint( |
|
|
neox_args=neox_args, |
|
|
iteration=iteration, |
|
|
model=model, |
|
|
optimizer=optimizer, |
|
|
lr_scheduler=lr_scheduler, |
|
|
) |
|
|
|
|
|
if neox_args.do_test: |
|
|
|
|
|
prefix = "the end of training for test data" |
|
|
evaluate_and_print_results( |
|
|
neox_args=neox_args, |
|
|
prefix=prefix, |
|
|
forward_step_func=forward_step, |
|
|
data_iterator=test_data_iterator, |
|
|
model=model, |
|
|
iteration=iteration, |
|
|
verbose=True, |
|
|
timers=timers, |
|
|
chart_name="test", |
|
|
) |
|
|
|
|
|
|
|
|
def _get_batch(neox_args, tokenizer, keys, data, datatype): |
|
|
"""Support function for get_batch / get_batch pipe (to avoid code repetition)""" |
|
|
data_b = mpu.broadcast_data(keys, data, datatype) |
|
|
|
|
|
|
|
|
tokens_ = data_b["text"].long() |
|
|
labels = tokens_[:, 1:].contiguous() |
|
|
tokens = tokens_[:, :-1].contiguous() |
|
|
|
|
|
|
|
|
attention_mask, loss_mask, position_ids = get_ltor_masks_and_position_ids( |
|
|
data=tokens, |
|
|
eod_token=neox_args.tokenizer.eod, |
|
|
eod_mask_loss=neox_args.eod_mask_loss, |
|
|
) |
|
|
|
|
|
return tokens, labels, loss_mask, attention_mask, position_ids |
|
|
|
|
|
|
|
|
def get_batch(neox_args, data_iterator): |
|
|
"""Generate a batch""" |
|
|
|
|
|
|
|
|
keys = ["text"] |
|
|
datatype = torch.int64 |
|
|
|
|
|
|
|
|
if data_iterator is not None: |
|
|
data = next(data_iterator) |
|
|
else: |
|
|
data = None |
|
|
return _get_batch( |
|
|
neox_args=neox_args, |
|
|
tokenizer=neox_args.tokenizer, |
|
|
keys=keys, |
|
|
data=data, |
|
|
datatype=datatype, |
|
|
) |
|
|
|
|
|
|
|
|
def get_batch_pipe(data, neox_args): |
|
|
"""A modification of get_batch() to work with the latest batch instead of an iterator.""" |
|
|
|
|
|
keys = ["text"] |
|
|
datatype = torch.int64 |
|
|
|
|
|
tokens, labels, loss_mask, attention_mask, position_ids = _get_batch( |
|
|
neox_args, neox_args.tokenizer, keys, data, datatype |
|
|
) |
|
|
|
|
|
|
|
|
return (tokens, position_ids, attention_mask), (labels, loss_mask) |
|
|
|
|
|
|
|
|
def forward_step(data_iterator, model, neox_args, timers, return_logits=False): |
|
|
"""Forward step.""" |
|
|
if neox_args.is_pipe_parallel: |
|
|
return model.eval_batch(data_iterator, return_logits=return_logits) |
|
|
|
|
|
|
|
|
if timers is not None: |
|
|
timers("batch generator").start() |
|
|
tokens, labels, loss_mask, attention_mask, position_ids = get_batch( |
|
|
neox_args=neox_args, data_iterator=data_iterator |
|
|
) |
|
|
if timers is not None: |
|
|
timers("batch generator").stop() |
|
|
|
|
|
outputs = model((tokens, position_ids, attention_mask)) |
|
|
loss = cross_entropy( |
|
|
outputs, (labels, loss_mask), _fp16=neox_args.fp16_lm_cross_entropy |
|
|
) |
|
|
if return_logits: |
|
|
return loss, outputs |
|
|
return loss |
|
|
|
|
|
|
|
|
def get_model(neox_args, use_cache=False): |
|
|
"""Build the model.""" |
|
|
|
|
|
print_rank_0("building GPT2 model ...") |
|
|
|
|
|
|
|
|
model = GPT2ModelPipe( |
|
|
neox_args=neox_args, |
|
|
num_tokentypes=0, |
|
|
parallel_output=True, |
|
|
topology=mpu.get_topology(), |
|
|
use_cache=use_cache, |
|
|
) |
|
|
|
|
|
|
|
|
if neox_args.soft_prompt_tuning is not None and neox_args.soft_prompt_tuning.get( |
|
|
"enabled", False |
|
|
): |
|
|
soft_prompt = SoftEmbedding( |
|
|
neox_args, |
|
|
wte=getattr(model, "0").word_embeddings, |
|
|
n_tokens=neox_args.soft_prompt_tuning.get("n_tokens", 10), |
|
|
init_string=neox_args.soft_prompt_tuning.get("init_string", ""), |
|
|
init_range=neox_args.soft_prompt_tuning.get("init_range", 0.5), |
|
|
) |
|
|
model.insert_layers( |
|
|
layers=soft_prompt, idx=1 |
|
|
) |
|
|
|
|
|
|
|
|
for name, param in model.named_parameters(): |
|
|
if not "soft_embedding" in name: |
|
|
param.requires_grad = False |
|
|
|
|
|
if not neox_args.is_pipe_parallel: |
|
|
|
|
|
model = model.to_sequential() |
|
|
|
|
|
if neox_args.deepspeed: |
|
|
|
|
|
return model |
|
|
else: |
|
|
raise ValueError("Must be using deepspeed to run neox") |
|
|
|
|
|
|
|
|
def get_optimizer(model, neox_args): |
|
|
"""Set up the optimizer.""" |
|
|
if neox_args.no_load_optim: |
|
|
return None, None |
|
|
|
|
|
param_groups = get_params_for_weight_decay_optimization(model, neox_args) |
|
|
print_rank_0( |
|
|
f'Configuring Optimizer type: {neox_args.optimizer_type} with params: {neox_args.optimizer["params"]}' |
|
|
) |
|
|
|
|
|
|
|
|
for param_group in param_groups: |
|
|
for param in param_group["params"]: |
|
|
if not hasattr(param, "model_parallel"): |
|
|
param.model_parallel = False |
|
|
|
|
|
|
|
|
_param_groups = [] |
|
|
for param_group in param_groups: |
|
|
trainable_params = [p for p in param_group["params"] if p.requires_grad] |
|
|
param_group["params"] = trainable_params |
|
|
_param_groups.append(param_group) |
|
|
param_groups = _param_groups |
|
|
|
|
|
if neox_args.optimizer_type.lower() in ["cpu_adam", "cpu_torch_adam"]: |
|
|
if neox_args.optimizer == "cpu_torch_adam": |
|
|
cpu_adam_optimizer = torch.optim.Adam |
|
|
else: |
|
|
from deepspeed.ops.adam import DeepSpeedCPUAdam |
|
|
|
|
|
cpu_adam_optimizer = DeepSpeedCPUAdam |
|
|
optimizer = cpu_adam_optimizer( |
|
|
param_groups, |
|
|
weight_decay=neox_args.weight_decay, |
|
|
**neox_args.optimizer["params"], |
|
|
) |
|
|
elif neox_args.optimizer_type.lower() == "onebitadam": |
|
|
assert neox_args.deepspeed |
|
|
optimizer = None |
|
|
|
|
|
elif neox_args.optimizer_type.lower() == "sm3": |
|
|
from .optimizers import SM3 |
|
|
|
|
|
optimizer = SM3(param_groups, **neox_args.optimizer["params"]) |
|
|
elif neox_args.optimizer_type.lower() == "madgrad_wd": |
|
|
from .optimizers import madgrad_wd |
|
|
|
|
|
optimizer = madgrad_wd( |
|
|
param_groups, |
|
|
weight_decay=neox_args.weight_decay, |
|
|
**neox_args.optimizer["params"], |
|
|
) |
|
|
elif neox_args.optimizer_type.lower() == "adam": |
|
|
|
|
|
if neox_args.use_bnb_optimizer: |
|
|
try: |
|
|
import bitsandbytes as bnb |
|
|
|
|
|
adam_optimizer = bnb.optim.Adam8bit |
|
|
except ModuleNotFoundError: |
|
|
print( |
|
|
"Please install bitsandbytes following https://github.com/facebookresearch/bitsandbytes." |
|
|
) |
|
|
raise Exception |
|
|
else: |
|
|
try: |
|
|
|
|
|
from apex.optimizers import FusedAdam as Adam |
|
|
except ImportError: |
|
|
|
|
|
print( |
|
|
"WARNING: APEX not installed - defaulting to deepspeed's fused adam" |
|
|
) |
|
|
from deepspeed.ops.adam import FusedAdam as Adam |
|
|
adam_optimizer = Adam |
|
|
optimizer = adam_optimizer( |
|
|
param_groups, |
|
|
weight_decay=neox_args.weight_decay, |
|
|
**neox_args.optimizer["params"], |
|
|
) |
|
|
else: |
|
|
raise ValueError(f"Optimizer type {neox_args.optimizer_type} not recognized") |
|
|
|
|
|
if neox_args.deepspeed: |
|
|
|
|
|
return optimizer, param_groups |
|
|
else: |
|
|
raise ValueError("Must be using deepspeed to run neox") |
|
|
|
|
|
|
|
|
def get_learning_rate_scheduler(optimizer, neox_args): |
|
|
"""Build the learning rate scheduler.""" |
|
|
if neox_args.no_load_optim: |
|
|
|
|
|
return None |
|
|
if neox_args.deepspeed and neox_args.optimizer_type.lower() == "onebitadam": |
|
|
print_rank_0( |
|
|
"WARNING: onebitadam requires the lr scheduler be built by deepspeed - " |
|
|
"Make sure one is added to your deepspeed config" |
|
|
) |
|
|
return None |
|
|
|
|
|
|
|
|
if neox_args.lr_decay_iters is not None: |
|
|
num_iters = neox_args.lr_decay_iters |
|
|
else: |
|
|
num_iters = neox_args.train_iters |
|
|
num_iters = max(1, num_iters) |
|
|
init_step = 0 |
|
|
warmup_iter = neox_args.warmup * num_iters |
|
|
lr_scheduler = AnnealingLR( |
|
|
optimizer, |
|
|
start_lr=neox_args.lr, |
|
|
warmup_iter=warmup_iter, |
|
|
total_iters=num_iters, |
|
|
decay_style=neox_args.lr_decay_style, |
|
|
last_iter=init_step, |
|
|
min_lr=neox_args.min_lr, |
|
|
use_checkpoint_lr_scheduler=neox_args.use_checkpoint_lr_scheduler, |
|
|
override_lr_scheduler=neox_args.override_lr_scheduler, |
|
|
) |
|
|
|
|
|
return lr_scheduler |
|
|
|
|
|
|
|
|
def setup_model_and_optimizer(neox_args, use_cache=False, iteration=None): |
|
|
"""Setup model and optimizer.""" |
|
|
model = get_model(neox_args=neox_args, use_cache=use_cache) |
|
|
optimizer, param_groups = get_optimizer(model=model, neox_args=neox_args) |
|
|
lr_scheduler = get_learning_rate_scheduler(optimizer=optimizer, neox_args=neox_args) |
|
|
|
|
|
if neox_args.deepspeed: |
|
|
print_rank_0("DeepSpeed is enabled.") |
|
|
if neox_args.no_load_optim: |
|
|
assert optimizer is None |
|
|
_model_params = None |
|
|
_lr_scheduler = None |
|
|
else: |
|
|
_model_params = param_groups if optimizer is None else None |
|
|
_lr_scheduler = lr_scheduler |
|
|
|
|
|
model, optimizer, _, lr_scheduler = deepspeed.initialize( |
|
|
model=model, |
|
|
optimizer=optimizer, |
|
|
args=neox_args, |
|
|
lr_scheduler=_lr_scheduler, |
|
|
dist_init_required=False, |
|
|
model_parameters=_model_params, |
|
|
config_params=neox_args.deepspeed_config, |
|
|
mpu=mpu if not neox_args.is_pipe_parallel else None, |
|
|
) |
|
|
model.total_params = get_total_params(model.module) |
|
|
print_rank_0(f' > total params: {"{:,}".format(model.total_params)}') |
|
|
|
|
|
if neox_args.is_pipe_parallel: |
|
|
model.set_has_attention_mask(True) |
|
|
model.set_batch_fn(partial(get_batch_pipe, neox_args=neox_args)) |
|
|
else: |
|
|
raise ValueError("Must be using deepspeed to run neox") |
|
|
|
|
|
if neox_args.load is not None: |
|
|
neox_args.iteration = load_checkpoint( |
|
|
neox_args=neox_args, |
|
|
model=model, |
|
|
optimizer=optimizer, |
|
|
lr_scheduler=lr_scheduler, |
|
|
iteration=iteration, |
|
|
) |
|
|
print_rank_0( |
|
|
f"Loading checkpoint and starting from iteration {neox_args.iteration}" |
|
|
) |
|
|
else: |
|
|
neox_args.iteration = 0 |
|
|
|
|
|
return model, optimizer, lr_scheduler |
|
|
|
|
|
|
|
|
def backward_step(neox_args, timers, optimizer, model, loss): |
|
|
"""Backward step.""" |
|
|
|
|
|
|
|
|
timers("backward-backward").start() |
|
|
if neox_args.deepspeed: |
|
|
model.backward(loss) |
|
|
else: |
|
|
raise ValueError("Must be using deepspeed to run neox") |
|
|
timers("backward-backward").stop() |
|
|
|
|
|
if neox_args.deepspeed: |
|
|
|
|
|
|
|
|
timers("backward-allreduce").reset() |
|
|
else: |
|
|
raise ValueError("Must be using deepspeed to run neox") |
|
|
|
|
|
|
|
|
def train_step(neox_args, timers, data_iterator, model, optimizer, lr_scheduler): |
|
|
"""Single training step.""" |
|
|
|
|
|
|
|
|
if neox_args.is_pipe_parallel: |
|
|
reduced_loss = train_step_pipe( |
|
|
neox_args=neox_args, timers=timers, model=model, data_iterator=data_iterator |
|
|
) |
|
|
else: |
|
|
losses = [] |
|
|
for _ in range(neox_args.gradient_accumulation_steps): |
|
|
|
|
|
timers("forward").start() |
|
|
loss = forward_step( |
|
|
neox_args=neox_args, |
|
|
timers=timers, |
|
|
data_iterator=data_iterator, |
|
|
model=model, |
|
|
) |
|
|
timers("forward").stop() |
|
|
losses.append(loss) |
|
|
|
|
|
timers("backward").start() |
|
|
backward_step( |
|
|
neox_args=neox_args, |
|
|
timers=timers, |
|
|
optimizer=optimizer, |
|
|
model=model, |
|
|
loss=loss, |
|
|
) |
|
|
timers("backward").stop() |
|
|
|
|
|
timers("optimizer").start() |
|
|
if neox_args.deepspeed: |
|
|
model.step() |
|
|
else: |
|
|
raise ValueError("Must be using deepspeed to run neox") |
|
|
timers("optimizer").stop() |
|
|
reduced_loss = { |
|
|
"lm_loss": reduce_losses(losses).mean() |
|
|
} |
|
|
|
|
|
if neox_args.precision == "fp16" and model.optimizer.overflow: |
|
|
skipped_iter = 1 |
|
|
else: |
|
|
skipped_iter = 0 |
|
|
|
|
|
return reduced_loss, skipped_iter |
|
|
|
|
|
|
|
|
def train_step_pipe(neox_args, timers, model, data_iterator): |
|
|
"""Single training step with DeepSpeed's pipeline parallel engine.""" |
|
|
|
|
|
assert neox_args.deepspeed |
|
|
loss = model.train_batch(data_iter=data_iterator) |
|
|
loss_dict = {"lm_loss": loss} |
|
|
|
|
|
for t in [ |
|
|
"forward", |
|
|
"backward", |
|
|
"allreduce", |
|
|
"optimizer", |
|
|
"batch generator", |
|
|
"data loader", |
|
|
]: |
|
|
timers(t).reset() |
|
|
return loss_dict |
|
|
|
|
|
|
|
|
def train( |
|
|
neox_args, |
|
|
timers, |
|
|
model, |
|
|
optimizer, |
|
|
lr_scheduler, |
|
|
train_data_iterator, |
|
|
valid_data_iterator, |
|
|
): |
|
|
"""Train the model function.""" |
|
|
|
|
|
|
|
|
model.train() |
|
|
|
|
|
|
|
|
total_loss_dict = {} |
|
|
|
|
|
|
|
|
iteration = neox_args.iteration |
|
|
|
|
|
timers("interval time").start() |
|
|
report_memory_flag = True |
|
|
|
|
|
|
|
|
noise_scale_logger = get_noise_scale_logger(neox_args) |
|
|
|
|
|
|
|
|
overflow_monitor = OverflowMonitor(optimizer) |
|
|
while iteration < neox_args.train_iters: |
|
|
loss_dict, skipped_iter = train_step( |
|
|
neox_args=neox_args, |
|
|
timers=timers, |
|
|
data_iterator=train_data_iterator, |
|
|
model=model, |
|
|
optimizer=optimizer, |
|
|
lr_scheduler=lr_scheduler, |
|
|
) |
|
|
iteration += 1 |
|
|
|
|
|
overflow_monitor.check(skipped_iter) |
|
|
if neox_args.log_gradient_noise_scale: |
|
|
noise_scale_logger.update() |
|
|
|
|
|
|
|
|
|
|
|
if optimizer.param_groups: |
|
|
lr = optimizer.param_groups[0].get("lr", 0) |
|
|
else: |
|
|
lr = 0 |
|
|
|
|
|
|
|
|
report_memory_flag = training_log( |
|
|
neox_args=neox_args, |
|
|
timers=timers, |
|
|
loss_dict=loss_dict, |
|
|
total_loss_dict=total_loss_dict, |
|
|
learning_rate=lr, |
|
|
iteration=iteration, |
|
|
loss_scale=optimizer.cur_scale if neox_args.precision == "fp16" else None, |
|
|
report_memory_flag=report_memory_flag, |
|
|
skipped_iter=skipped_iter, |
|
|
model=model, |
|
|
optimizer=optimizer, |
|
|
noise_scale_logger=noise_scale_logger, |
|
|
) |
|
|
|
|
|
|
|
|
if neox_args.save and iteration in neox_args.save_iters: |
|
|
save_checkpoint( |
|
|
neox_args=neox_args, |
|
|
iteration=iteration, |
|
|
model=model, |
|
|
optimizer=optimizer, |
|
|
lr_scheduler=lr_scheduler, |
|
|
) |
|
|
|
|
|
|
|
|
if ( |
|
|
neox_args.eval_interval |
|
|
and iteration % neox_args.eval_interval == 0 |
|
|
and neox_args.do_valid |
|
|
): |
|
|
prefix = "iteration {}".format(iteration) |
|
|
evaluate_and_print_results( |
|
|
neox_args=neox_args, |
|
|
prefix=prefix, |
|
|
forward_step_func=forward_step, |
|
|
data_iterator=valid_data_iterator, |
|
|
model=model, |
|
|
iteration=iteration, |
|
|
verbose=False, |
|
|
timers=timers, |
|
|
) |
|
|
|
|
|
if neox_args.exit_interval and iteration % neox_args.exit_interval == 0: |
|
|
torch.distributed.barrier() |
|
|
time_str = datetime.now().strftime("%Y-%m-%d %H:%M:%S") |
|
|
rank = torch.distributed.get_rank() |
|
|
print_rank_0( |
|
|
"rank: {} | time: {} | exiting the program at iteration {}".format( |
|
|
rank, time_str, iteration |
|
|
) |
|
|
) |
|
|
sys.exit() |
|
|
|
|
|
return iteration |
|
|
|
|
|
|
|
|
def evaluate( |
|
|
neox_args, forward_step_fn, data_iterator, model, verbose=False, timers=None |
|
|
): |
|
|
"""Evaluation. |
|
|
neox_args: NeoX Arguments |
|
|
forward_step_fn: function with args `neox_args, timers, |
|
|
data_iterator & model that will run a forward pass on the model |
|
|
data_iterator: Iterator that iterates over batches of data. Should return data in the form: |
|
|
{'text': np.array([tokens], dtype=np.int64)} |
|
|
where the size of the array is the model's context size + 1 |
|
|
(`get_batch` transforms it into inputs / labels) |
|
|
""" |
|
|
|
|
|
model.eval() |
|
|
losses = [] |
|
|
if neox_args.char_level_ppl: |
|
|
data_iterator = CharCounter(data_iterator, neox_args.tokenizer) |
|
|
|
|
|
with torch.no_grad(): |
|
|
iteration = 0 |
|
|
while iteration < neox_args.eval_iters: |
|
|
iteration += 1 |
|
|
if verbose and iteration % neox_args.log_interval == 0: |
|
|
print_rank_0( |
|
|
"Evaluating iter {}/{}".format(iteration, neox_args.eval_iters) |
|
|
) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
for _ in range( |
|
|
1 |
|
|
if neox_args.is_pipe_parallel |
|
|
else neox_args.gradient_accumulation_steps |
|
|
): |
|
|
|
|
|
loss = forward_step_fn( |
|
|
model=model, |
|
|
data_iterator=data_iterator, |
|
|
neox_args=neox_args, |
|
|
timers=timers, |
|
|
) |
|
|
losses.append(loss) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if neox_args.deepspeed and neox_args.deepspeed_activation_checkpointing: |
|
|
deepspeed.checkpointing.reset() |
|
|
|
|
|
|
|
|
eval_results = {"lm_loss": reduce_losses(losses).mean().item()} |
|
|
eval_results["lm_loss_ppl"] = math.exp(eval_results["lm_loss"]) |
|
|
|
|
|
if neox_args.char_level_ppl: |
|
|
|
|
|
|
|
|
|
|
|
tokens_per_char = data_iterator.tokens_per_char() |
|
|
print_rank_0(f"Counting chars took {data_iterator.total_time} seconds") |
|
|
|
|
|
data_iterator = data_iterator.data_iterator |
|
|
eval_results["lm_loss_char_lvl_ppl"] = math.exp( |
|
|
eval_results["lm_loss"] * tokens_per_char |
|
|
) |
|
|
|
|
|
if neox_args.eval_tasks: |
|
|
eval_results.update( |
|
|
run_eval_harness( |
|
|
model, forward_step_fn, neox_args, eval_tasks=neox_args.eval_tasks |
|
|
).get("results") |
|
|
) |
|
|
|
|
|
model.train() |
|
|
return eval_results |
|
|
|
|
|
|
|
|
def evaluate_and_print_results( |
|
|
neox_args, |
|
|
prefix, |
|
|
forward_step_func, |
|
|
data_iterator, |
|
|
model, |
|
|
iteration, |
|
|
verbose=False, |
|
|
timers=None, |
|
|
chart_name="validation", |
|
|
): |
|
|
"""Helper function to evaluate and dump results on screen.""" |
|
|
total_loss_dict = evaluate( |
|
|
neox_args=neox_args, |
|
|
forward_step_fn=forward_step_func, |
|
|
data_iterator=data_iterator, |
|
|
model=model, |
|
|
verbose=verbose, |
|
|
timers=timers, |
|
|
) |
|
|
string = f" {chart_name} results at {prefix} | " |
|
|
for k, v in total_loss_dict.items(): |
|
|
if isinstance(v, dict): |
|
|
for k2, v2 in v.items(): |
|
|
k3 = "_".join([k, k2]) |
|
|
string += f"{k3} value: {v2:.6E} | " |
|
|
tb_wandb_log( |
|
|
f"{chart_name}/{k3}", |
|
|
v2, |
|
|
iteration, |
|
|
use_wandb=neox_args.use_wandb, |
|
|
tensorboard_writer=neox_args.tensorboard_writer, |
|
|
) |
|
|
else: |
|
|
string += f"{k} value: {v:.6E} | " |
|
|
tb_wandb_log( |
|
|
f"{chart_name}/{k}", |
|
|
v, |
|
|
iteration, |
|
|
use_wandb=neox_args.use_wandb, |
|
|
tensorboard_writer=neox_args.tensorboard_writer, |
|
|
) |
|
|
|
|
|
length = len(string) + 1 |
|
|
print_rank_0("-" * length) |
|
|
print_rank_0(string) |
|
|
print_rank_0("-" * length) |
|
|
|