3v324v23's picture
lfs
1e3b872
import torch
from PIL import Image
from rembg import remove, new_session
import folder_paths
from ..session.CustomSession import CustomAbstractSession
from ..session.CustomSession import CustomSessionContainer
from ..session.ModnetPhotographicSession import ModnetPhotographicSession
from ..session.ModnetWebcamSession import ModnetWebcamSession
class ImageSegmentation:
def __init__(self):
pass
@classmethod
def INPUT_TYPES(cls):
return {
"required": {
"images": ("IMAGE",),
"model": ([
"u2net",
"u2netp",
"u2net_human_seg",
"u2net_cloth_seg",
"silueta",
"isnet-general-use",
"isnetis",
"modnet-p",
"modnet-w"
],),
"alpha_matting": (["true", "false"],),
"alpha_matting_foreground_threshold": ("INT", {
"default": 240,
"max": 250,
"step": 5
}),
"alpha_matting_background_threshold": ("INT", {
"default": 20,
"max": 250,
"step": 5
}),
"alpha_matting_erode_size": ("INT", {
"default": 10,
"step": 1
}),
"post_process_mask": (["false", "true"],),
},
}
RETURN_TYPES = ("IMAGE",)
FUNCTION = "node"
CATEGORY = "image/segmentation"
def node(
self,
images,
model,
alpha_matting,
alpha_matting_foreground_threshold,
alpha_matting_background_threshold,
alpha_matting_erode_size,
post_process_mask,
session=None
):
if session is None:
if model == "isnetis":
session = new_session("isnet-anime")
elif model == "modnet-p":
session = ModnetPhotographicSession(model)
elif model == "modnet-w":
session = ModnetWebcamSession(model)
else:
session = new_session(model)
def verst(image):
img: Image = image.tensor_to_image()
return remove(
img, alpha_matting == "true",
alpha_matting_foreground_threshold,
alpha_matting_background_threshold,
alpha_matting_erode_size, session,
False, post_process_mask == "true"
).image_to_tensor()
return (torch.stack([
verst(images[i]) for i in range(len(images))
]),)
class ImageSegmentationCustom:
def __init__(self):
pass
@classmethod
def INPUT_TYPES(cls):
return {
"required": {
"images": ("IMAGE",),
"model": (folder_paths.get_filename_list("onnx"),),
"alpha_matting": (["true", "false"],),
"alpha_matting_foreground_threshold": ("INT", {
"default": 240,
"max": 250,
"step": 5
}),
"alpha_matting_background_threshold": ("INT", {
"default": 20,
"max": 250,
"step": 5
}),
"alpha_matting_erode_size": ("INT", {
"default": 10,
"step": 1
}),
"post_process_mask": (["false", "true"],),
"mean": ("FLOAT", {
"default": 0.485,
"max": 1.0,
"step": 0.01
}),
"std": ("FLOAT", {
"default": 1.0,
"max": 1.0,
"step": 0.01
}),
"size": ("INT", {
"default": 1024,
"step": 8
}),
},
}
RETURN_TYPES = ("IMAGE",)
FUNCTION = "node"
CATEGORY = "image/segmentation"
def node(
self,
images,
model,
alpha_matting,
alpha_matting_foreground_threshold,
alpha_matting_background_threshold,
alpha_matting_erode_size,
post_process_mask,
mean,
std,
size
):
container = CustomSessionContainer(mean, mean, mean, std, std, std, size, size)
class CustomSession(CustomAbstractSession):
def __init__(self):
super().__init__(model)
@classmethod
def name(cls, *args, **kwargs):
return model
session = CustomSession().from_container(container)
return ImageSegmentation().node(
images,
model,
alpha_matting,
alpha_matting_foreground_threshold,
alpha_matting_background_threshold,
alpha_matting_erode_size,
post_process_mask,
session
)
class ImageSegmentationCustomAdvanced:
def __init__(self):
pass
@classmethod
def INPUT_TYPES(cls):
return {
"required": {
"images": ("IMAGE",),
"model": (folder_paths.get_filename_list("onnx"),),
"alpha_matting": (["true", "false"],),
"alpha_matting_foreground_threshold": ("INT", {
"default": 240,
"max": 250,
"step": 5
}),
"alpha_matting_background_threshold": ("INT", {
"default": 20,
"max": 250,
"step": 5
}),
"alpha_matting_erode_size": ("INT", {
"default": 10,
"step": 1
}),
"post_process_mask": (["false", "true"],),
"mean_r": ("FLOAT", {
"default": 0.485,
"max": 1.0,
"step": 0.01
}),
"mean_g": ("FLOAT", {
"default": 0.456,
"max": 1.0,
"step": 0.01
}),
"mean_b": ("FLOAT", {
"default": 0.406,
"max": 1.0,
"step": 0.01
}),
"std_r": ("FLOAT", {
"default": 1.0,
"max": 1.0,
"step": 0.01
}),
"std_g": ("FLOAT", {
"default": 1.0,
"max": 1.0,
"step": 0.01
}),
"std_b": ("FLOAT", {
"default": 1.0,
"max": 1.0,
"step": 0.01
}),
"width": ("INT", {
"default": 1024,
"step": 8
}),
"height": ("INT", {
"default": 1024,
"step": 8
}),
},
}
RETURN_TYPES = ("IMAGE",)
FUNCTION = "node"
CATEGORY = "image/segmentation"
def node(
self,
images,
model,
alpha_matting,
alpha_matting_foreground_threshold,
alpha_matting_background_threshold,
alpha_matting_erode_size,
post_process_mask,
mean_x,
mean_y,
mean_z,
std_x,
std_y,
std_z,
width,
height
):
container = CustomSessionContainer(mean_x, mean_y, mean_z, std_x, std_y, std_z, width, height)
class CustomSession(CustomAbstractSession):
def __init__(self):
super().__init__(model)
@classmethod
def name(cls, *args, **kwargs):
return model
session = CustomSession().from_container(container)
return ImageSegmentation().node(
images,
model,
alpha_matting,
alpha_matting_foreground_threshold,
alpha_matting_background_threshold,
alpha_matting_erode_size,
post_process_mask,
session
)
NODE_CLASS_MAPPINGS = {
"ImageSegmentation": ImageSegmentation,
"ImageSegmentationCustom": ImageSegmentationCustom,
"ImageSegmentationCustomAdvanced": ImageSegmentationCustomAdvanced,
}