File size: 4,256 Bytes
5464776
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
import albumentations as A
import cv2
import torch
import torch.nn as nn
import torch.nn.functional as F

from numpy.typing import NDArray
from transformers import PreTrainedModel
from timm import create_model
from typing import Optional
from .configuration import BoneAgeConfig


class GeM(nn.Module):
    def __init__(
        self, p: int = 3, eps: float = 1e-6, dim: int = 2, flatten: bool = True
    ):
        super().__init__()
        self.p = nn.Parameter(torch.ones(1) * p)
        self.eps = eps
        assert dim in {2, 3}, f"dim must be one of [2, 3], not {dim}"
        self.dim = dim
        if self.dim == 2:
            self.func = F.adaptive_avg_pool2d
        elif self.dim == 3:
            self.func = F.adaptive_avg_pool3d
        self.flatten = nn.Flatten(1) if flatten else nn.Identity()

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        # assumes x.shape is (n, c, [t], h, w)
        x = self.func(x.clamp(min=self.eps).pow(self.p), output_size=1).pow(
            1.0 / self.p
        )
        return self.flatten(x)


class BoneAgeModel(nn.Module):
    def __init__(
        self, backbone, feature_dim=768, dropout=0.1, num_classes=240, in_chans=2
    ):
        super().__init__()
        self.backbone = create_model(
            model_name=backbone,
            pretrained=False,
            num_classes=0,
            global_pool="",
            features_only=False,
            in_chans=in_chans,
        )
        self.pooling = GeM(p=3, dim=2)
        self.dropout = nn.Dropout(p=dropout)
        self.linear = nn.Linear(feature_dim, num_classes)

    def normalize(self, x: torch.Tensor) -> torch.Tensor:
        # [0, 255] -> [-1, 1]
        mini, maxi = 0.0, 255.0
        x = (x - mini) / (maxi - mini)
        x = (x - 0.5) * 2.0
        return x

    def forward(
        self, x: torch.Tensor, female: torch.Tensor, return_logits: bool = False
    ) -> torch.Tensor:
        assert x.size(0) == female.size(
            0
        ), f"x.size(0) [{x.size(0)}] must equal female.size(0) [{female.size(0)}]"
        female_ch = torch.zeros_like(x).to(x.device)
        female_ch[female.bool()] = 255.0
        x = torch.cat([x, female_ch], dim=1)
        x = self.normalize(x)
        features = self.pooling(self.backbone(x))
        logits = self.linear(features)
        if return_logits:
            return logits
        out = (logits.softmax(1) * torch.arange(logits.size(1)).to(logits.device)).sum(
            1
        )
        return out


class BoneAgeEnsembleModel(PreTrainedModel):
    config_class = BoneAgeConfig

    def __init__(self, config):
        super().__init__(config)
        self.num_models = config.num_models
        for i in range(self.num_models):
            setattr(
                self,
                f"net{i}",
                BoneAgeModel(
                    config.backbone,
                    config.feature_dim,
                    config.dropout,
                    config.num_classes,
                    config.in_chans,
                ),
            )

    @staticmethod
    def load_image_from_dicom(path: str) -> Optional[NDArray]:
        try:
            from pydicom import dcmread
            from pydicom.pixels import apply_voi_lut
        except ModuleNotFoundError:
            print("`pydicom` is not installed, returning None ...")
            return None
        dicom = dcmread(path)
        arr = apply_voi_lut(dicom.pixel_array, dicom)
        if dicom.PhotometricInterpretation == "MONOCHROME1":
            arr = arr.max() - arr

        arr = arr - arr.min()
        arr = arr / arr.max()
        arr = (arr * 255).astype("uint8")
        return arr

    @staticmethod
    def preprocess(x: NDArray) -> NDArray:
        x = A.LongestMaxSize(max_size=512, p=1)(image=x)["image"]
        x = A.PadIfNeeded(512, 512, border_mode=cv2.BORDER_CONSTANT, p=1)(image=x)[
            "image"
        ]
        return x

    def forward(
        self, x: torch.Tensor, female: torch.Tensor, return_logits: bool = False
    ) -> torch.Tensor:
        out = []
        for i in range(self.num_models):
            model = getattr(self, f"net{i}")
            out.append(model(x, female, return_logits))
        return torch.stack(out).mean(0)