fix: reset lr on load ckpt
Browse files- detector/model.py +5 -2
- train.py +1 -0
detector/model.py
CHANGED
@@ -130,6 +130,7 @@ class FontDetector(ptl.LightningModule):
|
|
130 |
betas: Tuple[float, float],
|
131 |
num_warmup_iters: int,
|
132 |
num_iters: int,
|
|
|
133 |
):
|
134 |
super().__init__()
|
135 |
self.model = model
|
@@ -156,6 +157,7 @@ class FontDetector(ptl.LightningModule):
|
|
156 |
self.betas = betas
|
157 |
self.num_warmup_iters = num_warmup_iters
|
158 |
self.num_iters = num_iters
|
|
|
159 |
self.load_step = 0
|
160 |
|
161 |
def forward(self, x):
|
@@ -240,7 +242,8 @@ class FontDetector(ptl.LightningModule):
|
|
240 |
self.scheduler = CosineWarmupScheduler(
|
241 |
optimizer, self.num_warmup_iters, self.num_iters
|
242 |
)
|
243 |
-
|
|
|
244 |
self.scheduler.step()
|
245 |
print("Current learning rate set to:", self.scheduler.get_last_lr())
|
246 |
return optimizer
|
@@ -261,4 +264,4 @@ class FontDetector(ptl.LightningModule):
|
|
261 |
self.scheduler.step()
|
262 |
|
263 |
def on_load_checkpoint(self, checkpoint: Dict[str, Any]) -> None:
|
264 |
-
self.
|
|
|
130 |
betas: Tuple[float, float],
|
131 |
num_warmup_iters: int,
|
132 |
num_iters: int,
|
133 |
+
num_epochs: int,
|
134 |
):
|
135 |
super().__init__()
|
136 |
self.model = model
|
|
|
157 |
self.betas = betas
|
158 |
self.num_warmup_iters = num_warmup_iters
|
159 |
self.num_iters = num_iters
|
160 |
+
self.num_epochs = num_epochs
|
161 |
self.load_step = 0
|
162 |
|
163 |
def forward(self, x):
|
|
|
242 |
self.scheduler = CosineWarmupScheduler(
|
243 |
optimizer, self.num_warmup_iters, self.num_iters
|
244 |
)
|
245 |
+
print("Load epoch:", self.load_epoch)
|
246 |
+
for _ in range(self.num_iters * self.load_epoch // self.num_epochs):
|
247 |
self.scheduler.step()
|
248 |
print("Current learning rate set to:", self.scheduler.get_last_lr())
|
249 |
return optimizer
|
|
|
264 |
self.scheduler.step()
|
265 |
|
266 |
def on_load_checkpoint(self, checkpoint: Dict[str, Any]) -> None:
|
267 |
+
self.load_epoch = checkpoint["epoch"]
|
train.py
CHANGED
@@ -87,6 +87,7 @@ detector = FontDetector(
|
|
87 |
betas=(b1, b2),
|
88 |
num_warmup_iters=num_warmup_iter,
|
89 |
num_iters=num_iters,
|
|
|
90 |
)
|
91 |
|
92 |
trainer.fit(detector, datamodule=data_module, ckpt_path=args.checkpoint)
|
|
|
87 |
betas=(b1, b2),
|
88 |
num_warmup_iters=num_warmup_iter,
|
89 |
num_iters=num_iters,
|
90 |
+
num_epochs=num_epochs,
|
91 |
)
|
92 |
|
93 |
trainer.fit(detector, datamodule=data_module, ckpt_path=args.checkpoint)
|