gene-hoi-denoising / utils /dist_util.py
meow
init
d6d3a5b
raw
history blame
2.07 kB
"""
Helpers for distributed training.
"""
import socket
import torch as th
import torch.distributed as dist
# Change this to reflect your cluster layout.
# The GPU for a given rank is (rank % GPUS_PER_NODE).
GPUS_PER_NODE = 8
SETUP_RETRY_COUNT = 3
used_device = 0
def setup_dist(device=0):
"""
Setup a distributed process group.
"""
global used_device
used_device = device
if dist.is_initialized():
return
# os.environ["CUDA_VISIBLE_DEVICES"] = str(device) # f"{MPI.COMM_WORLD.Get_rank() % GPUS_PER_NODE}"
# comm = MPI.COMM_WORLD
# backend = "gloo" if not th.cuda.is_available() else "nccl"
# if backend == "gloo":
# hostname = "localhost"
# else:
# hostname = socket.gethostbyname(socket.getfqdn())
# os.environ["MASTER_ADDR"] = comm.bcast(hostname, root=0)
# os.environ["RANK"] = str(comm.rank)
# os.environ["WORLD_SIZE"] = str(comm.size)
# port = comm.bcast(_find_free_port(), root=used_device)
# os.environ["MASTER_PORT"] = str(port)
# dist.init_process_group(backend=backend, init_method="env://")
def dev():
"""
Get the device to use for torch.distributed.
"""
global used_device
if th.cuda.is_available() and used_device>=0:
return th.device(f"cuda:{used_device}")
return th.device("cpu")
def load_state_dict(path, **kwargs):
"""
Load a PyTorch file without redundant fetches across MPI ranks.
"""
return th.load(path, **kwargs)
def sync_params(params):
"""
Synchronize a sequence of Tensors across ranks from rank 0.
"""
for p in params:
with th.no_grad():
dist.broadcast(p, 0)
def _find_free_port():
try:
s = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
s.bind(("", 0))
s.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
return s.getsockname()[1]
finally:
s.close()
def reduce_mean(tensor, nprocs):
rt = tensor.clone()
dist.all_reduce(rt, op=dist.ReduceOp.SUM)
rt = rt / nprocs
return rt