3v324v23's picture
lfs
1e3b872
import re
import torch
class ImageBatchGet:
def __init__(self):
pass
@classmethod
def INPUT_TYPES(cls):
return {
"required": {
"images": ("IMAGE",),
"index": ("INT", {
"default": 1,
"min": 1,
"step": 1
}),
},
}
RETURN_TYPES = ("IMAGE",)
FUNCTION = "node"
CATEGORY = "image/batch"
def node(self, images, index):
batch = images.shape[0]
index = min(batch, index) - 1
return (images[index].unsqueeze(0),)
class ImageBatchCopy:
def __init__(self):
pass
@classmethod
def INPUT_TYPES(cls):
return {
"required": {
"images": ("IMAGE",),
"index": ("INT", {
"default": 1,
"min": 1,
"step": 1
}),
"quantity": ("INT", {
"default": 1,
"min": 2,
"step": 1
}),
},
}
RETURN_TYPES = ("IMAGE",)
FUNCTION = "node"
CATEGORY = "image/batch"
def node(self, images, index, quantity):
batch = images.shape[0]
index = min(batch, index) - 1
return (images[index].repeat(quantity, 1, 1, 1),)
class ImageBatchRemove:
def __init__(self):
pass
@classmethod
def INPUT_TYPES(cls):
return {
"required": {
"images": ("IMAGE",),
"index": ("INT", {
"default": 1,
"min": 1,
"step": 1
}),
},
}
RETURN_TYPES = ("IMAGE",)
FUNCTION = "node"
CATEGORY = "image/batch"
def node(self, images, index):
batch = images.shape[0]
index = min(batch, index - 1)
return (torch.cat((images[:index], images[index + 1:]), dim=0),)
class ImageBatchFork:
def __init__(self):
pass
@classmethod
def INPUT_TYPES(cls):
return {
"required": {
"images": ("IMAGE",),
"priority": (["first", "second"],),
},
}
RETURN_TYPES = ("IMAGE", "IMAGE")
FUNCTION = "node"
CATEGORY = "image/batch"
def node(self, images, priority):
batch = images.shape[0]
if batch == 1:
return images, images
elif batch % 2 == 0:
first = batch // 2
second = batch // 2
else:
if priority == "first":
first = batch // 2 + 1
second = batch // 2
elif priority == "second":
first = batch // 2
second = batch // 2 + 1
else:
raise ValueError("Not existing priority.")
return images[:first], images[-second:]
class ImageBatchJoin:
def __init__(self):
pass
@classmethod
def INPUT_TYPES(cls):
return {
"required": {
"images_a": ("IMAGE",),
"images_b": ("IMAGE",),
},
}
RETURN_TYPES = ("IMAGE",)
FUNCTION = "node"
CATEGORY = "image/batch"
def node(self, images_a, images_b):
height_a, width_a, channels_a = images_a[0].shape
height_b, width_b, channels_b = images_b[0].shape
if height_a != height_b:
raise ValueError("Height of images_a not equals of images_b. You can use ImageTransformResize for fix it.")
if width_a != width_b:
raise ValueError("Width of images_a not equals of images_b. You can use ImageTransformResize for fix it.")
if channels_a != channels_b:
raise ValueError("Channels of images_a not equals of images_b. Your can add or delete alpha channels with AlphaChanel module.")
return (torch.cat((images_a, images_b)),)
class ImageBatchPermute:
def __init__(self):
pass
@classmethod
def INPUT_TYPES(cls):
return {
"required": {
"images": ("IMAGE",),
"permute": ("STRING", {"multiline": False}),
"start_with_zero": ("BOOLEAN",),
},
}
RETURN_TYPES = ("IMAGE",)
FUNCTION = "node"
CATEGORY = "image/batch"
def node(self, images, permute, start_with_zero):
order = [int(num) - 1 if not start_with_zero else int(num) for num in re.findall(r'\d+', permute)]
order = torch.tensor(order)
order = order.clamp(0, images.shape[0] - 1)
return (images.index_select(0, order),)
NODE_CLASS_MAPPINGS = {
"ImageBatchGet": ImageBatchGet,
"ImageBatchCopy": ImageBatchCopy,
"ImageBatchRemove": ImageBatchRemove,
"ImageBatchFork": ImageBatchFork,
"ImageBatchJoin": ImageBatchJoin,
"ImageBatchPermute": ImageBatchPermute
}