File size: 4,782 Bytes
d6d3a5b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
import json
import os
import os.path as op
import time

import comet_ml
import numpy as np
import torch
from loguru import logger
from tqdm import tqdm

from src.datasets.dataset_utils import copy_repo_arctic

# folder used for debugging
DUMMY_EXP = "xxxxxxxxx"


def add_paths(args):
    exp_key = args.exp_key
    args_p = f"./logs/{exp_key}/args.json"
    ckpt_p = f"./logs/{exp_key}/checkpoints/last.ckpt"
    if not op.exists(ckpt_p) or DUMMY_EXP in ckpt_p:
        ckpt_p = ""
    if args.resume_ckpt != "":
        ckpt_p = args.resume_ckpt
    args.ckpt_p = ckpt_p
    args.log_dir = f"./logs/{exp_key}"

    if args.infer_ckpt != "":
        basedir = "/".join(args.infer_ckpt.split("/")[:2])
        basename = op.basename(args.infer_ckpt).replace(".ckpt", ".params.pt")
        args.interface_p = op.join(basedir, basename)
    args.args_p = args_p
    if args.cluster:
        args.run_p = op.join(args.log_dir, "condor", "run.sh")
        args.submit_p = op.join(args.log_dir, "condor", "submit.sub")
        args.repo_p = op.join(args.log_dir, "repo")

    return args


def save_args(args, save_keys):
    args_save = {}
    for key, val in args.items():
        if key in save_keys:
            args_save[key] = val
    with open(args.args_p, "w") as f:
        json.dump(args_save, f, indent=4)
    logger.info(f"Saved args at {args.args_p}")


def create_files(args):
    os.makedirs(args.log_dir, exist_ok=True)
    if args.cluster:
        os.makedirs(op.dirname(args.run_p), exist_ok=True)
        copy_repo_arctic(args.exp_key)


def log_exp_meta(args):
    tags = [args.method]
    logger.info(f"Experiment tags: {tags}")
    args.experiment.set_name(args.exp_key)
    args.experiment.add_tags(tags)
    args.experiment.log_parameters(args)


def init_experiment(args):
    if args.resume_ckpt != "":
        args.exp_key = args.resume_ckpt.split("/")[1]
    if args.fast_dev_run:
        args.exp_key = DUMMY_EXP
    if args.exp_key == "":
        args.exp_key = generate_exp_key()
    args = add_paths(args)
    if op.exists(args.args_p) and args.exp_key not in [DUMMY_EXP]:
        with open(args.args_p, "r") as f:
            args_disk = json.load(f)
            if "comet_key" in args_disk.keys():
                args.comet_key = args_disk["comet_key"]

    create_files(args)

    project_name = args.project
    disabled = args.mute
    comet_url = args["comet_key"] if "comet_key" in args.keys() else None

    api_key = os.environ["COMET_API_KEY"]
    workspace = os.environ["COMET_WORKSPACE"]
    if not args.cluster:
        if comet_url is None:
            experiment = comet_ml.Experiment(
                api_key=api_key,
                workspace=workspace,
                project_name=project_name,
                disabled=disabled,
                display_summary_level=0,
            )
            args.comet_key = experiment.get_key()
        else:
            experiment = comet_ml.ExistingExperiment(
                previous_experiment=comet_url,
                api_key=api_key,
                project_name=project_name,
                workspace=workspace,
                disabled=disabled,
                display_summary_level=0,
            )

        device = "cuda" if torch.cuda.is_available() else "cpu"
        logger.add(
            os.path.join(args.log_dir, "train.log"),
            level="INFO",
            colorize=True,
        )
        logger.info(torch.cuda.get_device_properties(device))
        args.gpu = torch.cuda.get_device_properties(device).name
    else:
        experiment = None
    args.experiment = experiment
    return experiment, args


def log_dict(experiment, metric_dict, step, postfix=None):
    if experiment is None:
        return
    for key, value in metric_dict.items():
        if postfix is not None:
            key = key + postfix
        if isinstance(value, torch.Tensor) and len(value.view(-1)) == 1:
            value = value.item()

        if isinstance(value, (int, float, np.float32)):
            experiment.log_metric(key, value, step=step)


def generate_exp_key():
    import random

    hash = random.getrandbits(128)
    key = "%032x" % (hash)
    key = key[:9]
    return key


def push_images(experiment, all_im_list, global_step=None, no_tqdm=False, verbose=True):
    if verbose:
        print("Pushing PIL images")
        tic = time.time()
    iterator = all_im_list if no_tqdm else tqdm(all_im_list)
    for im in iterator:
        im_np = np.array(im["im"])
        if "fig_name" in im.keys():
            experiment.log_image(im_np, im["fig_name"], step=global_step)
        else:
            experiment.log_image(im_np, "unnamed", step=global_step)
    if verbose:
        toc = time.time()
        print("Done pushing PIL images (%.1fs)" % (toc - tic))