Spaces:
Runtime error
Runtime error
import numpy as np | |
from typing import List, Union, Tuple, Dict | |
import random | |
from PIL import Image | |
import cv2 | |
import os.path as osp | |
from tqdm import tqdm | |
from panopticapi.utils import rgb2id, id2rgb | |
from time import time | |
import traceback | |
from utils.io_utils import bbox_overlap_area | |
from utils.logger import LOGGER | |
from utils.constants import COLOR_PALETTE | |
class PartitionTree: | |
def __init__(self, bleft: int, btop: int, bright: int, bbottom: int, parent = None) -> None: | |
self.left: PartitionTree = None | |
self.right: PartitionTree = None | |
self.top: PartitionTree = None | |
self.bottom: PartitionTree = None | |
if bright < bleft: | |
bright = bleft | |
if bbottom < btop: | |
bbottom = btop | |
self.bleft = bleft | |
self.bright = bright | |
self.btop = btop | |
self.bbottom = bbottom | |
self.parent: PartitionTree = parent | |
def is_leaf(self): | |
return self.left is None | |
def new_partition(self, new_rect: List): | |
self.left = PartitionTree(self.bleft, self.btop, new_rect[0], self.bbottom, self) | |
self.top = PartitionTree(self.bleft, self.btop, self.bright, new_rect[1], self) | |
self.right = PartitionTree(new_rect[2], self.btop, self.bright, self.bbottom, self) | |
self.bottom = PartitionTree(self.bleft, new_rect[3], self.bright, self.bbottom, self) | |
if self.parent is not None: | |
self.root_update_rect(new_rect) | |
def root_update_rect(self, rect): | |
root = self.get_root() | |
root.update_child_rect(rect) | |
def update_child_rect(self, rect: List): | |
if self.is_leaf(): | |
self.update_from_rect(rect) | |
else: | |
self.left.update_child_rect(rect) | |
self.right.update_child_rect(rect) | |
self.top.update_child_rect(rect) | |
self.bottom.update_child_rect(rect) | |
def get_root(self): | |
if self.parent is not None: | |
return self.parent.get_root() | |
else: | |
return self | |
def update_from_rect(self, rect: List): | |
if not self.is_leaf(): | |
return | |
ix = min(self.bright, rect[2]) - max(self.bleft, rect[0]) | |
iy = min(self.bbottom, rect[3]) - max(self.btop, rect[1]) | |
if not (ix > 0 and iy > 0): | |
return | |
new_ltrb0 = np.array([self.bleft, self.btop, self.bright, self.bbottom]) | |
new_ltrb1 = new_ltrb0.copy() | |
if rect[0] > self.bleft and rect[0] < self.bright: | |
new_ltrb0[2] = rect[0] | |
else: | |
new_ltrb0[0] = rect[2] | |
if rect[1] > self.btop and rect[1] < self.bbottom: | |
new_ltrb1[3]= rect[1] | |
else: | |
new_ltrb1[1] = rect[3] | |
if (new_ltrb0[2:] - new_ltrb0[:2]).prod() > (new_ltrb1[2:] - new_ltrb1[:2]).prod(): | |
self.bleft, self.btop, self.bright, self.bbottom = new_ltrb0 | |
else: | |
self.bleft, self.btop, self.bright, self.bbottom = new_ltrb1 | |
def width(self) -> int: | |
return self.bright - self.bleft | |
def height(self) -> int: | |
return self.bbottom - self.btop | |
def prefer_partition(self, tgt_h: int, tgt_w: int): | |
if self.is_leaf(): | |
return self, min(self.width / tgt_w, 1.2) * min(self.height / tgt_h, 1.2) | |
else: | |
lp, ls = self.left.prefer_partition(tgt_h, tgt_w) | |
rp, rs = self.right.prefer_partition(tgt_h, tgt_w) | |
tp, ts = self.top.prefer_partition(tgt_h, tgt_w) | |
bp, bs = self.bottom.prefer_partition(tgt_h, tgt_w) | |
preferp = [(p, s) for s, p in sorted(zip([ls, rs, ts, bs],[lp, rp, tp, bp]), key=lambda pair: pair[0], reverse=True)][0] | |
return preferp | |
def new_random_pos(self, fg_h: int, fg_w: int, im_h: int, im_w: int, random_sample: bool = False): | |
extx, exty = int(fg_w / 3), int(fg_h / 3) | |
extxb, extyb = int(fg_w / 10), int(fg_h / 10) | |
region_w, region_h = self.width + extx, self.height + exty | |
downscale_ratio = max(min(region_w / fg_w, region_h / fg_h), 0.8) | |
if downscale_ratio < 1: | |
fg_h = int(downscale_ratio * fg_h) | |
fg_w = int(downscale_ratio * fg_w) | |
max_x, max_y = self.bright + extx - fg_w, self.bbottom + exty - fg_h | |
max_x = min(im_w+extxb-fg_w, max_x) | |
max_y = min(im_h+extyb-fg_h, max_y) | |
min_x = max(min(self.bright + extx - fg_w, self.bleft - extx), -extx) | |
min_x = max(-extxb, min_x) | |
min_y = max(min(self.bbottom + exty - fg_h, self.btop - exty), -exty) | |
min_y = max(-extyb, min_y) | |
px, py = min_x, min_y | |
if min_x < max_x: | |
if random_sample: | |
px = random.randint(min_x, max_x) | |
else: | |
px = int((min_x + max_x) / 2) | |
if min_y < max_y: | |
if random_sample: | |
py = random.randint(min_y, max_y) | |
else: | |
py = int((min_y + max_y) / 2) | |
return px, py, downscale_ratio | |
def drawpartition(self, image: np.ndarray, color = None): | |
if color is None: | |
color = (random.randint(0, 255), random.randint(0, 255), random.randint(0, 255)) | |
if not self.is_leaf(): | |
cv2.rectangle(image, (self.bleft, self.btop), (self.bright, self.bbottom), color, 2) | |
if not self.is_leaf(): | |
c = (random.randint(0, 255), random.randint(0, 255), random.randint(0, 255)) | |
self.left.drawpartition(image, c) | |
self.right.drawpartition(image, c) | |
self.top.drawpartition(image, c) | |
self.bottom.drawpartition(image, c) | |
def paste_one_fg(fg_pil: Image, bg: Image, segments: np.ndarray, px: int, py: int, seg_color: Tuple, cal_area=True): | |
fg_h, fg_w = fg_pil.height, fg_pil.width | |
im_h, im_w = bg.height, bg.width | |
bg.paste(fg_pil, (px, py), mask=fg_pil) | |
bgx1, bgx2, bgy1, bgy2 = px, px+fg_w, py, py+fg_h | |
fgx1, fgx2, fgy1, fgy2 = 0, fg_w, 0, fg_h | |
if bgx1 < 0: | |
fgx1 = -bgx1 | |
bgx1 = 0 | |
if bgy1 < 0: | |
fgy1 = -bgy1 | |
bgy1 = 0 | |
if bgx2 > im_w: | |
fgx2 = im_w - bgx2 | |
bgx2 = im_w | |
if bgy2 > im_h: | |
fgy2 = im_h - bgy2 | |
bgy2 = im_h | |
fg_mask = np.array(fg_pil)[fgy1: fgy2, fgx1: fgx2, 3] > 30 | |
segments[bgy1: bgy2, bgx1: bgx2][np.where(fg_mask)] = seg_color | |
if cal_area: | |
area = fg_mask.sum() | |
else: | |
area = 1 | |
bbox = [bgx1, bgy1, bgx2-bgx1, bgy2-bgy1] | |
return area, bbox, [bgx1, bgy1, bgx2, bgy2] | |
def partition_paste(fg_list, bg: Image): | |
segments_info = [] | |
fg_list.sort(key = lambda x: x['image'].shape[0] * x['image'].shape[1], reverse=True) | |
pnode: PartitionTree = None | |
im_h, im_w = bg.height, bg.width | |
ptree = PartitionTree(0, 0, bg.width, bg.height) | |
segments = np.zeros((im_h, im_w, 3), np.uint8) | |
for ii, fg_dict in enumerate(fg_list): | |
fg = fg_dict['image'] | |
fg_h, fg_w = fg.shape[:2] | |
pnode, _ = ptree.prefer_partition(fg_h, fg_w) | |
px, py, downscale_ratio = pnode.new_random_pos(fg_h, fg_w, im_h, im_w, True) | |
fg_pil = Image.fromarray(fg) | |
if downscale_ratio < 1: | |
fg_pil = fg_pil.resize((int(fg_w * downscale_ratio), int(fg_h * downscale_ratio)), resample=Image.Resampling.LANCZOS) | |
# fg_h, fg_w = fg_pil.height, fg_pil.width | |
seg_color = COLOR_PALETTE[ii] | |
area, bbox, xyxy = paste_one_fg(fg_pil, bg, segments, px,py, seg_color, cal_area=False) | |
pnode.new_partition(xyxy) | |
segments_info.append({ | |
'id': rgb2id(seg_color), | |
'bbox': bbox, | |
'area': area | |
}) | |
return segments_info, segments | |
# if downscale_ratio < 1: | |
# fg_pil = fg_pil.resize((int(fg_w * downscale_ratio), int(fg_h * downscale_ratio)), resample=Image.Resampling.LANCZOS) | |
# fg_h, fg_w = fg_pil.height, fg_pil.width | |
def gen_fg_regbboxes(fg_list: List[Dict], tgt_size: int, min_overlap=0.15, max_overlap=0.8): | |
def _sample_y(h): | |
y = (tgt_size - h) // 2 | |
if y > 0: | |
yrange = min(y, h // 4) | |
y += random.randint(-yrange, yrange) | |
return y | |
else: | |
return 0 | |
shape_list = [] | |
depth_list = [] | |
for fg_dict in fg_list: | |
shape_list.append(fg_dict['image'].shape[:2]) | |
shape_list = np.array(shape_list) | |
depth_list = np.random.random(len(fg_list)) | |
depth_list[shape_list[..., 1] > 0.6 * tgt_size] += 1 | |
# num_fg = len(fg_list) | |
# grid_sample = random.random() < 0.4 or num_fg > 6 | |
# grid_sample = grid_sample and num_fg < 9 and num_fg > 3 | |
# grid_sample = False | |
# if grid_sample: | |
# grid_pos = np.arange(9) | |
# np.random.shuffle(grid_pos) | |
# grid_pos = grid_pos[: num_fg] | |
# grid_x = grid_pos % 3 | |
# grid_y = grid_pos // 3 | |
# else: | |
pos_list = [[0, _sample_y(shape_list[0][0])]] | |
pre_overlap = 0 | |
for ii, ((h, w), d) in enumerate(zip(shape_list[1:], depth_list[1:])): | |
(preh, prew), predepth, (prex, prey) = shape_list[ii], depth_list[ii], pos_list[ii] | |
isfg = d < predepth | |
y = _sample_y(h) | |
x = prex+prew | |
if isfg: | |
min_x = max_x = x | |
if pre_overlap < max_overlap: | |
min_x -= (max_overlap - pre_overlap) * prew | |
min_x = int(min_x) | |
if pre_overlap < min_overlap: | |
max_x -= (min_overlap - pre_overlap) * prew | |
max_x = int(max_x) | |
x = random.randint(min_x, max_x) | |
pre_overlap = 0 | |
else: | |
overlap = random.uniform(min_overlap, max_overlap) | |
x -= int(overlap * w) | |
area = h * w | |
overlap_area = bbox_overlap_area([x, y, w, h], [prex, prey, prew, preh]) | |
pre_overlap = overlap_area / area | |
pos_list.append([x, y]) | |
pos_list = np.array(pos_list) | |
last_x2 = pos_list[-1][0] + shape_list[-1][1] | |
valid_shiftx = tgt_size - last_x2 | |
if valid_shiftx > 0: | |
shiftx = random.randint(0, valid_shiftx) | |
pos_list[:, 0] += shiftx | |
else: | |
pos_list[:, 0] += valid_shiftx // 2 | |
for pos, fg_dict, depth in zip(pos_list, fg_list, depth_list): | |
fg_dict['pos'] = pos | |
fg_dict['depth'] = depth | |
fg_list.sort(key=lambda x: x['depth'], reverse=True) | |
def regular_paste(fg_list, bg: Image, regen_bboxes=False): | |
segments_info = [] | |
im_h, im_w = bg.height, bg.width | |
if regen_bboxes: | |
random.shuffle(fg_list) | |
gen_fg_regbboxes(fg_list, im_h) | |
segments = np.zeros((im_h, im_w, 3), np.uint8) | |
for ii, fg_dict in enumerate(fg_list): | |
fg = fg_dict['image'] | |
px, py = fg_dict.pop('pos') | |
fg_pil = Image.fromarray(fg) | |
seg_color = COLOR_PALETTE[ii] | |
area, bbox, xyxy = paste_one_fg(fg_pil, bg, segments, px,py, seg_color, cal_area=True) | |
segments_info.append({ | |
'id': rgb2id(seg_color), | |
'bbox': bbox, | |
'area': area | |
}) | |
return segments_info, segments |