Spaces:
Runtime error
Runtime error
import json | |
import os | |
import os.path as op | |
import time | |
import comet_ml | |
import numpy as np | |
import torch | |
from loguru import logger | |
from tqdm import tqdm | |
from src.datasets.dataset_utils import copy_repo_arctic | |
# folder used for debugging | |
DUMMY_EXP = "xxxxxxxxx" | |
def add_paths(args): | |
exp_key = args.exp_key | |
args_p = f"./logs/{exp_key}/args.json" | |
ckpt_p = f"./logs/{exp_key}/checkpoints/last.ckpt" | |
if not op.exists(ckpt_p) or DUMMY_EXP in ckpt_p: | |
ckpt_p = "" | |
if args.resume_ckpt != "": | |
ckpt_p = args.resume_ckpt | |
args.ckpt_p = ckpt_p | |
args.log_dir = f"./logs/{exp_key}" | |
if args.infer_ckpt != "": | |
basedir = "/".join(args.infer_ckpt.split("/")[:2]) | |
basename = op.basename(args.infer_ckpt).replace(".ckpt", ".params.pt") | |
args.interface_p = op.join(basedir, basename) | |
args.args_p = args_p | |
if args.cluster: | |
args.run_p = op.join(args.log_dir, "condor", "run.sh") | |
args.submit_p = op.join(args.log_dir, "condor", "submit.sub") | |
args.repo_p = op.join(args.log_dir, "repo") | |
return args | |
def save_args(args, save_keys): | |
args_save = {} | |
for key, val in args.items(): | |
if key in save_keys: | |
args_save[key] = val | |
with open(args.args_p, "w") as f: | |
json.dump(args_save, f, indent=4) | |
logger.info(f"Saved args at {args.args_p}") | |
def create_files(args): | |
os.makedirs(args.log_dir, exist_ok=True) | |
if args.cluster: | |
os.makedirs(op.dirname(args.run_p), exist_ok=True) | |
copy_repo_arctic(args.exp_key) | |
def log_exp_meta(args): | |
tags = [args.method] | |
logger.info(f"Experiment tags: {tags}") | |
args.experiment.set_name(args.exp_key) | |
args.experiment.add_tags(tags) | |
args.experiment.log_parameters(args) | |
def init_experiment(args): | |
if args.resume_ckpt != "": | |
args.exp_key = args.resume_ckpt.split("/")[1] | |
if args.fast_dev_run: | |
args.exp_key = DUMMY_EXP | |
if args.exp_key == "": | |
args.exp_key = generate_exp_key() | |
args = add_paths(args) | |
if op.exists(args.args_p) and args.exp_key not in [DUMMY_EXP]: | |
with open(args.args_p, "r") as f: | |
args_disk = json.load(f) | |
if "comet_key" in args_disk.keys(): | |
args.comet_key = args_disk["comet_key"] | |
create_files(args) | |
project_name = args.project | |
disabled = args.mute | |
comet_url = args["comet_key"] if "comet_key" in args.keys() else None | |
api_key = os.environ["COMET_API_KEY"] | |
workspace = os.environ["COMET_WORKSPACE"] | |
if not args.cluster: | |
if comet_url is None: | |
experiment = comet_ml.Experiment( | |
api_key=api_key, | |
workspace=workspace, | |
project_name=project_name, | |
disabled=disabled, | |
display_summary_level=0, | |
) | |
args.comet_key = experiment.get_key() | |
else: | |
experiment = comet_ml.ExistingExperiment( | |
previous_experiment=comet_url, | |
api_key=api_key, | |
project_name=project_name, | |
workspace=workspace, | |
disabled=disabled, | |
display_summary_level=0, | |
) | |
device = "cuda" if torch.cuda.is_available() else "cpu" | |
logger.add( | |
os.path.join(args.log_dir, "train.log"), | |
level="INFO", | |
colorize=True, | |
) | |
logger.info(torch.cuda.get_device_properties(device)) | |
args.gpu = torch.cuda.get_device_properties(device).name | |
else: | |
experiment = None | |
args.experiment = experiment | |
return experiment, args | |
def log_dict(experiment, metric_dict, step, postfix=None): | |
if experiment is None: | |
return | |
for key, value in metric_dict.items(): | |
if postfix is not None: | |
key = key + postfix | |
if isinstance(value, torch.Tensor) and len(value.view(-1)) == 1: | |
value = value.item() | |
if isinstance(value, (int, float, np.float32)): | |
experiment.log_metric(key, value, step=step) | |
def generate_exp_key(): | |
import random | |
hash = random.getrandbits(128) | |
key = "%032x" % (hash) | |
key = key[:9] | |
return key | |
def push_images(experiment, all_im_list, global_step=None, no_tqdm=False, verbose=True): | |
if verbose: | |
print("Pushing PIL images") | |
tic = time.time() | |
iterator = all_im_list if no_tqdm else tqdm(all_im_list) | |
for im in iterator: | |
im_np = np.array(im["im"]) | |
if "fig_name" in im.keys(): | |
experiment.log_image(im_np, im["fig_name"], step=global_step) | |
else: | |
experiment.log_image(im_np, "unnamed", step=global_step) | |
if verbose: | |
toc = time.time() | |
print("Done pushing PIL images (%.1fs)" % (toc - tic)) | |