nofl commited on
Commit
335ddff
·
verified ·
1 Parent(s): 8f85551

Update App.py

Browse files
Files changed (1) hide show
  1. App.py +107 -8
App.py CHANGED
@@ -1,13 +1,112 @@
1
- # ops
2
  from aim import Run
3
  from aim.pytorch import track_gradients_dists, track_params_dists
 
 
 
 
 
 
 
 
 
 
4
 
5
- # Initialize a new Run
6
  aim_run = Run()
7
- ...
8
- items = {'accuracy': acc, 'loss': loss}
9
- aim_run.track(items, epoch=epoch, context={'subset': 'train'})
10
 
11
- # Track weights and gradients distributions
12
- track_params_dists(model, aim_run)
13
- track_gradients_dists(model, aim_run)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  from aim import Run
2
  from aim.pytorch import track_gradients_dists, track_params_dists
3
+ import torch
4
+ import torch.nn as nn
5
+ import torch.optim as optim
6
+ from torchvision import datasets, transforms
7
+ from tqdm import tqdm
8
+
9
+ # Hyperparameters
10
+ batch_size = 64
11
+ epochs = 10
12
+ learning_rate = 0.01
13
 
 
14
  aim_run = Run()
 
 
 
15
 
16
+ class CNN(nn.Module):
17
+ def __init__(self):
18
+ super(CNN, self).__init__()
19
+ self.conv1 = nn.Conv2d(1, 32, 3, 1)
20
+ self.conv2 = nn.Conv2d(32, 64, 3, 1)
21
+ self.pool = nn.MaxPool2d(2, 2)
22
+ self.fc1 = nn.Linear(64 * 7 * 7, 128)
23
+ self.fc2 = nn.Linear(128, 10)
24
+
25
+ def forward(self, x):
26
+ x = self.pool(torch.relu(self.conv1(x)))
27
+ x = self.pool(torch.relu(self.conv2(x)))
28
+ x = torch.flatten(x, 1)
29
+ x = torch.relu(self.fc1(x))
30
+ x = self.fc2(x)
31
+ return x
32
+
33
+ train_dataset = datasets.MNIST(root='./data',
34
+ train=True,
35
+ transform=transforms.ToTensor(),
36
+ download=True)
37
+
38
+ test_dataset = datasets.MNIST(root='./data',
39
+ train=False,
40
+ transform=transforms.ToTensor())
41
+
42
+
43
+ train_loader = torch.utils.data.DataLoader(dataset=train_dataset,
44
+ batch_size=batch_size,
45
+ shuffle=True)
46
+
47
+ test_loader = torch.utils.data.DataLoader(dataset=test_dataset,
48
+ batch_size=batch_size,
49
+ shuffle=False)
50
+
51
+
52
+ model = CNN()
53
+ optimizer = optim.Adam(model.parameters(), lr=learning_rate)
54
+ criterion = nn.CrossEntropyLoss()
55
+
56
+ # Training loop
57
+ for epoch in range(epochs):
58
+ model.train()
59
+ train_loss = 0
60
+ correct = 0
61
+ total = 0
62
+
63
+ for batch_idx, (data, target) in enumerate(tqdm(train_loader, desc="Training", leave=False)):
64
+ optimizer.zero_grad()
65
+ output = model(data)
66
+ loss = criterion(output, target)
67
+ loss.backward()
68
+ optimizer.step()
69
+
70
+ train_loss += loss.item()
71
+ _, predicted = torch.max(output.data, 1)
72
+ total += target.size(0)
73
+ correct += (predicted == target).sum().item()
74
+
75
+ # Track training metrics and distributions
76
+ acc = correct / total
77
+ items = {'accuracy': acc, 'loss': train_loss / len(train_loader)}
78
+ aim_run.track(items, epoch=epoch, context={'subset': 'train'})
79
+
80
+ track_params_dists(model, aim_run, epoch=epoch, context={'subset': 'train'})
81
+ track_gradients_dists(model, aim_run, epoch=epoch, context={'subset': 'train'})
82
+
83
+ ####################
84
+ model.eval()
85
+ test_loss = 0
86
+ correct = 0
87
+ total = 0
88
+ ####################
89
+ with torch.no_grad():
90
+ for batch_idx, (data, target) in enumerate(tqdm(test_loader, desc="Testing", leave=False)):
91
+ output = model(data)
92
+ loss = criterion(output, target)
93
+ test_loss += loss.item()
94
+ _, predicted = torch.max(output.data, 1)
95
+ total += target.size(0)
96
+ correct += (predicted == target).sum().item()
97
+
98
+
99
+
100
+
101
+
102
+
103
+ ##
104
+ acc = correct / total
105
+ items = {'accuracy': acc, 'loss': test_loss / len(test_loader)}
106
+ aim_run.track(items, epoch=epoch, context={'subset': 'test'})
107
+
108
+ track_params_dists(model, aim_run, epoch=epoch, context={'subset': 'test'})
109
+ track_gradients_dists(model, aim_run, epoch=epoch, context={'subset': 'test'})
110
+ #
111
+ ###
112
+ torch.save(model.state_dict(), 'mnist_cnn.pth')