File size: 1,790 Bytes
1e3b872 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 |
import torch
from nodes import VAEEncode
from comfy.utils import ProgressBar
class VAEDecodeBatched:
@classmethod
def INPUT_TYPES(s):
return {
"required": {
"samples": ("LATENT", ),
"vae": ("VAE", ),
"per_batch": ("INT", {"default": 16, "min": 1})
}
}
CATEGORY = "Video Helper Suite π₯π
₯π
π
’/batched nodes"
RETURN_TYPES = ("IMAGE",)
FUNCTION = "decode"
def decode(self, vae, samples, per_batch):
decoded = []
pbar = ProgressBar(samples["samples"].shape[0])
for start_idx in range(0, samples["samples"].shape[0], per_batch):
decoded.append(vae.decode(samples["samples"][start_idx:start_idx+per_batch]))
pbar.update(per_batch)
return (torch.cat(decoded, dim=0), )
class VAEEncodeBatched:
@classmethod
def INPUT_TYPES(s):
return {
"required": {
"pixels": ("IMAGE", ), "vae": ("VAE", ),
"per_batch": ("INT", {"default": 16, "min": 1})
}
}
CATEGORY = "Video Helper Suite π₯π
₯π
π
’/batched nodes"
RETURN_TYPES = ("LATENT",)
FUNCTION = "encode"
def encode(self, vae, pixels, per_batch):
t = []
pbar = ProgressBar(pixels.shape[0])
for start_idx in range(0, pixels.shape[0], per_batch):
try:
sub_pixels = vae.vae_encode_crop_pixels(pixels[start_idx:start_idx+per_batch])
except:
sub_pixels = VAEEncode.vae_encode_crop_pixels(pixels[start_idx:start_idx+per_batch])
t.append(vae.encode(sub_pixels[:,:,:,:3]))
pbar.update(per_batch)
return ({"samples": torch.cat(t, dim=0)}, )
|