gyrojeff commited on
Commit
69c8e55
·
1 Parent(s): 3f401e5

feat: sync_dist for log

Browse files
Files changed (1) hide show
  1. detector/model.py +8 -4
detector/model.py CHANGED
@@ -105,7 +105,7 @@ class FontDetector(ptl.LightningModule):
105
  X, y = batch
106
  y_hat = self.forward(X)
107
  loss = self.loss(y_hat, y)
108
- self.log("train_loss", loss, prog_bar=True)
109
  return {"loss": loss, "pred": y_hat, "target": y}
110
 
111
  def training_step_end(self, outputs):
@@ -114,12 +114,14 @@ class FontDetector(ptl.LightningModule):
114
  self.log(
115
  "train_font_accur",
116
  self.font_accur_train(y_hat[..., : config.FONT_COUNT], y[..., 0]),
 
117
  )
118
  self.log(
119
  "train_direction_accur",
120
  self.direction_accur_train(
121
  y_hat[..., config.FONT_COUNT : config.FONT_COUNT + 2], y[..., 1]
122
  ),
 
123
  )
124
 
125
  def on_train_epoch_end(self) -> None:
@@ -132,7 +134,7 @@ class FontDetector(ptl.LightningModule):
132
  X, y = batch
133
  y_hat = self.forward(X)
134
  loss = self.loss(y_hat, y)
135
- self.log("val_loss", loss, prog_bar=True)
136
  self.font_accur_val.update(y_hat[..., : config.FONT_COUNT], y[..., 0])
137
  self.direction_accur_val.update(
138
  y_hat[..., config.FONT_COUNT : config.FONT_COUNT + 2], y[..., 1]
@@ -140,8 +142,10 @@ class FontDetector(ptl.LightningModule):
140
  return {"loss": loss, "pred": y_hat, "target": y}
141
 
142
  def on_validation_epoch_end(self):
143
- self.log("val_font_accur", self.font_accur_val.compute())
144
- self.log("val_direction_accur", self.direction_accur_val.compute())
 
 
145
  self.font_accur_val.reset()
146
  self.direction_accur_val.reset()
147
 
 
105
  X, y = batch
106
  y_hat = self.forward(X)
107
  loss = self.loss(y_hat, y)
108
+ self.log("train_loss", loss, prog_bar=True, sync_dist=True)
109
  return {"loss": loss, "pred": y_hat, "target": y}
110
 
111
  def training_step_end(self, outputs):
 
114
  self.log(
115
  "train_font_accur",
116
  self.font_accur_train(y_hat[..., : config.FONT_COUNT], y[..., 0]),
117
+ sync_dist=True,
118
  )
119
  self.log(
120
  "train_direction_accur",
121
  self.direction_accur_train(
122
  y_hat[..., config.FONT_COUNT : config.FONT_COUNT + 2], y[..., 1]
123
  ),
124
+ sync_dist=True,
125
  )
126
 
127
  def on_train_epoch_end(self) -> None:
 
134
  X, y = batch
135
  y_hat = self.forward(X)
136
  loss = self.loss(y_hat, y)
137
+ self.log("val_loss", loss, prog_bar=True, sync_dist=True)
138
  self.font_accur_val.update(y_hat[..., : config.FONT_COUNT], y[..., 0])
139
  self.direction_accur_val.update(
140
  y_hat[..., config.FONT_COUNT : config.FONT_COUNT + 2], y[..., 1]
 
142
  return {"loss": loss, "pred": y_hat, "target": y}
143
 
144
  def on_validation_epoch_end(self):
145
+ self.log("val_font_accur", self.font_accur_val.compute(), sync_dist=True)
146
+ self.log(
147
+ "val_direction_accur", self.direction_accur_val.compute(), sync_dist=True
148
+ )
149
  self.font_accur_val.reset()
150
  self.direction_accur_val.reset()
151