Spaces:
Runtime error
Runtime error
import glob | |
import os | |
import subprocess | |
import warnings | |
from argparse import Namespace | |
from importlib.util import find_spec | |
from pathlib import Path | |
from typing import Callable | |
import matplotlib.pyplot as plt | |
import numpy as np | |
import soundfile as sf | |
import torch | |
from omegaconf import DictConfig | |
from diff_ttsg.utils import pylogger, rich_utils | |
log = pylogger.get_pylogger(__name__) | |
def extras(cfg: DictConfig) -> None: | |
"""Applies optional utilities before the task is started. | |
Utilities: | |
- Ignoring python warnings | |
- Setting tags from command line | |
- Rich config printing | |
""" | |
# return if no `extras` config | |
if not cfg.get("extras"): | |
log.warning("Extras config not found! <cfg.extras=null>") | |
return | |
# disable python warnings | |
if cfg.extras.get("ignore_warnings"): | |
log.info("Disabling python warnings! <cfg.extras.ignore_warnings=True>") | |
warnings.filterwarnings("ignore") | |
# prompt user to input tags from command line if none are provided in the config | |
if cfg.extras.get("enforce_tags"): | |
log.info("Enforcing tags! <cfg.extras.enforce_tags=True>") | |
rich_utils.enforce_tags(cfg, save_to_file=True) | |
# pretty print config tree using Rich library | |
if cfg.extras.get("print_config"): | |
log.info("Printing config tree with Rich! <cfg.extras.print_config=True>") | |
rich_utils.print_config_tree(cfg, resolve=True, save_to_file=True) | |
def task_wrapper(task_func: Callable) -> Callable: | |
"""Optional decorator that controls the failure behavior when executing the task function. | |
This wrapper can be used to: | |
- make sure loggers are closed even if the task function raises an exception (prevents multirun failure) | |
- save the exception to a `.log` file | |
- mark the run as failed with a dedicated file in the `logs/` folder (so we can find and rerun it later) | |
- etc. (adjust depending on your needs) | |
Example: | |
``` | |
@utils.task_wrapper | |
def train(cfg: DictConfig) -> Tuple[dict, dict]: | |
... | |
return metric_dict, object_dict | |
``` | |
""" | |
def wrap(cfg: DictConfig): | |
# execute the task | |
try: | |
metric_dict, object_dict = task_func(cfg=cfg) | |
# things to do if exception occurs | |
except Exception as ex: | |
# save exception to `.log` file | |
log.exception("") | |
# some hyperparameter combinations might be invalid or cause out-of-memory errors | |
# so when using hparam search plugins like Optuna, you might want to disable | |
# raising the below exception to avoid multirun failure | |
raise ex | |
# things to always do after either success or exception | |
finally: | |
# display output dir path in terminal | |
log.info(f"Output dir: {cfg.paths.output_dir}") | |
# always close wandb run (even if exception occurs so multirun won't fail) | |
if find_spec("wandb"): # check if wandb is installed | |
import wandb | |
if wandb.run: | |
log.info("Closing wandb!") | |
wandb.finish() | |
return metric_dict, object_dict | |
return wrap | |
def get_metric_value(metric_dict: dict, metric_name: str) -> float: | |
"""Safely retrieves value of the metric logged in LightningModule.""" | |
if not metric_name: | |
log.info("Metric name is None! Skipping metric value retrieval...") | |
return None | |
if metric_name not in metric_dict: | |
raise Exception( | |
f"Metric value not found! <metric_name={metric_name}>\n" | |
"Make sure metric name logged in LightningModule is correct!\n" | |
"Make sure `optimized_metric` name in `hparams_search` config is correct!" | |
) | |
metric_value = metric_dict[metric_name].item() | |
log.info(f"Retrieved metric value! <{metric_name}={metric_value}>") | |
return metric_value | |
def intersperse(lst, item): | |
# Adds blank symbol | |
result = [item] * (len(lst) * 2 + 1) | |
result[1::2] = lst | |
return result | |
def parse_filelist(filelist_path, split_char="|"): | |
with open(filelist_path, encoding='utf-8') as f: | |
filepaths_and_text = [line.strip().split(split_char) for line in f] | |
return filepaths_and_text | |
def save_figure_to_numpy(fig): | |
data = np.fromstring(fig.canvas.tostring_rgb(), dtype=np.uint8, sep='') | |
data = data.reshape(fig.canvas.get_width_height()[::-1] + (3,)) | |
return data | |
def plot_tensor(tensor): | |
plt.style.use('default') | |
fig, ax = plt.subplots(figsize=(12, 3)) | |
im = ax.imshow(tensor, aspect="auto", origin="lower", interpolation='none') | |
plt.colorbar(im, ax=ax) | |
plt.tight_layout() | |
fig.canvas.draw() | |
data = save_figure_to_numpy(fig) | |
plt.close() | |
return data | |
def save_plot(tensor, savepath): | |
plt.style.use('default') | |
fig, ax = plt.subplots(figsize=(12, 3)) | |
im = ax.imshow(tensor, aspect="auto", origin="lower", interpolation='none') | |
plt.colorbar(im, ax=ax) | |
plt.tight_layout() | |
fig.canvas.draw() | |
plt.savefig(savepath) | |
plt.close() | |
return |