|
import torch
|
|
|
|
|
|
class AlphaChanelAdd:
|
|
def __init__(self):
|
|
pass
|
|
|
|
@classmethod
|
|
def INPUT_TYPES(cls):
|
|
return {
|
|
"required": {
|
|
"images": ("IMAGE",),
|
|
},
|
|
}
|
|
|
|
RETURN_TYPES = ("IMAGE",)
|
|
FUNCTION = "node"
|
|
CATEGORY = "image/alpha"
|
|
|
|
def node(self, images):
|
|
batch, height, width, channels = images.shape
|
|
|
|
if channels == 4:
|
|
return images
|
|
|
|
alpha = torch.ones((batch, height, width, 1))
|
|
|
|
return (torch.cat((images, alpha), dim=-1),)
|
|
|
|
|
|
class AlphaChanelAddByMask:
|
|
def __init__(self):
|
|
pass
|
|
|
|
@classmethod
|
|
def INPUT_TYPES(cls):
|
|
return {
|
|
"required": {
|
|
"images": ("IMAGE",),
|
|
"mask": ("MASK",),
|
|
"method": (["default", "invert"],),
|
|
},
|
|
}
|
|
|
|
RETURN_TYPES = ("IMAGE",)
|
|
FUNCTION = "node"
|
|
CATEGORY = "image/alpha"
|
|
|
|
def node(self, images, mask, method):
|
|
img_count, img_height, img_width = images[:, :, :, 0].shape
|
|
mask_count, mask_height, mask_width = mask.shape
|
|
|
|
if mask_width == 64 and mask_height == 64:
|
|
mask = torch.zeros((img_count, img_height, img_width))
|
|
else:
|
|
if img_height != mask_height or img_width != mask_width:
|
|
raise ValueError(
|
|
"[AlphaChanelByMask]: Size of images not equals size of mask. " +
|
|
"Images: [" + str(img_width) + ", " + str(img_height) + "] - " +
|
|
"Mask: [" + str(mask_width) + ", " + str(mask_height) + "]."
|
|
)
|
|
|
|
if img_count != mask_count:
|
|
mask = mask.expand((img_count, -1, -1))
|
|
|
|
if method == "default":
|
|
return (torch.stack([
|
|
torch.stack((
|
|
images[i, :, :, 0],
|
|
images[i, :, :, 1],
|
|
images[i, :, :, 2],
|
|
1. - mask[i]
|
|
), dim=-1) for i in range(len(images))
|
|
]),)
|
|
else:
|
|
return (torch.stack([
|
|
torch.stack((
|
|
images[i, :, :, 0],
|
|
images[i, :, :, 1],
|
|
images[i, :, :, 2],
|
|
mask[i]
|
|
), dim=-1) for i in range(len(images))
|
|
]),)
|
|
|
|
|
|
class AlphaChanelAsMask:
|
|
def __init__(self):
|
|
pass
|
|
|
|
@classmethod
|
|
def INPUT_TYPES(cls):
|
|
return {
|
|
"required": {
|
|
"images": ("IMAGE",),
|
|
"method": (["default", "invert"],),
|
|
},
|
|
}
|
|
|
|
RETURN_TYPES = ("MASK",)
|
|
FUNCTION = "node"
|
|
CATEGORY = "image/alpha"
|
|
|
|
def node(self, images, method):
|
|
if images[0, 0, 0].shape[0] != 4:
|
|
raise ValueError("Alpha chanel not exist.")
|
|
|
|
if method == "default":
|
|
return (1.0 - images[0, :, :, 3],)
|
|
elif method == "invert":
|
|
return (images[0, :, :, 3],)
|
|
else:
|
|
raise ValueError("Unexpected method.")
|
|
|
|
|
|
class AlphaChanelRestore:
|
|
def __init__(self):
|
|
pass
|
|
|
|
@classmethod
|
|
def INPUT_TYPES(cls):
|
|
return {
|
|
"required": {
|
|
"images": ("IMAGE",),
|
|
},
|
|
}
|
|
|
|
RETURN_TYPES = ("IMAGE",)
|
|
FUNCTION = "node"
|
|
CATEGORY = "image/alpha"
|
|
|
|
def node(self, images):
|
|
batch, height, width, channels = images.shape
|
|
|
|
if channels != 4:
|
|
return images
|
|
|
|
tensor = images.clone().detach()
|
|
|
|
tensor[:, :, :, 3] = torch.ones((batch, height, width))
|
|
|
|
return (tensor,)
|
|
|
|
|
|
class AlphaChanelRemove:
|
|
def __init__(self):
|
|
pass
|
|
|
|
@classmethod
|
|
def INPUT_TYPES(cls):
|
|
return {
|
|
"required": {
|
|
"images": ("IMAGE",),
|
|
},
|
|
}
|
|
|
|
RETURN_TYPES = ("IMAGE",)
|
|
FUNCTION = "node"
|
|
CATEGORY = "image/alpha"
|
|
|
|
def node(self, images):
|
|
return (images[:, :, :, 0:3],)
|
|
|
|
|
|
NODE_CLASS_MAPPINGS = {
|
|
"AlphaChanelAdd": AlphaChanelAdd,
|
|
"AlphaChanelAddByMask": AlphaChanelAddByMask,
|
|
"AlphaChanelAsMask": AlphaChanelAsMask,
|
|
"AlphaChanelRestore": AlphaChanelRestore,
|
|
"AlphaChanelRemove": AlphaChanelRemove
|
|
}
|
|
|