File size: 5,156 Bytes
de5c30a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
"""Helper functions for computing parameter counts for MPT model.

Use if generic `sum(p.numel() for p in self.parameters())`
style computation does not account for MoE parameter sharding.
The helper functions in this file account for MoE parameter
sharding in the parameter count calculation. The functions below
calculate the total parameter count and the active parameter count.
Note: MPT has both n_total_params and n_active_params methods.
"""
from typing import Union
from torch import Tensor, nn
from torch.distributed._tensor import DTensor
from .layers_registry import ffns_with_megablocks

def module_n_params(module: nn.Module) -> int:
    """Gets the number of parameters in this module excluding child modules.

    Args:
        module (nn.Module): Module of which we get the number of parameters.

    Returns:
        An int for the number of parameters in this module.
    """
    n_params = 0
    for p in module.parameters(recurse=False):
        n_params += p.numel()
    return n_params

def _dtensor_safe_check_numel(tensor: Union[Tensor, DTensor]) -> int:
    if isinstance(tensor, DTensor):
        tensor = tensor._local_tensor
    return tensor.numel()

def megablocks_n_total_params(mpt_model) -> int:
    """Calculates the number of parameters in a MegaBlocks enabled MPT model.

    MoE experts are sharded across workers. This function scans for MegaBlocks
    modules then multiplies expert params count by MoE world size.

    Args:
        mpt_model (ComposerMPTCausalLM): MPT model of which the number of
            parameters is calculated.

    Returns:
        An int for the total number of parameters in this MPT model.
    """
    import megablocks
    moe_world_size = mpt_model.config.ffn_config.get('moe_world_size')
    n_total_params = 0
    for module in mpt_model.modules():
        if isinstance(module, (megablocks.layers.mlp.SparseMLP, megablocks.layers.mlp.MLP)):
            n_w1 = _dtensor_safe_check_numel(module.w1)
            n_total_params += n_w1 * moe_world_size
            n_w2 = _dtensor_safe_check_numel(module.w2)
            n_total_params += n_w2 * moe_world_size
            if hasattr(module, 'v1'):
                n_v1 = _dtensor_safe_check_numel(module.v1)
                n_total_params += n_v1 * moe_world_size
        else:
            n_total_params += module_n_params(module)
    return n_total_params

def megablocks_n_active_params(mpt_model) -> int:
    """Calculates the number of active parameters in a MegaBlocks enabled MPT.

    This requires we calculate the number of elements per expert and
    multiply this by top k.

    Args:
        mpt_model (ComposerMPTCausalLM): MPT model of which the number of
            active parameters is calculated.

    Returns:
        An int for the active number of parameters in this MPT model.
    """
    import megablocks
    moe_num_experts = mpt_model.config.ffn_config.get('moe_num_experts', 1)
    moe_world_size = mpt_model.config.ffn_config.get('moe_world_size')
    local_experts = moe_num_experts / moe_world_size
    moe_top_k = mpt_model.config.ffn_config.get('moe_top_k', 1)
    n_active_params = 0
    for module in mpt_model.modules():
        if isinstance(module, (megablocks.layers.mlp.SparseMLP, megablocks.layers.mlp.MLP)):
            n_w1 = _dtensor_safe_check_numel(module.w1)
            n_active_params += int(n_w1 / local_experts * moe_top_k)
            n_w2 = _dtensor_safe_check_numel(module.w2)
            n_active_params += int(n_w2 / local_experts * moe_top_k)
            if hasattr(module, 'v1'):
                n_v1 = _dtensor_safe_check_numel(module.v1)
                n_active_params += int(n_v1 / local_experts * moe_top_k)
        else:
            n_active_params += module_n_params(module)
    return n_active_params

def mpt_get_total_params(mpt_model) -> int:
    """Calculates the total parameter count of an MPT model.

    Note: Must be called before model parameters are sharded by FSDP.

    Args:
        mpt_model (ComposerMPTCausalLM): MPT model of which the number of
            active parameters is calculated.

    Returns:
        An int for the total number of parameters in this MPT model.
    """
    if mpt_model.config.ffn_config['ffn_type'] in ffns_with_megablocks:
        return megablocks_n_total_params(mpt_model)
    else:
        return sum((p.numel() for p in mpt_model.parameters()))

def mpt_get_active_params(mpt_model) -> int:
    """Calculates the total parameter count of an MPT model.

    Note: Must be called before model parameters are sharded by FSDP.

    Args:
        mpt_model (ComposerMPTCausalLM): MPT model of which the number of
            active parameters is calculated.

    Returns:
        An int for the active number of parameters in this MPT model.
    """
    if mpt_model.config.ffn_config['ffn_type'] in ffns_with_megablocks:
        params = megablocks_n_active_params(mpt_model)
    else:
        params = sum((p.numel() for p in mpt_model.parameters()))
    if not mpt_model.model.transformer.config.tie_word_embeddings:
        params -= _dtensor_safe_check_numel(mpt_model.model.transformer.wte.weight)
    return params