Spaces:
Runtime error
Runtime error
""" | |
Helpers for distributed training. | |
""" | |
import os | |
import socket | |
import torch as th | |
import torch.distributed as dist | |
from torch.distributed import barrier, is_initialized, broadcast | |
# Change this to reflect your cluster layout. | |
# The GPU for a given rank is (rank % GPUS_PER_NODE). | |
GPUS_PER_NODE = 8 | |
SETUP_RETRY_COUNT = 3 | |
import datetime | |
import os | |
import socket | |
from contextlib import closing | |
def find_free_port() -> int: | |
with closing(socket.socket(socket.AF_INET, socket.SOCK_STREAM)) as s: | |
s.bind(("", 0)) | |
s.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) | |
return s.getsockname()[1] | |
def check_if_port_open(port: int) -> bool: | |
with closing(socket.socket(socket.AF_INET, socket.SOCK_STREAM)) as s: | |
try: | |
s.bind(("", port)) | |
s.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) | |
return True | |
except OSError: | |
return False | |
def initialized(): | |
return dist.is_initialized() | |
def finalize(): | |
if dist.is_initialized(): | |
dist.destroy_process_group() | |
def initialize(): | |
is_mpirun = not ( | |
"RANK" in os.environ | |
and "WORLD_SIZE" in os.environ | |
and "MASTER_ADDR" in os.environ | |
and "MASTER_PORT" in os.environ | |
) | |
if is_mpirun: | |
from mpi4py import MPI | |
import subprocess | |
comm = MPI.COMM_WORLD | |
rank = comm.Get_rank() | |
world_size = comm.Get_size() | |
master_addr = None | |
master_port = None | |
if rank == 0: | |
hostname_cmd = ["hostname -I"] | |
result = subprocess.check_output(hostname_cmd, shell=True) | |
master_addr = result.decode("utf-8").split()[0] | |
base_port = os.environ.get( | |
"MASTER_PORT", "29500" | |
) # TORCH_DISTRIBUTED_DEFAULT_PORT | |
if check_if_port_open(int(base_port)): | |
master_port = base_port | |
else: | |
master_port = find_free_port() | |
master_addr = comm.bcast(master_addr, root=0) | |
master_port = comm.bcast(master_port, root=0) | |
# Determine local rank by assuming hostnames are unique | |
proc_name = MPI.Get_processor_name() | |
all_procs = comm.allgather(proc_name) | |
local_rank = sum([i == proc_name for i in all_procs[:rank]]) | |
uniq_proc_names = set(all_procs) | |
host_rank = sorted(uniq_proc_names).index(proc_name) | |
os.environ["LOCAL_RANK"] = str(local_rank) | |
os.environ["HOST_RANK"] = str(host_rank) | |
os.environ["NUM_HOSTS"] = str(len(uniq_proc_names)) | |
os.environ["RANK"] = str(rank) | |
os.environ["WORLD_SIZE"] = str(world_size) | |
os.environ["MASTER_ADDR"] = master_addr | |
os.environ["MASTER_PORT"] = str(master_port) | |
os.environ["OMP_NUM_THREADS"] = "1" | |
# Initialize torch distributed | |
backend = "gloo" if not th.cuda.is_available() else "nccl" | |
dist.init_process_group(backend=backend, timeout=datetime.timedelta(0, 3600)) | |
th.cuda.set_device(int(os.environ.get('LOCAL_RANK', '0'))) | |
if is_mpirun and dist.get_rank() == 0: | |
print("Distributed setup") | |
print("LOCAL_RANK", os.environ['LOCAL_RANK']) | |
print("HOST_RANK", os.environ['HOST_RANK']) | |
print("NUM_HOSTS", os.environ['NUM_HOSTS']) | |
print("WORLD_SIZE", os.environ['WORLD_SIZE']) | |
def local_host_gather(data): | |
from mpi4py import MPI | |
comm = MPI.COMM_WORLD | |
host_rank = os.environ["HOST_RANK"] | |
all_data = comm.allgather((host_rank, data)) | |
return [d[1] for d in all_data if d[0] == host_rank] | |
def in_distributed_mode(): | |
return dist is not None | |
def is_master(): | |
return get_rank() == 0 | |
def is_local_master(): | |
return get_local_rank() == 0 | |
def get_rank(): | |
return dist.get_rank() if in_distributed_mode() else 0 | |
def get_local_rank(): | |
return int(os.environ["LOCAL_RANK"]) | |
def worker_host_idx(): | |
return int(os.environ["HOST_RANK"]) | |
def num_hosts(): | |
return int(os.environ['NUM_HOSTS']) | |
def get_world_size(): | |
return dist.get_world_size() if in_distributed_mode() else 1 | |
def gpu_visible_device_list(): | |
return str(dist.get_rank()) if in_distributed_mode() else None | |
def get_device(): | |
""" | |
Get the device to use for torch.distributed. | |
""" | |
if th.cuda.is_available(): | |
return th.device("cuda") | |
return th.device("cpu") | |
def sync_params(params): | |
""" | |
Synchronize a sequence of Tensors across ranks from rank 0. | |
""" | |
for p in params: | |
with th.no_grad(): | |
dist.broadcast(p, 0) | |
def print0(*args, **kwargs): | |
if get_rank() == 0: | |
print(*args, **kwargs) | |
def allreduce(t: th.Tensor, async_op=False): | |
if dist.is_initialized(): | |
if not t.is_cuda: | |
cu = t.detach().cuda() | |
ret = dist.all_reduce(cu, async_op=async_op) | |
t.copy_(cu.cpu()) | |
else: | |
ret = dist.all_reduce(t, async_op=async_op) | |
return ret | |
return None | |
def allgather(t: th.Tensor, cat=True): | |
if dist.is_initialized(): | |
if not t.is_cuda: | |
t = t.cuda() | |
ls = [th.empty_like(t) for _ in range(get_world_size())] | |
dist.all_gather(ls, t) | |
else: | |
ls = [t] | |
if cat: | |
ls = th.cat(ls, dim=0) | |
return ls | |