BhumikaMak commited on
Commit
b3188de
·
verified ·
1 Parent(s): c0f1ad1

fix subplots plotting

Browse files
Files changed (1) hide show
  1. yolov5.py +3 -2
yolov5.py CHANGED
@@ -211,7 +211,8 @@ def dff_nmf(image, target_lyr, n_components):
211
  scores = scores * objectness # Adjust scores by objectness
212
  boxes = output1[..., :4] # First 4 values are x1, y1, x2, y2
213
  boxes = boxes[confidence_mask] # Filter boxes by confidence mask
214
- fig, ax = plt.subplots(1, figsize=(10, 10))
 
215
  ax.imshow(torch.tensor(batch_explanations[0][indx]).cpu().numpy(), cmap="RdYlGn") # Display image
216
  top_score_idx = scores.argmax(dim=0) # Get the index of the max score
217
  top_score = scores[top_score_idx].item()
@@ -228,7 +229,7 @@ def dff_nmf(image, target_lyr, n_components):
228
  predicted_label = labels[top_class_id] # Map ID to label
229
  ax.text(x1, y1, f"{predicted_label}: {top_score:.2f}",
230
  color='r', fontsize=12, verticalalignment='top')
231
-
232
 
233
  fig.canvas.draw() # Draw the canvas to make sure the image is rendered
234
  image_array = np.array(fig.canvas.renderer.buffer_rgba()) # Convert to numpy array
 
211
  scores = scores * objectness # Adjust scores by objectness
212
  boxes = output1[..., :4] # First 4 values are x1, y1, x2, y2
213
  boxes = boxes[confidence_mask] # Filter boxes by confidence mask
214
+ fig, ax = plt.subplots(1, figsize=(8, 8))
215
+ ax.axis("off")
216
  ax.imshow(torch.tensor(batch_explanations[0][indx]).cpu().numpy(), cmap="RdYlGn") # Display image
217
  top_score_idx = scores.argmax(dim=0) # Get the index of the max score
218
  top_score = scores[top_score_idx].item()
 
229
  predicted_label = labels[top_class_id] # Map ID to label
230
  ax.text(x1, y1, f"{predicted_label}: {top_score:.2f}",
231
  color='r', fontsize=12, verticalalignment='top')
232
+ plt.subplots_adjust(left=0, right=1, top=1, bottom=0)
233
 
234
  fig.canvas.draw() # Draw the canvas to make sure the image is rendered
235
  image_array = np.array(fig.canvas.renderer.buffer_rgba()) # Convert to numpy array