ChronoStellar's picture
Update app.py
dfbcd08 verified
import gradio as gr
import cv2
import joblib
import numpy as np
from skimage.feature import hog
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers
from transformers import AutoTokenizer, AutoModelForImageTextToText
from transformers import VisionEncoderDecoderModel, TrOCRProcessor
import torch
from PIL import Image
# Paths to your models
MODEL_TYPES = ["HOG & Logistic Regression","CRNN CTC","Fine Tuned TrOCR"]
clf_hog = joblib.load('./HOG_LogRes.pkl')
clf_crnn = tf.keras.models.load_model('./crnn_ctc.keras')
num_to_char = joblib.load('./decoder.joblib')
processor = TrOCRProcessor.from_pretrained("microsoft/trocr-base-printed")
clf_trocr = AutoModelForImageTextToText.from_pretrained("ChronoStellar/TrOCR_IndonesianLPR")
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
clf_trocr.to(device)
# Preprocessing and prediction functions for each model
def ocr_model_1(file_path):
im = cv2.imread(file_path)
im_gray = cv2.cvtColor(im, cv2.COLOR_BGR2GRAY)
ret, im_th = cv2.threshold(im_gray, 120, 255, cv2.THRESH_BINARY + cv2.THRESH_OTSU)
ctrs, hier = cv2.findContours(im_th, cv2.RETR_TREE, cv2.CHAIN_APPROX_SIMPLE)
bboxes = [cv2.boundingRect(c) for c in ctrs]
sorted_bboxes = sorted(bboxes, key=lambda b: b[0])
plate_char = []
image_height, image_width = im.shape[:2]
height_threshold = image_height * 0.3
width_threshold = image_width * 0.3
for num, i_bboxes in enumerate(sorted_bboxes):
[x, y, w, h] = i_bboxes
if h > height_threshold and w < width_threshold:
roi = im_gray[y:y + h, x:x + w]
roi = cv2.resize(roi, (64, 128), interpolation=cv2.INTER_AREA)
roi_hog_fd = hog(roi, orientations=9, pixels_per_cell=(8, 8), cells_per_block=(1, 1))
nbr = clf_hog.predict(np.array([roi_hog_fd]))
plate_char.append(str(nbr[0]))
return ''.join(plate_char)
max_length = 9
img_width = 200
img_height = 50
def decode_batch_predictions(pred):
input_len = np.ones(pred.shape[0]) * pred.shape[1]
results = keras.backend.ctc_decode(pred, input_length=input_len, greedy=True)[0][0][
:, :max_length
]
output_text = []
for res in results:
res = tf.strings.reduce_join(num_to_char(res)).numpy().decode("utf-8")
res = res.replace('[UNK]', '')
output_text.append(res)
return output_text
def ocr_model_2(file_path):
img = tf.io.read_file(file_path)
img = tf.io.decode_png(img, channels=1)
img = tf.image.convert_image_dtype(img, tf.float32)
img = tf.image.resize(img, [img_height, img_width])
img = tf.transpose(img, perm=[1, 0, 2])
img = tf.expand_dims(img, axis=0)
preds = clf_crnn.predict(img)
pred_text = decode_batch_predictions(preds)
return pred_text[0]
def ocr_model_3(file_path):
pil_image = Image.open(file_path).convert("RGB")
pixel_values = processor(pil_image, return_tensors="pt").pixel_values
pixel_values = pixel_values.to(device)
clf_trocr.eval()
with torch.no_grad():
generated_ids = clf_trocr.generate(pixel_values)
predicted_text = processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
return predicted_text
# Master OCR function that chooses the appropriate pipeline
def ocr(file_path, model_name):
if model_name == MODEL_TYPES[0]:
return ocr_model_1(file_path)
elif model_name == MODEL_TYPES[1]:
return ocr_model_2(file_path)
elif model_name == MODEL_TYPES[2]:
return ocr_model_3(file_path)
# Create Gradio interface
interface = gr.Interface(
fn=ocr,
inputs=[
gr.Image(type="filepath"),
gr.Dropdown(choices=MODEL_TYPES, label="Choose Model")
],
outputs=gr.Textbox(label="Predicted License Plate"),
title="Automatic License Plate Recognition",
description="Provide the file path of a license plate image, choose a model, and the system will predict the text on it. These Models are all trained on the same dataset, one model might be better compared to the other",
examples=[
['./B8837NR.jpg', 'Fine Tuned TrOCR'],
['./E5105OD.jpg', 'Fine Tuned TrOCR']
]
)
# Launch the Gradio app
interface.launch()