Spaces:
Runtime error
Runtime error
import spaces | |
import gradio as gr | |
import cv2 | |
from PIL import Image, ImageDraw, ImageFont | |
import torch | |
from transformers import Owlv2Processor, Owlv2ForObjectDetection | |
import numpy as np | |
import os | |
import matplotlib.pyplot as plt | |
import tempfile | |
import shutil | |
device = "cuda" | |
processor = Owlv2Processor.from_pretrained("google/owlv2-base-patch16") | |
model = Owlv2ForObjectDetection.from_pretrained("google/owlv2-base-patch16") | |
model = model.to(device) | |
def process_video(video_path, target, progress=gr.Progress()): | |
if video_path is None: | |
return None, None, "Error: No video uploaded" | |
if not os.path.exists(video_path): | |
return None, None, f"Error: Video file not found at {video_path}" | |
cap = cv2.VideoCapture(video_path) | |
if not cap.isOpened(): | |
return None, None, f"Error: Unable to open video file at {video_path}" | |
frame_count = int(cap.get(cv2.CAP_PROP_FRAME_COUNT)) | |
original_fps = int(cap.get(cv2.CAP_PROP_FPS)) | |
output_fps = 1 | |
frame_duration = 1 / output_fps | |
video_duration = frame_count / original_fps | |
frame_scores = [] | |
temp_dir = tempfile.mkdtemp() | |
frame_paths = [] | |
batch_size = 1 | |
batch_frames = [] | |
batch_indices = [] | |
for i, time in enumerate(progress.tqdm(np.arange(0, video_duration, frame_duration))): | |
frame_number = int(time * original_fps) | |
cap.set(cv2.CAP_PROP_POS_FRAMES, frame_number) | |
ret, img = cap.read() | |
if not ret: | |
break | |
# Convert to RGB without resizing | |
pil_img = Image.fromarray(cv2.cvtColor(img, cv2.COLOR_BGR2RGB)) | |
batch_frames.append(pil_img) | |
batch_indices.append(i) | |
if len(batch_frames) == batch_size or i == int(video_duration / frame_duration) - 1: | |
# Process batch | |
inputs = processor(text=[target] * len(batch_frames), images=batch_frames, return_tensors="pt", padding=True).to(device) | |
with torch.no_grad(): | |
outputs = model(**inputs) | |
target_sizes = torch.Tensor([pil_img.size[::-1] for _ in batch_frames]).to(device) | |
results = processor.post_process_object_detection(outputs, target_sizes=target_sizes) | |
for idx, (pil_img, result) in enumerate(zip(batch_frames, results)): | |
draw = ImageDraw.Draw(pil_img) | |
max_score = 0 | |
boxes, scores, labels = result["boxes"], result["scores"], result["labels"] | |
# Inside the loop where bounding boxes are drawn | |
for box, score, label in zip(boxes, scores, labels): | |
if score.item() >= 0.5: | |
box = [round(i, 2) for i in box.tolist()] | |
object_label = target | |
confidence = round(score.item(), 3) | |
annotation = f"{object_label}: {confidence}" | |
# Increase line width for the bounding box | |
draw.rectangle(box, outline="red", width=3) | |
# Calculate font size based on image dimensions | |
img_width, img_height = pil_img.size | |
font_size = int(min(img_width, img_height) * 0.03) # 3% of the smaller dimension | |
try: | |
font = ImageFont.truetype("arial.ttf", font_size) | |
except IOError: | |
font = ImageFont.load_default() | |
# Calculate text size | |
text_bbox = draw.textbbox((0, 0), annotation, font=font) | |
text_width = text_bbox[2] - text_bbox[0] | |
text_height = text_bbox[3] - text_bbox[1] | |
# Position text inside the top of the bounding box | |
text_position = (box[0], box[1]) | |
# Draw semi-transparent background for text | |
draw.rectangle([text_position[0], text_position[1], | |
text_position[0] + text_width, text_position[1] + text_height], | |
fill=(0, 0, 0, 128)) | |
# Draw text in red | |
draw.text(text_position, annotation, fill="red", font=font) | |
max_score = max(max_score, confidence) | |
frame_path = os.path.join(temp_dir, f"frame_{batch_indices[idx]:04d}.png") | |
pil_img.save(frame_path) | |
frame_paths.append(frame_path) | |
frame_scores.append(max_score) | |
# Clear batch | |
batch_frames = [] | |
batch_indices = [] | |
# Clear GPU cache every 10 frames | |
if i % 10 == 0: | |
torch.cuda.empty_cache() | |
cap.release() | |
return frame_paths, frame_scores, None | |
def create_heatmap(frame_scores, current_frame): | |
plt.figure(figsize=(16, 4)) | |
plt.imshow([frame_scores], cmap='hot_r', aspect='auto') | |
plt.title('Object Detection Heatmap', fontsize=14) | |
plt.xlabel('Frame', fontsize=12) | |
plt.yticks([]) | |
num_frames = len(frame_scores) | |
step = max(1, num_frames // 20) | |
frame_numbers = range(0, num_frames, step) | |
plt.xticks(frame_numbers, [str(i) for i in frame_numbers], rotation=90, ha='right') | |
plt.axvline(x=current_frame, color='blue', linestyle='--', linewidth=2) | |
plt.tight_layout() | |
with tempfile.NamedTemporaryFile(delete=False, suffix='.png') as tmp_file: | |
plt.savefig(tmp_file.name, format='png', dpi=400, bbox_inches='tight') | |
plt.close() | |
return tmp_file.name | |
def load_sample_frame(video_path, target_frame=87, original_fps=30, processing_fps=1): | |
cap = cv2.VideoCapture(video_path) | |
if not cap.isOpened(): | |
return None | |
# Calculate the corresponding frame number in the original video | |
original_frame_number = int(target_frame * (original_fps / processing_fps)) | |
# Set the frame position | |
cap.set(cv2.CAP_PROP_POS_FRAMES, original_frame_number) | |
ret, frame = cap.read() | |
cap.release() | |
if not ret: | |
return None | |
frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) | |
return frame_rgb | |
def update_frame_and_heatmap(frame_index, frame_paths, scores): | |
if frame_paths and 0 <= frame_index < len(frame_paths): | |
frame = Image.open(frame_paths[frame_index]) | |
heatmap_path = create_heatmap(scores, frame_index) | |
return np.array(frame), heatmap_path | |
return None, None | |
def gradio_app(): | |
with gr.Blocks() as app: | |
gr.Markdown("# Video Object Detection with Owlv2") | |
video_input = gr.Video(label="Upload Video") | |
target_input = gr.Textbox(label="Target Object", value="Elephant") | |
frame_slider = gr.Slider(minimum=0, maximum=100, step=1, label="Frame", value=0) | |
heatmap_output = gr.Image(label="Detection Heatmap") | |
output_image = gr.Image(label="Processed Frame") | |
error_output = gr.Textbox(label="Error Messages", visible=False) | |
sample_video_frame = gr.Image( | |
value=load_sample_frame("Drone Video of African Wildlife Wild Botswan.mp4", target_frame=87), | |
label="Drone Video of African Wildlife Wild Botswan by wildimagesonline.com - Sample Video Frame (Frame 87 at 1 FPS)" | |
) | |
use_sample_button = gr.Button("Use Sample Video") | |
progress_bar = gr.Progress() | |
frame_paths = gr.State([]) | |
frame_scores = gr.State([]) | |
def process_and_update(video, target): | |
paths, scores, error = process_video(video, target, progress_bar) | |
if paths is not None: | |
heatmap_path = create_heatmap(scores, 0) | |
first_frame = Image.open(paths[0]) | |
return paths, scores, np.array(first_frame), heatmap_path, error, gr.Slider(maximum=len(paths) - 1, value=0) | |
return None, None, None, None, error, gr.Slider(maximum=100, value=0) | |
video_input.upload(process_and_update, | |
inputs=[video_input, target_input], | |
outputs=[frame_paths, frame_scores, output_image, heatmap_output, error_output, frame_slider]) | |
frame_slider.change(update_frame_and_heatmap, | |
inputs=[frame_slider, frame_paths, frame_scores], | |
outputs=[output_image, heatmap_output]) | |
def use_sample_video(): | |
sample_video_path = "Drone Video of African Wildlife Wild Botswan.mp4" | |
return process_and_update(sample_video_path, "Elephant") | |
use_sample_button.click(use_sample_video, | |
inputs=None, | |
outputs=[frame_paths, frame_scores, output_image, heatmap_output, error_output, frame_slider]) | |
# Layout | |
with gr.Row(): | |
with gr.Column(scale=2): | |
output_image | |
with gr.Column(scale=1): | |
sample_video_frame | |
use_sample_button | |
return app | |
if __name__ == "__main__": | |
app = gradio_app() | |
app.launch() | |
# Cleanup temporary files | |
def cleanup(): | |
for path in frame_paths.value: | |
if os.path.exists(path): | |
os.remove(path) | |
if os.path.exists(temp_dir): | |
shutil.rmtree(temp_dir) | |
# Make sure to call cleanup when the app is closed | |
# This might require additional setup depending on how you're running the app |