BhumikaMak commited on
Commit
f72bf07
·
verified ·
1 Parent(s): 9101eba

update: dff_nmf for no attribute 'xywh'

Browse files
Files changed (1) hide show
  1. yolov8.py +15 -15
yolov8.py CHANGED
@@ -161,37 +161,37 @@ class DeepFeatureFactorization:
161
  return True
162
 
163
 
164
-
165
  def dff_nmf(image, target_lyr, n_components):
166
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
167
  mean = [0.485, 0.456, 0.406] # Mean for RGB channels
168
  std = [0.229, 0.224, 0.225] # Standard deviation for RGB channels
169
  img = cv2.resize(image, (640, 640))
170
  rgb_img_float = np.float32(img) / 255.0
171
- input_tensor = torch.from_numpy(rgb_img_float).permute(2, 0, 1).unsqueeze(0).to(device)
172
 
173
  model = YOLO('yolov8s.pt') # Ensure the model is loaded correctly
174
- dff = DeepFeatureFactorization(model=model,
175
- target_layer=model.model.model[int(target_lyr)],
176
  computation_on_concepts=None)
177
-
178
  concepts, batch_explanations, explanations = dff(input_tensor, model, n_components)
179
 
180
  # Getting predictions directly from YOLO
181
  with torch.no_grad():
182
  results = model(input_tensor)
183
 
184
- # Post-processing to extract detections
185
- boxes, scores, classes = results.xywh[0][:, :4], results.xywh[0][:, 4], results.xywh[0][:, 5]
186
- boxes = boxes.cpu().numpy()
187
- scores = scores.cpu().numpy()
188
- classes = classes.cpu().numpy()
 
 
189
 
190
  # Filter detections with confidence score > threshold (e.g., 0.5)
191
  high_conf_boxes = boxes[scores > 0.5]
192
  high_conf_classes = classes[scores > 0.5]
193
-
194
- # Use the processed detections for visualization and further tasks
195
  # Example visualization and output processing
196
  fig, ax = plt.subplots(1, figsize=(8, 8))
197
  ax.axis("off")
@@ -209,11 +209,11 @@ def dff_nmf(image, target_lyr, n_components):
209
  image_array = np.array(fig.canvas.renderer.buffer_rgba())
210
  image_resized = cv2.resize(image_array, (640, 640))
211
  rgba_channels = cv2.split(image_resized)
212
- alpha_channel = rgba_channels[3]
213
  rgb_channels = np.stack(rgba_channels[:3], axis=-1)
214
-
215
  visualization = show_factorization_on_image(rgb_img_float, np.transpose(rgb_channels, (2, 0, 1)), image_weight=0.3)
216
-
217
  return rgb_img_float, batch_explanations, visualization
218
 
219
 
 
161
  return True
162
 
163
 
 
164
  def dff_nmf(image, target_lyr, n_components):
165
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
166
  mean = [0.485, 0.456, 0.406] # Mean for RGB channels
167
  std = [0.229, 0.224, 0.225] # Standard deviation for RGB channels
168
  img = cv2.resize(image, (640, 640))
169
  rgb_img_float = np.float32(img) / 255.0
170
+ input_tensor = torch.from_numpy(rgb_img_float).permute(2, 0, 1).unsqueeze(0).to(device)
171
 
172
  model = YOLO('yolov8s.pt') # Ensure the model is loaded correctly
173
+ dff = DeepFeatureFactorization(model=model,
174
+ target_layer=model.model.model[int(target_lyr)],
175
  computation_on_concepts=None)
176
+
177
  concepts, batch_explanations, explanations = dff(input_tensor, model, n_components)
178
 
179
  # Getting predictions directly from YOLO
180
  with torch.no_grad():
181
  results = model(input_tensor)
182
 
183
+ # Assuming results is a list, extract the first element
184
+ detections = results[0] # The first element should contain the detection data
185
+
186
+ # Access detection results
187
+ boxes = detections.boxes.xyxy.cpu().numpy() # Bounding box coordinates (xyxy)
188
+ scores = detections.scores.cpu().numpy() # Confidence scores
189
+ classes = detections.classes.cpu().numpy() # Class IDs
190
 
191
  # Filter detections with confidence score > threshold (e.g., 0.5)
192
  high_conf_boxes = boxes[scores > 0.5]
193
  high_conf_classes = classes[scores > 0.5]
194
+
 
195
  # Example visualization and output processing
196
  fig, ax = plt.subplots(1, figsize=(8, 8))
197
  ax.axis("off")
 
209
  image_array = np.array(fig.canvas.renderer.buffer_rgba())
210
  image_resized = cv2.resize(image_array, (640, 640))
211
  rgba_channels = cv2.split(image_resized)
212
+ alpha_channel = rgba_channels[3]
213
  rgb_channels = np.stack(rgba_channels[:3], axis=-1)
214
+
215
  visualization = show_factorization_on_image(rgb_img_float, np.transpose(rgb_channels, (2, 0, 1)), image_weight=0.3)
216
+
217
  return rgb_img_float, batch_explanations, visualization
218
 
219