File size: 1,694 Bytes
8683813 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 |
# Copyright (c) OpenMMLab. All rights reserved.
from ...dist_utils import master_only
from ..hook import HOOKS
from .base import LoggerHook
@HOOKS.register_module()
class WandbLoggerHook(LoggerHook):
def __init__(self,
init_kwargs=None,
interval=10,
ignore_last=True,
reset_flag=False,
commit=True,
by_epoch=True,
with_step=True):
super(WandbLoggerHook, self).__init__(interval, ignore_last,
reset_flag, by_epoch)
self.import_wandb()
self.init_kwargs = init_kwargs
self.commit = commit
self.with_step = with_step
def import_wandb(self):
try:
import wandb
except ImportError:
raise ImportError(
'Please run "pip install wandb" to install wandb')
self.wandb = wandb
@master_only
def before_run(self, runner):
super(WandbLoggerHook, self).before_run(runner)
if self.wandb is None:
self.import_wandb()
if self.init_kwargs:
self.wandb.init(**self.init_kwargs)
else:
self.wandb.init()
@master_only
def log(self, runner):
tags = self.get_loggable_tags(runner)
if tags:
if self.with_step:
self.wandb.log(
tags, step=self.get_iter(runner), commit=self.commit)
else:
tags['global_step'] = self.get_iter(runner)
self.wandb.log(tags, commit=self.commit)
@master_only
def after_run(self, runner):
self.wandb.join()
|