This is an ensemble model for predicting breast cancer and breast density based on screening mammography.
The model uses 3 basic CNNs (tf_efficientnetv2_s
backbone) and performs inference on each provided image (i.e., CC and MLO view).
Each net in the ensemble uses a different resolution: 2048 x 1024, 1920 x 1280, and 1536 x 1536.
The final outputs are averaged together across the provided views and the neural nets.
The model can also perform inference on a single view (image), although performance will be decreased.
A hybrid classification-segmentation model was first pretrained on the Curated Breast Imaging Subset of Digital Database for Screening Mammography (CBIS-DDSM). This dataset contains film mammography studies (as opposed to digital) with accompanying ROI annotations for benign and malignant masses and calcifications.
The resultant model was further trained on data from the RSNA Screening Mammography Breast Cancer Detection challenge. The data was split into 80%/10%/10% train/val/test. Evaluation was performed on the 10% holdout test split. This procedure was repeated 3 separate times to better assess the model's performance. The provided weights are from the first data split.
Exponential moving averaging was used during training and increased performance.
Note that the model was trained using cropped images, and thus it is recommended to crop the image prior to inference. A cropping model is provided here: https://huggingface.co/ianpan/mammo-crop
The primary evaluation metric is the area under the receiver operating characteristic curve (AUC/AUROC). Below are the average and standard deviation across the 3 splits.
Split 1: 0.9464
Split 2: 0.9467
Split 3: 0.9422
Mean (std.): 0.9451 (0.002)
As this is a screening test, high sensitivity is desirable. We also calculate the specificity at varying sensitivities, shown below (averaged across 3 splits):
Sensitivity: 98.1%, Specificity: 65.4% +/- 7.2%, Threshold: 0.0072 +/- 0.0021
Sensitivity: 94.3%, Specificity: 78.7% +/- 0.9%, Threshold: 0.0127 +/- 0.0011
Sensitivity: 90.5%, Specificity: 84.8% +/- 2.7%, Threshold: 0.0184 +/- 0.0027
Example usage:
import cv2
import torch
from transformers import AutoModel
def crop_mammo(img, model, device):
img_shape = torch.tensor([img.shape[:2]]).to(device)
x = model.preprocess(img)
x = torch.from_numpy(x).expand(1, 1, -1, -1).float().to(device)
with torch.inference_mode():
coords = model(x, img_shape)
coords = coords[0].cpu().numpy()
x, y, w, h = coords
return img[y: y + h, x: x + w]
device = "cuda:0"
crop_model = AutoModel.from_pretrained("ianpan/mammo-crop", trust_remote_code=True)
crop_model = crop_model.eval().to(device)
model = AutoModel.from_pretrained("ianpan/mammoscreen", trust_remote_code=True)
model = model.eval().to(device)
cc_img = cv2.imread("mammo_cc.png", cv2.IMREAD_GRAYSCALE)
mlo_img = cv2.imread("mammo_mlo.png", cv2.IMREAD_GRAYSCALE)
cc_img = crop_mammo(cc_img, crop_model, device)
mlo_img = crop_mammo(mlo_img, crop_model, device)
with torch.inference_mode():
output = model({"cc": cc_img, "mlo": mlo_img}, device=device)
Note that the model preprocesses the data within the forward
function into the necessary format.
output
is a dictionary containing two keys: cancer
and density
. output['cancer']
is a tensor of shape (N, 1) and output['density']
is a tensor of shape (N, 4).
If you want the predicted density class, take the argmax: output['density'].argmax(1)
. If only a single study is provided, then N=1.
You can also access each neural net separately using model.net{i}
. However, you must apply the preprocessing outside of the forward
function.
input_dict = model.net0.preprocess({"cc": cc_img, "mlo": mlo_img}, device=device)
with torch.inference_mode():
out = model.net0(input_dict)
The model also supports batch inference. Construct a dictionary for each breast and pass a list of dictionaries to the model.
For example, if you want to perform inference for each breast for 2 patients (pt1
, pt2
):
cc_images = ["rt_pt1_cc.png", "lt_pt1_cc.png", "rt_pt2_cc.png", "lt_pt2_cc.png"]
mlo_images = ["rt_pt1_mlo.png", lt_pt1_mlo.png", "rt_pt2_mlo.png", "lt_pt2_mlo.png"]
cc_images = [cv2.imread(_, cv2.IMREAD_GRAYSCALE) for _ in cc_images]
mlo_images = [cv2.imread(_, cv2.IMREAD_GRAYSCALE) for _ in mlo_images]
cc_images = [crop_mammo(_, crop_model, device) for _ in cc_images]
mlo_images = [crop_mammo(_, crop_model, device), for _ in mlo_images]
input_dict = [{"cc": cc_img, "mlo": mlo_img} for cc_img, mlo_img in zip(cc_images, mlo_images)]
with torch.inference_mode():
output = model(input_dict, device=device)
Note that if you are converting images from DICOM to 8-bit PNG/JPEG, it is important to apply the lookup table to the pixel values, which can be done using pydicom.pixels.apply_voi_lut
.
If you have pydicom
installed, you can also load a DICOM image directly, which handles the proper 8-bit conversion for you:
img = model.load_image_from_dicom(path_to_dicom)
- Downloads last month
- 30
Model tree for ianpan/mammoscreen
Base model
timm/tf_efficientnetv2_s.in21k_ft_in1k