feat: add deepfont baseline
Browse files- detector/model.py +31 -0
- train.py +6 -1
detector/model.py
CHANGED
@@ -10,6 +10,37 @@ import torch.nn as nn
|
|
10 |
import pytorch_lightning as ptl
|
11 |
|
12 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
13 |
class ResNet18Regressor(nn.Module):
|
14 |
def __init__(self, pretrained: bool = False, regression_use_tanh: bool = False):
|
15 |
super().__init__()
|
|
|
10 |
import pytorch_lightning as ptl
|
11 |
|
12 |
|
13 |
+
class DeepFontBaseline(nn.Module):
|
14 |
+
def __init__(self) -> None:
|
15 |
+
super().__init__()
|
16 |
+
self.model = nn.Sequential(
|
17 |
+
nn.Conv2d(3, 64, 11, 2),
|
18 |
+
nn.BatchNorm2d(64),
|
19 |
+
nn.ReLU(),
|
20 |
+
nn.MaxPool2d(2, 2),
|
21 |
+
nn.Conv2d(64, 128, 3, 1, 1),
|
22 |
+
nn.BatchNorm2d(128),
|
23 |
+
nn.ReLU(),
|
24 |
+
nn.MaxPool2d(2, 2),
|
25 |
+
nn.Conv2d(128, 256, 3, 1, 1),
|
26 |
+
nn.ReLU(),
|
27 |
+
nn.Conv2d(256, 256, 3, 1, 1),
|
28 |
+
nn.ReLU(),
|
29 |
+
nn.Conv2d(256, 256, 3, 1, 1),
|
30 |
+
nn.ReLU(),
|
31 |
+
# fc
|
32 |
+
nn.Flatten(),
|
33 |
+
nn.Linear(256 * 12 * 12, 4096),
|
34 |
+
nn.ReLU(),
|
35 |
+
nn.Linear(4096, 4096),
|
36 |
+
nn.ReLU(),
|
37 |
+
nn.Linear(4096, config.FONT_COUNT),
|
38 |
+
)
|
39 |
+
|
40 |
+
def forward(self, X):
|
41 |
+
return self.model(X)
|
42 |
+
|
43 |
+
|
44 |
class ResNet18Regressor(nn.Module):
|
45 |
def __init__(self, pretrained: bool = False, regression_use_tanh: bool = False):
|
46 |
super().__init__()
|
train.py
CHANGED
@@ -39,7 +39,7 @@ parser.add_argument(
|
|
39 |
"--model",
|
40 |
type=str,
|
41 |
default="resnet18",
|
42 |
-
choices=["resnet18", "resnet34", "resnet50", "resnet101"],
|
43 |
help="Model to use (default: resnet18)",
|
44 |
)
|
45 |
parser.add_argument(
|
@@ -181,6 +181,11 @@ elif args.model == "resnet101":
|
|
181 |
model = ResNet101Regressor(
|
182 |
pretrained=args.pretrained, regression_use_tanh=regression_use_tanh
|
183 |
)
|
|
|
|
|
|
|
|
|
|
|
184 |
else:
|
185 |
raise NotImplementedError()
|
186 |
|
|
|
39 |
"--model",
|
40 |
type=str,
|
41 |
default="resnet18",
|
42 |
+
choices=["resnet18", "resnet34", "resnet50", "resnet101", "deepfont"],
|
43 |
help="Model to use (default: resnet18)",
|
44 |
)
|
45 |
parser.add_argument(
|
|
|
181 |
model = ResNet101Regressor(
|
182 |
pretrained=args.pretrained, regression_use_tanh=regression_use_tanh
|
183 |
)
|
184 |
+
elif args.model == "deepfont":
|
185 |
+
assert args.pretrained is False
|
186 |
+
assert args.size == 105
|
187 |
+
assert args.font_classification_only is True
|
188 |
+
model = DeepFontBaseline()
|
189 |
else:
|
190 |
raise NotImplementedError()
|
191 |
|