tvar-demo-test / dist.py
michellemoorre's picture
Initial commit
6c4dee3
"""
Helpers for distributed training.
"""
import os
import socket
import torch as th
import torch.distributed as dist
from torch.distributed import barrier, is_initialized, broadcast
# Change this to reflect your cluster layout.
# The GPU for a given rank is (rank % GPUS_PER_NODE).
GPUS_PER_NODE = 8
SETUP_RETRY_COUNT = 3
import datetime
import os
import socket
from contextlib import closing
def find_free_port() -> int:
with closing(socket.socket(socket.AF_INET, socket.SOCK_STREAM)) as s:
s.bind(("", 0))
s.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
return s.getsockname()[1]
def check_if_port_open(port: int) -> bool:
with closing(socket.socket(socket.AF_INET, socket.SOCK_STREAM)) as s:
try:
s.bind(("", port))
s.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
return True
except OSError:
return False
def initialized():
return dist.is_initialized()
def finalize():
if dist.is_initialized():
dist.destroy_process_group()
def initialize():
is_mpirun = not (
"RANK" in os.environ
and "WORLD_SIZE" in os.environ
and "MASTER_ADDR" in os.environ
and "MASTER_PORT" in os.environ
)
if is_mpirun:
from mpi4py import MPI
import subprocess
comm = MPI.COMM_WORLD
rank = comm.Get_rank()
world_size = comm.Get_size()
master_addr = None
master_port = None
if rank == 0:
hostname_cmd = ["hostname -I"]
result = subprocess.check_output(hostname_cmd, shell=True)
master_addr = result.decode("utf-8").split()[0]
base_port = os.environ.get(
"MASTER_PORT", "29500"
) # TORCH_DISTRIBUTED_DEFAULT_PORT
if check_if_port_open(int(base_port)):
master_port = base_port
else:
master_port = find_free_port()
master_addr = comm.bcast(master_addr, root=0)
master_port = comm.bcast(master_port, root=0)
# Determine local rank by assuming hostnames are unique
proc_name = MPI.Get_processor_name()
all_procs = comm.allgather(proc_name)
local_rank = sum([i == proc_name for i in all_procs[:rank]])
uniq_proc_names = set(all_procs)
host_rank = sorted(uniq_proc_names).index(proc_name)
os.environ["LOCAL_RANK"] = str(local_rank)
os.environ["HOST_RANK"] = str(host_rank)
os.environ["NUM_HOSTS"] = str(len(uniq_proc_names))
os.environ["RANK"] = str(rank)
os.environ["WORLD_SIZE"] = str(world_size)
os.environ["MASTER_ADDR"] = master_addr
os.environ["MASTER_PORT"] = str(master_port)
os.environ["OMP_NUM_THREADS"] = "1"
# Initialize torch distributed
backend = "gloo" if not th.cuda.is_available() else "nccl"
dist.init_process_group(backend=backend, timeout=datetime.timedelta(0, 3600))
th.cuda.set_device(int(os.environ.get('LOCAL_RANK', '0')))
if is_mpirun and dist.get_rank() == 0:
print("Distributed setup")
print("LOCAL_RANK", os.environ['LOCAL_RANK'])
print("HOST_RANK", os.environ['HOST_RANK'])
print("NUM_HOSTS", os.environ['NUM_HOSTS'])
print("WORLD_SIZE", os.environ['WORLD_SIZE'])
def local_host_gather(data):
from mpi4py import MPI
comm = MPI.COMM_WORLD
host_rank = os.environ["HOST_RANK"]
all_data = comm.allgather((host_rank, data))
return [d[1] for d in all_data if d[0] == host_rank]
def in_distributed_mode():
return dist is not None
def is_master():
return get_rank() == 0
def is_local_master():
return get_local_rank() == 0
def get_rank():
return dist.get_rank() if in_distributed_mode() else 0
def get_local_rank():
return int(os.environ["LOCAL_RANK"])
def worker_host_idx():
return int(os.environ["HOST_RANK"])
def num_hosts():
return int(os.environ['NUM_HOSTS'])
def get_world_size():
return dist.get_world_size() if in_distributed_mode() else 1
def gpu_visible_device_list():
return str(dist.get_rank()) if in_distributed_mode() else None
def get_device():
"""
Get the device to use for torch.distributed.
"""
if th.cuda.is_available():
return th.device("cuda")
return th.device("cpu")
def sync_params(params):
"""
Synchronize a sequence of Tensors across ranks from rank 0.
"""
for p in params:
with th.no_grad():
dist.broadcast(p, 0)
def print0(*args, **kwargs):
if get_rank() == 0:
print(*args, **kwargs)
def allreduce(t: th.Tensor, async_op=False):
if dist.is_initialized():
if not t.is_cuda:
cu = t.detach().cuda()
ret = dist.all_reduce(cu, async_op=async_op)
t.copy_(cu.cpu())
else:
ret = dist.all_reduce(t, async_op=async_op)
return ret
return None
def allgather(t: th.Tensor, cat=True):
if dist.is_initialized():
if not t.is_cuda:
t = t.cuda()
ls = [th.empty_like(t) for _ in range(get_world_size())]
dist.all_gather(ls, t)
else:
ls = [t]
if cat:
ls = th.cat(ls, dim=0)
return ls