Spaces:
Running
Running
import os | |
import shutil | |
from mmcv import Config | |
from pytorch_lightning import Trainer | |
from pytorch_lightning.callbacks import ModelCheckpoint, EarlyStopping | |
from pytorch_lightning.loggers import WandbLogger | |
from pytorch_lightning.utilities.seed import seed_everything | |
import wandb | |
from risk_biased.utils.callbacks import SwitchTrainingModeCallback | |
from risk_biased.utils.callbacks import ( | |
HistogramCallback, | |
PlotTrajCallback, | |
DrawCallbackParams, | |
) | |
from risk_biased.utils.load_model import load_from_config | |
from scripts.scripts_utils.load_utils import get_config | |
def create_log_dir(): | |
working_dir = os.path.dirname(os.path.realpath(__file__)) | |
log_dir = os.path.join(working_dir, "logs") | |
if not os.path.exists(log_dir): | |
os.mkdir(log_dir) | |
return log_dir | |
def save_log_config(cfg: Config, predictor): | |
# Save and log the config (not only a copy of the config file because settings may have been overwritten by argparse) | |
log_config_path = os.path.join(wandb.run.dir, "learning_config.py") | |
cfg.dump(log_config_path) | |
wandb.save(log_config_path) | |
# Save files listed in the current wandb log dir | |
for file_name in cfg.files_to_log: | |
dest_path = os.path.join(wandb.run.dir, os.path.basename(file_name)) | |
shutil.copy(file_name, dest_path) | |
wandb.save(dest_path) | |
if cfg.log_weights_and_grads: | |
wandb.watch(predictor, log="all", log_freq=100) | |
def create_callbacks(cfg: Config, log_dir: str, is_interaction: bool) -> list: | |
# Save checkpoint of last model in a specific directory | |
last_run_checkpoint_callback = ModelCheckpoint( | |
monitor="val/minfde/prior", | |
mode="min", | |
filename="epoch={epoch:02d}-step={step}-val_minfde_prior={val/minfde/prior:.2f}", | |
auto_insert_metric_name=False, | |
dirpath=os.path.join(log_dir, "checkpoints_last_run"), | |
save_last=True, | |
) | |
# Save checkpoints of current run in current wandb log dir | |
checkpoint_callback = ModelCheckpoint( | |
monitor="val/minfde/prior", | |
mode="min", | |
filename="epoch={epoch:02d}-step={step}-val_minfde_prior={val/minfde/prior:.2f}", | |
auto_insert_metric_name=False, | |
dirpath=wandb.run.dir, | |
save_last=True, | |
) | |
callbacks = [ | |
last_run_checkpoint_callback, | |
checkpoint_callback, | |
] | |
if not is_interaction: | |
histogram_callback = HistogramCallback( | |
params=DrawCallbackParams.from_config(cfg), | |
n_samples=1000, | |
) | |
plot_callback = PlotTrajCallback( | |
params=DrawCallbackParams.from_config(cfg), n_samples=10 | |
) | |
callbacks.append(histogram_callback) | |
callbacks.append(plot_callback) | |
if cfg.early_stopping: | |
early_stopping_callback = EarlyStopping( | |
monitor="val/minfde/prior", | |
min_delta=-0.2, | |
patience=5, | |
verbose=False, | |
mode="min", | |
) | |
callbacks.append(early_stopping_callback) | |
switch_mode_callback = SwitchTrainingModeCallback( | |
switch_at_epoch=cfg.num_epochs_cvae | |
) | |
callbacks.append(switch_mode_callback) | |
return callbacks | |
def get_trainer(cfg: Config, logger: WandbLogger, callbacks: list) -> Trainer: | |
num_epochs = cfg.num_epochs_cvae + cfg.num_epochs_bias | |
return Trainer( | |
gpus=cfg.gpus, | |
max_epochs=num_epochs, | |
logger=logger, | |
val_check_interval=float(cfg.val_check_interval_epoch), | |
accumulate_grad_batches=cfg.accumulate_grad_batches, | |
callbacks=callbacks, | |
) | |
def main(is_interaction: bool = False): | |
log_dir = create_log_dir() | |
cfg = get_config(log_dir, is_interaction) | |
predictor, dataloaders, cfg = load_from_config(cfg) | |
if cfg.seed is not None: | |
seed_everything(cfg.seed) | |
save_log_config(cfg, predictor) | |
logger = WandbLogger( | |
project=cfg.project, log_model=True, save_dir=log_dir, id=wandb.run.id | |
) | |
callbacks = create_callbacks(cfg, log_dir, is_interaction) | |
trainer = get_trainer(cfg, logger, callbacks) | |
trainer.fit( | |
predictor, | |
train_dataloaders=dataloaders.train_dataloader(), | |
val_dataloaders=dataloaders.val_dataloader(), | |
) | |
wandb.finish() | |
if __name__ == "__main__": | |
main(is_interaction=True) | |