|
import logging |
|
import os |
|
import time |
|
|
|
import cv2 |
|
import numpy as np |
|
import torch |
|
import yaml |
|
from matplotlib import colors |
|
from matplotlib import pyplot as plt |
|
from torch import Tensor, nn |
|
from torch.utils.data import ConcatDataset |
|
|
|
class CharsetMapper(object): |
|
"""A simple class to map ids into strings. |
|
|
|
It works only when the character set is 1:1 mapping between individual |
|
characters and individual ids. |
|
""" |
|
|
|
def __init__(self, |
|
filename='', |
|
max_length=30, |
|
null_char=u'\u2591'): |
|
"""Creates a lookup table. |
|
|
|
Args: |
|
filename: Path to charset file which maps characters to ids. |
|
max_sequence_length: The max length of ids and string. |
|
null_char: A unicode character used to replace '<null>' character. |
|
the default value is a light shade block '░'. |
|
""" |
|
self.null_char = null_char |
|
self.max_length = max_length |
|
|
|
self.label_to_char = self._read_charset(filename) |
|
self.char_to_label = dict(map(reversed, self.label_to_char.items())) |
|
self.num_classes = len(self.label_to_char) |
|
|
|
def _read_charset(self, filename): |
|
"""Reads a charset definition from a tab separated text file. |
|
|
|
Args: |
|
filename: a path to the charset file. |
|
|
|
Returns: |
|
a dictionary with keys equal to character codes and values - unicode |
|
characters. |
|
""" |
|
import re |
|
pattern = re.compile(r'(\d+)\t(.+)') |
|
charset = {} |
|
self.null_label = 0 |
|
charset[self.null_label] = self.null_char |
|
with open(filename, 'r') as f: |
|
for i, line in enumerate(f): |
|
m = pattern.match(line) |
|
assert m, f'Incorrect charset file. line #{i}: {line}' |
|
label = int(m.group(1)) + 1 |
|
char = m.group(2) |
|
charset[label] = char |
|
return charset |
|
|
|
def trim(self, text): |
|
assert isinstance(text, str) |
|
return text.replace(self.null_char, '') |
|
|
|
def get_text(self, labels, length=None, padding=True, trim=False): |
|
""" Returns a string corresponding to a sequence of character ids. |
|
""" |
|
length = length if length else self.max_length |
|
labels = [l.item() if isinstance(l, Tensor) else int(l) for l in labels] |
|
if padding: |
|
labels = labels + [self.null_label] * (length-len(labels)) |
|
text = ''.join([self.label_to_char[label] for label in labels]) |
|
if trim: text = self.trim(text) |
|
return text |
|
|
|
def get_labels(self, text, length=None, padding=True, case_sensitive=False): |
|
""" Returns the labels of the corresponding text. |
|
""" |
|
length = length if length else self.max_length |
|
if padding: |
|
text = text + self.null_char * (length - len(text)) |
|
if not case_sensitive: |
|
text = text.lower() |
|
labels = [self.char_to_label[char] for char in text] |
|
return labels |
|
|
|
def pad_labels(self, labels, length=None): |
|
length = length if length else self.max_length |
|
|
|
return labels + [self.null_label] * (length - len(labels)) |
|
|
|
@property |
|
def digits(self): |
|
return '0123456789' |
|
|
|
@property |
|
def digit_labels(self): |
|
return self.get_labels(self.digits, padding=False) |
|
|
|
@property |
|
def alphabets(self): |
|
all_chars = list(self.char_to_label.keys()) |
|
valid_chars = [] |
|
for c in all_chars: |
|
if c in 'abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ': |
|
valid_chars.append(c) |
|
return ''.join(valid_chars) |
|
|
|
@property |
|
def alphabet_labels(self): |
|
return self.get_labels(self.alphabets, padding=False) |
|
|
|
|
|
class Timer(object): |
|
"""A simple timer.""" |
|
def __init__(self): |
|
self.data_time = 0. |
|
self.data_diff = 0. |
|
self.data_total_time = 0. |
|
self.data_call = 0 |
|
self.running_time = 0. |
|
self.running_diff = 0. |
|
self.running_total_time = 0. |
|
self.running_call = 0 |
|
|
|
def tic(self): |
|
self.start_time = time.time() |
|
self.running_time = self.start_time |
|
|
|
def toc_data(self): |
|
self.data_time = time.time() |
|
self.data_diff = self.data_time - self.running_time |
|
self.data_total_time += self.data_diff |
|
self.data_call += 1 |
|
|
|
def toc_running(self): |
|
self.running_time = time.time() |
|
self.running_diff = self.running_time - self.data_time |
|
self.running_total_time += self.running_diff |
|
self.running_call += 1 |
|
|
|
def total_time(self): |
|
return self.data_total_time + self.running_total_time |
|
|
|
def average_time(self): |
|
return self.average_data_time() + self.average_running_time() |
|
|
|
def average_data_time(self): |
|
return self.data_total_time / (self.data_call or 1) |
|
|
|
def average_running_time(self): |
|
return self.running_total_time / (self.running_call or 1) |
|
|
|
|
|
class Logger(object): |
|
_handle = None |
|
_root = None |
|
|
|
@staticmethod |
|
def init(output_dir, name, phase): |
|
format = '[%(asctime)s %(filename)s:%(lineno)d %(levelname)s {}] ' \ |
|
'%(message)s'.format(name) |
|
logging.basicConfig(level=logging.INFO, format=format) |
|
|
|
try: os.makedirs(output_dir) |
|
except: pass |
|
config_path = os.path.join(output_dir, f'{phase}.txt') |
|
Logger._handle = logging.FileHandler(config_path) |
|
Logger._root = logging.getLogger() |
|
|
|
@staticmethod |
|
def enable_file(): |
|
if Logger._handle is None or Logger._root is None: |
|
raise Exception('Invoke Logger.init() first!') |
|
Logger._root.addHandler(Logger._handle) |
|
|
|
@staticmethod |
|
def disable_file(): |
|
if Logger._handle is None or Logger._root is None: |
|
raise Exception('Invoke Logger.init() first!') |
|
Logger._root.removeHandler(Logger._handle) |
|
|
|
|
|
class Config(object): |
|
|
|
def __init__(self, config_path, host=True): |
|
def __dict2attr(d, prefix=''): |
|
for k, v in d.items(): |
|
if isinstance(v, dict): |
|
__dict2attr(v, f'{prefix}{k}_') |
|
else: |
|
if k == 'phase': |
|
assert v in ['train', 'test'] |
|
if k == 'stage': |
|
assert v in ['pretrain-vision', 'pretrain-language', |
|
'train-semi-super', 'train-super'] |
|
self.__setattr__(f'{prefix}{k}', v) |
|
|
|
assert os.path.exists(config_path), '%s does not exists!' % config_path |
|
with open(config_path) as file: |
|
config_dict = yaml.load(file, Loader=yaml.FullLoader) |
|
with open('configs/template.yaml') as file: |
|
default_config_dict = yaml.load(file, Loader=yaml.FullLoader) |
|
__dict2attr(default_config_dict) |
|
__dict2attr(config_dict) |
|
self.global_workdir = os.path.join(self.global_workdir, self.global_name) |
|
|
|
def __getattr__(self, item): |
|
attr = self.__dict__.get(item) |
|
if attr is None: |
|
attr = dict() |
|
prefix = f'{item}_' |
|
for k, v in self.__dict__.items(): |
|
if k.startswith(prefix): |
|
n = k.replace(prefix, '') |
|
attr[n] = v |
|
return attr if len(attr) > 0 else None |
|
else: |
|
return attr |
|
|
|
def __repr__(self): |
|
str = 'ModelConfig(\n' |
|
for i, (k, v) in enumerate(sorted(vars(self).items())): |
|
str += f'\t({i}): {k} = {v}\n' |
|
str += ')' |
|
return str |
|
|
|
def blend_mask(image, mask, alpha=0.5, cmap='jet', color='b', color_alpha=1.0): |
|
|
|
mask = (mask-mask.min()) / (mask.max() - mask.min() + np.finfo(float).eps) |
|
if mask.shape != image.shape: |
|
mask = cv2.resize(mask,(image.shape[1], image.shape[0])) |
|
|
|
color_map = plt.get_cmap(cmap) |
|
mask = color_map(mask)[:,:,:3] |
|
|
|
mask = (mask * 255).astype(dtype=np.uint8) |
|
|
|
|
|
basic_color = np.array(colors.to_rgb(color)) * 255 |
|
basic_color = np.tile(basic_color, [image.shape[0], image.shape[1], 1]) |
|
basic_color = basic_color.astype(dtype=np.uint8) |
|
|
|
blended_img = cv2.addWeighted(image, color_alpha, basic_color, 1-color_alpha, 0) |
|
|
|
blended_img = cv2.addWeighted(blended_img, alpha, mask, 1-alpha, 0) |
|
|
|
return blended_img |
|
|
|
def onehot(label, depth, device=None): |
|
""" |
|
Args: |
|
label: shape (n1, n2, ..., ) |
|
depth: a scalar |
|
|
|
Returns: |
|
onehot: (n1, n2, ..., depth) |
|
""" |
|
if not isinstance(label, torch.Tensor): |
|
label = torch.tensor(label, device=device) |
|
onehot = torch.zeros(label.size() + torch.Size([depth]), device=device) |
|
onehot = onehot.scatter_(-1, label.unsqueeze(-1), 1) |
|
|
|
return onehot |
|
|
|
class MyDataParallel(nn.DataParallel): |
|
|
|
def gather(self, outputs, target_device): |
|
r""" |
|
Gathers tensors from different GPUs on a specified device |
|
(-1 means the CPU). |
|
""" |
|
def gather_map(outputs): |
|
out = outputs[0] |
|
if isinstance(out, (str, int, float)): |
|
return out |
|
if isinstance(out, list) and isinstance(out[0], str): |
|
return [o for out in outputs for o in out] |
|
if isinstance(out, torch.Tensor): |
|
return torch.nn.parallel._functions.Gather.apply(target_device, self.dim, *outputs) |
|
if out is None: |
|
return None |
|
if isinstance(out, dict): |
|
if not all((len(out) == len(d) for d in outputs)): |
|
raise ValueError('All dicts must have the same number of keys') |
|
return type(out)(((k, gather_map([d[k] for d in outputs])) |
|
for k in out)) |
|
return type(out)(map(gather_map, zip(*outputs))) |
|
|
|
|
|
|
|
try: |
|
res = gather_map(outputs) |
|
finally: |
|
gather_map = None |
|
return res |
|
|
|
|
|
class MyConcatDataset(ConcatDataset): |
|
def __getattr__(self, k): |
|
return getattr(self.datasets[0], k) |
|
|