Spaces:
Runtime error
Runtime error
import gradio as gr | |
from transformers import DetrImageProcessor, DetrForObjectDetection | |
from PIL import Image, ImageDraw | |
import torch | |
# Load the model and processor | |
model = DetrForObjectDetection.from_pretrained("hilmantm/detr-traffic-accident-detection") | |
processor = DetrImageProcessor.from_pretrained("hilmantm/detr-traffic-accident-detection") | |
def detect_accident(image): | |
"""Runs accident detection on the input image.""" | |
inputs = processor(images=image, return_tensors="pt") | |
outputs = model(**inputs) | |
# Post-process results | |
target_sizes = torch.tensor([image.size[::-1]]) | |
results = processor.post_process_object_detection(outputs, target_sizes=target_sizes, threshold=0.9)[0] | |
# Draw bounding boxes and labels | |
draw = ImageDraw.Draw(image) | |
for box, label, score in zip(results["boxes"], results["labels"], results["scores"]): | |
x_min, y_min, x_max, y_max = box | |
draw.rectangle((x_min, y_min, x_max, y_max), outline="red", width=3) | |
label_name = model.config.id2label[label.item()] | |
draw.text((x_min, y_min), f"{label_name}: {score:.2f}", fill="red") | |
return image | |
# Define the Gradio interface | |
def process_image(image): | |
processed_image = detect_accident(image) | |
return processed_image | |
# Launch the Gradio app | |
interface = gr.Interface(fn=process_image, inputs=gr.Image(type="pil"), outputs=gr.Image(type="pil")) | |
interface.launch(server_name="0.0.0.0", server_port=8000) | |
# from fastapi import FastAPI, File, UploadFile | |
# from fastapi.responses import StreamingResponse, JSONResponse | |
# from fastapi.middleware.cors import CORSMiddleware | |
# from transformers import DetrImageProcessor, DetrForObjectDetection | |
# from PIL import Image, ImageDraw | |
# import io | |
# import torch | |
# # Initialize FastAPI app | |
# app = FastAPI() | |
# # Add CORS middleware to allow communication with external clients | |
# app.add_middleware( | |
# CORSMiddleware, | |
# allow_origins=["*"], # Change this to specific domains in production | |
# allow_methods=["*"], | |
# allow_headers=["*"], | |
# ) | |
# # Load the model and processor | |
# model = DetrForObjectDetection.from_pretrained("hilmantm/detr-traffic-accident-detection") | |
# processor = DetrImageProcessor.from_pretrained("hilmantm/detr-traffic-accident-detection") | |
# def detect_accident(image): | |
# """Runs accident detection on the input image.""" | |
# inputs = processor(images=image, return_tensors="pt") | |
# outputs = model(**inputs) | |
# # Post-process results | |
# target_sizes = torch.tensor([image.size[::-1]]) | |
# results = processor.post_process_object_detection(outputs, target_sizes=target_sizes, threshold=0.9)[0] | |
# # Draw bounding boxes and labels | |
# draw = ImageDraw.Draw(image) | |
# for box, label, score in zip(results["boxes"], results["labels"], results["scores"]): | |
# x_min, y_min, x_max, y_max = box | |
# draw.rectangle((x_min, y_min, x_max, y_max), outline="red", width=3) | |
# label_name = model.config.id2label[label.item()] | |
# draw.text((x_min, y_min), f"{label_name}: {score:.2f}", fill="red") | |
# return image | |
# @app.post("/detect_accident") | |
# async def process_frame(file: UploadFile = File(...)): | |
# """API endpoint to process an uploaded frame.""" | |
# try: | |
# # Read and preprocess image | |
# image = Image.open(io.BytesIO(await file.read())) | |
# image = image.convert("RGB") # Ensure compatibility with the model | |
# # Detect accidents | |
# processed_image = detect_accident(image) | |
# # Save the processed image into bytes to send back | |
# img_byte_arr = io.BytesIO() | |
# processed_image.save(img_byte_arr, format="JPEG") | |
# img_byte_arr.seek(0) | |
# # Return the image as a streaming response | |
# return StreamingResponse(img_byte_arr, media_type="image/jpeg") | |
# except Exception as e: | |
# return JSONResponse(content={"status": "error", "message": str(e)}, status_code=500) | |
# # Run the app | |
# if __name__ == "__main__": | |
# import uvicorn | |
# uvicorn.run(app, host="0.0.0.0", port=8000) | |
# import gradio as gr | |
# from transformers import DetrImageProcessor, DetrForObjectDetection | |
# from PIL import Image, ImageDraw | |
# import torch | |
# import cv2 | |
# import numpy as np | |
# # Load model and processor | |
# model = DetrForObjectDetection.from_pretrained("hilmantm/detr-traffic-accident-detection") | |
# processor = DetrImageProcessor.from_pretrained("hilmantm/detr-traffic-accident-detection") | |
# # Function to detect accidents in an image | |
# def detect_accident(image): | |
# inputs = processor(images=image, return_tensors="pt") | |
# outputs = model(**inputs) | |
# # Post-process the results | |
# target_sizes = torch.tensor([image.size[::-1]]) | |
# results = processor.post_process_object_detection(outputs, target_sizes=target_sizes, threshold=0.9)[0] | |
# # Draw boxes and labels on the image | |
# draw = ImageDraw.Draw(image) | |
# for box, label, score in zip(results["boxes"], results["labels"], results["scores"]): | |
# x_min, y_min, x_max, y_max = box | |
# draw.rectangle((x_min, y_min, x_max, y_max), outline="red", width=3) | |
# draw.text((x_min, y_min), f"{label}: {score:.2f}", fill="red") | |
# return image | |
# # Function to detect accidents frame-by-frame in a video | |
# def detect_accident_in_video(video_path): | |
# cap = cv2.VideoCapture(video_path) | |
# frames = [] | |
# while True: | |
# ret, frame = cap.read() | |
# if not ret: | |
# break | |
# # Convert frame to PIL Image | |
# frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) | |
# pil_frame = Image.fromarray(frame_rgb) | |
# # Run accident detection on the frame | |
# processed_frame = detect_accident(pil_frame) | |
# # Convert PIL image back to numpy array for video | |
# frames.append(np.array(processed_frame)) | |
# cap.release() | |
# # Save processed frames as output video | |
# height, width, _ = frames[0].shape | |
# out = cv2.VideoWriter("output.mp4", cv2.VideoWriter_fourcc(*"mp4v"), 10, (width, height)) | |
# for frame in frames: | |
# out.write(cv2.cvtColor(frame, cv2.COLOR_RGB2BGR)) | |
# out.release() | |
# return "output.mp4" | |
# # Gradio app interface | |
# with gr.Blocks() as interface: | |
# gr.Markdown("# Traffic Accident Detection") | |
# gr.Markdown( | |
# "Upload an image or video to detect traffic accidents using the DETR model. " | |
# "For videos, the system processes frame by frame and outputs a new video with accident detection." | |
# ) | |
# # Input components | |
# with gr.Tab("Image Input"): | |
# image_input = gr.Image(type="pil", label="Upload Image") | |
# image_output = gr.Image(type="pil", label="Detection Output") | |
# image_button = gr.Button("Detect Accidents in Image") | |
# with gr.Tab("Video Input"): | |
# video_input = gr.Video(label="Upload Video") | |
# video_output = gr.Video(label="Processed Video") | |
# video_button = gr.Button("Detect Accidents in Video") | |
# # Define behaviors | |
# image_button.click(fn=detect_accident, inputs=image_input, outputs=image_output) | |
# video_button.click(fn=detect_accident_in_video, inputs=video_input, outputs=video_output) | |
# interface.launch() |