gyrojeff commited on
Commit
edd29d3
·
1 Parent(s): 705feb9

feat: add deepfont baseline

Browse files
Files changed (2) hide show
  1. detector/model.py +31 -0
  2. train.py +6 -1
detector/model.py CHANGED
@@ -10,6 +10,37 @@ import torch.nn as nn
10
  import pytorch_lightning as ptl
11
 
12
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
13
  class ResNet18Regressor(nn.Module):
14
  def __init__(self, pretrained: bool = False, regression_use_tanh: bool = False):
15
  super().__init__()
 
10
  import pytorch_lightning as ptl
11
 
12
 
13
+ class DeepFontBaseline(nn.Module):
14
+ def __init__(self) -> None:
15
+ super().__init__()
16
+ self.model = nn.Sequential(
17
+ nn.Conv2d(3, 64, 11, 2),
18
+ nn.BatchNorm2d(64),
19
+ nn.ReLU(),
20
+ nn.MaxPool2d(2, 2),
21
+ nn.Conv2d(64, 128, 3, 1, 1),
22
+ nn.BatchNorm2d(128),
23
+ nn.ReLU(),
24
+ nn.MaxPool2d(2, 2),
25
+ nn.Conv2d(128, 256, 3, 1, 1),
26
+ nn.ReLU(),
27
+ nn.Conv2d(256, 256, 3, 1, 1),
28
+ nn.ReLU(),
29
+ nn.Conv2d(256, 256, 3, 1, 1),
30
+ nn.ReLU(),
31
+ # fc
32
+ nn.Flatten(),
33
+ nn.Linear(256 * 12 * 12, 4096),
34
+ nn.ReLU(),
35
+ nn.Linear(4096, 4096),
36
+ nn.ReLU(),
37
+ nn.Linear(4096, config.FONT_COUNT),
38
+ )
39
+
40
+ def forward(self, X):
41
+ return self.model(X)
42
+
43
+
44
  class ResNet18Regressor(nn.Module):
45
  def __init__(self, pretrained: bool = False, regression_use_tanh: bool = False):
46
  super().__init__()
train.py CHANGED
@@ -39,7 +39,7 @@ parser.add_argument(
39
  "--model",
40
  type=str,
41
  default="resnet18",
42
- choices=["resnet18", "resnet34", "resnet50", "resnet101"],
43
  help="Model to use (default: resnet18)",
44
  )
45
  parser.add_argument(
@@ -181,6 +181,11 @@ elif args.model == "resnet101":
181
  model = ResNet101Regressor(
182
  pretrained=args.pretrained, regression_use_tanh=regression_use_tanh
183
  )
 
 
 
 
 
184
  else:
185
  raise NotImplementedError()
186
 
 
39
  "--model",
40
  type=str,
41
  default="resnet18",
42
+ choices=["resnet18", "resnet34", "resnet50", "resnet101", "deepfont"],
43
  help="Model to use (default: resnet18)",
44
  )
45
  parser.add_argument(
 
181
  model = ResNet101Regressor(
182
  pretrained=args.pretrained, regression_use_tanh=regression_use_tanh
183
  )
184
+ elif args.model == "deepfont":
185
+ assert args.pretrained is False
186
+ assert args.size == 105
187
+ assert args.font_classification_only is True
188
+ model = DeepFontBaseline()
189
  else:
190
  raise NotImplementedError()
191