Wan2.1-sajjad / generate.py
seokochin's picture
Upload 4 files
feedcc2 verified
raw
history blame
1.04 kB
import argparse
import torch
import subprocess
import os
# Define Arguments
parser = argparse.ArgumentParser()
parser.add_argument("--task", type=str, default="t2v-1.3B")
parser.add_argument("--size", type=str, default="832*480")
parser.add_argument("--frame_num", type=int, default=60)
parser.add_argument("--sample_steps", type=int, default=30)
parser.add_argument("--ckpt_dir", type=str, default="./Wan2.1-T2V-1.3B")
parser.add_argument("--prompt", type=str, required=True)
args = parser.parse_args()
# Check GPU Availability
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Using device: {device}")
# Run WAN 2.1 Inference
command = f"python run_model.py --task {args.task} --size {args.size} --frame_num {args.frame_num} --sample_steps {args.sample_steps} --ckpt_dir {args.ckpt_dir} --prompt \"{args.prompt}\" --device {device}"
subprocess.run(command, shell=True)
# Save output
if os.path.exists("output.mp4"):
print("✅ Video generated successfully: output.mp4")
else:
print("❌ Error generating video.")