import gradio as gr from transformers import DetrImageProcessor, DetrForObjectDetection from PIL import Image, ImageDraw import torch # Load the model and processor model = DetrForObjectDetection.from_pretrained("hilmantm/detr-traffic-accident-detection") processor = DetrImageProcessor.from_pretrained("hilmantm/detr-traffic-accident-detection") def detect_accident(image): """Runs accident detection on the input image.""" inputs = processor(images=image, return_tensors="pt") outputs = model(**inputs) # Post-process results target_sizes = torch.tensor([image.size[::-1]]) results = processor.post_process_object_detection(outputs, target_sizes=target_sizes, threshold=0.9)[0] # Draw bounding boxes and labels draw = ImageDraw.Draw(image) for box, label, score in zip(results["boxes"], results["labels"], results["scores"]): x_min, y_min, x_max, y_max = box draw.rectangle((x_min, y_min, x_max, y_max), outline="red", width=3) label_name = model.config.id2label[label.item()] draw.text((x_min, y_min), f"{label_name}: {score:.2f}", fill="red") return image # Define the Gradio interface def process_image(image): processed_image = detect_accident(image) return processed_image # Launch the Gradio app interface = gr.Interface(fn=process_image, inputs=gr.Image(type="pil"), outputs=gr.Image(type="pil")) interface.launch(server_name="0.0.0.0", server_port=8000) # from fastapi import FastAPI, File, UploadFile # from fastapi.responses import StreamingResponse, JSONResponse # from fastapi.middleware.cors import CORSMiddleware # from transformers import DetrImageProcessor, DetrForObjectDetection # from PIL import Image, ImageDraw # import io # import torch # # Initialize FastAPI app # app = FastAPI() # # Add CORS middleware to allow communication with external clients # app.add_middleware( # CORSMiddleware, # allow_origins=["*"], # Change this to specific domains in production # allow_methods=["*"], # allow_headers=["*"], # ) # # Load the model and processor # model = DetrForObjectDetection.from_pretrained("hilmantm/detr-traffic-accident-detection") # processor = DetrImageProcessor.from_pretrained("hilmantm/detr-traffic-accident-detection") # def detect_accident(image): # """Runs accident detection on the input image.""" # inputs = processor(images=image, return_tensors="pt") # outputs = model(**inputs) # # Post-process results # target_sizes = torch.tensor([image.size[::-1]]) # results = processor.post_process_object_detection(outputs, target_sizes=target_sizes, threshold=0.9)[0] # # Draw bounding boxes and labels # draw = ImageDraw.Draw(image) # for box, label, score in zip(results["boxes"], results["labels"], results["scores"]): # x_min, y_min, x_max, y_max = box # draw.rectangle((x_min, y_min, x_max, y_max), outline="red", width=3) # label_name = model.config.id2label[label.item()] # draw.text((x_min, y_min), f"{label_name}: {score:.2f}", fill="red") # return image # @app.post("/detect_accident") # async def process_frame(file: UploadFile = File(...)): # """API endpoint to process an uploaded frame.""" # try: # # Read and preprocess image # image = Image.open(io.BytesIO(await file.read())) # image = image.convert("RGB") # Ensure compatibility with the model # # Detect accidents # processed_image = detect_accident(image) # # Save the processed image into bytes to send back # img_byte_arr = io.BytesIO() # processed_image.save(img_byte_arr, format="JPEG") # img_byte_arr.seek(0) # # Return the image as a streaming response # return StreamingResponse(img_byte_arr, media_type="image/jpeg") # except Exception as e: # return JSONResponse(content={"status": "error", "message": str(e)}, status_code=500) # # Run the app # if __name__ == "__main__": # import uvicorn # uvicorn.run(app, host="0.0.0.0", port=8000) # import gradio as gr # from transformers import DetrImageProcessor, DetrForObjectDetection # from PIL import Image, ImageDraw # import torch # import cv2 # import numpy as np # # Load model and processor # model = DetrForObjectDetection.from_pretrained("hilmantm/detr-traffic-accident-detection") # processor = DetrImageProcessor.from_pretrained("hilmantm/detr-traffic-accident-detection") # # Function to detect accidents in an image # def detect_accident(image): # inputs = processor(images=image, return_tensors="pt") # outputs = model(**inputs) # # Post-process the results # target_sizes = torch.tensor([image.size[::-1]]) # results = processor.post_process_object_detection(outputs, target_sizes=target_sizes, threshold=0.9)[0] # # Draw boxes and labels on the image # draw = ImageDraw.Draw(image) # for box, label, score in zip(results["boxes"], results["labels"], results["scores"]): # x_min, y_min, x_max, y_max = box # draw.rectangle((x_min, y_min, x_max, y_max), outline="red", width=3) # draw.text((x_min, y_min), f"{label}: {score:.2f}", fill="red") # return image # # Function to detect accidents frame-by-frame in a video # def detect_accident_in_video(video_path): # cap = cv2.VideoCapture(video_path) # frames = [] # while True: # ret, frame = cap.read() # if not ret: # break # # Convert frame to PIL Image # frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) # pil_frame = Image.fromarray(frame_rgb) # # Run accident detection on the frame # processed_frame = detect_accident(pil_frame) # # Convert PIL image back to numpy array for video # frames.append(np.array(processed_frame)) # cap.release() # # Save processed frames as output video # height, width, _ = frames[0].shape # out = cv2.VideoWriter("output.mp4", cv2.VideoWriter_fourcc(*"mp4v"), 10, (width, height)) # for frame in frames: # out.write(cv2.cvtColor(frame, cv2.COLOR_RGB2BGR)) # out.release() # return "output.mp4" # # Gradio app interface # with gr.Blocks() as interface: # gr.Markdown("# Traffic Accident Detection") # gr.Markdown( # "Upload an image or video to detect traffic accidents using the DETR model. " # "For videos, the system processes frame by frame and outputs a new video with accident detection." # ) # # Input components # with gr.Tab("Image Input"): # image_input = gr.Image(type="pil", label="Upload Image") # image_output = gr.Image(type="pil", label="Detection Output") # image_button = gr.Button("Detect Accidents in Image") # with gr.Tab("Video Input"): # video_input = gr.Video(label="Upload Video") # video_output = gr.Video(label="Processed Video") # video_button = gr.Button("Detect Accidents in Video") # # Define behaviors # image_button.click(fn=detect_accident, inputs=image_input, outputs=image_output) # video_button.click(fn=detect_accident_in_video, inputs=video_input, outputs=video_output) # interface.launch()