Spaces:
Herta83
/
Running

Herta83 commited on
Commit
d22f1af
·
verified ·
1 Parent(s): 8b188c0

Delete model.py

Browse files
Files changed (1) hide show
  1. model.py +0 -158
model.py DELETED
@@ -1,158 +0,0 @@
1
- import os
2
- import glob
3
- import torch
4
- import torch.jit
5
- import torch.nn as nn
6
-
7
-
8
- class Model(torch.jit.ScriptModule):
9
- CHECKPOINT_FILENAME_PATTERN = "model-{}.pth"
10
-
11
- __constants__ = [
12
- "_hidden1",
13
- "_hidden2",
14
- "_hidden3",
15
- "_hidden4",
16
- "_hidden5",
17
- "_hidden6",
18
- "_hidden7",
19
- "_hidden8",
20
- "_hidden9",
21
- "_hidden10",
22
- "_features",
23
- "_classifier",
24
- "_digit_length",
25
- "_digit1",
26
- "_digit2",
27
- "_digit3",
28
- "_digit4",
29
- "_digit5",
30
- ]
31
-
32
- def __init__(self):
33
- super(Model, self).__init__()
34
-
35
- self._hidden1 = nn.Sequential(
36
- nn.Conv2d(in_channels=3, out_channels=48, kernel_size=5, padding=2),
37
- nn.BatchNorm2d(num_features=48),
38
- nn.ReLU(),
39
- nn.MaxPool2d(kernel_size=2, stride=2, padding=1),
40
- nn.Dropout(0.2),
41
- )
42
- self._hidden2 = nn.Sequential(
43
- nn.Conv2d(in_channels=48, out_channels=64, kernel_size=5, padding=2),
44
- nn.BatchNorm2d(num_features=64),
45
- nn.ReLU(),
46
- nn.MaxPool2d(kernel_size=2, stride=1, padding=1),
47
- nn.Dropout(0.2),
48
- )
49
- self._hidden3 = nn.Sequential(
50
- nn.Conv2d(in_channels=64, out_channels=128, kernel_size=5, padding=2),
51
- nn.BatchNorm2d(num_features=128),
52
- nn.ReLU(),
53
- nn.MaxPool2d(kernel_size=2, stride=2, padding=1),
54
- nn.Dropout(0.2),
55
- )
56
- self._hidden4 = nn.Sequential(
57
- nn.Conv2d(in_channels=128, out_channels=160, kernel_size=5, padding=2),
58
- nn.BatchNorm2d(num_features=160),
59
- nn.ReLU(),
60
- nn.MaxPool2d(kernel_size=2, stride=1, padding=1),
61
- nn.Dropout(0.2),
62
- )
63
- self._hidden5 = nn.Sequential(
64
- nn.Conv2d(in_channels=160, out_channels=192, kernel_size=5, padding=2),
65
- nn.BatchNorm2d(num_features=192),
66
- nn.ReLU(),
67
- nn.MaxPool2d(kernel_size=2, stride=2, padding=1),
68
- nn.Dropout(0.2),
69
- )
70
- self._hidden6 = nn.Sequential(
71
- nn.Conv2d(in_channels=192, out_channels=192, kernel_size=5, padding=2),
72
- nn.BatchNorm2d(num_features=192),
73
- nn.ReLU(),
74
- nn.MaxPool2d(kernel_size=2, stride=1, padding=1),
75
- nn.Dropout(0.2),
76
- )
77
- self._hidden7 = nn.Sequential(
78
- nn.Conv2d(in_channels=192, out_channels=192, kernel_size=5, padding=2),
79
- nn.BatchNorm2d(num_features=192),
80
- nn.ReLU(),
81
- nn.MaxPool2d(kernel_size=2, stride=2, padding=1),
82
- nn.Dropout(0.2),
83
- )
84
- self._hidden8 = nn.Sequential(
85
- nn.Conv2d(in_channels=192, out_channels=192, kernel_size=5, padding=2),
86
- nn.BatchNorm2d(num_features=192),
87
- nn.ReLU(),
88
- nn.MaxPool2d(kernel_size=2, stride=1, padding=1),
89
- nn.Dropout(0.2),
90
- )
91
- self._hidden9 = nn.Sequential(nn.Linear(192 * 7 * 7, 3072), nn.ReLU())
92
- self._hidden10 = nn.Sequential(nn.Linear(3072, 3072), nn.ReLU())
93
-
94
- self._digit_length = nn.Sequential(nn.Linear(3072, 7))
95
- self._digit1 = nn.Sequential(nn.Linear(3072, 11))
96
- self._digit2 = nn.Sequential(nn.Linear(3072, 11))
97
- self._digit3 = nn.Sequential(nn.Linear(3072, 11))
98
- self._digit4 = nn.Sequential(nn.Linear(3072, 11))
99
- self._digit5 = nn.Sequential(nn.Linear(3072, 11))
100
-
101
- @torch.jit.script_method
102
- def forward(self, x):
103
- x = self._hidden1(x)
104
- x = self._hidden2(x)
105
- x = self._hidden3(x)
106
- x = self._hidden4(x)
107
- x = self._hidden5(x)
108
- x = self._hidden6(x)
109
- x = self._hidden7(x)
110
- x = self._hidden8(x)
111
- x = x.view(x.size(0), 192 * 7 * 7)
112
- x = self._hidden9(x)
113
- x = self._hidden10(x)
114
-
115
- length_logits = self._digit_length(x)
116
- digit1_logits = self._digit1(x)
117
- digit2_logits = self._digit2(x)
118
- digit3_logits = self._digit3(x)
119
- digit4_logits = self._digit4(x)
120
- digit5_logits = self._digit5(x)
121
-
122
- return (
123
- length_logits,
124
- digit1_logits,
125
- digit2_logits,
126
- digit3_logits,
127
- digit4_logits,
128
- digit5_logits,
129
- )
130
-
131
- def store(self, path_to_dir, step, maximum=5):
132
- path_to_models = glob.glob(
133
- os.path.join(path_to_dir, Model.CHECKPOINT_FILENAME_PATTERN.format("*"))
134
- )
135
- if len(path_to_models) == maximum:
136
- min_step = min(
137
- [
138
- int(path_to_model.split("\\")[-1][6:-4])
139
- for path_to_model in path_to_models
140
- ]
141
- )
142
- path_to_min_step_model = os.path.join(
143
- path_to_dir, Model.CHECKPOINT_FILENAME_PATTERN.format(min_step)
144
- )
145
- os.remove(path_to_min_step_model)
146
-
147
- path_to_checkpoint_file = os.path.join(
148
- path_to_dir, Model.CHECKPOINT_FILENAME_PATTERN.format(step)
149
- )
150
- torch.save(self.state_dict(), path_to_checkpoint_file)
151
- return path_to_checkpoint_file
152
-
153
- def restore(self, path_to_checkpoint_file):
154
- self.load_state_dict(
155
- torch.load(path_to_checkpoint_file, map_location=torch.device("cpu"))
156
- )
157
- step = int(path_to_checkpoint_file.split("model-")[-1][:-4])
158
- return step