|
|
|
from ..dist_utils import allreduce_params |
|
from .hook import HOOKS, Hook |
|
|
|
|
|
@HOOKS.register_module() |
|
class SyncBuffersHook(Hook): |
|
"""Synchronize model buffers such as running_mean and running_var in BN at |
|
the end of each epoch. |
|
|
|
Args: |
|
distributed (bool): Whether distributed training is used. It is |
|
effective only for distributed training. Defaults to True. |
|
""" |
|
|
|
def __init__(self, distributed=True): |
|
self.distributed = distributed |
|
|
|
def after_epoch(self, runner): |
|
"""All-reduce model buffers at the end of each epoch.""" |
|
if self.distributed: |
|
allreduce_params(runner.model.buffers()) |
|
|