boostedhug commited on
Commit
d664d74
·
verified ·
1 Parent(s): cdb4ebf
Files changed (1) hide show
  1. app.py +68 -0
app.py ADDED
@@ -0,0 +1,68 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from fastapi import FastAPI, File, UploadFile
2
+ from fastapi.responses import JSONResponse
3
+ from fastapi.middleware.cors import CORSMiddleware
4
+ from transformers import DetrImageProcessor, DetrForObjectDetection
5
+ from PIL import Image, ImageDraw
6
+ import io
7
+ import torch
8
+
9
+ # Initialize FastAPI app
10
+ app = FastAPI()
11
+
12
+ # Add CORS middleware to allow communication with external clients
13
+ app.add_middleware(
14
+ CORSMiddleware,
15
+ allow_origins=["*"], # Change this to the specific domain in production
16
+ allow_methods=["*"],
17
+ allow_headers=["*"],
18
+ )
19
+
20
+ # Load the model and processor
21
+ model = DetrForObjectDetection.from_pretrained("hilmantm/detr-traffic-accident-detection")
22
+ processor = DetrImageProcessor.from_pretrained("hilmantm/detr-traffic-accident-detection")
23
+
24
+ def detect_accident(image):
25
+ """Runs accident detection on the input image."""
26
+ inputs = processor(images=image, return_tensors="pt")
27
+ outputs = model(**inputs)
28
+
29
+ # Post-process results
30
+ target_sizes = torch.tensor([image.size[::-1]])
31
+ results = processor.post_process_object_detection(outputs, target_sizes=target_sizes, threshold=0.9)[0]
32
+
33
+ # Draw bounding boxes and labels
34
+ draw = ImageDraw.Draw(image)
35
+ for box, label, score in zip(results["boxes"], results["labels"], results["scores"]):
36
+ x_min, y_min, x_max, y_max = box
37
+ draw.rectangle((x_min, y_min, x_max, y_max), outline="red", width=3)
38
+ draw.text((x_min, y_min), f"{label}: {score:.2f}", fill="red")
39
+
40
+ return image
41
+
42
+ @app.post("/detect_accident")
43
+ async def process_frame(file: UploadFile = File(...)):
44
+ """API endpoint to process an uploaded frame."""
45
+ try:
46
+ # Read and preprocess image
47
+ image = Image.open(io.BytesIO(await file.read()))
48
+ image = image.resize((256, int(image.height * 256 / image.width))) # Resize while maintaining aspect ratio
49
+
50
+ # Detect accidents
51
+ processed_image = detect_accident(image)
52
+
53
+ # Save the processed image into bytes to send back
54
+ img_byte_arr = io.BytesIO()
55
+ processed_image.save(img_byte_arr, format="JPEG")
56
+ img_byte_arr.seek(0)
57
+
58
+ return JSONResponse(
59
+ content={"status": "success", "message": "Frame processed successfully"},
60
+ media_type="image/jpeg"
61
+ )
62
+ except Exception as e:
63
+ return JSONResponse(content={"status": "error", "message": str(e)}, status_code=500)
64
+
65
+ # Run the app
66
+ if __name__ == "__main__":
67
+ import uvicorn
68
+ uvicorn.run(app, host="0.0.0.0", port=8000)