boostedhug's picture
stupid fixes
cdf733f verified
import gradio as gr
from transformers import DetrImageProcessor, DetrForObjectDetection
from PIL import Image, ImageDraw
import torch
# Load the model and processor
model = DetrForObjectDetection.from_pretrained("hilmantm/detr-traffic-accident-detection")
processor = DetrImageProcessor.from_pretrained("hilmantm/detr-traffic-accident-detection")
def detect_accident(image):
"""Runs accident detection on the input image."""
inputs = processor(images=image, return_tensors="pt")
outputs = model(**inputs)
# Post-process results
target_sizes = torch.tensor([image.size[::-1]])
results = processor.post_process_object_detection(outputs, target_sizes=target_sizes, threshold=0.9)[0]
# Draw bounding boxes and labels
draw = ImageDraw.Draw(image)
for box, label, score in zip(results["boxes"], results["labels"], results["scores"]):
x_min, y_min, x_max, y_max = box
draw.rectangle((x_min, y_min, x_max, y_max), outline="red", width=3)
label_name = model.config.id2label[label.item()]
draw.text((x_min, y_min), f"{label_name}: {score:.2f}", fill="red")
return image
# Define the Gradio interface
def process_image(image):
processed_image = detect_accident(image)
return processed_image
# Launch the Gradio app
interface = gr.Interface(fn=process_image, inputs=gr.Image(type="pil"), outputs=gr.Image(type="pil"))
interface.launch(server_name="0.0.0.0", server_port=8000)
# from fastapi import FastAPI, File, UploadFile
# from fastapi.responses import StreamingResponse, JSONResponse
# from fastapi.middleware.cors import CORSMiddleware
# from transformers import DetrImageProcessor, DetrForObjectDetection
# from PIL import Image, ImageDraw
# import io
# import torch
# # Initialize FastAPI app
# app = FastAPI()
# # Add CORS middleware to allow communication with external clients
# app.add_middleware(
# CORSMiddleware,
# allow_origins=["*"], # Change this to specific domains in production
# allow_methods=["*"],
# allow_headers=["*"],
# )
# # Load the model and processor
# model = DetrForObjectDetection.from_pretrained("hilmantm/detr-traffic-accident-detection")
# processor = DetrImageProcessor.from_pretrained("hilmantm/detr-traffic-accident-detection")
# def detect_accident(image):
# """Runs accident detection on the input image."""
# inputs = processor(images=image, return_tensors="pt")
# outputs = model(**inputs)
# # Post-process results
# target_sizes = torch.tensor([image.size[::-1]])
# results = processor.post_process_object_detection(outputs, target_sizes=target_sizes, threshold=0.9)[0]
# # Draw bounding boxes and labels
# draw = ImageDraw.Draw(image)
# for box, label, score in zip(results["boxes"], results["labels"], results["scores"]):
# x_min, y_min, x_max, y_max = box
# draw.rectangle((x_min, y_min, x_max, y_max), outline="red", width=3)
# label_name = model.config.id2label[label.item()]
# draw.text((x_min, y_min), f"{label_name}: {score:.2f}", fill="red")
# return image
# @app.post("/detect_accident")
# async def process_frame(file: UploadFile = File(...)):
# """API endpoint to process an uploaded frame."""
# try:
# # Read and preprocess image
# image = Image.open(io.BytesIO(await file.read()))
# image = image.convert("RGB") # Ensure compatibility with the model
# # Detect accidents
# processed_image = detect_accident(image)
# # Save the processed image into bytes to send back
# img_byte_arr = io.BytesIO()
# processed_image.save(img_byte_arr, format="JPEG")
# img_byte_arr.seek(0)
# # Return the image as a streaming response
# return StreamingResponse(img_byte_arr, media_type="image/jpeg")
# except Exception as e:
# return JSONResponse(content={"status": "error", "message": str(e)}, status_code=500)
# # Run the app
# if __name__ == "__main__":
# import uvicorn
# uvicorn.run(app, host="0.0.0.0", port=8000)
# import gradio as gr
# from transformers import DetrImageProcessor, DetrForObjectDetection
# from PIL import Image, ImageDraw
# import torch
# import cv2
# import numpy as np
# # Load model and processor
# model = DetrForObjectDetection.from_pretrained("hilmantm/detr-traffic-accident-detection")
# processor = DetrImageProcessor.from_pretrained("hilmantm/detr-traffic-accident-detection")
# # Function to detect accidents in an image
# def detect_accident(image):
# inputs = processor(images=image, return_tensors="pt")
# outputs = model(**inputs)
# # Post-process the results
# target_sizes = torch.tensor([image.size[::-1]])
# results = processor.post_process_object_detection(outputs, target_sizes=target_sizes, threshold=0.9)[0]
# # Draw boxes and labels on the image
# draw = ImageDraw.Draw(image)
# for box, label, score in zip(results["boxes"], results["labels"], results["scores"]):
# x_min, y_min, x_max, y_max = box
# draw.rectangle((x_min, y_min, x_max, y_max), outline="red", width=3)
# draw.text((x_min, y_min), f"{label}: {score:.2f}", fill="red")
# return image
# # Function to detect accidents frame-by-frame in a video
# def detect_accident_in_video(video_path):
# cap = cv2.VideoCapture(video_path)
# frames = []
# while True:
# ret, frame = cap.read()
# if not ret:
# break
# # Convert frame to PIL Image
# frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
# pil_frame = Image.fromarray(frame_rgb)
# # Run accident detection on the frame
# processed_frame = detect_accident(pil_frame)
# # Convert PIL image back to numpy array for video
# frames.append(np.array(processed_frame))
# cap.release()
# # Save processed frames as output video
# height, width, _ = frames[0].shape
# out = cv2.VideoWriter("output.mp4", cv2.VideoWriter_fourcc(*"mp4v"), 10, (width, height))
# for frame in frames:
# out.write(cv2.cvtColor(frame, cv2.COLOR_RGB2BGR))
# out.release()
# return "output.mp4"
# # Gradio app interface
# with gr.Blocks() as interface:
# gr.Markdown("# Traffic Accident Detection")
# gr.Markdown(
# "Upload an image or video to detect traffic accidents using the DETR model. "
# "For videos, the system processes frame by frame and outputs a new video with accident detection."
# )
# # Input components
# with gr.Tab("Image Input"):
# image_input = gr.Image(type="pil", label="Upload Image")
# image_output = gr.Image(type="pil", label="Detection Output")
# image_button = gr.Button("Detect Accidents in Image")
# with gr.Tab("Video Input"):
# video_input = gr.Video(label="Upload Video")
# video_output = gr.Video(label="Processed Video")
# video_button = gr.Button("Detect Accidents in Video")
# # Define behaviors
# image_button.click(fn=detect_accident, inputs=image_input, outputs=image_output)
# video_button.click(fn=detect_accident_in_video, inputs=video_input, outputs=video_output)
# interface.launch()