|
import torch
|
|
import numpy as np
|
|
|
|
|
|
def object_detection_collate(batch):
|
|
images = []
|
|
gt_boxes = []
|
|
gt_labels = []
|
|
image_type = type(batch[0][0])
|
|
box_type = type(batch[0][1])
|
|
label_type = type(batch[0][2])
|
|
for image, boxes, labels in batch:
|
|
if image_type is np.ndarray:
|
|
images.append(torch.from_numpy(image))
|
|
elif image_type is torch.Tensor:
|
|
images.append(image)
|
|
else:
|
|
raise TypeError(f"Image should be tensor or np.ndarray, but got {image_type}.")
|
|
if box_type is np.ndarray:
|
|
gt_boxes.append(torch.from_numpy(boxes))
|
|
elif box_type is torch.Tensor:
|
|
gt_boxes.append(boxes)
|
|
else:
|
|
raise TypeError(f"Boxes should be tensor or np.ndarray, but got {box_type}.")
|
|
if label_type is np.ndarray:
|
|
gt_labels.append(torch.from_numpy(labels))
|
|
elif label_type is torch.Tensor:
|
|
gt_labels.append(labels)
|
|
else:
|
|
raise TypeError(f"Labels should be tensor or np.ndarray, but got {label_type}.")
|
|
return torch.stack(images), gt_boxes, gt_labels |