Chintan-Shah commited on
Commit
7c9d3ad
·
verified ·
1 Parent(s): 84c7b51

Required files

Browse files
resnet.py ADDED
@@ -0,0 +1,262 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+ import pytorch_lightning as pl
5
+ from torch.utils.data import DataLoader, random_split
6
+ from torchvision.datasets import CIFAR10
7
+ import torchvision.transforms.v2 as transforms
8
+ from torchmetrics import Accuracy
9
+ from pytorch_lightning.callbacks import Callback
10
+
11
+ import os
12
+
13
+ AVAIL_GPUS = min(1, torch.cuda.device_count())
14
+ BATCH_SIZE = 256 if AVAIL_GPUS else 64
15
+
16
+ class BasicBlock(nn.Module):
17
+ expansion = 1
18
+
19
+ def __init__(self, in_planes, planes, stride=1):
20
+ super(BasicBlock, self).__init__()
21
+ self.conv1 = nn.Conv2d(
22
+ in_planes, planes, kernel_size=3, stride=stride, padding=1, bias=False)
23
+ self.bn1 = nn.BatchNorm2d(planes)
24
+ self.conv2 = nn.Conv2d(planes, planes, kernel_size=3,
25
+ stride=1, padding=1, bias=False)
26
+ self.bn2 = nn.BatchNorm2d(planes)
27
+
28
+ self.shortcut = nn.Sequential()
29
+ if stride != 1 or in_planes != self.expansion*planes:
30
+ self.shortcut = nn.Sequential(
31
+ nn.Conv2d(in_planes, self.expansion*planes,
32
+ kernel_size=1, stride=stride, bias=False),
33
+ nn.BatchNorm2d(self.expansion*planes)
34
+ )
35
+
36
+ def forward(self, x):
37
+ out = F.relu(self.bn1(self.conv1(x)))
38
+ out = self.bn2(self.conv2(out))
39
+ out += self.shortcut(x)
40
+ out = F.relu(out)
41
+ return out
42
+
43
+ class ResNet(nn.Module):
44
+ def __init__(self, block, num_blocks, num_classes=10):
45
+ super(ResNet, self).__init__()
46
+ self.in_planes = 64
47
+
48
+ self.conv1 = nn.Conv2d(3, 64, kernel_size=3,
49
+ stride=1, padding=1, bias=False)
50
+ self.bn1 = nn.BatchNorm2d(64)
51
+ self.layer1 = self._make_layer(block, 64, num_blocks[0], stride=1)
52
+ self.layer2 = self._make_layer(block, 128, num_blocks[1], stride=2)
53
+ self.layer3 = self._make_layer(block, 256, num_blocks[2], stride=2)
54
+ self.layer4 = self._make_layer(block, 512, num_blocks[3], stride=2)
55
+ self.linear = nn.Linear(512*block.expansion, num_classes)
56
+
57
+ def _make_layer(self, block, planes, num_blocks, stride):
58
+ strides = [stride] + [1]*(num_blocks-1)
59
+ layers = []
60
+ for stride in strides:
61
+ layers.append(block(self.in_planes, planes, stride))
62
+ self.in_planes = planes * block.expansion
63
+ return nn.Sequential(*layers)
64
+
65
+ def forward(self, x):
66
+ out = F.relu(self.bn1(self.conv1(x)))
67
+ out = self.layer1(out)
68
+ out = self.layer2(out)
69
+ out = self.layer3(out)
70
+ out = self.layer4(out)
71
+ out = F.avg_pool2d(out, 4)
72
+ out = out.view(out.size(0), -1)
73
+ out = self.linear(out)
74
+ return out
75
+
76
+
77
+ def ResNet18():
78
+ return ResNet(BasicBlock, [2, 2, 2, 2])
79
+
80
+
81
+ def ResNet34():
82
+ return ResNet(BasicBlock, [3, 4, 6, 3])
83
+
84
+ def test():
85
+ net = ResNet18()
86
+ y = net(torch.randn(1, 3, 32, 32))
87
+ print(y.size())
88
+
89
+ class LitResNet18(pl.LightningModule):
90
+ def __init__(self, data_dir, num_classes=10, learning_rate=0.01, max_lr=1.45E-03):
91
+ super(LitResNet18, self).__init__()
92
+ self.in_planes = 64
93
+ self.data_dir = data_dir
94
+
95
+ self.conv1 = nn.Conv2d(3, 64, kernel_size=3,
96
+ stride=1, padding=1, bias=False)
97
+ self.bn1 = nn.BatchNorm2d(64)
98
+ self.layer1 = self._make_layer(BasicBlock, 64, 2, stride=1)
99
+ self.layer2 = self._make_layer(BasicBlock, 128, 2, stride=2)
100
+ self.layer3 = self._make_layer(BasicBlock, 256, 2, stride=2)
101
+ self.layer4 = self._make_layer(BasicBlock, 512, 2, stride=2)
102
+ self.linear = nn.Linear(512*BasicBlock.expansion, num_classes)
103
+
104
+ self.learning_rate = learning_rate
105
+ self.max_lr = max_lr
106
+ self.num_classes = num_classes
107
+ self.steps_per_epoch = 50000 / BATCH_SIZE
108
+ self.ds_mean = (0.4914, 0.4822, 0.4465)
109
+ self.ds_std = (0.247, 0.243, 0.261)
110
+
111
+ self.accuracy = Accuracy(task="multiclass", num_classes=num_classes)
112
+
113
+ self.train_transforms = transforms.Compose([
114
+ transforms.RandomCrop(32, padding=4),
115
+ transforms.RandomHorizontalFlip(),
116
+ transforms.ToTensor(),
117
+ transforms.Pad(16, self.ds_mean, 'constant'),
118
+ transforms.ConvertImageDtype(torch.float),
119
+ transforms.Normalize(self.ds_mean, self.ds_std),
120
+ transforms.RandomErasing(scale=(0.125, 0.125), ratio=(1, 1), value=self.ds_mean, inplace=False),
121
+ transforms.CenterCrop(32),
122
+ ])
123
+
124
+ # Test data transformations
125
+ self.test_transforms = transforms.Compose([
126
+ transforms.ToTensor(),
127
+ transforms.ConvertImageDtype(torch.float),
128
+ transforms.Normalize(self.ds_mean, self.ds_std),
129
+ ])
130
+
131
+ def _make_layer(self, block, planes, num_blocks, stride):
132
+ strides = [stride] + [1]*(num_blocks-1)
133
+ layers = []
134
+ for stride in strides:
135
+ layers.append(block(self.in_planes, planes, stride))
136
+ self.in_planes = planes * block.expansion
137
+ return nn.Sequential(*layers)
138
+
139
+ def forward(self, x):
140
+ out = F.relu(self.bn1(self.conv1(x)))
141
+ out = self.layer1(out)
142
+ out = self.layer2(out)
143
+ out = self.layer3(out)
144
+ out = self.layer4(out)
145
+ out = F.avg_pool2d(out, 4)
146
+ out = out.view(out.size(0), -1)
147
+ out = self.linear(out)
148
+ return out
149
+
150
+ def configure_optimizers(self):
151
+ pct_start = 0.3
152
+ base_momentum = 0.85
153
+ max_momentum = 0.9
154
+ optimizer = torch.optim.SGD(self.parameters(), lr=self.learning_rate, momentum=0.9, weight_decay=5e-4)
155
+ steps_per_epoch = int(self.trainer.estimated_stepping_batches/self.trainer.max_epochs)
156
+ # steps_per_epoch = len(train_dataloader)
157
+ pct_start = 0.3
158
+ print("max_lr:", self.max_lr, "epochs:", self.trainer.max_epochs, "steps_per_epoch:", steps_per_epoch)
159
+ scheduler = torch.optim.lr_scheduler.OneCycleLR(
160
+ optimizer,
161
+ max_lr=self.max_lr,
162
+ epochs=self.trainer.max_epochs,
163
+ steps_per_epoch=steps_per_epoch,
164
+ pct_start=pct_start,
165
+ div_factor=10,
166
+ final_div_factor=10,
167
+ three_phase=False,
168
+ anneal_strategy='linear'
169
+ )
170
+ return([optimizer], [{'scheduler': scheduler, 'interval': 'step'}])
171
+
172
+
173
+ def training_step(self, batch, batch_idx):
174
+ x, y = batch
175
+ output = self(x)
176
+ logits = F.log_softmax(output, dim=1)
177
+ preds = torch.argmax(logits, dim=1)
178
+ self.accuracy(preds, y)
179
+
180
+ cross_entropy_loss = torch.nn.CrossEntropyLoss(reduction='mean')
181
+ loss = cross_entropy_loss(logits, y)
182
+
183
+ # Calling self.log will surface up scalars for you in TensorBoard
184
+ self.log("train_loss", loss, prog_bar=True)
185
+ self.log("train_acc", self.accuracy, prog_bar=True)
186
+ return loss
187
+
188
+ def validation_step(self, batch, batch_idx):
189
+ return self.evaluate(batch, 'val')
190
+
191
+ def test_step(self, batch, batch_idx):
192
+ return self.evaluate(batch, 'test')
193
+
194
+ def evaluate(self, batch, stage):
195
+ x, y = batch
196
+ output = self(x)
197
+ logits = F.log_softmax(output, dim=1)
198
+ preds = torch.argmax(logits, dim=1)
199
+ self.accuracy(preds, y)
200
+
201
+ cross_entropy_loss = torch.nn.CrossEntropyLoss(reduction='mean')
202
+ loss = cross_entropy_loss(logits, y)
203
+
204
+ # Calling self.log will surface up scalars for you in TensorBoard
205
+ self.log(f"{stage}_loss", loss, prog_bar=True)
206
+ self.log(f"{stage}_acc", self.accuracy, prog_bar=True)
207
+ return logits
208
+
209
+ ####################
210
+ # DATA RELATED HOOKS
211
+ ####################
212
+
213
+ def prepare_data(self):
214
+ # download
215
+ CIFAR10(root=self.data_dir, train=True, download=True)
216
+ CIFAR10(root=self.data_dir, train=False, download=True)
217
+
218
+ def setup(self, stage=None):
219
+
220
+ # Assign train/val datasets for use in dataloaders
221
+ if stage == "fit" or stage is None:
222
+ self.cifar10_train = CIFAR10(self.data_dir, train=True, transform=self.train_transforms)
223
+ self.cifar10_val = CIFAR10(self.data_dir, train=False, transform=self.test_transforms)
224
+ # self.cifar10_train, self.cifar10_val = random_split(cifar10_full, [45000, 5000])
225
+
226
+ # Assign test dataset for use in dataloader(s)
227
+ if stage == "test" or stage is None:
228
+ self.cifar10_test = CIFAR10(self.data_dir, train=False, transform=self.test_transforms)
229
+
230
+ def train_dataloader(self):
231
+ return DataLoader(self.cifar10_train, batch_size=BATCH_SIZE, num_workers=os.cpu_count())
232
+
233
+ def val_dataloader(self):
234
+ return DataLoader(self.cifar10_val, batch_size=BATCH_SIZE, num_workers=os.cpu_count())
235
+
236
+ def test_dataloader(self):
237
+ return DataLoader(self.cifar10_test, batch_size=BATCH_SIZE, num_workers=os.cpu_count())
238
+
239
+ class MisclassifiedCollector(Callback):
240
+
241
+ def __init__(self):
242
+ super().__init__()
243
+ self.misclassified_data = []
244
+ self.origData = None
245
+
246
+ def on_test_batch_end(self, trainer, pl_module, outputs, batch, batch_idx, dataloader_idx=0):
247
+ data, target = batch
248
+ # print("Data shape:", data.shape)
249
+
250
+ pred_batch = outputs.argmax(dim=1).cpu().tolist()
251
+ actual_batch = target.cpu().tolist()
252
+
253
+ if (len(self.misclassified_data) < 20):
254
+ for i in range(data.shape[0]):
255
+ if pred_batch[i] != actual_batch[i]:
256
+ _misclassified_data = {
257
+ 'pred': pred_batch[i],
258
+ 'actual': actual_batch[i],
259
+ 'data': data[i].detach().cpu().numpy()
260
+ }
261
+ self.misclassified_data.append(_misclassified_data)
262
+ # print("misclassified len:", len(self.misclassified_data))
sample-cifar10-epoch00-val_acc0.36.ckpt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:b4393eb3af9d95c114bf98630e3285621980d9e8bacfffefdb2a2e6cdb45f81d
3
+ size 89492160
sample-cifar10-epoch00-val_acc0.37.ckpt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:e2fb0bef4f23b3a6d07325b725082990d8b40523e9d207ade215c8bf5b708703
3
+ size 89492160