|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import logging |
|
import os |
|
import re |
|
from collections import OrderedDict |
|
from datetime import datetime |
|
|
|
import numpy as np |
|
import pytz |
|
import torch.distributed as dist |
|
from mmcv.utils.logging import logger_initialized |
|
from termcolor import colored |
|
|
|
from .dist_utils import is_local_master |
|
|
|
|
|
def get_root_logger( |
|
log_file=None, log_level=logging.INFO, name=colored("[Sana]", attrs=["bold"]), timezone="Asia/Shanghai" |
|
): |
|
"""Get root logger. |
|
|
|
Args: |
|
log_file (str, optional): File path of log. Defaults to None. |
|
log_level (int, optional): The level of logger. |
|
Defaults to logging.INFO. |
|
name (str): logger name |
|
Returns: |
|
:obj:`logging.Logger`: The obtained logger |
|
""" |
|
if log_file is None: |
|
log_file = "/dev/null" |
|
logger = get_logger(name=name, log_file=log_file, log_level=log_level, timezone=timezone) |
|
return logger |
|
|
|
|
|
class TimezoneFormatter(logging.Formatter): |
|
def __init__(self, fmt=None, datefmt=None, tz=None): |
|
super().__init__(fmt, datefmt) |
|
self.tz = pytz.timezone(tz) if tz else None |
|
|
|
def formatTime(self, record, datefmt=None): |
|
dt = datetime.fromtimestamp(record.created, self.tz) |
|
if datefmt: |
|
s = dt.strftime(datefmt) |
|
else: |
|
s = dt.isoformat() |
|
return s |
|
|
|
|
|
def get_logger(name, log_file=None, log_level=logging.INFO, timezone="UTC"): |
|
"""Initialize and get a logger by name. |
|
|
|
If the logger has not been initialized, this method will initialize the |
|
logger by adding one or two handlers, otherwise the initialized logger will |
|
be directly returned. During initialization, a StreamHandler will always be |
|
added. If `log_file` is specified and the process rank is 0, a FileHandler |
|
will also be added. |
|
|
|
Args: |
|
name (str): Logger name. |
|
log_file (str | None): The log filename. If specified, a FileHandler |
|
will be added to the logger. |
|
log_level (int): The logger level. Note that only the process of |
|
rank 0 is affected, and other processes will set the level to |
|
"Error" thus be silent most of the time. |
|
timezone (str): Timezone for the log timestamps. |
|
|
|
Returns: |
|
logging.Logger: The expected logger. |
|
""" |
|
logger = logging.getLogger(name) |
|
logger.propagate = False |
|
|
|
if name in logger_initialized: |
|
return logger |
|
|
|
|
|
|
|
for logger_name in logger_initialized: |
|
if name.startswith(logger_name): |
|
return logger |
|
|
|
stream_handler = logging.StreamHandler() |
|
handlers = [stream_handler] |
|
|
|
if dist.is_available() and dist.is_initialized(): |
|
rank = dist.get_rank() |
|
else: |
|
rank = 0 |
|
|
|
|
|
if rank == 0 and log_file is not None: |
|
file_handler = logging.FileHandler(log_file, "a") |
|
handlers.append(file_handler) |
|
|
|
formatter = TimezoneFormatter( |
|
"%(asctime)s - %(name)s - %(levelname)s - %(message)s", datefmt="%Y-%m-%d %H:%M:%S", tz=timezone |
|
) |
|
|
|
for handler in handlers: |
|
handler.setFormatter(formatter) |
|
handler.setLevel(log_level) |
|
logger.addHandler(handler) |
|
|
|
|
|
log_level = log_level if is_local_master() else logging.ERROR |
|
logger.setLevel(log_level) |
|
|
|
logger_initialized[name] = True |
|
|
|
return logger |
|
|
|
|
|
def rename_file_with_creation_time(file_path): |
|
|
|
creation_time = os.path.getctime(file_path) |
|
creation_time_str = datetime.fromtimestamp(creation_time).strftime("%Y-%m-%d_%H-%M-%S") |
|
|
|
|
|
dir_name, file_name = os.path.split(file_path) |
|
name, ext = os.path.splitext(file_name) |
|
new_file_name = f"{name}_{creation_time_str}{ext}" |
|
new_file_path = os.path.join(dir_name, new_file_name) |
|
|
|
|
|
os.rename(file_path, new_file_path) |
|
|
|
return new_file_path |
|
|
|
|
|
class TimezoneFormatter(logging.Formatter): |
|
def __init__(self, fmt=None, datefmt=None, tz=None): |
|
super().__init__(fmt, datefmt) |
|
self.tz = pytz.timezone(tz) if tz else None |
|
|
|
def formatTime(self, record, datefmt=None): |
|
dt = datetime.fromtimestamp(record.created, self.tz) |
|
if datefmt: |
|
s = dt.strftime(datefmt) |
|
else: |
|
s = dt.isoformat() |
|
return s |
|
|
|
|
|
class LogBuffer: |
|
def __init__(self): |
|
self.val_history = OrderedDict() |
|
self.n_history = OrderedDict() |
|
self.output = OrderedDict() |
|
self.ready = False |
|
|
|
def clear(self) -> None: |
|
self.val_history.clear() |
|
self.n_history.clear() |
|
self.clear_output() |
|
|
|
def clear_output(self) -> None: |
|
self.output.clear() |
|
self.ready = False |
|
|
|
def update(self, vars: dict, count: int = 1) -> None: |
|
assert isinstance(vars, dict) |
|
for key, var in vars.items(): |
|
if key not in self.val_history: |
|
self.val_history[key] = [] |
|
self.n_history[key] = [] |
|
self.val_history[key].append(var) |
|
self.n_history[key].append(count) |
|
|
|
def average(self, n: int = 0) -> None: |
|
"""Average latest n values or all values.""" |
|
assert n >= 0 |
|
for key in self.val_history: |
|
values = np.array(self.val_history[key][-n:]) |
|
nums = np.array(self.n_history[key][-n:]) |
|
avg = np.sum(values * nums) / np.sum(nums) |
|
self.output[key] = avg |
|
self.ready = True |
|
|
|
|
|
def tracker(args, result_dict, label="", pattern="epoch_step", metric="FID"): |
|
if args.report_to == "wandb": |
|
import wandb |
|
|
|
wandb_name = f"[{args.log_metric}]_{args.name}" |
|
wandb.init(project=args.tracker_project_name, name=wandb_name, resume="allow", id=wandb_name, tags="metrics") |
|
run = wandb.run |
|
if pattern == "step": |
|
pattern = "sample_steps" |
|
elif pattern == "epoch_step": |
|
pattern = "step" |
|
custom_name = f"custom_{pattern}" |
|
run.define_metric(custom_name) |
|
|
|
run.define_metric(f"{metric}_{label}", step_metric=custom_name) |
|
|
|
steps = [] |
|
results = [] |
|
|
|
def extract_value(regex, exp_name): |
|
match = re.search(regex, exp_name) |
|
if match: |
|
return match.group(1) |
|
else: |
|
return "unknown" |
|
|
|
for exp_name, result_value in result_dict.items(): |
|
if pattern == "step": |
|
regex = r".*step(\d+)_scale.*" |
|
custom_x = extract_value(regex, exp_name) |
|
elif pattern == "sample_steps": |
|
regex = r".*step(\d+)_size.*" |
|
custom_x = extract_value(regex, exp_name) |
|
else: |
|
regex = rf"{pattern}(\d+(\.\d+)?)" |
|
custom_x = extract_value(regex, exp_name) |
|
custom_x = 1 if custom_x == "unknown" else custom_x |
|
|
|
assert custom_x != "unknown" |
|
steps.append(float(custom_x)) |
|
results.append(result_value) |
|
|
|
sorted_data = sorted(zip(steps, results)) |
|
steps, results = zip(*sorted_data) |
|
|
|
for step, result in sorted(zip(steps, results)): |
|
run.log({f"{metric}_{label}": result, custom_name: step}) |
|
else: |
|
print(f"{args.report_to} is not supported") |
|
|