DiGuaQiu's picture
Create app.py
737e510 verified
raw
history blame
4.82 kB
import tempfile
from pathlib import Path
import SimpleITK as sitk
import torch
from mrsegmentator import inference
from mrsegmentator.utils import add_postfix
import gradio as gr
import utils
description_markdown = """
- **GitHub: https://github.com/hhaentze/mrsegmentator
- **Paper: https://arxiv.org/abs/2405.06463"
- **Please Note:** This tool is intended for research purposes only.
"""
css = """
h1 {
text-align: center;
display:block;
}
.markdown-block {
background-color: #0b0f1a; /* Light gray background */
color: white; /* Black text */
padding: 10px; /* Padding around the text */
border-radius: 5px; /* Rounded corners */
box-shadow: 0 0 10px rgba(11,15,26,1);
display: inline-flex; /* Use inline-flex to shrink to content size */
flex-direction: column;
justify-content: center; /* Vertically center content */
align-items: center; /* Horizontally center items within */
margin: auto; /* Center the block */
}
.markdown-block ul, .markdown-block ol {
background-color: #1e2936;
border-radius: 5px;
padding: 10px;
box-shadow: 0 0 10px rgba(0,0,0,0.3);
padding-left: 20px; /* Adjust padding for bullet alignment */
text-align: left; /* Ensure text within list is left-aligned */
list-style-position: inside;/* Ensures bullets/numbers are inside the content flow */
}
footer {
display:none !important
}
"""
examples = ["amos_0555.nii.gz","amos_0517.nii.gz", "amos_0541.nii.gz", "amos_0571.nii.gz"]
def save_file(segmentation, path):
"""If the segmentation comes from our sample files directly return the path.
Otherwise save it to the temporary file that was previously allocated by the input image"""
if Path(path).name in examples:
path = "segmentations/" + add_postfix(path, "seg")
else:
sitk.WriteImage(segmentation, path)
return path
def infer(image_path):
with tempfile.TemporaryDirectory() as tmpdirname:
inference.infer(
[image_path], tmpdirname, [0, 1, 2, 3, 4], cpu_only=False if torch.cuda.is_available() else True
)
filename = add_postfix(Path(image_path).name, "seg")
segmentation = sitk.ReadImage(tmpdirname + "/" + filename)
return segmentation
def infer_wrapper(input_file, image_state, seg_state, slider=50):
filename = Path(input_file).name
# inference
if filename in examples:
segmentation = sitk.ReadImage("segmentations/" + add_postfix(filename, "seg"))
else:
segmentation = infer(input_file.name)
# save file
seg_path = save_file(segmentation, input_file.name)
seg_state.append(utils.sitk2numpy(segmentation))
return utils.display(image_state[-1], seg_state[-1], slider), seg_state, seg_path
with gr.Blocks(css=css, title="MRSegmentator") as iface:
gr.Markdown("# Robust Multi-Modality Segmentation of 40 Classes in MRI and CT Imaging")
gr.Markdown(description_markdown, elem_classes="markdown-block")
image_state = gr.State([])
seg_state = gr.State([])
with gr.Row():
with gr.Column():
input_file = gr.File(
type="filepath", label="Upload an MRI Image (.nii/.nii.gz)", file_types=[".gz", ".nii.gz"]
)
gr.Examples(["images/" + ex for ex in examples], input_file)
with gr.Row():
submit_button = gr.Button("Run", variant="primary")
clear_button = gr.ClearButton()
slider = gr.Slider(1, 100, value=50, step=2, label="Select (relative) Slice")
download_file = gr.File(label="Download Segmentation", interactive=False)
with gr.Column():
overlay_image_np = gr.AnnotatedImage(label="Axial View")
# pred_dict = gr.Label(label="Prediction")
# explanation= gr.Textbox(label="Classification Decision")
# with gr.Accordion("Additional Information", open=False):
# gradcam = gr.Image(label="GradCAM")
# cropped_boxed_array_disp = gr.Image(label="Bounding Box")
input_file.change(
utils.read_and_display,
inputs=[input_file, image_state, seg_state],
outputs=[overlay_image_np, image_state, seg_state],
)
slider.change(utils.display, inputs=[image_state, seg_state, slider], outputs=[overlay_image_np])
submit_button.click(
infer_wrapper,
inputs=[input_file, image_state, seg_state, slider],
outputs=[overlay_image_np, seg_state, download_file],
)
clear_button.add([input_file, overlay_image_np, image_state, seg_state, download_file])
if __name__ == "__main__":
iface.queue()
# iface.launch(server_name='0.0.0.0', server_port=8080)
iface.launch()