bone-age / README.md
ianpan's picture
Update README.md
2c29ffd verified
|
raw
history blame
5.07 kB
metadata
library_name: transformers
tags:
  - radiology
  - medical_imaging
  - bone_age
  - x_ray
license: apache-2.0
base_model:
  - timm/convnextv2_tiny.fcmae_ft_in22k_in1k
pipeline_tag: image-classification

This model has been trained and validated on 14,036 pediatric hand radiographs from the RSNA Pediatric Bone Age Challenge dataset, which is publicly available. It can be loaded using:

from transformers import AutoModel

model = AutoModel.from_pretrained("ianpan/bone-age", trust_remote_code=True)

The model is a 3-fold ensemble utilizing the convnextv2_tiny backbone. The individual models can be accessed through model.net0, model.net1, model.net2. Originally, it was trained with both a regression and classification head. However, this model only loads the classification head, as stand-alone performance was slightly better. The classification head also generates better GradCAMs. The softmax function is applied to the output logits and multiplied by the corresponding class indices, then summed. This outputs a scalar float value representing the predicted bone age in units of months.

In addition to standard data augmentation, additional augmentations were also applied:

  • Using a cropped radiograph (from the model https://huggingface.co/ianpan/bone-age-crop) with probability 0.5
  • Histogram matching with a reference image (available in this repo under Files, ref_img.png) with probability 0.5

The model was trained over 20,000 iterations using a batch size of 64 across 2 NVIDIA RTX 3090 GPUs.

Note that both of the above augmentations could be applied simultaneously and in conjunction with standard data augamentations. Thus, the model accommodates a large range of variability in the appearance of a hand radiograph.

On the original challenge test set comprising 200 multi-annotated pediatric hand radiographs, this model achieves a mean absolute error of 4.16 months (when applying both cropping and histogram matching to the input radiograph), which surpasses the top solutions from the original challenge. Specific results as follows, with single model performance using model.net0 in brackets:

Crop (-) / Histogram Matching (-): 4.42 [4.67] months
Crop (+) / Histogram Matching (-): 4.47 [4.84] months
Crop (-) / Histogram Matching (+): 4.34 [4.59] months
Crop (+) / Histogram Matching (+): 4.16 [4.45] months

Thus it is preferable to both crop and histogram match the image to obtain the optimal results. See https://huggingface.co/ianpan/bone-age-crop for how to crop a bone age radiograph with a pretrained model. To histogram match with a reference image:

import cv2
from skimage.exposure import match_histograms

x = cv2.imread("target_radiograph.png", 0)
ref = cv2.imread("ref_img.png", 0) # download ref_img.png from this repo
x = match_histograms(x, ref)

Patient sex is an important variable affecting the model's prediction. This is passed to the model's forward() function using the female argument:

# 1 indicates female, 0 male
model(x, female=torch.tensor([1, 0, 1, 0])) # assuming batch size of 4

Example usage for a single image:

import cv2
import torch
from skimage.exposure import match_histograms
from transformers import AutoModel

device = "cuda" if torch.cuda.is_available() else "cpu"

crop_model = AutoModel.from_pretrained("ianpan/bone-age-crop", trust_remote_code=True)
crop_model = crop_model.eval().to(device)
img = cv2.imread(..., 0)
img_shape = torch.tensor([img.shape[:2]])
x = crop_model.preprocess(img) # only takes single image as input
x = torch.from_numpy(x).unsqueeze(0).unsqueeze(0) # add channel, batch dims
x = x.float()

# if you do not provide img_shape
# model will return normalized coordinates
with torch.inference_mode():
  coords = model(x.to(device), img_shape.to(device))

# only 1 sample in batch
coords = coords[0].cpu().numpy()
x, y, w, h = coords
# coords already rescaled with img_shape
cropped_img = img[y: y + h, x: x + w]

model = AutoModel.from_pretrained("ianpan/bone-age", trust_remote_code=True)
model = model.eval().to(device)
x = model.preprocess(cropped_img)
x = torch.from_numpy(x).unsqueeze(0).unsqueeze(0) 
x = x.float()
female = torch.tensor([1])

with torch.inference_mode():
  bone_age = model(x.to(device), female.to(device))

If you want the raw logits (class i = i months), you can pass return_logits=True to forward():

bone_age_logits = model(x, female, return_logits=True)

To run single model inference, simply access one of the nets:

bone_age = model.net0(x, female)

If you have pydicom installed, you can also load a DICOM image directly:

img = model.load_image_from_dicom(path_to_dicom)

This model is for demonstration and research purposes only and has NOT been approved by any regulatory agency for clinical use. The user assumes any and all responsibility regarding their own use of this model and its outputs.