File size: 1,158 Bytes
b559e06 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 |
import torch
import numpy as np
def object_detection_collate(batch):
images = []
gt_boxes = []
gt_labels = []
image_type = type(batch[0][0])
box_type = type(batch[0][1])
label_type = type(batch[0][2])
for image, boxes, labels in batch:
if image_type is np.ndarray:
images.append(torch.from_numpy(image))
elif image_type is torch.Tensor:
images.append(image)
else:
raise TypeError(f"Image should be tensor or np.ndarray, but got {image_type}.")
if box_type is np.ndarray:
gt_boxes.append(torch.from_numpy(boxes))
elif box_type is torch.Tensor:
gt_boxes.append(boxes)
else:
raise TypeError(f"Boxes should be tensor or np.ndarray, but got {box_type}.")
if label_type is np.ndarray:
gt_labels.append(torch.from_numpy(labels))
elif label_type is torch.Tensor:
gt_labels.append(labels)
else:
raise TypeError(f"Labels should be tensor or np.ndarray, but got {label_type}.")
return torch.stack(images), gt_boxes, gt_labels |