|
import torch.nn as nn
|
|
from util.util import to_device
|
|
from torch.nn import init
|
|
import os
|
|
import torch
|
|
from .networks import *
|
|
from params import *
|
|
|
|
class BidirectionalLSTM(nn.Module):
|
|
|
|
def __init__(self, nIn, nHidden, nOut):
|
|
super(BidirectionalLSTM, self).__init__()
|
|
|
|
self.rnn = nn.LSTM(nIn, nHidden, bidirectional=True)
|
|
self.embedding = nn.Linear(nHidden * 2, nOut)
|
|
|
|
|
|
def forward(self, input):
|
|
recurrent, _ = self.rnn(input)
|
|
T, b, h = recurrent.size()
|
|
t_rec = recurrent.view(T * b, h)
|
|
|
|
output = self.embedding(t_rec)
|
|
output = output.view(T, b, -1)
|
|
|
|
return output
|
|
|
|
|
|
class CRNN(nn.Module):
|
|
|
|
def __init__(self, leakyRelu=False):
|
|
super(CRNN, self).__init__()
|
|
self.name = 'OCR'
|
|
|
|
|
|
ks = [3, 3, 3, 3, 3, 3, 2]
|
|
ps = [1, 1, 1, 1, 1, 1, 0]
|
|
ss = [1, 1, 1, 1, 1, 1, 1]
|
|
nm = [64, 128, 256, 256, 512, 512, 512]
|
|
|
|
cnn = nn.Sequential()
|
|
nh = 256
|
|
dealwith_lossnone=False
|
|
|
|
def convRelu(i, batchNormalization=False):
|
|
nIn = 1 if i == 0 else nm[i - 1]
|
|
nOut = nm[i]
|
|
cnn.add_module('conv{0}'.format(i),
|
|
nn.Conv2d(nIn, nOut, ks[i], ss[i], ps[i]))
|
|
if batchNormalization:
|
|
cnn.add_module('batchnorm{0}'.format(i), nn.BatchNorm2d(nOut))
|
|
if leakyRelu:
|
|
cnn.add_module('relu{0}'.format(i),
|
|
nn.LeakyReLU(0.2, inplace=True))
|
|
else:
|
|
cnn.add_module('relu{0}'.format(i), nn.ReLU(True))
|
|
|
|
convRelu(0)
|
|
cnn.add_module('pooling{0}'.format(0), nn.MaxPool2d(2, 2))
|
|
convRelu(1)
|
|
cnn.add_module('pooling{0}'.format(1), nn.MaxPool2d(2, 2))
|
|
convRelu(2, True)
|
|
convRelu(3)
|
|
cnn.add_module('pooling{0}'.format(2),
|
|
nn.MaxPool2d((2, 2), (2, 1), (0, 1)))
|
|
convRelu(4, True)
|
|
if resolution==63:
|
|
cnn.add_module('pooling{0}'.format(3),
|
|
nn.MaxPool2d((2, 2), (2, 1), (0, 1)))
|
|
convRelu(5)
|
|
cnn.add_module('pooling{0}'.format(4),
|
|
nn.MaxPool2d((2, 2), (2, 1), (0, 1)))
|
|
convRelu(6, True)
|
|
|
|
self.cnn = cnn
|
|
self.use_rnn = False
|
|
if self.use_rnn:
|
|
self.rnn = nn.Sequential(
|
|
BidirectionalLSTM(512, nh, nh),
|
|
BidirectionalLSTM(nh, nh, ))
|
|
else:
|
|
self.linear = nn.Linear(512, VOCAB_SIZE)
|
|
|
|
|
|
if dealwith_lossnone:
|
|
self.register_backward_hook(self.backward_hook)
|
|
|
|
self.device = torch.device('cuda:{}'.format(0))
|
|
self.init = 'N02'
|
|
|
|
|
|
self = init_weights(self, self.init)
|
|
|
|
def forward(self, input):
|
|
|
|
conv = self.cnn(input)
|
|
b, c, h, w = conv.size()
|
|
if h!=1:
|
|
print('a')
|
|
assert h == 1, "the height of conv must be 1"
|
|
conv = conv.squeeze(2)
|
|
conv = conv.permute(2, 0, 1)
|
|
|
|
if self.use_rnn:
|
|
|
|
output = self.rnn(conv)
|
|
else:
|
|
output = self.linear(conv)
|
|
return output
|
|
|
|
def backward_hook(self, module, grad_input, grad_output):
|
|
for g in grad_input:
|
|
g[g != g] = 0
|
|
|
|
|
|
class OCRLabelConverter(object):
|
|
"""Convert between str and label.
|
|
|
|
NOTE:
|
|
Insert `blank` to the alphabet for CTC.
|
|
|
|
Args:
|
|
alphabet (str): set of the possible characters.
|
|
ignore_case (bool, default=True): whether or not to ignore all of the case.
|
|
"""
|
|
|
|
def __init__(self, alphabet, ignore_case=False):
|
|
self._ignore_case = ignore_case
|
|
if self._ignore_case:
|
|
alphabet = alphabet.lower()
|
|
self.alphabet = alphabet + '-'
|
|
|
|
self.dict = {}
|
|
for i, char in enumerate(alphabet):
|
|
|
|
self.dict[char] = i + 1
|
|
|
|
def encode(self, text):
|
|
"""Support batch or single str.
|
|
|
|
Args:
|
|
text (str or list of str): texts to convert.
|
|
|
|
Returns:
|
|
torch.IntTensor [length_0 + length_1 + ... length_{n - 1}]: encoded texts.
|
|
torch.IntTensor [n]: length of each text.
|
|
"""
|
|
'''
|
|
if isinstance(text, str):
|
|
text = [
|
|
self.dict[char.lower() if self._ignore_case else char]
|
|
for char in text
|
|
]
|
|
length = [len(text)]
|
|
elif isinstance(text, collections.Iterable):
|
|
length = [len(s) for s in text]
|
|
text = ''.join(text)
|
|
text, _ = self.encode(text)
|
|
return (torch.IntTensor(text), torch.IntTensor(length))
|
|
'''
|
|
length = []
|
|
result = []
|
|
for item in text:
|
|
item = item.decode('utf-8', 'strict')
|
|
length.append(len(item))
|
|
for char in item:
|
|
index = self.dict[char]
|
|
result.append(index)
|
|
|
|
text = result
|
|
return (torch.IntTensor(text), torch.IntTensor(length))
|
|
|
|
def decode(self, t, length, raw=False):
|
|
"""Decode encoded texts back into strs.
|
|
|
|
Args:
|
|
torch.IntTensor [length_0 + length_1 + ... length_{n - 1}]: encoded texts.
|
|
torch.IntTensor [n]: length of each text.
|
|
|
|
Raises:
|
|
AssertionError: when the texts and its length does not match.
|
|
|
|
Returns:
|
|
text (str or list of str): texts to convert.
|
|
"""
|
|
if length.numel() == 1:
|
|
length = length[0]
|
|
assert t.numel() == length, "text with length: {} does not match declared length: {}".format(t.numel(),
|
|
length)
|
|
if raw:
|
|
return ''.join([self.alphabet[i - 1] for i in t])
|
|
else:
|
|
char_list = []
|
|
for i in range(length):
|
|
if t[i] != 0 and (not (i > 0 and t[i - 1] == t[i])):
|
|
char_list.append(self.alphabet[t[i] - 1])
|
|
return ''.join(char_list)
|
|
else:
|
|
|
|
assert t.numel() == length.sum(), "texts with length: {} does not match declared length: {}".format(
|
|
t.numel(), length.sum())
|
|
texts = []
|
|
index = 0
|
|
for i in range(length.numel()):
|
|
l = length[i]
|
|
texts.append(
|
|
self.decode(
|
|
t[index:index + l], torch.IntTensor([l]), raw=raw))
|
|
index += l
|
|
return texts
|
|
|
|
|
|
class strLabelConverter(object):
|
|
"""Convert between str and label.
|
|
NOTE:
|
|
Insert `blank` to the alphabet for CTC.
|
|
Args:
|
|
alphabet (str): set of the possible characters.
|
|
ignore_case (bool, default=True): whether or not to ignore all of the case.
|
|
"""
|
|
|
|
def __init__(self, alphabet, ignore_case=False):
|
|
self._ignore_case = ignore_case
|
|
if self._ignore_case:
|
|
alphabet = alphabet.lower()
|
|
self.alphabet = alphabet + '-'
|
|
|
|
self.dict = {}
|
|
for i, char in enumerate(alphabet):
|
|
|
|
self.dict[char] = i + 1
|
|
|
|
def encode(self, text):
|
|
"""Support batch or single str.
|
|
Args:
|
|
text (str or list of str): texts to convert.
|
|
Returns:
|
|
torch.IntTensor [length_0 + length_1 + ... length_{n - 1}]: encoded texts.
|
|
torch.IntTensor [n]: length of each text.
|
|
"""
|
|
'''
|
|
if isinstance(text, str):
|
|
text = [
|
|
self.dict[char.lower() if self._ignore_case else char]
|
|
for char in text
|
|
]
|
|
length = [len(text)]
|
|
elif isinstance(text, collections.Iterable):
|
|
length = [len(s) for s in text]
|
|
text = ''.join(text)
|
|
text, _ = self.encode(text)
|
|
return (torch.IntTensor(text), torch.IntTensor(length))
|
|
'''
|
|
length = []
|
|
result = []
|
|
results = []
|
|
for item in text:
|
|
item = item.decode('utf-8', 'strict')
|
|
length.append(len(item))
|
|
for char in item:
|
|
index = self.dict[char]
|
|
result.append(index)
|
|
results.append(result)
|
|
result = []
|
|
|
|
return (torch.nn.utils.rnn.pad_sequence([torch.LongTensor(text) for text in results], batch_first=True), torch.IntTensor(length))
|
|
|
|
def decode(self, t, length, raw=False):
|
|
"""Decode encoded texts back into strs.
|
|
Args:
|
|
torch.IntTensor [length_0 + length_1 + ... length_{n - 1}]: encoded texts.
|
|
torch.IntTensor [n]: length of each text.
|
|
Raises:
|
|
AssertionError: when the texts and its length does not match.
|
|
Returns:
|
|
text (str or list of str): texts to convert.
|
|
"""
|
|
if length.numel() == 1:
|
|
length = length[0]
|
|
assert t.numel() == length, "text with length: {} does not match declared length: {}".format(t.numel(),
|
|
length)
|
|
if raw:
|
|
return ''.join([self.alphabet[i - 1] for i in t])
|
|
else:
|
|
char_list = []
|
|
for i in range(length):
|
|
if t[i] != 0 and (not (i > 0 and t[i - 1] == t[i])):
|
|
char_list.append(self.alphabet[t[i] - 1])
|
|
return ''.join(char_list)
|
|
else:
|
|
|
|
assert t.numel() == length.sum(), "texts with length: {} does not match declared length: {}".format(
|
|
t.numel(), length.sum())
|
|
texts = []
|
|
index = 0
|
|
for i in range(length.numel()):
|
|
l = length[i]
|
|
texts.append(
|
|
self.decode(
|
|
t[index:index + l], torch.IntTensor([l]), raw=raw))
|
|
index += l
|
|
return texts
|
|
|
|
|
|
|