|
|
|
|
|
import torch.cuda as cuda |
|
import torch.nn as nn |
|
import torch |
|
import collections |
|
from torch.nn.parallel._functions import Gather |
|
|
|
|
|
__all__ = ['UserScatteredDataParallel', 'user_scattered_collate', 'async_copy_to'] |
|
|
|
|
|
def async_copy_to(obj, dev, main_stream=None): |
|
if torch.is_tensor(obj): |
|
v = obj.cuda(dev, non_blocking=True) |
|
if main_stream is not None: |
|
v.data.record_stream(main_stream) |
|
return v |
|
elif isinstance(obj, collections.Mapping): |
|
return {k: async_copy_to(o, dev, main_stream) for k, o in obj.items()} |
|
elif isinstance(obj, collections.Sequence): |
|
return [async_copy_to(o, dev, main_stream) for o in obj] |
|
else: |
|
return obj |
|
|
|
|
|
def dict_gather(outputs, target_device, dim=0): |
|
""" |
|
Gathers variables from different GPUs on a specified device |
|
(-1 means the CPU), with dictionary support. |
|
""" |
|
def gather_map(outputs): |
|
out = outputs[0] |
|
if torch.is_tensor(out): |
|
|
|
if out.dim() == 0: |
|
outputs = [o.unsqueeze(0) for o in outputs] |
|
return Gather.apply(target_device, dim, *outputs) |
|
elif out is None: |
|
return None |
|
elif isinstance(out, collections.Mapping): |
|
return {k: gather_map([o[k] for o in outputs]) for k in out} |
|
elif isinstance(out, collections.Sequence): |
|
return type(out)(map(gather_map, zip(*outputs))) |
|
return gather_map(outputs) |
|
|
|
|
|
class DictGatherDataParallel(nn.DataParallel): |
|
def gather(self, outputs, output_device): |
|
return dict_gather(outputs, output_device, dim=self.dim) |
|
|
|
|
|
class UserScatteredDataParallel(DictGatherDataParallel): |
|
def scatter(self, inputs, kwargs, device_ids): |
|
assert len(inputs) == 1 |
|
inputs = inputs[0] |
|
inputs = _async_copy_stream(inputs, device_ids) |
|
inputs = [[i] for i in inputs] |
|
assert len(kwargs) == 0 |
|
kwargs = [{} for _ in range(len(inputs))] |
|
|
|
return inputs, kwargs |
|
|
|
|
|
def user_scattered_collate(batch): |
|
return batch |
|
|
|
|
|
def _async_copy(inputs, device_ids): |
|
nr_devs = len(device_ids) |
|
assert type(inputs) in (tuple, list) |
|
assert len(inputs) == nr_devs |
|
|
|
outputs = [] |
|
for i, dev in zip(inputs, device_ids): |
|
with cuda.device(dev): |
|
outputs.append(async_copy_to(i, dev)) |
|
|
|
return tuple(outputs) |
|
|
|
|
|
def _async_copy_stream(inputs, device_ids): |
|
nr_devs = len(device_ids) |
|
assert type(inputs) in (tuple, list) |
|
assert len(inputs) == nr_devs |
|
|
|
outputs = [] |
|
streams = [_get_stream(d) for d in device_ids] |
|
for i, dev, stream in zip(inputs, device_ids, streams): |
|
with cuda.device(dev): |
|
main_stream = cuda.current_stream() |
|
with cuda.stream(stream): |
|
outputs.append(async_copy_to(i, dev, main_stream=main_stream)) |
|
main_stream.wait_stream(stream) |
|
|
|
return outputs |
|
|
|
|
|
"""Adapted from: torch/nn/parallel/_functions.py""" |
|
|
|
_streams = None |
|
|
|
|
|
def _get_stream(device): |
|
"""Gets a background stream for copying between CPU and GPU""" |
|
global _streams |
|
if device == -1: |
|
return None |
|
if _streams is None: |
|
_streams = [None] * cuda.device_count() |
|
if _streams[device] is None: _streams[device] = cuda.Stream(device) |
|
return _streams[device] |
|
|