gyrojeff commited on
Commit
912d566
·
1 Parent(s): 49d8194

fix: reset lr on load ckpt

Browse files
Files changed (2) hide show
  1. detector/model.py +5 -2
  2. train.py +1 -0
detector/model.py CHANGED
@@ -130,6 +130,7 @@ class FontDetector(ptl.LightningModule):
130
  betas: Tuple[float, float],
131
  num_warmup_iters: int,
132
  num_iters: int,
 
133
  ):
134
  super().__init__()
135
  self.model = model
@@ -156,6 +157,7 @@ class FontDetector(ptl.LightningModule):
156
  self.betas = betas
157
  self.num_warmup_iters = num_warmup_iters
158
  self.num_iters = num_iters
 
159
  self.load_step = 0
160
 
161
  def forward(self, x):
@@ -240,7 +242,8 @@ class FontDetector(ptl.LightningModule):
240
  self.scheduler = CosineWarmupScheduler(
241
  optimizer, self.num_warmup_iters, self.num_iters
242
  )
243
- for _ in range(self.load_step):
 
244
  self.scheduler.step()
245
  print("Current learning rate set to:", self.scheduler.get_last_lr())
246
  return optimizer
@@ -261,4 +264,4 @@ class FontDetector(ptl.LightningModule):
261
  self.scheduler.step()
262
 
263
  def on_load_checkpoint(self, checkpoint: Dict[str, Any]) -> None:
264
- self.load_step = checkpoint["global_step"]
 
130
  betas: Tuple[float, float],
131
  num_warmup_iters: int,
132
  num_iters: int,
133
+ num_epochs: int,
134
  ):
135
  super().__init__()
136
  self.model = model
 
157
  self.betas = betas
158
  self.num_warmup_iters = num_warmup_iters
159
  self.num_iters = num_iters
160
+ self.num_epochs = num_epochs
161
  self.load_step = 0
162
 
163
  def forward(self, x):
 
242
  self.scheduler = CosineWarmupScheduler(
243
  optimizer, self.num_warmup_iters, self.num_iters
244
  )
245
+ print("Load epoch:", self.load_epoch)
246
+ for _ in range(self.num_iters * self.load_epoch // self.num_epochs):
247
  self.scheduler.step()
248
  print("Current learning rate set to:", self.scheduler.get_last_lr())
249
  return optimizer
 
264
  self.scheduler.step()
265
 
266
  def on_load_checkpoint(self, checkpoint: Dict[str, Any]) -> None:
267
+ self.load_epoch = checkpoint["epoch"]
train.py CHANGED
@@ -87,6 +87,7 @@ detector = FontDetector(
87
  betas=(b1, b2),
88
  num_warmup_iters=num_warmup_iter,
89
  num_iters=num_iters,
 
90
  )
91
 
92
  trainer.fit(detector, datamodule=data_module, ckpt_path=args.checkpoint)
 
87
  betas=(b1, b2),
88
  num_warmup_iters=num_warmup_iter,
89
  num_iters=num_iters,
90
+ num_epochs=num_epochs,
91
  )
92
 
93
  trainer.fit(detector, datamodule=data_module, ckpt_path=args.checkpoint)