DS-replit-3b-ternary-100B-star-27 / config_moe_args.py
kaizen9's picture
Upload model checkpoints directly from S3
de5c30a verified
raw
history blame contribute delete
No virus
7.52 kB
"""Helper function to configure MPT with MoEs."""
import inspect
from typing import Callable, Optional, Union
import torch
from packaging import version
from torch import distributed
from torch.distributed._tensor import DeviceMesh
from .layers_registry import ffns_with_megablocks
from .ffn import resolve_ffn_hidden_size
def create_process_group_ranks(ranks: tuple[int, ...]):
"""Creates a new distributed group.
Used in create_set_process_group and create_mod_process_group methods below.
This function is an alternative to `distributed.new_group(ranks)`.
Args:
ranks (tuple[int, ...]): Tuple of ranks of group members.
Returns:
A handle of distributed group that can be given to collective calls.
"""
ranks_gather_list = [None for _ in range(distributed.get_world_size())]
distributed.all_gather_object(ranks_gather_list, ranks)
ranks_per_subgroup = list(set(ranks_gather_list))
group, _ = distributed.distributed_c10d.new_subgroups_by_enumeration(ranks_per_subgroup)
return group
def create_set_process_group(k: int):
"""Creates a new distributed group using sets of k GPUs.
For example, if you have 16 GPUs and input k=4, the resulting process groups
will have ranks:
process group 0 ranks: [ 0, 1, 2, 3]
process group 1 ranks: [ 4, 5, 6, 7]
process group 2 ranks: [ 8, 9, 10, 11]
process group 3 ranks: [12, 13, 14, 15]
Args:
k (int): Number of GPUs to use in set size.
Returns:
A handle of distributed group that can be given to collective calls.
"""
world_size = distributed.get_world_size()
if world_size % k != 0:
raise RuntimeError(f'world_size={world_size!r} must be divisible by k={k!r}.')
start = distributed.get_rank() // k * k
ranks = tuple(range(start, start + k))
return create_process_group_ranks(ranks)
def get_megablocks_device_mesh(device_mesh_cfg: Optional[tuple[int, ...]], moe_world_size: int, world_size: int) -> DeviceMesh:
"""Helper function to get the device mesh for MegaBlocks MoE.
Args:
device_mesh_cfg (Optional[tuple[int, ...]]): The device mesh configuration specification.
moe_world_size (int): The MoE world size.
world_size (int): The world size.
Raises:
ValueError: If the device mesh configuration is not valid.
Returns:
The device mesh for MegaBlocks MoE.
"""
from torch.distributed._tensor.device_mesh import init_device_mesh
if device_mesh_cfg is None or len(device_mesh_cfg) == 1:
if device_mesh_cfg is not None:
world_size = device_mesh_cfg[0]
sharding_group_dim = world_size // moe_world_size
device_mesh = init_device_mesh('cuda', (sharding_group_dim, moe_world_size), mesh_dim_names=('weight_parallel', 'expert_parallel'))
else:
raise ValueError(f'device_mesh_cfg={device_mesh_cfg!r} must be length 1')
return device_mesh
def config_megablocks_moe_args(ffn_config: dict, d_model: int, expansion_ratio: Union[int, float], n_layers: int, get_device_mesh: Callable) -> dict:
"""Configures `ffn_config` for MegaBlocks MoE.
We prepare all necessary arguments for `megablocks.layers.arguments.Arguments` so that process
groups can be initialized and shared across all blocks in the network.
Args:
ffn_config (dict): FFN configuration before the MegaBlocks MoE is configured.
d_model (int): Hidden size of the network.
expansion_ratio (Union[int, float]): Expansion ratio in FFN.
n_layers (int): Number of blocks used in the network.
get_device_mesh (Callable): Function to get the device mesh. Takes in the device mesh config and the MoE world size.
Returns:
ffn_config (dict): FFN configuration with MegaBlocks MoE configured.
"""
try:
import megablocks
except:
raise RuntimeError('Requirements for MegaBlocks not installed; see install instructions in `README.md`.')
ffn_config.setdefault('fp16', False)
ffn_config.setdefault('bf16', False)
ffn_config['num_layers'] = n_layers
ffn_type = ffn_config.pop('ffn_type')
fc_type = ffn_config.pop('fc_type')
ffn_act_fn = ffn_config.pop('ffn_act_fn', None)
world_size = 1
moe_world_size = ffn_config.pop('moe_world_size')
device_mesh = None
device_mesh_cfg = ffn_config.pop('device_mesh', None)
if moe_world_size > 1:
if version.parse(torch.__version__.split('.dev')[0]) < version.parse('2.2.0'):
raise RuntimeError('MoE world size > 1 is not supported in torch version {torch.__version__}<2.2.')
world_size = distributed.get_world_size()
if world_size < moe_world_size or world_size % moe_world_size:
raise ValueError(f'Invalid world size configuration: world_size={world_size!r} and moe_world_size={moe_world_size!r}')
device_mesh = get_device_mesh(device_mesh_cfg=device_mesh_cfg, moe_world_size=moe_world_size, world_size=world_size)
ffn_config['moe_expert_model_parallelism'] = True
ffn_config['expert_parallel_group'] = device_mesh['expert_parallel'].get_group(0)
lbl_process_group = ffn_config.get('lbl_process_group', None)
if lbl_process_group is not None:
if lbl_process_group == 'expert_group':
lbl_process_group = ffn_config['expert_parallel_group']
elif lbl_process_group == 'global_group':
lbl_process_group = distributed.group.WORLD
elif isinstance(lbl_process_group, int):
if lbl_process_group > 1:
lbl_process_group = create_set_process_group(lbl_process_group)
else:
lbl_process_group = None
elif not isinstance(lbl_process_group, distributed.ProcessGroup):
raise ValueError(f'Unknown lbl_process_group={lbl_process_group!r}. Options are: none | a process group | ``expert_group`` | ``global_group`` | <GROUP_SIZE>.')
ffn_config['lbl_process_group'] = lbl_process_group
ffn_hidden_size = resolve_ffn_hidden_size(d_model, expansion_ratio)
ffn_config.setdefault('ffn_hidden_size', ffn_hidden_size)
args_to_keep_in_ffn_config = inspect.signature(megablocks.layers.arguments.Arguments).parameters
ffn_config = {k: v for k, v in ffn_config.items() if k in args_to_keep_in_ffn_config}
args = megablocks.layers.arguments.Arguments(hidden_size=d_model, **ffn_config)
ffn_config['args'] = args
ffn_config['device_mesh'] = device_mesh
ffn_config['moe_world_size'] = moe_world_size
ffn_config['ffn_type'] = ffn_type
ffn_config['fc_type'] = fc_type
ffn_config['ffn_act_fn'] = ffn_act_fn
return ffn_config
def config_moe_args(ffn_config: dict, d_model: int, expansion_ratio: Union[int, float], n_layers: int) -> dict:
"""Configures `ffn_config` for MoE.
Args:
ffn_config (dict): FFN configuration before the MoE is configured.
d_model (int): Hidden size of the network.
expansion_ratio (int, float): Expansion ratio in FFN.
n_layers (int): Number of blocks used in the network.
Returns:
ffn_config (dict): FFN configuration with MoE configured.
"""
if ffn_config['ffn_type'] in ffns_with_megablocks:
return config_megablocks_moe_args(ffn_config=ffn_config, d_model=d_model, expansion_ratio=expansion_ratio, n_layers=n_layers, get_device_mesh=get_megablocks_device_mesh)
else:
raise ValueError(f"Invalid ffn_type ({ffn_config['ffn_type']}).")