File size: 5,451 Bytes
8683813 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 |
# Copyright (c) OpenMMLab. All rights reserved.
import numbers
from abc import ABCMeta, abstractmethod
import numpy as np
import torch
from ..hook import Hook
class LoggerHook(Hook):
"""Base class for logger hooks.
Args:
interval (int): Logging interval (every k iterations).
ignore_last (bool): Ignore the log of last iterations in each epoch
if less than `interval`.
reset_flag (bool): Whether to clear the output buffer after logging.
by_epoch (bool): Whether EpochBasedRunner is used.
"""
__metaclass__ = ABCMeta
def __init__(self,
interval=10,
ignore_last=True,
reset_flag=False,
by_epoch=True):
self.interval = interval
self.ignore_last = ignore_last
self.reset_flag = reset_flag
self.by_epoch = by_epoch
@abstractmethod
def log(self, runner):
pass
@staticmethod
def is_scalar(val, include_np=True, include_torch=True):
"""Tell the input variable is a scalar or not.
Args:
val: Input variable.
include_np (bool): Whether include 0-d np.ndarray as a scalar.
include_torch (bool): Whether include 0-d torch.Tensor as a scalar.
Returns:
bool: True or False.
"""
if isinstance(val, numbers.Number):
return True
elif include_np and isinstance(val, np.ndarray) and val.ndim == 0:
return True
elif include_torch and isinstance(val, torch.Tensor) and len(val) == 1:
return True
else:
return False
def get_mode(self, runner):
if runner.mode == 'train':
if 'time' in runner.log_buffer.output:
mode = 'train'
else:
mode = 'val'
elif runner.mode == 'val':
mode = 'val'
else:
raise ValueError(f"runner mode should be 'train' or 'val', "
f'but got {runner.mode}')
return mode
def get_epoch(self, runner):
if runner.mode == 'train':
epoch = runner.epoch + 1
elif runner.mode == 'val':
# normal val mode
# runner.epoch += 1 has been done before val workflow
epoch = runner.epoch
else:
raise ValueError(f"runner mode should be 'train' or 'val', "
f'but got {runner.mode}')
return epoch
def get_iter(self, runner, inner_iter=False):
"""Get the current training iteration step."""
if self.by_epoch and inner_iter:
current_iter = runner.inner_iter + 1
else:
current_iter = runner.iter + 1
return current_iter
def get_lr_tags(self, runner):
tags = {}
lrs = runner.current_lr()
if isinstance(lrs, dict):
for name, value in lrs.items():
tags[f'learning_rate/{name}'] = value[0]
else:
tags['learning_rate'] = lrs[0]
return tags
def get_momentum_tags(self, runner):
tags = {}
momentums = runner.current_momentum()
if isinstance(momentums, dict):
for name, value in momentums.items():
tags[f'momentum/{name}'] = value[0]
else:
tags['momentum'] = momentums[0]
return tags
def get_loggable_tags(self,
runner,
allow_scalar=True,
allow_text=False,
add_mode=True,
tags_to_skip=('time', 'data_time')):
tags = {}
for var, val in runner.log_buffer.output.items():
if var in tags_to_skip:
continue
if self.is_scalar(val) and not allow_scalar:
continue
if isinstance(val, str) and not allow_text:
continue
if add_mode:
var = f'{self.get_mode(runner)}/{var}'
tags[var] = val
tags.update(self.get_lr_tags(runner))
tags.update(self.get_momentum_tags(runner))
return tags
def before_run(self, runner):
for hook in runner.hooks[::-1]:
if isinstance(hook, LoggerHook):
hook.reset_flag = True
break
def before_epoch(self, runner):
runner.log_buffer.clear() # clear logs of last epoch
def after_train_iter(self, runner):
if self.by_epoch and self.every_n_inner_iters(runner, self.interval):
runner.log_buffer.average(self.interval)
elif not self.by_epoch and self.every_n_iters(runner, self.interval):
runner.log_buffer.average(self.interval)
elif self.end_of_epoch(runner) and not self.ignore_last:
# not precise but more stable
runner.log_buffer.average(self.interval)
if runner.log_buffer.ready:
self.log(runner)
if self.reset_flag:
runner.log_buffer.clear_output()
def after_train_epoch(self, runner):
if runner.log_buffer.ready:
self.log(runner)
if self.reset_flag:
runner.log_buffer.clear_output()
def after_val_epoch(self, runner):
runner.log_buffer.average()
self.log(runner)
if self.reset_flag:
runner.log_buffer.clear_output()
|