gyrojeff commited on
Commit
68dd12a
·
1 Parent(s): 3b25c65

feat: add tanh

Browse files
Files changed (3) hide show
  1. detector/data.py +9 -4
  2. detector/model.py +6 -2
  3. train.py +4 -1
detector/data.py CHANGED
@@ -15,9 +15,10 @@ from PIL import Image
15
 
16
 
17
  class FontDataset(Dataset):
18
- def __init__(self, path: str, config_path: str = "configs/font.yml"):
19
  self.path = path
20
  self.fonts = load_font_with_exclusion(config_path)
 
21
 
22
  self.images = [
23
  os.path.join(path, f) for f in os.listdir(path) if f.endswith(".jpg")
@@ -50,6 +51,9 @@ class FontDataset(Dataset):
50
  out[7:10] = out[2:5]
51
  out[10] = label.line_spacing / label.image_width
52
  out[11] = label.angle / 180.0 + 0.5
 
 
 
53
 
54
  return out
55
 
@@ -87,6 +91,7 @@ class FontDataModule(LightningDataModule):
87
  train_shuffle: bool = True,
88
  val_shuffle: bool = False,
89
  test_shuffle: bool = False,
 
90
  **kwargs,
91
  ):
92
  super().__init__()
@@ -94,9 +99,9 @@ class FontDataModule(LightningDataModule):
94
  self.train_shuffle = train_shuffle
95
  self.val_shuffle = val_shuffle
96
  self.test_shuffle = test_shuffle
97
- self.train_dataset = FontDataset(train_path, config_path)
98
- self.val_dataset = FontDataset(val_path, config_path)
99
- self.test_dataset = FontDataset(test_path, config_path)
100
 
101
  def get_train_num_iter(self, num_device: int) -> int:
102
  return math.ceil(
 
15
 
16
 
17
  class FontDataset(Dataset):
18
+ def __init__(self, path: str, config_path: str = "configs/font.yml", regression_use_tanh: bool=False):
19
  self.path = path
20
  self.fonts = load_font_with_exclusion(config_path)
21
+ self.regression_use_tanh = regression_use_tanh
22
 
23
  self.images = [
24
  os.path.join(path, f) for f in os.listdir(path) if f.endswith(".jpg")
 
51
  out[7:10] = out[2:5]
52
  out[10] = label.line_spacing / label.image_width
53
  out[11] = label.angle / 180.0 + 0.5
54
+
55
+ if self.regression_use_tanh:
56
+ out[2:12] = out[2:12] * 2 - 1
57
 
58
  return out
59
 
 
91
  train_shuffle: bool = True,
92
  val_shuffle: bool = False,
93
  test_shuffle: bool = False,
94
+ regression_use_tanh: bool = False,
95
  **kwargs,
96
  ):
97
  super().__init__()
 
99
  self.train_shuffle = train_shuffle
100
  self.val_shuffle = val_shuffle
101
  self.test_shuffle = test_shuffle
102
+ self.train_dataset = FontDataset(train_path, config_path, regression_use_tanh)
103
+ self.val_dataset = FontDataset(val_path, config_path, regression_use_tanh)
104
+ self.test_dataset = FontDataset(test_path, config_path, regression_use_tanh)
105
 
106
  def get_train_num_iter(self, num_device: int) -> int:
107
  return math.ceil(
detector/model.py CHANGED
@@ -11,15 +11,19 @@ import pytorch_lightning as ptl
11
 
12
 
13
  class ResNet18Regressor(nn.Module):
14
- def __init__(self):
15
  super().__init__()
16
  self.model = torchvision.models.resnet18(weights=False)
17
  self.model.fc = nn.Linear(512, config.FONT_COUNT + 12)
 
18
 
19
  def forward(self, X):
20
  X = self.model(X)
21
  # [0, 1]
22
- X[..., config.FONT_COUNT + 2 :] = X[..., config.FONT_COUNT + 2 :].sigmoid()
 
 
 
23
  return X
24
 
25
 
 
11
 
12
 
13
  class ResNet18Regressor(nn.Module):
14
+ def __init__(self, regression_use_tanh: bool=False):
15
  super().__init__()
16
  self.model = torchvision.models.resnet18(weights=False)
17
  self.model.fc = nn.Linear(512, config.FONT_COUNT + 12)
18
+ self.regression_use_tanh = regression_use_tanh
19
 
20
  def forward(self, X):
21
  X = self.model(X)
22
  # [0, 1]
23
+ if not self.regression_use_tanh:
24
+ X[..., config.FONT_COUNT + 2 :] = X[..., config.FONT_COUNT + 2 :].sigmoid()
25
+ else:
26
+ X[..., config.FONT_COUNT + 2 :] = X[..., config.FONT_COUNT + 2 :].tanh()
27
  return X
28
 
29
 
train.py CHANGED
@@ -24,6 +24,8 @@ lambda_font = 4.0
24
  lambda_direction = 0.5
25
  lambda_regression = 1.0
26
 
 
 
27
  num_warmup_epochs = 1
28
  num_epochs = 100
29
 
@@ -38,6 +40,7 @@ data_module = FontDataModule(
38
  train_shuffle=True,
39
  val_shuffle=False,
40
  test_shuffle=False,
 
41
  )
42
 
43
  num_iters = data_module.get_train_num_iter(num_device) * num_epochs
@@ -62,7 +65,7 @@ trainer = ptl.Trainer(
62
  deterministic=True,
63
  )
64
 
65
- model = ResNet18Regressor()
66
 
67
  detector = FontDetector(
68
  model=model,
 
24
  lambda_direction = 0.5
25
  lambda_regression = 1.0
26
 
27
+ regression_use_tanh = True
28
+
29
  num_warmup_epochs = 1
30
  num_epochs = 100
31
 
 
40
  train_shuffle=True,
41
  val_shuffle=False,
42
  test_shuffle=False,
43
+ regression_use_tanh=regression_use_tanh,
44
  )
45
 
46
  num_iters = data_module.get_train_num_iter(num_device) * num_epochs
 
65
  deterministic=True,
66
  )
67
 
68
+ model = ResNet18Regressor(regression_use_tanh=regression_use_tanh)
69
 
70
  detector = FontDetector(
71
  model=model,