Spaces:
Running
on
Zero
Running
on
Zero
File size: 5,189 Bytes
ab163d2 aca82f8 f722806 ab163d2 1e965fe ab163d2 53435a7 ab163d2 bd937f4 9afb16d ab163d2 9afb16d ab163d2 574fc22 ab163d2 bd937f4 ab163d2 f7975aa ab163d2 2174707 ab163d2 f722806 ab163d2 aca82f8 ab163d2 aca82f8 f722806 ab163d2 f722806 ab163d2 b037e4f ab163d2 9afb16d 04f372b ab163d2 5da7918 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 |
import gradio as gr
import utils
import Model_Class
import Model_Seg
import SimpleITK as sitk
import torch
from numpy import uint8
import spaces
from numpy import uint8, rot90, fliplr
from monai.transforms import Rotate90
image_base64 = utils.image_to_base64("anatomy_aware_pipeline.png")
article_html = f"<img src='data:image/png;base64,{image_base64}' alt='Anatomical pipeline illustration' style='width:100%;'>"
description_markdown = """
- This tool combines a U-Net Segmentation Model with a ResNet-50 for Classification.
- For more info checkout the GitHub here: https://github.com/FJDorfner/Anatomy-Aware-Classification-axSpA
- **Usage:** Just drag a pelvic x-ray into the box and hit run.
- **Process:** The input image will be segmented and cropped to the SIJ before classification.
- **Please Note:** This tool is intended for research purposes only.
- **Privacy:** Please ensure data privacy and don't upload any sensitive patient information to this tool.
"""
css = """
h1 {
text-align: center;
display:block;
}
.markdown-block {
padding: 10px; /* Padding around the text */
border-radius: 5px; /* Rounded corners */
display: inline-flex; /* Use inline-flex to shrink to content size */
flex-direction: column;
justify-content: center; /* Vertically center content */
align-items: center; /* Horizontally center items within */
margin: auto; /* Center the block */
}
.markdown-block ul, .markdown-block ol {
border-radius: 5px;
padding: 10px;
padding-left: 20px; /* Adjust padding for bullet alignment */
text-align: left; /* Ensure text within list is left-aligned */
list-style-position: inside;/* Ensures bullets/numbers are inside the content flow */
}
footer {
display:none !important
}
"""
@spaces.GPU(duration=20)
def predict_image(input_image, input_file):
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
if input_image is not None:
image_path = input_image
elif input_file is not None:
image_path = input_file
else:
return None , None , "Please input an image before pressing run" , None , None
image_mask = Model_Seg.load_and_segment_image(image_path, device)
overlay_image_np, original_image_np = utils.overlay_mask(image_path, image_mask)
overlay_image_np = rot90(overlay_image_np, k=3)
overlay_image_np = fliplr(overlay_image_np)
image_mask_im = sitk.GetImageFromArray(image_mask[None, :, :].astype(uint8))
image_im = sitk.GetImageFromArray(original_image_np[None, :, :].astype(uint8))
cropped_boxed_im, _ = utils.mask_and_crop(image_im, image_mask_im)
cropped_boxed_array = sitk.GetArrayFromImage(cropped_boxed_im)
cropped_boxed_tensor = torch.Tensor(cropped_boxed_array)
rotate = Rotate90(spatial_axes=(0, 1), k=3)
cropped_boxed_tensor = rotate(cropped_boxed_tensor)
cropped_boxed_array_disp = cropped_boxed_tensor.numpy().squeeze().astype(uint8)
prediction, image_transformed = Model_Class.load_and_classify_image(cropped_boxed_tensor, device)
gradcam = Model_Class.make_GradCAM(image_transformed, device)
nr_axSpA_prob = float(prediction[0].item())
r_axSpA_prob = float(prediction[1].item())
# Decision based on the threshold
considered = "be considered r-axSpA" if r_axSpA_prob > 0.59 else "not be considered r-axSpA"
explanation = f"According to the pre-determined cut-off threshold of 0.59, the image should {considered}. This Tool is for research purposes only."
pred_dict = {"nr-axSpA": nr_axSpA_prob, "r-axSpA": r_axSpA_prob}
return overlay_image_np, pred_dict, explanation, gradcam, cropped_boxed_array_disp
with gr.Blocks(css=css, title="Anatomy Aware axSpA") as iface:
gr.Markdown("# Anatomy-Aware Image Classification for radiographic axSpA")
gr.Markdown(description_markdown, elem_classes="markdown-block")
with gr.Row():
with gr.Column():
with gr.Tab("PNG/JPG"):
input_image = gr.Image(type='filepath', label="Upload an X-ray Image")
with gr.Tab("NIfTI/DICOM"):
input_file = gr.File(type='filepath', label="Upload an X-ray Image")
with gr.Row():
submit_button = gr.Button("Run", variant="primary")
clear_button = gr.ClearButton()
with gr.Column():
overlay_image_np = gr.Image(label="Segmentation Mask")
pred_dict = gr.Label(label="Prediction")
explanation= gr.Textbox(label="Classification Decision")
with gr.Accordion("Additional Information", open=False):
gradcam = gr.Image(label="GradCAM")
cropped_boxed_array_disp = gr.Image(label="Bounding Box")
submit_button.click(predict_image, inputs = [input_image, input_file], outputs=[overlay_image_np, pred_dict, explanation, gradcam, cropped_boxed_array_disp])
clear_button.add([input_image,overlay_image_np, pred_dict, explanation, gradcam, cropped_boxed_array_disp])
gr.HTML(article_html)
if __name__ == "__main__":
iface.queue()
iface.launch()
|