Spaces:
Running
on
A10G
Running
on
A10G
import os | |
import re | |
import random | |
import time | |
import torch | |
import torch.nn as nn | |
import logging | |
import numpy as np | |
from os import path as osp | |
def constant_init(module, val, bias=0): | |
if hasattr(module, 'weight') and module.weight is not None: | |
nn.init.constant_(module.weight, val) | |
if hasattr(module, 'bias') and module.bias is not None: | |
nn.init.constant_(module.bias, bias) | |
initialized_logger = {} | |
def get_root_logger(logger_name='basicsr', log_level=logging.INFO, log_file=None): | |
"""Get the root logger. | |
The logger will be initialized if it has not been initialized. By default a | |
StreamHandler will be added. If `log_file` is specified, a FileHandler will | |
also be added. | |
Args: | |
logger_name (str): root logger name. Default: 'basicsr'. | |
log_file (str | None): The log filename. If specified, a FileHandler | |
will be added to the root logger. | |
log_level (int): The root logger level. Note that only the process of | |
rank 0 is affected, while other processes will set the level to | |
"Error" and be silent most of the time. | |
Returns: | |
logging.Logger: The root logger. | |
""" | |
logger = logging.getLogger(logger_name) | |
# if the logger has been initialized, just return it | |
if logger_name in initialized_logger: | |
return logger | |
format_str = '%(asctime)s %(levelname)s: %(message)s' | |
stream_handler = logging.StreamHandler() | |
stream_handler.setFormatter(logging.Formatter(format_str)) | |
logger.addHandler(stream_handler) | |
logger.propagate = False | |
if log_file is not None: | |
logger.setLevel(log_level) | |
# add file handler | |
# file_handler = logging.FileHandler(log_file, 'w') | |
file_handler = logging.FileHandler(log_file, 'a') #Shangchen: keep the previous log | |
file_handler.setFormatter(logging.Formatter(format_str)) | |
file_handler.setLevel(log_level) | |
logger.addHandler(file_handler) | |
initialized_logger[logger_name] = True | |
return logger | |
IS_HIGH_VERSION = [int(m) for m in list(re.findall(r"^([0-9]+)\.([0-9]+)\.([0-9]+)([^0-9][a-zA-Z0-9]*)?(\+git.*)?$",\ | |
torch.__version__)[0][:3])] >= [1, 12, 0] | |
def gpu_is_available(): | |
if IS_HIGH_VERSION: | |
if torch.backends.mps.is_available(): | |
return True | |
return True if torch.cuda.is_available() and torch.backends.cudnn.is_available() else False | |
def get_device(gpu_id=None): | |
if gpu_id is None: | |
gpu_str = '' | |
elif isinstance(gpu_id, int): | |
gpu_str = f':{gpu_id}' | |
else: | |
raise TypeError('Input should be int value.') | |
if IS_HIGH_VERSION: | |
if torch.backends.mps.is_available(): | |
return torch.device('mps'+gpu_str) | |
return torch.device('cuda'+gpu_str if torch.cuda.is_available() and torch.backends.cudnn.is_available() else 'cpu') | |
def set_random_seed(seed): | |
"""Set random seeds.""" | |
random.seed(seed) | |
np.random.seed(seed) | |
torch.manual_seed(seed) | |
torch.cuda.manual_seed(seed) | |
torch.cuda.manual_seed_all(seed) | |
def get_time_str(): | |
return time.strftime('%Y%m%d_%H%M%S', time.localtime()) | |
def scandir(dir_path, suffix=None, recursive=False, full_path=False): | |
"""Scan a directory to find the interested files. | |
Args: | |
dir_path (str): Path of the directory. | |
suffix (str | tuple(str), optional): File suffix that we are | |
interested in. Default: None. | |
recursive (bool, optional): If set to True, recursively scan the | |
directory. Default: False. | |
full_path (bool, optional): If set to True, include the dir_path. | |
Default: False. | |
Returns: | |
A generator for all the interested files with relative pathes. | |
""" | |
if (suffix is not None) and not isinstance(suffix, (str, tuple)): | |
raise TypeError('"suffix" must be a string or tuple of strings') | |
root = dir_path | |
def _scandir(dir_path, suffix, recursive): | |
for entry in os.scandir(dir_path): | |
if not entry.name.startswith('.') and entry.is_file(): | |
if full_path: | |
return_path = entry.path | |
else: | |
return_path = osp.relpath(entry.path, root) | |
if suffix is None: | |
yield return_path | |
elif return_path.endswith(suffix): | |
yield return_path | |
else: | |
if recursive: | |
yield from _scandir(entry.path, suffix=suffix, recursive=recursive) | |
else: | |
continue | |
return _scandir(dir_path, suffix=suffix, recursive=recursive) |