File size: 5,156 Bytes
de5c30a |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 |
"""Helper functions for computing parameter counts for MPT model.
Use if generic `sum(p.numel() for p in self.parameters())`
style computation does not account for MoE parameter sharding.
The helper functions in this file account for MoE parameter
sharding in the parameter count calculation. The functions below
calculate the total parameter count and the active parameter count.
Note: MPT has both n_total_params and n_active_params methods.
"""
from typing import Union
from torch import Tensor, nn
from torch.distributed._tensor import DTensor
from .layers_registry import ffns_with_megablocks
def module_n_params(module: nn.Module) -> int:
"""Gets the number of parameters in this module excluding child modules.
Args:
module (nn.Module): Module of which we get the number of parameters.
Returns:
An int for the number of parameters in this module.
"""
n_params = 0
for p in module.parameters(recurse=False):
n_params += p.numel()
return n_params
def _dtensor_safe_check_numel(tensor: Union[Tensor, DTensor]) -> int:
if isinstance(tensor, DTensor):
tensor = tensor._local_tensor
return tensor.numel()
def megablocks_n_total_params(mpt_model) -> int:
"""Calculates the number of parameters in a MegaBlocks enabled MPT model.
MoE experts are sharded across workers. This function scans for MegaBlocks
modules then multiplies expert params count by MoE world size.
Args:
mpt_model (ComposerMPTCausalLM): MPT model of which the number of
parameters is calculated.
Returns:
An int for the total number of parameters in this MPT model.
"""
import megablocks
moe_world_size = mpt_model.config.ffn_config.get('moe_world_size')
n_total_params = 0
for module in mpt_model.modules():
if isinstance(module, (megablocks.layers.mlp.SparseMLP, megablocks.layers.mlp.MLP)):
n_w1 = _dtensor_safe_check_numel(module.w1)
n_total_params += n_w1 * moe_world_size
n_w2 = _dtensor_safe_check_numel(module.w2)
n_total_params += n_w2 * moe_world_size
if hasattr(module, 'v1'):
n_v1 = _dtensor_safe_check_numel(module.v1)
n_total_params += n_v1 * moe_world_size
else:
n_total_params += module_n_params(module)
return n_total_params
def megablocks_n_active_params(mpt_model) -> int:
"""Calculates the number of active parameters in a MegaBlocks enabled MPT.
This requires we calculate the number of elements per expert and
multiply this by top k.
Args:
mpt_model (ComposerMPTCausalLM): MPT model of which the number of
active parameters is calculated.
Returns:
An int for the active number of parameters in this MPT model.
"""
import megablocks
moe_num_experts = mpt_model.config.ffn_config.get('moe_num_experts', 1)
moe_world_size = mpt_model.config.ffn_config.get('moe_world_size')
local_experts = moe_num_experts / moe_world_size
moe_top_k = mpt_model.config.ffn_config.get('moe_top_k', 1)
n_active_params = 0
for module in mpt_model.modules():
if isinstance(module, (megablocks.layers.mlp.SparseMLP, megablocks.layers.mlp.MLP)):
n_w1 = _dtensor_safe_check_numel(module.w1)
n_active_params += int(n_w1 / local_experts * moe_top_k)
n_w2 = _dtensor_safe_check_numel(module.w2)
n_active_params += int(n_w2 / local_experts * moe_top_k)
if hasattr(module, 'v1'):
n_v1 = _dtensor_safe_check_numel(module.v1)
n_active_params += int(n_v1 / local_experts * moe_top_k)
else:
n_active_params += module_n_params(module)
return n_active_params
def mpt_get_total_params(mpt_model) -> int:
"""Calculates the total parameter count of an MPT model.
Note: Must be called before model parameters are sharded by FSDP.
Args:
mpt_model (ComposerMPTCausalLM): MPT model of which the number of
active parameters is calculated.
Returns:
An int for the total number of parameters in this MPT model.
"""
if mpt_model.config.ffn_config['ffn_type'] in ffns_with_megablocks:
return megablocks_n_total_params(mpt_model)
else:
return sum((p.numel() for p in mpt_model.parameters()))
def mpt_get_active_params(mpt_model) -> int:
"""Calculates the total parameter count of an MPT model.
Note: Must be called before model parameters are sharded by FSDP.
Args:
mpt_model (ComposerMPTCausalLM): MPT model of which the number of
active parameters is calculated.
Returns:
An int for the active number of parameters in this MPT model.
"""
if mpt_model.config.ffn_config['ffn_type'] in ffns_with_megablocks:
params = megablocks_n_active_params(mpt_model)
else:
params = sum((p.numel() for p in mpt_model.parameters()))
if not mpt_model.model.transformer.config.tie_word_embeddings:
params -= _dtensor_safe_check_numel(mpt_model.model.transformer.wte.weight)
return params |