mambazjp's picture
Upload 57 files
8b79d57
raw
history blame
3.56 kB
import os
import numpy as np
import random
import json
import os
from torch.utils.data import Dataset
from torchvision import transforms
from torchvision.transforms import functional as F
from PIL import Image
class CocoPanopticDataset(Dataset):
def __init__(self,
imgdir: str,
anndir: str,
annfile: str,
transform=None):
with open(annfile) as f:
self.data = json.load(f)['annotations']
self.data = list(filter(lambda data: any(info['category_id'] == 1 for info in data['segments_info']), self.data))
self.imgdir = imgdir
self.anndir = anndir
self.transform = transform
def __len__(self):
return len(self.data)
def __getitem__(self, idx):
data = self.data[idx]
img = self._load_img(data)
seg = self._load_seg(data)
if self.transform is not None:
img, seg = self.transform(img, seg)
return img, seg
def _load_img(self, data):
with Image.open(os.path.join(self.imgdir, data['file_name'].replace('.png', '.jpg'))) as img:
return img.convert('RGB')
def _load_seg(self, data):
with Image.open(os.path.join(self.anndir, data['file_name'])) as ann:
ann.load()
ann = np.array(ann, copy=False).astype(np.int32)
ann = ann[:, :, 0] + 256 * ann[:, :, 1] + 256 * 256 * ann[:, :, 2]
seg = np.zeros(ann.shape, np.uint8)
for segments_info in data['segments_info']:
if segments_info['category_id'] in [1, 27, 32]: # person, backpack, tie
seg[ann == segments_info['id']] = 255
return Image.fromarray(seg)
class CocoPanopticTrainAugmentation:
def __init__(self, size):
self.size = size
self.jitter = transforms.ColorJitter(0.1, 0.1, 0.1, 0.1)
def __call__(self, img, seg):
# Affine
params = transforms.RandomAffine.get_params(degrees=(-20, 20), translate=(0.1, 0.1),
scale_ranges=(1, 1), shears=(-10, 10), img_size=img.size)
img = F.affine(img, *params, interpolation=F.InterpolationMode.BILINEAR)
seg = F.affine(seg, *params, interpolation=F.InterpolationMode.NEAREST)
# Resize
params = transforms.RandomResizedCrop.get_params(img, scale=(0.5, 1), ratio=(0.7, 1.3))
img = F.resized_crop(img, *params, self.size, interpolation=F.InterpolationMode.BILINEAR)
seg = F.resized_crop(seg, *params, self.size, interpolation=F.InterpolationMode.NEAREST)
# Horizontal flip
if random.random() < 0.5:
img = F.hflip(img)
seg = F.hflip(seg)
# Color jitter
img = self.jitter(img)
# To tensor
img = F.to_tensor(img)
seg = F.to_tensor(seg)
return img, seg
class CocoPanopticValidAugmentation:
def __init__(self, size):
self.size = size
def __call__(self, img, seg):
# Resize
params = transforms.RandomResizedCrop.get_params(img, scale=(1, 1), ratio=(1., 1.))
img = F.resized_crop(img, *params, self.size, interpolation=F.InterpolationMode.BILINEAR)
seg = F.resized_crop(seg, *params, self.size, interpolation=F.InterpolationMode.NEAREST)
# To tensor
img = F.to_tensor(img)
seg = F.to_tensor(seg)
return img, seg