This is an ensemble model for predicting breast cancer and breast density based on screening mammography. The model uses 3 basic CNNs (tf_efficientnetv2_s backbone) and performs inference on each provided image (i.e., CC and MLO view). Each net in the ensemble uses a different resolution: 2048 x 1024, 1920 x 1280, and 1536 x 1536. The final outputs are averaged together across the provided views and the neural nets. The model can also perform inference on a single view (image), although performance will be decreased.

A hybrid classification-segmentation model was first pretrained on the Curated Breast Imaging Subset of Digital Database for Screening Mammography (CBIS-DDSM). This dataset contains film mammography studies (as opposed to digital) with accompanying ROI annotations for benign and malignant masses and calcifications.

The resultant model was further trained on data from the RSNA Screening Mammography Breast Cancer Detection challenge. The data was split into 80%/10%/10% train/val/test. Evaluation was performed on the 10% holdout test split. This procedure was repeated 3 separate times to better assess the model's performance. The provided weights are from the first data split.

Exponential moving averaging was used during training and increased performance.

Note that the model was trained using cropped images, and thus it is recommended to crop the image prior to inference. A cropping model is provided here: https://huggingface.co/ianpan/mammo-crop

The primary evaluation metric is the area under the receiver operating characteristic curve (AUC/AUROC). Below are the average and standard deviation across the 3 splits.

Split 1: 0.9464
Split 2: 0.9467
Split 3: 0.9422

Mean (std.): 0.9451 (0.002)

As this is a screening test, high sensitivity is desirable. We also calculate the specificity at varying sensitivities, shown below (averaged across 3 splits):

Sensitivity: 98.1%, Specificity: 65.4% +/- 7.2%, Threshold: 0.0072 +/- 0.0021
Sensitivity: 94.3%, Specificity: 78.7% +/- 0.9%, Threshold: 0.0127 +/- 0.0011
Sensitivity: 90.5%, Specificity: 84.8% +/- 2.7%, Threshold: 0.0184 +/- 0.0027

Example usage:

import cv2
import torch
from transformers import AutoModel

def crop_mammo(img, model, device):
  img_shape = torch.tensor([img.shape[:2]]).to(device)
  x = model.preprocess(img)
  x = torch.from_numpy(x).expand(1, 1, -1, -1).float().to(device)
  with torch.inference_mode():
    coords = model(x, img_shape)
  coords = coords[0].cpu().numpy()
  x, y, w, h = coords
  return img[y: y + h, x: x + w]

device = "cuda:0"

crop_model = AutoModel.from_pretrained("ianpan/mammo-crop", trust_remote_code=True)
crop_model = crop_model.eval().to(device)

model = AutoModel.from_pretrained("ianpan/mammoscreen", trust_remote_code=True)
model = model.eval().to(device)

cc_img = cv2.imread("mammo_cc.png", cv2.IMREAD_GRAYSCALE)
mlo_img = cv2.imread("mammo_mlo.png", cv2.IMREAD_GRAYSCALE)

cc_img = crop_mammo(cc_img, crop_model, device)
mlo_img = crop_mammo(mlo_img, crop_model, device)

with torch.inference_mode():
  output = model({"cc": cc_img, "mlo": mlo_img}, device=device)

Note that the model preprocesses the data within the forward function into the necessary format. output is a dictionary containing two keys: cancer and density. output['cancer'] is a tensor of shape (N, 1) and output['density'] is a tensor of shape (N, 4). If you want the predicted density class, take the argmax: output['density'].argmax(1). If only a single study is provided, then N=1.

You can also access each neural net separately using model.net{i}. However, you must apply the preprocessing outside of the forward function.

input_dict = model.net0.preprocess({"cc": cc_img, "mlo": mlo_img}, device=device)
with torch.inference_mode():
  out = model.net0(input_dict) 

The model also supports batch inference. Construct a dictionary for each breast and pass a list of dictionaries to the model. For example, if you want to perform inference for each breast for 2 patients (pt1, pt2):

cc_images = ["rt_pt1_cc.png", "lt_pt1_cc.png", "rt_pt2_cc.png", "lt_pt2_cc.png"]
mlo_images = ["rt_pt1_mlo.png", lt_pt1_mlo.png", "rt_pt2_mlo.png", "lt_pt2_mlo.png"]

cc_images = [cv2.imread(_, cv2.IMREAD_GRAYSCALE) for _ in cc_images]
mlo_images = [cv2.imread(_, cv2.IMREAD_GRAYSCALE) for _ in mlo_images]

cc_images = [crop_mammo(_, crop_model, device) for _ in cc_images]
mlo_images = [crop_mammo(_, crop_model, device), for _ in mlo_images]

input_dict = [{"cc": cc_img, "mlo": mlo_img} for cc_img, mlo_img in zip(cc_images, mlo_images)]
with torch.inference_mode():
  output = model(input_dict, device=device)

Note that if you are converting images from DICOM to 8-bit PNG/JPEG, it is important to apply the lookup table to the pixel values, which can be done using pydicom.pixels.apply_voi_lut. If you have pydicom installed, you can also load a DICOM image directly, which handles the proper 8-bit conversion for you:

img = model.load_image_from_dicom(path_to_dicom)
Downloads last month
30
Safetensors
Model size
61M params
Tensor type
F32
ยท
Inference Providers NEW
This model is not currently available via any of the supported third-party Inference Providers, and HF Inference API does not yet support model repos that contain custom code.

Model tree for ianpan/mammoscreen

Finetuned
(3)
this model

Space using ianpan/mammoscreen 1