meow
init
d6d3a5b
raw
history blame
No virus
4.78 kB
import json
import os
import os.path as op
import time
import comet_ml
import numpy as np
import torch
from loguru import logger
from tqdm import tqdm
from src.datasets.dataset_utils import copy_repo_arctic
# folder used for debugging
DUMMY_EXP = "xxxxxxxxx"
def add_paths(args):
exp_key = args.exp_key
args_p = f"./logs/{exp_key}/args.json"
ckpt_p = f"./logs/{exp_key}/checkpoints/last.ckpt"
if not op.exists(ckpt_p) or DUMMY_EXP in ckpt_p:
ckpt_p = ""
if args.resume_ckpt != "":
ckpt_p = args.resume_ckpt
args.ckpt_p = ckpt_p
args.log_dir = f"./logs/{exp_key}"
if args.infer_ckpt != "":
basedir = "/".join(args.infer_ckpt.split("/")[:2])
basename = op.basename(args.infer_ckpt).replace(".ckpt", ".params.pt")
args.interface_p = op.join(basedir, basename)
args.args_p = args_p
if args.cluster:
args.run_p = op.join(args.log_dir, "condor", "run.sh")
args.submit_p = op.join(args.log_dir, "condor", "submit.sub")
args.repo_p = op.join(args.log_dir, "repo")
return args
def save_args(args, save_keys):
args_save = {}
for key, val in args.items():
if key in save_keys:
args_save[key] = val
with open(args.args_p, "w") as f:
json.dump(args_save, f, indent=4)
logger.info(f"Saved args at {args.args_p}")
def create_files(args):
os.makedirs(args.log_dir, exist_ok=True)
if args.cluster:
os.makedirs(op.dirname(args.run_p), exist_ok=True)
copy_repo_arctic(args.exp_key)
def log_exp_meta(args):
tags = [args.method]
logger.info(f"Experiment tags: {tags}")
args.experiment.set_name(args.exp_key)
args.experiment.add_tags(tags)
args.experiment.log_parameters(args)
def init_experiment(args):
if args.resume_ckpt != "":
args.exp_key = args.resume_ckpt.split("/")[1]
if args.fast_dev_run:
args.exp_key = DUMMY_EXP
if args.exp_key == "":
args.exp_key = generate_exp_key()
args = add_paths(args)
if op.exists(args.args_p) and args.exp_key not in [DUMMY_EXP]:
with open(args.args_p, "r") as f:
args_disk = json.load(f)
if "comet_key" in args_disk.keys():
args.comet_key = args_disk["comet_key"]
create_files(args)
project_name = args.project
disabled = args.mute
comet_url = args["comet_key"] if "comet_key" in args.keys() else None
api_key = os.environ["COMET_API_KEY"]
workspace = os.environ["COMET_WORKSPACE"]
if not args.cluster:
if comet_url is None:
experiment = comet_ml.Experiment(
api_key=api_key,
workspace=workspace,
project_name=project_name,
disabled=disabled,
display_summary_level=0,
)
args.comet_key = experiment.get_key()
else:
experiment = comet_ml.ExistingExperiment(
previous_experiment=comet_url,
api_key=api_key,
project_name=project_name,
workspace=workspace,
disabled=disabled,
display_summary_level=0,
)
device = "cuda" if torch.cuda.is_available() else "cpu"
logger.add(
os.path.join(args.log_dir, "train.log"),
level="INFO",
colorize=True,
)
logger.info(torch.cuda.get_device_properties(device))
args.gpu = torch.cuda.get_device_properties(device).name
else:
experiment = None
args.experiment = experiment
return experiment, args
def log_dict(experiment, metric_dict, step, postfix=None):
if experiment is None:
return
for key, value in metric_dict.items():
if postfix is not None:
key = key + postfix
if isinstance(value, torch.Tensor) and len(value.view(-1)) == 1:
value = value.item()
if isinstance(value, (int, float, np.float32)):
experiment.log_metric(key, value, step=step)
def generate_exp_key():
import random
hash = random.getrandbits(128)
key = "%032x" % (hash)
key = key[:9]
return key
def push_images(experiment, all_im_list, global_step=None, no_tqdm=False, verbose=True):
if verbose:
print("Pushing PIL images")
tic = time.time()
iterator = all_im_list if no_tqdm else tqdm(all_im_list)
for im in iterator:
im_np = np.array(im["im"])
if "fig_name" in im.keys():
experiment.log_image(im_np, im["fig_name"], step=global_step)
else:
experiment.log_image(im_np, "unnamed", step=global_step)
if verbose:
toc = time.time()
print("Done pushing PIL images (%.1fs)" % (toc - tic))