File size: 3,835 Bytes
cfcee64
 
23ee9fd
 
 
 
 
 
 
 
 
 
 
 
cfcee64
 
180b9cc
 
23ee9fd
 
 
 
 
fec189c
 
23ee9fd
 
 
5ad71fd
fec189c
 
 
23ee9fd
 
fec189c
23ee9fd
fec189c
 
 
23ee9fd
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
---
library_name: transformers
tags:
- chest_x_ray
- x_ray
- medical_imaging
- radiology
- segmentation
- classification
- lungs
- heart
base_model:
- timm/tf_efficientnetv2_s.in21k_ft_in1k
pipeline_tag: image-segmentation
---

This model performs both segmentation and classification on chest radiographs (X-rays).
The model uses a `tf_efficientnetv2_s` backbone with a U-Net decoder for segmentation and linear layer for classification.
For frontal radiographs, the model segments the: 1) right lung, 2) left lung, and 3) heart. 
The model also predicts the chest X-ray view (AP, PA, lateral), patient age, and patient sex.
The [CheXpert](https://stanfordmlgroup.github.io/competitions/chexpert/) (small version) and [NIH Chest X-ray](https://nihcc.app.box.com/v/ChestXray-NIHCC) datasets were used to train the model.
Segmentation masks were obtained from the CheXmask [dataset](https://physionet.org/content/chexmask-cxr-segmentation-data/0.4/) ([paper](https://www.nature.com/articles/s41597-024-03358-1)).
The final dataset comprised 335,516 images from 96,385 patients and was split into 80% training/20% validation. A holdout test set was not used since minimal tuning was performed.
The view classifier was trained only on CheXpert images (NIH images excluded from loss function), given that lateral radiographs are only present in CheXpert.
This is to avoid unwanted bias in the model, which can occur if one class originates only from a single dataset.

Validation performance as follows:
```
Segmentation (Dice similarity coefficient):
  Right Lung: 0.957
   Left Lung: 0.948
       Heart: 0.943

Age Prediction:
  Mean Absolute Error: 5.25 years

Classification:
  View (AP, PA, lateral): 99.42% accuracy
  Female: 0.999 AUC
```

To use the model:
```
import cv2
import torch
from transformers import AutoModel

device = "cuda" if torch.cuda.is_available() else "cpu"

model = AutoModel.from_pretrained("ianpan/chest-x-ray-basic", trust_remote_code=True)
model = model.eval().to(device)
img = cv2.imread(..., 0)
x = model.preprocess(img) # only takes single image as input
x = torch.from_numpy(x).unsqueeze(0).unsqueeze(0) # add channel, batch dims
x = x.float()

with torch.inference_mode():
  out = model(x.to(device))
```

The output is a dictionary which contains 4 keys:
* `mask` has 3 channels containing the segmentation masks. Take the argmax over the channel dimension to create a single image mask (i.e., `out["mask"].argmax(1)`): 1 = right lung, 2 = left lung, 3 = heart.
* `age`, in years.
* `view`, with 3 classes for each possible view. Take the argmax to select the predicted view (i.e., `out["view"].argmax(1)`): 0 = AP, 1 = PA, 2 = lateral.
* `female`, binarize with `out["female"] >= 0.5`. 

You can use the segmentation mask to crop the region containing the lungs from the rest of the X-ray.
You can also calculate the [cardiothoracic ratio (CTR)](https://radiopaedia.org/articles/cardiothoracic-ratio?lang=us) using this function:
```
import numpy as np

def calculate_ctr(mask): # single mask with dims (height, width)
    lungs = np.zeros_like(mask)
    lungs[mask == 1] = 1
    lungs[mask == 2] = 1
    heart = (mask == 3).astype("int")
    y, x = np.stack(np.where(lungs == 1))
    lung_min = x.min()
    lung_max = x.max()
    y, x = np.stack(np.where(heart == 1))    
    heart_min = x.min()
    heart_max = x.max()
    lung_range = lung_max - lung_min
    heart_range = heart_max - heart_min
    return heart_range / lung_range
```

If you have `pydicom` installed, you can also load a DICOM image directly:
```
img = model.load_image_from_dicom(path_to_dicom)
```
 
This model is for demonstration and research purposes only and has NOT been approved by any regulatory agency for clinical use. 
The user assumes any and all responsibility regarding their own use of this model and its outputs.