TheKnight115 commited on
Commit
0a0fad6
·
verified ·
1 Parent(s): f622f25

Update processor.py

Browse files
Files changed (1) hide show
  1. processor.py +101 -22
processor.py CHANGED
@@ -4,7 +4,7 @@ import cv2
4
  import numpy as np
5
  import os
6
  from ultralytics import YOLO
7
- from transformers import AutoModel, AutoProcessor
8
  from PIL import Image, ImageDraw, ImageFont
9
  import re
10
  import smtplib
@@ -12,6 +12,11 @@ from email.mime.text import MIMEText
12
  from email.mime.multipart import MIMEMultipart
13
  from email.mime.base import MIMEBase
14
  from email import encoders
 
 
 
 
 
15
 
16
  # Email credentials (Use environment variables for security)
17
  FROM_EMAIL = os.getenv("FROM_EMAIL")
@@ -20,7 +25,7 @@ TO_EMAIL = os.getenv("TO_EMAIL")
20
  SMTP_SERVER = 'smtp.gmail.com'
21
  SMTP_PORT = 465
22
 
23
- # Arabic dictionary
24
  arabic_dict = {
25
  "0": "٠", "1": "١", "2": "٢", "3": "٣", "4": "٤", "5": "٥",
26
  "6": "٦", "7": "٧", "8": "٨", "9": "٩", "A": "ا", "B": "ب",
@@ -28,6 +33,8 @@ arabic_dict = {
28
  "E": "ع", "G": "ق", "K": "ك", "L": "ل", "Z": "م", "N": "ن",
29
  "H": "ه", "U": "و", "V": "ي", " ": " "
30
  }
 
 
31
  class_colors = {
32
  0: (0, 255, 0), # Green (Helmet)
33
  1: (255, 0, 0), # Blue (License Plate)
@@ -37,44 +44,55 @@ class_colors = {
37
  5: (0, 255, 255), # Yellow (Person)
38
  }
39
 
40
- # Load models
41
- processor = AutoProcessor.from_pretrained("stepfun-ai/GOT-OCR2_0", trust_remote_code=True)
42
- model_ocr = AutoModel.from_pretrained("stepfun-ai/GOT-OCR2_0", trust_remote_code=True).to('cuda')
43
- model = YOLO('yolov8_Medium.pt') # Update path as needed
 
 
 
44
 
45
- # Define lane area
46
  red_lane = np.array([[2,1583],[1,1131],[1828,1141],[1912,1580]], np.int32)
47
 
48
- # Violation tracking
49
  violations_dict = {}
50
 
51
- def filter_license_plate_text(text):
52
- text = re.sub(r'[^A-Z0-9]+', "", text)
53
- match = re.search(r'(\d{4})\s*([A-Z]{2})', text)
 
54
  return f"{match.group(1)} {match.group(2)}" if match else None
55
 
56
- def convert_to_arabic(text):
57
- return "".join(arabic_dict.get(char, char) for char in text)
 
58
 
59
- def send_email(license_text, violation_image_path, violation_type):
60
- subject = {
 
 
61
  'No Helmet, In Red Lane': 'تنبيه مخالفة: عدم ارتداء خوذة ودخول المسار الأيسر',
62
  'In Red Lane': 'تنبيه مخالفة: دخول المسار الأيسر',
63
  'No Helmet': 'تنبيه مخالفة: عدم ارتداء خوذة'
64
- }.get(violation_type, 'تنبيه مخالفة')
65
-
66
- body = {
67
  'No Helmet, In Red Lane': f"لعدم ارتداء الخوذة ولدخولها المسار الأيسر ({license_text}) تم تغريم دراجة نارية التي تحمل لوحة",
68
  'In Red Lane': f"لدخولها المسار الأيسر ({license_text}) تم تغريم دراجة نارية التي تحمل لوحة",
69
  'No Helmet': f"لعدم ارتداء الخوذة ({license_text}) تم تغريم دراجة نارية التي تحمل لوحة"
70
- }.get(violation_type, f"تم تغريم دراجة نارية التي تحمل لوحة ({license_text})")
71
-
 
 
 
 
72
  msg = MIMEMultipart()
73
  msg['From'] = FROM_EMAIL
74
  msg['To'] = TO_EMAIL
75
- msg['Subject'] = subject
76
  msg.attach(MIMEText(body, 'plain'))
77
 
 
78
  if os.path.exists(violation_image_path):
79
  with open(violation_image_path, 'rb') as attachment_file:
80
  part = MIMEBase('application', 'octet-stream')
@@ -83,24 +101,29 @@ def send_email(license_text, violation_image_path, violation_type):
83
  part.add_header('Content-Disposition', f'attachment; filename={os.path.basename(violation_image_path)}')
84
  msg.attach(part)
85
 
 
86
  try:
87
  with smtplib.SMTP_SSL(SMTP_SERVER, SMTP_PORT) as server:
88
  server.login(FROM_EMAIL, EMAIL_PASSWORD)
89
  server.sendmail(FROM_EMAIL, TO_EMAIL, msg.as_string())
 
90
  except Exception as e:
91
  print(f"Failed to send email: {e}")
92
 
93
  def draw_text_pil(img, text, position, font_path, font_size, color):
 
94
  img_pil = Image.fromarray(cv2.cvtColor(img, cv2.COLOR_BGR2RGB))
95
  draw = ImageDraw.Draw(img_pil)
96
  try:
97
  font = ImageFont.truetype(font_path, size=font_size)
98
  except IOError:
 
99
  font = ImageFont.load_default()
100
  draw.text(position, text, font=font, fill=color)
101
  return cv2.cvtColor(np.array(img_pil), cv2.COLOR_RGB2BGR)
102
 
103
  def process_frame(frame, font_path, violation_image_path='violation.jpg'):
 
104
  results = model.track(frame)
105
  for box in results[0].boxes:
106
  x1, y1, x2, y2 = map(int, box.xyxy[0].cpu().numpy())
@@ -108,6 +131,11 @@ def process_frame(frame, font_path, violation_image_path='violation.jpg'):
108
  color = class_colors.get(int(box.cls), (255, 255, 255))
109
  confidence = box.conf[0].item()
110
 
 
 
 
 
 
111
  if label == 'MotorbikeDelivery' and confidence >= 0.4:
112
  motorbike_crop = frame[max(0, y1 - 50):y2, x1:x2]
113
  delivery_center = ((x1 + x2) // 2, y2)
@@ -116,7 +144,9 @@ def process_frame(frame, font_path, violation_image_path='violation.jpg'):
116
  if in_red_lane >= 0:
117
  violation_types.append("In Red Lane")
118
 
 
119
  sub_results = model(motorbike_crop)
 
120
  for sub_box in sub_results[0].boxes:
121
  sub_x1, sub_y1, sub_x2, sub_y2 = map(int, sub_box.xyxy[0].cpu().numpy())
122
  sub_label = model.names[int(sub_box.cls)]
@@ -125,10 +155,19 @@ def process_frame(frame, font_path, violation_image_path='violation.jpg'):
125
  elif sub_label == 'License_plate':
126
  license_crop = motorbike_crop[sub_y1:sub_y2, sub_x1:sub_x2]
127
  if violation_types:
 
128
  cv2.imwrite(violation_image_path, frame)
129
  license_plate_pil = Image.fromarray(cv2.cvtColor(license_crop, cv2.COLOR_BGR2RGB))
130
  license_plate_pil.save('license_plate.png')
131
- license_plate_text = model_ocr.chat(processor, temp_image_path, ocr_type='ocr')
 
 
 
 
 
 
 
 
132
  filtered_text = filter_license_plate_text(license_plate_text)
133
  if filtered_text:
134
  if filtered_text not in violations_dict:
@@ -141,7 +180,47 @@ def process_frame(frame, font_path, violation_image_path='violation.jpg'):
141
  if updated != current:
142
  violations_dict[filtered_text] = list(updated)
143
  send_email(filtered_text, violation_image_path, ', '.join(updated))
 
144
  arabic_text = convert_to_arabic(filtered_text)
145
  frame = draw_text_pil(frame, filtered_text, (x1, y2 + 30), font_path, 30, (255, 255, 255))
146
  frame = draw_text_pil(frame, arabic_text, (x1, y2 + 60), font_path, 30, (0, 255, 0))
147
  return frame
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4
  import numpy as np
5
  import os
6
  from ultralytics import YOLO
7
+ from transformers import AutoProcessor, AutoModelForCausalLM, pipeline
8
  from PIL import Image, ImageDraw, ImageFont
9
  import re
10
  import smtplib
 
12
  from email.mime.multipart import MIMEMultipart
13
  from email.mime.base import MIMEBase
14
  from email import encoders
15
+ import torch
16
+ from dotenv import load_dotenv
17
+
18
+ # Load environment variables
19
+ load_dotenv()
20
 
21
  # Email credentials (Use environment variables for security)
22
  FROM_EMAIL = os.getenv("FROM_EMAIL")
 
25
  SMTP_SERVER = 'smtp.gmail.com'
26
  SMTP_PORT = 465
27
 
28
+ # Arabic dictionary for converting license plate text
29
  arabic_dict = {
30
  "0": "٠", "1": "١", "2": "٢", "3": "٣", "4": "٤", "5": "٥",
31
  "6": "٦", "7": "٧", "8": "٨", "9": "٩", "A": "ا", "B": "ب",
 
33
  "E": "ع", "G": "ق", "K": "ك", "L": "ل", "Z": "م", "N": "ن",
34
  "H": "ه", "U": "و", "V": "ي", " ": " "
35
  }
36
+
37
+ # Define class colors
38
  class_colors = {
39
  0: (0, 255, 0), # Green (Helmet)
40
  1: (255, 0, 0), # Blue (License Plate)
 
44
  5: (0, 255, 255), # Yellow (Person)
45
  }
46
 
47
+ # Initialize the OCR pipeline
48
+ # Replace 'stepfun-ai/GOT-OCR2_0' with the correct model name if different
49
+ ocr_pipeline = pipeline("image-to-text", model="stepfun-ai/GOT-OCR2_0", trust_remote_code=True, device=0 if torch.cuda.is_available() else -1)
50
+
51
+ # Load YOLO model
52
+ # Ensure the path to the model is correct
53
+ model = YOLO('yolov8_Medium.pt') # Update the path as needed
54
 
55
+ # Define lane area coordinates (example coordinates)
56
  red_lane = np.array([[2,1583],[1,1131],[1828,1141],[1912,1580]], np.int32)
57
 
58
+ # Dictionary to track violations per license plate
59
  violations_dict = {}
60
 
61
+ def filter_license_plate_text(license_plate_text):
62
+ """Filter and format the license plate text."""
63
+ license_plate_text = re.sub(r'[^A-Z0-9]+', "", license_plate_text)
64
+ match = re.search(r'(\d{4})\s*([A-Z]{2})', license_plate_text)
65
  return f"{match.group(1)} {match.group(2)}" if match else None
66
 
67
+ def convert_to_arabic(license_plate_text):
68
+ """Convert license plate text from Latin to Arabic script."""
69
+ return "".join(arabic_dict.get(char, char) for char in license_plate_text)
70
 
71
+ def send_email(license_text, violation_image_path, violation_type):
72
+ """Send an email notification with violation details and image attachment."""
73
+ # Define the subject and body based on violation type
74
+ subjects = {
75
  'No Helmet, In Red Lane': 'تنبيه مخالفة: عدم ارتداء خوذة ودخول المسار الأيسر',
76
  'In Red Lane': 'تنبيه مخالفة: دخول المسار الأيسر',
77
  'No Helmet': 'تنبيه مخالفة: عدم ارتداء خوذة'
78
+ }
79
+ bodies = {
 
80
  'No Helmet, In Red Lane': f"لعدم ارتداء الخوذة ولدخولها المسار الأيسر ({license_text}) تم تغريم دراجة نارية التي تحمل لوحة",
81
  'In Red Lane': f"لدخولها المسار الأيسر ({license_text}) تم تغريم دراجة نارية التي تحمل لوحة",
82
  'No Helmet': f"لعدم ارتداء الخوذة ({license_text}) تم تغريم دراجة نارية التي تحمل لوحة"
83
+ }
84
+
85
+ subject = subjects.get(violation_type, 'تنبيه مخالفة')
86
+ body = bodies.get(violation_type, f"تم تغريم دراجة نارية التي تحمل لوحة ({license_text}) بسبب مخالفة.")
87
+
88
+ # Create the email message
89
  msg = MIMEMultipart()
90
  msg['From'] = FROM_EMAIL
91
  msg['To'] = TO_EMAIL
92
+ msg['Subject'] = subject
93
  msg.attach(MIMEText(body, 'plain'))
94
 
95
+ # Attach the violation image
96
  if os.path.exists(violation_image_path):
97
  with open(violation_image_path, 'rb') as attachment_file:
98
  part = MIMEBase('application', 'octet-stream')
 
101
  part.add_header('Content-Disposition', f'attachment; filename={os.path.basename(violation_image_path)}')
102
  msg.attach(part)
103
 
104
+ # Send the email using SMTP
105
  try:
106
  with smtplib.SMTP_SSL(SMTP_SERVER, SMTP_PORT) as server:
107
  server.login(FROM_EMAIL, EMAIL_PASSWORD)
108
  server.sendmail(FROM_EMAIL, TO_EMAIL, msg.as_string())
109
+ print("Email with attachment sent successfully!")
110
  except Exception as e:
111
  print(f"Failed to send email: {e}")
112
 
113
  def draw_text_pil(img, text, position, font_path, font_size, color):
114
+ """Draw text on an image using PIL for better font support."""
115
  img_pil = Image.fromarray(cv2.cvtColor(img, cv2.COLOR_BGR2RGB))
116
  draw = ImageDraw.Draw(img_pil)
117
  try:
118
  font = ImageFont.truetype(font_path, size=font_size)
119
  except IOError:
120
+ print(f"Font file not found at {font_path}. Using default font.")
121
  font = ImageFont.load_default()
122
  draw.text(position, text, font=font, fill=color)
123
  return cv2.cvtColor(np.array(img_pil), cv2.COLOR_RGB2BGR)
124
 
125
  def process_frame(frame, font_path, violation_image_path='violation.jpg'):
126
+ """Process a single video frame for violations."""
127
  results = model.track(frame)
128
  for box in results[0].boxes:
129
  x1, y1, x2, y2 = map(int, box.xyxy[0].cpu().numpy())
 
131
  color = class_colors.get(int(box.cls), (255, 255, 255))
132
  confidence = box.conf[0].item()
133
 
134
+ # Draw bounding box and label
135
+ cv2.rectangle(frame, (x1, y1), (x2, y2), color, 3)
136
+ cv2.putText(frame, f'{label}: {confidence:.2f}', (x1, y1 - 10),
137
+ cv2.FONT_HERSHEY_SIMPLEX, 0.6, color, 2)
138
+
139
  if label == 'MotorbikeDelivery' and confidence >= 0.4:
140
  motorbike_crop = frame[max(0, y1 - 50):y2, x1:x2]
141
  delivery_center = ((x1 + x2) // 2, y2)
 
144
  if in_red_lane >= 0:
145
  violation_types.append("In Red Lane")
146
 
147
+ # Detect sub-objects within the motorbike crop
148
  sub_results = model(motorbike_crop)
149
+
150
  for sub_box in sub_results[0].boxes:
151
  sub_x1, sub_y1, sub_x2, sub_y2 = map(int, sub_box.xyxy[0].cpu().numpy())
152
  sub_label = model.names[int(sub_box.cls)]
 
155
  elif sub_label == 'License_plate':
156
  license_crop = motorbike_crop[sub_y1:sub_y2, sub_x1:sub_x2]
157
  if violation_types:
158
+ # Save violation image
159
  cv2.imwrite(violation_image_path, frame)
160
  license_plate_pil = Image.fromarray(cv2.cvtColor(license_crop, cv2.COLOR_BGR2RGB))
161
  license_plate_pil.save('license_plate.png')
162
+
163
+ # Perform OCR
164
+ try:
165
+ ocr_result = ocr_pipeline(Image.open('license_plate.png'))
166
+ license_plate_text = ocr_result[0]['generated_text'] if ocr_result else ""
167
+ except Exception as e:
168
+ print(f"OCR failed: {e}")
169
+ license_plate_text = ""
170
+
171
  filtered_text = filter_license_plate_text(license_plate_text)
172
  if filtered_text:
173
  if filtered_text not in violations_dict:
 
180
  if updated != current:
181
  violations_dict[filtered_text] = list(updated)
182
  send_email(filtered_text, violation_image_path, ', '.join(updated))
183
+
184
  arabic_text = convert_to_arabic(filtered_text)
185
  frame = draw_text_pil(frame, filtered_text, (x1, y2 + 30), font_path, 30, (255, 255, 255))
186
  frame = draw_text_pil(frame, arabic_text, (x1, y2 + 60), font_path, 30, (0, 255, 0))
187
  return frame
188
+
189
+ def process_image(image_path, font_path, violation_image_path='violation.jpg'):
190
+ """Process an uploaded image and return the processed image."""
191
+ frame = cv2.imread(image_path)
192
+ if frame is None:
193
+ print("Error loading image")
194
+ return None
195
+
196
+ processed = process_frame(frame, font_path, violation_image_path)
197
+ return processed
198
+
199
+ def process_video(video_path, font_path, violation_image_path='violation.jpg'):
200
+ """Process a video file and save the processed video."""
201
+ cap = cv2.VideoCapture(video_path)
202
+ if not cap.isOpened():
203
+ print("Error opening video file")
204
+ return None
205
+
206
+ fourcc = cv2.VideoWriter_fourcc(*'mp4v')
207
+ output_video_path = 'output_violation.mp4'
208
+ fps = cap.get(cv2.CAP_PROP_FPS)
209
+ width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
210
+ height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
211
+ out = cv2.VideoWriter(output_video_path, fourcc, fps, (width, height))
212
+
213
+ while cap.isOpened():
214
+ ret, frame = cap.read()
215
+ if not ret:
216
+ break
217
+
218
+ # Optionally draw red lane
219
+ cv2.polylines(frame, [red_lane], isClosed=True, color=(0, 0, 255), thickness=3)
220
+
221
+ processed_frame = process_frame(frame, font_path, violation_image_path)
222
+ out.write(processed_frame)
223
+
224
+ cap.release()
225
+ out.release()
226
+ return output_video_path