import os import time from pathlib import Path from loguru import logger from datetime import datetime from hyvideo.utils.file_utils import save_videos_grid from hyvideo.config import parse_args from hyvideo.inference import HunyuanVideoSampler def main(): args = parse_args() print(args) models_root_path = Path(args.model_base) if not models_root_path.exists(): raise ValueError(f"`models_root` not exists: {models_root_path}") # Create save folder to save the samples save_path = args.save_path if args.save_path_suffix=="" else f'{args.save_path}_{args.save_path_suffix}' if not os.path.exists(args.save_path): os.makedirs(save_path, exist_ok=True) models_root_path = "ckpts/hunyuan-video-t2v-720p/transformers/hunyuan_video_720_bf16.safetensors" text_encoder_filename = "ckpts/hunyuan-video-t2v-720p/text_encoder/llava-llama-3-8b_fp16.safetensors" import json with open("./ckpts/hunyuan-video-t2v-720p/vae/config.json", "r", encoding="utf-8") as reader: text = reader.read() vae_config= json.loads(text) # reduce time window used by the VAE for temporal splitting (former time windows is too large for 24 GB) if vae_config["sample_tsize"] == 64: vae_config["sample_tsize"] = 32 with open("./ckpts/hunyuan-video-t2v-720p/vae/config.json", "w", encoding="utf-8") as writer: writer.write(json.dumps(vae_config)) # Load models hunyuan_video_sampler = HunyuanVideoSampler.from_pretrained(models_root_path,text_encoder_filename, args=args) from mmgp import offload, profile_type pipe = hunyuan_video_sampler.pipeline offload.profile(pipe, profile_no= profile_type.HighRAM_LowVRAM_Fast) # Get the updated args args = hunyuan_video_sampler.args # Start sampling # TODO: batch inference check outputs = hunyuan_video_sampler.predict( prompt=args.prompt, height=args.video_size[0], width=args.video_size[1], video_length=args.video_length, seed=args.seed, negative_prompt=args.neg_prompt, infer_steps=args.infer_steps, guidance_scale=args.cfg_scale, num_videos_per_prompt=args.num_videos, flow_shift=args.flow_shift, batch_size=args.batch_size, embedded_guidance_scale=args.embedded_cfg_scale ) samples = outputs['samples'] # Save samples if 'LOCAL_RANK' not in os.environ or int(os.environ['LOCAL_RANK']) == 0: for i, sample in enumerate(samples): sample = samples[i].unsqueeze(0) time_flag = datetime.fromtimestamp(time.time()).strftime("%Y-%m-%d-%H:%M:%S") save_path = f"{save_path}/{time_flag}_seed{outputs['seeds'][i]}_{outputs['prompts'][i][:100].replace('/','')}.mp4" save_videos_grid(sample, save_path, fps=24) logger.info(f'Sample save to: {save_path}') if __name__ == "__main__": main()