chest-x-ray-basic / modeling.py
ianpan's picture
Upload model
cfcee64 verified
import albumentations as A
import torch
import torch.nn as nn
from numpy.typing import NDArray
from transformers import PreTrainedModel
from timm import create_model
from typing import Optional
from .configuration import CXRConfig
from .unet import UnetDecoder, SegmentationHead
_PYDICOM_AVAILABLE = False
try:
from pydicom import dcmread
from pydicom.pixels import apply_voi_lut
_PYDICOM_AVAILABLE = True
except ModuleNotFoundError:
pass
class CXRModel(PreTrainedModel):
config_class = CXRConfig
def __init__(self, config):
super().__init__(config)
self.encoder = create_model(
model_name=config.backbone,
features_only=True,
pretrained=False,
in_chans=config.in_chans,
)
self.decoder = UnetDecoder(
decoder_n_blocks=config.decoder_n_blocks,
decoder_channels=config.decoder_channels,
encoder_channels=config.encoder_channels,
decoder_center_block=config.decoder_center_block,
decoder_norm_layer=config.decoder_norm_layer,
decoder_attention_type=config.decoder_attention_type,
)
self.img_size = config.img_size
self.segmentation_head = SegmentationHead(
in_channels=config.decoder_channels[-1],
out_channels=config.seg_num_classes,
size=self.img_size,
)
self.pooling = nn.AdaptiveAvgPool2d(1)
self.dropout = nn.Dropout(p=config.cls_dropout)
self.classifier = nn.Linear(config.feature_dim, config.cls_num_classes)
def normalize(self, x: torch.Tensor) -> torch.Tensor:
# [0, 255] -> [-1, 1]
mini, maxi = 0.0, 255.0
x = (x - mini) / (maxi - mini)
x = (x - 0.5) * 2.0
return x
@staticmethod
def load_image_from_dicom(path: str) -> Optional[NDArray]:
if not _PYDICOM_AVAILABLE:
print("`pydicom` is not installed, returning None ...")
return None
dicom = dcmread(path)
arr = apply_voi_lut(dicom.pixel_array, dicom)
if dicom.PhotometricInterpretation == "MONOCHROME1":
# invert image if needed
arr = arr.max() - arr
arr = arr - arr.min()
arr = arr / arr.max()
arr = (arr * 255).astype("uint8")
return arr
def preprocess(self, x: NDArray) -> NDArray:
x = A.Resize(self.img_size[0], self.img_size[1], p=1)(image=x)["image"]
return x
def forward(self, x: torch.Tensor, return_logits: bool = False) -> torch.Tensor:
x = self.normalize(x)
features = self.encoder(x)
decoder_output = self.decoder(features)
logits = self.segmentation_head(decoder_output[-1])
b, n = features[-1].shape[:2]
features = self.pooling(features[-1]).reshape(b, n)
features = self.dropout(features)
cls_logits = self.classifier(features)
out = {
"mask": logits,
"age": cls_logits[:, 0].unsqueeze(1),
"view": cls_logits[:, 1:4],
"female": cls_logits[:, 4].unsqueeze(1),
}
if return_logits:
return out
out["mask"] = out["mask"].softmax(1)
out["view"] = out["view"].softmax(1)
out["female"] = out["female"].sigmoid()
return out