|
import torch |
|
from nodes import VAEEncode |
|
from comfy.utils import ProgressBar |
|
|
|
|
|
class VAEDecodeBatched: |
|
@classmethod |
|
def INPUT_TYPES(s): |
|
return { |
|
"required": { |
|
"samples": ("LATENT", ), |
|
"vae": ("VAE", ), |
|
"per_batch": ("INT", {"default": 16, "min": 1}) |
|
} |
|
} |
|
|
|
CATEGORY = "Video Helper Suite π₯π
₯π
π
’/batched nodes" |
|
|
|
RETURN_TYPES = ("IMAGE",) |
|
FUNCTION = "decode" |
|
|
|
def decode(self, vae, samples, per_batch): |
|
decoded = [] |
|
pbar = ProgressBar(samples["samples"].shape[0]) |
|
for start_idx in range(0, samples["samples"].shape[0], per_batch): |
|
decoded.append(vae.decode(samples["samples"][start_idx:start_idx+per_batch])) |
|
pbar.update(per_batch) |
|
return (torch.cat(decoded, dim=0), ) |
|
|
|
|
|
class VAEEncodeBatched: |
|
@classmethod |
|
def INPUT_TYPES(s): |
|
return { |
|
"required": { |
|
"pixels": ("IMAGE", ), "vae": ("VAE", ), |
|
"per_batch": ("INT", {"default": 16, "min": 1}) |
|
} |
|
} |
|
|
|
CATEGORY = "Video Helper Suite π₯π
₯π
π
’/batched nodes" |
|
|
|
RETURN_TYPES = ("LATENT",) |
|
FUNCTION = "encode" |
|
|
|
def encode(self, vae, pixels, per_batch): |
|
t = [] |
|
pbar = ProgressBar(pixels.shape[0]) |
|
for start_idx in range(0, pixels.shape[0], per_batch): |
|
try: |
|
sub_pixels = vae.vae_encode_crop_pixels(pixels[start_idx:start_idx+per_batch]) |
|
except: |
|
sub_pixels = VAEEncode.vae_encode_crop_pixels(pixels[start_idx:start_idx+per_batch]) |
|
t.append(vae.encode(sub_pixels[:,:,:,:3])) |
|
pbar.update(per_batch) |
|
return ({"samples": torch.cat(t, dim=0)}, ) |
|
|