|
import os |
|
import sys |
|
import warnings |
|
from importlib.util import find_spec |
|
from pathlib import Path |
|
from typing import Any, Callable, Dict, Tuple |
|
|
|
import gdown |
|
import matplotlib.pyplot as plt |
|
import numpy as np |
|
import torch |
|
import wget |
|
from omegaconf import DictConfig |
|
|
|
from matcha.utils import pylogger, rich_utils |
|
|
|
log = pylogger.get_pylogger(__name__) |
|
|
|
|
|
def extras(cfg: DictConfig) -> None: |
|
"""Applies optional utilities before the task is started. |
|
|
|
Utilities: |
|
- Ignoring python warnings |
|
- Setting tags from command line |
|
- Rich config printing |
|
|
|
:param cfg: A DictConfig object containing the config tree. |
|
""" |
|
|
|
if not cfg.get("extras"): |
|
log.warning("Extras config not found! <cfg.extras=null>") |
|
return |
|
|
|
|
|
if cfg.extras.get("ignore_warnings"): |
|
log.info("Disabling python warnings! <cfg.extras.ignore_warnings=True>") |
|
warnings.filterwarnings("ignore") |
|
|
|
|
|
if cfg.extras.get("enforce_tags"): |
|
log.info("Enforcing tags! <cfg.extras.enforce_tags=True>") |
|
rich_utils.enforce_tags(cfg, save_to_file=True) |
|
|
|
|
|
if cfg.extras.get("print_config"): |
|
log.info("Printing config tree with Rich! <cfg.extras.print_config=True>") |
|
rich_utils.print_config_tree(cfg, resolve=True, save_to_file=True) |
|
|
|
|
|
def task_wrapper(task_func: Callable) -> Callable: |
|
"""Optional decorator that controls the failure behavior when executing the task function. |
|
|
|
This wrapper can be used to: |
|
- make sure loggers are closed even if the task function raises an exception (prevents multirun failure) |
|
- save the exception to a `.log` file |
|
- mark the run as failed with a dedicated file in the `logs/` folder (so we can find and rerun it later) |
|
- etc. (adjust depending on your needs) |
|
|
|
Example: |
|
``` |
|
@utils.task_wrapper |
|
def train(cfg: DictConfig) -> Tuple[Dict[str, Any], Dict[str, Any]]: |
|
... |
|
return metric_dict, object_dict |
|
``` |
|
|
|
:param task_func: The task function to be wrapped. |
|
|
|
:return: The wrapped task function. |
|
""" |
|
|
|
def wrap(cfg: DictConfig) -> Tuple[Dict[str, Any], Dict[str, Any]]: |
|
|
|
try: |
|
metric_dict, object_dict = task_func(cfg=cfg) |
|
|
|
|
|
except Exception as ex: |
|
|
|
log.exception("") |
|
|
|
|
|
|
|
|
|
raise ex |
|
|
|
|
|
finally: |
|
|
|
log.info(f"Output dir: {cfg.paths.output_dir}") |
|
|
|
|
|
if find_spec("wandb"): |
|
import wandb |
|
|
|
if wandb.run: |
|
log.info("Closing wandb!") |
|
wandb.finish() |
|
|
|
return metric_dict, object_dict |
|
|
|
return wrap |
|
|
|
|
|
def get_metric_value(metric_dict: Dict[str, Any], metric_name: str) -> float: |
|
"""Safely retrieves value of the metric logged in LightningModule. |
|
|
|
:param metric_dict: A dict containing metric values. |
|
:param metric_name: The name of the metric to retrieve. |
|
:return: The value of the metric. |
|
""" |
|
if not metric_name: |
|
log.info("Metric name is None! Skipping metric value retrieval...") |
|
return None |
|
|
|
if metric_name not in metric_dict: |
|
raise ValueError( |
|
f"Metric value not found! <metric_name={metric_name}>\n" |
|
"Make sure metric name logged in LightningModule is correct!\n" |
|
"Make sure `optimized_metric` name in `hparams_search` config is correct!" |
|
) |
|
|
|
metric_value = metric_dict[metric_name].item() |
|
log.info(f"Retrieved metric value! <{metric_name}={metric_value}>") |
|
|
|
return metric_value |
|
|
|
|
|
def intersperse(lst, item): |
|
|
|
result = [item] * (len(lst) * 2 + 1) |
|
result[1::2] = lst |
|
return result |
|
|
|
|
|
def save_figure_to_numpy(fig): |
|
data = np.fromstring(fig.canvas.tostring_rgb(), dtype=np.uint8, sep="") |
|
data = data.reshape(fig.canvas.get_width_height()[::-1] + (3,)) |
|
return data |
|
|
|
|
|
def plot_tensor(tensor): |
|
plt.style.use("default") |
|
fig, ax = plt.subplots(figsize=(12, 3)) |
|
im = ax.imshow(tensor, aspect="auto", origin="lower", interpolation="none") |
|
plt.colorbar(im, ax=ax) |
|
plt.tight_layout() |
|
fig.canvas.draw() |
|
data = save_figure_to_numpy(fig) |
|
plt.close() |
|
return data |
|
|
|
|
|
def save_plot(tensor, savepath): |
|
plt.style.use("default") |
|
fig, ax = plt.subplots(figsize=(12, 3)) |
|
im = ax.imshow(tensor, aspect="auto", origin="lower", interpolation="none") |
|
plt.colorbar(im, ax=ax) |
|
plt.tight_layout() |
|
fig.canvas.draw() |
|
plt.savefig(savepath) |
|
plt.close() |
|
|
|
|
|
def to_numpy(tensor): |
|
if isinstance(tensor, np.ndarray): |
|
return tensor |
|
elif isinstance(tensor, torch.Tensor): |
|
return tensor.detach().cpu().numpy() |
|
elif isinstance(tensor, list): |
|
return np.array(tensor) |
|
else: |
|
raise TypeError("Unsupported type for conversion to numpy array") |
|
|
|
|
|
def get_user_data_dir(appname="matcha_tts"): |
|
""" |
|
Args: |
|
appname (str): Name of application |
|
|
|
Returns: |
|
Path: path to user data directory |
|
""" |
|
|
|
MATCHA_HOME = os.environ.get("MATCHA_HOME") |
|
if MATCHA_HOME is not None: |
|
ans = Path(MATCHA_HOME).expanduser().resolve(strict=False) |
|
elif sys.platform == "win32": |
|
import winreg |
|
|
|
key = winreg.OpenKey( |
|
winreg.HKEY_CURRENT_USER, |
|
r"Software\Microsoft\Windows\CurrentVersion\Explorer\Shell Folders", |
|
) |
|
dir_, _ = winreg.QueryValueEx(key, "Local AppData") |
|
ans = Path(dir_).resolve(strict=False) |
|
elif sys.platform == "darwin": |
|
ans = Path("~/Library/Application Support/").expanduser() |
|
else: |
|
ans = Path.home().joinpath(".local/share") |
|
|
|
final_path = ans.joinpath(appname) |
|
final_path.mkdir(parents=True, exist_ok=True) |
|
return final_path |
|
|
|
|
|
def assert_model_downloaded(checkpoint_path, url, use_wget=True): |
|
if Path(checkpoint_path).exists(): |
|
log.debug(f"[+] Model already present at {checkpoint_path}!") |
|
print(f"[+] Model already present at {checkpoint_path}!") |
|
return |
|
log.info(f"[-] Model not found at {checkpoint_path}! Will download it") |
|
print(f"[-] Model not found at {checkpoint_path}! Will download it") |
|
checkpoint_path = str(checkpoint_path) |
|
if not use_wget: |
|
gdown.download(url=url, output=checkpoint_path, quiet=False, fuzzy=True) |
|
else: |
|
wget.download(url=url, out=checkpoint_path) |
|
|