princeml commited on
Commit
d939c54
·
1 Parent(s): 3d84a0d

Create models.py

Browse files
Files changed (1) hide show
  1. models.py +530 -0
models.py ADDED
@@ -0,0 +1,530 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import cv2
3
+ import os
4
+ import json
5
+ from tqdm import tqdm
6
+ from glob import glob
7
+ import matplotlib.pyplot as plt
8
+ import tensorflow as tf
9
+ from tensorflow.keras import layers, models, optimizers
10
+
11
+ from custom_layers import yolov4_neck, yolov4_head, nms
12
+ from utils import load_weights, get_detection_data, draw_bbox, voc_ap, draw_plot_func, read_txt_to_list
13
+ from config import yolo_config
14
+ from loss import yolo_loss
15
+
16
+
17
+ class Yolov4(object):
18
+ def __init__(self,
19
+ weight_path=None,
20
+ class_name_path='coco_classes.txt',
21
+ config=yolo_config,
22
+ ):
23
+ assert config['img_size'][0] == config['img_size'][1], 'not support yet'
24
+ assert config['img_size'][0] % config['strides'][-1] == 0, 'must be a multiple of last stride'
25
+ self.class_names = [line.strip() for line in open(class_name_path).readlines()]
26
+ self.img_size = yolo_config['img_size']
27
+ self.num_classes = len(self.class_names)
28
+ self.weight_path = weight_path
29
+ self.anchors = np.array(yolo_config['anchors']).reshape((3, 3, 2))
30
+ self.xyscale = yolo_config['xyscale']
31
+ self.strides = yolo_config['strides']
32
+ self.output_sizes = [self.img_size[0] // s for s in self.strides]
33
+ self.class_color = {name: list(np.random.random(size=3)*255) for name in self.class_names}
34
+ # Training
35
+ self.max_boxes = yolo_config['max_boxes']
36
+ self.iou_loss_thresh = yolo_config['iou_loss_thresh']
37
+ self.config = yolo_config
38
+ assert self.num_classes > 0, 'no classes detected!'
39
+
40
+ tf.keras.backend.clear_session()
41
+ if yolo_config['num_gpu'] > 1:
42
+ mirrored_strategy = tf.distribute.MirroredStrategy()
43
+ with mirrored_strategy.scope():
44
+ self.build_model(load_pretrained=True if self.weight_path else False)
45
+ else:
46
+ self.build_model(load_pretrained=True if self.weight_path else False)
47
+
48
+ def build_model(self, load_pretrained=True):
49
+ # core yolo model
50
+ input_layer = layers.Input(self.img_size)
51
+ yolov4_output = yolov4_neck(input_layer, self.num_classes)
52
+ self.yolo_model = models.Model(input_layer, yolov4_output)
53
+
54
+ # Build training model
55
+ y_true = [
56
+ layers.Input(name='input_2', shape=(52, 52, 3, (self.num_classes + 5))), # label small boxes
57
+ layers.Input(name='input_3', shape=(26, 26, 3, (self.num_classes + 5))), # label medium boxes
58
+ layers.Input(name='input_4', shape=(13, 13, 3, (self.num_classes + 5))), # label large boxes
59
+ layers.Input(name='input_5', shape=(self.max_boxes, 4)), # true bboxes
60
+ ]
61
+ loss_list = tf.keras.layers.Lambda(yolo_loss, name='yolo_loss',
62
+ arguments={'num_classes': self.num_classes,
63
+ 'iou_loss_thresh': self.iou_loss_thresh,
64
+ 'anchors': self.anchors})([*self.yolo_model.output, *y_true])
65
+ self.training_model = models.Model([self.yolo_model.input, *y_true], loss_list)
66
+
67
+ # Build inference model
68
+ yolov4_output = yolov4_head(yolov4_output, self.num_classes, self.anchors, self.xyscale)
69
+ # output: [boxes, scores, classes, valid_detections]
70
+ self.inference_model = models.Model(input_layer,
71
+ nms(yolov4_output, self.img_size, self.num_classes,
72
+ iou_threshold=self.config['iou_threshold'],
73
+ score_threshold=self.config['score_threshold']))
74
+
75
+ if load_pretrained and self.weight_path and self.weight_path.endswith('.weights'):
76
+ if self.weight_path.endswith('.weights'):
77
+ load_weights(self.yolo_model, self.weight_path)
78
+ print(f'load from {self.weight_path}')
79
+ elif self.weight_path.endswith('.h5'):
80
+ self.training_model.load_weights(self.weight_path)
81
+ print(f'load from {self.weight_path}')
82
+
83
+ self.training_model.compile(optimizer=optimizers.Adam(lr=1e-3),
84
+ loss={'yolo_loss': lambda y_true, y_pred: y_pred})
85
+
86
+ def load_model(self, path):
87
+ self.yolo_model = models.load_model(path, compile=False)
88
+ yolov4_output = yolov4_head(self.yolo_model.output, self.num_classes, self.anchors, self.xyscale)
89
+ self.inference_model = models.Model(self.yolo_model.input,
90
+ nms(yolov4_output, self.img_size, self.num_classes)) # [boxes, scores, classes, valid_detections]
91
+
92
+ def save_model(self, path):
93
+ self.yolo_model.save(path)
94
+
95
+ def preprocess_img(self, img):
96
+ img = cv2.resize(img, self.img_size[:2])
97
+ img = img / 255.
98
+ return img
99
+
100
+ def fit(self, train_data_gen, epochs, val_data_gen=None, initial_epoch=0, callbacks=None):
101
+ self.training_model.fit(train_data_gen,
102
+ steps_per_epoch=len(train_data_gen),
103
+ validation_data=val_data_gen,
104
+ validation_steps=len(val_data_gen),
105
+ epochs=epochs,
106
+ callbacks=callbacks,
107
+ initial_epoch=initial_epoch)
108
+ # raw_img: RGB
109
+ def predict_img(self, raw_img, random_color=True, plot_img=True, figsize=(10, 10), show_text=True, return_output=True):
110
+ print('img shape: ', raw_img.shape)
111
+ img = self.preprocess_img(raw_img)
112
+ imgs = np.expand_dims(img, axis=0)
113
+ pred_output = self.inference_model.predict(imgs)
114
+ detections = get_detection_data(img=raw_img,
115
+ model_outputs=pred_output,
116
+ class_names=self.class_names)
117
+
118
+ output_img = draw_bbox(raw_img, detections, cmap=self.class_color, random_color=random_color, figsize=figsize,
119
+ show_text=show_text, show_img=False)
120
+ if return_output:
121
+ return output_img, detections
122
+ else:
123
+ return detections
124
+
125
+ def predict(self, img_path, random_color=True, plot_img=True, figsize=(10, 10), show_text=True):
126
+ raw_img = img_path
127
+ return self.predict_img(raw_img, random_color, plot_img, figsize, show_text)
128
+
129
+ def export_gt(self, annotation_path, gt_folder_path):
130
+ with open(annotation_path) as file:
131
+ for line in file:
132
+ line = line.split(' ')
133
+ filename = line[0].split(os.sep)[-1].split('.')[0]
134
+ objs = line[1:]
135
+ # export txt file
136
+ with open(os.path.join(gt_folder_path, filename + '.txt'), 'w') as output_file:
137
+ for obj in objs:
138
+ x_min, y_min, x_max, y_max, class_id = [float(o) for o in obj.strip().split(',')]
139
+ output_file.write(f'{self.class_names[int(class_id)]} {x_min} {y_min} {x_max} {y_max}\n')
140
+
141
+ def export_prediction(self, annotation_path, pred_folder_path, img_folder_path, bs=2):
142
+ with open(annotation_path) as file:
143
+ img_paths = [os.path.join(img_folder_path, line.split(' ')[0].split(os.sep)[-1]) for line in file]
144
+ # print(img_paths[:20])
145
+ for batch_idx in tqdm(range(0, len(img_paths), bs)):
146
+ # print(len(img_paths), batch_idx, batch_idx*bs, (batch_idx+1)*bs)
147
+ paths = img_paths[batch_idx:batch_idx+bs]
148
+ # print(paths)
149
+ # read and process img
150
+ imgs = np.zeros((len(paths), *self.img_size))
151
+ raw_img_shapes = []
152
+ for j, path in enumerate(paths):
153
+ img = cv2.imread(path)
154
+ raw_img_shapes.append(img.shape)
155
+ img = self.preprocess_img(img)
156
+ imgs[j] = img
157
+
158
+ # process batch output
159
+ b_boxes, b_scores, b_classes, b_valid_detections = self.inference_model.predict(imgs)
160
+ for k in range(len(paths)):
161
+ num_boxes = b_valid_detections[k]
162
+ raw_img_shape = raw_img_shapes[k]
163
+ boxes = b_boxes[k, :num_boxes]
164
+ classes = b_classes[k, :num_boxes]
165
+ scores = b_scores[k, :num_boxes]
166
+ # print(raw_img_shape)
167
+ boxes[:, [0, 2]] = (boxes[:, [0, 2]] * raw_img_shape[1]) # w
168
+ boxes[:, [1, 3]] = (boxes[:, [1, 3]] * raw_img_shape[0]) # h
169
+ cls_names = [self.class_names[int(c)] for c in classes]
170
+ # print(raw_img_shape, boxes.astype(int), cls_names, scores)
171
+
172
+ img_path = paths[k]
173
+ filename = img_path.split(os.sep)[-1].split('.')[0]
174
+ # print(filename)
175
+ output_path = os.path.join(pred_folder_path, filename+'.txt')
176
+ with open(output_path, 'w') as pred_file:
177
+ for box_idx in range(num_boxes):
178
+ b = boxes[box_idx]
179
+ pred_file.write(f'{cls_names[box_idx]} {scores[box_idx]} {b[0]} {b[1]} {b[2]} {b[3]}\n')
180
+
181
+
182
+ def eval_map(self, gt_folder_path, pred_folder_path, temp_json_folder_path, output_files_path):
183
+ """Process Gt"""
184
+ ground_truth_files_list = glob(gt_folder_path + '/*.txt')
185
+ assert len(ground_truth_files_list) > 0, 'no ground truth file'
186
+ ground_truth_files_list.sort()
187
+ # dictionary with counter per class
188
+ gt_counter_per_class = {}
189
+ counter_images_per_class = {}
190
+
191
+ gt_files = []
192
+ for txt_file in ground_truth_files_list:
193
+ file_id = txt_file.split(".txt", 1)[0]
194
+ file_id = os.path.basename(os.path.normpath(file_id))
195
+ # check if there is a correspondent detection-results file
196
+ temp_path = os.path.join(pred_folder_path, (file_id + ".txt"))
197
+ assert os.path.exists(temp_path), "Error. File not found: {}\n".format(temp_path)
198
+ lines_list = read_txt_to_list(txt_file)
199
+ # create ground-truth dictionary
200
+ bounding_boxes = []
201
+ is_difficult = False
202
+ already_seen_classes = []
203
+ for line in lines_list:
204
+ class_name, left, top, right, bottom = line.split()
205
+ # check if class is in the ignore list, if yes skip
206
+ bbox = left + " " + top + " " + right + " " + bottom
207
+ bounding_boxes.append({"class_name": class_name, "bbox": bbox, "used": False})
208
+ # count that object
209
+ if class_name in gt_counter_per_class:
210
+ gt_counter_per_class[class_name] += 1
211
+ else:
212
+ # if class didn't exist yet
213
+ gt_counter_per_class[class_name] = 1
214
+
215
+ if class_name not in already_seen_classes:
216
+ if class_name in counter_images_per_class:
217
+ counter_images_per_class[class_name] += 1
218
+ else:
219
+ # if class didn't exist yet
220
+ counter_images_per_class[class_name] = 1
221
+ already_seen_classes.append(class_name)
222
+
223
+ # dump bounding_boxes into a ".json" file
224
+ new_temp_file = os.path.join(temp_json_folder_path, file_id+"_ground_truth.json") #TEMP_FILES_PATH + "/" + file_id + "_ground_truth.json"
225
+ gt_files.append(new_temp_file)
226
+ with open(new_temp_file, 'w') as outfile:
227
+ json.dump(bounding_boxes, outfile)
228
+
229
+ gt_classes = list(gt_counter_per_class.keys())
230
+ # let's sort the classes alphabetically
231
+ gt_classes = sorted(gt_classes)
232
+ n_classes = len(gt_classes)
233
+ print(gt_classes, gt_counter_per_class)
234
+
235
+ """Process prediction"""
236
+
237
+ dr_files_list = sorted(glob(os.path.join(pred_folder_path, '*.txt')))
238
+
239
+ for class_index, class_name in enumerate(gt_classes):
240
+ bounding_boxes = []
241
+ for txt_file in dr_files_list:
242
+ # the first time it checks if all the corresponding ground-truth files exist
243
+ file_id = txt_file.split(".txt", 1)[0]
244
+ file_id = os.path.basename(os.path.normpath(file_id))
245
+ temp_path = os.path.join(gt_folder_path, (file_id + ".txt"))
246
+ if class_index == 0:
247
+ if not os.path.exists(temp_path):
248
+ error_msg = f"Error. File not found: {temp_path}\n"
249
+ print(error_msg)
250
+ lines = read_txt_to_list(txt_file)
251
+ for line in lines:
252
+ try:
253
+ tmp_class_name, confidence, left, top, right, bottom = line.split()
254
+ except ValueError:
255
+ error_msg = f"""Error: File {txt_file} in the wrong format.\n
256
+ Expected: <class_name> <confidence> <left> <top> <right> <bottom>\n
257
+ Received: {line} \n"""
258
+ print(error_msg)
259
+ if tmp_class_name == class_name:
260
+ # print("match")
261
+ bbox = left + " " + top + " " + right + " " + bottom
262
+ bounding_boxes.append({"confidence": confidence, "file_id": file_id, "bbox": bbox})
263
+ # sort detection-results by decreasing confidence
264
+ bounding_boxes.sort(key=lambda x: float(x['confidence']), reverse=True)
265
+ with open(temp_json_folder_path + "/" + class_name + "_dr.json", 'w') as outfile:
266
+ json.dump(bounding_boxes, outfile)
267
+
268
+ """
269
+ Calculate the AP for each class
270
+ """
271
+ sum_AP = 0.0
272
+ ap_dictionary = {}
273
+ # open file to store the output
274
+ with open(output_files_path + "/output.txt", 'w') as output_file:
275
+ output_file.write("# AP and precision/recall per class\n")
276
+ count_true_positives = {}
277
+ for class_index, class_name in enumerate(gt_classes):
278
+ count_true_positives[class_name] = 0
279
+ """
280
+ Load detection-results of that class
281
+ """
282
+ dr_file = temp_json_folder_path + "/" + class_name + "_dr.json"
283
+ dr_data = json.load(open(dr_file))
284
+
285
+ """
286
+ Assign detection-results to ground-truth objects
287
+ """
288
+ nd = len(dr_data)
289
+ tp = [0] * nd # creates an array of zeros of size nd
290
+ fp = [0] * nd
291
+ for idx, detection in enumerate(dr_data):
292
+ file_id = detection["file_id"]
293
+ gt_file = temp_json_folder_path + "/" + file_id + "_ground_truth.json"
294
+ ground_truth_data = json.load(open(gt_file))
295
+ ovmax = -1
296
+ gt_match = -1
297
+ # load detected object bounding-box
298
+ bb = [float(x) for x in detection["bbox"].split()]
299
+ for obj in ground_truth_data:
300
+ # look for a class_name match
301
+ if obj["class_name"] == class_name:
302
+ bbgt = [float(x) for x in obj["bbox"].split()]
303
+ bi = [max(bb[0], bbgt[0]), max(bb[1], bbgt[1]), min(bb[2], bbgt[2]), min(bb[3], bbgt[3])]
304
+ iw = bi[2] - bi[0] + 1
305
+ ih = bi[3] - bi[1] + 1
306
+ if iw > 0 and ih > 0:
307
+ # compute overlap (IoU) = area of intersection / area of union
308
+ ua = (bb[2] - bb[0] + 1) * (bb[3] - bb[1] + 1) + \
309
+ (bbgt[2] - bbgt[0]+ 1) * (bbgt[3] - bbgt[1] + 1) - iw * ih
310
+ ov = iw * ih / ua
311
+ if ov > ovmax:
312
+ ovmax = ov
313
+ gt_match = obj
314
+
315
+ min_overlap = 0.5
316
+ if ovmax >= min_overlap:
317
+ # if "difficult" not in gt_match:
318
+ if not bool(gt_match["used"]):
319
+ # true positive
320
+ tp[idx] = 1
321
+ gt_match["used"] = True
322
+ count_true_positives[class_name] += 1
323
+ # update the ".json" file
324
+ with open(gt_file, 'w') as f:
325
+ f.write(json.dumps(ground_truth_data))
326
+ else:
327
+ # false positive (multiple detection)
328
+ fp[idx] = 1
329
+ else:
330
+ fp[idx] = 1
331
+
332
+
333
+ # compute precision/recall
334
+ cumsum = 0
335
+ for idx, val in enumerate(fp):
336
+ fp[idx] += cumsum
337
+ cumsum += val
338
+ print('fp ', cumsum)
339
+ cumsum = 0
340
+ for idx, val in enumerate(tp):
341
+ tp[idx] += cumsum
342
+ cumsum += val
343
+ print('tp ', cumsum)
344
+ rec = tp[:]
345
+ for idx, val in enumerate(tp):
346
+ rec[idx] = float(tp[idx]) / gt_counter_per_class[class_name]
347
+ print('recall ', cumsum)
348
+ prec = tp[:]
349
+ for idx, val in enumerate(tp):
350
+ prec[idx] = float(tp[idx]) / (fp[idx] + tp[idx])
351
+ print('prec ', cumsum)
352
+
353
+ ap, mrec, mprec = voc_ap(rec[:], prec[:])
354
+ sum_AP += ap
355
+ text = "{0:.2f}%".format(
356
+ ap * 100) + " = " + class_name + " AP " # class_name + " AP = {0:.2f}%".format(ap*100)
357
+
358
+ print(text)
359
+ ap_dictionary[class_name] = ap
360
+
361
+ n_images = counter_images_per_class[class_name]
362
+ # lamr, mr, fppi = log_average_miss_rate(np.array(prec), np.array(rec), n_images)
363
+ # lamr_dictionary[class_name] = lamr
364
+
365
+ """
366
+ Draw plot
367
+ """
368
+ if True:
369
+ plt.plot(rec, prec, '-o')
370
+ # add a new penultimate point to the list (mrec[-2], 0.0)
371
+ # since the last line segment (and respective area) do not affect the AP value
372
+ area_under_curve_x = mrec[:-1] + [mrec[-2]] + [mrec[-1]]
373
+ area_under_curve_y = mprec[:-1] + [0.0] + [mprec[-1]]
374
+ plt.fill_between(area_under_curve_x, 0, area_under_curve_y, alpha=0.2, edgecolor='r')
375
+ # set window title
376
+ fig = plt.gcf() # gcf - get current figure
377
+ fig.canvas.set_window_title('AP ' + class_name)
378
+ # set plot title
379
+ plt.title('class: ' + text)
380
+ # plt.suptitle('This is a somewhat long figure title', fontsize=16)
381
+ # set axis titles
382
+ plt.xlabel('Recall')
383
+ plt.ylabel('Precision')
384
+ # optional - set axes
385
+ axes = plt.gca() # gca - get current axes
386
+ axes.set_xlim([0.0, 1.0])
387
+ axes.set_ylim([0.0, 1.05]) # .05 to give some extra space
388
+ # Alternative option -> wait for button to be pressed
389
+ # while not plt.waitforbuttonpress(): pass # wait for key display
390
+ # Alternative option -> normal display
391
+ plt.show()
392
+ # save the plot
393
+ # fig.savefig(output_files_path + "/classes/" + class_name + ".png")
394
+ # plt.cla() # clear axes for next plot
395
+
396
+ # if show_animation:
397
+ # cv2.destroyAllWindows()
398
+
399
+ output_file.write("\n# mAP of all classes\n")
400
+ mAP = sum_AP / n_classes
401
+ text = "mAP = {0:.2f}%".format(mAP * 100)
402
+ output_file.write(text + "\n")
403
+ print(text)
404
+
405
+ """
406
+ Count total of detection-results
407
+ """
408
+ # iterate through all the files
409
+ det_counter_per_class = {}
410
+ for txt_file in dr_files_list:
411
+ # get lines to list
412
+ lines_list = read_txt_to_list(txt_file)
413
+ for line in lines_list:
414
+ class_name = line.split()[0]
415
+ # check if class is in the ignore list, if yes skip
416
+ # if class_name in args.ignore:
417
+ # continue
418
+ # count that object
419
+ if class_name in det_counter_per_class:
420
+ det_counter_per_class[class_name] += 1
421
+ else:
422
+ # if class didn't exist yet
423
+ det_counter_per_class[class_name] = 1
424
+ # print(det_counter_per_class)
425
+ dr_classes = list(det_counter_per_class.keys())
426
+
427
+ """
428
+ Plot the total number of occurences of each class in the ground-truth
429
+ """
430
+ if True:
431
+ window_title = "ground-truth-info"
432
+ plot_title = "ground-truth\n"
433
+ plot_title += "(" + str(len(ground_truth_files_list)) + " files and " + str(n_classes) + " classes)"
434
+ x_label = "Number of objects per class"
435
+ output_path = output_files_path + "/ground-truth-info.png"
436
+ to_show = False
437
+ plot_color = 'forestgreen'
438
+ draw_plot_func(
439
+ gt_counter_per_class,
440
+ n_classes,
441
+ window_title,
442
+ plot_title,
443
+ x_label,
444
+ output_path,
445
+ to_show,
446
+ plot_color,
447
+ '',
448
+ )
449
+
450
+ """
451
+ Finish counting true positives
452
+ """
453
+ for class_name in dr_classes:
454
+ # if class exists in detection-result but not in ground-truth then there are no true positives in that class
455
+ if class_name not in gt_classes:
456
+ count_true_positives[class_name] = 0
457
+ # print(count_true_positives)
458
+
459
+ """
460
+ Plot the total number of occurences of each class in the "detection-results" folder
461
+ """
462
+ if True:
463
+ window_title = "detection-results-info"
464
+ # Plot title
465
+ plot_title = "detection-results\n"
466
+ plot_title += "(" + str(len(dr_files_list)) + " files and "
467
+ count_non_zero_values_in_dictionary = sum(int(x) > 0 for x in list(det_counter_per_class.values()))
468
+ plot_title += str(count_non_zero_values_in_dictionary) + " detected classes)"
469
+ # end Plot title
470
+ x_label = "Number of objects per class"
471
+ output_path = output_files_path + "/detection-results-info.png"
472
+ to_show = False
473
+ plot_color = 'forestgreen'
474
+ true_p_bar = count_true_positives
475
+ draw_plot_func(
476
+ det_counter_per_class,
477
+ len(det_counter_per_class),
478
+ window_title,
479
+ plot_title,
480
+ x_label,
481
+ output_path,
482
+ to_show,
483
+ plot_color,
484
+ true_p_bar
485
+ )
486
+
487
+ """
488
+ Draw mAP plot (Show AP's of all classes in decreasing order)
489
+ """
490
+ if True:
491
+ window_title = "mAP"
492
+ plot_title = "mAP = {0:.2f}%".format(mAP * 100)
493
+ x_label = "Average Precision"
494
+ output_path = output_files_path + "/mAP.png"
495
+ to_show = True
496
+ plot_color = 'royalblue'
497
+ draw_plot_func(
498
+ ap_dictionary,
499
+ n_classes,
500
+ window_title,
501
+ plot_title,
502
+ x_label,
503
+ output_path,
504
+ to_show,
505
+ plot_color,
506
+ ""
507
+ )
508
+
509
+ def predict_raw(self, img_path):
510
+ raw_img = cv2.imread(img_path)
511
+ print('img shape: ', raw_img.shape)
512
+ img = self.preprocess_img(raw_img)
513
+ imgs = np.expand_dims(img, axis=0)
514
+ return self.yolo_model.predict(imgs)
515
+
516
+ def predict_nonms(self, img_path, iou_threshold=0.413, score_threshold=0.1):
517
+ raw_img = cv2.imread(img_path)
518
+ print('img shape: ', raw_img.shape)
519
+ img = self.preprocess_img(raw_img)
520
+ imgs = np.expand_dims(img, axis=0)
521
+ yolov4_output = self.yolo_model.predict(imgs)
522
+ output = yolov4_head(yolov4_output, self.num_classes, self.anchors, self.xyscale)
523
+ pred_output = nms(output, self.img_size, self.num_classes, iou_threshold, score_threshold)
524
+ pred_output = [p.numpy() for p in pred_output]
525
+ detections = get_detection_data(img=raw_img,
526
+ model_outputs=pred_output,
527
+ class_names=self.class_names)
528
+ draw_bbox(raw_img, detections, cmap=self.class_color, random_color=True)
529
+ return detections
530
+