Spaces:
Runtime error
Runtime error
import time | |
import numpy as np | |
import pytorch_lightning as pl | |
import torch | |
import torch.optim as optim | |
import common.pl_utils as pl_utils | |
from common.comet_utils import log_dict | |
from common.pl_utils import avg_losses_cpu, push_checkpoint_metric | |
from common.xdict import xdict | |
class AbstractPL(pl.LightningModule): | |
def __init__( | |
self, | |
args, | |
push_images_fn, | |
tracked_metric, | |
metric_init_val, | |
high_loss_val, | |
): | |
super().__init__() | |
self.experiment = args.experiment | |
self.args = args | |
self.tracked_metric = tracked_metric | |
self.metric_init_val = metric_init_val | |
self.started_training = False | |
self.loss_dict_vec = [] | |
self.push_images = push_images_fn | |
self.vis_train_batches = [] | |
self.vis_val_batches = [] | |
self.high_loss_val = high_loss_val | |
self.max_vis_examples = 20 | |
self.val_step_outputs = [] | |
self.test_step_outputs = [] | |
def set_training_flags(self): | |
self.started_training = True | |
def load_from_ckpt(self, ckpt_path): | |
sd = torch.load(ckpt_path)["state_dict"] | |
print(self.load_state_dict(sd)) | |
def training_step(self, batch, batch_idx): | |
self.set_training_flags() | |
if len(self.vis_train_batches) < self.num_vis_train: | |
self.vis_train_batches.append(batch) | |
inputs, targets, meta_info = batch | |
out = self.forward(inputs, targets, meta_info, "train") | |
loss = out["loss"] | |
loss = {k: loss[k].mean().view(-1) for k in loss} | |
total_loss = sum(loss[k] for k in loss) | |
loss_dict = {"total_loss": total_loss, "loss": total_loss} | |
loss_dict.update(loss) | |
for k, v in loss_dict.items(): | |
if k != "loss": | |
loss_dict[k] = v.detach() | |
log_every = self.args.log_every | |
self.loss_dict_vec.append(loss_dict) | |
self.loss_dict_vec = self.loss_dict_vec[len(self.loss_dict_vec) - log_every :] | |
if batch_idx % log_every == 0 and batch_idx != 0: | |
running_loss_dict = avg_losses_cpu(self.loss_dict_vec) | |
running_loss_dict = xdict(running_loss_dict).postfix("__train") | |
log_dict(self.experiment, running_loss_dict, step=self.global_step) | |
return loss_dict | |
def on_train_epoch_end(self): | |
self.experiment.log_epoch_end(self.current_epoch) | |
def validation_step(self, batch, batch_idx): | |
if len(self.vis_val_batches) < self.num_vis_val: | |
self.vis_val_batches.append(batch) | |
out = self.inference_step(batch, batch_idx) | |
self.val_step_outputs.append(out) | |
return out | |
def on_validation_epoch_end(self): | |
outputs = self.val_step_outputs | |
outputs = self.inference_epoch_end(outputs, postfix="__val") | |
self.log("loss__val", outputs["loss__val"]) | |
self.val_step_outputs.clear() # free memory | |
return outputs | |
def inference_step(self, batch, batch_idx): | |
if self.training: | |
self.eval() | |
with torch.no_grad(): | |
inputs, targets, meta_info = batch | |
out, loss = self.forward(inputs, targets, meta_info, "test") | |
return {"out_dict": out, "loss": loss} | |
def inference_epoch_end(self, out_list, postfix): | |
if not self.started_training: | |
self.started_training = True | |
result = push_checkpoint_metric(self.tracked_metric, self.metric_init_val) | |
return result | |
# unpack | |
outputs, loss_dict = pl_utils.reform_outputs(out_list) | |
if "test" in postfix: | |
per_img_metric_dict = {} | |
for k, v in outputs.items(): | |
if "metric." in k: | |
per_img_metric_dict[k] = np.array(v) | |
metric_dict = {} | |
for k, v in outputs.items(): | |
if "metric." in k: | |
metric_dict[k] = np.nanmean(np.array(v)) | |
loss_metric_dict = {} | |
loss_metric_dict.update(metric_dict) | |
loss_metric_dict.update(loss_dict) | |
loss_metric_dict = xdict(loss_metric_dict).postfix(postfix) | |
log_dict( | |
self.experiment, | |
loss_metric_dict, | |
step=self.global_step, | |
) | |
if self.args.interface_p is None and "test" not in postfix: | |
result = push_checkpoint_metric( | |
self.tracked_metric, loss_metric_dict[self.tracked_metric] | |
) | |
self.log(self.tracked_metric, result[self.tracked_metric]) | |
if not self.args.no_vis: | |
print("Rendering train images") | |
self.visualize_batches(self.vis_train_batches, "_train", False) | |
print("Rendering val images") | |
self.visualize_batches(self.vis_val_batches, "_val", False) | |
if "test" in postfix: | |
return ( | |
outputs, | |
{"per_img_metric_dict": per_img_metric_dict}, | |
metric_dict, | |
) | |
return loss_metric_dict | |
def configure_optimizers(self): | |
optimizer = torch.optim.Adam(self.parameters(), lr=self.args.lr) | |
scheduler = optim.lr_scheduler.MultiStepLR( | |
optimizer, self.args.lr_dec_epoch, gamma=self.args.lr_decay, verbose=True | |
) | |
return [optimizer], [scheduler] | |
def visualize_batches(self, batches, postfix, no_tqdm=True): | |
im_list = [] | |
if self.training: | |
self.eval() | |
tic = time.time() | |
for batch_idx, batch in enumerate(batches): | |
with torch.no_grad(): | |
inputs, targets, meta_info = batch | |
vis_dict = self.forward(inputs, targets, meta_info, "vis") | |
for vis_fn in self.vis_fns: | |
curr_im_list = vis_fn( | |
vis_dict, | |
self.max_vis_examples, | |
self.renderer, | |
postfix=postfix, | |
no_tqdm=no_tqdm, | |
) | |
im_list += curr_im_list | |
print("Rendering: %d/%d" % (batch_idx + 1, len(batches))) | |
self.push_images(self.experiment, im_list, self.global_step) | |
print("Done rendering (%.1fs)" % (time.time() - tic)) | |
return im_list | |