BhumikaMak commited on
Commit
b158018
·
verified ·
1 Parent(s): 81eac19

updated args passing

Browse files
Files changed (1) hide show
  1. yolov5.py +7 -6
yolov5.py CHANGED
@@ -76,14 +76,13 @@ def xai_yolov5(image):
76
  # Grad-CAM visualization
77
  cam_image, renormalized_cam_image = generate_cam_image(model, target_layers, tensor, image, boxes)
78
 
79
- rgb_img_float, batch_explanations = dff_nmf(image, target_lyr = -5, n_components = 8)
80
  im = visualize_batch_explanations(rgb_img_float, batch_explanations)
81
 
82
  # Combine results
83
- print('shapes........', image.shape, detections_img.shape, renormalized_cam_image.shape, im.shape)
84
  final_image = np.hstack((image, detections_img, renormalized_cam_image, im))
85
  caption = "Results using YOLOv5"
86
- return Image.fromarray(final_image), caption
87
 
88
 
89
 
@@ -194,8 +193,8 @@ def dff_nmf(image, target_lyr, n_components):
194
  yaml_data = requests.get(yolov5_categories_url).text
195
  labels = yaml.safe_load(yaml_data)['names'] # Parse the YAML file to get class names
196
  num_classes = model.model.model.model[-1].nc
197
-
198
- for indx in range( explanations[0].shape[0]):
199
  upsampled_input = explanations[0][indx]
200
  upsampled_input = torch.tensor(upsampled_input)
201
  device = next(model.parameters()).device
@@ -234,7 +233,9 @@ def dff_nmf(image, target_lyr, n_components):
234
  plt.show()
235
  plt.savefig("test_" + str(indx) + ".png" )
236
  plt.clf()
237
- return rgb_img_float, batch_explanations
 
 
238
 
239
 
240
  def visualize_batch_explanations(rgb_img_float, batch_explanations, image_weight=0.7):
 
76
  # Grad-CAM visualization
77
  cam_image, renormalized_cam_image = generate_cam_image(model, target_layers, tensor, image, boxes)
78
 
79
+ rgb_img_float, batch_explanations, result = dff_nmf(image, target_lyr = -5, n_components = 8)
80
  im = visualize_batch_explanations(rgb_img_float, batch_explanations)
81
 
82
  # Combine results
 
83
  final_image = np.hstack((image, detections_img, renormalized_cam_image, im))
84
  caption = "Results using YOLOv5"
85
+ return Image.fromarray(final_image), caption, result
86
 
87
 
88
 
 
193
  yaml_data = requests.get(yolov5_categories_url).text
194
  labels = yaml.safe_load(yaml_data)['names'] # Parse the YAML file to get class names
195
  num_classes = model.model.model.model[-1].nc
196
+ results = []
197
+ for indx in range(explanations[0].shape[0]):
198
  upsampled_input = explanations[0][indx]
199
  upsampled_input = torch.tensor(upsampled_input)
200
  device = next(model.parameters()).device
 
233
  plt.show()
234
  plt.savefig("test_" + str(indx) + ".png" )
235
  plt.clf()
236
+ results.append(Image.open(f"test_{indx}.png"))
237
+
238
+ return rgb_img_float, batch_explanations, results
239
 
240
 
241
  def visualize_batch_explanations(rgb_img_float, batch_explanations, image_weight=0.7):