lhzstar
initial commits
6bc94ac
raw
history blame
1.15 kB
import torch
_output_ref = None
_replicas_ref = None
def data_parallel_workaround(model, *input):
global _output_ref
global _replicas_ref
device_ids = list(range(torch.cuda.device_count()))
output_device = device_ids[0]
replicas = torch.nn.parallel.replicate(model, device_ids)
# input.shape = (num_args, batch, ...)
inputs = torch.nn.parallel.scatter(input, device_ids)
# inputs.shape = (num_gpus, num_args, batch/num_gpus, ...)
replicas = replicas[:len(inputs)]
outputs = torch.nn.parallel.parallel_apply(replicas, inputs)
y_hat = torch.nn.parallel.gather(outputs, output_device)
_output_ref = outputs
_replicas_ref = replicas
return y_hat
class ValueWindow():
def __init__(self, window_size=100):
self._window_size = window_size
self._values = []
def append(self, x):
self._values = self._values[-(self._window_size - 1):] + [x]
@property
def sum(self):
return sum(self._values)
@property
def count(self):
return len(self._values)
@property
def average(self):
return self.sum / max(1, self.count)
def reset(self):
self._values = []