|
|
|
from ...dist_utils import master_only |
|
from ..hook import HOOKS |
|
from .base import LoggerHook |
|
|
|
|
|
@HOOKS.register_module() |
|
class WandbLoggerHook(LoggerHook): |
|
|
|
def __init__(self, |
|
init_kwargs=None, |
|
interval=10, |
|
ignore_last=True, |
|
reset_flag=False, |
|
commit=True, |
|
by_epoch=True, |
|
with_step=True): |
|
super(WandbLoggerHook, self).__init__(interval, ignore_last, |
|
reset_flag, by_epoch) |
|
self.import_wandb() |
|
self.init_kwargs = init_kwargs |
|
self.commit = commit |
|
self.with_step = with_step |
|
|
|
def import_wandb(self): |
|
try: |
|
import wandb |
|
except ImportError: |
|
raise ImportError( |
|
'Please run "pip install wandb" to install wandb') |
|
self.wandb = wandb |
|
|
|
@master_only |
|
def before_run(self, runner): |
|
super(WandbLoggerHook, self).before_run(runner) |
|
if self.wandb is None: |
|
self.import_wandb() |
|
if self.init_kwargs: |
|
self.wandb.init(**self.init_kwargs) |
|
else: |
|
self.wandb.init() |
|
|
|
@master_only |
|
def log(self, runner): |
|
tags = self.get_loggable_tags(runner) |
|
if tags: |
|
if self.with_step: |
|
self.wandb.log( |
|
tags, step=self.get_iter(runner), commit=self.commit) |
|
else: |
|
tags['global_step'] = self.get_iter(runner) |
|
self.wandb.log(tags, commit=self.commit) |
|
|
|
@master_only |
|
def after_run(self, runner): |
|
self.wandb.join() |
|
|