|
|
|
|
|
|
|
|
|
|
|
import functools |
|
|
import os |
|
|
import shutil |
|
|
import wandb |
|
|
import yaml |
|
|
from copy import deepcopy |
|
|
from dataclasses import dataclass, field |
|
|
from time import time |
|
|
|
|
|
import torch |
|
|
import torch.distributed as dist |
|
|
from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import ( |
|
|
CheckpointImpl, |
|
|
apply_activation_checkpointing, |
|
|
checkpoint_wrapper, |
|
|
) |
|
|
from torch.utils.data import DataLoader |
|
|
from transformers import HfArgumentParser, set_seed |
|
|
from transformers.optimization import ( |
|
|
get_constant_schedule_with_warmup, |
|
|
get_cosine_with_min_lr_schedule_with_warmup, |
|
|
) |
|
|
|
|
|
from data.dataset_base import DataConfig, PackedDataset, collate_wrapper |
|
|
from data.data_utils import add_special_tokens |
|
|
from modeling.autoencoder import load_ae |
|
|
from modeling.bagel import ( |
|
|
BagelConfig, Bagel, Qwen2Config, Qwen2ForCausalLM, SiglipVisionConfig, SiglipVisionModel |
|
|
) |
|
|
from modeling.qwen2 import Qwen2Tokenizer |
|
|
from train.train_utils import create_logger, get_latest_ckpt |
|
|
from train.fsdp_utils import ( |
|
|
FSDPCheckpoint, FSDPConfig, grad_checkpoint_check_fn, fsdp_wrapper, |
|
|
fsdp_ema_setup, fsdp_ema_update, |
|
|
) |
|
|
|
|
|
|
|
|
@dataclass |
|
|
class ModelArguments: |
|
|
model_path: str = field( |
|
|
default="hf/BAGEL-7B-MoT", |
|
|
metadata={"help": "Path of the pretrained BAGEL model."} |
|
|
) |
|
|
llm_path: str = field( |
|
|
default="hf/Qwen2.5-0.5B-Instruct/", |
|
|
metadata={"help": "Path or HuggingFace repo ID of the pretrained Qwen2-style language model."} |
|
|
) |
|
|
llm_qk_norm: bool = field( |
|
|
default=True, |
|
|
metadata={"help": "Enable QK LayerNorm (qk_norm) inside the attention blocks."} |
|
|
) |
|
|
tie_word_embeddings: bool = field( |
|
|
default=False, |
|
|
metadata={"help": "Share input and output word embeddings (tied embeddings)."} |
|
|
) |
|
|
layer_module: str = field( |
|
|
default="Qwen2MoTDecoderLayer", |
|
|
metadata={"help": "Python class name of the decoder layer to instantiate."} |
|
|
) |
|
|
vae_path: str = field( |
|
|
default="flux/vae/ae.safetensors", |
|
|
metadata={"help": "Path to the pretrained VAE checkpoint for latent-space image generation."} |
|
|
) |
|
|
vit_path: str = field( |
|
|
default="hf/siglip-so400m-14-980-flash-attn2-navit/", |
|
|
metadata={"help": "Path or repo ID of the SigLIP Vision Transformer used for image understanding."} |
|
|
) |
|
|
max_latent_size: int = field( |
|
|
default=32, |
|
|
metadata={"help": "Maximum latent grid size (patches per side) for the VAE latent tensor."} |
|
|
) |
|
|
latent_patch_size: int = field( |
|
|
default=2, |
|
|
metadata={"help": "Spatial size (in VAE pixels) covered by each latent patch."} |
|
|
) |
|
|
vit_patch_size: int = field( |
|
|
default=14, |
|
|
metadata={"help": "Patch size (pixels) for the Vision Transformer encoder."} |
|
|
) |
|
|
vit_max_num_patch_per_side: int = field( |
|
|
default=70, |
|
|
metadata={"help": "Maximum number of ViT patches along one image side after cropping / resize."} |
|
|
) |
|
|
connector_act: str = field( |
|
|
default="gelu_pytorch_tanh", |
|
|
metadata={"help": "Activation function used in the latent-to-text connector MLP."} |
|
|
) |
|
|
interpolate_pos: bool = field( |
|
|
default=False, |
|
|
metadata={"help": "Interpolate positional embeddings when image resolution differs from pre-training."} |
|
|
) |
|
|
vit_select_layer: int = field( |
|
|
default=-2, |
|
|
metadata={"help": "Which hidden layer of the ViT to take as the visual feature (negative = from the end)."} |
|
|
) |
|
|
vit_rope: bool = field( |
|
|
default=False, |
|
|
metadata={"help": "Replace ViT positional encodings with RoPE."} |
|
|
) |
|
|
|
|
|
text_cond_dropout_prob: float = field( |
|
|
default=0.1, |
|
|
metadata={"help": "Probability of dropping text embeddings during training."} |
|
|
) |
|
|
vae_cond_dropout_prob: float = field( |
|
|
default=0.3, |
|
|
metadata={"help": "Probability of dropping VAE latent inputs during training."} |
|
|
) |
|
|
vit_cond_dropout_prob: float = field( |
|
|
default=0.3, |
|
|
metadata={"help": "Probability of dropping ViT visual features during training."} |
|
|
) |
|
|
|
|
|
|
|
|
@dataclass |
|
|
class DataArguments: |
|
|
dataset_config_file: str = field( |
|
|
default="data/configs/example.yaml", |
|
|
metadata={"help": "YAML file specifying dataset groups, weights, and preprocessing rules."} |
|
|
) |
|
|
prefetch_factor: int = field( |
|
|
default=2, |
|
|
metadata={"help": "How many batches each DataLoader worker pre-loads in advance."} |
|
|
) |
|
|
num_workers: int = field( |
|
|
default=4, |
|
|
metadata={"help": "Number of background workers for the PyTorch DataLoader."} |
|
|
) |
|
|
max_num_tokens_per_sample: int = field( |
|
|
default=16384, |
|
|
metadata={"help": "Maximum tokens allowed in one raw sample; longer samples are skipped."} |
|
|
) |
|
|
max_num_tokens: int = field( |
|
|
default=36864, |
|
|
metadata={"help": "Hard limit on tokens in a packed batch; flush if adding a sample would exceed it."} |
|
|
) |
|
|
prefer_buffer_before: int = field( |
|
|
default=16384, |
|
|
metadata={"help": "While batch length is below this, pop from the overflow buffer before new sampling."} |
|
|
) |
|
|
max_buffer_size: int = field( |
|
|
default=50, |
|
|
metadata={"help": "Maximum number of oversized samples kept in the overflow buffer."} |
|
|
) |
|
|
data_seed: int = field( |
|
|
default=42, |
|
|
metadata={"help": "Seed used when shuffling / sampling data shards to ensure reproducibility."} |
|
|
) |
|
|
|
|
|
|
|
|
@dataclass |
|
|
class TrainingArguments: |
|
|
|
|
|
visual_gen: bool = field( |
|
|
default=True, |
|
|
metadata={"help": "Train image generation branch."} |
|
|
) |
|
|
visual_und: bool = field( |
|
|
default=True, |
|
|
metadata={"help": "Train image understanding branch."} |
|
|
) |
|
|
|
|
|
|
|
|
results_dir: str = field( |
|
|
default="results", |
|
|
metadata={"help": "Root directory for logs."} |
|
|
) |
|
|
checkpoint_dir: str = field( |
|
|
default="results/checkpoints", |
|
|
metadata={"help": "Root directory for model checkpoints."} |
|
|
) |
|
|
wandb_project: str = field( |
|
|
default="bagel", |
|
|
metadata={"help": "Weights & Biases project name."} |
|
|
) |
|
|
wandb_name: str = field( |
|
|
default="run", |
|
|
metadata={"help": "Name shown in the Weights & Biases UI for this run."} |
|
|
) |
|
|
wandb_runid: str = field( |
|
|
default="0", |
|
|
metadata={"help": "Unique identifier to resume a previous W&B run, if desired."} |
|
|
) |
|
|
wandb_resume: str = field( |
|
|
default="allow", |
|
|
metadata={"help": "W&B resume mode: 'allow', 'must', or 'never'."} |
|
|
) |
|
|
wandb_offline: bool = field( |
|
|
default=True, |
|
|
metadata={"help": "Run W&B in offline mode (logs locally, sync later)."} |
|
|
) |
|
|
|
|
|
|
|
|
global_seed: int = field( |
|
|
default=4396, |
|
|
metadata={"help": "Base random seed; actual seed is offset by rank for DDP."} |
|
|
) |
|
|
auto_resume: bool = field( |
|
|
default=False, |
|
|
metadata={"help": "Automatically pick up the latest checkpoint found in checkpoint_dir."} |
|
|
) |
|
|
resume_from: str = field( |
|
|
default=None, |
|
|
metadata={"help": "Explicit checkpoint path to resume from (overrides auto_resume)." } |
|
|
) |
|
|
resume_model_only: bool = field( |
|
|
default=False, |
|
|
metadata={"help": "Load only model weights, ignoring optimizer/scheduler states."} |
|
|
) |
|
|
finetune_from_ema: bool = field( |
|
|
default=False, |
|
|
metadata={"help": "When resume_model_only=True, load the EMA (exponential moving average) weights instead of raw weights."} |
|
|
) |
|
|
finetune_from_hf: bool = field( |
|
|
default=False, |
|
|
metadata={"help": "Whether finetune from HugginFace model."} |
|
|
) |
|
|
|
|
|
|
|
|
log_every: int = field( |
|
|
default=10, |
|
|
metadata={"help": "Print / log every N training steps."} |
|
|
) |
|
|
save_every: int = field( |
|
|
default=2000, |
|
|
metadata={"help": "Save a checkpoint every N training steps."} |
|
|
) |
|
|
max_checkpoints: int = field( |
|
|
default=3, |
|
|
metadata={"help": "Maximum number of checkpoints to keep. Older checkpoints will be deleted."} |
|
|
) |
|
|
total_steps: int = field( |
|
|
default=20000, |
|
|
metadata={"help": "Total number of optimizer steps to train for."} |
|
|
) |
|
|
|
|
|
|
|
|
warmup_steps: int = field( |
|
|
default=2000, |
|
|
metadata={"help": "Linear warm-up steps before applying the main LR schedule."} |
|
|
) |
|
|
lr_scheduler: str = field( |
|
|
default="constant", |
|
|
metadata={"help": "Type of LR schedule: 'constant' or 'cosine'."} |
|
|
) |
|
|
lr: float = field( |
|
|
default=1e-4, |
|
|
metadata={"help": "Peak learning rate after warm-up."} |
|
|
) |
|
|
min_lr: float = field( |
|
|
default=1e-7, |
|
|
metadata={"help": "Minimum learning rate for cosine schedule (ignored for constant)."} |
|
|
) |
|
|
beta1: float = field( |
|
|
default=0.9, |
|
|
metadata={"help": "AdamW β₁ coefficient."} |
|
|
) |
|
|
beta2: float = field( |
|
|
default=0.95, |
|
|
metadata={"help": "AdamW β₂ coefficient."} |
|
|
) |
|
|
eps: float = field( |
|
|
default=1e-8, |
|
|
metadata={"help": "AdamW ε for numerical stability."} |
|
|
) |
|
|
ema: float = field( |
|
|
default=0.9999, |
|
|
metadata={"help": "Decay rate for the exponential moving average of model weights."} |
|
|
) |
|
|
max_grad_norm: int = field( |
|
|
default=1.0, |
|
|
metadata={"help": "Gradient clipping threshold (L2 norm)."} |
|
|
) |
|
|
timestep_shift: float = field( |
|
|
default=1.0, |
|
|
metadata={"help": "Shift applied to diffusion timestep indices (for latent prediction)."} |
|
|
) |
|
|
mse_weight: float = field( |
|
|
default=1.0, |
|
|
metadata={"help": "Scaling factor for the image-reconstruction MSE loss term."} |
|
|
) |
|
|
ce_weight: float = field( |
|
|
default=1.0, |
|
|
metadata={"help": "Scaling factor for the language cross-entropy loss term."} |
|
|
) |
|
|
ce_loss_reweighting: bool = field( |
|
|
default=False, |
|
|
metadata={"help": "Reweight CE loss by token importance (provided via ce_loss_weights)."} |
|
|
) |
|
|
expected_num_tokens: int = field( |
|
|
default=32768, |
|
|
metadata={"help": "Soft target token count; yield the batch once it reaches or exceeds this size."} |
|
|
) |
|
|
|
|
|
|
|
|
num_replicate: int = field( |
|
|
default=1, |
|
|
metadata={"help": "Number of model replicas per GPU rank for tensor parallelism."} |
|
|
) |
|
|
num_shard: int = field( |
|
|
default=8, |
|
|
metadata={"help": "Number of parameter shards when using FSDP HYBRID_SHARD."} |
|
|
) |
|
|
sharding_strategy: str = field( |
|
|
default="HYBRID_SHARD", |
|
|
metadata={"help": "FSDP sharding strategy: FULL_SHARD, SHARD_GRAD_OP, HYBRID_SHARD, etc."} |
|
|
) |
|
|
backward_prefetch: str = field( |
|
|
default="BACKWARD_PRE", |
|
|
metadata={"help": "FSDP backward prefetch strategy (BACKWARD_PRE or NO_PREFETCH)."} |
|
|
) |
|
|
cpu_offload: bool = field( |
|
|
default=False, |
|
|
metadata={"help": "Enable FSDP parameter offload to CPU."} |
|
|
) |
|
|
|
|
|
|
|
|
freeze_llm: bool = field( |
|
|
default=False, |
|
|
metadata={"help": "Keep language-model weights fixed (no gradient updates)."} |
|
|
) |
|
|
freeze_vit: bool = field( |
|
|
default=False, |
|
|
metadata={"help": "Keep ViT weights fixed during training."} |
|
|
) |
|
|
freeze_vae: bool = field( |
|
|
default=True, |
|
|
metadata={"help": "Keep VAE weights fixed; only predict latents, don't fine-tune encoder/decoder."} |
|
|
) |
|
|
freeze_und: bool = field( |
|
|
default=False, |
|
|
metadata={"help": "Freeze the visual understanding connector layers."} |
|
|
) |
|
|
copy_init_moe: bool = field( |
|
|
default=True, |
|
|
metadata={"help": "Duplicate initial MoE experts so each has identical initialisation."} |
|
|
) |
|
|
use_flex: bool = field( |
|
|
default=False, |
|
|
metadata={"help": "Enable FLEX (flash-ext friendly) packing algorithm for sequence data."} |
|
|
) |
|
|
|
|
|
|
|
|
def main(): |
|
|
assert torch.cuda.is_available() |
|
|
dist.init_process_group("nccl") |
|
|
device = dist.get_rank() % torch.cuda.device_count() |
|
|
torch.cuda.set_device(device) |
|
|
parser = HfArgumentParser((ModelArguments, DataArguments, TrainingArguments)) |
|
|
model_args, data_args, training_args = parser.parse_args_into_dataclasses() |
|
|
|
|
|
|
|
|
if dist.get_rank() == 0: |
|
|
os.makedirs(training_args.results_dir, exist_ok=True) |
|
|
os.makedirs(training_args.checkpoint_dir, exist_ok=True) |
|
|
logger = create_logger(training_args.results_dir, dist.get_rank()) |
|
|
wandb.init( |
|
|
project=training_args.wandb_project, |
|
|
id=f"{training_args.wandb_name}-run{training_args.wandb_runid}", |
|
|
name=training_args.wandb_name, |
|
|
resume=training_args.wandb_resume, |
|
|
mode="offline" if training_args.wandb_offline else "online" |
|
|
) |
|
|
wandb.config.update(training_args, allow_val_change=True) |
|
|
wandb.config.update(model_args, allow_val_change=True) |
|
|
wandb.config.update(data_args, allow_val_change=True) |
|
|
else: |
|
|
logger = create_logger(None, dist.get_rank()) |
|
|
dist.barrier() |
|
|
logger.info(f'Training arguments {training_args}') |
|
|
logger.info(f'Model arguments {model_args}') |
|
|
logger.info(f'Data arguments {data_args}') |
|
|
|
|
|
|
|
|
if training_args.auto_resume: |
|
|
resume_from = get_latest_ckpt(training_args.checkpoint_dir) |
|
|
if resume_from is None: |
|
|
resume_from = training_args.resume_from |
|
|
resume_model_only = training_args.resume_model_only |
|
|
if resume_model_only: |
|
|
finetune_from_ema = training_args.finetune_from_ema |
|
|
else: |
|
|
finetune_from_ema = False |
|
|
else: |
|
|
resume_model_only = False |
|
|
finetune_from_ema = False |
|
|
else: |
|
|
resume_from = training_args.resume_from |
|
|
resume_model_only = training_args.resume_model_only |
|
|
if resume_model_only: |
|
|
finetune_from_ema = training_args.finetune_from_ema |
|
|
else: |
|
|
finetune_from_ema = False |
|
|
|
|
|
|
|
|
seed = training_args.global_seed * dist.get_world_size() + dist.get_rank() |
|
|
set_seed(seed) |
|
|
|
|
|
|
|
|
if training_args.finetune_from_hf: |
|
|
llm_config = Qwen2Config.from_json_file(os.path.join(model_args.model_path, "llm_config.json")) |
|
|
else: |
|
|
llm_config = Qwen2Config.from_pretrained(model_args.llm_path) |
|
|
llm_config.layer_module = model_args.layer_module |
|
|
llm_config.qk_norm = model_args.llm_qk_norm |
|
|
llm_config.tie_word_embeddings = model_args.tie_word_embeddings |
|
|
llm_config.freeze_und = training_args.freeze_und |
|
|
if training_args.finetune_from_hf: |
|
|
language_model = Qwen2ForCausalLM(llm_config) |
|
|
else: |
|
|
language_model = Qwen2ForCausalLM.from_pretrained(model_args.llm_path, config=llm_config) |
|
|
if training_args.copy_init_moe: |
|
|
language_model.init_moe() |
|
|
|
|
|
if training_args.visual_und: |
|
|
if training_args.finetune_from_hf: |
|
|
vit_config = SiglipVisionConfig.from_json_file(os.path.join(model_args.model_path, "vit_config.json")) |
|
|
else: |
|
|
vit_config = SiglipVisionConfig.from_pretrained(model_args.vit_path) |
|
|
vit_config.num_hidden_layers = vit_config.num_hidden_layers + 1 + model_args.vit_select_layer |
|
|
vit_config.rope = model_args.vit_rope |
|
|
if training_args.finetune_from_hf: |
|
|
vit_model = SiglipVisionModel(vit_config) |
|
|
else: |
|
|
vit_model = SiglipVisionModel.from_pretrained(model_args.vit_path, config=vit_config) |
|
|
|
|
|
if training_args.visual_gen: |
|
|
vae_model, vae_config = load_ae( |
|
|
local_path=os.path.join(model_args.model_path, "ae.safetensors") |
|
|
if training_args.finetune_from_hf else model_args.vae_path |
|
|
) |
|
|
|
|
|
config = BagelConfig( |
|
|
visual_gen=training_args.visual_gen, |
|
|
visual_und=training_args.visual_und, |
|
|
llm_config=llm_config, |
|
|
vit_config=vit_config if training_args.visual_und else None, |
|
|
vae_config=vae_config if training_args.visual_gen else None, |
|
|
latent_patch_size=model_args.latent_patch_size, |
|
|
max_latent_size=model_args.max_latent_size, |
|
|
vit_max_num_patch_per_side=model_args.vit_max_num_patch_per_side, |
|
|
connector_act=model_args.connector_act, |
|
|
interpolate_pos=model_args.interpolate_pos, |
|
|
timestep_shift=training_args.timestep_shift, |
|
|
) |
|
|
model = Bagel( |
|
|
language_model, |
|
|
vit_model if training_args.visual_und else None, |
|
|
config |
|
|
) |
|
|
|
|
|
if training_args.visual_und: |
|
|
model.vit_model.vision_model.embeddings.convert_conv2d_to_linear(vit_config) |
|
|
|
|
|
|
|
|
tokenizer = Qwen2Tokenizer.from_pretrained(model_args.model_path if training_args.finetune_from_hf else model_args.llm_path) |
|
|
tokenizer, new_token_ids, num_new_tokens = add_special_tokens(tokenizer) |
|
|
if num_new_tokens > 0: |
|
|
model.language_model.resize_token_embeddings(len(tokenizer)) |
|
|
model.config.llm_config.vocab_size = len(tokenizer) |
|
|
model.language_model.config.vocab_size = len(tokenizer) |
|
|
|
|
|
|
|
|
if training_args.freeze_vae and training_args.visual_gen: |
|
|
for param in vae_model.parameters(): |
|
|
param.requires_grad = False |
|
|
if training_args.freeze_llm: |
|
|
model.language_model.eval() |
|
|
for param in model.language_model.parameters(): |
|
|
param.requires_grad = False |
|
|
if training_args.freeze_vit and training_args.visual_und: |
|
|
model.vit_model.eval() |
|
|
for param in model.vit_model.parameters(): |
|
|
param.requires_grad = False |
|
|
|
|
|
|
|
|
fsdp_config = FSDPConfig( |
|
|
sharding_strategy=training_args.sharding_strategy, |
|
|
backward_prefetch=training_args.backward_prefetch, |
|
|
cpu_offload=training_args.cpu_offload, |
|
|
num_replicate=training_args.num_replicate, |
|
|
num_shard=training_args.num_shard, |
|
|
) |
|
|
ema_model = deepcopy(model) |
|
|
model, ema_model = FSDPCheckpoint.try_load_ckpt( |
|
|
resume_from, logger, model, ema_model, resume_from_ema=finetune_from_ema |
|
|
) |
|
|
ema_model = fsdp_ema_setup(ema_model, fsdp_config) |
|
|
fsdp_model = fsdp_wrapper(model, fsdp_config) |
|
|
apply_activation_checkpointing( |
|
|
fsdp_model, |
|
|
checkpoint_wrapper_fn=functools.partial( |
|
|
checkpoint_wrapper, checkpoint_impl=CheckpointImpl.NO_REENTRANT |
|
|
), |
|
|
check_fn=grad_checkpoint_check_fn |
|
|
) |
|
|
|
|
|
if dist.get_rank() == 0: |
|
|
print(fsdp_model) |
|
|
for name, param in model.named_parameters(): |
|
|
print(name, param.requires_grad) |
|
|
|
|
|
|
|
|
optimizer = torch.optim.AdamW( |
|
|
fsdp_model.parameters(), |
|
|
lr=training_args.lr, |
|
|
betas=(training_args.beta1, training_args.beta2), |
|
|
eps=training_args.eps, |
|
|
weight_decay=0 |
|
|
) |
|
|
if training_args.lr_scheduler == 'cosine': |
|
|
scheduler = get_cosine_with_min_lr_schedule_with_warmup( |
|
|
optimizer=optimizer, |
|
|
num_warmup_steps=training_args.warmup_steps, |
|
|
num_training_steps=training_args.total_steps, |
|
|
min_lr=training_args.min_lr, |
|
|
) |
|
|
elif training_args.lr_scheduler == 'constant': |
|
|
scheduler = get_constant_schedule_with_warmup( |
|
|
optimizer=optimizer, num_warmup_steps=training_args.warmup_steps |
|
|
) |
|
|
else: |
|
|
raise ValueError |
|
|
|
|
|
|
|
|
if resume_model_only: |
|
|
train_step = 0 |
|
|
data_status = None |
|
|
else: |
|
|
optimizer, scheduler, train_step, data_status = FSDPCheckpoint.try_load_train_state( |
|
|
resume_from, optimizer, scheduler, fsdp_config, |
|
|
) |
|
|
|
|
|
|
|
|
with open(data_args.dataset_config_file, "r") as stream: |
|
|
dataset_meta = yaml.safe_load(stream) |
|
|
print(dataset_meta) |
|
|
dataset_config = DataConfig(grouped_datasets=dataset_meta) |
|
|
if training_args.visual_und: |
|
|
dataset_config.vit_patch_size = model_args.vit_patch_size |
|
|
dataset_config.max_num_patch_per_side = model_args.vit_max_num_patch_per_side |
|
|
if training_args.visual_gen: |
|
|
vae_image_downsample = model_args.latent_patch_size * vae_config.downsample |
|
|
dataset_config.vae_image_downsample = vae_image_downsample |
|
|
dataset_config.max_latent_size = model_args.max_latent_size |
|
|
dataset_config.text_cond_dropout_prob = model_args.text_cond_dropout_prob |
|
|
dataset_config.vae_cond_dropout_prob = model_args.vae_cond_dropout_prob |
|
|
dataset_config.vit_cond_dropout_prob = model_args.vit_cond_dropout_prob |
|
|
train_dataset = PackedDataset( |
|
|
dataset_config, |
|
|
tokenizer=tokenizer, |
|
|
special_tokens=new_token_ids, |
|
|
local_rank=dist.get_rank(), |
|
|
world_size=dist.get_world_size(), |
|
|
num_workers=data_args.num_workers, |
|
|
expected_num_tokens=training_args.expected_num_tokens, |
|
|
max_num_tokens_per_sample=data_args.max_num_tokens_per_sample, |
|
|
max_num_tokens=data_args.max_num_tokens, |
|
|
max_buffer_size=data_args.max_buffer_size, |
|
|
prefer_buffer_before=data_args.prefer_buffer_before, |
|
|
interpolate_pos=model_args.interpolate_pos, |
|
|
use_flex=training_args.use_flex, |
|
|
data_status=data_status, |
|
|
) |
|
|
train_dataset.set_epoch(data_args.data_seed) |
|
|
train_loader = DataLoader( |
|
|
train_dataset, |
|
|
batch_size=1, |
|
|
num_workers=data_args.num_workers, |
|
|
pin_memory=True, |
|
|
collate_fn=collate_wrapper(), |
|
|
drop_last=True, |
|
|
prefetch_factor=data_args.prefetch_factor, |
|
|
) |
|
|
|
|
|
|
|
|
if training_args.visual_gen: |
|
|
vae_model.to(device).eval() |
|
|
fsdp_model.train() |
|
|
ema_model.eval() |
|
|
|
|
|
|
|
|
start_time = time() |
|
|
logger.info(f"Training for {training_args.total_steps} steps, starting at {train_step}...") |
|
|
for curr_step, data in enumerate(train_loader, start=train_step): |
|
|
data = data.cuda(device).to_dict() |
|
|
data_indexes = data.pop('batch_data_indexes', None) |
|
|
ce_loss_weights = data.pop('ce_loss_weights', None) |
|
|
with torch.amp.autocast("cuda", enabled=True, dtype=torch.bfloat16): |
|
|
if training_args.visual_gen: |
|
|
with torch.no_grad(): |
|
|
data['padded_latent'] = vae_model.encode(data.pop('padded_images')) |
|
|
loss_dict = fsdp_model(**data) |
|
|
|
|
|
loss = 0 |
|
|
ce = loss_dict["ce"] |
|
|
if ce is not None: |
|
|
total_ce_tokens = torch.tensor(len(data['ce_loss_indexes']), device=device) |
|
|
dist.all_reduce(total_ce_tokens, op=dist.ReduceOp.SUM) |
|
|
if training_args.ce_loss_reweighting: |
|
|
ce = ce * ce_loss_weights |
|
|
total_ce_loss_weights = ce_loss_weights.sum() |
|
|
dist.all_reduce(total_ce_loss_weights, op=dist.ReduceOp.SUM) |
|
|
ce = ce.sum() * dist.get_world_size() / total_ce_loss_weights |
|
|
else: |
|
|
ce = ce.sum() * dist.get_world_size() / total_ce_tokens |
|
|
loss_dict["ce"] = ce.detach() |
|
|
loss = loss + ce * training_args.ce_weight |
|
|
else: |
|
|
assert not training_args.visual_und |
|
|
loss_dict["ce"] = torch.tensor(0, device=device) |
|
|
total_ce_tokens = torch.tensor(0, device=device) |
|
|
|
|
|
if training_args.visual_gen: |
|
|
mse = loss_dict["mse"] |
|
|
total_mse_tokens = torch.tensor(len(data['mse_loss_indexes']), device=device) |
|
|
dist.all_reduce(total_mse_tokens, op=dist.ReduceOp.SUM) |
|
|
mse = mse.mean(dim=-1).sum() * dist.get_world_size() / total_mse_tokens |
|
|
loss_dict["mse"] = mse.detach() |
|
|
loss = loss + mse * training_args.mse_weight |
|
|
else: |
|
|
assert not training_args.visual_gen |
|
|
loss_dict["mse"] = torch.tensor(0, device=device) |
|
|
total_mse_tokens = torch.tensor(0, device=device) |
|
|
|
|
|
|
|
|
if torch.isnan(loss): |
|
|
logger.warning(f"Step {curr_step}: Loss is NaN, skipping this step") |
|
|
continue |
|
|
|
|
|
optimizer.zero_grad() |
|
|
loss.backward() |
|
|
total_norm = fsdp_model.clip_grad_norm_(training_args.max_grad_norm) |
|
|
optimizer.step() |
|
|
scheduler.step() |
|
|
fsdp_ema_update(ema_model, fsdp_model, decay=training_args.ema) |
|
|
|
|
|
|
|
|
if curr_step % training_args.log_every == 0: |
|
|
total_samples = torch.tensor(len(data['sample_lens']), device=device) |
|
|
dist.all_reduce(total_samples, op=dist.ReduceOp.SUM) |
|
|
|
|
|
|
|
|
torch.cuda.synchronize() |
|
|
end_time = time() |
|
|
steps_per_sec = training_args.log_every / (end_time - start_time) |
|
|
message = f"(step={curr_step:07d}) " |
|
|
wandb_log = {} |
|
|
for key, value in loss_dict.items(): |
|
|
|
|
|
avg_loss = torch.tensor(value.item(), device=device) |
|
|
dist.all_reduce(avg_loss, op=dist.ReduceOp.SUM) |
|
|
avg_loss = avg_loss.item() / dist.get_world_size() |
|
|
message += f"Train Loss {key}: {avg_loss:.4f}, " |
|
|
wandb_log[key] = avg_loss |
|
|
message += f"Train Steps/Sec: {steps_per_sec:.2f}, " |
|
|
logger.info(message) |
|
|
|
|
|
wandb_log['lr'] = optimizer.param_groups[0]['lr'] |
|
|
wandb_log['total_mse_tokens'] = total_mse_tokens.item() |
|
|
wandb_log['total_ce_tokens'] = total_ce_tokens.item() |
|
|
wandb_log['total_norm'] = total_norm.item() |
|
|
wandb_log['total_samples'] = total_samples.item() |
|
|
|
|
|
mem_allocated = torch.tensor(torch.cuda.max_memory_allocated() / 1024**2, device=device) |
|
|
dist.all_reduce(mem_allocated, op=dist.ReduceOp.MAX) |
|
|
wandb_log['mem_allocated'] = mem_allocated |
|
|
mem_cache = torch.tensor(torch.cuda.max_memory_reserved() / 1024**2, device=device) |
|
|
dist.all_reduce(mem_cache, op=dist.ReduceOp.MAX) |
|
|
wandb_log['mem_cache'] = mem_cache |
|
|
|
|
|
if dist.get_rank() == 0: |
|
|
wandb.log(wandb_log, step=curr_step) |
|
|
start_time = time() |
|
|
|
|
|
if data_status is None: |
|
|
data_status = {} |
|
|
for item in data_indexes: |
|
|
if item['dataset_name'] not in data_status.keys(): |
|
|
data_status[item['dataset_name']] = {} |
|
|
data_status[item['dataset_name']][item['worker_id']] = item['data_indexes'] |
|
|
|
|
|
if curr_step > 0 and curr_step % training_args.save_every == 0: |
|
|
if dist.get_rank() == 0: |
|
|
gather_list = [None] * dist.get_world_size() |
|
|
else: |
|
|
gather_list = None |
|
|
dist.gather_object(data_status, gather_list, dst=0) |
|
|
|
|
|
FSDPCheckpoint.fsdp_save_ckpt( |
|
|
ckpt_dir=training_args.checkpoint_dir, |
|
|
train_steps=curr_step, |
|
|
model=fsdp_model, |
|
|
ema_model=ema_model, |
|
|
optimizer=optimizer, |
|
|
scheduler=scheduler, |
|
|
logger=logger, |
|
|
fsdp_config=fsdp_config, |
|
|
data_status=gather_list |
|
|
) |
|
|
|
|
|
|
|
|
if dist.get_rank() == 0: |
|
|
try: |
|
|
|
|
|
checkpoint_dirs = [d for d in os.listdir(training_args.checkpoint_dir) |
|
|
if os.path.isdir(os.path.join(training_args.checkpoint_dir, d)) |
|
|
and d.isdigit()] |
|
|
checkpoint_dirs.sort(key=int) |
|
|
logger.info(f"Found {len(checkpoint_dirs)} checkpoint directories: {checkpoint_dirs}") |
|
|
|
|
|
|
|
|
while len(checkpoint_dirs) > training_args.max_checkpoints: |
|
|
oldest_checkpoint = checkpoint_dirs.pop(0) |
|
|
old_path = os.path.join(training_args.checkpoint_dir, oldest_checkpoint) |
|
|
shutil.rmtree(old_path) |
|
|
logger.info(f"Deleted old checkpoint folder: {old_path}") |
|
|
|
|
|
logger.info(f"Now keeping {len(checkpoint_dirs)} checkpoints (max: {training_args.max_checkpoints})") |
|
|
except Exception as e: |
|
|
logger.warning(f"Failed to delete old checkpoints: {e}") |
|
|
|
|
|
logger.info("Done!") |
|
|
if dist.get_rank() == 0: |
|
|
wandb.finish() |
|
|
dist.destroy_process_group() |
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
main() |
|
|
|