|
import gradio as gr |
|
import cv2 |
|
import joblib |
|
import numpy as np |
|
from skimage.feature import hog |
|
import tensorflow as tf |
|
from tensorflow import keras |
|
from tensorflow.keras import layers |
|
from transformers import AutoTokenizer, AutoModelForImageTextToText |
|
from transformers import VisionEncoderDecoderModel, TrOCRProcessor |
|
import torch |
|
from PIL import Image |
|
|
|
|
|
MODEL_TYPES = ["HOG & Logistic Regression","CRNN CTC","Fine Tuned TrOCR"] |
|
|
|
clf_hog = joblib.load('./HOG_LogRes.pkl') |
|
|
|
clf_crnn = tf.keras.models.load_model('./crnn_ctc.keras') |
|
num_to_char = joblib.load('./decoder.joblib') |
|
|
|
processor = TrOCRProcessor.from_pretrained("microsoft/trocr-base-printed") |
|
clf_trocr = AutoModelForImageTextToText.from_pretrained("ChronoStellar/TrOCR_IndonesianLPR") |
|
|
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
|
clf_trocr.to(device) |
|
|
|
|
|
def ocr_model_1(file_path): |
|
im = cv2.imread(file_path) |
|
im_gray = cv2.cvtColor(im, cv2.COLOR_BGR2GRAY) |
|
ret, im_th = cv2.threshold(im_gray, 120, 255, cv2.THRESH_BINARY + cv2.THRESH_OTSU) |
|
ctrs, hier = cv2.findContours(im_th, cv2.RETR_TREE, cv2.CHAIN_APPROX_SIMPLE) |
|
bboxes = [cv2.boundingRect(c) for c in ctrs] |
|
sorted_bboxes = sorted(bboxes, key=lambda b: b[0]) |
|
|
|
plate_char = [] |
|
image_height, image_width = im.shape[:2] |
|
height_threshold = image_height * 0.3 |
|
width_threshold = image_width * 0.3 |
|
|
|
for num, i_bboxes in enumerate(sorted_bboxes): |
|
[x, y, w, h] = i_bboxes |
|
if h > height_threshold and w < width_threshold: |
|
roi = im_gray[y:y + h, x:x + w] |
|
roi = cv2.resize(roi, (64, 128), interpolation=cv2.INTER_AREA) |
|
roi_hog_fd = hog(roi, orientations=9, pixels_per_cell=(8, 8), cells_per_block=(1, 1)) |
|
nbr = clf_hog.predict(np.array([roi_hog_fd])) |
|
plate_char.append(str(nbr[0])) |
|
|
|
return ''.join(plate_char) |
|
|
|
max_length = 9 |
|
img_width = 200 |
|
img_height = 50 |
|
|
|
def decode_batch_predictions(pred): |
|
input_len = np.ones(pred.shape[0]) * pred.shape[1] |
|
results = keras.backend.ctc_decode(pred, input_length=input_len, greedy=True)[0][0][ |
|
:, :max_length |
|
] |
|
output_text = [] |
|
for res in results: |
|
res = tf.strings.reduce_join(num_to_char(res)).numpy().decode("utf-8") |
|
res = res.replace('[UNK]', '') |
|
output_text.append(res) |
|
return output_text |
|
|
|
def ocr_model_2(file_path): |
|
img = tf.io.read_file(file_path) |
|
img = tf.io.decode_png(img, channels=1) |
|
img = tf.image.convert_image_dtype(img, tf.float32) |
|
img = tf.image.resize(img, [img_height, img_width]) |
|
img = tf.transpose(img, perm=[1, 0, 2]) |
|
img = tf.expand_dims(img, axis=0) |
|
preds = clf_crnn.predict(img) |
|
pred_text = decode_batch_predictions(preds) |
|
return pred_text[0] |
|
|
|
|
|
def ocr_model_3(file_path): |
|
pil_image = Image.open(file_path).convert("RGB") |
|
pixel_values = processor(pil_image, return_tensors="pt").pixel_values |
|
pixel_values = pixel_values.to(device) |
|
clf_trocr.eval() |
|
with torch.no_grad(): |
|
generated_ids = clf_trocr.generate(pixel_values) |
|
predicted_text = processor.batch_decode(generated_ids, skip_special_tokens=True)[0] |
|
return predicted_text |
|
|
|
|
|
|
|
def ocr(file_path, model_name): |
|
if model_name == MODEL_TYPES[0]: |
|
return ocr_model_1(file_path) |
|
elif model_name == MODEL_TYPES[1]: |
|
return ocr_model_2(file_path) |
|
elif model_name == MODEL_TYPES[2]: |
|
return ocr_model_3(file_path) |
|
|
|
|
|
interface = gr.Interface( |
|
fn=ocr, |
|
inputs=[ |
|
gr.Image(type="filepath"), |
|
gr.Dropdown(choices=MODEL_TYPES, label="Choose Model") |
|
], |
|
outputs=gr.Textbox(label="Predicted License Plate"), |
|
title="Automatic License Plate Recognition", |
|
description="Provide the file path of a license plate image, choose a model, and the system will predict the text on it. These Models are all trained on the same dataset, one model might be better compared to the other", |
|
examples=[ |
|
['./B8837NR.jpg', 'Fine Tuned TrOCR'], |
|
['./E5105OD.jpg', 'Fine Tuned TrOCR'] |
|
] |
|
) |
|
|
|
|
|
interface.launch() |