import yaml import os from torch.hub import download_url_to_file, get_dir from urllib.parse import urlparse import torch import typing import traceback import einops import gc import torchvision.transforms.functional as transform from comfy.model_management import soft_empty_cache, get_torch_device import numpy as np BASE_MODEL_DOWNLOAD_URLS = [ "https://github.com/styler00dollar/VSGAN-tensorrt-docker/releases/download/models/", "https://github.com/Fannovel16/ComfyUI-Frame-Interpolation/releases/download/models/", "https://github.com/dajes/frame-interpolation-pytorch/releases/download/v1.0.0/" ] config_path = os.path.join(os.path.dirname(__file__), "./config.yaml") if os.path.exists(config_path): config = yaml.load(open(config_path, "r"), Loader=yaml.FullLoader) else: raise Exception("config.yaml file is neccessary, plz recreate the config file by downloading it from https://github.com/Fannovel16/ComfyUI-Frame-Interpolation") DEVICE = get_torch_device() class InterpolationStateList(): def __init__(self, frame_indices: typing.List[int], is_skip_list: bool): self.frame_indices = frame_indices self.is_skip_list = is_skip_list def is_frame_skipped(self, frame_index): is_frame_in_list = frame_index in self.frame_indices return self.is_skip_list and is_frame_in_list or not self.is_skip_list and not is_frame_in_list class MakeInterpolationStateList: @classmethod def INPUT_TYPES(s): return { "required": { "frame_indices": ("STRING", {"multiline": True, "default": "1,2,3"}), "is_skip_list": ("BOOLEAN", {"default": True},), }, } RETURN_TYPES = ("INTERPOLATION_STATES",) FUNCTION = "create_options" CATEGORY = "ComfyUI-Frame-Interpolation/VFI" def create_options(self, frame_indices: str, is_skip_list: bool): frame_indices_list = [int(item) for item in frame_indices.split(',')] interpolation_state_list = InterpolationStateList( frame_indices=frame_indices_list, is_skip_list=is_skip_list, ) return (interpolation_state_list,) def get_ckpt_container_path(model_type): return os.path.abspath(os.path.join(os.path.dirname(__file__), config["ckpts_path"], model_type)) def load_file_from_url(url, model_dir=None, progress=True, file_name=None): """Load file form http url, will download models if necessary. Ref:https://github.com/1adrianb/face-alignment/blob/master/face_alignment/utils.py Args: url (str): URL to be downloaded. model_dir (str): The path to save the downloaded model. Should be a full path. If None, use pytorch hub_dir. Default: None. progress (bool): Whether to show the download progress. Default: True. file_name (str): The downloaded file name. If None, use the file name in the url. Default: None. Returns: str: The path to the downloaded file. """ if model_dir is None: # use the pytorch hub_dir hub_dir = get_dir() model_dir = os.path.join(hub_dir, 'checkpoints') os.makedirs(model_dir, exist_ok=True) parts = urlparse(url) file_name = os.path.basename(parts.path) if file_name is not None: file_name = file_name cached_file = os.path.abspath(os.path.join(model_dir, file_name)) if not os.path.exists(cached_file): print(f'Downloading: "{url}" to {cached_file}\n') download_url_to_file(url, cached_file, hash_prefix=None, progress=progress) return cached_file def load_file_from_github_release(model_type, ckpt_name): error_strs = [] for i, base_model_download_url in enumerate(BASE_MODEL_DOWNLOAD_URLS): try: return load_file_from_url(base_model_download_url + ckpt_name, get_ckpt_container_path(model_type)) except Exception: traceback_str = traceback.format_exc() if i < len(BASE_MODEL_DOWNLOAD_URLS) - 1: print("Failed! Trying another endpoint.") error_strs.append(f"Error when downloading from: {base_model_download_url + ckpt_name}\n\n{traceback_str}") error_str = '\n\n'.join(error_strs) raise Exception(f"Tried all GitHub base urls to download {ckpt_name} but no suceess. Below is the error log:\n\n{error_str}") def load_file_from_direct_url(model_type, url): return load_file_from_url(url, get_ckpt_container_path(model_type)) def preprocess_frames(frames): return einops.rearrange(frames[..., :3], "n h w c -> n c h w") def postprocess_frames(frames): return einops.rearrange(frames, "n c h w -> n h w c")[..., :3].cpu() def assert_batch_size(frames, batch_size=2, vfi_name=None): subject_verb = "Most VFI models require" if vfi_name is None else f"VFI model {vfi_name} requires" assert len(frames) >= batch_size, f"{subject_verb} at least {batch_size} frames to work with, only found {frames.shape[0]}. Please check the frame input using PreviewImage." def _generic_frame_loop( frames, clear_cache_after_n_frames, multiplier: typing.Union[typing.SupportsInt, typing.List], return_middle_frame_function, *return_middle_frame_function_args, interpolation_states: InterpolationStateList = None, use_timestep=True, dtype=torch.float16, final_logging=True): #https://github.com/hzwer/Practical-RIFE/blob/main/inference_video.py#L169 def non_timestep_inference(frame0, frame1, n): middle = return_middle_frame_function(frame0, frame1, None, *return_middle_frame_function_args) if n == 1: return [middle] first_half = non_timestep_inference(frame0, middle, n=n//2) second_half = non_timestep_inference(middle, frame1, n=n//2) if n%2: return [*first_half, middle, *second_half] else: return [*first_half, *second_half] output_frames = torch.zeros(multiplier*frames.shape[0], *frames.shape[1:], dtype=dtype, device="cpu") out_len = 0 number_of_frames_processed_since_last_cleared_cuda_cache = 0 for frame_itr in range(len(frames) - 1): # Skip the final frame since there are no frames after it frame0 = frames[frame_itr:frame_itr+1] output_frames[out_len] = frame0 # Start with first frame out_len += 1 # Ensure that input frames are in fp32 - the same dtype as model frame0 = frame0.to(dtype=torch.float32) frame1 = frames[frame_itr+1:frame_itr+2].to(dtype=torch.float32) if interpolation_states is not None and interpolation_states.is_frame_skipped(frame_itr): continue # Generate and append a batch of middle frames middle_frame_batches = [] if use_timestep: for middle_i in range(1, multiplier): timestep = middle_i/multiplier middle_frame = return_middle_frame_function( frame0.to(DEVICE), frame1.to(DEVICE), timestep, *return_middle_frame_function_args ).detach().cpu() middle_frame_batches.append(middle_frame.to(dtype=dtype)) else: middle_frames = non_timestep_inference(frame0.to(DEVICE), frame1.to(DEVICE), multiplier - 1) middle_frame_batches.extend(torch.cat(middle_frames, dim=0).detach().cpu().to(dtype=dtype)) # Copy middle frames to output for middle_frame in middle_frame_batches: output_frames[out_len] = middle_frame out_len += 1 number_of_frames_processed_since_last_cleared_cuda_cache += 1 # Try to avoid a memory overflow by clearing cuda cache regularly if number_of_frames_processed_since_last_cleared_cuda_cache >= clear_cache_after_n_frames: print("Comfy-VFI: Clearing cache...", end=' ') soft_empty_cache() number_of_frames_processed_since_last_cleared_cuda_cache = 0 print("Done cache clearing") gc.collect() if final_logging: print(f"Comfy-VFI done! {len(output_frames)} frames generated at resolution: {output_frames[0].shape}") # Append final frame output_frames[out_len] = frames[-1:] out_len += 1 # clear cache for courtesy if final_logging: print("Comfy-VFI: Final clearing cache...", end = ' ') soft_empty_cache() if final_logging: print("Done cache clearing") return output_frames[:out_len] def generic_frame_loop( model_name, frames, clear_cache_after_n_frames, multiplier: typing.Union[typing.SupportsInt, typing.List], return_middle_frame_function, *return_middle_frame_function_args, interpolation_states: InterpolationStateList = None, use_timestep=True, dtype=torch.float32): assert_batch_size(frames, vfi_name=model_name.replace('_', ' ').replace('VFI', '')) if type(multiplier) == int: return _generic_frame_loop( frames, clear_cache_after_n_frames, multiplier, return_middle_frame_function, *return_middle_frame_function_args, interpolation_states=interpolation_states, use_timestep=use_timestep, dtype=dtype ) if type(multiplier) == list: multipliers = list(map(int, multiplier)) multipliers += [2] * (len(frames) - len(multipliers) - 1) frame_batches = [] for frame_itr in range(len(frames) - 1): multiplier = multipliers[frame_itr] if multiplier == 0: continue frame_batch = _generic_frame_loop( frames[frame_itr:frame_itr+2], clear_cache_after_n_frames, multiplier, return_middle_frame_function, *return_middle_frame_function_args, interpolation_states=interpolation_states, use_timestep=use_timestep, dtype=dtype, final_logging=False ) if frame_itr != len(frames) - 2: # Not append last frame unless this batch is the last one frame_batch = frame_batch[:-1] frame_batches.append(frame_batch) output_frames = torch.cat(frame_batches) print(f"Comfy-VFI done! {len(output_frames)} frames generated at resolution: {output_frames[0].shape}") return output_frames raise NotImplementedError(f"multipiler of {type(multiplier)}") class FloatToInt: @classmethod def INPUT_TYPES(s): return { "required": { "float": ("FLOAT", {"default": 0, 'min': 0, 'step': 0.01}) } } RETURN_TYPES = ("INT",) FUNCTION = "convert" CATEGORY = "ComfyUI-Frame-Interpolation" def convert(self, float): if hasattr(float, "__iter__"): return (list(map(int, float)),) return (int(float),) """ def generic_4frame_loop( frames, clear_cache_after_n_frames, multiplier: typing.SupportsInt, return_middle_frame_function, *return_middle_frame_function_args, interpolation_states: InterpolationStateList = None, use_timestep=False): if use_timestep: raise NotImplementedError("Timestep 4 frame VFI model") def non_timestep_inference(frame_0, frame_1, frame_2, frame_3, n): middle = return_middle_frame_function(frame_0, frame_1, None, *return_middle_frame_function_args) if n == 1: return [middle] first_half = non_timestep_inference(frame_0, middle, n=n//2) second_half = non_timestep_inference(middle, frame_1, n=n//2) if n%2: return [*first_half, middle, *second_half] else: return [*first_half, *second_half] """