import streamlit as st from PIL import Image import cv2 import numpy as np from transformers import TrOCRProcessor, VisionEncoderDecoderModel from ultralytics import YOLO import Levenshtein import yaml import os import io import tempfile import torchvision class LicensePlateProcessor: def __init__(self): # Load models for plate detection self.yolo_detector = YOLO('detect_plate.pt') # For license plate detection self.province_detector = YOLO('best.pt') # For province detection self.char_reader = YOLO('read_char.pt') # For character reading # Load TrOCR for province detection self.processor_plate = TrOCRProcessor.from_pretrained('openthaigpt/thai-trocr') self.model_plate = VisionEncoderDecoderModel.from_pretrained('openthaigpt/thai-trocr') # Load character mapping from yaml with open('data.yaml', 'r', encoding='utf-8') as f: data_config = yaml.safe_load(f) self.char_mapping = data_config.get('char_mapping', {}) self.names = data_config['names'] # Load province list self.thai_provinces = [ "กรุงเทพมหานคร", "กระบี่", "กาญจนบุรี", "กาฬสินธุ์", "กำแพงเพชร", "ขอนแก่น", "จันทบุรี", "ฉะเชิงเทรา", "ชลบุรี", "ชัยนาท", "ชัยภูมิ", "ชุมพร", "เชียงราย", "เชียงใหม่", "ตรัง", "ตราด", "ตาก", "นครนายก", "นครปฐม", "นครพนม", "นครราชสีมา", "นครศรีธรรมราช", "นครสวรรค์", "นราธิวาส", "น่าน", "บึงกาฬ", "บุรีรัมย์", "ปทุมธานี", "ประจวบคีรีขันธ์", "ปราจีนบุรี", "ปัตตานี", "พะเยา", "พังงา", "พัทลุง", "พิจิตร", "พิษณุโลก", "เพชรบูรณ์", "เพชรบุรี", "แพร่", "ภูเก็ต", "มหาสารคาม", "มุกดาหาร", "แม่ฮ่องสอน", "ยโสธร", "ยะลา", "ร้อยเอ็ด", "ระนอง", "ระยอง", "ราชบุรี", "ลพบุรี", "ลำปาง", "ลำพูน", "เลย", "ศรีสะเกษ", "สกลนคร", "สงขลา", "สมุทรปราการ", "สมุทรสงคราม", "สมุทรสาคร", "สระแก้ว", "สระบุรี", "สิงห์บุรี", "สุโขทัย", "สุพรรณบุรี", "สุราษฎร์ธานี", "สุรินทร์", "หนองคาย", "หนองบัวลำภู", "อำนาจเจริญ", "อุดรธานี", "อุทัยธานี", "อุบลราชธานี", "อ่างทอง" ] self.CONF_THRESHOLD = 0.3 def _map_class_to_char(self, class_name): """Map class to character using yaml mapping""" if str(class_name) in self.char_mapping: return self.char_mapping[str(class_name)] return str(class_name) def get_closest_province(self, input_text): """Find closest matching province""" min_distance = float('inf') closest_province = None for province in self.thai_provinces: distance = Levenshtein.distance(input_text, province) if distance < min_distance: min_distance = distance closest_province = province return closest_province, min_distance def read_plate_characters(self, plate_image): """Read characters from plate image""" results = self.char_reader.predict(plate_image, conf=0.3) detections = [] for r in results: boxes = r.boxes for box in boxes: x1, y1, x2, y2 = map(int, box.xyxy[0]) confidence = float(box.conf[0]) class_id = int(box.cls[0]) mapped_char = self._map_class_to_char(self.names[class_id]) detections.append({ 'char': mapped_char, 'confidence': confidence, 'bbox': (x1, y1, x2, y2) }) # Sort detections left to right detections.sort(key=lambda x: x['bbox'][0]) # Combine characters plate_text = ''.join(det['char'] for det in detections) return plate_text def process_image(self, image_path: str): try: # Read image image = cv2.imread(image_path) if image is None: print(f"Error: Could not read image from {image_path}") return None # Detect license plate location plate_results = self.yolo_detector(image) province_results = self.province_detector(image) data = {"plate_number": "", "province": "", "raw_province": ""} # Save visualization output_image = image.copy() # Process license plate detections for result in plate_results: for box in result.boxes: confidence = float(box.conf) if confidence < self.CONF_THRESHOLD: continue x1, y1, x2, y2 = map(int, box.xyxy.flatten()) cropped_image = image[y1:y2, x1:x2] # Draw rectangle on output image (green for plate) cv2.rectangle(output_image, (x1, y1), (x2, y2), (0, 255, 0), 2) # Read characters using YOLO character reader data["plate_number"] = self.read_plate_characters(cropped_image) # Process province detections for result in province_results: for box in result.boxes: confidence = float(box.conf) if confidence < self.CONF_THRESHOLD: continue x1, y1, x2, y2 = map(int, box.xyxy.flatten()) cropped_image = image[y1:y2, x1:x2] # Draw rectangle on output image (blue for province) cv2.rectangle(output_image, (x1, y1), (x2, y2), (255, 0, 0), 2) # Process province using TrOCR cropped_image_gray = cv2.cvtColor(cropped_image, cv2.COLOR_BGR2GRAY) equalized_image = cv2.equalizeHist(cropped_image_gray) _, thresh_image = cv2.threshold(equalized_image, 65, 255, cv2.THRESH_BINARY_INV) cropped_image_3d = cv2.cvtColor(thresh_image, cv2.COLOR_GRAY2RGB) resized_image = cv2.resize(cropped_image_3d, (128, 32)) pixel_values = self.processor_plate(resized_image, return_tensors="pt").pixel_values generated_ids = self.model_plate.generate(pixel_values) generated_text = self.processor_plate.batch_decode(generated_ids, skip_special_tokens=True)[0] generated_province, _ = self.get_closest_province(generated_text) data["raw_province"] = generated_text data["province"] = generated_province # Save the output image cv2.imwrite('output_detection.jpg', output_image) return data except Exception as e: print(f"Error processing image: {str(e)}") return None def main(): st.set_page_config( page_title="Thai License Plate Recognition", layout="wide" ) st.title("Thai License Plate Recognition") st.write("Upload an image to detect and read Thai license plates") # Initialize processor @st.cache_resource def load_processor(): return LicensePlateProcessor() processor = load_processor() # File uploader uploaded_file = st.file_uploader("Choose an image...", type=["jpg", "jpeg", "png"]) if uploaded_file is not None: # Create columns for side-by-side display col1, col2 = st.columns(2) # Display original image with col1: st.subheader("Original Image") image = Image.open(uploaded_file) st.image(image, use_column_width=True) # Convert PIL Image to OpenCV format for processing image_array = np.array(image) if len(image_array.shape) == 3 and image_array.shape[2] == 4: # Convert RGBA to RGB if needed image_array = cv2.cvtColor(image_array, cv2.COLOR_RGBA2RGB) image_cv = cv2.cvtColor(image_array, cv2.COLOR_RGB2BGR) # Process image with st.spinner("Processing image..."): try: # Save the OpenCV image for processing temp_path = 'temp_input.jpg' cv2.imwrite(temp_path, image_cv) # Process the image using the processor results = processor.process_image(temp_path) # Clean up temporary input file os.remove(temp_path) if results: # Display results st.subheader("Detection Results") # Create a styled container for results results_container = st.container() with results_container: st.markdown(f"""

License Plate: {results['plate_number']}

Province: {results['province']}

Raw Province Text: {results['raw_province']}

""", unsafe_allow_html=True) # Display detection visualization with col2: st.subheader("Detection Visualization") if os.path.exists('output_detection.jpg'): # Read and convert the output image from BGR to RGB output_image = cv2.imread('output_detection.jpg') output_image_rgb = cv2.cvtColor(output_image, cv2.COLOR_BGR2RGB) st.image(output_image_rgb, use_column_width=True) # Clean up output image os.remove('output_detection.jpg') else: st.error("No license plate detected in the image.") except Exception as e: st.error(f"Error processing image: {str(e)}") # Clean up any temporary files in case of error if os.path.exists('temp_input.jpg'): os.remove('temp_input.jpg') if os.path.exists('output_detection.jpg'): os.remove('output_detection.jpg') # Add information about the application with st.expander("About This Application"): st.markdown(""" ### Thai License Plate Recognition System This application uses advanced computer vision and deep learning to: - Detect license plates in images using YOLO - Read Thai license plate numbers using character recognition - Identify province names using TrOCR - Provide visual detection results #### How to Use: 1. Click the 'Browse files' button above 2. Select an image containing a Thai license plate 3. Wait for the processing to complete 4. View the results and detection visualization #### Technologies Used: - YOLO for license plate detection - Custom YOLO model for character recognition - TrOCR for province text recognition """) if __name__ == "__main__": main()