| import math | |
| import warnings | |
| from collections.abc import Sequence | |
| from functools import partial | |
| from typing import Optional, Tuple, Union | |
| import torch | |
| from torch import nn | |
| from .norm import NORM_CLASS_REGISTRY | |
| def torch_default_param_init_fn_(module: nn.Module, verbose: int=0, **kwargs): | |
| del kwargs | |
| if (verbose > 1): | |
| warnings.warn(f"Initializing network using module's reset_parameters attribute") | |
| if hasattr(module, 'reset_parameters'): | |
| module.reset_parameters() | |
| def fused_init_helper_(module: nn.Module, init_fn_): | |
| _fused = getattr(module, '_fused', None) | |
| if (_fused is None): | |
| raise RuntimeError(f'Internal logic error') | |
| (dim, splits) = _fused | |
| splits = (0, *splits, module.weight.size(dim)) | |
| for (s, e) in zip(splits[:(- 1)], splits[1:]): | |
| slice_indices = ([slice(None)] * module.weight.ndim) | |
| slice_indices[dim] = slice(s, e) | |
| init_fn_(module.weight[slice_indices]) | |
| def generic_param_init_fn_(module: nn.Module, init_fn_, n_layers: int, d_model: Optional[int]=None, init_div_is_residual: Union[(int, float, str, bool)]=True, emb_init_std: Optional[float]=None, emb_init_uniform_lim: Optional[Union[(Tuple[(float, float)], float)]]=None, verbose: int=0, **kwargs): | |
| del kwargs | |
| if (verbose > 1): | |
| warnings.warn(f'If model has bias parameters they are initialized to 0.') | |
| init_div_is_residual = init_div_is_residual | |
| if (init_div_is_residual is False): | |
| div_is_residual = 1.0 | |
| elif (init_div_is_residual is True): | |
| div_is_residual = math.sqrt((2 * n_layers)) | |
| elif (isinstance(init_div_is_residual, float) or isinstance(init_div_is_residual, int)): | |
| div_is_residual = init_div_is_residual | |
| elif (isinstance(init_div_is_residual, str) and init_div_is_residual.isnumeric()): | |
| div_is_residual = float(init_div_is_residual) | |
| else: | |
| div_is_residual = 1.0 | |
| raise ValueError(f'Expected init_div_is_residual to be boolean or numeric, got {init_div_is_residual}') | |
| if (init_div_is_residual is not False): | |
| if (verbose > 1): | |
| warnings.warn((f'Initializing _is_residual layers then dividing them by {div_is_residual:.3f}. ' + f'Set `init_div_is_residual: false` in init config to disable this.')) | |
| if isinstance(module, nn.Linear): | |
| if hasattr(module, '_fused'): | |
| fused_init_helper_(module, init_fn_) | |
| else: | |
| init_fn_(module.weight) | |
| if (module.bias is not None): | |
| torch.nn.init.zeros_(module.bias) | |
| if ((init_div_is_residual is not False) and getattr(module, '_is_residual', False)): | |
| with torch.no_grad(): | |
| module.weight.div_(div_is_residual) | |
| elif isinstance(module, nn.Embedding): | |
| if (emb_init_std is not None): | |
| std = emb_init_std | |
| if (std == 0): | |
| warnings.warn(f'Embedding layer initialized to 0.') | |
| emb_init_fn_ = partial(torch.nn.init.normal_, mean=0.0, std=std) | |
| if (verbose > 1): | |
| warnings.warn(f'Embedding layer initialized using normal distribution with mean=0 and std={std!r}.') | |
| elif (emb_init_uniform_lim is not None): | |
| lim = emb_init_uniform_lim | |
| if isinstance(lim, Sequence): | |
| if (len(lim) > 2): | |
| raise ValueError(f'Uniform init requires a min and a max limit. User input: {lim}.') | |
| if (lim[0] == lim[1]): | |
| warnings.warn(f'Embedding layer initialized to {lim[0]}.') | |
| else: | |
| if (lim == 0): | |
| warnings.warn(f'Embedding layer initialized to 0.') | |
| lim = [(- lim), lim] | |
| (a, b) = lim | |
| emb_init_fn_ = partial(torch.nn.init.uniform_, a=a, b=b) | |
| if (verbose > 1): | |
| warnings.warn(f'Embedding layer initialized using uniform distribution in range {lim}.') | |
| else: | |
| emb_init_fn_ = init_fn_ | |
| emb_init_fn_(module.weight) | |
| elif isinstance(module, tuple(set(NORM_CLASS_REGISTRY.values()))): | |
| if (verbose > 1): | |
| warnings.warn(f'Norm weights are set to 1. If norm layer has a bias it is initialized to 0.') | |
| if (hasattr(module, 'weight') and (module.weight is not None)): | |
| torch.nn.init.ones_(module.weight) | |
| if (hasattr(module, 'bias') and (module.bias is not None)): | |
| torch.nn.init.zeros_(module.bias) | |
| elif isinstance(module, nn.MultiheadAttention): | |
| if module._qkv_same_embed_dim: | |
| assert (module.in_proj_weight is not None) | |
| assert ((module.q_proj_weight is None) and (module.k_proj_weight is None) and (module.v_proj_weight is None)) | |
| assert (d_model is not None) | |
| _d = d_model | |
| splits = (0, _d, (2 * _d), (3 * _d)) | |
| for (s, e) in zip(splits[:(- 1)], splits[1:]): | |
| init_fn_(module.in_proj_weight[s:e]) | |
| else: | |
| assert ((module.q_proj_weight is not None) and (module.k_proj_weight is not None) and (module.v_proj_weight is not None)) | |
| assert (module.in_proj_weight is None) | |
| init_fn_(module.q_proj_weight) | |
| init_fn_(module.k_proj_weight) | |
| init_fn_(module.v_proj_weight) | |
| if (module.in_proj_bias is not None): | |
| torch.nn.init.zeros_(module.in_proj_bias) | |
| if (module.bias_k is not None): | |
| torch.nn.init.zeros_(module.bias_k) | |
| if (module.bias_v is not None): | |
| torch.nn.init.zeros_(module.bias_v) | |
| init_fn_(module.out_proj.weight) | |
| if ((init_div_is_residual is not False) and getattr(module.out_proj, '_is_residual', False)): | |
| with torch.no_grad(): | |
| module.out_proj.weight.div_(div_is_residual) | |
| if (module.out_proj.bias is not None): | |
| torch.nn.init.zeros_(module.out_proj.bias) | |
| else: | |
| for _ in module.parameters(recurse=False): | |
| raise NotImplementedError(f'{module.__class__.__name__} parameters are not initialized by param_init_fn.') | |
| def _normal_init_(std, mean=0.0): | |
| return partial(torch.nn.init.normal_, mean=mean, std=std) | |
| def _normal_param_init_fn_(module: nn.Module, std: float, n_layers: int, d_model: Optional[int]=None, init_div_is_residual: Union[(int, float, str, bool)]=True, emb_init_std: Optional[float]=None, emb_init_uniform_lim: Optional[Union[(Tuple[(float, float)], float)]]=None, verbose: int=0, **kwargs): | |
| del kwargs | |
| init_fn_ = _normal_init_(std=std) | |
| if (verbose > 1): | |
| warnings.warn(f'Using torch.nn.init.normal_ init fn mean=0.0, std={std}') | |
| generic_param_init_fn_(module=module, init_fn_=init_fn_, d_model=d_model, n_layers=n_layers, init_div_is_residual=init_div_is_residual, emb_init_std=emb_init_std, emb_init_uniform_lim=emb_init_uniform_lim, verbose=verbose) | |
| def baseline_param_init_fn_(module: nn.Module, init_std: float, n_layers: int, d_model: Optional[int]=None, init_div_is_residual: Union[(int, float, str, bool)]=True, emb_init_std: Optional[float]=None, emb_init_uniform_lim: Optional[Union[(Tuple[(float, float)], float)]]=None, verbose: int=0, **kwargs): | |
| del kwargs | |
| if (init_std is None): | |
| raise ValueError("You must set model.init_config['init_std'] to a float value to use the default initialization scheme.") | |
| _normal_param_init_fn_(module=module, std=init_std, d_model=d_model, n_layers=n_layers, init_div_is_residual=init_div_is_residual, emb_init_std=emb_init_std, emb_init_uniform_lim=emb_init_uniform_lim, verbose=verbose) | |
| def small_param_init_fn_(module: nn.Module, n_layers: int, d_model: int, init_div_is_residual: Union[(int, float, str, bool)]=True, emb_init_std: Optional[float]=None, emb_init_uniform_lim: Optional[Union[(Tuple[(float, float)], float)]]=None, verbose: int=0, **kwargs): | |
| del kwargs | |
| std = math.sqrt((2 / (5 * d_model))) | |
| _normal_param_init_fn_(module=module, std=std, d_model=d_model, n_layers=n_layers, init_div_is_residual=init_div_is_residual, emb_init_std=emb_init_std, emb_init_uniform_lim=emb_init_uniform_lim, verbose=verbose) | |
| def neox_param_init_fn_(module: nn.Module, n_layers: int, d_model: int, emb_init_std: Optional[float]=None, emb_init_uniform_lim: Optional[Union[(Tuple[(float, float)], float)]]=None, verbose: int=0, **kwargs): | |
| 'From section 2.3.1 of GPT-NeoX-20B:\n\n An Open-Source AutoregressiveLanguage Model — Black et. al. (2022)\n see https://github.com/EleutherAI/gpt-neox/blob/9610391ab319403cef079b438edd016a2443af54/megatron/model/init_functions.py#L151\n and https://github.com/EleutherAI/gpt-neox/blob/main/megatron/model/transformer.py\n ' | |
| del kwargs | |
| residual_div = (n_layers / math.sqrt(10)) | |
| if (verbose > 1): | |
| warnings.warn(f'setting init_div_is_residual to {residual_div}') | |
| small_param_init_fn_(module=module, d_model=d_model, n_layers=n_layers, init_div_is_residual=residual_div, emb_init_std=emb_init_std, emb_init_uniform_lim=emb_init_uniform_lim, verbose=verbose) | |
| def kaiming_uniform_param_init_fn_(module: nn.Module, n_layers: int, d_model: Optional[int]=None, init_div_is_residual: Union[(int, float, str, bool)]=True, emb_init_std: Optional[float]=None, emb_init_uniform_lim: Optional[Union[(Tuple[(float, float)], float)]]=None, init_gain: float=0, fan_mode: str='fan_in', init_nonlinearity: str='leaky_relu', verbose: int=0, **kwargs): | |
| del kwargs | |
| if (verbose > 1): | |
| warnings.warn((f'Using nn.init.kaiming_uniform_ init fn with parameters: ' + f'a={init_gain}, mode={fan_mode}, nonlinearity={init_nonlinearity}')) | |
| kaiming_uniform_ = partial(nn.init.kaiming_uniform_, a=init_gain, mode=fan_mode, nonlinearity=init_nonlinearity) | |
| generic_param_init_fn_(module=module, init_fn_=kaiming_uniform_, d_model=d_model, n_layers=n_layers, init_div_is_residual=init_div_is_residual, emb_init_std=emb_init_std, emb_init_uniform_lim=emb_init_uniform_lim, verbose=verbose) | |
| def kaiming_normal_param_init_fn_(module: nn.Module, n_layers: int, d_model: Optional[int]=None, init_div_is_residual: Union[(int, float, str, bool)]=True, emb_init_std: Optional[float]=None, emb_init_uniform_lim: Optional[Union[(Tuple[(float, float)], float)]]=None, init_gain: float=0, fan_mode: str='fan_in', init_nonlinearity: str='leaky_relu', verbose: int=0, **kwargs): | |
| del kwargs | |
| if (verbose > 1): | |
| warnings.warn((f'Using nn.init.kaiming_normal_ init fn with parameters: ' + f'a={init_gain}, mode={fan_mode}, nonlinearity={init_nonlinearity}')) | |
| kaiming_normal_ = partial(torch.nn.init.kaiming_normal_, a=init_gain, mode=fan_mode, nonlinearity=init_nonlinearity) | |
| generic_param_init_fn_(module=module, init_fn_=kaiming_normal_, d_model=d_model, n_layers=n_layers, init_div_is_residual=init_div_is_residual, emb_init_std=emb_init_std, emb_init_uniform_lim=emb_init_uniform_lim, verbose=verbose) | |
| def xavier_uniform_param_init_fn_(module: nn.Module, n_layers: int, d_model: Optional[int]=None, init_div_is_residual: Union[(int, float, str, bool)]=True, emb_init_std: Optional[float]=None, emb_init_uniform_lim: Optional[Union[(Tuple[(float, float)], float)]]=None, init_gain: float=0, verbose: int=0, **kwargs): | |
| del kwargs | |
| xavier_uniform_ = partial(torch.nn.init.xavier_uniform_, gain=init_gain) | |
| if (verbose > 1): | |
| warnings.warn((f'Using torch.nn.init.xavier_uniform_ init fn with parameters: ' + f'gain={init_gain}')) | |
| generic_param_init_fn_(module=module, init_fn_=xavier_uniform_, d_model=d_model, n_layers=n_layers, init_div_is_residual=init_div_is_residual, emb_init_std=emb_init_std, emb_init_uniform_lim=emb_init_uniform_lim, verbose=verbose) | |
| def xavier_normal_param_init_fn_(module: nn.Module, n_layers: int, d_model: Optional[int]=None, init_div_is_residual: Union[(int, float, str, bool)]=True, emb_init_std: Optional[float]=None, emb_init_uniform_lim: Optional[Union[(Tuple[(float, float)], float)]]=None, init_gain: float=0, verbose: int=0, **kwargs): | |
| xavier_normal_ = partial(torch.nn.init.xavier_normal_, gain=init_gain) | |
| if (verbose > 1): | |
| warnings.warn((f'Using torch.nn.init.xavier_normal_ init fn with parameters: ' + f'gain={init_gain}')) | |
| generic_param_init_fn_(module=module, init_fn_=xavier_normal_, d_model=d_model, n_layers=n_layers, init_div_is_residual=init_div_is_residual, emb_init_std=emb_init_std, emb_init_uniform_lim=emb_init_uniform_lim, verbose=verbose) | |
| MODEL_INIT_REGISTRY = {'default_': torch_default_param_init_fn_, 'baseline_': baseline_param_init_fn_, 'kaiming_uniform_': kaiming_uniform_param_init_fn_, 'kaiming_normal_': kaiming_normal_param_init_fn_, 'neox_init_': neox_param_init_fn_, 'small_init_': small_param_init_fn_, 'xavier_uniform_': xavier_uniform_param_init_fn_, 'xavier_normal_': xavier_normal_param_init_fn_} | |