hilmantm commited on
Commit
047bbb8
·
1 Parent(s): 2d8940f

feat: implement detect accident from video (not yet done)

Browse files
Files changed (3) hide show
  1. app.py +110 -11
  2. requirements.txt +4 -1
  3. video/README.md +0 -0
app.py CHANGED
@@ -7,6 +7,10 @@ import cv2
7
  import torch
8
  import supervision as sv
9
  import numpy as np
 
 
 
 
10
 
11
  DEVICE = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
12
  CHECKPOINT = 'facebook/detr-resnet-50'
@@ -14,6 +18,8 @@ CHECKPOINT_ACCIDENT_DETECTION = 'hilmantm/detr-traffic-accident-detection'
14
  CONFIDENCE_TRESHOLD = 0.5
15
  IOU_TRESHOLD = 0.8
16
  NMS_TRESHOLD = 0.5
 
 
17
  fdic = {
18
  "family" : "Impact",
19
  "style" : "italic",
@@ -26,9 +32,6 @@ image_processor = DetrImageProcessor.from_pretrained(CHECKPOINT)
26
  model = DetrForObjectDetection.from_pretrained(CHECKPOINT_ACCIDENT_DETECTION)
27
  model.to(DEVICE)
28
 
29
- # use this function only for DETR Algorithm
30
- # def detect_object(model, test_image_path, nms_treshold = 0.5):
31
-
32
  def inference_from_image(pil_image):
33
 
34
  box_annotator = sv.BoxAnnotator()
@@ -65,6 +68,101 @@ def inference_from_image(pil_image):
65
  print("No object detected")
66
  return None
67
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
68
  with gr.Blocks() as demo:
69
  gr.Markdown(
70
  """
@@ -76,21 +174,22 @@ with gr.Blocks() as demo:
76
  with gr.Row():
77
  with gr.Column():
78
  input_image = gr.Image(label="Input image", type="pil")
79
- inp = gr.Textbox(label="Image URL", placeholder="You have image from URL? Drop here")
80
  with gr.Column():
81
  output_image = gr.Image(label="Output image with predicted accident", type="pil")
82
 
83
  detect_image_btn = gr.Button(value="Detect Accident")
84
  detect_image_btn.click(fn=inference_from_image, inputs=[input_image], outputs=[output_image])
85
 
86
- gr.Markdown("## Detect Accident from Video")
87
- with gr.Row():
88
- with gr.Column():
89
- inp = gr.Textbox(label="Youtube URL", placeholder="You should upload video to youtube and drop the link here")
90
- with gr.Column():
91
- output_image = gr.Image(label="Output image with predicted accident", type="pil")
 
92
 
93
- gr.Button(value="Detect Accident")
 
94
 
95
 
96
  demo.launch(debug=True)
 
7
  import torch
8
  import supervision as sv
9
  import numpy as np
10
+ from pytube import YouTube
11
+ import uuid
12
+ import os
13
+ from moviepy.editor import VideoFileClip
14
 
15
  DEVICE = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
16
  CHECKPOINT = 'facebook/detr-resnet-50'
 
18
  CONFIDENCE_TRESHOLD = 0.5
19
  IOU_TRESHOLD = 0.8
20
  NMS_TRESHOLD = 0.5
21
+ VIDEO_PATH = os.path.join("video")
22
+ VIDEO_INFRENCE = False
23
  fdic = {
24
  "family" : "Impact",
25
  "style" : "italic",
 
32
  model = DetrForObjectDetection.from_pretrained(CHECKPOINT_ACCIDENT_DETECTION)
33
  model.to(DEVICE)
34
 
 
 
 
35
  def inference_from_image(pil_image):
36
 
37
  box_annotator = sv.BoxAnnotator()
 
68
  print("No object detected")
69
  return None
70
 
71
+ def convert_to_h264(file_path, output_file):
72
+ clip = VideoFileClip(file_path)
73
+ clip.write_videofile(output_file, codec="libx264")
74
+ clip.close()
75
+
76
+ def inference_from_video(url):
77
+ box_annotator = sv.BoxAnnotator()
78
+
79
+ # Define the YouTube video URL
80
+ video_url = url
81
+
82
+ # Create a YouTube object and get the video stream
83
+ yt = YouTube(video_url)
84
+ yt_stream = yt.streams.filter(progressive=True, file_extension='mp4').first()
85
+
86
+ # Download the video to a file
87
+ unique_id = uuid.uuid4().hex[:6].upper()
88
+ video_folder = os.path.join(VIDEO_PATH, unique_id)
89
+ video_filename = os.path.join(video_folder, f"{unique_id}.mp4")
90
+ result_video_filename = os.path.join(video_folder, f"{unique_id}_result.mp4")
91
+ result_video_filename_temp = os.path.join(video_folder, f"{unique_id}_result_temp.mp4")
92
+
93
+ os.mkdir(video_folder)
94
+ yt_stream.download(filename=video_filename)
95
+
96
+ # Load the video
97
+ cap = cv2.VideoCapture(video_filename)
98
+
99
+ # Get the video frame dimensions
100
+ frame_width = int(cap.get(3))
101
+ frame_height = int(cap.get(4))
102
+
103
+ # Define the codec and create a VideoWriter object
104
+ out = cv2.VideoWriter(result_video_filename_temp, cv2.VideoWriter_fourcc(*'mp4v'), 30, (frame_width, frame_height))
105
+
106
+ while True:
107
+ ret, image = cap.read()
108
+ if not ret:
109
+ break
110
+
111
+ # inference
112
+ with torch.no_grad():
113
+
114
+ # load image and predict
115
+ inputs = image_processor(images=image, return_tensors='pt').to(DEVICE)
116
+ outputs = model(**inputs)
117
+
118
+ # post-process
119
+ target_sizes = torch.tensor([image.shape[:2]]).to(DEVICE)
120
+ results = image_processor.post_process_object_detection(
121
+ outputs=outputs,
122
+ threshold=CONFIDENCE_TRESHOLD,
123
+ target_sizes=target_sizes
124
+ )[0]
125
+
126
+ print("transformer result", results)
127
+
128
+ if results['scores'].shape[0] != 0 or results['labels'].shape[0] != 0:
129
+ # annotate
130
+ detections = sv.Detections.from_transformers(transformers_results=results).with_nms(threshold=NMS_TRESHOLD)
131
+ labels = [
132
+ f"{model.config.id2label[class_id]} {confidence:0.2f}"
133
+ for _, confidence, class_id, _
134
+ in detections
135
+ ]
136
+ frame = box_annotator.annotate(scene=image.copy(), detections=detections, labels=labels)
137
+ out.write(frame)
138
+ else:
139
+ out.write(image)
140
+
141
+ cap.release()
142
+ out.release()
143
+
144
+ convert_to_h264(result_video_filename_temp, result_video_filename)
145
+
146
+ # delete temp file
147
+ os.remove(result_video_filename_temp)
148
+
149
+ return result_video_filename
150
+
151
+
152
+ def testing(file):
153
+ unique_id = "39EE5A"
154
+ video_folder = os.path.join(VIDEO_PATH, unique_id)
155
+ video_filename = os.path.join(video_folder, f"{unique_id}.mp4")
156
+ result_video_filename = os.path.join(video_folder, f"{unique_id}_result.mp4")
157
+ result_video_filename_temp = os.path.join(video_folder, f"{unique_id}_result_temp.mp4")
158
+
159
+ convert_to_h264(result_video_filename_temp, result_video_filename)
160
+
161
+ os.remove(result_video_filename_temp)
162
+
163
+ return result_video_filename
164
+
165
+
166
  with gr.Blocks() as demo:
167
  gr.Markdown(
168
  """
 
174
  with gr.Row():
175
  with gr.Column():
176
  input_image = gr.Image(label="Input image", type="pil")
 
177
  with gr.Column():
178
  output_image = gr.Image(label="Output image with predicted accident", type="pil")
179
 
180
  detect_image_btn = gr.Button(value="Detect Accident")
181
  detect_image_btn.click(fn=inference_from_image, inputs=[input_image], outputs=[output_image])
182
 
183
+ if VIDEO_INFRENCE:
184
+ gr.Markdown("## Detect Accident from Video")
185
+ with gr.Row():
186
+ with gr.Column():
187
+ inp = gr.Textbox(label="Youtube URL", placeholder="You should upload video to youtube and drop the link here")
188
+ with gr.Column():
189
+ output_video = gr.Video(label="Output image with predicted accident", format="mp4")
190
 
191
+ detect_video_btn = gr.Button(value="Detect Accident")
192
+ detect_video_btn.click(fn=inference_from_video, inputs=[inp], outputs=[output_video])
193
 
194
 
195
  demo.launch(debug=True)
requirements.txt CHANGED
@@ -4,4 +4,7 @@ supervision==0.3.0
4
  pytorch-lightning
5
  roboflow
6
  timm
7
- numpy
 
 
 
 
4
  pytorch-lightning
5
  roboflow
6
  timm
7
+ numpy
8
+ pytube
9
+ ffmpeg
10
+ moviepy
video/README.md ADDED
File without changes