|
""" a modified version of CRNN torch repository https://github.com/bgshih/crnn/blob/master/tool/create_dataset.py """ |
|
|
|
import fire |
|
import os |
|
import lmdb |
|
import cv2 |
|
|
|
import numpy as np |
|
|
|
|
|
def checkImageIsValid(imageBin): |
|
if imageBin is None: |
|
return False |
|
imageBuf = np.frombuffer(imageBin, dtype=np.uint8) |
|
img = cv2.imdecode(imageBuf, cv2.IMREAD_GRAYSCALE) |
|
imgH, imgW = img.shape[0], img.shape[1] |
|
if imgH * imgW == 0: |
|
return False |
|
return True |
|
|
|
|
|
def writeCache(env, cache): |
|
with env.begin(write=True) as txn: |
|
for k, v in cache.items(): |
|
txn.put(k, v) |
|
|
|
|
|
def createDataset(inputPath, gtFile, outputPath, checkValid=True): |
|
""" |
|
Create LMDB dataset for training and evaluation. |
|
ARGS: |
|
inputPath : input folder path where starts imagePath |
|
outputPath : LMDB output path |
|
gtFile : list of image path and label |
|
checkValid : if true, check the validity of every image |
|
""" |
|
os.makedirs(outputPath, exist_ok=True) |
|
env = lmdb.open(outputPath, map_size=1099511627776) |
|
cache = {} |
|
cnt = 1 |
|
|
|
with open(gtFile, 'r', encoding='utf-8') as data: |
|
datalist = data.readlines() |
|
|
|
nSamples = len(datalist) |
|
for i in range(nSamples): |
|
imagePath, label = datalist[i].strip('\n').split('\t') |
|
imagePath = os.path.join(inputPath, imagePath) |
|
|
|
|
|
|
|
|
|
|
|
if not os.path.exists(imagePath): |
|
print('%s does not exist' % imagePath) |
|
continue |
|
with open(imagePath, 'rb') as f: |
|
imageBin = f.read() |
|
if checkValid: |
|
try: |
|
if not checkImageIsValid(imageBin): |
|
print('%s is not a valid image' % imagePath) |
|
continue |
|
except: |
|
print('error occured', i) |
|
with open(outputPath + '/error_image_log.txt', 'a') as log: |
|
log.write('%s-th image data occured error\n' % str(i)) |
|
continue |
|
|
|
imageKey = 'image-%09d'.encode() % cnt |
|
labelKey = 'label-%09d'.encode() % cnt |
|
cache[imageKey] = imageBin |
|
cache[labelKey] = label.encode() |
|
|
|
if cnt % 1000 == 0: |
|
writeCache(env, cache) |
|
cache = {} |
|
print('Written %d / %d' % (cnt, nSamples)) |
|
cnt += 1 |
|
nSamples = cnt-1 |
|
cache['num-samples'.encode()] = str(nSamples).encode() |
|
writeCache(env, cache) |
|
print('Created dataset with %d samples' % nSamples) |
|
|
|
|
|
if __name__ == '__main__': |
|
fire.Fire(createDataset) |
|
|