Fabrice-TIERCELIN commited on
Commit
da7deef
·
verified ·
1 Parent(s): 71ee741

' instead of "

Browse files
Files changed (1) hide show
  1. sample_video.py +58 -58
sample_video.py CHANGED
@@ -1,58 +1,58 @@
1
- import os
2
- import time
3
- from pathlib import Path
4
- from loguru import logger
5
- from datetime import datetime
6
-
7
- from hyvideo.utils.file_utils import save_videos_grid
8
- from hyvideo.config import parse_args
9
- from hyvideo.inference import HunyuanVideoSampler
10
-
11
-
12
- def main():
13
- args = parse_args()
14
- print(args)
15
- models_root_path = Path(args.model_base)
16
- if not models_root_path.exists():
17
- raise ValueError(f"`models_root` not exists: {models_root_path}")
18
-
19
- # Create save folder to save the samples
20
- save_path = args.save_path if args.save_path_suffix=="" else f'{args.save_path}_{args.save_path_suffix}'
21
- if not os.path.exists(args.save_path):
22
- os.makedirs(save_path, exist_ok=True)
23
-
24
- # Load models
25
- hunyuan_video_sampler = HunyuanVideoSampler.from_pretrained(models_root_path, args=args)
26
-
27
- # Get the updated args
28
- args = hunyuan_video_sampler.args
29
-
30
- # Start sampling
31
- # TODO: batch inference check
32
- outputs = hunyuan_video_sampler.predict(
33
- prompt=args.prompt,
34
- height=args.video_size[0],
35
- width=args.video_size[1],
36
- video_length=args.video_length,
37
- seed=args.seed,
38
- negative_prompt=args.neg_prompt,
39
- infer_steps=args.infer_steps,
40
- guidance_scale=args.cfg_scale,
41
- num_videos_per_prompt=args.num_videos,
42
- flow_shift=args.flow_shift,
43
- batch_size=args.batch_size,
44
- embedded_guidance_scale=args.embedded_cfg_scale
45
- )
46
- samples = outputs['samples']
47
-
48
- # Save samples
49
- if 'LOCAL_RANK' not in os.environ or int(os.environ['LOCAL_RANK']) == 0:
50
- for i, sample in enumerate(samples):
51
- sample = samples[i].unsqueeze(0)
52
- time_flag = datetime.fromtimestamp(time.time()).strftime("%Y-%m-%d-%H:%M:%S")
53
- save_path = f"{save_path}/{time_flag}_seed{outputs['seeds'][i]}_{outputs['prompts'][i][:100].replace('/','')}.mp4"
54
- save_videos_grid(sample, save_path, fps=24)
55
- logger.info(f'Sample save to: {save_path}')
56
-
57
- if __name__ == "__main__":
58
- main()
 
1
+ import os
2
+ import time
3
+ from pathlib import Path
4
+ from loguru import logger
5
+ from datetime import datetime
6
+
7
+ from hyvideo.utils.file_utils import save_videos_grid
8
+ from hyvideo.config import parse_args
9
+ from hyvideo.inference import HunyuanVideoSampler
10
+
11
+
12
+ def main():
13
+ args = parse_args()
14
+ print(args)
15
+ models_root_path = Path(args.model_base)
16
+ if not models_root_path.exists():
17
+ raise ValueError(f"`models_root` not exists: {models_root_path}")
18
+
19
+ # Create save folder to save the samples
20
+ save_path = args.save_path if args.save_path_suffix=="" else f"{args.save_path}_{args.save_path_suffix}"
21
+ if not os.path.exists(args.save_path):
22
+ os.makedirs(save_path, exist_ok=True)
23
+
24
+ # Load models
25
+ hunyuan_video_sampler = HunyuanVideoSampler.from_pretrained(models_root_path, args=args)
26
+
27
+ # Get the updated args
28
+ args = hunyuan_video_sampler.args
29
+
30
+ # Start sampling
31
+ # TODO: batch inference check
32
+ outputs = hunyuan_video_sampler.predict(
33
+ prompt=args.prompt,
34
+ height=args.video_size[0],
35
+ width=args.video_size[1],
36
+ video_length=args.video_length,
37
+ seed=args.seed,
38
+ negative_prompt=args.neg_prompt,
39
+ infer_steps=args.infer_steps,
40
+ guidance_scale=args.cfg_scale,
41
+ num_videos_per_prompt=args.num_videos,
42
+ flow_shift=args.flow_shift,
43
+ batch_size=args.batch_size,
44
+ embedded_guidance_scale=args.embedded_cfg_scale
45
+ )
46
+ samples = outputs["samples"]
47
+
48
+ # Save samples
49
+ if "LOCAL_RANK" not in os.environ or int(os.environ["LOCAL_RANK"]) == 0:
50
+ for i, sample in enumerate(samples):
51
+ sample = samples[i].unsqueeze(0)
52
+ time_flag = datetime.fromtimestamp(time.time()).strftime("%Y-%m-%d-%H:%M:%S")
53
+ save_path = f"{save_path}/{time_flag}_seed{outputs['seeds'][i]}_{outputs['prompts'][i][:100].replace('/','')}.mp4"
54
+ save_videos_grid(sample, save_path, fps=24)
55
+ logger.info(f"Sample save to: {save_path}")
56
+
57
+ if __name__ == "__main__":
58
+ main()