Spaces:
Sleeping
Sleeping
TheKnight115
commited on
Update processor.py
Browse files- processor.py +101 -22
processor.py
CHANGED
@@ -4,7 +4,7 @@ import cv2
|
|
4 |
import numpy as np
|
5 |
import os
|
6 |
from ultralytics import YOLO
|
7 |
-
from transformers import
|
8 |
from PIL import Image, ImageDraw, ImageFont
|
9 |
import re
|
10 |
import smtplib
|
@@ -12,6 +12,11 @@ from email.mime.text import MIMEText
|
|
12 |
from email.mime.multipart import MIMEMultipart
|
13 |
from email.mime.base import MIMEBase
|
14 |
from email import encoders
|
|
|
|
|
|
|
|
|
|
|
15 |
|
16 |
# Email credentials (Use environment variables for security)
|
17 |
FROM_EMAIL = os.getenv("FROM_EMAIL")
|
@@ -20,7 +25,7 @@ TO_EMAIL = os.getenv("TO_EMAIL")
|
|
20 |
SMTP_SERVER = 'smtp.gmail.com'
|
21 |
SMTP_PORT = 465
|
22 |
|
23 |
-
# Arabic dictionary
|
24 |
arabic_dict = {
|
25 |
"0": "٠", "1": "١", "2": "٢", "3": "٣", "4": "٤", "5": "٥",
|
26 |
"6": "٦", "7": "٧", "8": "٨", "9": "٩", "A": "ا", "B": "ب",
|
@@ -28,6 +33,8 @@ arabic_dict = {
|
|
28 |
"E": "ع", "G": "ق", "K": "ك", "L": "ل", "Z": "م", "N": "ن",
|
29 |
"H": "ه", "U": "و", "V": "ي", " ": " "
|
30 |
}
|
|
|
|
|
31 |
class_colors = {
|
32 |
0: (0, 255, 0), # Green (Helmet)
|
33 |
1: (255, 0, 0), # Blue (License Plate)
|
@@ -37,44 +44,55 @@ class_colors = {
|
|
37 |
5: (0, 255, 255), # Yellow (Person)
|
38 |
}
|
39 |
|
40 |
-
#
|
41 |
-
|
42 |
-
|
43 |
-
|
|
|
|
|
|
|
44 |
|
45 |
-
# Define lane area
|
46 |
red_lane = np.array([[2,1583],[1,1131],[1828,1141],[1912,1580]], np.int32)
|
47 |
|
48 |
-
#
|
49 |
violations_dict = {}
|
50 |
|
51 |
-
def filter_license_plate_text(
|
52 |
-
|
53 |
-
|
|
|
54 |
return f"{match.group(1)} {match.group(2)}" if match else None
|
55 |
|
56 |
-
def convert_to_arabic(
|
57 |
-
|
|
|
58 |
|
59 |
-
def send_email(license_text, violation_image_path, violation_type):
|
60 |
-
|
|
|
|
|
61 |
'No Helmet, In Red Lane': 'تنبيه مخالفة: عدم ارتداء خوذة ودخول المسار الأيسر',
|
62 |
'In Red Lane': 'تنبيه مخالفة: دخول المسار الأيسر',
|
63 |
'No Helmet': 'تنبيه مخالفة: عدم ارتداء خوذة'
|
64 |
-
}
|
65 |
-
|
66 |
-
body = {
|
67 |
'No Helmet, In Red Lane': f"لعدم ارتداء الخوذة ولدخولها المسار الأيسر ({license_text}) تم تغريم دراجة نارية التي تحمل لوحة",
|
68 |
'In Red Lane': f"لدخولها المسار الأيسر ({license_text}) تم تغريم دراجة نارية التي تحمل لوحة",
|
69 |
'No Helmet': f"لعدم ارتداء الخوذة ({license_text}) تم تغريم دراجة نارية التي تحمل لوحة"
|
70 |
-
}
|
71 |
-
|
|
|
|
|
|
|
|
|
72 |
msg = MIMEMultipart()
|
73 |
msg['From'] = FROM_EMAIL
|
74 |
msg['To'] = TO_EMAIL
|
75 |
-
msg['Subject'] = subject
|
76 |
msg.attach(MIMEText(body, 'plain'))
|
77 |
|
|
|
78 |
if os.path.exists(violation_image_path):
|
79 |
with open(violation_image_path, 'rb') as attachment_file:
|
80 |
part = MIMEBase('application', 'octet-stream')
|
@@ -83,24 +101,29 @@ def send_email(license_text, violation_image_path, violation_type):
|
|
83 |
part.add_header('Content-Disposition', f'attachment; filename={os.path.basename(violation_image_path)}')
|
84 |
msg.attach(part)
|
85 |
|
|
|
86 |
try:
|
87 |
with smtplib.SMTP_SSL(SMTP_SERVER, SMTP_PORT) as server:
|
88 |
server.login(FROM_EMAIL, EMAIL_PASSWORD)
|
89 |
server.sendmail(FROM_EMAIL, TO_EMAIL, msg.as_string())
|
|
|
90 |
except Exception as e:
|
91 |
print(f"Failed to send email: {e}")
|
92 |
|
93 |
def draw_text_pil(img, text, position, font_path, font_size, color):
|
|
|
94 |
img_pil = Image.fromarray(cv2.cvtColor(img, cv2.COLOR_BGR2RGB))
|
95 |
draw = ImageDraw.Draw(img_pil)
|
96 |
try:
|
97 |
font = ImageFont.truetype(font_path, size=font_size)
|
98 |
except IOError:
|
|
|
99 |
font = ImageFont.load_default()
|
100 |
draw.text(position, text, font=font, fill=color)
|
101 |
return cv2.cvtColor(np.array(img_pil), cv2.COLOR_RGB2BGR)
|
102 |
|
103 |
def process_frame(frame, font_path, violation_image_path='violation.jpg'):
|
|
|
104 |
results = model.track(frame)
|
105 |
for box in results[0].boxes:
|
106 |
x1, y1, x2, y2 = map(int, box.xyxy[0].cpu().numpy())
|
@@ -108,6 +131,11 @@ def process_frame(frame, font_path, violation_image_path='violation.jpg'):
|
|
108 |
color = class_colors.get(int(box.cls), (255, 255, 255))
|
109 |
confidence = box.conf[0].item()
|
110 |
|
|
|
|
|
|
|
|
|
|
|
111 |
if label == 'MotorbikeDelivery' and confidence >= 0.4:
|
112 |
motorbike_crop = frame[max(0, y1 - 50):y2, x1:x2]
|
113 |
delivery_center = ((x1 + x2) // 2, y2)
|
@@ -116,7 +144,9 @@ def process_frame(frame, font_path, violation_image_path='violation.jpg'):
|
|
116 |
if in_red_lane >= 0:
|
117 |
violation_types.append("In Red Lane")
|
118 |
|
|
|
119 |
sub_results = model(motorbike_crop)
|
|
|
120 |
for sub_box in sub_results[0].boxes:
|
121 |
sub_x1, sub_y1, sub_x2, sub_y2 = map(int, sub_box.xyxy[0].cpu().numpy())
|
122 |
sub_label = model.names[int(sub_box.cls)]
|
@@ -125,10 +155,19 @@ def process_frame(frame, font_path, violation_image_path='violation.jpg'):
|
|
125 |
elif sub_label == 'License_plate':
|
126 |
license_crop = motorbike_crop[sub_y1:sub_y2, sub_x1:sub_x2]
|
127 |
if violation_types:
|
|
|
128 |
cv2.imwrite(violation_image_path, frame)
|
129 |
license_plate_pil = Image.fromarray(cv2.cvtColor(license_crop, cv2.COLOR_BGR2RGB))
|
130 |
license_plate_pil.save('license_plate.png')
|
131 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
132 |
filtered_text = filter_license_plate_text(license_plate_text)
|
133 |
if filtered_text:
|
134 |
if filtered_text not in violations_dict:
|
@@ -141,7 +180,47 @@ def process_frame(frame, font_path, violation_image_path='violation.jpg'):
|
|
141 |
if updated != current:
|
142 |
violations_dict[filtered_text] = list(updated)
|
143 |
send_email(filtered_text, violation_image_path, ', '.join(updated))
|
|
|
144 |
arabic_text = convert_to_arabic(filtered_text)
|
145 |
frame = draw_text_pil(frame, filtered_text, (x1, y2 + 30), font_path, 30, (255, 255, 255))
|
146 |
frame = draw_text_pil(frame, arabic_text, (x1, y2 + 60), font_path, 30, (0, 255, 0))
|
147 |
return frame
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
4 |
import numpy as np
|
5 |
import os
|
6 |
from ultralytics import YOLO
|
7 |
+
from transformers import AutoProcessor, AutoModelForCausalLM, pipeline
|
8 |
from PIL import Image, ImageDraw, ImageFont
|
9 |
import re
|
10 |
import smtplib
|
|
|
12 |
from email.mime.multipart import MIMEMultipart
|
13 |
from email.mime.base import MIMEBase
|
14 |
from email import encoders
|
15 |
+
import torch
|
16 |
+
from dotenv import load_dotenv
|
17 |
+
|
18 |
+
# Load environment variables
|
19 |
+
load_dotenv()
|
20 |
|
21 |
# Email credentials (Use environment variables for security)
|
22 |
FROM_EMAIL = os.getenv("FROM_EMAIL")
|
|
|
25 |
SMTP_SERVER = 'smtp.gmail.com'
|
26 |
SMTP_PORT = 465
|
27 |
|
28 |
+
# Arabic dictionary for converting license plate text
|
29 |
arabic_dict = {
|
30 |
"0": "٠", "1": "١", "2": "٢", "3": "٣", "4": "٤", "5": "٥",
|
31 |
"6": "٦", "7": "٧", "8": "٨", "9": "٩", "A": "ا", "B": "ب",
|
|
|
33 |
"E": "ع", "G": "ق", "K": "ك", "L": "ل", "Z": "م", "N": "ن",
|
34 |
"H": "ه", "U": "و", "V": "ي", " ": " "
|
35 |
}
|
36 |
+
|
37 |
+
# Define class colors
|
38 |
class_colors = {
|
39 |
0: (0, 255, 0), # Green (Helmet)
|
40 |
1: (255, 0, 0), # Blue (License Plate)
|
|
|
44 |
5: (0, 255, 255), # Yellow (Person)
|
45 |
}
|
46 |
|
47 |
+
# Initialize the OCR pipeline
|
48 |
+
# Replace 'stepfun-ai/GOT-OCR2_0' with the correct model name if different
|
49 |
+
ocr_pipeline = pipeline("image-to-text", model="stepfun-ai/GOT-OCR2_0", trust_remote_code=True, device=0 if torch.cuda.is_available() else -1)
|
50 |
+
|
51 |
+
# Load YOLO model
|
52 |
+
# Ensure the path to the model is correct
|
53 |
+
model = YOLO('yolov8_Medium.pt') # Update the path as needed
|
54 |
|
55 |
+
# Define lane area coordinates (example coordinates)
|
56 |
red_lane = np.array([[2,1583],[1,1131],[1828,1141],[1912,1580]], np.int32)
|
57 |
|
58 |
+
# Dictionary to track violations per license plate
|
59 |
violations_dict = {}
|
60 |
|
61 |
+
def filter_license_plate_text(license_plate_text):
|
62 |
+
"""Filter and format the license plate text."""
|
63 |
+
license_plate_text = re.sub(r'[^A-Z0-9]+', "", license_plate_text)
|
64 |
+
match = re.search(r'(\d{4})\s*([A-Z]{2})', license_plate_text)
|
65 |
return f"{match.group(1)} {match.group(2)}" if match else None
|
66 |
|
67 |
+
def convert_to_arabic(license_plate_text):
|
68 |
+
"""Convert license plate text from Latin to Arabic script."""
|
69 |
+
return "".join(arabic_dict.get(char, char) for char in license_plate_text)
|
70 |
|
71 |
+
def send_email(license_text, violation_image_path, violation_type):
|
72 |
+
"""Send an email notification with violation details and image attachment."""
|
73 |
+
# Define the subject and body based on violation type
|
74 |
+
subjects = {
|
75 |
'No Helmet, In Red Lane': 'تنبيه مخالفة: عدم ارتداء خوذة ودخول المسار الأيسر',
|
76 |
'In Red Lane': 'تنبيه مخالفة: دخول المسار الأيسر',
|
77 |
'No Helmet': 'تنبيه مخالفة: عدم ارتداء خوذة'
|
78 |
+
}
|
79 |
+
bodies = {
|
|
|
80 |
'No Helmet, In Red Lane': f"لعدم ارتداء الخوذة ولدخولها المسار الأيسر ({license_text}) تم تغريم دراجة نارية التي تحمل لوحة",
|
81 |
'In Red Lane': f"لدخولها المسار الأيسر ({license_text}) تم تغريم دراجة نارية التي تحمل لوحة",
|
82 |
'No Helmet': f"لعدم ارتداء الخوذة ({license_text}) تم تغريم دراجة نارية التي تحمل لوحة"
|
83 |
+
}
|
84 |
+
|
85 |
+
subject = subjects.get(violation_type, 'تنبيه مخالفة')
|
86 |
+
body = bodies.get(violation_type, f"تم تغريم دراجة نارية التي تحمل لوحة ({license_text}) بسبب مخالفة.")
|
87 |
+
|
88 |
+
# Create the email message
|
89 |
msg = MIMEMultipart()
|
90 |
msg['From'] = FROM_EMAIL
|
91 |
msg['To'] = TO_EMAIL
|
92 |
+
msg['Subject'] = subject
|
93 |
msg.attach(MIMEText(body, 'plain'))
|
94 |
|
95 |
+
# Attach the violation image
|
96 |
if os.path.exists(violation_image_path):
|
97 |
with open(violation_image_path, 'rb') as attachment_file:
|
98 |
part = MIMEBase('application', 'octet-stream')
|
|
|
101 |
part.add_header('Content-Disposition', f'attachment; filename={os.path.basename(violation_image_path)}')
|
102 |
msg.attach(part)
|
103 |
|
104 |
+
# Send the email using SMTP
|
105 |
try:
|
106 |
with smtplib.SMTP_SSL(SMTP_SERVER, SMTP_PORT) as server:
|
107 |
server.login(FROM_EMAIL, EMAIL_PASSWORD)
|
108 |
server.sendmail(FROM_EMAIL, TO_EMAIL, msg.as_string())
|
109 |
+
print("Email with attachment sent successfully!")
|
110 |
except Exception as e:
|
111 |
print(f"Failed to send email: {e}")
|
112 |
|
113 |
def draw_text_pil(img, text, position, font_path, font_size, color):
|
114 |
+
"""Draw text on an image using PIL for better font support."""
|
115 |
img_pil = Image.fromarray(cv2.cvtColor(img, cv2.COLOR_BGR2RGB))
|
116 |
draw = ImageDraw.Draw(img_pil)
|
117 |
try:
|
118 |
font = ImageFont.truetype(font_path, size=font_size)
|
119 |
except IOError:
|
120 |
+
print(f"Font file not found at {font_path}. Using default font.")
|
121 |
font = ImageFont.load_default()
|
122 |
draw.text(position, text, font=font, fill=color)
|
123 |
return cv2.cvtColor(np.array(img_pil), cv2.COLOR_RGB2BGR)
|
124 |
|
125 |
def process_frame(frame, font_path, violation_image_path='violation.jpg'):
|
126 |
+
"""Process a single video frame for violations."""
|
127 |
results = model.track(frame)
|
128 |
for box in results[0].boxes:
|
129 |
x1, y1, x2, y2 = map(int, box.xyxy[0].cpu().numpy())
|
|
|
131 |
color = class_colors.get(int(box.cls), (255, 255, 255))
|
132 |
confidence = box.conf[0].item()
|
133 |
|
134 |
+
# Draw bounding box and label
|
135 |
+
cv2.rectangle(frame, (x1, y1), (x2, y2), color, 3)
|
136 |
+
cv2.putText(frame, f'{label}: {confidence:.2f}', (x1, y1 - 10),
|
137 |
+
cv2.FONT_HERSHEY_SIMPLEX, 0.6, color, 2)
|
138 |
+
|
139 |
if label == 'MotorbikeDelivery' and confidence >= 0.4:
|
140 |
motorbike_crop = frame[max(0, y1 - 50):y2, x1:x2]
|
141 |
delivery_center = ((x1 + x2) // 2, y2)
|
|
|
144 |
if in_red_lane >= 0:
|
145 |
violation_types.append("In Red Lane")
|
146 |
|
147 |
+
# Detect sub-objects within the motorbike crop
|
148 |
sub_results = model(motorbike_crop)
|
149 |
+
|
150 |
for sub_box in sub_results[0].boxes:
|
151 |
sub_x1, sub_y1, sub_x2, sub_y2 = map(int, sub_box.xyxy[0].cpu().numpy())
|
152 |
sub_label = model.names[int(sub_box.cls)]
|
|
|
155 |
elif sub_label == 'License_plate':
|
156 |
license_crop = motorbike_crop[sub_y1:sub_y2, sub_x1:sub_x2]
|
157 |
if violation_types:
|
158 |
+
# Save violation image
|
159 |
cv2.imwrite(violation_image_path, frame)
|
160 |
license_plate_pil = Image.fromarray(cv2.cvtColor(license_crop, cv2.COLOR_BGR2RGB))
|
161 |
license_plate_pil.save('license_plate.png')
|
162 |
+
|
163 |
+
# Perform OCR
|
164 |
+
try:
|
165 |
+
ocr_result = ocr_pipeline(Image.open('license_plate.png'))
|
166 |
+
license_plate_text = ocr_result[0]['generated_text'] if ocr_result else ""
|
167 |
+
except Exception as e:
|
168 |
+
print(f"OCR failed: {e}")
|
169 |
+
license_plate_text = ""
|
170 |
+
|
171 |
filtered_text = filter_license_plate_text(license_plate_text)
|
172 |
if filtered_text:
|
173 |
if filtered_text not in violations_dict:
|
|
|
180 |
if updated != current:
|
181 |
violations_dict[filtered_text] = list(updated)
|
182 |
send_email(filtered_text, violation_image_path, ', '.join(updated))
|
183 |
+
|
184 |
arabic_text = convert_to_arabic(filtered_text)
|
185 |
frame = draw_text_pil(frame, filtered_text, (x1, y2 + 30), font_path, 30, (255, 255, 255))
|
186 |
frame = draw_text_pil(frame, arabic_text, (x1, y2 + 60), font_path, 30, (0, 255, 0))
|
187 |
return frame
|
188 |
+
|
189 |
+
def process_image(image_path, font_path, violation_image_path='violation.jpg'):
|
190 |
+
"""Process an uploaded image and return the processed image."""
|
191 |
+
frame = cv2.imread(image_path)
|
192 |
+
if frame is None:
|
193 |
+
print("Error loading image")
|
194 |
+
return None
|
195 |
+
|
196 |
+
processed = process_frame(frame, font_path, violation_image_path)
|
197 |
+
return processed
|
198 |
+
|
199 |
+
def process_video(video_path, font_path, violation_image_path='violation.jpg'):
|
200 |
+
"""Process a video file and save the processed video."""
|
201 |
+
cap = cv2.VideoCapture(video_path)
|
202 |
+
if not cap.isOpened():
|
203 |
+
print("Error opening video file")
|
204 |
+
return None
|
205 |
+
|
206 |
+
fourcc = cv2.VideoWriter_fourcc(*'mp4v')
|
207 |
+
output_video_path = 'output_violation.mp4'
|
208 |
+
fps = cap.get(cv2.CAP_PROP_FPS)
|
209 |
+
width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
|
210 |
+
height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
|
211 |
+
out = cv2.VideoWriter(output_video_path, fourcc, fps, (width, height))
|
212 |
+
|
213 |
+
while cap.isOpened():
|
214 |
+
ret, frame = cap.read()
|
215 |
+
if not ret:
|
216 |
+
break
|
217 |
+
|
218 |
+
# Optionally draw red lane
|
219 |
+
cv2.polylines(frame, [red_lane], isClosed=True, color=(0, 0, 255), thickness=3)
|
220 |
+
|
221 |
+
processed_frame = process_frame(frame, font_path, violation_image_path)
|
222 |
+
out.write(processed_frame)
|
223 |
+
|
224 |
+
cap.release()
|
225 |
+
out.release()
|
226 |
+
return output_video_path
|