SAMmodel / aot /dataloaders /train_datasets.py
aikenml's picture
Upload folder using huggingface_hub
a69d385
raw
history blame
24.6 kB
from __future__ import division
import os
from glob import glob
import json
import random
import cv2
from PIL import Image
import numpy as np
import torch
from torch.utils.data import Dataset
import torchvision.transforms as TF
import dataloaders.image_transforms as IT
cv2.setNumThreads(0)
def _get_images(sample):
return [sample['ref_img'], sample['prev_img']] + sample['curr_img']
def _get_labels(sample):
return [sample['ref_label'], sample['prev_label']] + sample['curr_label']
def _merge_sample(sample1, sample2, min_obj_pixels=100, max_obj_n=10):
sample1_images = _get_images(sample1)
sample2_images = _get_images(sample2)
sample1_labels = _get_labels(sample1)
sample2_labels = _get_labels(sample2)
obj_idx = torch.arange(0, max_obj_n * 2 + 1).view(max_obj_n * 2 + 1, 1, 1)
selected_idx = None
selected_obj = None
all_img = []
all_mask = []
for idx, (s1_img, s2_img, s1_label, s2_label) in enumerate(
zip(sample1_images, sample2_images, sample1_labels,
sample2_labels)):
s2_fg = (s2_label > 0).float()
s2_bg = 1 - s2_fg
merged_img = s1_img * s2_bg + s2_img * s2_fg
merged_mask = s1_label * s2_bg.long() + (
(s2_label + max_obj_n) * s2_fg.long())
merged_mask = (merged_mask == obj_idx).float()
if idx == 0:
after_merge_pixels = merged_mask.sum(dim=(1, 2), keepdim=True)
selected_idx = after_merge_pixels > min_obj_pixels
selected_idx[0] = True
obj_num = selected_idx.sum().int().item() - 1
selected_idx = selected_idx.expand(-1,
s1_label.size()[1],
s1_label.size()[2])
if obj_num > max_obj_n:
selected_obj = list(range(1, obj_num + 1))
random.shuffle(selected_obj)
selected_obj = [0] + selected_obj[:max_obj_n]
merged_mask = merged_mask[selected_idx].view(obj_num + 1,
s1_label.size()[1],
s1_label.size()[2])
if obj_num > max_obj_n:
merged_mask = merged_mask[selected_obj]
merged_mask[0] += 0.1
merged_mask = torch.argmax(merged_mask, dim=0, keepdim=True).long()
all_img.append(merged_img)
all_mask.append(merged_mask)
sample = {
'ref_img': all_img[0],
'prev_img': all_img[1],
'curr_img': all_img[2:],
'ref_label': all_mask[0],
'prev_label': all_mask[1],
'curr_label': all_mask[2:]
}
sample['meta'] = sample1['meta']
sample['meta']['obj_num'] = min(obj_num, max_obj_n)
return sample
class StaticTrain(Dataset):
def __init__(self,
root,
output_size,
seq_len=5,
max_obj_n=10,
dynamic_merge=True,
merge_prob=1.0,
aug_type='v1'):
self.root = root
self.clip_n = seq_len
self.output_size = output_size
self.max_obj_n = max_obj_n
self.dynamic_merge = dynamic_merge
self.merge_prob = merge_prob
self.img_list = list()
self.mask_list = list()
dataset_list = list()
lines = ['COCO', 'ECSSD', 'MSRA10K', 'PASCAL-S', 'PASCALVOC2012']
for line in lines:
dataset_name = line.strip()
img_dir = os.path.join(root, 'JPEGImages', dataset_name)
mask_dir = os.path.join(root, 'Annotations', dataset_name)
img_list = sorted(glob(os.path.join(img_dir, '*.jpg'))) + \
sorted(glob(os.path.join(img_dir, '*.png')))
mask_list = sorted(glob(os.path.join(mask_dir, '*.png')))
if len(img_list) > 0:
if len(img_list) == len(mask_list):
dataset_list.append(dataset_name)
self.img_list += img_list
self.mask_list += mask_list
print(f'\t{dataset_name}: {len(img_list)} imgs.')
else:
print(
f'\tPreTrain dataset {dataset_name} has {len(img_list)} imgs and {len(mask_list)} annots. Not match! Skip.'
)
else:
print(
f'\tPreTrain dataset {dataset_name} doesn\'t exist. Skip.')
print(
f'{len(self.img_list)} imgs are used for PreTrain. They are from {dataset_list}.'
)
self.aug_type = aug_type
self.pre_random_horizontal_flip = IT.RandomHorizontalFlip(0.5)
self.random_horizontal_flip = IT.RandomHorizontalFlip(0.3)
if self.aug_type == 'v1':
self.color_jitter = TF.ColorJitter(0.1, 0.1, 0.1, 0.03)
elif self.aug_type == 'v2':
self.color_jitter = TF.RandomApply(
[TF.ColorJitter(0.4, 0.4, 0.2, 0.1)], p=0.8)
self.gray_scale = TF.RandomGrayscale(p=0.2)
self.blur = TF.RandomApply([IT.GaussianBlur([.1, 2.])], p=0.3)
else:
assert NotImplementedError
self.random_affine = IT.RandomAffine(degrees=20,
translate=(0.1, 0.1),
scale=(0.9, 1.1),
shear=10,
resample=Image.BICUBIC,
fillcolor=(124, 116, 104))
base_ratio = float(output_size[1]) / output_size[0]
self.random_resize_crop = IT.RandomResizedCrop(
output_size, (0.8, 1),
ratio=(base_ratio * 3. / 4., base_ratio * 4. / 3.),
interpolation=Image.BICUBIC)
self.to_tensor = TF.ToTensor()
self.to_onehot = IT.ToOnehot(max_obj_n, shuffle=True)
self.normalize = TF.Normalize((0.485, 0.456, 0.406),
(0.229, 0.224, 0.225))
def __len__(self):
return len(self.img_list)
def load_image_in_PIL(self, path, mode='RGB'):
img = Image.open(path)
img.load() # Very important for loading large image
return img.convert(mode)
def sample_sequence(self, idx):
img_pil = self.load_image_in_PIL(self.img_list[idx], 'RGB')
mask_pil = self.load_image_in_PIL(self.mask_list[idx], 'P')
frames = []
masks = []
img_pil, mask_pil = self.pre_random_horizontal_flip(img_pil, mask_pil)
# img_pil, mask_pil = self.pre_random_vertical_flip(img_pil, mask_pil)
for i in range(self.clip_n):
img, mask = img_pil, mask_pil
if i > 0:
img, mask = self.random_horizontal_flip(img, mask)
img, mask = self.random_affine(img, mask)
img = self.color_jitter(img)
img, mask = self.random_resize_crop(img, mask)
if self.aug_type == 'v2':
img = self.gray_scale(img)
img = self.blur(img)
mask = np.array(mask, np.uint8)
if i == 0:
mask, obj_list = self.to_onehot(mask)
obj_num = len(obj_list)
else:
mask, _ = self.to_onehot(mask, obj_list)
mask = torch.argmax(mask, dim=0, keepdim=True)
frames.append(self.normalize(self.to_tensor(img)))
masks.append(mask)
sample = {
'ref_img': frames[0],
'prev_img': frames[1],
'curr_img': frames[2:],
'ref_label': masks[0],
'prev_label': masks[1],
'curr_label': masks[2:]
}
sample['meta'] = {
'seq_name': self.img_list[idx],
'frame_num': 1,
'obj_num': obj_num
}
return sample
def __getitem__(self, idx):
sample1 = self.sample_sequence(idx)
if self.dynamic_merge and (sample1['meta']['obj_num'] == 0
or random.random() < self.merge_prob):
rand_idx = np.random.randint(len(self.img_list))
while (rand_idx == idx):
rand_idx = np.random.randint(len(self.img_list))
sample2 = self.sample_sequence(rand_idx)
sample = self.merge_sample(sample1, sample2)
else:
sample = sample1
return sample
def merge_sample(self, sample1, sample2, min_obj_pixels=100):
return _merge_sample(sample1, sample2, min_obj_pixels, self.max_obj_n)
class VOSTrain(Dataset):
def __init__(self,
image_root,
label_root,
imglistdic,
transform=None,
rgb=True,
repeat_time=1,
rand_gap=3,
seq_len=5,
rand_reverse=True,
dynamic_merge=True,
enable_prev_frame=False,
merge_prob=0.3,
max_obj_n=10):
self.image_root = image_root
self.label_root = label_root
self.rand_gap = rand_gap
self.seq_len = seq_len
self.rand_reverse = rand_reverse
self.repeat_time = repeat_time
self.transform = transform
self.dynamic_merge = dynamic_merge
self.merge_prob = merge_prob
self.enable_prev_frame = enable_prev_frame
self.max_obj_n = max_obj_n
self.rgb = rgb
self.imglistdic = imglistdic
self.seqs = list(self.imglistdic.keys())
print('Video Num: {} X {}'.format(len(self.seqs), self.repeat_time))
def __len__(self):
return int(len(self.seqs) * self.repeat_time)
def reverse_seq(self, imagelist, lablist):
if np.random.randint(2) == 1:
imagelist = imagelist[::-1]
lablist = lablist[::-1]
return imagelist, lablist
def get_ref_index(self,
seqname,
lablist,
objs,
min_fg_pixels=200,
max_try=5):
bad_indices = []
for _ in range(max_try):
ref_index = np.random.randint(len(lablist))
if ref_index in bad_indices:
continue
ref_label = Image.open(
os.path.join(self.label_root, seqname, lablist[ref_index]))
ref_label = np.array(ref_label, dtype=np.uint8)
ref_objs = list(np.unique(ref_label))
is_consistent = True
for obj in ref_objs:
if obj == 0:
continue
if obj not in objs:
is_consistent = False
xs, ys = np.nonzero(ref_label)
if len(xs) > min_fg_pixels and is_consistent:
break
bad_indices.append(ref_index)
return ref_index
def get_ref_index_v2(self,
seqname,
lablist,
min_fg_pixels=200,
max_try=20,
total_gap=0):
search_range = len(lablist) - total_gap
if search_range <= 1:
return 0
bad_indices = []
for _ in range(max_try):
ref_index = np.random.randint(search_range)
if ref_index in bad_indices:
continue
ref_label = Image.open(
os.path.join(self.label_root, seqname, lablist[ref_index]))
ref_label = np.array(ref_label, dtype=np.uint8)
xs, ys = np.nonzero(ref_label)
if len(xs) > min_fg_pixels:
break
bad_indices.append(ref_index)
return ref_index
def get_curr_gaps(self, seq_len, max_gap=999, max_try=10):
for _ in range(max_try):
curr_gaps = []
total_gap = 0
for _ in range(seq_len):
gap = int(np.random.randint(self.rand_gap) + 1)
total_gap += gap
curr_gaps.append(gap)
if total_gap <= max_gap:
break
return curr_gaps, total_gap
def get_prev_index(self, lablist, total_gap):
search_range = len(lablist) - total_gap
if search_range > 1:
prev_index = np.random.randint(search_range)
else:
prev_index = 0
return prev_index
def check_index(self, total_len, index, allow_reflect=True):
if total_len <= 1:
return 0
if index < 0:
if allow_reflect:
index = -index
index = self.check_index(total_len, index, True)
else:
index = 0
elif index >= total_len:
if allow_reflect:
index = 2 * (total_len - 1) - index
index = self.check_index(total_len, index, True)
else:
index = total_len - 1
return index
def get_curr_indices(self, lablist, prev_index, gaps):
total_len = len(lablist)
curr_indices = []
now_index = prev_index
for gap in gaps:
now_index += gap
curr_indices.append(self.check_index(total_len, now_index))
return curr_indices
def get_image_label(self, seqname, imagelist, lablist, index):
image = cv2.imread(
os.path.join(self.image_root, seqname, imagelist[index]))
image = np.array(image, dtype=np.float32)
if self.rgb:
image = image[:, :, [2, 1, 0]]
label = Image.open(
os.path.join(self.label_root, seqname, lablist[index]))
label = np.array(label, dtype=np.uint8)
return image, label
def sample_sequence(self, idx):
idx = idx % len(self.seqs)
seqname = self.seqs[idx]
imagelist, lablist = self.imglistdic[seqname]
frame_num = len(imagelist)
if self.rand_reverse:
imagelist, lablist = self.reverse_seq(imagelist, lablist)
is_consistent = False
max_try = 5
try_step = 0
while (is_consistent is False and try_step < max_try):
try_step += 1
# generate random gaps
curr_gaps, total_gap = self.get_curr_gaps(self.seq_len - 1)
if self.enable_prev_frame: # prev frame is randomly sampled
# get prev frame
prev_index = self.get_prev_index(lablist, total_gap)
prev_image, prev_label = self.get_image_label(
seqname, imagelist, lablist, prev_index)
prev_objs = list(np.unique(prev_label))
# get curr frames
curr_indices = self.get_curr_indices(lablist, prev_index,
curr_gaps)
curr_images, curr_labels, curr_objs = [], [], []
for curr_index in curr_indices:
curr_image, curr_label = self.get_image_label(
seqname, imagelist, lablist, curr_index)
c_objs = list(np.unique(curr_label))
curr_images.append(curr_image)
curr_labels.append(curr_label)
curr_objs.extend(c_objs)
objs = list(np.unique(prev_objs + curr_objs))
start_index = prev_index
end_index = max(curr_indices)
# get ref frame
_try_step = 0
ref_index = self.get_ref_index_v2(seqname, lablist)
while (ref_index > start_index and ref_index <= end_index
and _try_step < max_try):
_try_step += 1
ref_index = self.get_ref_index_v2(seqname, lablist)
ref_image, ref_label = self.get_image_label(
seqname, imagelist, lablist, ref_index)
ref_objs = list(np.unique(ref_label))
else: # prev frame is next to ref frame
# get ref frame
ref_index = self.get_ref_index_v2(seqname, lablist)
ref_image, ref_label = self.get_image_label(
seqname, imagelist, lablist, ref_index)
ref_objs = list(np.unique(ref_label))
# get curr frames
curr_indices = self.get_curr_indices(lablist, ref_index,
curr_gaps)
curr_images, curr_labels, curr_objs = [], [], []
for curr_index in curr_indices:
curr_image, curr_label = self.get_image_label(
seqname, imagelist, lablist, curr_index)
c_objs = list(np.unique(curr_label))
curr_images.append(curr_image)
curr_labels.append(curr_label)
curr_objs.extend(c_objs)
objs = list(np.unique(curr_objs))
prev_image, prev_label = curr_images[0], curr_labels[0]
curr_images, curr_labels = curr_images[1:], curr_labels[1:]
is_consistent = True
for obj in objs:
if obj == 0:
continue
if obj not in ref_objs:
is_consistent = False
break
# get meta info
obj_num = list(np.sort(ref_objs))[-1]
sample = {
'ref_img': ref_image,
'prev_img': prev_image,
'curr_img': curr_images,
'ref_label': ref_label,
'prev_label': prev_label,
'curr_label': curr_labels
}
sample['meta'] = {
'seq_name': seqname,
'frame_num': frame_num,
'obj_num': obj_num
}
if self.transform is not None:
sample = self.transform(sample)
return sample
def __getitem__(self, idx):
sample1 = self.sample_sequence(idx)
if self.dynamic_merge and (sample1['meta']['obj_num'] == 0
or random.random() < self.merge_prob):
rand_idx = np.random.randint(len(self.seqs))
while (rand_idx == (idx % len(self.seqs))):
rand_idx = np.random.randint(len(self.seqs))
sample2 = self.sample_sequence(rand_idx)
sample = self.merge_sample(sample1, sample2)
else:
sample = sample1
return sample
def merge_sample(self, sample1, sample2, min_obj_pixels=100):
return _merge_sample(sample1, sample2, min_obj_pixels, self.max_obj_n)
class DAVIS2017_Train(VOSTrain):
def __init__(self,
split=['train'],
root='./DAVIS',
transform=None,
rgb=True,
repeat_time=1,
full_resolution=True,
year=2017,
rand_gap=3,
seq_len=5,
rand_reverse=True,
dynamic_merge=True,
enable_prev_frame=False,
max_obj_n=10,
merge_prob=0.3):
if full_resolution:
resolution = 'Full-Resolution'
if not os.path.exists(os.path.join(root, 'JPEGImages',
resolution)):
print('No Full-Resolution, use 480p instead.')
resolution = '480p'
else:
resolution = '480p'
image_root = os.path.join(root, 'JPEGImages', resolution)
label_root = os.path.join(root, 'Annotations', resolution)
seq_names = []
for spt in split:
with open(os.path.join(root, 'ImageSets', str(year),
spt + '.txt')) as f:
seqs_tmp = f.readlines()
seqs_tmp = list(map(lambda elem: elem.strip(), seqs_tmp))
seq_names.extend(seqs_tmp)
imglistdic = {}
for seq_name in seq_names:
images = list(
np.sort(os.listdir(os.path.join(image_root, seq_name))))
labels = list(
np.sort(os.listdir(os.path.join(label_root, seq_name))))
imglistdic[seq_name] = (images, labels)
super(DAVIS2017_Train, self).__init__(image_root,
label_root,
imglistdic,
transform,
rgb,
repeat_time,
rand_gap,
seq_len,
rand_reverse,
dynamic_merge,
enable_prev_frame,
merge_prob=merge_prob,
max_obj_n=max_obj_n)
class YOUTUBEVOS_Train(VOSTrain):
def __init__(self,
root='./datasets/YTB',
year=2019,
transform=None,
rgb=True,
rand_gap=3,
seq_len=3,
rand_reverse=True,
dynamic_merge=True,
enable_prev_frame=False,
max_obj_n=10,
merge_prob=0.3):
root = os.path.join(root, str(year), 'train')
image_root = os.path.join(root, 'JPEGImages')
label_root = os.path.join(root, 'Annotations')
self.seq_list_file = os.path.join(root, 'meta.json')
self._check_preprocess()
seq_names = list(self.ann_f.keys())
imglistdic = {}
for seq_name in seq_names:
data = self.ann_f[seq_name]['objects']
obj_names = list(data.keys())
images = []
labels = []
for obj_n in obj_names:
if len(data[obj_n]["frames"]) < 2:
print("Short object: " + seq_name + '-' + obj_n)
continue
images += list(
map(lambda x: x + '.jpg', list(data[obj_n]["frames"])))
labels += list(
map(lambda x: x + '.png', list(data[obj_n]["frames"])))
images = np.sort(np.unique(images))
labels = np.sort(np.unique(labels))
if len(images) < 2:
print("Short video: " + seq_name)
continue
imglistdic[seq_name] = (images, labels)
super(YOUTUBEVOS_Train, self).__init__(image_root,
label_root,
imglistdic,
transform,
rgb,
1,
rand_gap,
seq_len,
rand_reverse,
dynamic_merge,
enable_prev_frame,
merge_prob=merge_prob,
max_obj_n=max_obj_n)
def _check_preprocess(self):
if not os.path.isfile(self.seq_list_file):
print('No such file: {}.'.format(self.seq_list_file))
return False
else:
self.ann_f = json.load(open(self.seq_list_file, 'r'))['videos']
return True
class TEST(Dataset):
def __init__(
self,
seq_len=3,
obj_num=3,
transform=None,
):
self.seq_len = seq_len
self.obj_num = obj_num
self.transform = transform
def __len__(self):
return 3000
def __getitem__(self, idx):
img = np.zeros((800, 800, 3)).astype(np.float32)
label = np.ones((800, 800)).astype(np.uint8)
sample = {
'ref_img': img,
'prev_img': img,
'curr_img': [img] * (self.seq_len - 2),
'ref_label': label,
'prev_label': label,
'curr_label': [label] * (self.seq_len - 2)
}
sample['meta'] = {
'seq_name': 'test',
'frame_num': 100,
'obj_num': self.obj_num
}
if self.transform is not None:
sample = self.transform(sample)
return sample