|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
"""Input/output checkpointing.""" |
|
|
|
|
|
import json |
|
|
import os |
|
|
import re |
|
|
import shutil |
|
|
import random |
|
|
import sys |
|
|
import numpy as np |
|
|
|
|
|
import torch |
|
|
from glob import glob |
|
|
|
|
|
from megatron import mpu |
|
|
from megatron import print_rank_0 |
|
|
from megatron.utils import natural_sort |
|
|
from megatron.text_generation_utils import get_batch, forward_model |
|
|
from pathlib import Path |
|
|
from pprint import pformat |
|
|
|
|
|
|
|
|
def check_checkpoint_args(neox_args, checkpoint_args): |
|
|
"""Ensure fixed arguments for a model are the same for the input |
|
|
arguments and the one retrieved from checkpoint.""" |
|
|
|
|
|
assert isinstance(checkpoint_args, dict), "args stored in checkpoint is a dict" |
|
|
for checkpoint_arg_name, checkpoint_arg_value in checkpoint_args.items(): |
|
|
args_value = getattr(neox_args, checkpoint_arg_name) |
|
|
error_message = "{} value from checkpoint ({}) is not equal to the currently set argument value ({}).".format( |
|
|
checkpoint_arg_name, checkpoint_arg_value, args_value |
|
|
) |
|
|
assert checkpoint_arg_value == args_value, error_message |
|
|
|
|
|
|
|
|
def do_forward_pass(neox_args, model, inference=False): |
|
|
|
|
|
|
|
|
model_was_in_train = model.training |
|
|
model.eval() |
|
|
|
|
|
|
|
|
|
|
|
context_tokens_tensor = ( |
|
|
torch.arange(neox_args.seq_length + 1) |
|
|
.repeat((neox_args.train_micro_batch_size_per_gpu, 1)) |
|
|
.cuda() |
|
|
) |
|
|
|
|
|
|
|
|
if inference: |
|
|
tokens, attention_mask, position_ids = get_batch( |
|
|
neox_args, context_tokens_tensor[:, : neox_args.seq_length] |
|
|
) |
|
|
model_inputs = ( |
|
|
tokens, |
|
|
position_ids, |
|
|
attention_mask, |
|
|
torch.Tensor(), |
|
|
) |
|
|
logits, _ = forward_model(neox_args, model, model_inputs) |
|
|
elif neox_args.is_pipe_parallel: |
|
|
data_iterator = iter([{"text": context_tokens_tensor}]) |
|
|
_, logits = model.eval_batch(data_iter=data_iterator, return_logits=True) |
|
|
else: |
|
|
tokens, attention_mask, position_ids = get_batch( |
|
|
neox_args, context_tokens_tensor[:, : neox_args.seq_length] |
|
|
) |
|
|
logits = model((tokens, position_ids, attention_mask)) |
|
|
|
|
|
|
|
|
if model_was_in_train: |
|
|
model.train() |
|
|
|
|
|
if logits is not None: |
|
|
logits = logits.detach().cpu()[ |
|
|
0 |
|
|
] |
|
|
|
|
|
return logits |
|
|
|
|
|
|
|
|
def check_forward_pass(neox_args, model, checkpoint_logits, inference): |
|
|
|
|
|
logits = do_forward_pass(neox_args=neox_args, model=model, inference=inference) |
|
|
|
|
|
|
|
|
if ( |
|
|
logits is not None and checkpoint_logits is not None |
|
|
): |
|
|
if not (logits == checkpoint_logits).all().item(): |
|
|
if mpu.get_data_parallel_rank() == 0: |
|
|
print( |
|
|
" > WARNING: validate_checkpoint_forward() forward after load of checkpoint does not yield exactly same result" |
|
|
) |
|
|
assert ( |
|
|
torch.isclose(logits, checkpoint_logits).all().item() |
|
|
), "validate_checkpoint_forward() forward after load of checkpoint does not yield a close result" |
|
|
|
|
|
|
|
|
def ensure_directory_exists(filename): |
|
|
"""Build filename's path if it does not already exists.""" |
|
|
dirname = os.path.dirname(filename) |
|
|
if not os.path.exists(dirname): |
|
|
os.makedirs(dirname) |
|
|
|
|
|
|
|
|
def get_checkpoint_name(checkpoints_path, iteration, release=False, mp_rank=None): |
|
|
"""A unified checkpoint name.""" |
|
|
if release: |
|
|
directory = "release" |
|
|
else: |
|
|
directory = "iter_{:07d}".format(iteration) |
|
|
return os.path.join( |
|
|
checkpoints_path, |
|
|
directory, |
|
|
"mp_rank_{:02d}".format( |
|
|
mpu.get_model_parallel_rank() if mp_rank is None else mp_rank |
|
|
), |
|
|
"model_optim_rng.pt", |
|
|
) |
|
|
|
|
|
|
|
|
def delete_old_checkpoints(save_dir, n_to_keep): |
|
|
if torch.distributed.get_rank() == 0: |
|
|
ckpt_dir_regex = r"global_step[\d]*" |
|
|
if save_dir.endswith("/"): |
|
|
save_dir = save_dir.strip("/") |
|
|
all_ckpts = natural_sort( |
|
|
[ |
|
|
i |
|
|
for i in glob(f"{save_dir}/*") |
|
|
if os.path.isdir(i) and re.search(ckpt_dir_regex, i) |
|
|
] |
|
|
) |
|
|
n_to_delete = len(all_ckpts) - n_to_keep |
|
|
if n_to_delete > 0: |
|
|
to_delete = all_ckpts[:n_to_delete] |
|
|
print(f"WARNING: Deleting old checkpoints: \n\t{', '.join(to_delete)}") |
|
|
for ckpt in to_delete: |
|
|
try: |
|
|
shutil.rmtree(ckpt) |
|
|
except FileNotFoundError: |
|
|
pass |
|
|
|
|
|
|
|
|
def save_ds_checkpoint(iteration, model, neox_args): |
|
|
"""Save a model checkpoint.""" |
|
|
sd = { |
|
|
"iteration": iteration, |
|
|
"args": { |
|
|
"num_layers": neox_args.num_layers, |
|
|
"hidden_size": neox_args.hidden_size, |
|
|
"num_attention_heads": neox_args.num_attention_heads, |
|
|
"max_position_embeddings": neox_args.max_position_embeddings, |
|
|
"make_vocab_size_divisible_by": neox_args.make_vocab_size_divisible_by, |
|
|
"padded_vocab_size": neox_args.padded_vocab_size, |
|
|
"tokenizer_type": neox_args.tokenizer_type, |
|
|
"model_parallel_size": neox_args.model_parallel_size, |
|
|
}, |
|
|
} |
|
|
|
|
|
if not neox_args.no_save_rng: |
|
|
sd["random_rng_state"] = random.getstate() |
|
|
sd["np_rng_state"] = np.random.get_state() |
|
|
sd["torch_rng_state"] = torch.get_rng_state() |
|
|
sd["cuda_rng_state"] = torch.cuda.get_rng_state() |
|
|
sd["rng_tracker_states"] = mpu.get_cuda_rng_tracker().get_states() |
|
|
|
|
|
if neox_args.checkpoint_validation_with_forward_pass: |
|
|
logits = do_forward_pass(neox_args=neox_args, model=model) |
|
|
sd["checkpoint_validation_logits"] = logits |
|
|
|
|
|
|
|
|
tag = f"global_step{iteration}" |
|
|
|
|
|
|
|
|
model.save_checkpoint(neox_args.save, tag=tag, client_state=sd) |
|
|
|
|
|
|
|
|
if torch.distributed.get_rank() == 0 and neox_args.config_files is not None: |
|
|
configs_directory = os.path.join(neox_args.save, tag, "configs") |
|
|
os.makedirs(configs_directory, exist_ok=True) |
|
|
for config_filename, config_data in neox_args.config_files.items(): |
|
|
with open(os.path.join(configs_directory, config_filename), "w") as f: |
|
|
if isinstance(config_data, str): |
|
|
f.write(config_data) |
|
|
else: |
|
|
json.dump(config_data, f) |
|
|
|
|
|
|
|
|
def save_checkpoint(neox_args, iteration, model, optimizer, lr_scheduler): |
|
|
"""Save a model checkpoint.""" |
|
|
|
|
|
if neox_args.deepspeed: |
|
|
save_ds_checkpoint(iteration, model, neox_args) |
|
|
else: |
|
|
raise ValueError("Must be using deepspeed to use neox") |
|
|
|
|
|
|
|
|
torch.distributed.barrier() |
|
|
if neox_args.keep_last_n_checkpoints is not None: |
|
|
delete_old_checkpoints(neox_args.save, neox_args.keep_last_n_checkpoints) |
|
|
|
|
|
|
|
|
torch.distributed.barrier() |
|
|
|
|
|
|
|
|
def load_checkpoint( |
|
|
neox_args, model, optimizer, lr_scheduler, inference=False, iteration=None |
|
|
): |
|
|
"""Load a model checkpoint and return the iteration.""" |
|
|
if neox_args.deepspeed: |
|
|
load_optim_and_scheduler = ( |
|
|
not neox_args.no_load_optim |
|
|
) |
|
|
if neox_args.finetune: |
|
|
load_optim_and_scheduler = False |
|
|
if iteration is not None: |
|
|
tag = f"global_step{iteration}" |
|
|
else: |
|
|
tag = None |
|
|
checkpoint_name, state_dict = model.load_checkpoint( |
|
|
neox_args.load, |
|
|
load_optimizer_states=load_optim_and_scheduler, |
|
|
load_lr_scheduler_states=load_optim_and_scheduler, |
|
|
tag=tag, |
|
|
) |
|
|
|
|
|
if checkpoint_name is None: |
|
|
|
|
|
|
|
|
if iteration is not None: |
|
|
available_checkpoints = sorted( |
|
|
[ |
|
|
int(i.name.replace("global_step", "")) |
|
|
for i in Path(neox_args.load).glob("global_step*") |
|
|
] |
|
|
) |
|
|
raise ValueError( |
|
|
f"Unable to load checkpoint for iteration {iteration}. \nAvailable iterations: {pformat(available_checkpoints)}" |
|
|
) |
|
|
if mpu.get_data_parallel_rank() == 0: |
|
|
print("Unable to load checkpoint.") |
|
|
|
|
|
return 0 |
|
|
else: |
|
|
raise ValueError("Must be using deepspeed to use neox") |
|
|
|
|
|
|
|
|
if neox_args.finetune: |
|
|
iteration = 0 |
|
|
else: |
|
|
iteration = state_dict.get("iteration") |
|
|
if iteration is None: |
|
|
iteration = state_dict.get("total_iters") |
|
|
|
|
|
|
|
|
if "args" in state_dict: |
|
|
checkpoint_args = state_dict["args"] |
|
|
check_checkpoint_args(neox_args=neox_args, checkpoint_args=checkpoint_args) |
|
|
print_rank_0( |
|
|
" > validated currently set args with arguments in the checkpoint ..." |
|
|
) |
|
|
else: |
|
|
print_rank_0(" > could not find arguments in the checkpoint for validation...") |
|
|
|
|
|
|
|
|
if neox_args.checkpoint_validation_with_forward_pass: |
|
|
if "checkpoint_validation_logits" in state_dict: |
|
|
check_forward_pass( |
|
|
neox_args=neox_args, |
|
|
model=model, |
|
|
checkpoint_logits=state_dict["checkpoint_validation_logits"], |
|
|
inference=inference, |
|
|
) |
|
|
print_rank_0(" > validated loaded checkpoint with forward pass ...") |
|
|
else: |
|
|
if mpu.get_data_parallel_rank() == 0: |
|
|
print( |
|
|
" > WARNING: checkpoint_validation_with_forward_pass is configured but no checkpoint validation data available in checkpoint {}".format( |
|
|
checkpoint_name |
|
|
) |
|
|
) |
|
|
|
|
|
|
|
|
if not neox_args.finetune and not neox_args.no_load_rng: |
|
|
try: |
|
|
random.setstate(state_dict["random_rng_state"]) |
|
|
np.random.set_state(state_dict["np_rng_state"]) |
|
|
torch.set_rng_state(state_dict["torch_rng_state"]) |
|
|
torch.cuda.set_rng_state(state_dict["cuda_rng_state"]) |
|
|
mpu.get_cuda_rng_tracker().set_states(state_dict["rng_tracker_states"]) |
|
|
except KeyError: |
|
|
print_rank_0( |
|
|
"Unable to load optimizer from checkpoint {}. " |
|
|
"Specify --no-load-rng or --finetune to prevent " |
|
|
"attempting to load the optimizer state, " |
|
|
"exiting ...".format(checkpoint_name) |
|
|
) |
|
|
sys.exit() |
|
|
|
|
|
torch.distributed.barrier() |
|
|
if mpu.get_data_parallel_rank() == 0: |
|
|
print(" successfully loaded {}".format(checkpoint_name)) |
|
|
|
|
|
return iteration |
|
|
|