monai
medical
lung_nodule_ct_detection / scripts /detection_inferer.py
katielink's picture
use detection inferer
f6cc1e0
raw
history blame
2.66 kB
# Copyright (c) MONAI Consortium
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Any, List, Union
import numpy as np
import torch
from monai.apps.detection.networks.retinanet_detector import RetinaNetDetector
from monai.inferers.inferer import Inferer
from torch import Tensor
class RetinaNetInferer(Inferer):
"""
RetinaNet Inferer takes RetinaNet as input
Args:
detector: the RetinaNetDetector that converts network output BxCxMxN or BxCxMxNxP
map into boxes and classification scores.
args: other optional args to be passed to detector.
kwargs: other optional keyword args to be passed to detector.
"""
def __init__(self, detector: RetinaNetDetector, *args, **kwargs) -> None:
Inferer.__init__(self)
self.detector = detector
self.sliding_window_size = None
if self.detector.inferer is not None:
if hasattr(self.detector.inferer, "roi_size"):
self.sliding_window_size = np.prod(self.detector.inferer.roi_size)
def __call__(self, inputs: Union[List[Tensor], Tensor], network: torch.nn.Module, *args: Any, **kwargs: Any):
"""Unified callable function API of Inferers.
Args:
inputs: model input data for inference.
network: target detection network to execute inference.
supports callable that fullfilles requirements of network in
monai.apps.detection.networks.retinanet_detector.RetinaNetDetector``
args: optional args to be passed to ``network``.
kwargs: optional keyword args to be passed to ``network``.
"""
self.detector.network = network
self.detector.training = self.detector.network.training
# if image smaller than sliding window roi size, no need to use sliding window inferer
# use sliding window inferer only when image is large
use_inferer = self.sliding_window_size is not None and not all(
[data_i[0, ...].numel() < self.sliding_window_size for data_i in inputs]
)
return self.detector(inputs, use_inferer=use_inferer, *args, **kwargs)