File size: 4,818 Bytes
737e510
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
import tempfile
from pathlib import Path

import SimpleITK as sitk
import torch
from mrsegmentator import inference
from mrsegmentator.utils import add_postfix

import gradio as gr
import utils


description_markdown = """
- **GitHub: https://github.com/hhaentze/mrsegmentator
- **Paper: https://arxiv.org/abs/2405.06463"
- **Please Note:** This tool is intended for research purposes only.
"""


css = """

h1 {
    text-align: center;
    display:block;
}
.markdown-block {
    background-color: #0b0f1a; /* Light gray background */
    color: white;             /* Black text */
    padding: 10px;            /* Padding around the text */
    border-radius: 5px;       /* Rounded corners */
    box-shadow: 0 0 10px rgba(11,15,26,1);
    display: inline-flex;       /* Use inline-flex to shrink to content size */
    flex-direction: column;
    justify-content: center;    /* Vertically center content */
    align-items: center;        /* Horizontally center items within */
    margin: auto;               /* Center the block */
}

.markdown-block ul, .markdown-block ol {
    background-color: #1e2936;
    border-radius: 5px;
    padding: 10px;
    box-shadow: 0 0 10px rgba(0,0,0,0.3);
    padding-left: 20px;         /* Adjust padding for bullet alignment */
    text-align: left;           /* Ensure text within list is left-aligned */
    list-style-position: inside;/* Ensures bullets/numbers are inside the content flow */
}

footer {
    display:none !important
}
"""

examples = ["amos_0555.nii.gz","amos_0517.nii.gz", "amos_0541.nii.gz", "amos_0571.nii.gz"]


def save_file(segmentation, path):
    """If the segmentation comes from our sample files directly return the path.
    Otherwise save it to the temporary file that was previously allocated by the input image"""

    if Path(path).name in examples:
        path = "segmentations/" + add_postfix(path, "seg")
    else:
        sitk.WriteImage(segmentation, path)

    return path


def infer(image_path):
    with tempfile.TemporaryDirectory() as tmpdirname:

        inference.infer(
            [image_path], tmpdirname, [0, 1, 2, 3, 4], cpu_only=False if torch.cuda.is_available() else True
        )
        filename = add_postfix(Path(image_path).name, "seg")
        segmentation = sitk.ReadImage(tmpdirname + "/" + filename)

    return segmentation


def infer_wrapper(input_file, image_state, seg_state, slider=50):

    filename = Path(input_file).name

    # inference
    if filename in examples:
        segmentation = sitk.ReadImage("segmentations/" + add_postfix(filename, "seg"))
    else:
        segmentation = infer(input_file.name)

    # save file
    seg_path = save_file(segmentation, input_file.name)
    seg_state.append(utils.sitk2numpy(segmentation))

    return utils.display(image_state[-1], seg_state[-1], slider), seg_state, seg_path


with gr.Blocks(css=css, title="MRSegmentator") as iface:

    gr.Markdown("# Robust Multi-Modality Segmentation of 40 Classes in MRI and CT Imaging")
    gr.Markdown(description_markdown, elem_classes="markdown-block")

    image_state = gr.State([])
    seg_state = gr.State([])

    with gr.Row():
        with gr.Column():

            input_file = gr.File(
                type="filepath", label="Upload an MRI Image (.nii/.nii.gz)", file_types=[".gz", ".nii.gz"]
            )
            gr.Examples(["images/" + ex for ex in examples], input_file)

            with gr.Row():
                submit_button = gr.Button("Run", variant="primary")
                clear_button = gr.ClearButton()

            slider = gr.Slider(1, 100, value=50, step=2, label="Select (relative) Slice")
            download_file = gr.File(label="Download Segmentation", interactive=False)

        with gr.Column():
            overlay_image_np = gr.AnnotatedImage(label="Axial View")

            # pred_dict = gr.Label(label="Prediction")
            # explanation= gr.Textbox(label="Classification Decision")

            # with gr.Accordion("Additional Information", open=False):
            #     gradcam = gr.Image(label="GradCAM")
            #     cropped_boxed_array_disp = gr.Image(label="Bounding Box")

    input_file.change(
        utils.read_and_display,
        inputs=[input_file, image_state, seg_state],
        outputs=[overlay_image_np, image_state, seg_state],
    )
    slider.change(utils.display, inputs=[image_state, seg_state, slider], outputs=[overlay_image_np])

    submit_button.click(
        infer_wrapper,
        inputs=[input_file, image_state, seg_state, slider],
        outputs=[overlay_image_np, seg_state, download_file],
    )

    clear_button.add([input_file, overlay_image_np, image_state, seg_state, download_file])


if __name__ == "__main__":
    iface.queue()
    # iface.launch(server_name='0.0.0.0', server_port=8080)
    iface.launch()