File size: 5,069 Bytes
5464776
 
2c29ffd
 
 
 
 
 
 
 
 
5464776
 
2c29ffd
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
---
library_name: transformers
tags:
- radiology
- medical_imaging
- bone_age
- x_ray
license: apache-2.0
base_model:
- timm/convnextv2_tiny.fcmae_ft_in22k_in1k
pipeline_tag: image-classification
---

This model has been trained and validated on 14,036 pediatric hand radiographs from the [RSNA Pediatric Bone Age Challenge](https://www.rsna.org/rsnai/ai-image-challenge/rsna-pediatric-bone-age-challenge-2017) dataset, which is publicly available. 
It can be loaded using:
```
from transformers import AutoModel

model = AutoModel.from_pretrained("ianpan/bone-age", trust_remote_code=True)
```
The model is a 3-fold ensemble utilizing the `convnextv2_tiny` backbone. 
The individual models can be accessed through `model.net0`, `model.net1`, `model.net2`.
Originally, it was trained with both a regression and classification head. 
However, this model only loads the classification head, as stand-alone performance was slightly better. The classification head also generates better GradCAMs.
The softmax function is applied to the output logits and multiplied by the corresponding class indices, then summed.
This outputs a scalar float value representing the predicted bone age in units of months. 

In addition to standard data augmentation, additional augmentations were also applied:
- Using a cropped radiograph (from the model <https://huggingface.co/ianpan/bone-age-crop>) with probability 0.5
- Histogram matching with a reference image (available in this repo under Files, `ref_img.png`) with probability 0.5

The model was trained over 20,000 iterations using a batch size of 64 across 2 NVIDIA RTX 3090 GPUs. 

Note that both of the above augmentations could be applied simultaneously and in conjunction with standard data augamentations. Thus, the model accommodates a large range of variability in the appearance of a hand radiograph. 

On the original challenge test set comprising 200 multi-annotated pediatric hand radiographs, this model achieves a **mean absolute error of 4.16 months** (when applying both cropping and histogram matching to the input radiograph), which surpasses the [top solutions](https://pubs.rsna.org/doi/10.1148/radiol.2018180736) from the original challenge.
Specific results as follows, with single model performance using `model.net0` in brackets:
```
Crop (-) / Histogram Matching (-): 4.42 [4.67] months
Crop (+) / Histogram Matching (-): 4.47 [4.84] months
Crop (-) / Histogram Matching (+): 4.34 [4.59] months
Crop (+) / Histogram Matching (+): 4.16 [4.45] months
```

Thus it is preferable to both crop and histogram match the image to obtain the optimal results. See <https://huggingface.co/ianpan/bone-age-crop> for how to crop a bone age radiograph with a pretrained model.
To histogram match with a reference image: 
```
import cv2
from skimage.exposure import match_histograms

x = cv2.imread("target_radiograph.png", 0)
ref = cv2.imread("ref_img.png", 0) # download ref_img.png from this repo
x = match_histograms(x, ref)
```

Patient sex is an important variable affecting the model's prediction. This is passed to the model's `forward()` function using the `female` argument:
```
# 1 indicates female, 0 male
model(x, female=torch.tensor([1, 0, 1, 0])) # assuming batch size of 4
```

Example usage for a single image:
```
import cv2
import torch
from skimage.exposure import match_histograms
from transformers import AutoModel

device = "cuda" if torch.cuda.is_available() else "cpu"

crop_model = AutoModel.from_pretrained("ianpan/bone-age-crop", trust_remote_code=True)
crop_model = crop_model.eval().to(device)
img = cv2.imread(..., 0)
img_shape = torch.tensor([img.shape[:2]])
x = crop_model.preprocess(img) # only takes single image as input
x = torch.from_numpy(x).unsqueeze(0).unsqueeze(0) # add channel, batch dims
x = x.float()

# if you do not provide img_shape
# model will return normalized coordinates
with torch.inference_mode():
  coords = model(x.to(device), img_shape.to(device))

# only 1 sample in batch
coords = coords[0].cpu().numpy()
x, y, w, h = coords
# coords already rescaled with img_shape
cropped_img = img[y: y + h, x: x + w]

model = AutoModel.from_pretrained("ianpan/bone-age", trust_remote_code=True)
model = model.eval().to(device)
x = model.preprocess(cropped_img)
x = torch.from_numpy(x).unsqueeze(0).unsqueeze(0) 
x = x.float()
female = torch.tensor([1])

with torch.inference_mode():
  bone_age = model(x.to(device), female.to(device))
```

If you want the raw logits (class `i` = `i` months), you can pass `return_logits=True` to `forward()`:
```
bone_age_logits = model(x, female, return_logits=True)
```

To run single model inference, simply access one of the nets:
```
bone_age = model.net0(x, female)
```

If you have `pydicom` installed, you can also load a DICOM image directly:
```
img = model.load_image_from_dicom(path_to_dicom)
```
 
This model is for demonstration and research purposes only and has NOT been approved by any regulatory agency for clinical use. 
The user assumes any and all responsibility regarding their own use of this model and its outputs.