|
import argparse |
|
import logging |
|
import os |
|
import glob |
|
import tqdm |
|
import torch |
|
import PIL |
|
import cv2 |
|
import numpy as np |
|
import torch.nn.functional as F |
|
from torchvision import transforms |
|
from utils import Config, Logger, CharsetMapper |
|
|
|
def get_model(config): |
|
import importlib |
|
names = config.model_name.split('.') |
|
module_name, class_name = '.'.join(names[:-1]), names[-1] |
|
cls = getattr(importlib.import_module(module_name), class_name) |
|
model = cls(config) |
|
logging.info(model) |
|
model = model.eval() |
|
return model |
|
|
|
def preprocess(img, width, height): |
|
img = cv2.resize(np.array(img), (width, height)) |
|
img = transforms.ToTensor()(img).unsqueeze(0) |
|
mean = torch.tensor([0.485, 0.456, 0.406]) |
|
std = torch.tensor([0.229, 0.224, 0.225]) |
|
return (img-mean[...,None,None]) / std[...,None,None] |
|
|
|
def postprocess(output, charset, model_eval): |
|
def _get_output(last_output, model_eval): |
|
if isinstance(last_output, (tuple, list)): |
|
for res in last_output: |
|
if res['name'] == model_eval: output = res |
|
else: output = last_output |
|
return output |
|
|
|
def _decode(logit): |
|
""" Greed decode """ |
|
out = F.softmax(logit, dim=2) |
|
pt_text, pt_scores, pt_lengths = [], [], [] |
|
for o in out: |
|
text = charset.get_text(o.argmax(dim=1), padding=False, trim=False) |
|
text = text.split(charset.null_char)[0] |
|
pt_text.append(text) |
|
pt_scores.append(o.max(dim=1)[0]) |
|
pt_lengths.append(min(len(text) + 1, charset.max_length)) |
|
return pt_text, pt_scores, pt_lengths |
|
|
|
output = _get_output(output, model_eval) |
|
logits, pt_lengths = output['logits'], output['pt_lengths'] |
|
pt_text, pt_scores, pt_lengths_ = _decode(logits) |
|
|
|
return pt_text, pt_scores, pt_lengths_ |
|
|
|
def load(model, file, device=None, strict=True): |
|
if device is None: device = 'cpu' |
|
elif isinstance(device, int): device = torch.device('cuda', device) |
|
assert os.path.isfile(file) |
|
state = torch.load(file, map_location=device) |
|
if set(state.keys()) == {'model', 'opt'}: |
|
state = state['model'] |
|
model.load_state_dict(state, strict=strict) |
|
return model |
|
|
|
def main(): |
|
parser = argparse.ArgumentParser() |
|
parser.add_argument('--config', type=str, default='configs/train_abinet.yaml', |
|
help='path to config file') |
|
parser.add_argument('--input', type=str, default='figs/test') |
|
parser.add_argument('--cuda', type=int, default=-1) |
|
parser.add_argument('--checkpoint', type=str, default='workdir/train-abinet/best-train-abinet.pth') |
|
parser.add_argument('--model_eval', type=str, default='alignment', |
|
choices=['alignment', 'vision', 'language']) |
|
args = parser.parse_args() |
|
config = Config(args.config) |
|
if args.checkpoint is not None: config.model_checkpoint = args.checkpoint |
|
if args.model_eval is not None: config.model_eval = args.model_eval |
|
config.global_phase = 'test' |
|
config.model_vision_checkpoint, config.model_language_checkpoint = None, None |
|
device = 'cpu' if args.cuda < 0 else f'cuda:{args.cuda}' |
|
|
|
Logger.init(config.global_workdir, config.global_name, config.global_phase) |
|
Logger.enable_file() |
|
logging.info(config) |
|
|
|
logging.info('Construct model.') |
|
model = get_model(config).to(device) |
|
model = load(model, config.model_checkpoint, device=device) |
|
charset = CharsetMapper(filename=config.dataset_charset_path, |
|
max_length=config.dataset_max_length + 1) |
|
|
|
if os.path.isdir(args.input): |
|
paths = [os.path.join(args.input, fname) for fname in os.listdir(args.input)] |
|
else: |
|
paths = glob.glob(os.path.expanduser(args.input)) |
|
assert paths, "The input path(s) was not found" |
|
paths = sorted(paths) |
|
for path in tqdm.tqdm(paths): |
|
img = PIL.Image.open(path).convert('RGB') |
|
img = preprocess(img, config.dataset_image_width, config.dataset_image_height) |
|
img = img.to(device) |
|
res = model(img) |
|
pt_text, _, __ = postprocess(res, charset, config.model_eval) |
|
logging.info(f'{path}: {pt_text[0]}') |
|
|
|
if __name__ == '__main__': |
|
main() |
|
|