File size: 7,255 Bytes
ab163d2 f722806 ab163d2 53435a7 ab163d2 04f372b ae39e2f 04f372b ab163d2 9afb16d ab163d2 9afb16d ab163d2 574fc22 9afb16d ab163d2 9afb16d ab163d2 9afb16d ab163d2 9afb16d ab163d2 9afb16d ab163d2 9afb16d ab163d2 f7975aa ab163d2 f722806 ab163d2 f722806 ab163d2 f722806 ab163d2 ae39e2f ab163d2 9afb16d 04f372b ab163d2 5da7918 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 |
import gradio as gr
import utils
import Model_Class
import Model_Seg
import SimpleITK as sitk
import torch
from numpy import uint8
import spaces
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
image_base64 = utils.image_to_base64("anatomy_aware_pipeline.png")
article_html = f"<img src='data:image/png;base64,{image_base64}' alt='Anatomical pipeline illustration' style='width:100%;'>"
description_markdown = """
- This tool combines a U-Net Segmentation Model with a ResNet-50 for Classification.
- **Usage:** Just drag a pelvic x-ray into the box and hit run.
- **Process:** The input image will be segmented and cropped to the SIJ before classification.
- **Please Note:** This tool is intended for research purposes only.
- **Privacy:** Please ensure data privacy and don't upload any sensitive patient information to this tool.
"""
description_html = """
<div style="background-color: #0b0f1a; color: white !important; padding: 10px; border-radius: 5px; box-shadow: 0 0 10px rgba(11,15,26,1); display: inline-flex; flex-direction: column; justify-content: center; align-items: center; margin: auto;">
<ul style="background-color: #1e2936; border-radius: 5px; color: white !important; padding: 10px; box-shadow: 0 0 10px rgba(0,0,0,0.3); padding-left: 20px; text-align: left; list-style-position: inside;">
<li>This tool combines a U-Net Segmentation Model with a ResNet-50 for Classification.</li>
<li><strong>Usage:</strong> Just drag a pelvic x-ray into the box and hit run.</li>
<li><strong>Process:</strong> The input image will be segmented and cropped to the SIJ before classification.</li>
<li><strong>Please Note:</strong> This tool is intended for research purposes only.</li>
<li><strong>Privacy:</strong> Please ensure data privacy and don't upload any sensitive patient information to this tool.</li>
</ul>
</div>
"""
css = """
/* Existing CSS styles */
.markdown-block {
background-color: #0b0f1a !important; /* Light gray background */
color: black !important; /* Black text */
padding: 10px; /* Padding around the text */
border-radius: 5px; /* Rounded corners */
box-shadow: 0 0 10px rgba(11,15,26,1);
display: inline-flex; /* Use inline-flex to shrink to content size */
flex-direction: column;
justify-content: center; /* Vertically center content */
align-items: center; /* Horizontally center items within */
margin: auto; /* Center the block */
}
.markdown-block ul, .markdown-block ol {
background-color: #1e2936 !important;
border-radius: 5px;
padding: 10px;
box-shadow: 0 0 10px rgba(0,0,0,0.3);
padding-left: 20px; /* Adjust padding for bullet alignment */
text-align: left; /* Ensure text within list is left-aligned */
list-style-position: inside;/* Ensures bullets/numbers are inside the content flow */
}
/* New custom CSS class for Gradio container */
.gradiocontainer-custom {
background-color: #0b0f1a !important; /* Custom background color */
color: black !important; /* Custom text color */
padding: 10px; /* Padding around the text */
border-radius: 5px; /* Rounded corners */
box-shadow: 0 0 10px rgba(11,15,26,1);
display: inline-flex; /* Use inline-flex to shrink to content size */
flex-direction: column;
justify-content: center; /* Vertically center content */
align-items: center; /* Horizontally center items within */
margin: auto; /* Center the block */
}
.gradiocontainer-custom ul, .gradiocontainer-custom ol {
background-color: #1e2936 !important;
border-radius: 5px;
padding: 10px;
box-shadow: 0 0 10px rgba(0,0,0,0.3);
padding-left: 20px; /* Adjust padding for bullet alignment */
text-align: left; /* Ensure text within list is left-aligned */
list-style-position: inside; /* Ensures bullets/numbers are inside the content flow */
}
footer {
display:none !important;
}
"""
@spaces.GPU(duration=20)
def predict_image(input_image, input_file):
if input_image is not None:
image_path = input_image
elif input_file is not None:
image_path = input_file
else:
return None , None , "Please input an image before pressing run" , None , None
image_mask = Model_Seg.load_and_segment_image(image_path, device)
overlay_image_np, original_image_np = utils.overlay_mask(image_path, image_mask)
image_mask_im = sitk.GetImageFromArray(image_mask[None, :, :].astype(uint8))
image_im = sitk.GetImageFromArray(original_image_np[None, :, :].astype(uint8))
cropped_boxed_im, _ = utils.mask_and_crop(image_im, image_mask_im)
cropped_boxed_array = sitk.GetArrayFromImage(cropped_boxed_im)
cropped_boxed_array_disp = cropped_boxed_array.squeeze()
cropped_boxed_tensor = torch.Tensor(cropped_boxed_array)
prediction, image_transformed = Model_Class.load_and_classify_image(cropped_boxed_tensor, device)
gradcam = Model_Class.make_GradCAM(image_transformed, device)
nr_axSpA_prob = float(prediction[0].item())
r_axSpA_prob = float(prediction[1].item())
# Decision based on the threshold
considered = "be considered r-axSpA" if r_axSpA_prob > 0.59 else "not be considered r-axSpA"
explanation = f"According to the pre-determined cut-off threshold of 0.59, the image should {considered}. This Tool is for research purposes only."
pred_dict = {"nr-axSpA": nr_axSpA_prob, "r-axSpA": r_axSpA_prob}
return overlay_image_np, pred_dict, explanation, gradcam, cropped_boxed_array_disp
with gr.Blocks(title="Anatomy Aware axSpA") as iface:
gr.Markdown("# Anatomy-Aware Image Classification for radiographic axSpA")
gr.Markdown(description_markdown, elem_classes="markdown-block")
gr.HTML(description_html)
with gr.Row():
with gr.Column():
with gr.Tab("PNG/JPG"):
input_image = gr.Image(type='filepath', label="Upload an X-ray Image")
with gr.Tab("NIfTI/DICOM"):
input_file = gr.File(type='filepath', label="Upload an X-ray Image")
with gr.Row():
submit_button = gr.Button("Run", variant="primary")
clear_button = gr.ClearButton()
with gr.Column():
overlay_image_np = gr.Image(label="Segmentation Mask")
pred_dict = gr.Label(label="Prediction")
explanation= gr.Textbox(label="Classification Decision")
with gr.Accordion("Additional Information", open=False):
gradcam = gr.Image(label="GradCAM")
cropped_boxed_array_disp = gr.Image(label="Bounding Box")
submit_button.click(predict_image, inputs = [input_image, input_file], outputs=[overlay_image_np, pred_dict, explanation, gradcam, cropped_boxed_array_disp])
clear_button.add([input_image,overlay_image_np, pred_dict, explanation, gradcam, cropped_boxed_array_disp])
gr.HTML(article_html)
if __name__ == "__main__":
iface.queue()
iface.launch()
|