gene-hoi-denoising / common /abstract_pl.py
meow
init
d6d3a5b
raw
history blame
6.25 kB
import time
import numpy as np
import pytorch_lightning as pl
import torch
import torch.optim as optim
import common.pl_utils as pl_utils
from common.comet_utils import log_dict
from common.pl_utils import avg_losses_cpu, push_checkpoint_metric
from common.xdict import xdict
class AbstractPL(pl.LightningModule):
def __init__(
self,
args,
push_images_fn,
tracked_metric,
metric_init_val,
high_loss_val,
):
super().__init__()
self.experiment = args.experiment
self.args = args
self.tracked_metric = tracked_metric
self.metric_init_val = metric_init_val
self.started_training = False
self.loss_dict_vec = []
self.push_images = push_images_fn
self.vis_train_batches = []
self.vis_val_batches = []
self.high_loss_val = high_loss_val
self.max_vis_examples = 20
self.val_step_outputs = []
self.test_step_outputs = []
def set_training_flags(self):
self.started_training = True
def load_from_ckpt(self, ckpt_path):
sd = torch.load(ckpt_path)["state_dict"]
print(self.load_state_dict(sd))
def training_step(self, batch, batch_idx):
self.set_training_flags()
if len(self.vis_train_batches) < self.num_vis_train:
self.vis_train_batches.append(batch)
inputs, targets, meta_info = batch
out = self.forward(inputs, targets, meta_info, "train")
loss = out["loss"]
loss = {k: loss[k].mean().view(-1) for k in loss}
total_loss = sum(loss[k] for k in loss)
loss_dict = {"total_loss": total_loss, "loss": total_loss}
loss_dict.update(loss)
for k, v in loss_dict.items():
if k != "loss":
loss_dict[k] = v.detach()
log_every = self.args.log_every
self.loss_dict_vec.append(loss_dict)
self.loss_dict_vec = self.loss_dict_vec[len(self.loss_dict_vec) - log_every :]
if batch_idx % log_every == 0 and batch_idx != 0:
running_loss_dict = avg_losses_cpu(self.loss_dict_vec)
running_loss_dict = xdict(running_loss_dict).postfix("__train")
log_dict(self.experiment, running_loss_dict, step=self.global_step)
return loss_dict
def on_train_epoch_end(self):
self.experiment.log_epoch_end(self.current_epoch)
def validation_step(self, batch, batch_idx):
if len(self.vis_val_batches) < self.num_vis_val:
self.vis_val_batches.append(batch)
out = self.inference_step(batch, batch_idx)
self.val_step_outputs.append(out)
return out
def on_validation_epoch_end(self):
outputs = self.val_step_outputs
outputs = self.inference_epoch_end(outputs, postfix="__val")
self.log("loss__val", outputs["loss__val"])
self.val_step_outputs.clear() # free memory
return outputs
def inference_step(self, batch, batch_idx):
if self.training:
self.eval()
with torch.no_grad():
inputs, targets, meta_info = batch
out, loss = self.forward(inputs, targets, meta_info, "test")
return {"out_dict": out, "loss": loss}
def inference_epoch_end(self, out_list, postfix):
if not self.started_training:
self.started_training = True
result = push_checkpoint_metric(self.tracked_metric, self.metric_init_val)
return result
# unpack
outputs, loss_dict = pl_utils.reform_outputs(out_list)
if "test" in postfix:
per_img_metric_dict = {}
for k, v in outputs.items():
if "metric." in k:
per_img_metric_dict[k] = np.array(v)
metric_dict = {}
for k, v in outputs.items():
if "metric." in k:
metric_dict[k] = np.nanmean(np.array(v))
loss_metric_dict = {}
loss_metric_dict.update(metric_dict)
loss_metric_dict.update(loss_dict)
loss_metric_dict = xdict(loss_metric_dict).postfix(postfix)
log_dict(
self.experiment,
loss_metric_dict,
step=self.global_step,
)
if self.args.interface_p is None and "test" not in postfix:
result = push_checkpoint_metric(
self.tracked_metric, loss_metric_dict[self.tracked_metric]
)
self.log(self.tracked_metric, result[self.tracked_metric])
if not self.args.no_vis:
print("Rendering train images")
self.visualize_batches(self.vis_train_batches, "_train", False)
print("Rendering val images")
self.visualize_batches(self.vis_val_batches, "_val", False)
if "test" in postfix:
return (
outputs,
{"per_img_metric_dict": per_img_metric_dict},
metric_dict,
)
return loss_metric_dict
def configure_optimizers(self):
optimizer = torch.optim.Adam(self.parameters(), lr=self.args.lr)
scheduler = optim.lr_scheduler.MultiStepLR(
optimizer, self.args.lr_dec_epoch, gamma=self.args.lr_decay, verbose=True
)
return [optimizer], [scheduler]
def visualize_batches(self, batches, postfix, no_tqdm=True):
im_list = []
if self.training:
self.eval()
tic = time.time()
for batch_idx, batch in enumerate(batches):
with torch.no_grad():
inputs, targets, meta_info = batch
vis_dict = self.forward(inputs, targets, meta_info, "vis")
for vis_fn in self.vis_fns:
curr_im_list = vis_fn(
vis_dict,
self.max_vis_examples,
self.renderer,
postfix=postfix,
no_tqdm=no_tqdm,
)
im_list += curr_im_list
print("Rendering: %d/%d" % (batch_idx + 1, len(batches)))
self.push_images(self.experiment, im_list, self.global_step)
print("Done rendering (%.1fs)" % (time.time() - tic))
return im_list