jacklishufan's picture
init commit
844f7c0
import os
os.environ["MKL_NUM_THREADS"] = "1"
os.environ["NUMEXPR_NUM_THREADS"] = "1"
os.environ["OMP_NUM_THREADS"] = "1"
from os import path, makedirs, listdir
import sys
import numpy as np
np.random.seed(1)
import random
random.seed(1)
import timeit
import cv2
from skimage.morphology import square, dilation
cv2.setNumThreads(0)
cv2.ocl.setUseOpenCL(False)
pred_folders = ['dpn92cls_0_tuned', 'dpn92cls_1_tuned', 'dpn92cls_2_tuned', 'res34cls2_0_tuned', 'res34cls2_1_tuned', 'res34cls2_2_tuned',
'res50cls_0_tuned', 'res50cls_1_tuned', 'res50cls_2_tuned', 'se154cls_0_tuned', 'se154cls_1_tuned', 'se154cls_2_tuned']
loc_folders = ['pred50_loc_tuned', 'pred92_loc_tuned', 'pred34_loc', 'pred154_loc']
_thr = [0.38, 0.13, 0.14]
if __name__ == '__main__':
t0 = timeit.default_timer()
pre_file = sys.argv[1]
post_file = sys.argv[2]
loc_pred_file = sys.argv[3]
cls_pred_file = sys.argv[4]
loc_fn = os.path.basename(loc_pred_file)
loc_fn = '{0}'.format(loc_fn + '_part1.png')
cls_fn = os.path.basename(cls_pred_file)
preds = []
for d in pred_folders:
msk1 = cv2.imread(path.join(d, '{0}'.format(cls_fn + '_part1.png')), cv2.IMREAD_UNCHANGED)
msk2 = cv2.imread(path.join(d, '{0}'.format(cls_fn + '_part2.png')), cv2.IMREAD_UNCHANGED)
msk = np.concatenate([msk1, msk2[..., 1:]], axis=2)
preds.append(msk)
preds = np.asarray(preds).astype('float').sum(axis=0) / len(pred_folders) / 255
loc_preds = []
for d in loc_folders:
msk = cv2.imread(path.join(d, loc_fn), cv2.IMREAD_UNCHANGED)
loc_preds.append(msk)
loc_preds = np.asarray(loc_preds).astype('float').sum(axis=0) / len(loc_folders) / 255
msk_dmg = preds[..., 1:].argmax(axis=2) + 1
msk_loc = (1 * ((loc_preds > _thr[0]) | ((loc_preds > _thr[1]) & (msk_dmg > 1) & (msk_dmg < 4)) | ((loc_preds > _thr[2]) & (msk_dmg > 1)))).astype('uint8')
msk_dmg = msk_dmg * msk_loc
_msk = (msk_dmg == 2)
if _msk.sum() > 0:
_msk = dilation(_msk, square(5))
msk_dmg[_msk & msk_dmg == 1] = 2
msk_dmg = msk_dmg.astype('uint8')
cv2.imwrite(loc_pred_file, msk_loc, [cv2.IMWRITE_PNG_COMPRESSION, 9])
cv2.imwrite(cls_pred_file, msk_dmg, [cv2.IMWRITE_PNG_COMPRESSION, 9])
elapsed = timeit.default_timer() - t0
print('Time: {:.3f} min'.format(elapsed / 60))