import torch import numpy as np def object_detection_collate(batch): images = [] gt_boxes = [] gt_labels = [] image_type = type(batch[0][0]) box_type = type(batch[0][1]) label_type = type(batch[0][2]) for image, boxes, labels in batch: if image_type is np.ndarray: images.append(torch.from_numpy(image)) elif image_type is torch.Tensor: images.append(image) else: raise TypeError(f"Image should be tensor or np.ndarray, but got {image_type}.") if box_type is np.ndarray: gt_boxes.append(torch.from_numpy(boxes)) elif box_type is torch.Tensor: gt_boxes.append(boxes) else: raise TypeError(f"Boxes should be tensor or np.ndarray, but got {box_type}.") if label_type is np.ndarray: gt_labels.append(torch.from_numpy(labels)) elif label_type is torch.Tensor: gt_labels.append(labels) else: raise TypeError(f"Labels should be tensor or np.ndarray, but got {label_type}.") return torch.stack(images), gt_boxes, gt_labels