|
import spaces |
|
import gradio as gr |
|
import cv2 |
|
from PIL import Image, ImageDraw, ImageFont |
|
import torch |
|
from transformers import Owlv2Processor, Owlv2ForObjectDetection |
|
import numpy as np |
|
import os |
|
import matplotlib.pyplot as plt |
|
import tempfile |
|
import shutil |
|
|
|
device = "cuda" |
|
|
|
processor = Owlv2Processor.from_pretrained("google/owlv2-base-patch16") |
|
model = Owlv2ForObjectDetection.from_pretrained("google/owlv2-base-patch16") |
|
|
|
model = model.to(device) |
|
|
|
def process_video(video_path, target, progress=gr.Progress()): |
|
if video_path is None: |
|
return None, None, "Error: No video uploaded" |
|
|
|
if not os.path.exists(video_path): |
|
return None, None, f"Error: Video file not found at {video_path}" |
|
|
|
cap = cv2.VideoCapture(video_path) |
|
if not cap.isOpened(): |
|
return None, None, f"Error: Unable to open video file at {video_path}" |
|
|
|
frame_count = int(cap.get(cv2.CAP_PROP_FRAME_COUNT)) |
|
original_fps = int(cap.get(cv2.CAP_PROP_FPS)) |
|
output_fps = 1 |
|
frame_duration = 1 / output_fps |
|
video_duration = frame_count / original_fps |
|
|
|
frame_scores = [] |
|
temp_dir = tempfile.mkdtemp() |
|
frame_paths = [] |
|
|
|
batch_size = 1 |
|
batch_frames = [] |
|
batch_indices = [] |
|
|
|
for i, time in enumerate(progress.tqdm(np.arange(0, video_duration, frame_duration))): |
|
frame_number = int(time * original_fps) |
|
cap.set(cv2.CAP_PROP_POS_FRAMES, frame_number) |
|
ret, img = cap.read() |
|
if not ret: |
|
break |
|
|
|
|
|
pil_img = Image.fromarray(cv2.cvtColor(img, cv2.COLOR_BGR2RGB)) |
|
|
|
batch_frames.append(pil_img) |
|
batch_indices.append(i) |
|
|
|
if len(batch_frames) == batch_size or i == int(video_duration / frame_duration) - 1: |
|
|
|
inputs = processor(text=[target] * len(batch_frames), images=batch_frames, return_tensors="pt", padding=True).to(device) |
|
|
|
with torch.no_grad(): |
|
outputs = model(**inputs) |
|
|
|
target_sizes = torch.Tensor([pil_img.size[::-1] for _ in batch_frames]).to(device) |
|
results = processor.post_process_object_detection(outputs, target_sizes=target_sizes) |
|
|
|
for idx, (pil_img, result) in enumerate(zip(batch_frames, results)): |
|
draw = ImageDraw.Draw(pil_img) |
|
max_score = 0 |
|
|
|
boxes, scores, labels = result["boxes"], result["scores"], result["labels"] |
|
|
|
|
|
for box, score, label in zip(boxes, scores, labels): |
|
if score.item() >= 0.5: |
|
box = [round(i, 2) for i in box.tolist()] |
|
object_label = target |
|
confidence = round(score.item(), 3) |
|
annotation = f"{object_label}: {confidence}" |
|
|
|
|
|
draw.rectangle(box, outline="red", width=3) |
|
|
|
|
|
img_width, img_height = pil_img.size |
|
font_size = int(min(img_width, img_height) * 0.03) |
|
try: |
|
font = ImageFont.truetype("arial.ttf", font_size) |
|
except IOError: |
|
font = ImageFont.load_default() |
|
|
|
|
|
text_bbox = draw.textbbox((0, 0), annotation, font=font) |
|
text_width = text_bbox[2] - text_bbox[0] |
|
text_height = text_bbox[3] - text_bbox[1] |
|
|
|
|
|
text_position = (box[0], box[1]) |
|
|
|
|
|
draw.rectangle([text_position[0], text_position[1], |
|
text_position[0] + text_width, text_position[1] + text_height], |
|
fill=(0, 0, 0, 128)) |
|
|
|
|
|
draw.text(text_position, annotation, fill="red", font=font) |
|
|
|
max_score = max(max_score, confidence) |
|
|
|
frame_path = os.path.join(temp_dir, f"frame_{batch_indices[idx]:04d}.png") |
|
pil_img.save(frame_path) |
|
frame_paths.append(frame_path) |
|
frame_scores.append(max_score) |
|
|
|
|
|
batch_frames = [] |
|
batch_indices = [] |
|
|
|
|
|
if i % 10 == 0: |
|
torch.cuda.empty_cache() |
|
|
|
cap.release() |
|
return frame_paths, frame_scores, None |
|
|
|
def create_heatmap(frame_scores, current_frame): |
|
plt.figure(figsize=(16, 4)) |
|
plt.imshow([frame_scores], cmap='hot_r', aspect='auto') |
|
plt.title('Object Detection Heatmap', fontsize=14) |
|
plt.xlabel('Frame', fontsize=12) |
|
plt.yticks([]) |
|
|
|
num_frames = len(frame_scores) |
|
step = max(1, num_frames // 20) |
|
frame_numbers = range(0, num_frames, step) |
|
plt.xticks(frame_numbers, [str(i) for i in frame_numbers], rotation=90, ha='right') |
|
|
|
plt.axvline(x=current_frame, color='blue', linestyle='--', linewidth=2) |
|
|
|
plt.tight_layout() |
|
|
|
with tempfile.NamedTemporaryFile(delete=False, suffix='.png') as tmp_file: |
|
plt.savefig(tmp_file.name, format='png', dpi=400, bbox_inches='tight') |
|
plt.close() |
|
|
|
return tmp_file.name |
|
|
|
def load_sample_frame(video_path, target_frame=87, original_fps=30, processing_fps=1): |
|
cap = cv2.VideoCapture(video_path) |
|
if not cap.isOpened(): |
|
return None |
|
|
|
|
|
original_frame_number = int(target_frame * (original_fps / processing_fps)) |
|
|
|
|
|
cap.set(cv2.CAP_PROP_POS_FRAMES, original_frame_number) |
|
|
|
ret, frame = cap.read() |
|
cap.release() |
|
if not ret: |
|
return None |
|
frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) |
|
return frame_rgb |
|
|
|
def update_frame_and_heatmap(frame_index, frame_paths, scores): |
|
if frame_paths and 0 <= frame_index < len(frame_paths): |
|
frame = Image.open(frame_paths[frame_index]) |
|
heatmap_path = create_heatmap(scores, frame_index) |
|
return np.array(frame), heatmap_path |
|
return None, None |
|
|
|
def gradio_app(): |
|
with gr.Blocks() as app: |
|
gr.Markdown("# Video Object Detection with Owlv2") |
|
|
|
video_input = gr.Video(label="Upload Video") |
|
target_input = gr.Textbox(label="Target Object", value="Elephant") |
|
frame_slider = gr.Slider(minimum=0, maximum=100, step=1, label="Frame", value=0) |
|
heatmap_output = gr.Image(label="Detection Heatmap") |
|
output_image = gr.Image(label="Processed Frame") |
|
error_output = gr.Textbox(label="Error Messages", visible=False) |
|
sample_video_frame = gr.Image( |
|
value=load_sample_frame("Drone Video of African Wildlife Wild Botswan.mp4", target_frame=87), |
|
label="Drone Video of African Wildlife Wild Botswan by wildimagesonline.com - Sample Video Frame (Frame 87 at 1 FPS)" |
|
) |
|
use_sample_button = gr.Button("Use Sample Video") |
|
progress_bar = gr.Progress() |
|
|
|
frame_paths = gr.State([]) |
|
frame_scores = gr.State([]) |
|
|
|
def process_and_update(video, target): |
|
paths, scores, error = process_video(video, target, progress_bar) |
|
if paths is not None: |
|
heatmap_path = create_heatmap(scores, 0) |
|
first_frame = Image.open(paths[0]) |
|
return paths, scores, np.array(first_frame), heatmap_path, error, gr.Slider(maximum=len(paths) - 1, value=0) |
|
return None, None, None, None, error, gr.Slider(maximum=100, value=0) |
|
|
|
video_input.upload(process_and_update, |
|
inputs=[video_input, target_input], |
|
outputs=[frame_paths, frame_scores, output_image, heatmap_output, error_output, frame_slider]) |
|
|
|
frame_slider.change(update_frame_and_heatmap, |
|
inputs=[frame_slider, frame_paths, frame_scores], |
|
outputs=[output_image, heatmap_output]) |
|
|
|
def use_sample_video(): |
|
sample_video_path = "Drone Video of African Wildlife Wild Botswan.mp4" |
|
return process_and_update(sample_video_path, "Elephant") |
|
|
|
use_sample_button.click(use_sample_video, |
|
inputs=None, |
|
outputs=[frame_paths, frame_scores, output_image, heatmap_output, error_output, frame_slider]) |
|
|
|
|
|
with gr.Row(): |
|
with gr.Column(scale=2): |
|
output_image |
|
with gr.Column(scale=1): |
|
sample_video_frame |
|
use_sample_button |
|
|
|
return app |
|
|
|
if __name__ == "__main__": |
|
app = gradio_app() |
|
app.launch() |
|
|
|
|
|
def cleanup(): |
|
for path in frame_paths.value: |
|
if os.path.exists(path): |
|
os.remove(path) |
|
if os.path.exists(temp_dir): |
|
shutil.rmtree(temp_dir) |
|
|
|
|
|
|