Spaces:
Runtime error
Runtime error
import numpy as np | |
from random import choice as rchoice | |
from random import randint | |
import random | |
import cv2, traceback, imageio | |
import os.path as osp | |
from typing import Optional, List, Union, Tuple, Dict | |
from utils.io_utils import imread_nogrey_rgb, json2dict | |
from .transforms import rotate_image | |
from utils.logger import LOGGER | |
class NameSampler: | |
def __init__(self, name_prob_dict, sample_num=2048) -> None: | |
self.name_prob_dict = name_prob_dict | |
self._id2name = list(name_prob_dict.keys()) | |
self.sample_ids = [] | |
total_prob = 0. | |
for ii, (_, prob) in enumerate(name_prob_dict.items()): | |
tgt_num = int(prob * sample_num) | |
total_prob += prob | |
if tgt_num > 0: | |
self.sample_ids += [ii] * tgt_num | |
nsamples = len(self.sample_ids) | |
assert prob <= 1 | |
if prob < 1 and nsamples < sample_num: | |
self.sample_ids += [len(self._id2name)] * (sample_num - nsamples) | |
self._id2name.append('_') | |
def sample(self) -> str: | |
return self._id2name[rchoice(self.sample_ids)] | |
class PossionSampler: | |
def __init__(self, lam=3, min_val=1, max_val=8) -> None: | |
self._distr = np.random.poisson(lam, 1024) | |
invalid = np.where(np.logical_or(self._distr<min_val, self._distr > max_val)) | |
self._distr[invalid] = np.random.randint(min_val, max_val, len(invalid[0])) | |
def sample(self) -> int: | |
return rchoice(self._distr) | |
class NormalSampler: | |
def __init__(self, loc=0.33, std=0.2, min_scale=0.15, max_scale=0.85, scalar=1, to_int = True): | |
s = np.random.normal(loc, std, 4096) | |
valid = np.where(np.logical_and(s>min_scale, s<max_scale)) | |
self._distr = s[valid] * scalar | |
if to_int: | |
self._distr = self._distr.astype(np.int32) | |
def sample(self) -> int: | |
return rchoice(self._distr) | |
class PersonBBoxSampler: | |
def __init__(self, sample_path: Union[str, List]='data/cocoperson_bbox_samples.json', fg_info_list: List = None, fg_transform=None, is_train=True) -> None: | |
if isinstance(sample_path, str): | |
sample_path = [sample_path] | |
self.bbox_list = [] | |
for sp in sample_path: | |
bboxlist = json2dict(sp) | |
for bboxes in bboxlist: | |
if isinstance(bboxes, dict): | |
bboxes = bboxes['bboxes'] | |
bboxes = np.array(bboxes) | |
bboxes[:, [0, 1]] -= bboxes[:, [0, 1]].min(axis=0) | |
self.bbox_list.append(bboxes) | |
self.fg_info_list = fg_info_list | |
self.fg_transform = fg_transform | |
self.is_train = is_train | |
def sample(self, tgt_size: int, scale_range=(1, 1), size_thres=(0.02, 0.85)) -> List[np.ndarray]: | |
bboxes_normalized = rchoice(self.bbox_list) | |
if scale_range[0] != 1 or scale_range[1] != 1: | |
bbox_scale = random.uniform(scale_range[0], scale_range[1]) | |
else: | |
bbox_scale = 1 | |
bboxes = (bboxes_normalized * tgt_size * bbox_scale).astype(np.int32) | |
xyxy_array = np.copy(bboxes) | |
xyxy_array[:, [2, 3]] += xyxy_array[:, [0, 1]] | |
x_max, y_max = xyxy_array[:, 2].max(), xyxy_array[:, 3].max() | |
x_shift = tgt_size - x_max | |
x_shift = randint(0, x_shift) if x_shift > 0 else 0 | |
y_shift = tgt_size - y_max | |
y_shift = randint(0, y_shift) if y_shift > 0 else 0 | |
bboxes[:, [0, 1]] += [x_shift, y_shift] | |
valid_bboxes = [] | |
max_size = size_thres[1] * tgt_size | |
min_size = size_thres[0] * tgt_size | |
for bbox in bboxes: | |
w = min(bbox[2], tgt_size - bbox[0]) | |
h = min(bbox[3], tgt_size - bbox[1]) | |
if max(h, w) < max_size and min(h, w) > min_size: | |
valid_bboxes.append(bbox) | |
return valid_bboxes | |
def sample_matchfg(self, tgt_size: int): | |
while True: | |
bboxes = self.sample(tgt_size, (1.1, 1.8)) | |
if len(bboxes) > 0: | |
break | |
MIN_FG_SIZE = 20 | |
num_fg = len(bboxes) | |
rotate = 20 if self.is_train else 15 | |
fgs = random_load_nfg(num_fg, self.fg_info_list, random_rotate_prob=0.33, random_rotate=rotate) | |
assert len(fgs) == num_fg | |
bboxes.sort(key=lambda x: x[2] / x[3]) | |
fgs.sort(key=lambda x: x['asp_ratio']) | |
for fg, bbox in zip(fgs, bboxes): | |
x, y, w, h = bbox | |
img = fg['image'] | |
im_h, im_w = img.shape[:2] | |
if im_h < h and im_w < w: | |
scale = min(h / im_h, w / im_w) | |
new_h, new_w = int(scale * im_h), int(scale * im_w) | |
img = cv2.resize(img, (new_w, new_h), interpolation=cv2.INTER_LINEAR) | |
else: | |
scale_h, scale_w = min(1, h / im_h), min(1, w / im_w) | |
scale = (scale_h + scale_w) / 2 | |
if scale < 1: | |
new_h, new_w = max(int(scale * im_h), MIN_FG_SIZE), max(int(scale * im_w), MIN_FG_SIZE) | |
img = cv2.resize(img, (new_w, new_h), interpolation=cv2.INTER_AREA) | |
if self.fg_transform is not None: | |
img = self.fg_transform(image=img)['image'] | |
im_h, im_w = img.shape[:2] | |
fg['image'] = img | |
px = int(x + w / 2 - im_w / 2) | |
py = int(y + h / 2 - im_h / 2) | |
fg['pos'] = (px, py) | |
random.shuffle(fgs) | |
slist, llist = [], [] | |
large_size = int(tgt_size * 0.55) | |
for fg in fgs: | |
if max(fg['image'].shape[:2]) > large_size: | |
llist.append(fg) | |
else: | |
slist.append(fg) | |
return llist + slist | |
def random_load_nfg(num_fg: int, fg_info_list: List[Union[Dict, str]], random_rotate=0, random_rotate_prob=0.): | |
fgs = [] | |
while len(fgs) < num_fg: | |
fg, fginfo = random_load_valid_fg(fg_info_list) | |
if random.random() < random_rotate_prob: | |
rotate_deg = randint(-random_rotate, random_rotate) | |
fg = rotate_image(fg, rotate_deg, alpha_crop=True) | |
asp_ratio = fg.shape[1] / fg.shape[0] | |
fgs.append({'image': fg, 'asp_ratio': asp_ratio, 'fginfo': fginfo}) | |
while len(fgs) < num_fg and random.random() < 0.12: | |
fgs.append({'image': fg, 'asp_ratio': asp_ratio, 'fginfo': fginfo}) | |
return fgs | |
def random_load_valid_fg(fg_info_list: List[Union[Dict, str]]) -> Tuple[np.ndarray, Dict]: | |
while True: | |
item = fginfo = rchoice(fg_info_list) | |
file_path = fginfo['file_path'] | |
if 'root_dir' in fginfo and fginfo['root_dir']: | |
file_path = osp.join(fginfo['root_dir'], file_path) | |
try: | |
fg = imageio.imread(file_path) | |
except: | |
LOGGER.error(traceback.format_exc()) | |
LOGGER.error(f'invalid fg: {file_path}') | |
fg_info_list.remove(item) | |
continue | |
c = 1 | |
if len(fg.shape) == 3: | |
c = fg.shape[-1] | |
if c != 4: | |
LOGGER.warning(f'fg {file_path} doesnt have alpha channel') | |
fg_info_list.remove(item) | |
else: | |
if 'xyxy' in fginfo: | |
x1, y1, x2, y2 = fginfo['xyxy'] | |
else: | |
oh, ow = fg.shape[:2] | |
ksize = 5 | |
mask = cv2.blur(fg[..., 3], (ksize,ksize)) | |
_, mask = cv2.threshold(mask, 20, 255, cv2.THRESH_BINARY) | |
x1, y1, w, h = cv2.boundingRect(cv2.findNonZero(mask)) | |
x2, y2 = x1 + w, y1 + h | |
if oh - h > 15 or ow - w > 15: | |
crop = True | |
else: | |
x1 = y1 = 0 | |
x2, y2 = ow, oh | |
fginfo['xyxy'] = [x1, y1, x2, y2] | |
fg = fg[y1: y2, x1: x2] | |
return fg, fginfo | |
def random_load_valid_bg(bg_list: List[str]) -> np.ndarray: | |
while True: | |
try: | |
bgp = rchoice(bg_list) | |
return imread_nogrey_rgb(bgp) | |
except: | |
LOGGER.error(traceback.format_exc()) | |
LOGGER.error(f'invalid bg: {bgp}') | |
bg_list.remove(bgp) | |
continue |