monai
medical
File size: 2,662 Bytes
f6cc1e0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
# Copyright (c) MONAI Consortium
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#     http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import Any, List, Union

import numpy as np
import torch
from monai.apps.detection.networks.retinanet_detector import RetinaNetDetector
from monai.inferers.inferer import Inferer
from torch import Tensor


class RetinaNetInferer(Inferer):
    """
    RetinaNet Inferer takes RetinaNet as input

    Args:
        detector: the RetinaNetDetector that converts network output BxCxMxN or BxCxMxNxP
            map into boxes and classification scores.
        args: other optional args to be passed to detector.
        kwargs: other optional keyword args to be passed to detector.
    """

    def __init__(self, detector: RetinaNetDetector, *args, **kwargs) -> None:
        Inferer.__init__(self)
        self.detector = detector
        self.sliding_window_size = None
        if self.detector.inferer is not None:
            if hasattr(self.detector.inferer, "roi_size"):
                self.sliding_window_size = np.prod(self.detector.inferer.roi_size)

    def __call__(self, inputs: Union[List[Tensor], Tensor], network: torch.nn.Module, *args: Any, **kwargs: Any):
        """Unified callable function API of Inferers.
        Args:
            inputs: model input data for inference.
            network: target detection network to execute inference.
                supports callable that fullfilles requirements of network in
                monai.apps.detection.networks.retinanet_detector.RetinaNetDetector``
            args: optional args to be passed to ``network``.
            kwargs: optional keyword args to be passed to ``network``.
        """
        self.detector.network = network
        self.detector.training = self.detector.network.training

        # if image smaller than sliding window roi size, no need to use sliding window inferer
        # use sliding window inferer only when image is large
        use_inferer = self.sliding_window_size is not None and not all(
            [data_i[0, ...].numel() < self.sliding_window_size for data_i in inputs]
        )

        return self.detector(inputs, use_inferer=use_inferer, *args, **kwargs)