|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
from typing import Any, List, Union |
|
|
|
import numpy as np |
|
import torch |
|
from monai.apps.detection.networks.retinanet_detector import RetinaNetDetector |
|
from monai.inferers.inferer import Inferer |
|
from torch import Tensor |
|
|
|
|
|
class RetinaNetInferer(Inferer): |
|
""" |
|
RetinaNet Inferer takes RetinaNet as input |
|
|
|
Args: |
|
detector: the RetinaNetDetector that converts network output BxCxMxN or BxCxMxNxP |
|
map into boxes and classification scores. |
|
args: other optional args to be passed to detector. |
|
kwargs: other optional keyword args to be passed to detector. |
|
""" |
|
|
|
def __init__(self, detector: RetinaNetDetector, *args, **kwargs) -> None: |
|
Inferer.__init__(self) |
|
self.detector = detector |
|
self.sliding_window_size = None |
|
if self.detector.inferer is not None: |
|
if hasattr(self.detector.inferer, "roi_size"): |
|
self.sliding_window_size = np.prod(self.detector.inferer.roi_size) |
|
|
|
def __call__(self, inputs: Union[List[Tensor], Tensor], network: torch.nn.Module, *args: Any, **kwargs: Any): |
|
"""Unified callable function API of Inferers. |
|
Args: |
|
inputs: model input data for inference. |
|
network: target detection network to execute inference. |
|
supports callable that fullfilles requirements of network in |
|
monai.apps.detection.networks.retinanet_detector.RetinaNetDetector`` |
|
args: optional args to be passed to ``network``. |
|
kwargs: optional keyword args to be passed to ``network``. |
|
""" |
|
self.detector.network = network |
|
self.detector.training = self.detector.network.training |
|
|
|
|
|
|
|
use_inferer = self.sliding_window_size is not None and not all( |
|
[data_i[0, ...].numel() < self.sliding_window_size for data_i in inputs] |
|
) |
|
|
|
return self.detector(inputs, use_inferer=use_inferer, *args, **kwargs) |
|
|