import streamlit as st from PIL import Image import cv2 import numpy as np from transformers import TrOCRProcessor, VisionEncoderDecoderModel from ultralytics import YOLO import Levenshtein import yaml import os import io import tempfile import torchvision class LicensePlateProcessor: def __init__(self): # Load models for plate detection self.yolo_detector = YOLO('detect_plate.pt') # For license plate detection self.province_detector = YOLO('best.pt') # For province detection self.char_reader = YOLO('read_char.pt') # For character reading # Load TrOCR for province detection self.processor_plate = TrOCRProcessor.from_pretrained('openthaigpt/thai-trocr') self.model_plate = VisionEncoderDecoderModel.from_pretrained('openthaigpt/thai-trocr') # Load character mapping from yaml with open('data.yaml', 'r', encoding='utf-8') as f: data_config = yaml.safe_load(f) self.char_mapping = data_config.get('char_mapping', {}) self.names = data_config['names'] # Load province list self.thai_provinces = [ "กรุงเทพมหานคร", "กระบี่", "กาญจนบุรี", "กาฬสินธุ์", "กำแพงเพชร", "ขอนแก่น", "จันทบุรี", "ฉะเชิงเทรา", "ชลบุรี", "ชัยนาท", "ชัยภูมิ", "ชุมพร", "เชียงราย", "เชียงใหม่", "ตรัง", "ตราด", "ตาก", "นครนายก", "นครปฐม", "นครพนม", "นครราชสีมา", "นครศรีธรรมราช", "นครสวรรค์", "นราธิวาส", "น่าน", "บึงกาฬ", "บุรีรัมย์", "ปทุมธานี", "ประจวบคีรีขันธ์", "ปราจีนบุรี", "ปัตตานี", "พะเยา", "พังงา", "พัทลุง", "พิจิตร", "พิษณุโลก", "เพชรบูรณ์", "เพชรบุรี", "แพร่", "ภูเก็ต", "มหาสารคาม", "มุกดาหาร", "แม่ฮ่องสอน", "ยโสธร", "ยะลา", "ร้อยเอ็ด", "ระนอง", "ระยอง", "ราชบุรี", "ลพบุรี", "ลำปาง", "ลำพูน", "เลย", "ศรีสะเกษ", "สกลนคร", "สงขลา", "สมุทรปราการ", "สมุทรสงคราม", "สมุทรสาคร", "สระแก้ว", "สระบุรี", "สิงห์บุรี", "สุโขทัย", "สุพรรณบุรี", "สุราษฎร์ธานี", "สุรินทร์", "หนองคาย", "หนองบัวลำภู", "อำนาจเจริญ", "อุดรธานี", "อุทัยธานี", "อุบลราชธานี", "อ่างทอง" ] self.CONF_THRESHOLD = 0.3 def _map_class_to_char(self, class_name): """Map class to character using yaml mapping""" if str(class_name) in self.char_mapping: return self.char_mapping[str(class_name)] return str(class_name) def get_closest_province(self, input_text): """Find closest matching province""" min_distance = float('inf') closest_province = None for province in self.thai_provinces: distance = Levenshtein.distance(input_text, province) if distance < min_distance: min_distance = distance closest_province = province return closest_province, min_distance def read_plate_characters(self, plate_image): """Read characters from plate image""" results = self.char_reader.predict(plate_image, conf=0.3) detections = [] for r in results: boxes = r.boxes for box in boxes: x1, y1, x2, y2 = map(int, box.xyxy[0]) confidence = float(box.conf[0]) class_id = int(box.cls[0]) mapped_char = self._map_class_to_char(self.names[class_id]) detections.append({ 'char': mapped_char, 'confidence': confidence, 'bbox': (x1, y1, x2, y2) }) # Sort detections left to right detections.sort(key=lambda x: x['bbox'][0]) # Combine characters plate_text = ''.join(det['char'] for det in detections) return plate_text def process_image(self, image_path: str): try: # Read image image = cv2.imread(image_path) if image is None: print(f"Error: Could not read image from {image_path}") return None # Detect license plate location plate_results = self.yolo_detector(image) province_results = self.province_detector(image) data = {"plate_number": "", "province": "", "raw_province": ""} # Save visualization output_image = image.copy() # Process license plate detections for result in plate_results: for box in result.boxes: confidence = float(box.conf) if confidence < self.CONF_THRESHOLD: continue x1, y1, x2, y2 = map(int, box.xyxy.flatten()) cropped_image = image[y1:y2, x1:x2] # Draw rectangle on output image (green for plate) cv2.rectangle(output_image, (x1, y1), (x2, y2), (0, 255, 0), 2) # Read characters using YOLO character reader data["plate_number"] = self.read_plate_characters(cropped_image) # Process province detections for result in province_results: for box in result.boxes: confidence = float(box.conf) if confidence < self.CONF_THRESHOLD: continue x1, y1, x2, y2 = map(int, box.xyxy.flatten()) cropped_image = image[y1:y2, x1:x2] # Draw rectangle on output image (blue for province) cv2.rectangle(output_image, (x1, y1), (x2, y2), (255, 0, 0), 2) # Process province using TrOCR cropped_image_gray = cv2.cvtColor(cropped_image, cv2.COLOR_BGR2GRAY) equalized_image = cv2.equalizeHist(cropped_image_gray) _, thresh_image = cv2.threshold(equalized_image, 65, 255, cv2.THRESH_BINARY_INV) cropped_image_3d = cv2.cvtColor(thresh_image, cv2.COLOR_GRAY2RGB) resized_image = cv2.resize(cropped_image_3d, (128, 32)) pixel_values = self.processor_plate(resized_image, return_tensors="pt").pixel_values generated_ids = self.model_plate.generate(pixel_values) generated_text = self.processor_plate.batch_decode(generated_ids, skip_special_tokens=True)[0] generated_province, _ = self.get_closest_province(generated_text) data["raw_province"] = generated_text data["province"] = generated_province # Save the output image cv2.imwrite('output_detection.jpg', output_image) return data except Exception as e: print(f"Error processing image: {str(e)}") return None def main(): st.set_page_config( page_title="Thai License Plate Recognition", layout="wide" ) st.title("Thai License Plate Recognition") st.write("Upload an image to detect and read Thai license plates") # Initialize processor @st.cache_resource def load_processor(): return LicensePlateProcessor() processor = load_processor() # File uploader uploaded_file = st.file_uploader("Choose an image...", type=["jpg", "jpeg", "png"]) if uploaded_file is not None: # Create columns for side-by-side display col1, col2 = st.columns(2) # Display original image with col1: st.subheader("Original Image") image = Image.open(uploaded_file) st.image(image, use_column_width=True) # Convert PIL Image to OpenCV format for processing image_array = np.array(image) if len(image_array.shape) == 3 and image_array.shape[2] == 4: # Convert RGBA to RGB if needed image_array = cv2.cvtColor(image_array, cv2.COLOR_RGBA2RGB) image_cv = cv2.cvtColor(image_array, cv2.COLOR_RGB2BGR) # Process image with st.spinner("Processing image..."): try: # Save the OpenCV image for processing temp_path = 'temp_input.jpg' cv2.imwrite(temp_path, image_cv) # Process the image using the processor results = processor.process_image(temp_path) # Clean up temporary input file os.remove(temp_path) if results: # Display results st.subheader("Detection Results") # Create a styled container for results results_container = st.container() with results_container: st.markdown(f"""
Raw Province Text: {results['raw_province']}