freddyaboulton's picture
Update app.py
3cf5d84 verified
import spaces
import gradio as gr
import cv2
from PIL import Image
import torch
import time
import numpy as np
from gradio_webrtc import WebRTC
import os
from twilio.rest import Client
from transformers import RTDetrForObjectDetection, RTDetrImageProcessor
from draw_boxes import draw_bounding_boxes
image_processor = RTDetrImageProcessor.from_pretrained("PekingU/rtdetr_r50vd")
model = RTDetrForObjectDetection.from_pretrained("PekingU/rtdetr_r50vd").to("cuda")
account_sid = os.environ.get("TWILIO_ACCOUNT_SID")
auth_token = os.environ.get("TWILIO_AUTH_TOKEN")
if account_sid and auth_token:
client = Client(account_sid, auth_token)
token = client.tokens.create()
rtc_configuration = {
"iceServers": token.ice_servers,
"iceTransportPolicy": "relay",
}
else:
rtc_configuration = None
print("RTC_CONFIGURATION", rtc_configuration)
SUBSAMPLE = 2
@spaces.GPU
def stream_object_detection(video, conf_threshold):
cap = cv2.VideoCapture(video)
fps = int(cap.get(cv2.CAP_PROP_FPS))
iterating = True
#desired_fps = fps // SUBSAMPLE
batch = []
width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH)) // 2
height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT)) // 2
#n_frames = 0
while iterating:
iterating, frame = cap.read()
frame = cv2.resize( frame, (0,0), fx=0.5, fy=0.5)
frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
#if n_frames % SUBSAMPLE == 0:
batch.append(frame)
if len(batch) == fps:
inputs = image_processor(images=batch, return_tensors="pt").to("cuda")
print(f"starting batch of size {len(batch)}")
start = time.time()
with torch.no_grad():
outputs = model(**inputs)
end = time.time()
print("time taken for inference", end - start)
start = time.time()
boxes = image_processor.post_process_object_detection(
outputs,
target_sizes=torch.tensor([(height, width)] * len(batch)),
threshold=conf_threshold)
for _, (array, box) in enumerate(zip(batch, boxes)):
pil_image = draw_bounding_boxes(Image.fromarray(array), box, model, conf_threshold)
frame = np.array(pil_image)
# Convert RGB to BGR
frame = frame[:, :, ::-1].copy()
yield frame
batch = []
end = time.time()
print("time taken for processing boxes", end - start)
with gr.Blocks() as app:
gr.HTML(
"""
<h1 style='text-align: center'>
Video Object Detection with RT-DETR (Powered by WebRTC ⚡️)
</h1>
""")
gr.HTML(
"""
<h3 style='text-align: center'>
<a href='https://arxiv.org/abs/2304.08069' target='_blank'>arXiv</a> | <a href='https://huggingface.co/PekingU/rtdetr_r101vd_coco_o365' target='_blank'>github</a>
</h3>
""")
with gr.Row():
with gr.Column():
video = gr.Video(label="Video Source")
conf_threshold = gr.Slider(
label="Confidence Threshold",
minimum=0.0,
maximum=1.0,
step=0.05,
value=0.30,
)
with gr.Column():
output = WebRTC(label="WebRTC Stream",
rtc_configuration=rtc_configuration,
mode="receive",
modality="video")
detect = gr.Button("Detect", variant="primary")
output.stream(
fn=stream_object_detection,
inputs=[video, conf_threshold],
outputs=[output],
trigger=detect.click
)
gr.Examples(examples=["video_example.mp4"],
inputs=[video])
if __name__ == '__main__':
app.launch()