|
import gradio as gr |
|
import utils |
|
import Model_Class |
|
import Model_Seg |
|
|
|
import SimpleITK as sitk |
|
import torch |
|
from numpy import uint8 |
|
import spaces |
|
|
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
|
|
|
image_base64 = utils.image_to_base64("anatomy_aware_pipeline.png") |
|
article_html = f"<img src='data:image/png;base64,{image_base64}' alt='Anatomical pipeline illustration' style='width:100%;'>" |
|
|
|
description_markdown = """ |
|
- This tool combines a U-Net Segmentation Model with a ResNet-50 for Classification. |
|
- For more info checkout the GitHub here: https://github.com/FJDorfner/Anatomy-Aware-Classification-axSpA |
|
- **Usage:** Just drag a pelvic x-ray into the box and hit run. |
|
- **Process:** The input image will be segmented and cropped to the SIJ before classification. |
|
- **Please Note:** This tool is intended for research purposes only. |
|
- **Privacy:** Please ensure data privacy and don't upload any sensitive patient information to this tool. |
|
""" |
|
|
|
css = """ |
|
h1 { |
|
text-align: center; |
|
display:block; |
|
} |
|
.markdown-block { |
|
padding: 10px; /* Padding around the text */ |
|
border-radius: 5px; /* Rounded corners */ |
|
display: inline-flex; /* Use inline-flex to shrink to content size */ |
|
flex-direction: column; |
|
justify-content: center; /* Vertically center content */ |
|
align-items: center; /* Horizontally center items within */ |
|
margin: auto; /* Center the block */ |
|
} |
|
|
|
.markdown-block ul, .markdown-block ol { |
|
border-radius: 5px; |
|
padding: 10px; |
|
padding-left: 20px; /* Adjust padding for bullet alignment */ |
|
text-align: left; /* Ensure text within list is left-aligned */ |
|
list-style-position: inside;/* Ensures bullets/numbers are inside the content flow */ |
|
} |
|
|
|
footer { |
|
display:none !important |
|
} |
|
""" |
|
@spaces.GPU(duration=20) |
|
def predict_image(input_image, input_file): |
|
|
|
if input_image is not None: |
|
image_path = input_image |
|
|
|
elif input_file is not None: |
|
image_path = input_file |
|
|
|
else: |
|
return None , None , "Please input an image before pressing run" , None , None |
|
|
|
image_mask = Model_Seg.load_and_segment_image(image_path, device) |
|
|
|
overlay_image_np, original_image_np = utils.overlay_mask(image_path, image_mask) |
|
|
|
image_mask_im = sitk.GetImageFromArray(image_mask[None, :, :].astype(uint8)) |
|
image_im = sitk.GetImageFromArray(original_image_np[None, :, :].astype(uint8)) |
|
cropped_boxed_im, _ = utils.mask_and_crop(image_im, image_mask_im) |
|
|
|
cropped_boxed_array = sitk.GetArrayFromImage(cropped_boxed_im) |
|
cropped_boxed_array_disp = cropped_boxed_array.squeeze() |
|
cropped_boxed_tensor = torch.Tensor(cropped_boxed_array) |
|
prediction, image_transformed = Model_Class.load_and_classify_image(cropped_boxed_tensor, device) |
|
|
|
|
|
gradcam = Model_Class.make_GradCAM(image_transformed, device) |
|
|
|
nr_axSpA_prob = float(prediction[0].item()) |
|
r_axSpA_prob = float(prediction[1].item()) |
|
|
|
|
|
considered = "be considered r-axSpA" if r_axSpA_prob > 0.59 else "not be considered r-axSpA" |
|
|
|
explanation = f"According to the pre-determined cut-off threshold of 0.59, the image should {considered}. This Tool is for research purposes only." |
|
|
|
pred_dict = {"nr-axSpA": nr_axSpA_prob, "r-axSpA": r_axSpA_prob} |
|
|
|
return overlay_image_np, pred_dict, explanation, gradcam, cropped_boxed_array_disp |
|
|
|
|
|
|
|
|
|
with gr.Blocks(css=css, title="Anatomy Aware axSpA") as iface: |
|
|
|
gr.Markdown("# Anatomy-Aware Image Classification for radiographic axSpA") |
|
gr.Markdown(description_markdown, elem_classes="markdown-block") |
|
|
|
|
|
with gr.Row(): |
|
with gr.Column(): |
|
|
|
with gr.Tab("PNG/JPG"): |
|
input_image = gr.Image(type='filepath', label="Upload an X-ray Image") |
|
|
|
with gr.Tab("NIfTI/DICOM"): |
|
input_file = gr.File(type='filepath', label="Upload an X-ray Image") |
|
|
|
with gr.Row(): |
|
submit_button = gr.Button("Run", variant="primary") |
|
clear_button = gr.ClearButton() |
|
|
|
with gr.Column(): |
|
overlay_image_np = gr.Image(label="Segmentation Mask") |
|
|
|
pred_dict = gr.Label(label="Prediction") |
|
explanation= gr.Textbox(label="Classification Decision") |
|
|
|
with gr.Accordion("Additional Information", open=False): |
|
gradcam = gr.Image(label="GradCAM") |
|
cropped_boxed_array_disp = gr.Image(label="Bounding Box") |
|
|
|
submit_button.click(predict_image, inputs = [input_image, input_file], outputs=[overlay_image_np, pred_dict, explanation, gradcam, cropped_boxed_array_disp]) |
|
clear_button.add([input_image,overlay_image_np, pred_dict, explanation, gradcam, cropped_boxed_array_disp]) |
|
gr.HTML(article_html) |
|
|
|
|
|
if __name__ == "__main__": |
|
iface.queue() |
|
iface.launch() |
|
|