|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import argparse |
|
import datetime |
|
import os |
|
import os.path as osp |
|
import subprocess |
|
|
|
from termcolor import colored |
|
|
|
|
|
def supports_gpus_per_node(): |
|
VILA_DATASETS = os.environ.get("VILA_DATASETS", "") |
|
if "eos" in VILA_DATASETS.lower(): |
|
return False |
|
return True |
|
|
|
|
|
def main() -> None: |
|
parser = argparse.ArgumentParser() |
|
parser.add_argument("--job-name", "-J", type=str, required=True) |
|
parser.add_argument("--nodes", "-N", type=int, default=1) |
|
parser.add_argument("--gpus-per-node", type=int, default=8) |
|
parser.add_argument("--mode", "-m", type=str, default="train") |
|
parser.add_argument("--time", "-t", type=str, default="4:00:00") |
|
parser.add_argument("--timedelta", type=int, default=5) |
|
parser.add_argument("--output-dir", type=str) |
|
parser.add_argument("--max-retry", type=int, default=-1) |
|
|
|
parser.add_argument("--pty", action="store_true") |
|
parser.add_argument("cmd", nargs=argparse.REMAINDER) |
|
args = parser.parse_args() |
|
|
|
if args.max_retry < 0: |
|
if args.mode == "train": |
|
args.max_retry = 3 |
|
else: |
|
args.max_retry = 0 |
|
|
|
|
|
if "%t" in args.job_name: |
|
args.job_name = args.job_name.replace("%t", datetime.datetime.now().strftime("%Y%m%d%H%M%S")) |
|
if args.output_dir is None: |
|
args.output_dir = os.path.join("runs", args.mode, args.job_name) |
|
output_dir = os.path.expanduser(args.output_dir) |
|
|
|
|
|
time = datetime.datetime.strptime(args.time, "%H:%M:%S") |
|
if time < datetime.datetime.strptime("0:01:00", "%H:%M:%S"): |
|
raise ValueError("Time must be at least 1 minutes") |
|
timeout = time - datetime.timedelta(minutes=args.timedelta) |
|
timeout = timeout.hour * 60 + timeout.minute |
|
timeout = f"{timeout}m" |
|
|
|
|
|
if "VILA_SLURM_ACCOUNT" not in os.environ or "VILA_SLURM_PARTITION" not in os.environ: |
|
raise ValueError("`VILA_SLURM_ACCOUNT` and `VILA_SLURM_PARTITION` must be set in the environment.") |
|
account = os.environ["VILA_SLURM_ACCOUNT"] |
|
partition = os.environ["VILA_SLURM_PARTITION"] |
|
|
|
|
|
env = os.environ.copy() |
|
env["RUN_NAME"] = args.job_name |
|
env["OUTPUT_DIR"] = output_dir |
|
|
|
|
|
cmd = ["srun"] |
|
cmd += ["--account", account] |
|
cmd += ["--partition", partition] |
|
cmd += ["--job-name", f"{account}:{args.mode}/{args.job_name}"] |
|
if not args.pty: |
|
|
|
cmd += ["--output", f"{output_dir}/slurm/%J.out"] |
|
cmd += ["--error", f"{output_dir}/slurm/%J.err"] |
|
cmd += ["--nodes", str(args.nodes)] |
|
if supports_gpus_per_node(): |
|
|
|
cmd += ["--gpus-per-node", str(args.gpus_per_node)] |
|
cmd += ["--time", args.time] |
|
cmd += ["--exclusive"] |
|
cmd += ["timeout", timeout] |
|
cmd += args.cmd |
|
full_cmd = " ".join(cmd) |
|
if os.environ.get("SLURM_JOB_ID"): |
|
print(colored("Running inside slurm nodes detected", "yellow")) |
|
full_cmd = " ".join(args.cmd) |
|
print(colored(full_cmd, attrs=["bold"])) |
|
|
|
|
|
fail_times = 0 |
|
while True: |
|
returncode = subprocess.run(full_cmd, env=env, shell=True).returncode |
|
print(f"returncode: {returncode}") |
|
if returncode == 0: |
|
print("Job finished successfully!") |
|
break |
|
if returncode != 124: |
|
fail_times += 1 |
|
if fail_times > args.max_retry: |
|
break |
|
print(f"Job failed, retrying {fail_times} / {args.max_retry}") |
|
else: |
|
fail_times = 0 |
|
print("Job timed out, retrying...") |
|
|
|
|
|
print(f"Job finished with exit code {returncode}") |
|
exit(returncode) |
|
|
|
|
|
if __name__ == "__main__": |
|
main() |
|
|