hilmantm commited on
Commit
2d8940f
·
1 Parent(s): 3d2dd18

fix: change inference function

Browse files
Files changed (2) hide show
  1. app.py +43 -33
  2. requirements.txt +6 -1
app.py CHANGED
@@ -2,18 +2,18 @@ from transformers import DetrForObjectDetection, DetrImageProcessor
2
  import torch
3
  from PIL import Image
4
  import matplotlib.pyplot as plt
5
- import io
6
  import gradio as gr
7
- from random import choice
 
 
 
8
 
9
  DEVICE = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
10
  CHECKPOINT = 'facebook/detr-resnet-50'
11
  CHECKPOINT_ACCIDENT_DETECTION = 'hilmantm/detr-traffic-accident-detection'
12
  CONFIDENCE_TRESHOLD = 0.5
13
  IOU_TRESHOLD = 0.8
14
- COLORS = ["#ff7f7f", "#ff7fbf", "#ff7fff", "#bf7fff",
15
- "#7f7fff", "#7fbfff", "#7fffff", "#7fffbf",
16
- "#7fff7f", "#bfff7f", "#ffff7f", "#ffbf7f"]
17
  fdic = {
18
  "family" : "Impact",
19
  "style" : "italic",
@@ -26,34 +26,44 @@ image_processor = DetrImageProcessor.from_pretrained(CHECKPOINT)
26
  model = DetrForObjectDetection.from_pretrained(CHECKPOINT_ACCIDENT_DETECTION)
27
  model.to(DEVICE)
28
 
29
- def get_figure(in_pil_img, in_results):
30
- plt.figure(figsize=(16, 16))
31
- plt.imshow(in_pil_img)
32
- ax = plt.gca()
33
-
34
- for prediction in in_results:
35
- selected_color = choice(COLORS)
36
-
37
- x, y = prediction['box']['xmin'], prediction['box']['ymin'],
38
- w, h = prediction['box']['xmax'] - prediction['box']['xmin'], prediction['box']['ymax'] - prediction['box']['ymin']
39
-
40
- ax.add_patch(plt.Rectangle((x, y), w, h, fill=False, color=selected_color, linewidth=3))
41
- ax.text(x, y, f"{prediction['label']}: {round(prediction['score']*100, 1)}%", fontdict=fdic)
42
-
43
- plt.axis("off")
44
-
45
- return plt.gcf()
46
-
47
- def inference_from_image(image):
48
- results = model(image)
49
- figure = get_figure(image, results)
50
-
51
- buf = io.BytesIO()
52
- figure.savefig(buf, bbox_inches='tight')
53
- buf.seek(0)
54
- output_pil_img = Image.open(buf)
55
-
56
- return output_pil_img
 
 
 
 
 
 
 
 
 
 
57
 
58
  with gr.Blocks() as demo:
59
  gr.Markdown(
 
2
  import torch
3
  from PIL import Image
4
  import matplotlib.pyplot as plt
 
5
  import gradio as gr
6
+ import cv2
7
+ import torch
8
+ import supervision as sv
9
+ import numpy as np
10
 
11
  DEVICE = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
12
  CHECKPOINT = 'facebook/detr-resnet-50'
13
  CHECKPOINT_ACCIDENT_DETECTION = 'hilmantm/detr-traffic-accident-detection'
14
  CONFIDENCE_TRESHOLD = 0.5
15
  IOU_TRESHOLD = 0.8
16
+ NMS_TRESHOLD = 0.5
 
 
17
  fdic = {
18
  "family" : "Impact",
19
  "style" : "italic",
 
26
  model = DetrForObjectDetection.from_pretrained(CHECKPOINT_ACCIDENT_DETECTION)
27
  model.to(DEVICE)
28
 
29
+ # use this function only for DETR Algorithm
30
+ # def detect_object(model, test_image_path, nms_treshold = 0.5):
31
+
32
+ def inference_from_image(pil_image):
33
+
34
+ box_annotator = sv.BoxAnnotator()
35
+ numpy_image = np.array(pil_image)
36
+ # Convert BGR to RGB if needed (OpenCV uses BGR by default)
37
+ opencv_image_bgr = cv2.cvtColor(numpy_image, cv2.COLOR_RGB2BGR)
38
+ image = cv2.cvtColor(opencv_image_bgr, cv2.COLOR_BGR2RGB)
39
+
40
+ # inference
41
+ with torch.no_grad():
42
+ # load image and predict
43
+ inputs = image_processor(images=image, return_tensors='pt').to(DEVICE)
44
+ outputs = model(**inputs)
45
+ # post-process
46
+ target_sizes = torch.tensor([image.shape[:2]]).to(DEVICE)
47
+ results = image_processor.post_process_object_detection(
48
+ outputs=outputs,
49
+ threshold=CONFIDENCE_TRESHOLD,
50
+ target_sizes=target_sizes
51
+ )[0]
52
+
53
+ if results['scores'].shape[0] != 0 or results['labels'].shape[0] != 0:
54
+ # annotate
55
+ detections = sv.Detections.from_transformers(transformers_results=results).with_nms(threshold=NMS_TRESHOLD)
56
+ labels = [
57
+ f"{model.config.id2label[class_id]} {confidence:0.2f}"
58
+ for _, confidence, class_id, _
59
+ in detections
60
+ ]
61
+ frame = box_annotator.annotate(scene=image.copy(), detections=detections, labels=labels)
62
+ result_image = Image.fromarray(frame)
63
+ return result_image
64
+ else:
65
+ print("No object detected")
66
+ return None
67
 
68
  with gr.Blocks() as demo:
69
  gr.Markdown(
requirements.txt CHANGED
@@ -1,2 +1,7 @@
1
  torch
2
- transformers[timm]
 
 
 
 
 
 
1
  torch
2
+ transformers[timm]
3
+ supervision==0.3.0
4
+ pytorch-lightning
5
+ roboflow
6
+ timm
7
+ numpy