Spaces:
Runtime error
Runtime error
import os | |
import time | |
from pathlib import Path | |
from loguru import logger | |
from datetime import datetime | |
from hyvideo.utils.file_utils import save_videos_grid | |
from hyvideo.config import parse_args | |
from hyvideo.inference import HunyuanVideoSampler | |
def main(): | |
args = parse_args() | |
print(args) | |
models_root_path = Path(args.model_base) | |
if not models_root_path.exists(): | |
raise ValueError(f"`models_root` not exists: {models_root_path}") | |
# Create save folder to save the samples | |
save_path = args.save_path if args.save_path_suffix=="" else f'{args.save_path}_{args.save_path_suffix}' | |
if not os.path.exists(args.save_path): | |
os.makedirs(save_path, exist_ok=True) | |
models_root_path = "ckpts/hunyuan-video-t2v-720p/transformers/hunyuan_video_720_bf16.safetensors" | |
text_encoder_filename = "ckpts/hunyuan-video-t2v-720p/text_encoder/llava-llama-3-8b_fp16.safetensors" | |
import json | |
with open("./ckpts/hunyuan-video-t2v-720p/vae/config.json", "r", encoding="utf-8") as reader: | |
text = reader.read() | |
vae_config= json.loads(text) | |
# reduce time window used by the VAE for temporal splitting (former time windows is too large for 24 GB) | |
if vae_config["sample_tsize"] == 64: | |
vae_config["sample_tsize"] = 32 | |
with open("./ckpts/hunyuan-video-t2v-720p/vae/config.json", "w", encoding="utf-8") as writer: | |
writer.write(json.dumps(vae_config)) | |
# Load models | |
hunyuan_video_sampler = HunyuanVideoSampler.from_pretrained(models_root_path,text_encoder_filename, args=args) | |
from mmgp import offload, profile_type | |
pipe = hunyuan_video_sampler.pipeline | |
offload.profile(pipe, profile_no= profile_type.HighRAM_LowVRAM_Fast) | |
# Get the updated args | |
args = hunyuan_video_sampler.args | |
# Start sampling | |
# TODO: batch inference check | |
outputs = hunyuan_video_sampler.predict( | |
prompt=args.prompt, | |
height=args.video_size[0], | |
width=args.video_size[1], | |
video_length=args.video_length, | |
seed=args.seed, | |
negative_prompt=args.neg_prompt, | |
infer_steps=args.infer_steps, | |
guidance_scale=args.cfg_scale, | |
num_videos_per_prompt=args.num_videos, | |
flow_shift=args.flow_shift, | |
batch_size=args.batch_size, | |
embedded_guidance_scale=args.embedded_cfg_scale | |
) | |
samples = outputs['samples'] | |
# Save samples | |
if 'LOCAL_RANK' not in os.environ or int(os.environ['LOCAL_RANK']) == 0: | |
for i, sample in enumerate(samples): | |
sample = samples[i].unsqueeze(0) | |
time_flag = datetime.fromtimestamp(time.time()).strftime("%Y-%m-%d-%H:%M:%S") | |
save_path = f"{save_path}/{time_flag}_seed{outputs['seeds'][i]}_{outputs['prompts'][i][:100].replace('/','')}.mp4" | |
save_videos_grid(sample, save_path, fps=24) | |
logger.info(f'Sample save to: {save_path}') | |
if __name__ == "__main__": | |
main() | |