Spaces:
Build error
Build error
from __future__ import division | |
from __future__ import unicode_literals | |
import torch | |
def get_param_buffer_for_ema(model, | |
update_buffer=False, | |
required_buffers=['running_mean', 'running_var']): | |
params = model.parameters() | |
all_param_buffer = [p for p in params if p.requires_grad] | |
if update_buffer: | |
named_buffers = model.named_buffers() | |
for key, value in named_buffers: | |
for buffer_name in required_buffers: | |
if buffer_name in key: | |
all_param_buffer.append(value) | |
break | |
return all_param_buffer | |
class ExponentialMovingAverage: | |
""" | |
Maintains (exponential) moving average of a set of parameters. | |
""" | |
def __init__(self, parameters, decay, use_num_updates=True): | |
""" | |
Args: | |
parameters: Iterable of `torch.nn.Parameter`; usually the result of | |
`model.parameters()`. | |
decay: The exponential decay. | |
use_num_updates: Whether to use number of updates when computing | |
averages. | |
""" | |
if decay < 0.0 or decay > 1.0: | |
raise ValueError('Decay must be between 0 and 1') | |
self.decay = decay | |
self.num_updates = 0 if use_num_updates else None | |
self.shadow_params = [p.clone().detach() for p in parameters] | |
self.collected_params = [] | |
def update(self, parameters): | |
""" | |
Update currently maintained parameters. | |
Call this every time the parameters are updated, such as the result of | |
the `optimizer.step()` call. | |
Args: | |
parameters: Iterable of `torch.nn.Parameter`; usually the same set of | |
parameters used to initialize this object. | |
""" | |
decay = self.decay | |
if self.num_updates is not None: | |
self.num_updates += 1 | |
decay = min(decay, | |
(1 + self.num_updates) / (10 + self.num_updates)) | |
one_minus_decay = 1.0 - decay | |
with torch.no_grad(): | |
for s_param, param in zip(self.shadow_params, parameters): | |
s_param.sub_(one_minus_decay * (s_param - param)) | |
def copy_to(self, parameters): | |
""" | |
Copy current parameters into given collection of parameters. | |
Args: | |
parameters: Iterable of `torch.nn.Parameter`; the parameters to be | |
updated with the stored moving averages. | |
""" | |
for s_param, param in zip(self.shadow_params, parameters): | |
param.data.copy_(s_param.data) | |
def store(self, parameters): | |
""" | |
Save the current parameters for restoring later. | |
Args: | |
parameters: Iterable of `torch.nn.Parameter`; the parameters to be | |
temporarily stored. | |
""" | |
self.collected_params = [param.clone() for param in parameters] | |
def restore(self, parameters): | |
""" | |
Restore the parameters stored with the `store` method. | |
Useful to validate the model with EMA parameters without affecting the | |
original optimization process. Store the parameters before the | |
`copy_to` method. After validation (or model saving), use this to | |
restore the former parameters. | |
Args: | |
parameters: Iterable of `torch.nn.Parameter`; the parameters to be | |
updated with the stored parameters. | |
""" | |
for c_param, param in zip(self.collected_params, parameters): | |
param.data.copy_(c_param.data) | |
del (self.collected_params) | |