Ubuntu
commited on
Commit
·
2959565
1
Parent(s):
d759493
Added augmentations as per the original paper
Browse files- README.md +26 -0
- resnet_execute.py +18 -7
README.md
CHANGED
@@ -6,6 +6,32 @@
|
|
6 |
|
7 |
## Data Augmentations
|
8 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
9 |
|
10 |
|
11 |
## Model Results
|
|
|
6 |
|
7 |
## Data Augmentations
|
8 |
|
9 |
+
To enhance the model's robustness and generalization capabilities, we apply a series of data augmentations to the training dataset. These augmentations are inspired by the original ResNet paper and implemented using the albumentations library. The augmentations include random resized cropping, horizontal flipping, and color jittering, followed by normalization. These transformations help the model learn invariant features and improve performance on unseen data.
|
10 |
+
|
11 |
+
### Augmentations and Hyperparameters
|
12 |
+
|
13 |
+
1. **Random Resized Crop:**
|
14 |
+
- Height: 224
|
15 |
+
- Width: 224
|
16 |
+
- Scale: (0.08, 1.0)
|
17 |
+
- Aspect Ratio: (3/4, 4/3)
|
18 |
+
- Probability: 1.0
|
19 |
+
|
20 |
+
2. **Horizontal Flip:**
|
21 |
+
- Probability: 0.5
|
22 |
+
|
23 |
+
3. **Color Jitter:**
|
24 |
+
- Brightness: 0.4
|
25 |
+
- Contrast: 0.4
|
26 |
+
- Saturation: 0.4
|
27 |
+
- Hue: 0.1
|
28 |
+
- Probability: 0.8
|
29 |
+
|
30 |
+
4. **Normalization:**
|
31 |
+
- Mean: (0.485, 0.456, 0.406)
|
32 |
+
- Standard Deviation: (0.229, 0.224, 0.225)
|
33 |
+
|
34 |
+
These augmentations are applied only to the training dataset, while the test dataset undergoes resizing and normalization to ensure consistent evaluation metrics.
|
35 |
|
36 |
|
37 |
## Model Results
|
resnet_execute.py
CHANGED
@@ -10,20 +10,31 @@ from torchvision import datasets
|
|
10 |
from checkpoint import save_checkpoint, load_checkpoint
|
11 |
import matplotlib.pyplot as plt
|
12 |
from torchvision.utils import make_grid
|
|
|
|
|
|
|
13 |
|
14 |
# Define transformations
|
15 |
-
|
16 |
-
|
17 |
-
|
18 |
-
|
19 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
20 |
])
|
21 |
|
22 |
# Train dataset and loader
|
23 |
-
trainset = datasets.ImageFolder(root='/mnt/imagenet/ILSVRC/Data/CLS-LOC/train', transform=
|
24 |
trainloader = DataLoader(trainset, batch_size=128, shuffle=True, num_workers=16, pin_memory=True)
|
25 |
|
26 |
-
testset = datasets.ImageFolder(root='/mnt/imagenet/ILSVRC/Data/CLS-LOC/val', transform=
|
27 |
testloader = DataLoader(testset, batch_size=1000, shuffle=False, num_workers=16, pin_memory=True)
|
28 |
|
29 |
# Initialize model, loss function, and optimizer
|
|
|
10 |
from checkpoint import save_checkpoint, load_checkpoint
|
11 |
import matplotlib.pyplot as plt
|
12 |
from torchvision.utils import make_grid
|
13 |
+
import albumentations as A
|
14 |
+
from albumentations.pytorch import ToTensorV2
|
15 |
+
import numpy as np
|
16 |
|
17 |
# Define transformations
|
18 |
+
train_transform = A.Compose([
|
19 |
+
A.RandomResizedCrop(height=224, width=224, scale=(0.08, 1.0), ratio=(3/4, 4/3), p=1.0),
|
20 |
+
A.HorizontalFlip(p=0.5),
|
21 |
+
A.ColorJitter(brightness=0.4, contrast=0.4, saturation=0.4, hue=0.1, p=0.8),
|
22 |
+
A.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)),
|
23 |
+
ToTensorV2()
|
24 |
+
])
|
25 |
+
|
26 |
+
test_transform = A.Compose([
|
27 |
+
A.Resize(height=256, width=256),
|
28 |
+
A.CenterCrop(height=224, width=224),
|
29 |
+
A.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)),
|
30 |
+
ToTensorV2()
|
31 |
])
|
32 |
|
33 |
# Train dataset and loader
|
34 |
+
trainset = datasets.ImageFolder(root='/mnt/imagenet/ILSVRC/Data/CLS-LOC/train', transform=lambda img: train_transform(image=np.array(img))['image'])
|
35 |
trainloader = DataLoader(trainset, batch_size=128, shuffle=True, num_workers=16, pin_memory=True)
|
36 |
|
37 |
+
testset = datasets.ImageFolder(root='/mnt/imagenet/ILSVRC/Data/CLS-LOC/val', transform=lambda img: test_transform(image=np.array(img))['image'])
|
38 |
testloader = DataLoader(testset, batch_size=1000, shuffle=False, num_workers=16, pin_memory=True)
|
39 |
|
40 |
# Initialize model, loss function, and optimizer
|