license_plate / app.py
Sompote's picture
Upload app.py
58bc27b verified
import streamlit as st
import os
import numpy as np
import cv2
from PIL import Image
import torch
from transformers import TrOCRProcessor, VisionEncoderDecoderModel
from ultralytics import YOLO
import Levenshtein
# Page config
st.set_page_config(
page_title="Thai License Plate Detection",
page_icon="🚗",
layout="centered"
)
# Initialize session state for models
if 'models_loaded' not in st.session_state:
st.session_state['models_loaded'] = False
def load_ocr_models():
"""Load OCR models with proper error handling"""
try:
# Set environment variables to suppress warnings
os.environ['TOKENIZERS_PARALLELISM'] = 'false'
# Load processor with specific config
processor = TrOCRProcessor.from_pretrained(
'openthaigpt/thai-trocr',
revision='main',
use_auth_token=False,
trust_remote_code=True,
local_files_only=False
)
# Load OCR model with specific config
ocr_model = VisionEncoderDecoderModel.from_pretrained(
'openthaigpt/thai-trocr',
revision='main',
use_auth_token=False,
trust_remote_code=True,
local_files_only=False
)
# Move model to CPU explicitly
ocr_model = ocr_model.to('cpu')
return processor, ocr_model
except Exception as e:
st.error(f"Error loading OCR models: {str(e)}")
st.error("Detailed error information:")
import traceback
st.code(traceback.format_exc())
return None, None
# Load models
@st.cache_resource
def load_models():
try:
# Check if YOLO weights exist
if not os.path.exists('best.pt'):
st.error("YOLO model weights (best.pt) not found in the current directory!")
return None, None, None
# Load YOLO model
try:
yolo_model = YOLO('best.pt', task='detect')
except Exception as yolo_error:
st.error(f"Error loading YOLO model: {str(yolo_error)}")
return None, None, None
# Load OCR models
processor, ocr_model = load_ocr_models()
if processor is None or ocr_model is None:
return None, None, None
return processor, ocr_model, yolo_model
except Exception as e:
st.error(f"Error in model loading: {str(e)}")
st.error("Detailed error information:")
import traceback
st.code(traceback.format_exc())
return None, None, None
# Thai provinces list
thai_provinces = [
"กรุงเทพมหานคร", "กระบี่", "กาญจนบุรี", "กาฬสินธุ์", "กำแพงเพชร", "ขอนแก่น", "จันทบุรี", "ฉะเชิงเทรา",
"ชลบุรี", "ชัยนาท", "ชัยภูมิ", "ชุมพร", "เชียงราย", "เชียงใหม่", "ตรัง", "ตราด", "ตาก", "นครนายก",
"นครปฐม", "นครพนม", "นครราชสีมา", "นครศรีธรรมราช", "นครสวรรค์", "นราธิวาส", "น่าน", "บึงกาฬ",
"บุรีรัมย์", "ปทุมธานี", "ประจวบคีรีขันธ์", "ปราจีนบุรี", "ปัตตานี", "พะเยา", "พังงา", "พัทลุง",
"พิจิตร", "พิษณุโลก", "เพชรบูรณ์", "เพชรบุรี", "แพร่", "ภูเก็ต", "มหาสารคาม", "มุกดาหาร", "แม่ฮ่องสอน",
"ยโสธร", "ยะลา", "ร้อยเอ็ด", "ระนอง", "ระยอง", "ราชบุรี", "ลพบุรี", "ลำปาง", "ลำพูน", "เลย",
"ศรีสะเกษ", "สกลนคร", "สงขลา", "สมุทรปราการ", "สมุทรสงคราม", "สมุทรสาคร", "สระแก้ว", "สระบุรี",
"สิงห์บุรี", "สุโขทัย", "สุพรรณบุรี", "สุราษฎร์ธานี", "สุรินทร์", "หนองคาย", "หนองบัวลำภู", "อำนาจเจริญ",
"อุดรธานี", "อุทัยธานี", "อุบลราชธานี", "อ่างทอง"
]
def get_closest_province(input_text, provinces):
min_distance = float('inf')
closest_province = None
for province in provinces:
distance = Levenshtein.distance(input_text, province)
if distance < min_distance:
min_distance = distance
closest_province = province
return closest_province, min_distance
def process_image(image, processor, ocr_model, yolo_model):
CONF_THRESHOLD = 0.2
data = {"plate_number": "", "province": "", "raw_province": "", "plate_crop": None, "province_crop": None}
# Convert PIL Image to cv2 format
image = np.array(image)
image = cv2.cvtColor(image, cv2.COLOR_RGB2BGR)
# Image enhancement
lab = cv2.cvtColor(image, cv2.COLOR_BGR2LAB)
l, a, b = cv2.split(lab)
clahe = cv2.createCLAHE(clipLimit=3.0, tileGridSize=(8,8))
cl = clahe.apply(l)
enhanced = cv2.merge((cl,a,b))
image = cv2.cvtColor(enhanced, cv2.COLOR_LAB2BGR)
# YOLO detection
results = yolo_model(image)
# Process detections
detections = []
for result in results:
for box in result.boxes:
confidence = float(box.conf)
class_id = int(box.cls.item())
if confidence < CONF_THRESHOLD:
continue
x1, y1, x2, y2 = map(int, box.xyxy.flatten())
detections.append((class_id, confidence, (x1, y1, x2, y2)))
# Sort by class_id
detections.sort(key=lambda x: x[0])
for class_id, confidence, (x1, y1, x2, y2) in detections:
cropped_image = image[y1:y2, x1:x2]
if cropped_image.size == 0:
continue
# Preprocess for OCR
cropped_image_gray = cv2.cvtColor(cropped_image, cv2.COLOR_BGR2GRAY)
thresh_image = cv2.adaptiveThreshold(
cropped_image_gray,
255,
cv2.ADAPTIVE_THRESH_GAUSSIAN_C,
cv2.THRESH_BINARY_INV,
11,
2
)
kernel = cv2.getStructuringElement(cv2.MORPH_RECT, (2,2))
thresh_image = cv2.morphologyEx(thresh_image, cv2.MORPH_CLOSE, kernel)
cropped_image_3d = cv2.cvtColor(thresh_image, cv2.COLOR_GRAY2RGB)
resized_image = cv2.resize(cropped_image_3d, (128, 32))
# OCR processing
pixel_values = processor(resized_image, return_tensors="pt").pixel_values
generated_ids = ocr_model.generate(pixel_values)
generated_text = processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
# Convert crop to PIL for display
cropped_pil = Image.fromarray(cv2.cvtColor(cropped_image, cv2.COLOR_BGR2RGB))
if class_id == 0: # License plate
data["plate_number"] = generated_text
data["plate_crop"] = cropped_pil
elif class_id == 1: # Province
generated_province, distance = get_closest_province(generated_text, thai_provinces)
data["raw_province"] = generated_text
data["province"] = generated_province
data["province_crop"] = cropped_pil
return data
# Main app
st.title("Thai License Plate Detection 🚗")
# Load models
try:
if not st.session_state['models_loaded']:
with st.spinner("Loading models... (this may take a minute)"):
processor, ocr_model, yolo_model = load_models()
st.session_state['models_loaded'] = True
st.session_state['processor'] = processor
st.session_state['ocr_model'] = ocr_model
st.session_state['yolo_model'] = yolo_model
except Exception as e:
st.error(f"Error loading models: {str(e)}")
st.stop()
# File uploader
uploaded_file = st.file_uploader("Upload an image of a Thai license plate", type=["jpg", "jpeg", "png"])
if uploaded_file is not None:
try:
# Display the uploaded image
col1, col2 = st.columns(2)
with col1:
st.subheader("Uploaded Image")
image = Image.open(uploaded_file)
st.image(image, use_column_width=True)
# Process the image
with col2:
st.subheader("Detection Results")
with st.spinner("Processing image..."):
results = process_image(
image,
st.session_state['processor'],
st.session_state['ocr_model'],
st.session_state['yolo_model']
)
if results["plate_number"]:
st.success("Detection successful!")
st.write("📝 License Plate:", results['plate_number'])
if results['plate_crop'] is not None:
st.subheader("Cropped License Plate")
st.image(results['plate_crop'], caption="Detected License Plate Region")
if results['raw_province']:
st.write("🔍 Detected Province Text:", results['raw_province'])
if results['province']:
st.write("🏠 Matched Province:", results['province'])
else:
st.write("⚠️ No close province match found")
if results['province_crop'] is not None:
st.subheader("Cropped Province")
st.image(results['province_crop'], caption="Detected Province Region")
else:
st.write("⚠️ No province text detected")
else:
st.error("No license plate detected in the image.")
except Exception as e:
st.error(f"An error occurred: {str(e)}")
st.markdown("---")
st.markdown("### Instructions")
st.markdown("""
1. Upload an image containing a Thai license plate
2. Wait for the processing to complete
3. View the detected license plate number and province
""")
# Add footer with GitHub link
st.markdown("---")
st.markdown("Made with ❤️ by [AI Research Group KMUTT]")
st.markdown("Check out the [GitHub Repository](https://github.com/yourusername/your-repo) for more information")