blitzkrieg0000 commited on
Commit
3ddfba1
·
verified ·
1 Parent(s): d4b45ad

Update Lib/Core.py

Browse files
Files changed (1) hide show
  1. Lib/Core.py +120 -52
Lib/Core.py CHANGED
@@ -1,87 +1,155 @@
 
 
 
 
1
  import cv2
2
  import numpy as np
3
  import torch
4
  from matplotlib import pyplot as plt
5
  from ultralytics import YOLO
6
- from ultralytics.engine.results import Masks
 
 
 
 
7
 
8
  class CablePoleSegmentation():
9
- def __init__(self, MODEL_PATH=None, retina_mask=False):
10
- if not MODEL_PATH:
11
- MODEL_PATH = "./weight/yolov9c-cable-seg.pt"
12
  self._RetinaMask=retina_mask
13
- self.Model = YOLO(MODEL_PATH) # load a custom model
 
14
 
15
 
16
- def RescaleTheMask(self, orijinal_image, masks):
17
- _masks = []
18
- for contour in masks:
19
- b_mask = np.zeros(orijinal_image.shape[:2], np.uint8)
20
- contour = contour.astype(np.int32)
21
- contour = contour.reshape(-1, 1, 2)
22
- mask = cv2.drawContours(b_mask, [contour], -1, (1, 1, 1), cv2.FILLED)
23
- _masks += [mask]
24
- return _masks
25
 
26
 
27
- def Process(self, image):
28
- with torch.no_grad():
29
- results = self.Model(
30
- image,
31
- save=False,
32
- show_boxes=False,
33
- project="./result/",
34
- conf=0.5,
35
- retina_masks=self._RetinaMask,
36
- stream=True
37
- )
 
 
 
38
 
39
  with torch.no_grad():
40
  for result in results:
41
- maskCountours = result.masks.xy
42
- boxes = result.boxes.xyxy.int().cpu().numpy()
43
- classes = result.boxes.cls.cpu().numpy()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
44
 
45
- rescaledMasks = self.RescaleTheMask(image, maskCountours)
46
- return rescaledMasks, boxes, classes, result.plot()
47
 
48
 
 
 
 
 
49
 
50
- def PlotResults(self, masks, boxes, classes, original_image, result_image, mask, cable_mask):
51
- fig, axs = plt.subplots(nrows=2, ncols=2, figsize=(27,15))
52
- axs[0][0].imshow(original_image)
53
- axs[0][0].set_title("Orijinal Görüntü")
54
 
55
- axs[0][1].imshow(mask)
56
- axs[0][1].set_title("Segmentasyon Maskesi")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
57
 
 
 
58
 
59
- cv2.imwrite("cable_mask.png", cable_mask)
60
- axs[1][0].imshow(cable_mask)
61
- axs[1][0].set_title("Seçilen")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
62
 
63
- axs[1][1].imshow(result_image)
64
- axs[1][1].set_title("Sonuç")
65
- plt.show()
66
 
67
 
68
 
69
  if "__main__" == __name__:
70
- test = "data/16_3450.png"
71
  image = cv2.imread(test)
72
- model = CablePoleSegmentation(retina_mask=True)
73
- masks, boxes, classes, result_plot = model.Process(image)
74
 
 
 
75
 
76
- fig, axs = plt.subplots(nrows=2, ncols=2, figsize=(27,15))
77
- axs[0][0].imshow(image)
78
- axs[0][0].set_title("Orijinal Görüntü")
79
- axs[1][1].imshow(np.any(masks, axis=0))
80
- axs[1][1].set_title("Sonuç")
81
- plt.show()
82
 
83
 
84
- # model.PlotResults(*model.Process(image))
 
 
 
85
 
 
 
 
 
 
 
 
 
86
 
87
 
 
1
+ import os
2
+ import sys
3
+ sys.path.append(os.getcwd())
4
+
5
  import cv2
6
  import numpy as np
7
  import torch
8
  from matplotlib import pyplot as plt
9
  from ultralytics import YOLO
10
+ from Lib.Consts import LABELS, COLOR_MAP, COLOR_MAP_RGB
11
+
12
+
13
+ DEVICE = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
14
+
15
 
16
  class CablePoleSegmentation():
17
+ def __init__(self, model_path=None, retina_mask=False):
18
+ if not model_path:
19
+ model_path = "./weight/yolov9c-cable-seg.pt"
20
  self._RetinaMask=retina_mask
21
+ self.Model = None
22
+ self.PrepareModel(model_path)
23
 
24
 
25
+ def PrepareModel(self, model_path):
26
+ self.Model = YOLO(model_path)
27
+ self.Model.fuse()
 
 
 
 
 
 
28
 
29
 
30
+ def ScaleMasks(self, masks: torch.Tensor, shape: tuple) -> torch.Tensor:
31
+ masks = masks.unsqueeze(0)
32
+ interpolatedMask:torch.Tensor = torch.nn.functional.interpolate(masks, shape, mode="nearest")
33
+ interpolatedMask = interpolatedMask.squeeze(0)
34
+ return interpolatedMask
35
+
36
+
37
+ def ParseResults(self, results, threshold=0.5, scale_masks=True):
38
+ batches = []
39
+
40
+ SCORES = torch.Tensor([]).to(DEVICE)
41
+ CLASSES = torch.Tensor([]).to(DEVICE)
42
+ MASKS = torch.Tensor([]).to(DEVICE)
43
+ BOXES = torch.Tensor([]).to(DEVICE)
44
 
45
  with torch.no_grad():
46
  for result in results:
47
+ original_shape = result.orig_shape
48
+ _scores = result.boxes.conf # 7
49
+ _classes = result.boxes.cls # 7
50
+ _masks = result.masks.data # 7, 480, 640
51
+ _boxes = result.boxes.xyxy # 7, 4
52
+
53
+ # Threshold Filter
54
+ conditions = _scores > threshold
55
+ SCORES = torch.cat((SCORES, _scores[conditions]), dim=0)
56
+ CLASSES = torch.cat((CLASSES, _classes[conditions]), dim=0)
57
+ BOXES = torch.cat((BOXES, _boxes[conditions]), dim=0)
58
+ mask = _masks[conditions]
59
+
60
+ if mask.shape[0] == 0:
61
+ continue
62
+
63
+ if scale_masks:
64
+ mask = self.ScaleMasks(mask, original_shape[:2])
65
+
66
+ MASKS = torch.cat((MASKS, mask), dim=0)
67
+
68
+ batches += [(SCORES, CLASSES, MASKS, BOXES)]
69
 
70
+ return batches
 
71
 
72
 
73
+ def DrawResults(self, image, scores: torch.Tensor, classes: torch.Tensor, masks: torch.Tensor, boxes: torch.Tensor, labels:dict=LABELS, class_filter:list=None):
74
+ _image = np.array(image).copy()
75
+ _image = cv2.cvtColor(_image, cv2.COLOR_BGR2RGB)
76
+ maskCanvas = np.zeros_like(_image)
77
 
 
 
 
 
78
 
79
+ with torch.no_grad():
80
+ scores = scores.cpu().numpy()
81
+ classes = classes.cpu().numpy().astype(np.int32)
82
+ masks = masks.cpu().numpy()
83
+ boxes = boxes.cpu().numpy()
84
+ colors = list(COLOR_MAP_RGB.values())
85
+
86
+ for score, cls, mask, box in zip(scores, classes, masks, boxes):
87
+ label = labels[cls]
88
+ _color = colors[cls]
89
+
90
+ if class_filter and cls not in class_filter:
91
+ continue
92
+
93
+ box = box.astype(np.int32)
94
+ mask = (cv2.cvtColor(mask, cv2.COLOR_GRAY2BGR)*_color).astype(np.uint8)
95
+ maskCanvas = cv2.addWeighted(maskCanvas, 1.0, mask, 1.0, 0)
96
+ maskCanvas = cv2.rectangle(maskCanvas, (box[0], box[1]), (box[2], box[3]), color=_color, thickness=5) # Red color for bounding box
97
+ maskCanvas = cv2.putText(maskCanvas, f"{label} : {score:.2f}", (box[0], box[1] - 10), cv2.FONT_HERSHEY_SIMPLEX, 0.9, color=_color, thickness=2)
98
+
99
+ canvas = cv2.addWeighted(_image, 1.0, maskCanvas.astype(np.uint8), 0.5, 0)
100
+ return canvas, maskCanvas
101
 
102
+
103
+ def Process(self, image, model_threshold=0.6, overall_threshold=0.6, iou=0.7, class_filter:list=None):
104
 
105
+ with torch.no_grad():
106
+ results = self.Model(
107
+ image,
108
+ save=False,
109
+ show_boxes=False,
110
+ project="./inference/",
111
+ conf=model_threshold,
112
+ iou=iou,
113
+ retina_masks=False,
114
+ stream=True,
115
+ classes=class_filter,
116
+ device=DEVICE
117
+ )
118
+
119
+ batches = self.ParseResults(results, threshold=overall_threshold, scale_masks=True)
120
+
121
+ return batches
122
+
123
 
 
 
 
124
 
125
 
126
 
127
  if "__main__" == __name__:
128
+ test = "data/DJI_20240905091530_0003_W.JPG"
129
  image = cv2.imread(test)
130
+ model = CablePoleSegmentation(retina_mask=False)
131
+ batches = model.Process(image)
132
 
133
+ if len(batches) == 0:
134
+ exit()
135
 
136
+ scores, classes, masks, boxes = batches[0] # First
137
+ canvas, mask = model.DrawResults(image, scores, classes, masks, boxes, class_filter=None)
138
+ print(canvas.shape)
 
 
 
139
 
140
 
141
+ #! Plot
142
+ fig, axs = plt.subplots(1, 3, figsize=(27, 15))
143
+ axs[0].imshow(cv2.cvtColor(image, cv2.COLOR_BGR2RGB))
144
+ axs[0].set_title("Orijinal Görüntü")
145
 
146
+ axs[1].imshow(mask)
147
+ axs[1].set_title("Segmentasyon Maskesi")
148
+
149
+ axs[2].imshow(canvas)
150
+ axs[2].set_title("Sonuç")
151
+
152
+ plt.tight_layout()
153
+ plt.show()
154
 
155