File size: 5,546 Bytes
b559e06 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 |
import numpy as np
import pathlib
import cv2
import pandas as pd
import copy
import os
import logging
class OpenImagesDataset:
def __init__(self, root,
transform=None, target_transform=None,
dataset_type="train", balance_data=False):
self.root = pathlib.Path(root)
self.transform = transform
self.target_transform = target_transform
self.dataset_type = dataset_type.lower()
self.data, self.class_names, self.class_dict = self._read_data()
self.balance_data = balance_data
self.min_image_num = -1
if self.balance_data:
self.data = self._balance_data()
self.ids = [info['image_id'] for info in self.data]
self.class_stat = None
def _getitem(self, index):
image_info = self.data[index]
image = self._read_image(image_info['image_id'])
# duplicate boxes to prevent corruption of dataset
boxes = copy.copy(image_info['boxes'])
boxes[:, 0] *= image.shape[1]
boxes[:, 1] *= image.shape[0]
boxes[:, 2] *= image.shape[1]
boxes[:, 3] *= image.shape[0]
# duplicate labels to prevent corruption of dataset
labels = copy.copy(image_info['labels'])
if self.transform:
image, boxes, labels = self.transform(image, boxes, labels)
if self.target_transform:
boxes, labels = self.target_transform(boxes, labels)
return image_info['image_id'], image, boxes, labels
def __getitem__(self, index):
_, image, boxes, labels = self._getitem(index)
return image, boxes, labels
def get_annotation(self, index):
"""To conform the eval_ssd implementation that is based on the VOC dataset."""
image_id, image, boxes, labels = self._getitem(index)
is_difficult = np.zeros(boxes.shape[0], dtype=np.uint8)
return image_id, (boxes, labels, is_difficult)
def get_image(self, index):
image_info = self.data[index]
image = self._read_image(image_info['image_id'])
if self.transform:
image, _ = self.transform(image)
return image
def _read_data(self):
annotation_file = f"{self.root}/sub-{self.dataset_type}-annotations-bbox.csv"
logging.info(f'loading annotations from: {annotation_file}')
annotations = pd.read_csv(annotation_file)
logging.info(f'annotations loaded from: {annotation_file}')
class_names = ['BACKGROUND'] + sorted(list(annotations['ClassName'].unique()))
class_dict = {class_name: i for i, class_name in enumerate(class_names)}
data = []
for image_id, group in annotations.groupby("ImageID"):
img_path = os.path.join(self.root, self.dataset_type, image_id + '.jpg')
if os.path.isfile(img_path) is False:
logging.error(f'missing ImageID {image_id}.jpg - dropping from annotations')
continue
boxes = group.loc[:, ["XMin", "YMin", "XMax", "YMax"]].values.astype(np.float32)
# make labels 64 bits to satisfy the cross_entropy function
labels = np.array([class_dict[name] for name in group["ClassName"]], dtype='int64')
#print('found image {:s} ({:d})'.format(img_path, len(data)))
data.append({
'image_id': image_id,
'boxes': boxes,
'labels': labels
})
print('num images: {:d}'.format(len(data)))
return data, class_names, class_dict
def __len__(self):
return len(self.data)
def __repr__(self):
if self.class_stat is None:
self.class_stat = {name: 0 for name in self.class_names[1:]}
for example in self.data:
for class_index in example['labels']:
class_name = self.class_names[class_index]
self.class_stat[class_name] += 1
content = ["Dataset Summary:"
f"Number of Images: {len(self.data)}",
f"Minimum Number of Images for a Class: {self.min_image_num}",
"Label Distribution:"]
for class_name, num in self.class_stat.items():
content.append(f"\t{class_name}: {num}")
return "\n".join(content)
def _read_image(self, image_id):
image_file = self.root / self.dataset_type / f"{image_id}.jpg"
image = cv2.imread(str(image_file))
if image.shape[2] == 1:
image = cv2.cvtColor(image, cv2.COLOR_GRAY2RGB)
else:
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
return image
def _balance_data(self):
logging.info('balancing data')
label_image_indexes = [set() for _ in range(len(self.class_names))]
for i, image in enumerate(self.data):
for label_id in image['labels']:
label_image_indexes[label_id].add(i)
label_stat = [len(s) for s in label_image_indexes]
self.min_image_num = min(label_stat[1:])
sample_image_indexes = set()
for image_indexes in label_image_indexes[1:]:
image_indexes = np.array(list(image_indexes))
sub = np.random.permutation(image_indexes)[:self.min_image_num]
sample_image_indexes.update(sub)
sample_data = [self.data[i] for i in sample_image_indexes]
return sample_data
|