import mmcv, torch
from tqdm import tqdm
from einops import rearrange
import os
import os.path as osp
import cv2
import gc
import math
from .anime_instances import AnimeInstances
import numpy as np
from typing import List, Tuple, Union, Optional, Callable
from mmengine import Config
from mmengine.model.utils import revert_sync_batchnorm
from mmdet.utils import register_all_modules, get_test_pipeline_cfg
from mmdet.apis import init_detector
from mmdet.registry import MODELS
from mmdet.structures import DetDataSample, SampleList
from mmdet.structures.bbox.transforms import scale_boxes, get_box_wh
from mmdet.models.dense_heads.rtmdet_ins_head import RTMDetInsHead
from pycocotools.coco import COCO
from mmcv.transforms import Compose
from mmdet.models.detectors.single_stage import SingleStageDetector
from utils.logger import LOGGER
from utils.io_utils import square_pad_resize, find_all_imgs, imglist2grid, mask2rle, dict2json, scaledown_maxsize, resize_pad
from utils.constants import DEFAULT_DEVICE, CATEGORIES
from utils.booru_tagger import Tagger
from .models.animeseg_refine import AnimeSegmentation, load_refinenet, get_mask
from .models.rtmdet_inshead_custom import RTMDetInsSepBNHeadCustom
from torchvision.ops.boxes import box_iou
import torch.nn.functional as F
def prepare_refine_batch(segmentations: np.ndarray, img: np.ndarray, max_batch_size: int = 4, device: str = 'cpu', input_size: int = 720):
img, (pt, pb, pl, pr) = resize_pad(img, input_size, pad_value=(0, 0, 0))
img = img.transpose((2, 0, 1)).astype(np.float32) / 255.
batch = []
num_seg = len(segmentations)
for ii, seg in enumerate(segmentations):
seg, _ = resize_pad(seg, input_size, 0)
seg = seg[None, ...]
batch.append(np.concatenate((img, seg)))
if ii == num_seg - 1:
yield torch.from_numpy(np.array(batch)).to(device), (pt, pb, pl, pr)
elif len(batch) >= max_batch_size:
yield torch.from_numpy(np.array(batch)).to(device), (pt, pb, pl, pr)
batch = []
VALID_REFINEMETHODS = {'animeseg', 'none'}
def single_image_preprocess(img: Union[str, np.ndarray], pipeline: Compose):
if isinstance(img, str):
img = mmcv.imread(img)
elif not isinstance(img, np.ndarray):
raise NotImplementedError
# img = square_pad_resize(img, 1024)[0]
data_ = dict(img=img, img_id=0)
data_ = pipeline(data_)
data_['inputs'] = [data_['inputs']]
data_['data_samples'] = [data_['data_samples']]
return data_, img
def animeseg_refine(det_pred: DetDataSample, img: np.ndarray, net: AnimeSegmentation, to_rgb=True, input_size: int = 1024):
num_pred = len(det_pred.pred_instances)
if num_pred < 1:
with torch.no_grad():
if to_rgb:
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
seg_thr = 0.5
mask = get_mask(net, img, s=input_size)[..., 0]
mask = (mask > seg_thr)
ins_masks = det_pred.pred_instances.masks
if isinstance(ins_masks, torch.Tensor):
tensor_device = ins_masks.device
tensor_dtype = ins_masks.dtype
to_tensor = True
ins_masks = ins_masks.cpu().numpy()
area_original = np.sum(ins_masks, axis=(1, 2))
masks_refined = np.bitwise_and(ins_masks, mask[None, ...])
area_refined = np.sum(masks_refined, axis=(1, 2))
for ii in range(num_pred):
if area_refined[ii] / area_original[ii] > 0.3:
ins_masks[ii] = masks_refined[ii]
ins_masks = np.ascontiguousarray(ins_masks)
# for ii, insm in enumerate(ins_masks):
# cv2.imwrite(f'{ii}.png', insm.astype(np.uint8) * 255)
if to_tensor:
ins_masks = torch.from_numpy(ins_masks).to(dtype=tensor_dtype).to(device=tensor_device)
det_pred.pred_instances.masks = ins_masks
# rst = np.concatenate((mask * img + 1 - mask, mask * 255), axis=2).astype(np.uint8)
# cv2.imwrite('rst.png', rst)
# def refinenet_forward(det_pred: DetDataSample, img: np.ndarray, net: AnimeSegmentation, to_rgb=True, input_size: int = 1024):
# num_pred = len(det_pred.pred_instances)
# if num_pred < 1:
# return
# with torch.no_grad():
# if to_rgb:
# img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
# seg_thr = 0.5
# h0, w0 = h, w = img.shape[0], img.shape[1]
# if h > w:
# h, w = input_size, int(input_size * w / h)
# else:
# h, w = int(input_size * h / w), input_size
# ph, pw = input_size - h, input_size - w
# tmpImg = np.zeros([s, s, 3], dtype=np.float32)
# tmpImg[ph // 2:ph // 2 + h, pw // 2:pw // 2 + w] = cv2.resize(input_img, (w, h)) / 255
# tmpImg = tmpImg.transpose((2, 0, 1))
# tmpImg = torch.from_numpy(tmpImg).unsqueeze(0).type(torch.FloatTensor).to(model.device)
# with torch.no_grad():
# if use_amp:
# with amp.autocast():
# pred = model(tmpImg)
# pred =
# else:
# pred = model(tmpImg)
# pred = pred[0, :, ph // 2:ph // 2 + h, pw // 2:pw // 2 + w]
# pred = cv2.resize(pred.cpu().numpy().transpose((1, 2, 0)), (w0, h0))[:, :, np.newaxis]
# return pred
# mask = (mask > seg_thr)
# ins_masks = det_pred.pred_instances.masks
# if isinstance(ins_masks, torch.Tensor):
# tensor_device = ins_masks.device
# tensor_dtype = ins_masks.dtype
# to_tensor = True
# ins_masks = ins_masks.cpu().numpy()
# area_original = np.sum(ins_masks, axis=(1, 2))
# masks_refined = np.bitwise_and(ins_masks, mask[None, ...])
# area_refined = np.sum(masks_refined, axis=(1, 2))
# for ii in range(num_pred):
# if area_refined[ii] / area_original[ii] > 0.3:
# ins_masks[ii] = masks_refined[ii]
# ins_masks = np.ascontiguousarray(ins_masks)
# # for ii, insm in enumerate(ins_masks):
# # cv2.imwrite(f'{ii}.png', insm.astype(np.uint8) * 255)
# if to_tensor:
# ins_masks = torch.from_numpy(ins_masks).to(dtype=tensor_dtype).to(device=tensor_device)
# det_pred.pred_instances.masks = ins_masks
def read_imglst_from_txt(filep) -> List[str]:
with open(filep, 'r', encoding='utf8') as f:
lines =
return lines
class AnimeInsSeg:
def __init__(self, ckpt: str, default_det_size: int = 640, device: str = None,
refine_kwargs: dict = {'refine_method': 'refinenet_isnet'},
tagger_path: str = 'models/wd-v1-4-swinv2-tagger-v2/model.onnx', mask_thr=0.3) -> None:
self.ckpt = ckpt
self.default_det_size = default_det_size
self.device = DEFAULT_DEVICE if device is None else device
# init detector in mmdet's way
ckpt = torch.load(ckpt, map_location='cpu')
cfg = Config.fromstring(ckpt['meta']['cfg'].replace('file_client_args', 'backend_args'), file_format='.py')
cfg.visualizer = []
cfg.vis_backends = {}
# self.model: SingleStageDetector = init_detector(cfg, checkpoint=None, device='cpu')
model =
model = revert_sync_batchnorm(model)
self.model =
self.model.load_state_dict(ckpt['state_dict'], strict=False)
self.model =
self.cfg = cfg.copy()
test_pipeline = get_test_pipeline_cfg(self.cfg.copy())
test_pipeline[0].type = 'mmdet.LoadImageFromNDArray'
test_pipeline = Compose(test_pipeline)
self.default_data_pipeline = test_pipeline
self.refinenet = None
self.refinenet_animeseg: AnimeSegmentation = None
self.postprocess_refine: Callable = None
if refine_kwargs is not None:
self.tagger = None
self.tagger_path = tagger_path
self.mask_thr = mask_thr
def init_tagger(self, tagger_path: str = None):
tagger_path = self.tagger_path if tagger_path is None else tagger_path
self.tagger = Tagger(self.tagger_path)
def infer_tags(self, instances: AnimeInstances, img: np.ndarray, infer_grey: bool = False):
if self.tagger is None:
if infer_grey:
img = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)[..., None][..., [0, 0, 0]]
num_ins = len(instances)
for ii in range(num_ins):
bbox = instances.bboxes[ii]
mask = instances.masks[ii]
if isinstance(bbox, torch.Tensor):
bbox = bbox.cpu().numpy()
mask = mask.cpu().numpy()
bbox = bbox.astype(np.int32)
crop = img[bbox[1]: bbox[3] + bbox[1], bbox[0]: bbox[2] + bbox[0]].copy()
mask = mask[bbox[1]: bbox[3] + bbox[1], bbox[0]: bbox[2] + bbox[0]]
crop[mask == 0] = 255
tags, character_tags = self.tagger.label_cv2_bgr(crop)
exclude_tags = ['simple_background', 'white_background']
valid_tags = []
for tag in tags:
if tag in exclude_tags:
instances.tags[ii] = ' '.join(valid_tags)
instances.character_tags[ii] = character_tags
def infer_embeddings(self, imgs, det_size = None):
def hijack_bbox_mask_post_process(
rescale: bool = False,
with_nms: bool = True,
img_meta: Optional[dict] = None):
stride = self.prior_generator.strides[0][0]
if rescale:
assert img_meta.get('scale_factor') is not None
scale_factor = [1 / s for s in img_meta['scale_factor']]
results.bboxes = scale_boxes(results.bboxes, scale_factor)
if hasattr(results, 'score_factors'):
# TODO: Add sqrt operation in order to be consistent with
# the paper.
score_factors = results.pop('score_factors')
results.scores = results.scores * score_factors
# filter small size bboxes
if cfg.get('min_bbox_size', -1) >= 0:
w, h = get_box_wh(results.bboxes)
valid_mask = (w > cfg.min_bbox_size) & (h > cfg.min_bbox_size)
if not valid_mask.all():
results = results[valid_mask]
# results.mask_feat = mask_feat
return results, mask_feat
def hijack_detector_predict(self: SingleStageDetector,
batch_inputs: torch.Tensor,
batch_data_samples: SampleList,
rescale: bool = True) -> SampleList:
x = self.extract_feat(batch_inputs)
bbox_head: RTMDetInsSepBNHeadCustom = self.bbox_head
old_postprocess = RTMDetInsSepBNHeadCustom._bbox_mask_post_process
RTMDetInsSepBNHeadCustom._bbox_mask_post_process = hijack_bbox_mask_post_process
# results_list = bbox_head.predict(
# x, batch_data_samples, rescale=rescale)
batch_img_metas = [
data_samples.metainfo for data_samples in batch_data_samples
outs = bbox_head(x)
results_list = bbox_head.predict_by_feat(
*outs, batch_img_metas=batch_img_metas, rescale=rescale)
# batch_data_samples = self.add_pred_to_datasample(
# batch_data_samples, results_list)
RTMDetInsSepBNHeadCustom._bbox_mask_post_process = old_postprocess
return results_list
old_predict = SingleStageDetector.predict
SingleStageDetector.predict = hijack_detector_predict
test_pipeline, imgs, _ = self.prepare_data_pipeline(imgs, det_size)
if len(imgs) > 1:
imgs = tqdm(imgs)
model = self.model
img = imgs[0]
data_, img = test_pipeline(img)
data = model.data_preprocessor(data_, False)
instance_data, mask_feat = model(**data, mode='predict')[0]
SingleStageDetector.predict = old_predict
# print((instance_data.scores > 0.9).sum())
return img, instance_data, mask_feat
def segment_with_bboxes(self, img, bboxes: torch.Tensor, instance_data, mask_feat: torch.Tensor):
# instance_data.bboxes: x1, y1, x2, y2
maxidx = torch.argmax(instance_data.scores)
bbox = instance_data.bboxes[maxidx].cpu().numpy()
p1, p2 = (int(bbox[0]), int(bbox[1])), (int(bbox[2]), int(bbox[3]))
tgt_bboxes = instance_data.bboxes
im_h, im_w = img.shape[:2]
long_side = max(im_h, im_w)
bbox_head: RTMDetInsSepBNHeadCustom = self.model.bbox_head
priors, kernels = instance_data.priors, instance_data.kernels
stride = bbox_head.prior_generator.strides[0][0]
ins_bboxes, ins_segs, scores = [], [], []
for bbox in bboxes:
bbox = torch.from_numpy(np.array([bbox])).to(tgt_bboxes.dtype).to(tgt_bboxes.device)
ioulst = box_iou(bbox, tgt_bboxes).squeeze()
matched_idx = torch.argmax(ioulst)
mask_logits = bbox_head._mask_predict_by_feat_single(
mask_feat, kernels[matched_idx][None, ...], priors[matched_idx][None, ...])
mask_logits = F.interpolate(
mask_logits.unsqueeze(0), scale_factor=stride, mode='bilinear')
mask_logits = F.interpolate(
size=[long_side, long_side],
align_corners=False)[..., :im_h, :im_w]
mask = mask_logits.sigmoid().squeeze()
mask = mask > 0.5
mask = mask.cpu().numpy()
matched_iou_score = ioulst[matched_idx]
matched_score = instance_data.scores[matched_idx]
matched_bbox = tgt_bboxes[matched_idx]
# p1, p2 = (int(matched_bbox[0]), int(matched_bbox[1])), (int(matched_bbox[2]), int(matched_bbox[3]))
if len(ins_bboxes) > 0:
ins_bboxes = np.array(ins_bboxes).astype(np.int32)
ins_bboxes[:, 2:] -= ins_bboxes[:, :2]
ins_segs = np.array(ins_segs)
instances = AnimeInstances(ins_segs, ins_bboxes, scores)
self._postprocess_refine(instances, img)
drawed = instances.draw_instances(img)
# cv2.imshow('drawed', drawed)
# cv2.waitKey(0)
return instances
def set_detect_size(self, det_size: Union[int, Tuple]):
if isinstance(det_size, int):
det_size = (det_size, det_size)
self.default_data_pipeline.transforms[1].scale = det_size
self.default_data_pipeline.transforms[2].size = det_size
def infer(self, imgs: Union[List, str, np.ndarray],
pred_score_thr: float = 0.3,
refine_kwargs: dict = None,
output_type: str="tensor",
det_size: int = None,
save_dir: str = '',
save_visualization: bool = False,
save_annotation: str = '',
infer_tags: bool = False,
obj_id_start: int = -1,
img_id_start: int = -1,
verbose: bool = False,
infer_grey: bool = False,
save_mask_only: bool = False,
max_instances: int = 100,
**kwargs) -> Union[List[AnimeInstances], AnimeInstances, None]:
imgs (str, ndarray, Sequence[str/ndarray]):
Either image files or loaded images.
:obj:`AnimeInstances` or list[:obj:`AnimeInstances`]:
If save_annotation or save_annotation, return None.
if det_size is not None:
if refine_kwargs is not None:
if isinstance(imgs, str):
if imgs.endswith('.txt'):
imgs = read_imglst_from_txt(imgs)
if save_annotation or save_visualization:
return self._infer_save_annotations(imgs, pred_score_thr, det_size, save_dir, save_visualization, \
save_annotation, infer_tags, obj_id_start, img_id_start, val_dir=val_dir)
return self._infer_simple(imgs, pred_score_thr, det_size, output_type, infer_tags, verbose=verbose, infer_grey=infer_grey)
def _det_forward(self, img, test_pipeline, pred_score_thr: float = 0.3) -> Tuple[AnimeInstances, np.ndarray]:
data_, img = test_pipeline(img)
with torch.no_grad():
results: DetDataSample = self.model.test_step(data_)[0]
pred_instances = results.pred_instances
pred_instances = pred_instances[pred_instances.scores > pred_score_thr]
if len(pred_instances) < 1:
return AnimeInstances(), img
del data_
bboxes =
bboxes[:, 2:] -= bboxes[:, :2]
masks = pred_instances.masks
scores = pred_instances.scores
return AnimeInstances(masks, bboxes, scores), img
def _infer_simple(self, imgs: Union[List, str, np.ndarray],
pred_score_thr: float = 0.3,
det_size: int = None,
output_type: str = "tensor",
infer_tags: bool = False,
infer_grey: bool = False,
verbose: bool = False) -> Union[DetDataSample, List[DetDataSample]]:
if isinstance(imgs, List):
return_list = True
return_list = False
assert output_type in {'tensor', 'numpy'}
test_pipeline, imgs, _ = self.prepare_data_pipeline(imgs, det_size)
predictions = []
if len(imgs) > 1:
imgs = tqdm(imgs)
for img in imgs:
instances, img = self._det_forward(img, test_pipeline, pred_score_thr)
# drawed = instances.draw_instances(img)
# cv2.imwrite('drawed.jpg', drawed)
self.postprocess_results(instances, img)
# drawed = instances.draw_instances(img)
# cv2.imwrite('drawed_post.jpg', drawed)
if infer_tags:
self.infer_tags(instances, img, infer_grey)
if output_type == 'numpy':
if return_list:
return predictions
return predictions[0]
def _infer_save_annotations(self, imgs: Union[List, str, np.ndarray],
pred_score_thr: float = 0.3,
det_size: int = None,
save_dir: str = '',
save_visualization: bool = False,
save_annotation: str = '',
infer_tags: bool = False,
obj_id_start: int = 100000000000,
img_id_start: int = 100000000000,
save_mask_only: bool = False,
val_dir = None,
**kwargs) -> None:
coco_api = None
if isinstance(imgs, str) and imgs.endswith('.json'):
coco_api = COCO(imgs)
if val_dir is None:
val_dir = osp.join(osp.dirname(osp.dirname(imgs)), 'val')
imgs = coco_api.getImgIds()
imgp2ids = {}
imgps, coco_imgmetas = [], []
for imgid in imgs:
imeta = coco_api.loadImgs(imgid)[0]
imgname = imeta['file_name']
imgp = osp.join(val_dir, imgname)
imgp2ids[imgp] = imgid
imgs = imgps
test_pipeline, imgs, target_dir = self.prepare_data_pipeline(imgs, det_size)
if save_dir == '':
save_dir = osp.join(target_dir, \
osp.basename(self.ckpt).replace('.ckpt', '').replace('.pth', '').replace('.pt', ''))
if not osp.exists(save_dir):
det_annotations = []
image_meta = []
obj_id = obj_id_start + 1
image_id = img_id_start + 1
for ii, img in enumerate(tqdm(imgs)):
# prepare data
if isinstance(img, str):
img_name = osp.basename(img)
img_name = f'{ii}'.zfill(12) + '.jpg'
if coco_api is not None:
image_id = imgp2ids[img]
instances, img = self._det_forward(img, test_pipeline, pred_score_thr)
except Exception as e:
raise e
if isinstance(e, torch.cuda.OutOfMemoryError):
instances, img = self._det_forward(img, test_pipeline, pred_score_thr)
LOGGER.warning(f'cuda out of memory: {img_name}')
if isinstance(img, str):
img = cv2.imread(img)
instances = None
if instances is not None:
self.postprocess_results(instances, img)
if infer_tags:
self.infer_tags(instances, img)
if save_visualization:
out_file = osp.join(save_dir, img_name)
self.save_visualization(out_file, img, instances)
if save_annotation:
im_h, im_w = img.shape[:2]
"id": image_id,"height": im_h,"width": im_w,
"file_name": img_name, "id": image_id
if instances is not None:
for ii in range(len(instances)):
segmentation = instances.masks[ii].squeeze().cpu().numpy().astype(np.uint8)
area = segmentation.sum()
segmentation *= 255
if save_mask_only:
cv2.imwrite(osp.join(save_dir, 'mask_' + str(ii).zfill(3) + '_' +img_name+'.png'), segmentation)
score = instances.scores[ii]
if isinstance(score, torch.Tensor):
score = score.item()
score = float(score)
bbox = instances.bboxes[ii].cpu().numpy()
bbox = bbox.astype(np.float32).tolist()
segmentation = mask2rle(segmentation)
tag_string = instances.tags[ii]
tag_string_character = instances.character_tags[ii]
det_annotations.append({'id': obj_id, 'category_id': 0, 'iscrowd': 0, 'score': score,
'segmentation': segmentation, 'image_id': image_id, 'area': area,
'tag_string': tag_string, 'tag_string_character': tag_string_character, 'bbox': bbox
obj_id += 1
image_id += 1
if save_annotation != '' and not save_mask_only:
det_meta = {"info": {},"licenses": [], "images": image_meta,
"annotations": det_annotations, "categories": CATEGORIES}
detp = save_annotation
dict2json(det_meta, detp)'annotations saved to {detp}')
def set_refine_method(self, refine_method: str = 'none', refine_size: int = 720):
if refine_method == 'none':
self.postprocess_refine = None
elif refine_method == 'animeseg':
if self.refinenet_animeseg is None:
self.refinenet_animeseg = load_refinenet(refine_method)
self.postprocess_refine = lambda det_pred, img: \
animeseg_refine(det_pred, img, self.refinenet_animeseg, True, refine_size)
elif refine_method == 'refinenet_isnet':
if self.refinenet is None:
self.refinenet = load_refinenet(refine_method)
self.postprocess_refine = self._postprocess_refine
raise NotImplementedError(f'Invalid refine method: {refine_method}')
def _postprocess_refine(self, instances: AnimeInstances, img: np.ndarray, refine_size: int = 720, max_refine_batch: int = 4, **kwargs):
if instances.is_empty:
segs = instances.masks
is_tensor = instances.is_tensor
if is_tensor:
segs = segs.cpu().numpy()
segs = segs.astype(np.float32)
im_h, im_w = img.shape[:2]
masks = []
with torch.no_grad():
for batch, (pt, pb, pl, pr) in prepare_refine_batch(segs, img, max_refine_batch, self.device, refine_size):
preds = self.refinenet(batch)[0][0].sigmoid()
if pb == 0:
pb = -im_h
if pr == 0:
pr = -im_w
preds = preds[..., pt: -pb, pl: -pr]
preds = torch.nn.functional.interpolate(preds, (im_h, im_w), mode='bilinear', align_corners=True)
masks.append(preds.cpu()[:, 0])
masks = (torch.concat(masks, dim=0) > self.mask_thr).to(self.device)
if not is_tensor:
masks = masks.cpu().numpy()
instances.masks = masks
def prepare_data_pipeline(self, imgs: Union[str, np.ndarray, List], det_size: int) -> Tuple[Compose, List, str]:
if det_size is None:
det_size = self.default_det_size
target_dir = './workspace/output'
# cast imgs to a list of np.ndarray or image_file_path if necessary
if isinstance(imgs, str):
if osp.isdir(imgs):
target_dir = imgs
imgs = find_all_imgs(imgs, abs_path=True)
elif osp.isfile(imgs):
target_dir = osp.dirname(imgs)
imgs = [imgs]
elif isinstance(imgs, np.ndarray) or isinstance(imgs, str):
imgs = [imgs]
elif isinstance(imgs, List):
if len(imgs) > 0:
if isinstance(imgs[0], np.ndarray) or isinstance(imgs[0], str):
raise NotImplementedError
raise NotImplementedError
test_pipeline = lambda img: single_image_preprocess(img, pipeline=self.default_data_pipeline)
return test_pipeline, imgs, target_dir
def save_visualization(self, out_file: str, img: np.ndarray, instances: AnimeInstances):
drawed = instances.draw_instances(img)
mmcv.imwrite(drawed, out_file)
def postprocess_results(self, results: DetDataSample, img: np.ndarray) -> None:
if self.postprocess_refine is not None:
self.postprocess_refine(results, img)
def set_mask_threshold(self, mask_thr: float):
self.model.bbox_head.test_cfg['mask_thr_binary'] = mask_thr
def set_max_instance(self, num_ins):
self.model.bbox_head.test_cfg['max_per_img'] = num_ins