Ubuntu commited on
Commit
2959565
·
1 Parent(s): d759493

Added augmentations as per the original paper

Browse files
Files changed (2) hide show
  1. README.md +26 -0
  2. resnet_execute.py +18 -7
README.md CHANGED
@@ -6,6 +6,32 @@
6
 
7
  ## Data Augmentations
8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9
 
10
 
11
  ## Model Results
 
6
 
7
  ## Data Augmentations
8
 
9
+ To enhance the model's robustness and generalization capabilities, we apply a series of data augmentations to the training dataset. These augmentations are inspired by the original ResNet paper and implemented using the albumentations library. The augmentations include random resized cropping, horizontal flipping, and color jittering, followed by normalization. These transformations help the model learn invariant features and improve performance on unseen data.
10
+
11
+ ### Augmentations and Hyperparameters
12
+
13
+ 1. **Random Resized Crop:**
14
+ - Height: 224
15
+ - Width: 224
16
+ - Scale: (0.08, 1.0)
17
+ - Aspect Ratio: (3/4, 4/3)
18
+ - Probability: 1.0
19
+
20
+ 2. **Horizontal Flip:**
21
+ - Probability: 0.5
22
+
23
+ 3. **Color Jitter:**
24
+ - Brightness: 0.4
25
+ - Contrast: 0.4
26
+ - Saturation: 0.4
27
+ - Hue: 0.1
28
+ - Probability: 0.8
29
+
30
+ 4. **Normalization:**
31
+ - Mean: (0.485, 0.456, 0.406)
32
+ - Standard Deviation: (0.229, 0.224, 0.225)
33
+
34
+ These augmentations are applied only to the training dataset, while the test dataset undergoes resizing and normalization to ensure consistent evaluation metrics.
35
 
36
 
37
  ## Model Results
resnet_execute.py CHANGED
@@ -10,20 +10,31 @@ from torchvision import datasets
10
  from checkpoint import save_checkpoint, load_checkpoint
11
  import matplotlib.pyplot as plt
12
  from torchvision.utils import make_grid
 
 
 
13
 
14
  # Define transformations
15
- transform = transforms.Compose([
16
- transforms.Resize(256), # Resize the smaller side to 256 pixels while keeping aspect ratio
17
- transforms.CenterCrop(224), # Then crop to 224x224 pixels from the center
18
- transforms.ToTensor(),
19
- transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) # ImageNet normalization
 
 
 
 
 
 
 
 
20
  ])
21
 
22
  # Train dataset and loader
23
- trainset = datasets.ImageFolder(root='/mnt/imagenet/ILSVRC/Data/CLS-LOC/train', transform=transform)
24
  trainloader = DataLoader(trainset, batch_size=128, shuffle=True, num_workers=16, pin_memory=True)
25
 
26
- testset = datasets.ImageFolder(root='/mnt/imagenet/ILSVRC/Data/CLS-LOC/val', transform=transform )
27
  testloader = DataLoader(testset, batch_size=1000, shuffle=False, num_workers=16, pin_memory=True)
28
 
29
  # Initialize model, loss function, and optimizer
 
10
  from checkpoint import save_checkpoint, load_checkpoint
11
  import matplotlib.pyplot as plt
12
  from torchvision.utils import make_grid
13
+ import albumentations as A
14
+ from albumentations.pytorch import ToTensorV2
15
+ import numpy as np
16
 
17
  # Define transformations
18
+ train_transform = A.Compose([
19
+ A.RandomResizedCrop(height=224, width=224, scale=(0.08, 1.0), ratio=(3/4, 4/3), p=1.0),
20
+ A.HorizontalFlip(p=0.5),
21
+ A.ColorJitter(brightness=0.4, contrast=0.4, saturation=0.4, hue=0.1, p=0.8),
22
+ A.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)),
23
+ ToTensorV2()
24
+ ])
25
+
26
+ test_transform = A.Compose([
27
+ A.Resize(height=256, width=256),
28
+ A.CenterCrop(height=224, width=224),
29
+ A.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)),
30
+ ToTensorV2()
31
  ])
32
 
33
  # Train dataset and loader
34
+ trainset = datasets.ImageFolder(root='/mnt/imagenet/ILSVRC/Data/CLS-LOC/train', transform=lambda img: train_transform(image=np.array(img))['image'])
35
  trainloader = DataLoader(trainset, batch_size=128, shuffle=True, num_workers=16, pin_memory=True)
36
 
37
+ testset = datasets.ImageFolder(root='/mnt/imagenet/ILSVRC/Data/CLS-LOC/val', transform=lambda img: test_transform(image=np.array(img))['image'])
38
  testloader = DataLoader(testset, batch_size=1000, shuffle=False, num_workers=16, pin_memory=True)
39
 
40
  # Initialize model, loss function, and optimizer