import streamlit as st import os import numpy as np import cv2 from PIL import Image import torch from transformers import TrOCRProcessor, VisionEncoderDecoderModel from ultralytics import YOLO import Levenshtein # Page config st.set_page_config( page_title="Thai License Plate Detection", page_icon="🚗", layout="centered" ) # Initialize session state for models if 'models_loaded' not in st.session_state: st.session_state['models_loaded'] = False def load_ocr_models(): """Load OCR models with proper error handling""" try: # Set environment variables to suppress warnings os.environ['TOKENIZERS_PARALLELISM'] = 'false' # Load processor with specific config processor = TrOCRProcessor.from_pretrained( 'openthaigpt/thai-trocr', revision='main', use_auth_token=False, trust_remote_code=True, local_files_only=False ) # Load OCR model with specific config ocr_model = VisionEncoderDecoderModel.from_pretrained( 'openthaigpt/thai-trocr', revision='main', use_auth_token=False, trust_remote_code=True, local_files_only=False ) # Move model to CPU explicitly ocr_model = ocr_model.to('cpu') return processor, ocr_model except Exception as e: st.error(f"Error loading OCR models: {str(e)}") st.error("Detailed error information:") import traceback st.code(traceback.format_exc()) return None, None # Load models @st.cache_resource def load_models(): try: # Check if YOLO weights exist if not os.path.exists('best.pt'): st.error("YOLO model weights (best.pt) not found in the current directory!") return None, None, None # Load YOLO model try: yolo_model = YOLO('best.pt', task='detect') except Exception as yolo_error: st.error(f"Error loading YOLO model: {str(yolo_error)}") return None, None, None # Load OCR models processor, ocr_model = load_ocr_models() if processor is None or ocr_model is None: return None, None, None return processor, ocr_model, yolo_model except Exception as e: st.error(f"Error in model loading: {str(e)}") st.error("Detailed error information:") import traceback st.code(traceback.format_exc()) return None, None, None # Thai provinces list thai_provinces = [ "āļāļĢāļļāļ‡āđ€āļ—āļžāļĄāļŦāļēāļ™āļ„āļĢ", "āļāļĢāļ°āļšāļĩāđˆ", "āļāļēāļāļˆāļ™āļšāļļāļĢāļĩ", "āļāļēāļŽāļŠāļīāļ™āļ˜āļļāđŒ", "āļāļģāđāļžāļ‡āđ€āļžāļŠāļĢ", "āļ‚āļ­āļ™āđāļāđˆāļ™", "āļˆāļąāļ™āļ—āļšāļļāļĢāļĩ", "āļ‰āļ°āđ€āļŠāļīāļ‡āđ€āļ—āļĢāļē", "āļŠāļĨāļšāļļāļĢāļĩ", "āļŠāļąāļĒāļ™āļēāļ—", "āļŠāļąāļĒāļ āļđāļĄāļī", "āļŠāļļāļĄāļžāļĢ", "āđ€āļŠāļĩāļĒāļ‡āļĢāļēāļĒ", "āđ€āļŠāļĩāļĒāļ‡āđƒāļŦāļĄāđˆ", "āļ•āļĢāļąāļ‡", "āļ•āļĢāļēāļ”", "āļ•āļēāļ", "āļ™āļ„āļĢāļ™āļēāļĒāļ", "āļ™āļ„āļĢāļ›āļāļĄ", "āļ™āļ„āļĢāļžāļ™āļĄ", "āļ™āļ„āļĢāļĢāļēāļŠāļŠāļĩāļĄāļē", "āļ™āļ„āļĢāļĻāļĢāļĩāļ˜āļĢāļĢāļĄāļĢāļēāļŠ", "āļ™āļ„āļĢāļŠāļ§āļĢāļĢāļ„āđŒ", "āļ™āļĢāļēāļ˜āļīāļ§āļēāļŠ", "āļ™āđˆāļēāļ™", "āļšāļķāļ‡āļāļēāļŽ", "āļšāļļāļĢāļĩāļĢāļąāļĄāļĒāđŒ", "āļ›āļ—āļļāļĄāļ˜āļēāļ™āļĩ", "āļ›āļĢāļ°āļˆāļ§āļšāļ„āļĩāļĢāļĩāļ‚āļąāļ™āļ˜āđŒ", "āļ›āļĢāļēāļˆāļĩāļ™āļšāļļāļĢāļĩ", "āļ›āļąāļ•āļ•āļēāļ™āļĩ", "āļžāļ°āđ€āļĒāļē", "āļžāļąāļ‡āļ‡āļē", "āļžāļąāļ—āļĨāļļāļ‡", "āļžāļīāļˆāļīāļ•āļĢ", "āļžāļīāļĐāļ“āļļāđ‚āļĨāļ", "āđ€āļžāļŠāļĢāļšāļđāļĢāļ“āđŒ", "āđ€āļžāļŠāļĢāļšāļļāļĢāļĩ", "āđāļžāļĢāđˆ", "āļ āļđāđ€āļāđ‡āļ•", "āļĄāļŦāļēāļŠāļēāļĢāļ„āļēāļĄ", "āļĄāļļāļāļ”āļēāļŦāļēāļĢ", "āđāļĄāđˆāļŪāđˆāļ­āļ‡āļŠāļ­āļ™", "āļĒāđ‚āļŠāļ˜āļĢ", "āļĒāļ°āļĨāļē", "āļĢāđ‰āļ­āļĒāđ€āļ­āđ‡āļ”", "āļĢāļ°āļ™āļ­āļ‡", "āļĢāļ°āļĒāļ­āļ‡", "āļĢāļēāļŠāļšāļļāļĢāļĩ", "āļĨāļžāļšāļļāļĢāļĩ", "āļĨāļģāļ›āļēāļ‡", "āļĨāļģāļžāļđāļ™", "āđ€āļĨāļĒ", "āļĻāļĢāļĩāļŠāļ°āđ€āļāļĐ", "āļŠāļāļĨāļ™āļ„āļĢ", "āļŠāļ‡āļ‚āļĨāļē", "āļŠāļĄāļļāļ—āļĢāļ›āļĢāļēāļāļēāļĢ", "āļŠāļĄāļļāļ—āļĢāļŠāļ‡āļ„āļĢāļēāļĄ", "āļŠāļĄāļļāļ—āļĢāļŠāļēāļ„āļĢ", "āļŠāļĢāļ°āđāļāđ‰āļ§", "āļŠāļĢāļ°āļšāļļāļĢāļĩ", "āļŠāļīāļ‡āļŦāđŒāļšāļļāļĢāļĩ", "āļŠāļļāđ‚āļ‚āļ—āļąāļĒ", "āļŠāļļāļžāļĢāļĢāļ“āļšāļļāļĢāļĩ", "āļŠāļļāļĢāļēāļĐāļŽāļĢāđŒāļ˜āļēāļ™āļĩ", "āļŠāļļāļĢāļīāļ™āļ—āļĢāđŒ", "āļŦāļ™āļ­āļ‡āļ„āļēāļĒ", "āļŦāļ™āļ­āļ‡āļšāļąāļ§āļĨāļģāļ āļđ", "āļ­āļģāļ™āļēāļˆāđ€āļˆāļĢāļīāļ", "āļ­āļļāļ”āļĢāļ˜āļēāļ™āļĩ", "āļ­āļļāļ—āļąāļĒāļ˜āļēāļ™āļĩ", "āļ­āļļāļšāļĨāļĢāļēāļŠāļ˜āļēāļ™āļĩ", "āļ­āđˆāļēāļ‡āļ—āļ­āļ‡" ] def get_closest_province(input_text, provinces): min_distance = float('inf') closest_province = None for province in provinces: distance = Levenshtein.distance(input_text, province) if distance < min_distance: min_distance = distance closest_province = province return closest_province, min_distance def process_image(image, processor, ocr_model, yolo_model): CONF_THRESHOLD = 0.2 data = {"plate_number": "", "province": "", "raw_province": "", "plate_crop": None, "province_crop": None} # Convert PIL Image to cv2 format image = np.array(image) image = cv2.cvtColor(image, cv2.COLOR_RGB2BGR) # Image enhancement lab = cv2.cvtColor(image, cv2.COLOR_BGR2LAB) l, a, b = cv2.split(lab) clahe = cv2.createCLAHE(clipLimit=3.0, tileGridSize=(8,8)) cl = clahe.apply(l) enhanced = cv2.merge((cl,a,b)) image = cv2.cvtColor(enhanced, cv2.COLOR_LAB2BGR) # YOLO detection results = yolo_model(image) # Process detections detections = [] for result in results: for box in result.boxes: confidence = float(box.conf) class_id = int(box.cls.item()) if confidence < CONF_THRESHOLD: continue x1, y1, x2, y2 = map(int, box.xyxy.flatten()) detections.append((class_id, confidence, (x1, y1, x2, y2))) # Sort by class_id detections.sort(key=lambda x: x[0]) for class_id, confidence, (x1, y1, x2, y2) in detections: cropped_image = image[y1:y2, x1:x2] if cropped_image.size == 0: continue # Preprocess for OCR cropped_image_gray = cv2.cvtColor(cropped_image, cv2.COLOR_BGR2GRAY) thresh_image = cv2.adaptiveThreshold( cropped_image_gray, 255, cv2.ADAPTIVE_THRESH_GAUSSIAN_C, cv2.THRESH_BINARY_INV, 11, 2 ) kernel = cv2.getStructuringElement(cv2.MORPH_RECT, (2,2)) thresh_image = cv2.morphologyEx(thresh_image, cv2.MORPH_CLOSE, kernel) cropped_image_3d = cv2.cvtColor(thresh_image, cv2.COLOR_GRAY2RGB) resized_image = cv2.resize(cropped_image_3d, (128, 32)) # OCR processing pixel_values = processor(resized_image, return_tensors="pt").pixel_values generated_ids = ocr_model.generate(pixel_values) generated_text = processor.batch_decode(generated_ids, skip_special_tokens=True)[0] # Convert crop to PIL for display cropped_pil = Image.fromarray(cv2.cvtColor(cropped_image, cv2.COLOR_BGR2RGB)) if class_id == 0: # License plate data["plate_number"] = generated_text data["plate_crop"] = cropped_pil elif class_id == 1: # Province generated_province, distance = get_closest_province(generated_text, thai_provinces) data["raw_province"] = generated_text data["province"] = generated_province data["province_crop"] = cropped_pil return data # Main app st.title("Thai License Plate Detection 🚗") # Load models try: if not st.session_state['models_loaded']: with st.spinner("Loading models... (this may take a minute)"): processor, ocr_model, yolo_model = load_models() st.session_state['models_loaded'] = True st.session_state['processor'] = processor st.session_state['ocr_model'] = ocr_model st.session_state['yolo_model'] = yolo_model except Exception as e: st.error(f"Error loading models: {str(e)}") st.stop() # File uploader uploaded_file = st.file_uploader("Upload an image of a Thai license plate", type=["jpg", "jpeg", "png"]) if uploaded_file is not None: try: # Display the uploaded image col1, col2 = st.columns(2) with col1: st.subheader("Uploaded Image") image = Image.open(uploaded_file) st.image(image, use_column_width=True) # Process the image with col2: st.subheader("Detection Results") with st.spinner("Processing image..."): results = process_image( image, st.session_state['processor'], st.session_state['ocr_model'], st.session_state['yolo_model'] ) if results["plate_number"]: st.success("Detection successful!") st.write("📝 License Plate:", results['plate_number']) if results['plate_crop'] is not None: st.subheader("Cropped License Plate") st.image(results['plate_crop'], caption="Detected License Plate Region") if results['raw_province']: st.write("🔍 Detected Province Text:", results['raw_province']) if results['province']: st.write("🏠 Matched Province:", results['province']) else: st.write("⚠ïļ No close province match found") if results['province_crop'] is not None: st.subheader("Cropped Province") st.image(results['province_crop'], caption="Detected Province Region") else: st.write("⚠ïļ No province text detected") else: st.error("No license plate detected in the image.") except Exception as e: st.error(f"An error occurred: {str(e)}") st.markdown("---") st.markdown("### Instructions") st.markdown(""" 1. Upload an image containing a Thai license plate 2. Wait for the processing to complete 3. View the detected license plate number and province """) # Add footer with GitHub link st.markdown("---") st.markdown("Made with âĪïļ by [AI Research Group KMUTT]") st.markdown("Check out the [GitHub Repository](https://github.com/yourusername/your-repo) for more information")