|
import re |
|
|
|
import torch |
|
|
|
|
|
class ImageBatchGet: |
|
def __init__(self): |
|
pass |
|
|
|
@classmethod |
|
def INPUT_TYPES(cls): |
|
return { |
|
"required": { |
|
"images": ("IMAGE",), |
|
"index": ("INT", { |
|
"default": 1, |
|
"min": 1, |
|
"step": 1 |
|
}), |
|
}, |
|
} |
|
|
|
RETURN_TYPES = ("IMAGE",) |
|
FUNCTION = "node" |
|
CATEGORY = "image/batch" |
|
|
|
def node(self, images, index): |
|
batch = images.shape[0] |
|
index = min(batch, index) - 1 |
|
|
|
return (images[index].unsqueeze(0),) |
|
|
|
|
|
class ImageBatchCopy: |
|
def __init__(self): |
|
pass |
|
|
|
@classmethod |
|
def INPUT_TYPES(cls): |
|
return { |
|
"required": { |
|
"images": ("IMAGE",), |
|
"index": ("INT", { |
|
"default": 1, |
|
"min": 1, |
|
"step": 1 |
|
}), |
|
"quantity": ("INT", { |
|
"default": 1, |
|
"min": 2, |
|
"step": 1 |
|
}), |
|
}, |
|
} |
|
|
|
RETURN_TYPES = ("IMAGE",) |
|
FUNCTION = "node" |
|
CATEGORY = "image/batch" |
|
|
|
def node(self, images, index, quantity): |
|
batch = images.shape[0] |
|
index = min(batch, index) - 1 |
|
|
|
return (images[index].repeat(quantity, 1, 1, 1),) |
|
|
|
|
|
class ImageBatchRemove: |
|
def __init__(self): |
|
pass |
|
|
|
@classmethod |
|
def INPUT_TYPES(cls): |
|
return { |
|
"required": { |
|
"images": ("IMAGE",), |
|
"index": ("INT", { |
|
"default": 1, |
|
"min": 1, |
|
"step": 1 |
|
}), |
|
}, |
|
} |
|
|
|
RETURN_TYPES = ("IMAGE",) |
|
FUNCTION = "node" |
|
CATEGORY = "image/batch" |
|
|
|
def node(self, images, index): |
|
batch = images.shape[0] |
|
index = min(batch, index - 1) |
|
|
|
return (torch.cat((images[:index], images[index + 1:]), dim=0),) |
|
|
|
|
|
class ImageBatchFork: |
|
def __init__(self): |
|
pass |
|
|
|
@classmethod |
|
def INPUT_TYPES(cls): |
|
return { |
|
"required": { |
|
"images": ("IMAGE",), |
|
"priority": (["first", "second"],), |
|
}, |
|
} |
|
|
|
RETURN_TYPES = ("IMAGE", "IMAGE") |
|
FUNCTION = "node" |
|
CATEGORY = "image/batch" |
|
|
|
def node(self, images, priority): |
|
batch = images.shape[0] |
|
|
|
if batch == 1: |
|
return images, images |
|
elif batch % 2 == 0: |
|
first = batch // 2 |
|
second = batch // 2 |
|
else: |
|
if priority == "first": |
|
first = batch // 2 + 1 |
|
second = batch // 2 |
|
elif priority == "second": |
|
first = batch // 2 |
|
second = batch // 2 + 1 |
|
else: |
|
raise ValueError("Not existing priority.") |
|
|
|
return images[:first], images[-second:] |
|
|
|
|
|
class ImageBatchJoin: |
|
def __init__(self): |
|
pass |
|
|
|
@classmethod |
|
def INPUT_TYPES(cls): |
|
return { |
|
"required": { |
|
"images_a": ("IMAGE",), |
|
"images_b": ("IMAGE",), |
|
}, |
|
} |
|
|
|
RETURN_TYPES = ("IMAGE",) |
|
FUNCTION = "node" |
|
CATEGORY = "image/batch" |
|
|
|
def node(self, images_a, images_b): |
|
height_a, width_a, channels_a = images_a[0].shape |
|
height_b, width_b, channels_b = images_b[0].shape |
|
|
|
if height_a != height_b: |
|
raise ValueError("Height of images_a not equals of images_b. You can use ImageTransformResize for fix it.") |
|
|
|
if width_a != width_b: |
|
raise ValueError("Width of images_a not equals of images_b. You can use ImageTransformResize for fix it.") |
|
|
|
if channels_a != channels_b: |
|
raise ValueError("Channels of images_a not equals of images_b. Your can add or delete alpha channels with AlphaChanel module.") |
|
|
|
return (torch.cat((images_a, images_b)),) |
|
|
|
|
|
class ImageBatchPermute: |
|
def __init__(self): |
|
pass |
|
|
|
@classmethod |
|
def INPUT_TYPES(cls): |
|
return { |
|
"required": { |
|
"images": ("IMAGE",), |
|
"permute": ("STRING", {"multiline": False}), |
|
"start_with_zero": ("BOOLEAN",), |
|
}, |
|
} |
|
|
|
RETURN_TYPES = ("IMAGE",) |
|
FUNCTION = "node" |
|
CATEGORY = "image/batch" |
|
|
|
def node(self, images, permute, start_with_zero): |
|
order = [int(num) - 1 if not start_with_zero else int(num) for num in re.findall(r'\d+', permute)] |
|
order = torch.tensor(order) |
|
order = order.clamp(0, images.shape[0] - 1) |
|
|
|
return (images.index_select(0, order),) |
|
|
|
|
|
NODE_CLASS_MAPPINGS = { |
|
"ImageBatchGet": ImageBatchGet, |
|
"ImageBatchCopy": ImageBatchCopy, |
|
"ImageBatchRemove": ImageBatchRemove, |
|
"ImageBatchFork": ImageBatchFork, |
|
"ImageBatchJoin": ImageBatchJoin, |
|
"ImageBatchPermute": ImageBatchPermute |
|
} |
|
|