File size: 3,115 Bytes
803ef9e |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 |
import os
import numpy as np
from PIL import Image
from os.path import join
from collections import defaultdict
import torch.utils.data as data
DATA_ROOTS = 'data/Aircraft'
# url = 'http://www.robots.ox.ac.uk/~vgg/data/fgvc-aircraft/archives/fgvc-aircraft-2013b.tar.gz'
# wget http://www.robots.ox.ac.uk/~vgg/data/fgvc-aircraft/archives/fgvc-aircraft-2013b.tar.gz
# python
# from torchvision.datasets.utils import extract_archive
# extract_archive("fgvc-aircraft-2013b.tar.gz")
# Download and preprocess: https://github.com/lvyilin/pytorch-fgvc-dataset/blob/master/aircraft.py
# class_types = ('variant', 'family', 'manufacturer')
# splits = ('train', 'val', 'trainval', 'test')
# img_folder = os.path.join('fgvc-aircraft-2013b', 'data', 'images')
class Aircraft(data.Dataset):
def __init__(self, root=DATA_ROOTS, train=True, image_transforms=None):
super().__init__()
self.root = root
self.train = train
self.image_transforms = image_transforms
paths, bboxes, labels = self.load_images()
self.paths = paths
self.bboxes = bboxes
self.labels = labels
def load_images(self):
split = 'trainval' if self.train else 'test'
variant_path = os.path.join(self.root, 'data', 'images_variant_%s.txt'%split)
with open(variant_path, 'r') as f:
names_to_variants = [line.split('\n')[0].split(' ', 1) for line in f.readlines()]
names_to_variants = dict(names_to_variants)
variants_to_names = defaultdict(list)
for name, variant in names_to_variants.items():
variants_to_names[variant].append(name)
variants = sorted(list(set(variants_to_names.keys())))
names_to_bboxes = self.get_bounding_boxes()
split_files, split_labels, split_bboxes = [], [], []
for variant_id, variant in enumerate(variants):
class_files = [join(self.root, 'data', 'images', '%s.jpg'%filename) for filename in sorted(variants_to_names[variant])]
bboxes = [names_to_bboxes[name] for name in sorted(variants_to_names[variant])]
labels = list([variant_id] * len(class_files))
split_files += class_files
split_labels += labels
split_bboxes += bboxes
return split_files, split_bboxes, split_labels
def get_bounding_boxes(self):
bboxes_path = os.path.join(self.root, 'data', 'images_box.txt')
with open(bboxes_path, 'r') as f:
names_to_bboxes = [line.split('\n')[0].split(' ') for line in f.readlines()]
names_to_bboxes = dict((name, list(map(int, (xmin, ymin, xmax, ymax)))) for name, xmin, ymin, xmax, ymax in names_to_bboxes)
return names_to_bboxes
def __len__(self):
return len(self.paths)
def __getitem__(self, index):
path = self.paths[index]
bbox = tuple(self.bboxes[index])
label = self.labels[index]
image = Image.open(path).convert(mode='RGB')
image = image.crop(bbox)
if self.image_transforms:
image = self.image_transforms(image)
return image, label |