fa0311's picture
Upload 81 files
b559e06 verified
import torch
import numpy as np
def object_detection_collate(batch):
images = []
gt_boxes = []
gt_labels = []
image_type = type(batch[0][0])
box_type = type(batch[0][1])
label_type = type(batch[0][2])
for image, boxes, labels in batch:
if image_type is np.ndarray:
images.append(torch.from_numpy(image))
elif image_type is torch.Tensor:
images.append(image)
else:
raise TypeError(f"Image should be tensor or np.ndarray, but got {image_type}.")
if box_type is np.ndarray:
gt_boxes.append(torch.from_numpy(boxes))
elif box_type is torch.Tensor:
gt_boxes.append(boxes)
else:
raise TypeError(f"Boxes should be tensor or np.ndarray, but got {box_type}.")
if label_type is np.ndarray:
gt_labels.append(torch.from_numpy(labels))
elif label_type is torch.Tensor:
gt_labels.append(labels)
else:
raise TypeError(f"Labels should be tensor or np.ndarray, but got {label_type}.")
return torch.stack(images), gt_boxes, gt_labels