gyrojeff commited on
Commit
416c7bb
·
1 Parent(s): 2928b04

feat: add cli support for switching model

Browse files
Files changed (2) hide show
  1. detector/model.py +12 -8
  2. train.py +54 -4
detector/model.py CHANGED
@@ -11,9 +11,10 @@ import pytorch_lightning as ptl
11
 
12
 
13
  class ResNet18Regressor(nn.Module):
14
- def __init__(self, regression_use_tanh: bool = False):
15
  super().__init__()
16
- self.model = torchvision.models.resnet18(weights=False)
 
17
  self.model.fc = nn.Linear(512, config.FONT_COUNT + 12)
18
  self.regression_use_tanh = regression_use_tanh
19
 
@@ -28,9 +29,10 @@ class ResNet18Regressor(nn.Module):
28
 
29
 
30
  class ResNet34Regressor(nn.Module):
31
- def __init__(self, regression_use_tanh: bool = False):
32
  super().__init__()
33
- self.model = torchvision.models.resnet34(weights=False)
 
34
  self.model.fc = nn.Linear(512, config.FONT_COUNT + 12)
35
  self.regression_use_tanh = regression_use_tanh
36
 
@@ -45,9 +47,10 @@ class ResNet34Regressor(nn.Module):
45
 
46
 
47
  class ResNet50Regressor(nn.Module):
48
- def __init__(self, regression_use_tanh: bool = False):
49
  super().__init__()
50
- self.model = torchvision.models.resnet50(weights=False)
 
51
  self.model.fc = nn.Linear(2048, config.FONT_COUNT + 12)
52
  self.regression_use_tanh = regression_use_tanh
53
 
@@ -62,9 +65,10 @@ class ResNet50Regressor(nn.Module):
62
 
63
 
64
  class ResNet101Regressor(nn.Module):
65
- def __init__(self, regression_use_tanh: bool = False):
66
  super().__init__()
67
- self.model = torchvision.models.resnet101(weights=False)
 
68
  self.model.fc = nn.Linear(2048, config.FONT_COUNT + 12)
69
  self.regression_use_tanh = regression_use_tanh
70
 
 
11
 
12
 
13
  class ResNet18Regressor(nn.Module):
14
+ def __init__(self, pretrained: bool = False, regression_use_tanh: bool = False):
15
  super().__init__()
16
+ weights = torchvision.models.ResNet18_Weights.DEFAULT if pretrained else None
17
+ self.model = torchvision.models.resnet18(weights=weights)
18
  self.model.fc = nn.Linear(512, config.FONT_COUNT + 12)
19
  self.regression_use_tanh = regression_use_tanh
20
 
 
29
 
30
 
31
  class ResNet34Regressor(nn.Module):
32
+ def __init__(self, pretrained: bool = False, regression_use_tanh: bool = False):
33
  super().__init__()
34
+ weights = torchvision.models.ResNet34_Weights.DEFAULT if pretrained else None
35
+ self.model = torchvision.models.resnet34(weights=weights)
36
  self.model.fc = nn.Linear(512, config.FONT_COUNT + 12)
37
  self.regression_use_tanh = regression_use_tanh
38
 
 
47
 
48
 
49
  class ResNet50Regressor(nn.Module):
50
+ def __init__(self, pretrained: bool = False, regression_use_tanh: bool = False):
51
  super().__init__()
52
+ weights = torchvision.models.ResNet50_Weights.DEFAULT if pretrained else None
53
+ self.model = torchvision.models.resnet50(weights=weights)
54
  self.model.fc = nn.Linear(2048, config.FONT_COUNT + 12)
55
  self.regression_use_tanh = regression_use_tanh
56
 
 
65
 
66
 
67
  class ResNet101Regressor(nn.Module):
68
+ def __init__(self, pretrained: bool = False, regression_use_tanh: bool = False):
69
  super().__init__()
70
+ weights = torchvision.models.ResNet101_Weights.DEFAULT if pretrained else None
71
+ self.model = torchvision.models.resnet101(weights=weights)
72
  self.model.fc = nn.Linear(2048, config.FONT_COUNT + 12)
73
  self.regression_use_tanh = regression_use_tanh
74
 
train.py CHANGED
@@ -12,9 +12,42 @@ from utils import get_current_tag
12
  torch.set_float32_matmul_precision("high")
13
 
14
  parser = argparse.ArgumentParser()
15
- parser.add_argument("-d", "--devices", nargs="*", type=int, default=[0])
16
- parser.add_argument("-b", "--single-batch-size", type=int, default=64)
17
- parser.add_argument("-c", "--checkpoint", type=str, default=None)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
18
 
19
  args = parser.parse_args()
20
 
@@ -76,7 +109,24 @@ trainer = ptl.Trainer(
76
  deterministic=True,
77
  )
78
 
79
- model = ResNet50Regressor(regression_use_tanh=regression_use_tanh)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
80
 
81
  detector = FontDetector(
82
  model=model,
 
12
  torch.set_float32_matmul_precision("high")
13
 
14
  parser = argparse.ArgumentParser()
15
+ parser.add_argument(
16
+ "-d",
17
+ "--devices",
18
+ nargs="*",
19
+ type=int,
20
+ default=[0],
21
+ help="GPU devices to use (default: [0])",
22
+ )
23
+ parser.add_argument(
24
+ "-b",
25
+ "--single-batch-size",
26
+ type=int,
27
+ default=64,
28
+ help="Batch size of single device (default: 64)",
29
+ )
30
+ parser.add_argument(
31
+ "-c",
32
+ "--checkpoint",
33
+ type=str,
34
+ default=None,
35
+ help="Trainer checkpoint path (default: None)",
36
+ )
37
+ parser.add_argument(
38
+ "-m",
39
+ "--model",
40
+ type=str,
41
+ default="resnet18",
42
+ choices=["resnet18", "resnet34", "resnet50", "resnet101"],
43
+ help="Model to use (default: resnet18)",
44
+ )
45
+ parser.add_argument(
46
+ "-p",
47
+ "--pretrained",
48
+ action="store_true",
49
+ help="Use pretrained model for ResNet (default: False)",
50
+ )
51
 
52
  args = parser.parse_args()
53
 
 
109
  deterministic=True,
110
  )
111
 
112
+ if args.model == "resnet18":
113
+ model = ResNet18Regressor(
114
+ pretrained=args.pretrained, regression_use_tanh=regression_use_tanh
115
+ )
116
+ elif args.model == "resnet34":
117
+ model = ResNet34Regressor(
118
+ pretrained=args.pretrained, regression_use_tanh=regression_use_tanh
119
+ )
120
+ elif args.model == "resnet50":
121
+ model = ResNet50Regressor(
122
+ pretrained=args.pretrained, regression_use_tanh=regression_use_tanh
123
+ )
124
+ elif args.model == "resnet101":
125
+ model = ResNet101Regressor(
126
+ pretrained=args.pretrained, regression_use_tanh=regression_use_tanh
127
+ )
128
+ else:
129
+ raise NotImplementedError()
130
 
131
  detector = FontDetector(
132
  model=model,