|
import spaces |
|
import gradio as gr |
|
import cv2 |
|
from PIL import Image |
|
import torch |
|
import time |
|
import numpy as np |
|
import uuid |
|
|
|
from transformers import RTDetrForObjectDetection, RTDetrImageProcessor |
|
|
|
from draw_boxes import draw_bounding_boxes |
|
|
|
image_processor = RTDetrImageProcessor.from_pretrained("PekingU/rtdetr_r50vd") |
|
model = RTDetrForObjectDetection.from_pretrained("PekingU/rtdetr_r50vd") |
|
|
|
|
|
SUBSAMPLE = 2 |
|
|
|
|
|
def stream_object_detection(video, conf_threshold): |
|
cap = cv2.VideoCapture(video) |
|
|
|
video_codec = cv2.VideoWriter_fourcc(*"mp4v") |
|
fps = int(cap.get(cv2.CAP_PROP_FPS)) |
|
|
|
desired_fps = fps // SUBSAMPLE |
|
width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH)) // 2 |
|
height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT)) // 2 |
|
|
|
iterating, frame = cap.read() |
|
|
|
n_frames = 0 |
|
|
|
name = f"output_{uuid.uuid4()}.mp4" |
|
segment_file = cv2.VideoWriter(name, video_codec, desired_fps, (width, height)) |
|
batch = [] |
|
|
|
while iterating: |
|
frame = cv2.resize( frame, (0,0), fx=0.5, fy=0.5) |
|
frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) |
|
if n_frames % SUBSAMPLE == 0: |
|
batch.append(frame) |
|
if len(batch) == 2 * desired_fps: |
|
inputs = image_processor(images=batch, return_tensors="pt").to("cuda") |
|
|
|
print(f"starting batch of size {len(batch)}") |
|
start = time.time() |
|
with torch.no_grad(): |
|
outputs = model(**inputs) |
|
end = time.time() |
|
print("time taken for inference", end - start) |
|
|
|
start = time.time() |
|
boxes = image_processor.post_process_object_detection( |
|
outputs, |
|
target_sizes=torch.tensor([(height, width)] * len(batch)), |
|
threshold=conf_threshold) |
|
|
|
for i, (array, box) in enumerate(zip(batch, boxes)): |
|
pil_image = draw_bounding_boxes(Image.fromarray(array), box, model, conf_threshold) |
|
frame = np.array(pil_image) |
|
|
|
frame = frame[:, :, ::-1].copy() |
|
segment_file.write(frame) |
|
|
|
batch = [] |
|
segment_file.release() |
|
yield name |
|
end = time.time() |
|
print("time taken for processing boxes", end - start) |
|
name = f"output_{uuid.uuid4()}.mp4" |
|
segment_file = cv2.VideoWriter(name, video_codec, desired_fps, (width, height)) |
|
|
|
iterating, frame = cap.read() |
|
n_frames += 1 |
|
|
|
|
|
with gr.Blocks() as app: |
|
gr.HTML( |
|
""" |
|
<h1 style='text-align: center'> |
|
Video Object Detection with RT-DETR |
|
</h1> |
|
""") |
|
gr.HTML( |
|
""" |
|
<h3 style='text-align: center'> |
|
<a href='https://arxiv.org/abs/2304.08069' target='_blank'>arXiv</a> | <a href='https://huggingface.co/PekingU/rtdetr_r101vd_coco_o365' target='_blank'>github</a> |
|
</h3> |
|
""") |
|
with gr.Row(): |
|
with gr.Column(): |
|
video = gr.Video(label="Video Source") |
|
conf_threshold = gr.Slider( |
|
label="Confidence Threshold", |
|
minimum=0.0, |
|
maximum=1.0, |
|
step=0.05, |
|
value=0.30, |
|
) |
|
with gr.Column(): |
|
output_video = gr.Video(label="Processed Video", streaming=True, autoplay=True) |
|
|
|
video.upload( |
|
fn=stream_object_detection, |
|
inputs=[video, conf_threshold], |
|
outputs=[output_video], |
|
) |
|
|
|
if __name__ == '__main__': |
|
app.launch() |
|
|