Spaces:
Running
on
A10G
Running
on
A10G
File size: 4,637 Bytes
320e465 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 |
import os
import re
import random
import time
import torch
import torch.nn as nn
import logging
import numpy as np
from os import path as osp
def constant_init(module, val, bias=0):
if hasattr(module, 'weight') and module.weight is not None:
nn.init.constant_(module.weight, val)
if hasattr(module, 'bias') and module.bias is not None:
nn.init.constant_(module.bias, bias)
initialized_logger = {}
def get_root_logger(logger_name='basicsr', log_level=logging.INFO, log_file=None):
"""Get the root logger.
The logger will be initialized if it has not been initialized. By default a
StreamHandler will be added. If `log_file` is specified, a FileHandler will
also be added.
Args:
logger_name (str): root logger name. Default: 'basicsr'.
log_file (str | None): The log filename. If specified, a FileHandler
will be added to the root logger.
log_level (int): The root logger level. Note that only the process of
rank 0 is affected, while other processes will set the level to
"Error" and be silent most of the time.
Returns:
logging.Logger: The root logger.
"""
logger = logging.getLogger(logger_name)
# if the logger has been initialized, just return it
if logger_name in initialized_logger:
return logger
format_str = '%(asctime)s %(levelname)s: %(message)s'
stream_handler = logging.StreamHandler()
stream_handler.setFormatter(logging.Formatter(format_str))
logger.addHandler(stream_handler)
logger.propagate = False
if log_file is not None:
logger.setLevel(log_level)
# add file handler
# file_handler = logging.FileHandler(log_file, 'w')
file_handler = logging.FileHandler(log_file, 'a') #Shangchen: keep the previous log
file_handler.setFormatter(logging.Formatter(format_str))
file_handler.setLevel(log_level)
logger.addHandler(file_handler)
initialized_logger[logger_name] = True
return logger
IS_HIGH_VERSION = [int(m) for m in list(re.findall(r"^([0-9]+)\.([0-9]+)\.([0-9]+)([^0-9][a-zA-Z0-9]*)?(\+git.*)?$",\
torch.__version__)[0][:3])] >= [1, 12, 0]
def gpu_is_available():
if IS_HIGH_VERSION:
if torch.backends.mps.is_available():
return True
return True if torch.cuda.is_available() and torch.backends.cudnn.is_available() else False
def get_device(gpu_id=None):
if gpu_id is None:
gpu_str = ''
elif isinstance(gpu_id, int):
gpu_str = f':{gpu_id}'
else:
raise TypeError('Input should be int value.')
if IS_HIGH_VERSION:
if torch.backends.mps.is_available():
return torch.device('mps'+gpu_str)
return torch.device('cuda'+gpu_str if torch.cuda.is_available() and torch.backends.cudnn.is_available() else 'cpu')
def set_random_seed(seed):
"""Set random seeds."""
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
def get_time_str():
return time.strftime('%Y%m%d_%H%M%S', time.localtime())
def scandir(dir_path, suffix=None, recursive=False, full_path=False):
"""Scan a directory to find the interested files.
Args:
dir_path (str): Path of the directory.
suffix (str | tuple(str), optional): File suffix that we are
interested in. Default: None.
recursive (bool, optional): If set to True, recursively scan the
directory. Default: False.
full_path (bool, optional): If set to True, include the dir_path.
Default: False.
Returns:
A generator for all the interested files with relative pathes.
"""
if (suffix is not None) and not isinstance(suffix, (str, tuple)):
raise TypeError('"suffix" must be a string or tuple of strings')
root = dir_path
def _scandir(dir_path, suffix, recursive):
for entry in os.scandir(dir_path):
if not entry.name.startswith('.') and entry.is_file():
if full_path:
return_path = entry.path
else:
return_path = osp.relpath(entry.path, root)
if suffix is None:
yield return_path
elif return_path.endswith(suffix):
yield return_path
else:
if recursive:
yield from _scandir(entry.path, suffix=suffix, recursive=recursive)
else:
continue
return _scandir(dir_path, suffix=suffix, recursive=recursive) |