|
import folder_paths |
|
from impact.core import * |
|
import os |
|
|
|
import mmcv |
|
from mmdet.apis import (inference_detector, init_detector) |
|
from mmdet.evaluation import get_classes |
|
|
|
|
|
def load_mmdet(model_path): |
|
model_config = os.path.splitext(model_path)[0] + ".py" |
|
model = init_detector(model_config, model_path, device="cpu") |
|
return model |
|
|
|
|
|
def inference_segm_old(model, image, conf_threshold): |
|
image = image.numpy()[0] * 255 |
|
mmdet_results = inference_detector(model, image) |
|
|
|
bbox_results, segm_results = mmdet_results |
|
label = "A" |
|
|
|
classes = get_classes("coco") |
|
labels = [ |
|
np.full(bbox.shape[0], i, dtype=np.int32) |
|
for i, bbox in enumerate(bbox_results) |
|
] |
|
n, m = bbox_results[0].shape |
|
if n == 0: |
|
return [[], [], []] |
|
labels = np.concatenate(labels) |
|
bboxes = np.vstack(bbox_results) |
|
segms = mmcv.concat_list(segm_results) |
|
filter_idxs = np.where(bboxes[:, -1] > conf_threshold)[0] |
|
results = [[], [], []] |
|
for i in filter_idxs: |
|
results[0].append(label + "-" + classes[labels[i]]) |
|
results[1].append(bboxes[i]) |
|
results[2].append(segms[i]) |
|
|
|
return results |
|
|
|
|
|
def inference_segm(image, modelname, conf_thres, lab="A"): |
|
image = image.numpy()[0] * 255 |
|
mmdet_results = inference_detector(modelname, image).pred_instances |
|
bboxes = mmdet_results.bboxes.numpy() |
|
segms = mmdet_results.masks.numpy() |
|
scores = mmdet_results.scores.numpy() |
|
|
|
classes = get_classes("coco") |
|
|
|
n, m = bboxes.shape |
|
if n == 0: |
|
return [[], [], [], []] |
|
labels = mmdet_results.labels |
|
filter_inds = np.where(mmdet_results.scores > conf_thres)[0] |
|
results = [[], [], [], []] |
|
for i in filter_inds: |
|
results[0].append(lab + "-" + classes[labels[i]]) |
|
results[1].append(bboxes[i]) |
|
results[2].append(segms[i]) |
|
results[3].append(scores[i]) |
|
|
|
return results |
|
|
|
|
|
def inference_bbox(modelname, image, conf_threshold): |
|
image = image.numpy()[0] * 255 |
|
label = "A" |
|
output = inference_detector(modelname, image).pred_instances |
|
cv2_image = np.array(image) |
|
cv2_image = cv2_image[:, :, ::-1].copy() |
|
cv2_gray = cv2.cvtColor(cv2_image, cv2.COLOR_BGR2GRAY) |
|
|
|
segms = [] |
|
for x0, y0, x1, y1 in output.bboxes: |
|
cv2_mask = np.zeros(cv2_gray.shape, np.uint8) |
|
cv2.rectangle(cv2_mask, (int(x0), int(y0)), (int(x1), int(y1)), 255, -1) |
|
cv2_mask_bool = cv2_mask.astype(bool) |
|
segms.append(cv2_mask_bool) |
|
|
|
n, m = output.bboxes.shape |
|
if n == 0: |
|
return [[], [], [], []] |
|
|
|
bboxes = output.bboxes.numpy() |
|
scores = output.scores.numpy() |
|
filter_idxs = np.where(scores > conf_threshold)[0] |
|
results = [[], [], [], []] |
|
for i in filter_idxs: |
|
results[0].append(label) |
|
results[1].append(bboxes[i]) |
|
results[2].append(segms[i]) |
|
results[3].append(scores[i]) |
|
|
|
return results |
|
|
|
|
|
class BBoxDetector: |
|
bbox_model = None |
|
|
|
def __init__(self, bbox_model): |
|
self.bbox_model = bbox_model |
|
|
|
def detect(self, image, threshold, dilation, crop_factor, drop_size=1, detailer_hook=None): |
|
drop_size = max(drop_size, 1) |
|
mmdet_results = inference_bbox(self.bbox_model, image, threshold) |
|
segmasks = create_segmasks(mmdet_results) |
|
|
|
if dilation > 0: |
|
segmasks = dilate_masks(segmasks, dilation) |
|
|
|
items = [] |
|
h = image.shape[1] |
|
w = image.shape[2] |
|
|
|
for x in segmasks: |
|
item_bbox = x[0] |
|
item_mask = x[1] |
|
|
|
y1, x1, y2, x2 = item_bbox |
|
|
|
if x2 - x1 > drop_size and y2 - y1 > drop_size: |
|
crop_region = make_crop_region(w, h, item_bbox, crop_factor) |
|
cropped_image = crop_image(image, crop_region) |
|
cropped_mask = crop_ndarray2(item_mask, crop_region) |
|
confidence = x[2] |
|
|
|
|
|
item = SEG(cropped_image, cropped_mask, confidence, crop_region, item_bbox, None, None) |
|
|
|
items.append(item) |
|
|
|
shape = image.shape[1], image.shape[2] |
|
return shape, items |
|
|
|
def detect_combined(self, image, threshold, dilation): |
|
mmdet_results = inference_bbox(self.bbox_model, image, threshold) |
|
segmasks = create_segmasks(mmdet_results) |
|
if dilation > 0: |
|
segmasks = dilate_masks(segmasks, dilation) |
|
|
|
return combine_masks(segmasks) |
|
|
|
def setAux(self, x): |
|
pass |
|
|
|
|
|
class SegmDetector(BBoxDetector): |
|
segm_model = None |
|
|
|
def __init__(self, segm_model): |
|
self.segm_model = segm_model |
|
|
|
def detect(self, image, threshold, dilation, crop_factor, drop_size=1, detailer_hook=None): |
|
drop_size = max(drop_size, 1) |
|
mmdet_results = inference_segm(image, self.segm_model, threshold) |
|
segmasks = create_segmasks(mmdet_results) |
|
|
|
if dilation > 0: |
|
segmasks = dilate_masks(segmasks, dilation) |
|
|
|
items = [] |
|
h = image.shape[1] |
|
w = image.shape[2] |
|
for x in segmasks: |
|
item_bbox = x[0] |
|
item_mask = x[1] |
|
|
|
y1, x1, y2, x2 = item_bbox |
|
|
|
if x2 - x1 > drop_size and y2 - y1 > drop_size: |
|
crop_region = make_crop_region(w, h, item_bbox, crop_factor) |
|
cropped_image = crop_image(image, crop_region) |
|
cropped_mask = crop_ndarray2(item_mask, crop_region) |
|
confidence = x[2] |
|
|
|
item = SEG(cropped_image, cropped_mask, confidence, crop_region, item_bbox, None, None) |
|
items.append(item) |
|
|
|
segs = image.shape, items |
|
|
|
if detailer_hook is not None and hasattr(detailer_hook, "post_detection"): |
|
segs = detailer_hook.post_detection(segs) |
|
|
|
return segs |
|
|
|
def detect_combined(self, image, threshold, dilation): |
|
mmdet_results = inference_bbox(self.bbox_model, image, threshold) |
|
segmasks = create_segmasks(mmdet_results) |
|
if dilation > 0: |
|
segmasks = dilate_masks(segmasks, dilation) |
|
|
|
return combine_masks(segmasks) |
|
|
|
def setAux(self, x): |
|
pass |
|
|
|
|
|
class MMDetDetectorProvider: |
|
@classmethod |
|
def INPUT_TYPES(s): |
|
bboxs = ["bbox/"+x for x in folder_paths.get_filename_list("mmdets_bbox")] |
|
segms = ["segm/"+x for x in folder_paths.get_filename_list("mmdets_segm")] |
|
return {"required": {"model_name": (bboxs + segms, )}} |
|
RETURN_TYPES = ("BBOX_DETECTOR", "SEGM_DETECTOR") |
|
FUNCTION = "load_mmdet" |
|
|
|
CATEGORY = "ImpactPack" |
|
|
|
def load_mmdet(self, model_name): |
|
mmdet_path = folder_paths.get_full_path("mmdets", model_name) |
|
model = load_mmdet(mmdet_path) |
|
|
|
if model_name.startswith("bbox"): |
|
return BBoxDetector(model), NO_SEGM_DETECTOR() |
|
else: |
|
return NO_BBOX_DETECTOR(), model |