from concurrent.futures import ThreadPoolExecutor import glob import json import math import os import random import time from typing import Optional, Sequence, Tuple, Union import numpy as np import torch from safetensors.torch import save_file, load_file from safetensors import safe_open from PIL import Image import cv2 import av from utils import safetensors_utils from utils.model_utils import dtype_to_str import logging logger = logging.getLogger(__name__) logging.basicConfig(level=logging.INFO) IMAGE_EXTENSIONS = [".png", ".jpg", ".jpeg", ".webp", ".bmp", ".PNG", ".JPG", ".JPEG", ".WEBP", ".BMP"] try: import pillow_avif IMAGE_EXTENSIONS.extend([".avif", ".AVIF"]) except: pass # JPEG-XL on Linux try: from jxlpy import JXLImagePlugin IMAGE_EXTENSIONS.extend([".jxl", ".JXL"]) except: pass # JPEG-XL on Windows try: import pillow_jxl IMAGE_EXTENSIONS.extend([".jxl", ".JXL"]) except: pass VIDEO_EXTENSIONS = [".mp4", ".avi", ".mov", ".webm", ".MP4", ".AVI", ".MOV", ".WEBM"] # some of them are not tested ARCHITECTURE_HUNYUAN_VIDEO = "hv" def glob_images(directory, base="*"): img_paths = [] for ext in IMAGE_EXTENSIONS: if base == "*": img_paths.extend(glob.glob(os.path.join(glob.escape(directory), base + ext))) else: img_paths.extend(glob.glob(glob.escape(os.path.join(directory, base + ext)))) img_paths = list(set(img_paths)) # remove duplicates img_paths.sort() return img_paths def glob_videos(directory, base="*"): video_paths = [] for ext in VIDEO_EXTENSIONS: if base == "*": video_paths.extend(glob.glob(os.path.join(glob.escape(directory), base + ext))) else: video_paths.extend(glob.glob(glob.escape(os.path.join(directory, base + ext)))) video_paths = list(set(video_paths)) # remove duplicates video_paths.sort() return video_paths def divisible_by(num: int, divisor: int) -> int: return num - num % divisor def resize_image_to_bucket(image: Union[Image.Image, np.ndarray], bucket_reso: tuple[int, int]) -> np.ndarray: """ Resize the image to the bucket resolution. """ is_pil_image = isinstance(image, Image.Image) if is_pil_image: image_width, image_height = image.size else: image_height, image_width = image.shape[:2] if bucket_reso == (image_width, image_height): return np.array(image) if is_pil_image else image bucket_width, bucket_height = bucket_reso if bucket_width == image_width or bucket_height == image_height: image = np.array(image) if is_pil_image else image else: # resize the image to the bucket resolution to match the short side scale_width = bucket_width / image_width scale_height = bucket_height / image_height scale = max(scale_width, scale_height) image_width = int(image_width * scale + 0.5) image_height = int(image_height * scale + 0.5) if scale > 1: image = Image.fromarray(image) if not is_pil_image else image image = image.resize((image_width, image_height), Image.LANCZOS) image = np.array(image) else: image = np.array(image) if is_pil_image else image image = cv2.resize(image, (image_width, image_height), interpolation=cv2.INTER_AREA) # crop the image to the bucket resolution crop_left = (image_width - bucket_width) // 2 crop_top = (image_height - bucket_height) // 2 image = image[crop_top : crop_top + bucket_height, crop_left : crop_left + bucket_width] return image class ItemInfo: def __init__( self, item_key: str, caption: str, original_size: tuple[int, int], bucket_size: Optional[Union[tuple[int, int], tuple[int, int, int]]] = None, frame_count: Optional[int] = None, content: Optional[np.ndarray] = None, latent_cache_path: Optional[str] = None, ) -> None: self.item_key = item_key self.caption = caption self.original_size = original_size self.bucket_size = bucket_size self.frame_count = frame_count self.content = content self.latent_cache_path = latent_cache_path self.text_encoder_output_cache_path: Optional[str] = None def __str__(self) -> str: return ( f"ItemInfo(item_key={self.item_key}, caption={self.caption}, " + f"original_size={self.original_size}, bucket_size={self.bucket_size}, " + f"frame_count={self.frame_count}, latent_cache_path={self.latent_cache_path})" ) def save_latent_cache(item_info: ItemInfo, latent: torch.Tensor): assert latent.dim() == 4, "latent should be 4D tensor (frame, channel, height, width)" metadata = { "architecture": "hunyuan_video", "width": f"{item_info.original_size[0]}", "height": f"{item_info.original_size[1]}", "format_version": "1.0.0", } if item_info.frame_count is not None: metadata["frame_count"] = f"{item_info.frame_count}" _, F, H, W = latent.shape dtype_str = dtype_to_str(latent.dtype) sd = {f"latents_{F}x{H}x{W}_{dtype_str}": latent.detach().cpu()} latent_dir = os.path.dirname(item_info.latent_cache_path) os.makedirs(latent_dir, exist_ok=True) save_file(sd, item_info.latent_cache_path, metadata=metadata) def save_text_encoder_output_cache(item_info: ItemInfo, embed: torch.Tensor, mask: Optional[torch.Tensor], is_llm: bool): assert ( embed.dim() == 1 or embed.dim() == 2 ), f"embed should be 2D tensor (feature, hidden_size) or (hidden_size,), got {embed.shape}" assert mask is None or mask.dim() == 1, f"mask should be 1D tensor (feature), got {mask.shape}" metadata = { "architecture": "hunyuan_video", "caption1": item_info.caption, "format_version": "1.0.0", } sd = {} if os.path.exists(item_info.text_encoder_output_cache_path): # load existing cache and update metadata with safetensors_utils.MemoryEfficientSafeOpen(item_info.text_encoder_output_cache_path) as f: existing_metadata = f.metadata() for key in f.keys(): sd[key] = f.get_tensor(key) assert existing_metadata["architecture"] == metadata["architecture"], "architecture mismatch" if existing_metadata["caption1"] != metadata["caption1"]: logger.warning(f"caption mismatch: existing={existing_metadata['caption1']}, new={metadata['caption1']}, overwrite") # TODO verify format_version existing_metadata.pop("caption1", None) existing_metadata.pop("format_version", None) metadata.update(existing_metadata) # copy existing metadata else: text_encoder_output_dir = os.path.dirname(item_info.text_encoder_output_cache_path) os.makedirs(text_encoder_output_dir, exist_ok=True) dtype_str = dtype_to_str(embed.dtype) text_encoder_type = "llm" if is_llm else "clipL" sd[f"{text_encoder_type}_{dtype_str}"] = embed.detach().cpu() if mask is not None: sd[f"{text_encoder_type}_mask"] = mask.detach().cpu() safetensors_utils.mem_eff_save_file(sd, item_info.text_encoder_output_cache_path, metadata=metadata) class BucketSelector: RESOLUTION_STEPS_HUNYUAN = 16 def __init__(self, resolution: Tuple[int, int], enable_bucket: bool = True, no_upscale: bool = False): self.resolution = resolution self.bucket_area = resolution[0] * resolution[1] self.reso_steps = BucketSelector.RESOLUTION_STEPS_HUNYUAN if not enable_bucket: # only define one bucket self.bucket_resolutions = [resolution] self.no_upscale = False else: # prepare bucket resolution self.no_upscale = no_upscale sqrt_size = int(math.sqrt(self.bucket_area)) min_size = divisible_by(sqrt_size // 2, self.reso_steps) self.bucket_resolutions = [] for w in range(min_size, sqrt_size + self.reso_steps, self.reso_steps): h = divisible_by(self.bucket_area // w, self.reso_steps) self.bucket_resolutions.append((w, h)) self.bucket_resolutions.append((h, w)) self.bucket_resolutions = list(set(self.bucket_resolutions)) self.bucket_resolutions.sort() # calculate aspect ratio to find the nearest resolution self.aspect_ratios = np.array([w / h for w, h in self.bucket_resolutions]) def get_bucket_resolution(self, image_size: tuple[int, int]) -> tuple[int, int]: """ return the bucket resolution for the given image size, (width, height) """ area = image_size[0] * image_size[1] if self.no_upscale and area <= self.bucket_area: w, h = image_size w = divisible_by(w, self.reso_steps) h = divisible_by(h, self.reso_steps) return w, h aspect_ratio = image_size[0] / image_size[1] ar_errors = self.aspect_ratios - aspect_ratio bucket_id = np.abs(ar_errors).argmin() return self.bucket_resolutions[bucket_id] def load_video( video_path: str, start_frame: Optional[int] = None, end_frame: Optional[int] = None, bucket_selector: Optional[BucketSelector] = None, ) -> list[np.ndarray]: container = av.open(video_path) video = [] bucket_reso = None for i, frame in enumerate(container.decode(video=0)): if start_frame is not None and i < start_frame: continue if end_frame is not None and i >= end_frame: break frame = frame.to_image() if bucket_selector is not None and bucket_reso is None: bucket_reso = bucket_selector.get_bucket_resolution(frame.size) if bucket_reso is not None: frame = resize_image_to_bucket(frame, bucket_reso) else: frame = np.array(frame) video.append(frame) container.close() return video class BucketBatchManager: def __init__(self, bucketed_item_info: dict[tuple[int, int], list[ItemInfo]], batch_size: int): self.batch_size = batch_size self.buckets = bucketed_item_info self.bucket_resos = list(self.buckets.keys()) self.bucket_resos.sort() self.bucket_batch_indices = [] for bucket_reso in self.bucket_resos: bucket = self.buckets[bucket_reso] num_batches = math.ceil(len(bucket) / self.batch_size) for i in range(num_batches): self.bucket_batch_indices.append((bucket_reso, i)) self.shuffle() def show_bucket_info(self): for bucket_reso in self.bucket_resos: bucket = self.buckets[bucket_reso] logger.info(f"bucket: {bucket_reso}, count: {len(bucket)}") logger.info(f"total batches: {len(self)}") def shuffle(self): for bucket in self.buckets.values(): random.shuffle(bucket) random.shuffle(self.bucket_batch_indices) def __len__(self): return len(self.bucket_batch_indices) def __getitem__(self, idx): bucket_reso, batch_idx = self.bucket_batch_indices[idx] bucket = self.buckets[bucket_reso] start = batch_idx * self.batch_size end = min(start + self.batch_size, len(bucket)) latents = [] llm_embeds = [] llm_masks = [] clip_l_embeds = [] for item_info in bucket[start:end]: sd = load_file(item_info.latent_cache_path) latent = None for key in sd.keys(): if key.startswith("latents_"): latent = sd[key] break latents.append(latent) sd = load_file(item_info.text_encoder_output_cache_path) llm_embed = llm_mask = clip_l_embed = None for key in sd.keys(): if key.startswith("llm_mask"): llm_mask = sd[key] elif key.startswith("llm_"): llm_embed = sd[key] elif key.startswith("clipL_mask"): pass elif key.startswith("clipL_"): clip_l_embed = sd[key] llm_embeds.append(llm_embed) llm_masks.append(llm_mask) clip_l_embeds.append(clip_l_embed) latents = torch.stack(latents) llm_embeds = torch.stack(llm_embeds) llm_masks = torch.stack(llm_masks) clip_l_embeds = torch.stack(clip_l_embeds) return latents, llm_embeds, llm_masks, clip_l_embeds class ContentDatasource: def __init__(self): self.caption_only = False def set_caption_only(self, caption_only: bool): self.caption_only = caption_only def is_indexable(self): return False def get_caption(self, idx: int) -> tuple[str, str]: """ Returns caption. May not be called if is_indexable() returns False. """ raise NotImplementedError def __len__(self): raise NotImplementedError def __iter__(self): raise NotImplementedError def __next__(self): raise NotImplementedError class ImageDatasource(ContentDatasource): def __init__(self): super().__init__() def get_image_data(self, idx: int) -> tuple[str, Image.Image, str]: """ Returns image data as a tuple of image path, image, and caption for the given index. Key must be unique and valid as a file name. May not be called if is_indexable() returns False. """ raise NotImplementedError class ImageDirectoryDatasource(ImageDatasource): def __init__(self, image_directory: str, caption_extension: Optional[str] = None): super().__init__() self.image_directory = image_directory self.caption_extension = caption_extension self.current_idx = 0 # glob images logger.info(f"glob images in {self.image_directory}") self.image_paths = glob_images(self.image_directory) logger.info(f"found {len(self.image_paths)} images") def is_indexable(self): return True def __len__(self): return len(self.image_paths) def get_image_data(self, idx: int) -> tuple[str, Image.Image, str]: image_path = self.image_paths[idx] image = Image.open(image_path).convert("RGB") _, caption = self.get_caption(idx) return image_path, image, caption def get_caption(self, idx: int) -> tuple[str, str]: image_path = self.image_paths[idx] caption_path = os.path.splitext(image_path)[0] + self.caption_extension if self.caption_extension else "" with open(caption_path, "r", encoding="utf-8") as f: caption = f.read().strip() return image_path, caption def __iter__(self): self.current_idx = 0 return self def __next__(self) -> callable: """ Returns a fetcher function that returns image data. """ if self.current_idx >= len(self.image_paths): raise StopIteration if self.caption_only: def create_caption_fetcher(index): return lambda: self.get_caption(index) fetcher = create_caption_fetcher(self.current_idx) else: def create_image_fetcher(index): return lambda: self.get_image_data(index) fetcher = create_image_fetcher(self.current_idx) self.current_idx += 1 return fetcher class ImageJsonlDatasource(ImageDatasource): def __init__(self, image_jsonl_file: str): super().__init__() self.image_jsonl_file = image_jsonl_file self.current_idx = 0 # load jsonl logger.info(f"load image jsonl from {self.image_jsonl_file}") self.data = [] with open(self.image_jsonl_file, "r", encoding="utf-8") as f: for line in f: data = json.loads(line) self.data.append(data) logger.info(f"loaded {len(self.data)} images") def is_indexable(self): return True def __len__(self): return len(self.data) def get_image_data(self, idx: int) -> tuple[str, Image.Image, str]: data = self.data[idx] image_path = data["image_path"] image = Image.open(image_path).convert("RGB") caption = data["caption"] return image_path, image, caption def get_caption(self, idx: int) -> tuple[str, str]: data = self.data[idx] image_path = data["image_path"] caption = data["caption"] return image_path, caption def __iter__(self): self.current_idx = 0 return self def __next__(self) -> callable: if self.current_idx >= len(self.data): raise StopIteration if self.caption_only: def create_caption_fetcher(index): return lambda: self.get_caption(index) fetcher = create_caption_fetcher(self.current_idx) else: def create_fetcher(index): return lambda: self.get_image_data(index) fetcher = create_fetcher(self.current_idx) self.current_idx += 1 return fetcher class VideoDatasource(ContentDatasource): def __init__(self): super().__init__() # None means all frames self.start_frame = None self.end_frame = None self.bucket_selector = None def __len__(self): raise NotImplementedError def get_video_data_from_path( self, video_path: str, start_frame: Optional[int] = None, end_frame: Optional[int] = None, bucket_selector: Optional[BucketSelector] = None, ) -> tuple[str, list[Image.Image], str]: # this method can resize the video if bucket_selector is given to reduce the memory usage start_frame = start_frame if start_frame is not None else self.start_frame end_frame = end_frame if end_frame is not None else self.end_frame bucket_selector = bucket_selector if bucket_selector is not None else self.bucket_selector video = load_video(video_path, start_frame, end_frame, bucket_selector) return video def set_start_and_end_frame(self, start_frame: Optional[int], end_frame: Optional[int]): self.start_frame = start_frame self.end_frame = end_frame def set_bucket_selector(self, bucket_selector: BucketSelector): self.bucket_selector = bucket_selector def __iter__(self): raise NotImplementedError def __next__(self): raise NotImplementedError class VideoDirectoryDatasource(VideoDatasource): def __init__(self, video_directory: str, caption_extension: Optional[str] = None): super().__init__() self.video_directory = video_directory self.caption_extension = caption_extension self.current_idx = 0 # glob images logger.info(f"glob images in {self.video_directory}") self.video_paths = glob_videos(self.video_directory) logger.info(f"found {len(self.video_paths)} videos") def is_indexable(self): return True def __len__(self): return len(self.video_paths) def get_video_data( self, idx: int, start_frame: Optional[int] = None, end_frame: Optional[int] = None, bucket_selector: Optional[BucketSelector] = None, ) -> tuple[str, list[Image.Image], str]: video_path = self.video_paths[idx] video = self.get_video_data_from_path(video_path, start_frame, end_frame, bucket_selector) _, caption = self.get_caption(idx) return video_path, video, caption def get_caption(self, idx: int) -> tuple[str, str]: video_path = self.video_paths[idx] caption_path = os.path.splitext(video_path)[0] + self.caption_extension if self.caption_extension else "" with open(caption_path, "r", encoding="utf-8") as f: caption = f.read().strip() return video_path, caption def __iter__(self): self.current_idx = 0 return self def __next__(self): if self.current_idx >= len(self.video_paths): raise StopIteration if self.caption_only: def create_caption_fetcher(index): return lambda: self.get_caption(index) fetcher = create_caption_fetcher(self.current_idx) else: def create_fetcher(index): return lambda: self.get_video_data(index) fetcher = create_fetcher(self.current_idx) self.current_idx += 1 return fetcher class VideoJsonlDatasource(VideoDatasource): def __init__(self, video_jsonl_file: str): super().__init__() self.video_jsonl_file = video_jsonl_file self.current_idx = 0 # load jsonl logger.info(f"load video jsonl from {self.video_jsonl_file}") self.data = [] with open(self.video_jsonl_file, "r", encoding="utf-8") as f: for line in f: data = json.loads(line) self.data.append(data) logger.info(f"loaded {len(self.data)} videos") def is_indexable(self): return True def __len__(self): return len(self.data) def get_video_data( self, idx: int, start_frame: Optional[int] = None, end_frame: Optional[int] = None, bucket_selector: Optional[BucketSelector] = None, ) -> tuple[str, list[Image.Image], str]: data = self.data[idx] video_path = data["video_path"] video = self.get_video_data_from_path(video_path, start_frame, end_frame, bucket_selector) caption = data["caption"] return video_path, video, caption def get_caption(self, idx: int) -> tuple[str, str]: data = self.data[idx] video_path = data["video_path"] caption = data["caption"] return video_path, caption def __iter__(self): self.current_idx = 0 return self def __next__(self): if self.current_idx >= len(self.data): raise StopIteration if self.caption_only: def create_caption_fetcher(index): return lambda: self.get_caption(index) fetcher = create_caption_fetcher(self.current_idx) else: def create_fetcher(index): return lambda: self.get_video_data(index) fetcher = create_fetcher(self.current_idx) self.current_idx += 1 return fetcher class BaseDataset(torch.utils.data.Dataset): def __init__( self, resolution: Tuple[int, int] = (960, 544), caption_extension: Optional[str] = None, batch_size: int = 1, enable_bucket: bool = False, bucket_no_upscale: bool = False, cache_directory: Optional[str] = None, debug_dataset: bool = False, ): self.resolution = resolution self.caption_extension = caption_extension self.batch_size = batch_size self.enable_bucket = enable_bucket self.bucket_no_upscale = bucket_no_upscale self.cache_directory = cache_directory self.debug_dataset = debug_dataset self.seed = None self.current_epoch = 0 if not self.enable_bucket: self.bucket_no_upscale = False def get_metadata(self) -> dict: metadata = { "resolution": self.resolution, "caption_extension": self.caption_extension, "batch_size_per_device": self.batch_size, "enable_bucket": bool(self.enable_bucket), "bucket_no_upscale": bool(self.bucket_no_upscale), } return metadata def get_latent_cache_path(self, item_info: ItemInfo) -> str: w, h = item_info.original_size basename = os.path.splitext(os.path.basename(item_info.item_key))[0] assert self.cache_directory is not None, "cache_directory is required / cache_directoryは必須です" return os.path.join(self.cache_directory, f"{basename}_{w:04d}x{h:04d}_{ARCHITECTURE_HUNYUAN_VIDEO}.safetensors") def get_text_encoder_output_cache_path(self, item_info: ItemInfo) -> str: basename = os.path.splitext(os.path.basename(item_info.item_key))[0] assert self.cache_directory is not None, "cache_directory is required / cache_directoryは必須です" return os.path.join(self.cache_directory, f"{basename}_{ARCHITECTURE_HUNYUAN_VIDEO}_te.safetensors") def retrieve_latent_cache_batches(self, num_workers: int): raise NotImplementedError def retrieve_text_encoder_output_cache_batches(self, num_workers: int): raise NotImplementedError def prepare_for_training(self): pass def set_seed(self, seed: int): self.seed = seed def set_current_epoch(self, epoch): if not self.current_epoch == epoch: # shuffle buckets when epoch is incremented if epoch > self.current_epoch: logger.info("epoch is incremented. current_epoch: {}, epoch: {}".format(self.current_epoch, epoch)) num_epochs = epoch - self.current_epoch for _ in range(num_epochs): self.current_epoch += 1 self.shuffle_buckets() # self.current_epoch seem to be set to 0 again in the next epoch. it may be caused by skipped_dataloader? else: logger.warning("epoch is not incremented. current_epoch: {}, epoch: {}".format(self.current_epoch, epoch)) self.current_epoch = epoch def set_current_step(self, step): self.current_step = step def set_max_train_steps(self, max_train_steps): self.max_train_steps = max_train_steps def shuffle_buckets(self): raise NotImplementedError def __len__(self): return NotImplementedError def __getitem__(self, idx): raise NotImplementedError def _default_retrieve_text_encoder_output_cache_batches(self, datasource: ContentDatasource, batch_size: int, num_workers: int): datasource.set_caption_only(True) executor = ThreadPoolExecutor(max_workers=num_workers) data: list[ItemInfo] = [] futures = [] def aggregate_future(consume_all: bool = False): while len(futures) >= num_workers or (consume_all and len(futures) > 0): completed_futures = [future for future in futures if future.done()] if len(completed_futures) == 0: if len(futures) >= num_workers or consume_all: # to avoid adding too many futures time.sleep(0.1) continue else: break # submit batch if possible for future in completed_futures: item_key, caption = future.result() item_info = ItemInfo(item_key, caption, (0, 0), (0, 0)) item_info.text_encoder_output_cache_path = self.get_text_encoder_output_cache_path(item_info) data.append(item_info) futures.remove(future) def submit_batch(flush: bool = False): nonlocal data if len(data) >= batch_size or (len(data) > 0 and flush): batch = data[0:batch_size] if len(data) > batch_size: data = data[batch_size:] else: data = [] return batch return None for fetch_op in datasource: future = executor.submit(fetch_op) futures.append(future) aggregate_future() while True: batch = submit_batch() if batch is None: break yield batch aggregate_future(consume_all=True) while True: batch = submit_batch(flush=True) if batch is None: break yield batch executor.shutdown() class ImageDataset(BaseDataset): def __init__( self, resolution: Tuple[int, int], caption_extension: Optional[str], batch_size: int, enable_bucket: bool, bucket_no_upscale: bool, image_directory: Optional[str] = None, image_jsonl_file: Optional[str] = None, cache_directory: Optional[str] = None, debug_dataset: bool = False, ): super(ImageDataset, self).__init__( resolution, caption_extension, batch_size, enable_bucket, bucket_no_upscale, cache_directory, debug_dataset ) self.image_directory = image_directory self.image_jsonl_file = image_jsonl_file if image_directory is not None: self.datasource = ImageDirectoryDatasource(image_directory, caption_extension) elif image_jsonl_file is not None: self.datasource = ImageJsonlDatasource(image_jsonl_file) else: raise ValueError("image_directory or image_jsonl_file must be specified") if self.cache_directory is None: self.cache_directory = self.image_directory self.batch_manager = None self.num_train_items = 0 def get_metadata(self): metadata = super().get_metadata() if self.image_directory is not None: metadata["image_directory"] = os.path.basename(self.image_directory) if self.image_jsonl_file is not None: metadata["image_jsonl_file"] = os.path.basename(self.image_jsonl_file) return metadata def get_total_image_count(self): return len(self.datasource) if self.datasource.is_indexable() else None def retrieve_latent_cache_batches(self, num_workers: int): buckset_selector = BucketSelector(self.resolution, self.enable_bucket, self.bucket_no_upscale) executor = ThreadPoolExecutor(max_workers=num_workers) batches: dict[tuple[int, int], list[ItemInfo]] = {} # (width, height) -> [ItemInfo] futures = [] def aggregate_future(consume_all: bool = False): while len(futures) >= num_workers or (consume_all and len(futures) > 0): completed_futures = [future for future in futures if future.done()] if len(completed_futures) == 0: if len(futures) >= num_workers or consume_all: # to avoid adding too many futures time.sleep(0.1) continue else: break # submit batch if possible for future in completed_futures: original_size, item_key, image, caption = future.result() bucket_height, bucket_width = image.shape[:2] bucket_reso = (bucket_width, bucket_height) item_info = ItemInfo(item_key, caption, original_size, bucket_reso, content=image) item_info.latent_cache_path = self.get_latent_cache_path(item_info) if bucket_reso not in batches: batches[bucket_reso] = [] batches[bucket_reso].append(item_info) futures.remove(future) def submit_batch(flush: bool = False): for key in batches: if len(batches[key]) >= self.batch_size or flush: batch = batches[key][0 : self.batch_size] if len(batches[key]) > self.batch_size: batches[key] = batches[key][self.batch_size :] else: del batches[key] return key, batch return None, None for fetch_op in self.datasource: def fetch_and_resize(op: callable) -> tuple[tuple[int, int], str, Image.Image, str]: image_key, image, caption = op() image: Image.Image image_size = image.size bucket_reso = buckset_selector.get_bucket_resolution(image_size) image = resize_image_to_bucket(image, bucket_reso) return image_size, image_key, image, caption future = executor.submit(fetch_and_resize, fetch_op) futures.append(future) aggregate_future() while True: key, batch = submit_batch() if key is None: break yield key, batch aggregate_future(consume_all=True) while True: key, batch = submit_batch(flush=True) if key is None: break yield key, batch executor.shutdown() def retrieve_text_encoder_output_cache_batches(self, num_workers: int): return self._default_retrieve_text_encoder_output_cache_batches(self.datasource, self.batch_size, num_workers) def prepare_for_training(self): bucket_selector = BucketSelector(self.resolution, self.enable_bucket, self.bucket_no_upscale) # glob cache files latent_cache_files = glob.glob(os.path.join(self.cache_directory, f"*_{ARCHITECTURE_HUNYUAN_VIDEO}.safetensors")) # assign cache files to item info bucketed_item_info: dict[tuple[int, int], list[ItemInfo]] = {} # (width, height) -> [ItemInfo] for cache_file in latent_cache_files: tokens = os.path.basename(cache_file).split("_") image_size = tokens[-2] # 0000x0000 image_width, image_height = map(int, image_size.split("x")) image_size = (image_width, image_height) item_key = "_".join(tokens[:-2]) text_encoder_output_cache_file = os.path.join( self.cache_directory, f"{item_key}_{ARCHITECTURE_HUNYUAN_VIDEO}_te.safetensors" ) if not os.path.exists(text_encoder_output_cache_file): logger.warning(f"Text encoder output cache file not found: {text_encoder_output_cache_file}") continue bucket_reso = bucket_selector.get_bucket_resolution(image_size) item_info = ItemInfo(item_key, "", image_size, bucket_reso, latent_cache_path=cache_file) item_info.text_encoder_output_cache_path = text_encoder_output_cache_file bucket = bucketed_item_info.get(bucket_reso, []) bucket.append(item_info) bucketed_item_info[bucket_reso] = bucket # prepare batch manager self.batch_manager = BucketBatchManager(bucketed_item_info, self.batch_size) self.batch_manager.show_bucket_info() self.num_train_items = sum([len(bucket) for bucket in bucketed_item_info.values()]) def shuffle_buckets(self): # set random seed for this epoch random.seed(self.seed + self.current_epoch) self.batch_manager.shuffle() def __len__(self): if self.batch_manager is None: return 100 # dummy value return len(self.batch_manager) def __getitem__(self, idx): return self.batch_manager[idx] class VideoDataset(BaseDataset): def __init__( self, resolution: Tuple[int, int], caption_extension: Optional[str], batch_size: int, enable_bucket: bool, bucket_no_upscale: bool, frame_extraction: Optional[str] = "head", frame_stride: Optional[int] = 1, frame_sample: Optional[int] = 1, target_frames: Optional[list[int]] = None, video_directory: Optional[str] = None, video_jsonl_file: Optional[str] = None, cache_directory: Optional[str] = None, debug_dataset: bool = False, ): super(VideoDataset, self).__init__( resolution, caption_extension, batch_size, enable_bucket, bucket_no_upscale, cache_directory, debug_dataset ) self.video_directory = video_directory self.video_jsonl_file = video_jsonl_file self.target_frames = target_frames self.frame_extraction = frame_extraction self.frame_stride = frame_stride self.frame_sample = frame_sample if video_directory is not None: self.datasource = VideoDirectoryDatasource(video_directory, caption_extension) elif video_jsonl_file is not None: self.datasource = VideoJsonlDatasource(video_jsonl_file) if self.frame_extraction == "uniform" and self.frame_sample == 1: self.frame_extraction = "head" logger.warning("frame_sample is set to 1 for frame_extraction=uniform. frame_extraction is changed to head.") if self.frame_extraction == "head": # head extraction. we can limit the number of frames to be extracted self.datasource.set_start_and_end_frame(0, max(self.target_frames)) if self.cache_directory is None: self.cache_directory = self.video_directory self.batch_manager = None self.num_train_items = 0 def get_metadata(self): metadata = super().get_metadata() if self.video_directory is not None: metadata["video_directory"] = os.path.basename(self.video_directory) if self.video_jsonl_file is not None: metadata["video_jsonl_file"] = os.path.basename(self.video_jsonl_file) metadata["frame_extraction"] = self.frame_extraction metadata["frame_stride"] = self.frame_stride metadata["frame_sample"] = self.frame_sample metadata["target_frames"] = self.target_frames return metadata def retrieve_latent_cache_batches(self, num_workers: int): buckset_selector = BucketSelector(self.resolution) self.datasource.set_bucket_selector(buckset_selector) executor = ThreadPoolExecutor(max_workers=num_workers) # key: (width, height, frame_count), value: [ItemInfo] batches: dict[tuple[int, int, int], list[ItemInfo]] = {} futures = [] def aggregate_future(consume_all: bool = False): while len(futures) >= num_workers or (consume_all and len(futures) > 0): completed_futures = [future for future in futures if future.done()] if len(completed_futures) == 0: if len(futures) >= num_workers or consume_all: # to avoid adding too many futures time.sleep(0.1) continue else: break # submit batch if possible for future in completed_futures: original_frame_size, video_key, video, caption = future.result() frame_count = len(video) video = np.stack(video, axis=0) height, width = video.shape[1:3] bucket_reso = (width, height) # already resized crop_pos_and_frames = [] if self.frame_extraction == "head": for target_frame in self.target_frames: if frame_count >= target_frame: crop_pos_and_frames.append((0, target_frame)) elif self.frame_extraction == "chunk": # split by target_frames for target_frame in self.target_frames: for i in range(0, frame_count, target_frame): if i + target_frame <= frame_count: crop_pos_and_frames.append((i, target_frame)) elif self.frame_extraction == "slide": # slide window for target_frame in self.target_frames: if frame_count >= target_frame: for i in range(0, frame_count - target_frame + 1, self.frame_stride): crop_pos_and_frames.append((i, target_frame)) elif self.frame_extraction == "uniform": # select N frames uniformly for target_frame in self.target_frames: if frame_count >= target_frame: frame_indices = np.linspace(0, frame_count - target_frame, self.frame_sample, dtype=int) for i in frame_indices: crop_pos_and_frames.append((i, target_frame)) else: raise ValueError(f"frame_extraction {self.frame_extraction} is not supported") for crop_pos, target_frame in crop_pos_and_frames: cropped_video = video[crop_pos : crop_pos + target_frame] body, ext = os.path.splitext(video_key) item_key = f"{body}_{crop_pos:05d}-{target_frame:03d}{ext}" batch_key = (*bucket_reso, target_frame) # bucket_reso with frame_count item_info = ItemInfo( item_key, caption, original_frame_size, batch_key, frame_count=target_frame, content=cropped_video ) item_info.latent_cache_path = self.get_latent_cache_path(item_info) batch = batches.get(batch_key, []) batch.append(item_info) batches[batch_key] = batch futures.remove(future) def submit_batch(flush: bool = False): for key in batches: if len(batches[key]) >= self.batch_size or flush: batch = batches[key][0 : self.batch_size] if len(batches[key]) > self.batch_size: batches[key] = batches[key][self.batch_size :] else: del batches[key] return key, batch return None, None for operator in self.datasource: def fetch_and_resize(op: callable) -> tuple[tuple[int, int], str, list[np.ndarray], str]: video_key, video, caption = op() video: list[np.ndarray] frame_size = (video[0].shape[1], video[0].shape[0]) # resize if necessary bucket_reso = buckset_selector.get_bucket_resolution(frame_size) video = [resize_image_to_bucket(frame, bucket_reso) for frame in video] return frame_size, video_key, video, caption future = executor.submit(fetch_and_resize, operator) futures.append(future) aggregate_future() while True: key, batch = submit_batch() if key is None: break yield key, batch aggregate_future(consume_all=True) while True: key, batch = submit_batch(flush=True) if key is None: break yield key, batch executor.shutdown() def retrieve_text_encoder_output_cache_batches(self, num_workers: int): return self._default_retrieve_text_encoder_output_cache_batches(self.datasource, self.batch_size, num_workers) def prepare_for_training(self): bucket_selector = BucketSelector(self.resolution, self.enable_bucket, self.bucket_no_upscale) # glob cache files latent_cache_files = glob.glob(os.path.join(self.cache_directory, f"*_{ARCHITECTURE_HUNYUAN_VIDEO}.safetensors")) # assign cache files to item info bucketed_item_info: dict[tuple[int, int, int], list[ItemInfo]] = {} # (width, height, frame_count) -> [ItemInfo] for cache_file in latent_cache_files: tokens = os.path.basename(cache_file).split("_") image_size = tokens[-2] # 0000x0000 image_width, image_height = map(int, image_size.split("x")) image_size = (image_width, image_height) frame_pos, frame_count = tokens[-3].split("-") frame_pos, frame_count = int(frame_pos), int(frame_count) item_key = "_".join(tokens[:-3]) text_encoder_output_cache_file = os.path.join( self.cache_directory, f"{item_key}_{ARCHITECTURE_HUNYUAN_VIDEO}_te.safetensors" ) if not os.path.exists(text_encoder_output_cache_file): logger.warning(f"Text encoder output cache file not found: {text_encoder_output_cache_file}") continue bucket_reso = bucket_selector.get_bucket_resolution(image_size) bucket_reso = (*bucket_reso, frame_count) item_info = ItemInfo(item_key, "", image_size, bucket_reso, frame_count=frame_count, latent_cache_path=cache_file) item_info.text_encoder_output_cache_path = text_encoder_output_cache_file bucket = bucketed_item_info.get(bucket_reso, []) bucket.append(item_info) bucketed_item_info[bucket_reso] = bucket # prepare batch manager self.batch_manager = BucketBatchManager(bucketed_item_info, self.batch_size) self.batch_manager.show_bucket_info() self.num_train_items = sum([len(bucket) for bucket in bucketed_item_info.values()]) def shuffle_buckets(self): # set random seed for this epoch random.seed(self.seed + self.current_epoch) self.batch_manager.shuffle() def __len__(self): if self.batch_manager is None: return 100 # dummy value return len(self.batch_manager) def __getitem__(self, idx): return self.batch_manager[idx] class DatasetGroup(torch.utils.data.ConcatDataset): def __init__(self, datasets: Sequence[Union[ImageDataset, VideoDataset]]): super().__init__(datasets) self.datasets: list[Union[ImageDataset, VideoDataset]] = datasets self.num_train_items = 0 for dataset in self.datasets: self.num_train_items += dataset.num_train_items def set_current_epoch(self, epoch): for dataset in self.datasets: dataset.set_current_epoch(epoch) def set_current_step(self, step): for dataset in self.datasets: dataset.set_current_step(step) def set_max_train_steps(self, max_train_steps): for dataset in self.datasets: dataset.set_max_train_steps(max_train_steps)