File size: 4,660 Bytes
fd52b7f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
import torch
from torch.utils.data import Dataset, DataLoader
import numpy as np
from sklearn.model_selection import train_test_split
import albumentations as A
from albumentations.pytorch.transforms import ToTensorV2
from PIL import Image

from pathlib import Path
from random import randint

from utils import *

"""
    Dataset class for storing stamps data.
    
    Arguments:
    data -- list of dictionaries containing file_path (path to the image), box_nb (number of boxes on the image), and boxes of shape (4,)
    image_folder -- path to folder containing images
    transforms -- transforms from albumentations package
"""
class StampDataset(Dataset):
    def __init__(
            self, 
            data=read_data(), 
            image_folder=Path(IMAGE_FOLDER),
            transforms=None):
        self.data = data
        self.image_folder = image_folder
        self.transforms = transforms

    def __getitem__(self, idx):
        item = self.data[idx]
        image_fn = self.image_folder / item['file_path']
        boxes = item['boxes']
        box_nb = item['box_nb']
        labels = torch.zeros((box_nb, 2), dtype=torch.int64)
        labels[:, 0] = 1

        img = np.array(Image.open(image_fn))

        try:
            if self.transforms:
                sample = self.transforms(**{
                    "image":img,
                    "bboxes": boxes,
                    "labels": labels,
                })
                img = sample['image']
                boxes = torch.stack(tuple(map(torch.tensor, zip(*sample['bboxes'])))).permute(1, 0)
        except:
            return self.__getitem__(randint(0, len(self.data)-1))

        target_tensor = boxes_to_tensor(boxes.type(torch.float32))
        return img, target_tensor

    def __len__(self):
        return len(self.data)

def collate_fn(batch):
    return tuple(zip(*batch))


def get_datasets(data_path=ANNOTATIONS_PATH, train_transforms=None, val_transforms=None):
    """
        Creates StampDataset objects.

        Arguments: 
        data_path -- string or Path, specifying path to annotations file
        train_transforms -- transforms to be applied during training
        val_transforms -- transforms to be applied during validation

        Returns:
        (train_dataset, val_dataset) -- tuple of StampDataset for training and validation
    """
    data = read_data(data_path)
    if train_transforms is None:
        train_transforms = A.Compose([
            A.RandomCropNearBBox(max_part_shift=0.6, p=0.4),
            A.Resize(height=448, width=448),
            A.HorizontalFlip(p=0.5),
            A.VerticalFlip(p=0.5),
            # A.Affine(scale=(0.9, 1.1), translate_percent=(0.05, 0.1), rotate=(-45, 45), shear=(-30, 30), p=0.3),
            # A.Blur(blur_limit=4, p=0.3),
            A.Normalize(),
            ToTensorV2(p=1.0),
        ],
        bbox_params={
            "format":"coco",
            'label_fields': ['labels']
        })

    if val_transforms is None:
        val_transforms = A.Compose([
            A.Resize(height=448, width=448),
            A.Normalize(),
            ToTensorV2(p=1.0),
        ],
        bbox_params={
            "format":"coco",
            'label_fields': ['labels']
        })
    train, test_data = train_test_split(data, test_size=0.1, shuffle=True)

    train_data, val_data = train_test_split(train, test_size=0.2, shuffle=True)

    train_dataset = StampDataset(train_data, transforms=train_transforms)
    val_dataset = StampDataset(val_data, transforms=val_transforms)
    test_dataset = StampDataset(test_data, transforms=val_transforms)

    return train_dataset, val_dataset, test_dataset


def get_loaders(batch_size=8, data_path=ANNOTATIONS_PATH, num_workers=0, train_transforms=None, val_transforms=None):
    """
        Creates StampDataset objects.

        Arguments: 
        batch_size -- integer specifying the number of images in the batch
        data_path -- string or Path, specifying path to annotations file
        train_transforms -- transforms to be applied during training
        val_transforms -- transforms to be applied during validation

        Returns:
        (train_loader, val_loader) -- tuple of DataLoader for training and validation
    """
    train_dataset, val_dataset, _ = get_datasets(data_path)

    train_loader = DataLoader(
        train_dataset,
        batch_size=batch_size,
        shuffle=True,
        num_workers=num_workers,
        collate_fn=collate_fn, drop_last=True)

    val_loader = DataLoader(
        val_dataset,
        batch_size=batch_size,
        collate_fn=collate_fn)
    
    return train_loader, val_loader