use huggingface models
Browse files- app.py +56 -74
- requirements.txt +2 -1
app.py
CHANGED
@@ -7,20 +7,20 @@ import torch
|
|
7 |
import torch.nn as nn
|
8 |
|
9 |
from einops import rearrange
|
10 |
-
from importlib import import_module
|
11 |
from pytorch_grad_cam import GradCAM
|
12 |
from pytorch_grad_cam.utils.model_targets import ClassifierOutputTarget
|
13 |
from skimage.exposure import match_histograms
|
14 |
-
from
|
15 |
|
16 |
|
17 |
class ModelForGradCAM(nn.Module):
|
18 |
-
def __init__(self, model):
|
19 |
super().__init__()
|
20 |
self.model = model
|
|
|
21 |
|
22 |
def forward(self, x):
|
23 |
-
return self.model(
|
24 |
|
25 |
|
26 |
def convert_bone_age_to_string(bone_age: float):
|
@@ -47,67 +47,29 @@ def convert_bone_age_to_string(bone_age: float):
|
|
47 |
return str_output
|
48 |
|
49 |
|
50 |
-
device = "cuda" if torch.cuda.is_available() else "cpu"
|
51 |
-
|
52 |
-
cfg_crop = import_module("skp.configs.boneage.cfg_crop_simple_resize").cfg
|
53 |
-
crop_model = load_model_from_config(
|
54 |
-
cfg_crop, weights_path="crop.pt", device=device, eval_mode=True
|
55 |
-
)
|
56 |
-
|
57 |
-
cfg = import_module("skp.configs.boneage.cfg_female_channel_reg_cls_match_hist").cfg
|
58 |
-
cfg.backbone = "convnextv2_tiny"
|
59 |
-
|
60 |
-
model_list = load_kfold_ensemble_as_list(
|
61 |
-
cfg, [f"net{i}.pt" for i in range(3)], device=device, eval_mode=True
|
62 |
-
)
|
63 |
-
|
64 |
-
ref_img = rearrange(cv2.imread("ref_img.png", 0), "h w -> h w 1 ")
|
65 |
-
|
66 |
-
with open("greulich_and_pyle_ages.json", "r") as f:
|
67 |
-
greulich_and_pyle_ages = json.load(f)["bone_ages"]
|
68 |
-
|
69 |
-
greulich_and_pyle_ages = {k: np.asarray(v) for k, v in greulich_and_pyle_ages.items()}
|
70 |
-
|
71 |
-
model_grad_cam = ModelForGradCAM(model_list[0])
|
72 |
-
target_layers = [model_grad_cam.model.backbone.stages[-1]]
|
73 |
-
|
74 |
-
|
75 |
@spaces.GPU
|
76 |
def predict_bone_age(Radiograph, Sex, Heatmap):
|
77 |
-
|
78 |
-
x =
|
79 |
-
x =
|
80 |
-
x = rearrange(x, "h w c -> 1 c h w")
|
81 |
# crop
|
|
|
82 |
with torch.inference_mode():
|
83 |
-
box = crop_model(
|
84 |
-
|
85 |
-
|
86 |
-
box[[0, 2]] = box[[0, 2]] * x0.shape[1]
|
87 |
-
box[[1, 3]] = box[[1, 3]] * x0.shape[0]
|
88 |
-
box = box.numpy().astype("int")
|
89 |
-
x, y, w, h = box
|
90 |
-
x0 = x0[y : y + h, x : x + w]
|
91 |
# histogram matching
|
92 |
-
|
93 |
-
|
94 |
-
|
95 |
-
|
96 |
-
|
97 |
-
|
98 |
-
x = np.concatenate([x, ch], axis=-1)
|
99 |
-
x = torch.from_numpy(x)
|
100 |
-
x = rearrange(x, "h w c -> 1 c h w")
|
101 |
with torch.inference_mode():
|
102 |
-
bone_age = []
|
103 |
-
for each_model in model_list:
|
104 |
-
pred = each_model({"x": x.to(device).float()}, return_loss=False)[
|
105 |
-
"logits1"
|
106 |
-
][0].cpu()
|
107 |
-
pred = (pred.softmax(0) * torch.arange(240)).sum().numpy()
|
108 |
-
bone_age.append(pred)
|
109 |
-
bone_age = np.mean(bone_age)
|
110 |
|
|
|
|
|
111 |
gp_ages = greulich_and_pyle_ages["female" if Sex else "male"]
|
112 |
diffs_gp = np.abs(bone_age - gp_ages)
|
113 |
diffs_gp = np.argsort(diffs_gp)
|
@@ -119,29 +81,33 @@ def predict_bone_age(Radiograph, Sex, Heatmap):
|
|
119 |
closest2 = convert_bone_age_to_string(closest2)
|
120 |
|
121 |
if Heatmap:
|
|
|
|
|
|
|
|
|
|
|
|
|
122 |
targets = [ClassifierOutputTarget(round(bone_age))]
|
123 |
with GradCAM(model=model_grad_cam, target_layers=target_layers) as cam:
|
124 |
-
grayscale_cam = cam(
|
125 |
-
input_tensor=x.to(device).float(), targets=targets, eigen_smooth=True
|
126 |
-
)
|
127 |
|
128 |
heatmap = cv2.applyColorMap(
|
129 |
(grayscale_cam[0] * 255).astype("uint8"), cv2.COLORMAP_JET
|
130 |
)
|
131 |
-
image = cv2.cvtColor(
|
|
|
|
|
132 |
image_weight = 0.6
|
133 |
grad_cam_image = (1 - image_weight) * heatmap[..., ::-1] + image_weight * image
|
134 |
-
grad_cam_image = grad_cam_image
|
135 |
else:
|
136 |
# if no heatmap desired, just show image
|
137 |
-
grad_cam_image = cv2.cvtColor(
|
138 |
-
x[0, 0].cpu().numpy().astype("uint8"), cv2.COLOR_GRAY2RGB
|
139 |
-
)
|
140 |
|
141 |
return (
|
142 |
bone_age_str,
|
143 |
f"The closest Greulich & Pyle bone ages are:\n 1) {closest1}\n 2) {closest2}",
|
144 |
-
grad_cam_image,
|
145 |
)
|
146 |
|
147 |
|
@@ -157,11 +123,8 @@ with gr.Blocks() as demo:
|
|
157 |
"""
|
158 |
# Deep Learning Model for Pediatric Bone Age
|
159 |
|
160 |
-
This model predicts the bone age from a single frontal view hand radiograph.
|
161 |
-
|
162 |
-
[RSNA Pediatric Bone Age Challenge](https://www.rsna.org/rsnai/ai-image-challenge/rsna-pediatric-bone-age-challenge-2017) dataset.
|
163 |
-
The model achieves a mean absolute error of 4.26 months on the original test set comprising 200 multi-annotated hand radiographs,
|
164 |
-
which is competitive with [top solutions](https://pubs.rsna.org/doi/10.1148/radiol.2018180736) from the original challenge.
|
165 |
|
166 |
There is also an option to output a heatmap over the radiograph to show regions where the model is focusing on
|
167 |
to make its prediction. However, this takes extra computation and will increase the runtime.
|
@@ -172,7 +135,7 @@ with gr.Blocks() as demo:
|
|
172 |
|
173 |
Created by: Ian Pan, <https://ianpan.me>
|
174 |
|
175 |
-
Last updated: December
|
176 |
"""
|
177 |
)
|
178 |
gr.Interface(
|
@@ -184,8 +147,27 @@ with gr.Blocks() as demo:
|
|
184 |
["examples/10043.png", "Female", "No"],
|
185 |
["examples/8888.png", "Female", "Yes"],
|
186 |
],
|
187 |
-
cache_examples=
|
188 |
)
|
189 |
|
190 |
if __name__ == "__main__":
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
191 |
demo.launch(share=True)
|
|
|
7 |
import torch.nn as nn
|
8 |
|
9 |
from einops import rearrange
|
|
|
10 |
from pytorch_grad_cam import GradCAM
|
11 |
from pytorch_grad_cam.utils.model_targets import ClassifierOutputTarget
|
12 |
from skimage.exposure import match_histograms
|
13 |
+
from transformers import AutoModel
|
14 |
|
15 |
|
16 |
class ModelForGradCAM(nn.Module):
|
17 |
+
def __init__(self, model, female):
|
18 |
super().__init__()
|
19 |
self.model = model
|
20 |
+
self.female = female
|
21 |
|
22 |
def forward(self, x):
|
23 |
+
return self.model(x, self.female, return_logits=True)
|
24 |
|
25 |
|
26 |
def convert_bone_age_to_string(bone_age: float):
|
|
|
47 |
return str_output
|
48 |
|
49 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
50 |
@spaces.GPU
|
51 |
def predict_bone_age(Radiograph, Sex, Heatmap):
|
52 |
+
x = crop_model.preprocess(Radiograph)
|
53 |
+
x = torch.from_numpy(x).float().to(device)
|
54 |
+
x = rearrange(x, "h w -> 1 1 h w")
|
|
|
55 |
# crop
|
56 |
+
img_shape = torch.tensor([Radiograph.shape[:2]]).to(device)
|
57 |
with torch.inference_mode():
|
58 |
+
box = crop_model(x, img_shape=img_shape).to("cpu").numpy()
|
59 |
+
x, y, w, h = box[0]
|
60 |
+
cropped = Radiograph[y : y + h, x : x + w]
|
|
|
|
|
|
|
|
|
|
|
61 |
# histogram matching
|
62 |
+
x = match_histograms(cropped, ref_img)
|
63 |
+
|
64 |
+
x = model.preprocess(x)
|
65 |
+
x = torch.from_numpy(x).float().to(device)
|
66 |
+
x = rearrange(x, "h w -> 1 1 h w")
|
67 |
+
female = torch.tensor([Sex]).to(device)
|
|
|
|
|
|
|
68 |
with torch.inference_mode():
|
69 |
+
bone_age = model(x, female)[0].item()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
70 |
|
71 |
+
# get closest G&P ages
|
72 |
+
# from: https://rad.esmil.com/Reference/G_P_BoneAge/
|
73 |
gp_ages = greulich_and_pyle_ages["female" if Sex else "male"]
|
74 |
diffs_gp = np.abs(bone_age - gp_ages)
|
75 |
diffs_gp = np.argsort(diffs_gp)
|
|
|
81 |
closest2 = convert_bone_age_to_string(closest2)
|
82 |
|
83 |
if Heatmap:
|
84 |
+
# net1 and net2 to give good GradCAMs
|
85 |
+
# net0 is bad for some reason
|
86 |
+
# because GradCAM expects 1 input tensor, need to
|
87 |
+
# pass female during class instantiation
|
88 |
+
model_grad_cam = ModelForGradCAM(model.net1, female)
|
89 |
+
target_layers = [model_grad_cam.model.backbone.stages[-1]]
|
90 |
targets = [ClassifierOutputTarget(round(bone_age))]
|
91 |
with GradCAM(model=model_grad_cam, target_layers=target_layers) as cam:
|
92 |
+
grayscale_cam = cam(input_tensor=x, targets=targets, eigen_smooth=True)
|
|
|
|
|
93 |
|
94 |
heatmap = cv2.applyColorMap(
|
95 |
(grayscale_cam[0] * 255).astype("uint8"), cv2.COLORMAP_JET
|
96 |
)
|
97 |
+
image = cv2.cvtColor(
|
98 |
+
x[0, 0].to("cpu").numpy().astype("uint8"), cv2.COLOR_GRAY2RGB
|
99 |
+
)
|
100 |
image_weight = 0.6
|
101 |
grad_cam_image = (1 - image_weight) * heatmap[..., ::-1] + image_weight * image
|
102 |
+
grad_cam_image = grad_cam_image
|
103 |
else:
|
104 |
# if no heatmap desired, just show image
|
105 |
+
grad_cam_image = cv2.cvtColor(x[0, 0].to("cpu").numpy(), cv2.COLOR_GRAY2RGB)
|
|
|
|
|
106 |
|
107 |
return (
|
108 |
bone_age_str,
|
109 |
f"The closest Greulich & Pyle bone ages are:\n 1) {closest1}\n 2) {closest2}",
|
110 |
+
grad_cam_image.astype("uint8"),
|
111 |
)
|
112 |
|
113 |
|
|
|
123 |
"""
|
124 |
# Deep Learning Model for Pediatric Bone Age
|
125 |
|
126 |
+
This model predicts the bone age from a single frontal view hand radiograph. Read more about the model here:
|
127 |
+
<https://huggingface.co/ianpan/bone-age>
|
|
|
|
|
|
|
128 |
|
129 |
There is also an option to output a heatmap over the radiograph to show regions where the model is focusing on
|
130 |
to make its prediction. However, this takes extra computation and will increase the runtime.
|
|
|
135 |
|
136 |
Created by: Ian Pan, <https://ianpan.me>
|
137 |
|
138 |
+
Last updated: December 16, 2024
|
139 |
"""
|
140 |
)
|
141 |
gr.Interface(
|
|
|
147 |
["examples/10043.png", "Female", "No"],
|
148 |
["examples/8888.png", "Female", "Yes"],
|
149 |
],
|
150 |
+
cache_examples="lazy",
|
151 |
)
|
152 |
|
153 |
if __name__ == "__main__":
|
154 |
+
device = "cuda" if torch.cuda.is_available() else "cpu"
|
155 |
+
print(f"Using device `{device}` ...")
|
156 |
+
|
157 |
+
crop_model = AutoModel.from_pretrained(
|
158 |
+
"ianpan/bone-age-crop", trust_remote_code=True
|
159 |
+
)
|
160 |
+
model = AutoModel.from_pretrained("ianpan/bone-age", trust_remote_code=True)
|
161 |
+
|
162 |
+
crop_model, model = crop_model.eval().to(device), model.eval().to(device)
|
163 |
+
|
164 |
+
ref_img = cv2.imread("ref_img.png", 0)
|
165 |
+
|
166 |
+
with open("greulich_and_pyle_ages.json", "r") as f:
|
167 |
+
greulich_and_pyle_ages = json.load(f)["bone_ages"]
|
168 |
+
|
169 |
+
greulich_and_pyle_ages = {
|
170 |
+
k: np.asarray(v) for k, v in greulich_and_pyle_ages.items()
|
171 |
+
}
|
172 |
+
|
173 |
demo.launch(share=True)
|
requirements.txt
CHANGED
@@ -5,4 +5,5 @@ gradio
|
|
5 |
scikit-image
|
6 |
spaces
|
7 |
timm
|
8 |
-
torch
|
|
|
|
5 |
scikit-image
|
6 |
spaces
|
7 |
timm
|
8 |
+
torch
|
9 |
+
transformers
|