ianpan commited on
Commit
21b0590
·
1 Parent(s): 041fdf1

use huggingface models

Browse files
Files changed (2) hide show
  1. app.py +56 -74
  2. requirements.txt +2 -1
app.py CHANGED
@@ -7,20 +7,20 @@ import torch
7
  import torch.nn as nn
8
 
9
  from einops import rearrange
10
- from importlib import import_module
11
  from pytorch_grad_cam import GradCAM
12
  from pytorch_grad_cam.utils.model_targets import ClassifierOutputTarget
13
  from skimage.exposure import match_histograms
14
- from skp.utils import load_model_from_config, load_kfold_ensemble_as_list
15
 
16
 
17
  class ModelForGradCAM(nn.Module):
18
- def __init__(self, model):
19
  super().__init__()
20
  self.model = model
 
21
 
22
  def forward(self, x):
23
- return self.model({"x": x})["logits1"]
24
 
25
 
26
  def convert_bone_age_to_string(bone_age: float):
@@ -47,67 +47,29 @@ def convert_bone_age_to_string(bone_age: float):
47
  return str_output
48
 
49
 
50
- device = "cuda" if torch.cuda.is_available() else "cpu"
51
-
52
- cfg_crop = import_module("skp.configs.boneage.cfg_crop_simple_resize").cfg
53
- crop_model = load_model_from_config(
54
- cfg_crop, weights_path="crop.pt", device=device, eval_mode=True
55
- )
56
-
57
- cfg = import_module("skp.configs.boneage.cfg_female_channel_reg_cls_match_hist").cfg
58
- cfg.backbone = "convnextv2_tiny"
59
-
60
- model_list = load_kfold_ensemble_as_list(
61
- cfg, [f"net{i}.pt" for i in range(3)], device=device, eval_mode=True
62
- )
63
-
64
- ref_img = rearrange(cv2.imread("ref_img.png", 0), "h w -> h w 1 ")
65
-
66
- with open("greulich_and_pyle_ages.json", "r") as f:
67
- greulich_and_pyle_ages = json.load(f)["bone_ages"]
68
-
69
- greulich_and_pyle_ages = {k: np.asarray(v) for k, v in greulich_and_pyle_ages.items()}
70
-
71
- model_grad_cam = ModelForGradCAM(model_list[0])
72
- target_layers = [model_grad_cam.model.backbone.stages[-1]]
73
-
74
-
75
  @spaces.GPU
76
  def predict_bone_age(Radiograph, Sex, Heatmap):
77
- x0 = rearrange(Radiograph, "h w -> h w 1")
78
- x = cfg_crop.val_transforms(image=x0)["image"]
79
- x = torch.from_numpy(x)
80
- x = rearrange(x, "h w c -> 1 c h w")
81
  # crop
 
82
  with torch.inference_mode():
83
- box = crop_model({"x": x.to(device).float()}, return_loss=False)["logits"][
84
- 0
85
- ].cpu()
86
- box[[0, 2]] = box[[0, 2]] * x0.shape[1]
87
- box[[1, 3]] = box[[1, 3]] * x0.shape[0]
88
- box = box.numpy().astype("int")
89
- x, y, w, h = box
90
- x0 = x0[y : y + h, x : x + w]
91
  # histogram matching
92
- x0 = match_histograms(x0, ref_img)
93
- x = cfg.val_transforms(image=x0)["image"]
94
- # create image channel for female/male
95
- ch = np.zeros_like(x)
96
- if Sex: # 0- male, 1- female
97
- ch[...] = 255
98
- x = np.concatenate([x, ch], axis=-1)
99
- x = torch.from_numpy(x)
100
- x = rearrange(x, "h w c -> 1 c h w")
101
  with torch.inference_mode():
102
- bone_age = []
103
- for each_model in model_list:
104
- pred = each_model({"x": x.to(device).float()}, return_loss=False)[
105
- "logits1"
106
- ][0].cpu()
107
- pred = (pred.softmax(0) * torch.arange(240)).sum().numpy()
108
- bone_age.append(pred)
109
- bone_age = np.mean(bone_age)
110
 
 
 
111
  gp_ages = greulich_and_pyle_ages["female" if Sex else "male"]
112
  diffs_gp = np.abs(bone_age - gp_ages)
113
  diffs_gp = np.argsort(diffs_gp)
@@ -119,29 +81,33 @@ def predict_bone_age(Radiograph, Sex, Heatmap):
119
  closest2 = convert_bone_age_to_string(closest2)
120
 
121
  if Heatmap:
 
 
 
 
 
 
122
  targets = [ClassifierOutputTarget(round(bone_age))]
123
  with GradCAM(model=model_grad_cam, target_layers=target_layers) as cam:
124
- grayscale_cam = cam(
125
- input_tensor=x.to(device).float(), targets=targets, eigen_smooth=True
126
- )
127
 
128
  heatmap = cv2.applyColorMap(
129
  (grayscale_cam[0] * 255).astype("uint8"), cv2.COLORMAP_JET
130
  )
131
- image = cv2.cvtColor(x[0, 0].cpu().numpy().astype("uint8"), cv2.COLOR_GRAY2RGB)
 
 
132
  image_weight = 0.6
133
  grad_cam_image = (1 - image_weight) * heatmap[..., ::-1] + image_weight * image
134
- grad_cam_image = grad_cam_image.astype("uint8")
135
  else:
136
  # if no heatmap desired, just show image
137
- grad_cam_image = cv2.cvtColor(
138
- x[0, 0].cpu().numpy().astype("uint8"), cv2.COLOR_GRAY2RGB
139
- )
140
 
141
  return (
142
  bone_age_str,
143
  f"The closest Greulich & Pyle bone ages are:\n 1) {closest1}\n 2) {closest2}",
144
- grad_cam_image,
145
  )
146
 
147
 
@@ -157,11 +123,8 @@ with gr.Blocks() as demo:
157
  """
158
  # Deep Learning Model for Pediatric Bone Age
159
 
160
- This model predicts the bone age from a single frontal view hand radiograph.
161
- The model was trained on the publicly available
162
- [RSNA Pediatric Bone Age Challenge](https://www.rsna.org/rsnai/ai-image-challenge/rsna-pediatric-bone-age-challenge-2017) dataset.
163
- The model achieves a mean absolute error of 4.26 months on the original test set comprising 200 multi-annotated hand radiographs,
164
- which is competitive with [top solutions](https://pubs.rsna.org/doi/10.1148/radiol.2018180736) from the original challenge.
165
 
166
  There is also an option to output a heatmap over the radiograph to show regions where the model is focusing on
167
  to make its prediction. However, this takes extra computation and will increase the runtime.
@@ -172,7 +135,7 @@ with gr.Blocks() as demo:
172
 
173
  Created by: Ian Pan, <https://ianpan.me>
174
 
175
- Last updated: December 15, 2024
176
  """
177
  )
178
  gr.Interface(
@@ -184,8 +147,27 @@ with gr.Blocks() as demo:
184
  ["examples/10043.png", "Female", "No"],
185
  ["examples/8888.png", "Female", "Yes"],
186
  ],
187
- cache_examples=False,
188
  )
189
 
190
  if __name__ == "__main__":
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
191
  demo.launch(share=True)
 
7
  import torch.nn as nn
8
 
9
  from einops import rearrange
 
10
  from pytorch_grad_cam import GradCAM
11
  from pytorch_grad_cam.utils.model_targets import ClassifierOutputTarget
12
  from skimage.exposure import match_histograms
13
+ from transformers import AutoModel
14
 
15
 
16
  class ModelForGradCAM(nn.Module):
17
+ def __init__(self, model, female):
18
  super().__init__()
19
  self.model = model
20
+ self.female = female
21
 
22
  def forward(self, x):
23
+ return self.model(x, self.female, return_logits=True)
24
 
25
 
26
  def convert_bone_age_to_string(bone_age: float):
 
47
  return str_output
48
 
49
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
50
  @spaces.GPU
51
  def predict_bone_age(Radiograph, Sex, Heatmap):
52
+ x = crop_model.preprocess(Radiograph)
53
+ x = torch.from_numpy(x).float().to(device)
54
+ x = rearrange(x, "h w -> 1 1 h w")
 
55
  # crop
56
+ img_shape = torch.tensor([Radiograph.shape[:2]]).to(device)
57
  with torch.inference_mode():
58
+ box = crop_model(x, img_shape=img_shape).to("cpu").numpy()
59
+ x, y, w, h = box[0]
60
+ cropped = Radiograph[y : y + h, x : x + w]
 
 
 
 
 
61
  # histogram matching
62
+ x = match_histograms(cropped, ref_img)
63
+
64
+ x = model.preprocess(x)
65
+ x = torch.from_numpy(x).float().to(device)
66
+ x = rearrange(x, "h w -> 1 1 h w")
67
+ female = torch.tensor([Sex]).to(device)
 
 
 
68
  with torch.inference_mode():
69
+ bone_age = model(x, female)[0].item()
 
 
 
 
 
 
 
70
 
71
+ # get closest G&P ages
72
+ # from: https://rad.esmil.com/Reference/G_P_BoneAge/
73
  gp_ages = greulich_and_pyle_ages["female" if Sex else "male"]
74
  diffs_gp = np.abs(bone_age - gp_ages)
75
  diffs_gp = np.argsort(diffs_gp)
 
81
  closest2 = convert_bone_age_to_string(closest2)
82
 
83
  if Heatmap:
84
+ # net1 and net2 to give good GradCAMs
85
+ # net0 is bad for some reason
86
+ # because GradCAM expects 1 input tensor, need to
87
+ # pass female during class instantiation
88
+ model_grad_cam = ModelForGradCAM(model.net1, female)
89
+ target_layers = [model_grad_cam.model.backbone.stages[-1]]
90
  targets = [ClassifierOutputTarget(round(bone_age))]
91
  with GradCAM(model=model_grad_cam, target_layers=target_layers) as cam:
92
+ grayscale_cam = cam(input_tensor=x, targets=targets, eigen_smooth=True)
 
 
93
 
94
  heatmap = cv2.applyColorMap(
95
  (grayscale_cam[0] * 255).astype("uint8"), cv2.COLORMAP_JET
96
  )
97
+ image = cv2.cvtColor(
98
+ x[0, 0].to("cpu").numpy().astype("uint8"), cv2.COLOR_GRAY2RGB
99
+ )
100
  image_weight = 0.6
101
  grad_cam_image = (1 - image_weight) * heatmap[..., ::-1] + image_weight * image
102
+ grad_cam_image = grad_cam_image
103
  else:
104
  # if no heatmap desired, just show image
105
+ grad_cam_image = cv2.cvtColor(x[0, 0].to("cpu").numpy(), cv2.COLOR_GRAY2RGB)
 
 
106
 
107
  return (
108
  bone_age_str,
109
  f"The closest Greulich & Pyle bone ages are:\n 1) {closest1}\n 2) {closest2}",
110
+ grad_cam_image.astype("uint8"),
111
  )
112
 
113
 
 
123
  """
124
  # Deep Learning Model for Pediatric Bone Age
125
 
126
+ This model predicts the bone age from a single frontal view hand radiograph. Read more about the model here:
127
+ <https://huggingface.co/ianpan/bone-age>
 
 
 
128
 
129
  There is also an option to output a heatmap over the radiograph to show regions where the model is focusing on
130
  to make its prediction. However, this takes extra computation and will increase the runtime.
 
135
 
136
  Created by: Ian Pan, <https://ianpan.me>
137
 
138
+ Last updated: December 16, 2024
139
  """
140
  )
141
  gr.Interface(
 
147
  ["examples/10043.png", "Female", "No"],
148
  ["examples/8888.png", "Female", "Yes"],
149
  ],
150
+ cache_examples="lazy",
151
  )
152
 
153
  if __name__ == "__main__":
154
+ device = "cuda" if torch.cuda.is_available() else "cpu"
155
+ print(f"Using device `{device}` ...")
156
+
157
+ crop_model = AutoModel.from_pretrained(
158
+ "ianpan/bone-age-crop", trust_remote_code=True
159
+ )
160
+ model = AutoModel.from_pretrained("ianpan/bone-age", trust_remote_code=True)
161
+
162
+ crop_model, model = crop_model.eval().to(device), model.eval().to(device)
163
+
164
+ ref_img = cv2.imread("ref_img.png", 0)
165
+
166
+ with open("greulich_and_pyle_ages.json", "r") as f:
167
+ greulich_and_pyle_ages = json.load(f)["bone_ages"]
168
+
169
+ greulich_and_pyle_ages = {
170
+ k: np.asarray(v) for k, v in greulich_and_pyle_ages.items()
171
+ }
172
+
173
  demo.launch(share=True)
requirements.txt CHANGED
@@ -5,4 +5,5 @@ gradio
5
  scikit-image
6
  spaces
7
  timm
8
- torch
 
 
5
  scikit-image
6
  spaces
7
  timm
8
+ torch
9
+ transformers