|
import torch |
|
from PIL import Image |
|
from rembg import remove, new_session |
|
|
|
import folder_paths |
|
|
|
from ..session.CustomSession import CustomAbstractSession |
|
from ..session.CustomSession import CustomSessionContainer |
|
from ..session.ModnetPhotographicSession import ModnetPhotographicSession |
|
from ..session.ModnetWebcamSession import ModnetWebcamSession |
|
|
|
|
|
class ImageSegmentation: |
|
def __init__(self): |
|
pass |
|
|
|
@classmethod |
|
def INPUT_TYPES(cls): |
|
return { |
|
"required": { |
|
"images": ("IMAGE",), |
|
"model": ([ |
|
"u2net", |
|
"u2netp", |
|
"u2net_human_seg", |
|
"u2net_cloth_seg", |
|
"silueta", |
|
"isnet-general-use", |
|
"isnetis", |
|
"modnet-p", |
|
"modnet-w" |
|
],), |
|
"alpha_matting": (["true", "false"],), |
|
"alpha_matting_foreground_threshold": ("INT", { |
|
"default": 240, |
|
"max": 250, |
|
"step": 5 |
|
}), |
|
"alpha_matting_background_threshold": ("INT", { |
|
"default": 20, |
|
"max": 250, |
|
"step": 5 |
|
}), |
|
"alpha_matting_erode_size": ("INT", { |
|
"default": 10, |
|
"step": 1 |
|
}), |
|
"post_process_mask": (["false", "true"],), |
|
}, |
|
} |
|
|
|
RETURN_TYPES = ("IMAGE",) |
|
FUNCTION = "node" |
|
CATEGORY = "image/segmentation" |
|
|
|
def node( |
|
self, |
|
images, |
|
model, |
|
alpha_matting, |
|
alpha_matting_foreground_threshold, |
|
alpha_matting_background_threshold, |
|
alpha_matting_erode_size, |
|
post_process_mask, |
|
session=None |
|
): |
|
if session is None: |
|
if model == "isnetis": |
|
session = new_session("isnet-anime") |
|
elif model == "modnet-p": |
|
session = ModnetPhotographicSession(model) |
|
elif model == "modnet-w": |
|
session = ModnetWebcamSession(model) |
|
else: |
|
session = new_session(model) |
|
|
|
def verst(image): |
|
img: Image = image.tensor_to_image() |
|
|
|
return remove( |
|
img, alpha_matting == "true", |
|
alpha_matting_foreground_threshold, |
|
alpha_matting_background_threshold, |
|
alpha_matting_erode_size, session, |
|
False, post_process_mask == "true" |
|
).image_to_tensor() |
|
|
|
return (torch.stack([ |
|
verst(images[i]) for i in range(len(images)) |
|
]),) |
|
|
|
|
|
class ImageSegmentationCustom: |
|
def __init__(self): |
|
pass |
|
|
|
@classmethod |
|
def INPUT_TYPES(cls): |
|
return { |
|
"required": { |
|
"images": ("IMAGE",), |
|
"model": (folder_paths.get_filename_list("onnx"),), |
|
"alpha_matting": (["true", "false"],), |
|
"alpha_matting_foreground_threshold": ("INT", { |
|
"default": 240, |
|
"max": 250, |
|
"step": 5 |
|
}), |
|
"alpha_matting_background_threshold": ("INT", { |
|
"default": 20, |
|
"max": 250, |
|
"step": 5 |
|
}), |
|
"alpha_matting_erode_size": ("INT", { |
|
"default": 10, |
|
"step": 1 |
|
}), |
|
"post_process_mask": (["false", "true"],), |
|
"mean": ("FLOAT", { |
|
"default": 0.485, |
|
"max": 1.0, |
|
"step": 0.01 |
|
}), |
|
"std": ("FLOAT", { |
|
"default": 1.0, |
|
"max": 1.0, |
|
"step": 0.01 |
|
}), |
|
"size": ("INT", { |
|
"default": 1024, |
|
"step": 8 |
|
}), |
|
}, |
|
} |
|
|
|
RETURN_TYPES = ("IMAGE",) |
|
FUNCTION = "node" |
|
CATEGORY = "image/segmentation" |
|
|
|
def node( |
|
self, |
|
images, |
|
model, |
|
alpha_matting, |
|
alpha_matting_foreground_threshold, |
|
alpha_matting_background_threshold, |
|
alpha_matting_erode_size, |
|
post_process_mask, |
|
mean, |
|
std, |
|
size |
|
): |
|
container = CustomSessionContainer(mean, mean, mean, std, std, std, size, size) |
|
|
|
class CustomSession(CustomAbstractSession): |
|
def __init__(self): |
|
super().__init__(model) |
|
|
|
@classmethod |
|
def name(cls, *args, **kwargs): |
|
return model |
|
|
|
session = CustomSession().from_container(container) |
|
|
|
return ImageSegmentation().node( |
|
images, |
|
model, |
|
alpha_matting, |
|
alpha_matting_foreground_threshold, |
|
alpha_matting_background_threshold, |
|
alpha_matting_erode_size, |
|
post_process_mask, |
|
session |
|
) |
|
|
|
|
|
class ImageSegmentationCustomAdvanced: |
|
def __init__(self): |
|
pass |
|
|
|
@classmethod |
|
def INPUT_TYPES(cls): |
|
return { |
|
"required": { |
|
"images": ("IMAGE",), |
|
"model": (folder_paths.get_filename_list("onnx"),), |
|
"alpha_matting": (["true", "false"],), |
|
"alpha_matting_foreground_threshold": ("INT", { |
|
"default": 240, |
|
"max": 250, |
|
"step": 5 |
|
}), |
|
"alpha_matting_background_threshold": ("INT", { |
|
"default": 20, |
|
"max": 250, |
|
"step": 5 |
|
}), |
|
"alpha_matting_erode_size": ("INT", { |
|
"default": 10, |
|
"step": 1 |
|
}), |
|
"post_process_mask": (["false", "true"],), |
|
"mean_r": ("FLOAT", { |
|
"default": 0.485, |
|
"max": 1.0, |
|
"step": 0.01 |
|
}), |
|
"mean_g": ("FLOAT", { |
|
"default": 0.456, |
|
"max": 1.0, |
|
"step": 0.01 |
|
}), |
|
"mean_b": ("FLOAT", { |
|
"default": 0.406, |
|
"max": 1.0, |
|
"step": 0.01 |
|
}), |
|
"std_r": ("FLOAT", { |
|
"default": 1.0, |
|
"max": 1.0, |
|
"step": 0.01 |
|
}), |
|
"std_g": ("FLOAT", { |
|
"default": 1.0, |
|
"max": 1.0, |
|
"step": 0.01 |
|
}), |
|
"std_b": ("FLOAT", { |
|
"default": 1.0, |
|
"max": 1.0, |
|
"step": 0.01 |
|
}), |
|
"width": ("INT", { |
|
"default": 1024, |
|
"step": 8 |
|
}), |
|
"height": ("INT", { |
|
"default": 1024, |
|
"step": 8 |
|
}), |
|
}, |
|
} |
|
|
|
RETURN_TYPES = ("IMAGE",) |
|
FUNCTION = "node" |
|
CATEGORY = "image/segmentation" |
|
|
|
def node( |
|
self, |
|
images, |
|
model, |
|
alpha_matting, |
|
alpha_matting_foreground_threshold, |
|
alpha_matting_background_threshold, |
|
alpha_matting_erode_size, |
|
post_process_mask, |
|
mean_x, |
|
mean_y, |
|
mean_z, |
|
std_x, |
|
std_y, |
|
std_z, |
|
width, |
|
height |
|
): |
|
container = CustomSessionContainer(mean_x, mean_y, mean_z, std_x, std_y, std_z, width, height) |
|
|
|
class CustomSession(CustomAbstractSession): |
|
def __init__(self): |
|
super().__init__(model) |
|
|
|
@classmethod |
|
def name(cls, *args, **kwargs): |
|
return model |
|
|
|
session = CustomSession().from_container(container) |
|
|
|
return ImageSegmentation().node( |
|
images, |
|
model, |
|
alpha_matting, |
|
alpha_matting_foreground_threshold, |
|
alpha_matting_background_threshold, |
|
alpha_matting_erode_size, |
|
post_process_mask, |
|
session |
|
) |
|
|
|
|
|
NODE_CLASS_MAPPINGS = { |
|
"ImageSegmentation": ImageSegmentation, |
|
"ImageSegmentationCustom": ImageSegmentationCustom, |
|
"ImageSegmentationCustomAdvanced": ImageSegmentationCustomAdvanced, |
|
} |
|
|