Spaces:
Runtime error
Runtime error
File size: 6,251 Bytes
d6d3a5b |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 |
import time
import numpy as np
import pytorch_lightning as pl
import torch
import torch.optim as optim
import common.pl_utils as pl_utils
from common.comet_utils import log_dict
from common.pl_utils import avg_losses_cpu, push_checkpoint_metric
from common.xdict import xdict
class AbstractPL(pl.LightningModule):
def __init__(
self,
args,
push_images_fn,
tracked_metric,
metric_init_val,
high_loss_val,
):
super().__init__()
self.experiment = args.experiment
self.args = args
self.tracked_metric = tracked_metric
self.metric_init_val = metric_init_val
self.started_training = False
self.loss_dict_vec = []
self.push_images = push_images_fn
self.vis_train_batches = []
self.vis_val_batches = []
self.high_loss_val = high_loss_val
self.max_vis_examples = 20
self.val_step_outputs = []
self.test_step_outputs = []
def set_training_flags(self):
self.started_training = True
def load_from_ckpt(self, ckpt_path):
sd = torch.load(ckpt_path)["state_dict"]
print(self.load_state_dict(sd))
def training_step(self, batch, batch_idx):
self.set_training_flags()
if len(self.vis_train_batches) < self.num_vis_train:
self.vis_train_batches.append(batch)
inputs, targets, meta_info = batch
out = self.forward(inputs, targets, meta_info, "train")
loss = out["loss"]
loss = {k: loss[k].mean().view(-1) for k in loss}
total_loss = sum(loss[k] for k in loss)
loss_dict = {"total_loss": total_loss, "loss": total_loss}
loss_dict.update(loss)
for k, v in loss_dict.items():
if k != "loss":
loss_dict[k] = v.detach()
log_every = self.args.log_every
self.loss_dict_vec.append(loss_dict)
self.loss_dict_vec = self.loss_dict_vec[len(self.loss_dict_vec) - log_every :]
if batch_idx % log_every == 0 and batch_idx != 0:
running_loss_dict = avg_losses_cpu(self.loss_dict_vec)
running_loss_dict = xdict(running_loss_dict).postfix("__train")
log_dict(self.experiment, running_loss_dict, step=self.global_step)
return loss_dict
def on_train_epoch_end(self):
self.experiment.log_epoch_end(self.current_epoch)
def validation_step(self, batch, batch_idx):
if len(self.vis_val_batches) < self.num_vis_val:
self.vis_val_batches.append(batch)
out = self.inference_step(batch, batch_idx)
self.val_step_outputs.append(out)
return out
def on_validation_epoch_end(self):
outputs = self.val_step_outputs
outputs = self.inference_epoch_end(outputs, postfix="__val")
self.log("loss__val", outputs["loss__val"])
self.val_step_outputs.clear() # free memory
return outputs
def inference_step(self, batch, batch_idx):
if self.training:
self.eval()
with torch.no_grad():
inputs, targets, meta_info = batch
out, loss = self.forward(inputs, targets, meta_info, "test")
return {"out_dict": out, "loss": loss}
def inference_epoch_end(self, out_list, postfix):
if not self.started_training:
self.started_training = True
result = push_checkpoint_metric(self.tracked_metric, self.metric_init_val)
return result
# unpack
outputs, loss_dict = pl_utils.reform_outputs(out_list)
if "test" in postfix:
per_img_metric_dict = {}
for k, v in outputs.items():
if "metric." in k:
per_img_metric_dict[k] = np.array(v)
metric_dict = {}
for k, v in outputs.items():
if "metric." in k:
metric_dict[k] = np.nanmean(np.array(v))
loss_metric_dict = {}
loss_metric_dict.update(metric_dict)
loss_metric_dict.update(loss_dict)
loss_metric_dict = xdict(loss_metric_dict).postfix(postfix)
log_dict(
self.experiment,
loss_metric_dict,
step=self.global_step,
)
if self.args.interface_p is None and "test" not in postfix:
result = push_checkpoint_metric(
self.tracked_metric, loss_metric_dict[self.tracked_metric]
)
self.log(self.tracked_metric, result[self.tracked_metric])
if not self.args.no_vis:
print("Rendering train images")
self.visualize_batches(self.vis_train_batches, "_train", False)
print("Rendering val images")
self.visualize_batches(self.vis_val_batches, "_val", False)
if "test" in postfix:
return (
outputs,
{"per_img_metric_dict": per_img_metric_dict},
metric_dict,
)
return loss_metric_dict
def configure_optimizers(self):
optimizer = torch.optim.Adam(self.parameters(), lr=self.args.lr)
scheduler = optim.lr_scheduler.MultiStepLR(
optimizer, self.args.lr_dec_epoch, gamma=self.args.lr_decay, verbose=True
)
return [optimizer], [scheduler]
def visualize_batches(self, batches, postfix, no_tqdm=True):
im_list = []
if self.training:
self.eval()
tic = time.time()
for batch_idx, batch in enumerate(batches):
with torch.no_grad():
inputs, targets, meta_info = batch
vis_dict = self.forward(inputs, targets, meta_info, "vis")
for vis_fn in self.vis_fns:
curr_im_list = vis_fn(
vis_dict,
self.max_vis_examples,
self.renderer,
postfix=postfix,
no_tqdm=no_tqdm,
)
im_list += curr_im_list
print("Rendering: %d/%d" % (batch_idx + 1, len(batches)))
self.push_images(self.experiment, im_list, self.global_step)
print("Done rendering (%.1fs)" % (time.time() - tic))
return im_list
|