File size: 3,335 Bytes
cfcee64
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
import albumentations as A
import torch
import torch.nn as nn

from numpy.typing import NDArray
from transformers import PreTrainedModel
from timm import create_model
from typing import Optional
from .configuration import CXRConfig
from .unet import UnetDecoder, SegmentationHead

_PYDICOM_AVAILABLE = False
try:
    from pydicom import dcmread
    from pydicom.pixels import apply_voi_lut

    _PYDICOM_AVAILABLE = True
except ModuleNotFoundError:
    pass


class CXRModel(PreTrainedModel):
    config_class = CXRConfig

    def __init__(self, config):
        super().__init__(config)
        self.encoder = create_model(
            model_name=config.backbone,
            features_only=True,
            pretrained=False,
            in_chans=config.in_chans,
        )
        self.decoder = UnetDecoder(
            decoder_n_blocks=config.decoder_n_blocks,
            decoder_channels=config.decoder_channels,
            encoder_channels=config.encoder_channels,
            decoder_center_block=config.decoder_center_block,
            decoder_norm_layer=config.decoder_norm_layer,
            decoder_attention_type=config.decoder_attention_type,
        )
        self.img_size = config.img_size
        self.segmentation_head = SegmentationHead(
            in_channels=config.decoder_channels[-1],
            out_channels=config.seg_num_classes,
            size=self.img_size,
        )
        self.pooling = nn.AdaptiveAvgPool2d(1)
        self.dropout = nn.Dropout(p=config.cls_dropout)
        self.classifier = nn.Linear(config.feature_dim, config.cls_num_classes)

    def normalize(self, x: torch.Tensor) -> torch.Tensor:
        # [0, 255] -> [-1, 1]
        mini, maxi = 0.0, 255.0
        x = (x - mini) / (maxi - mini)
        x = (x - 0.5) * 2.0
        return x

    @staticmethod
    def load_image_from_dicom(path: str) -> Optional[NDArray]:
        if not _PYDICOM_AVAILABLE:
            print("`pydicom` is not installed, returning None ...")
            return None
        dicom = dcmread(path)
        arr = apply_voi_lut(dicom.pixel_array, dicom)
        if dicom.PhotometricInterpretation == "MONOCHROME1":
            # invert image if needed
            arr = arr.max() - arr

        arr = arr - arr.min()
        arr = arr / arr.max()
        arr = (arr * 255).astype("uint8")
        return arr

    def preprocess(self, x: NDArray) -> NDArray:
        x = A.Resize(self.img_size[0], self.img_size[1], p=1)(image=x)["image"]
        return x

    def forward(self, x: torch.Tensor, return_logits: bool = False) -> torch.Tensor:
        x = self.normalize(x)
        features = self.encoder(x)
        decoder_output = self.decoder(features)
        logits = self.segmentation_head(decoder_output[-1])
        b, n = features[-1].shape[:2]
        features = self.pooling(features[-1]).reshape(b, n)
        features = self.dropout(features)
        cls_logits = self.classifier(features)
        out = {
            "mask": logits,
            "age": cls_logits[:, 0].unsqueeze(1),
            "view": cls_logits[:, 1:4],
            "female": cls_logits[:, 4].unsqueeze(1),
        }
        if return_logits:
            return out
        out["mask"] = out["mask"].softmax(1)
        out["view"] = out["view"].softmax(1)
        out["female"] = out["female"].sigmoid()
        return out