File size: 5,416 Bytes
7bfd23a
 
0004c57
 
 
 
 
 
 
 
 
 
7bfd23a
 
0004c57
 
 
 
 
7bfd23a
0004c57
 
 
7bfd23a
0004c57
 
 
 
7bfd23a
0004c57
7bfd23a
0004c57
 
7bfd23a
0004c57
 
7bfd23a
0004c57
 
 
 
7bfd23a
0004c57
7bfd23a
0004c57
7bfd23a
0004c57
 
7bfd23a
0004c57
7ae9bb6
 
 
0004c57
7bfd23a
0004c57
7bfd23a
0004c57
 
e13a1c5
0004c57
7bfd23a
e13a1c5
 
 
9a118e1
 
 
e13a1c5
9a118e1
 
 
 
 
 
 
 
0004c57
9a118e1
7bfd23a
0004c57
 
7bfd23a
9a118e1
 
 
0004c57
9a118e1
0004c57
7bfd23a
0004c57
 
 
7bfd23a
0004c57
 
9a118e1
0004c57
 
 
7bfd23a
0004c57
 
7bfd23a
0004c57
 
 
7bfd23a
0004c57
 
7bfd23a
9a118e1
 
 
0004c57
 
9a118e1
b619a88
 
 
 
 
 
0004c57
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
---
library_name: transformers
tags:
- mammography
- cancer
- breast_cancer
- radiology
- breast_density
license: apache-2.0
base_model:
- timm/tf_efficientnetv2_s.in21k_ft_in1k
pipeline_tag: image-classification
---

This is an ensemble model for predicting breast cancer and breast density based on screening mammography.
The model uses 3 basic CNNs (`tf_efficientnetv2_s` backbone) and performs inference on each provided image (i.e., CC and MLO view). 
Each net in the ensemble uses a different resolution: 2048 x 1024, 1920 x 1280, and 1536 x 1536.
The final outputs are averaged together across the provided views and the neural nets.
The model can also perform inference on a single view (image), although performance will be decreased.

A hybrid classification-segmentation model was first pretrained on the Curated Breast Imaging Subset of Digital Database for Screening Mammography
[(CBIS-DDSM)](https://www.cancerimagingarchive.net/collection/cbis-ddsm/). This dataset contains film mammography studies
(as opposed to digital) with accompanying ROI annotations for benign and malignant masses and calcifications.

The resultant model was further trained on data from the [RSNA Screening Mammography Breast Cancer Detection challenge](https://www.kaggle.com/competitions/rsna-breast-cancer-detection/).
The data was split into 80%/10%/10% train/val/test. Evaluation was performed on the 10% holdout test split.
This procedure was repeated 3 separate times to better assess the model's performance.
The provided weights are from the first data split. 

Exponential moving averaging was used during training and increased performance.

Note that the model was trained using cropped images, and thus it is recommended to crop the image prior to inference.
A cropping model is provided here: https://huggingface.co/ianpan/mammo-crop

The primary evaluation metric is the area under the receiver operating characteristic curve (AUC/AUROC).
Below are the average and standard deviation across the 3 splits.

```
Split 1: 0.9464
Split 2: 0.9467
Split 3: 0.9422

Mean (std.): 0.9451 (0.002)

```

As this is a screening test, high sensitivity is desirable. We also calculate the specificity
at varying sensitivities, shown below (averaged across 3 splits):

```
Sensitivity: 98.1%, Specificity: 65.4% +/- 7.2%, Threshold: 0.0072 +/- 0.0021
Sensitivity: 94.3%, Specificity: 78.7% +/- 0.9%, Threshold: 0.0127 +/- 0.0011
Sensitivity: 90.5%, Specificity: 84.8% +/- 2.7%, Threshold: 0.0184 +/- 0.0027
```

Example usage:

```
import cv2
import torch
from transformers import AutoModel

def crop_mammo(img, model, device):
  img_shape = torch.tensor([img.shape[:2]]).to(device)
  x = model.preprocess(img)
  x = torch.from_numpy(x).expand(1, 1, -1, -1).float().to(device)
  with torch.inference_mode():
    coords = model(x, img_shape)
  coords = coords[0].cpu().numpy()
  x, y, w, h = coords
  return img[y: y + h, x: x + w]

device = "cuda:0"

crop_model = AutoModel.from_pretrained("ianpan/mammo-crop", trust_remote_code=True)
crop_model = crop_model.eval().to(device)

model = AutoModel.from_pretrained("ianpan/mammoscreen", trust_remote_code=True)
model = model.eval().to(device)

cc_img = cv2.imread("mammo_cc.png", cv2.IMREAD_GRAYSCALE)
mlo_img = cv2.imread("mammo_mlo.png", cv2.IMREAD_GRAYSCALE)

cc_img = crop_mammo(cc_img, crop_model, device)
mlo_img = crop_mammo(mlo_img, crop_model, device)

with torch.inference_mode():
  output = model({"cc": cc_img, "mlo": mlo_img}, device=device)
```

Note that the model preprocesses the data within the `forward` function into the necessary format.
`output` is a dictionary containing two keys: `cancer` and `density`. `output['cancer']` is a tensor of shape (N, 1) and `output['density']` is a tensor of shape (N, 4).
If you want the predicted density class, take the argmax: `output['density'].argmax(1)`. If only a single study is provided, then N=1.

You can also access each neural net separately using `model.net{i}`. However, you must apply the preprocessing outside of the `forward` function.
```
input_dict = model.net0.preprocess({"cc": cc_img, "mlo": mlo_img}, device=device)
with torch.inference_mode():
  out = model.net0(input_dict) 
```

The model also supports batch inference. Construct a dictionary for each breast and pass a list of dictionaries to the model.
For example, if you want to perform inference for each breast for 2 patients (`pt1`, `pt2`):

```
cc_images = ["rt_pt1_cc.png", "lt_pt1_cc.png", "rt_pt2_cc.png", "lt_pt2_cc.png"]
mlo_images = ["rt_pt1_mlo.png", lt_pt1_mlo.png", "rt_pt2_mlo.png", "lt_pt2_mlo.png"]

cc_images = [cv2.imread(_, cv2.IMREAD_GRAYSCALE) for _ in cc_images]
mlo_images = [cv2.imread(_, cv2.IMREAD_GRAYSCALE) for _ in mlo_images]

cc_images = [crop_mammo(_, crop_model, device) for _ in cc_images]
mlo_images = [crop_mammo(_, crop_model, device), for _ in mlo_images]

input_dict = [{"cc": cc_img, "mlo": mlo_img} for cc_img, mlo_img in zip(cc_images, mlo_images)]
with torch.inference_mode():
  output = model(input_dict, device=device)
```

Note that if you are converting images from DICOM to 8-bit PNG/JPEG, it is important to apply the lookup table to the pixel values, which can be done using `pydicom.pixels.apply_voi_lut`. 
If you have `pydicom` installed, you can also load a DICOM image directly, which handles the proper 8-bit conversion for you:
```
img = model.load_image_from_dicom(path_to_dicom)
```