Spaces:
Runtime error
Runtime error
File size: 4,782 Bytes
d6d3a5b |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 |
import json
import os
import os.path as op
import time
import comet_ml
import numpy as np
import torch
from loguru import logger
from tqdm import tqdm
from src.datasets.dataset_utils import copy_repo_arctic
# folder used for debugging
DUMMY_EXP = "xxxxxxxxx"
def add_paths(args):
exp_key = args.exp_key
args_p = f"./logs/{exp_key}/args.json"
ckpt_p = f"./logs/{exp_key}/checkpoints/last.ckpt"
if not op.exists(ckpt_p) or DUMMY_EXP in ckpt_p:
ckpt_p = ""
if args.resume_ckpt != "":
ckpt_p = args.resume_ckpt
args.ckpt_p = ckpt_p
args.log_dir = f"./logs/{exp_key}"
if args.infer_ckpt != "":
basedir = "/".join(args.infer_ckpt.split("/")[:2])
basename = op.basename(args.infer_ckpt).replace(".ckpt", ".params.pt")
args.interface_p = op.join(basedir, basename)
args.args_p = args_p
if args.cluster:
args.run_p = op.join(args.log_dir, "condor", "run.sh")
args.submit_p = op.join(args.log_dir, "condor", "submit.sub")
args.repo_p = op.join(args.log_dir, "repo")
return args
def save_args(args, save_keys):
args_save = {}
for key, val in args.items():
if key in save_keys:
args_save[key] = val
with open(args.args_p, "w") as f:
json.dump(args_save, f, indent=4)
logger.info(f"Saved args at {args.args_p}")
def create_files(args):
os.makedirs(args.log_dir, exist_ok=True)
if args.cluster:
os.makedirs(op.dirname(args.run_p), exist_ok=True)
copy_repo_arctic(args.exp_key)
def log_exp_meta(args):
tags = [args.method]
logger.info(f"Experiment tags: {tags}")
args.experiment.set_name(args.exp_key)
args.experiment.add_tags(tags)
args.experiment.log_parameters(args)
def init_experiment(args):
if args.resume_ckpt != "":
args.exp_key = args.resume_ckpt.split("/")[1]
if args.fast_dev_run:
args.exp_key = DUMMY_EXP
if args.exp_key == "":
args.exp_key = generate_exp_key()
args = add_paths(args)
if op.exists(args.args_p) and args.exp_key not in [DUMMY_EXP]:
with open(args.args_p, "r") as f:
args_disk = json.load(f)
if "comet_key" in args_disk.keys():
args.comet_key = args_disk["comet_key"]
create_files(args)
project_name = args.project
disabled = args.mute
comet_url = args["comet_key"] if "comet_key" in args.keys() else None
api_key = os.environ["COMET_API_KEY"]
workspace = os.environ["COMET_WORKSPACE"]
if not args.cluster:
if comet_url is None:
experiment = comet_ml.Experiment(
api_key=api_key,
workspace=workspace,
project_name=project_name,
disabled=disabled,
display_summary_level=0,
)
args.comet_key = experiment.get_key()
else:
experiment = comet_ml.ExistingExperiment(
previous_experiment=comet_url,
api_key=api_key,
project_name=project_name,
workspace=workspace,
disabled=disabled,
display_summary_level=0,
)
device = "cuda" if torch.cuda.is_available() else "cpu"
logger.add(
os.path.join(args.log_dir, "train.log"),
level="INFO",
colorize=True,
)
logger.info(torch.cuda.get_device_properties(device))
args.gpu = torch.cuda.get_device_properties(device).name
else:
experiment = None
args.experiment = experiment
return experiment, args
def log_dict(experiment, metric_dict, step, postfix=None):
if experiment is None:
return
for key, value in metric_dict.items():
if postfix is not None:
key = key + postfix
if isinstance(value, torch.Tensor) and len(value.view(-1)) == 1:
value = value.item()
if isinstance(value, (int, float, np.float32)):
experiment.log_metric(key, value, step=step)
def generate_exp_key():
import random
hash = random.getrandbits(128)
key = "%032x" % (hash)
key = key[:9]
return key
def push_images(experiment, all_im_list, global_step=None, no_tqdm=False, verbose=True):
if verbose:
print("Pushing PIL images")
tic = time.time()
iterator = all_im_list if no_tqdm else tqdm(all_im_list)
for im in iterator:
im_np = np.array(im["im"])
if "fig_name" in im.keys():
experiment.log_image(im_np, im["fig_name"], step=global_step)
else:
experiment.log_image(im_np, "unnamed", step=global_step)
if verbose:
toc = time.time()
print("Done pushing PIL images (%.1fs)" % (toc - tic))
|