MingGatsby commited on
Commit
6d24ac5
1 Parent(s): 46c3233

Upload 2 files

Browse files
Files changed (2) hide show
  1. app.py +306 -0
  2. requirement.txt +3 -0
app.py ADDED
@@ -0,0 +1,306 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Import required libraries
2
+ import os
3
+ import io
4
+ import torch
5
+ import tempfile
6
+ import numpy as np
7
+ import streamlit as st
8
+
9
+ # Import utility and custom functions
10
+ from PIL import Image
11
+ from Util.DICOM import DICOM_Utils
12
+ from Util.Custom_Model import Build_Custom_Model, reshape_transform
13
+
14
+ # Import additional MONAI and PyTorch Grad-CAM utilities
15
+ from monai.config import print_config
16
+ from monai.utils import set_determinism
17
+ from monai.networks.nets import SEResNet50
18
+ from monai.transforms import (
19
+ Activations,
20
+ EnsureChannelFirst,
21
+ AsDiscrete,
22
+ Compose,
23
+ LoadImage,
24
+ RandFlip,
25
+ RandRotate,
26
+ RandZoom,
27
+ ScaleIntensity,
28
+ AsChannelFirst,
29
+ AddChannel,
30
+ RandSpatialCrop,
31
+ ScaleIntensityRangePercentiles,
32
+ Resize,
33
+ )
34
+ from pytorch_grad_cam import GradCAM
35
+ from pytorch_grad_cam.utils.image import show_cam_on_image
36
+ from pytorch_grad_cam.utils.model_targets import ClassifierOutputTarget
37
+
38
+
39
+ # (Int) Random seed
40
+ SEED = 0
41
+
42
+ # (Int) Model parameters
43
+ NUM_CLASSES = 1
44
+
45
+ # (String) CT Model directory
46
+ CT_MODEL_DIRECTORY = "C:\\Src\\GitHub\\AI_UI\\models\\CLOTS\\CT"
47
+
48
+ # (String) MRI Model directory
49
+ MRI_MODEL_DIRECTORY = "C:\\Src\\GitHub\\AI_UI\\models\\CLOTS\\MRI"
50
+
51
+ # (Boolean) Use custom model
52
+ CUSTOM_MODEL_FLAG = True
53
+
54
+ # (List[int]) Image size
55
+ SPATIAL_SIZE = [224, 224]
56
+
57
+ # (String) CT Model file name
58
+ CT_MODEL_FILE_NAME = "best_metric_model.pth"
59
+
60
+ # (String) MRI Model file name
61
+ MRI_MODEL_FILE_NAME = "best_metric_model.pth"
62
+
63
+ # (Boolean) List model modules
64
+ LIST_MODEL_MODULES = False
65
+
66
+ # (String) Model name
67
+ CT_MODEL_NAME = "swin_base_patch4_window7_224"
68
+
69
+ # (String) Model name
70
+ MRI_MODEL_NAME = "swin_base_patch4_window7_224"
71
+
72
+ # (Float) Model inference threshold
73
+ CT_INFERENCE_THRESHOLD = 0.5
74
+
75
+ # (Float) Model inference threshold
76
+ MRI_INFERENCE_THRESHOLD = 0.5
77
+
78
+ # (Int) Display CAM Class ID
79
+ CAM_CLASS_ID = 0
80
+
81
+ # (Int) Window Center for image display
82
+ DEFAULT_CT_WINDOW_CENTER = 40
83
+
84
+ # (Int) Window Width for image display
85
+ DEFAULT_CT_WINDOW_WIDTH = 100
86
+
87
+ # (Int) Window Center for image display
88
+ DEFAULT_MRI_WINDOW_CENTER = 400
89
+
90
+ # (Int) Window Width for image display
91
+ DEFAULT_MRI_WINDOW_WIDTH = 1000
92
+
93
+ # (Int) Minimum value for Window Center
94
+ WINDOW_CENTER_MIN = -600
95
+
96
+ # (Int) Maximum value for Window Center
97
+ WINDOW_CENTER_MAX = 1000
98
+
99
+ # (Int) Minimum value for Window Width
100
+ WINDOW_WIDTH_MIN = 1
101
+
102
+ # (Int) Maximum value for Window Width
103
+ WINDOW_WIDTH_MAX = 3000
104
+
105
+ # Evaluation Transforms
106
+ eval_transforms = Compose(
107
+ [
108
+ LoadImage(image_only=True),
109
+ AsChannelFirst(),
110
+ ScaleIntensityRangePercentiles(lower=20, upper=80, b_min=0.0, b_max=1.0, clip=False, relative=True),
111
+ Resize(spatial_size=SPATIAL_SIZE)
112
+ ]
113
+ )
114
+
115
+ # CAM Original Transforms
116
+ cam_original_transforms = Compose(
117
+ [
118
+ LoadImage(image_only=True),
119
+ AsChannelFirst(),
120
+ Resize(spatial_size=SPATIAL_SIZE)
121
+ ]
122
+ )
123
+
124
+ # CAM Original Transforms
125
+ original_transforms = Compose(
126
+ [
127
+ LoadImage(image_only=True),
128
+ AsChannelFirst()
129
+ ]
130
+ )
131
+
132
+ # Function to convert PIL Image to byte stream in PNG format for downloading
133
+ def image_to_bytes(image):
134
+ byte_stream = io.BytesIO()
135
+ image.save(byte_stream, format='PNG')
136
+ return byte_stream.getvalue()
137
+
138
+ set_determinism(seed=SEED)
139
+ torch.manual_seed(SEED)
140
+
141
+ # Parameters
142
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
143
+ ct_root_dir = tempfile.mkdtemp() if CT_MODEL_DIRECTORY is None else CT_MODEL_DIRECTORY
144
+ mri_root_dir = tempfile.mkdtemp() if MRI_MODEL_DIRECTORY is None else MRI_MODEL_DIRECTORY
145
+
146
+ def load_model(root_dir, model_name, model_file_name):
147
+ if CUSTOM_MODEL_FLAG:
148
+ model = Build_Custom_Model(model_name, NUM_CLASSES, pretrained=False).to(device)
149
+ else:
150
+ model = SEResNet50(spatial_dims=2, in_channels=1, num_classes=NUM_CLASSES).to(device)
151
+ model.load_state_dict(torch.load(os.path.join(root_dir, model_file_name)))
152
+ model.eval()
153
+ return model
154
+
155
+ ct_model = load_model(ct_root_dir, CT_MODEL_NAME, CT_MODEL_FILE_NAME)
156
+ mri_model = load_model(mri_root_dir, MRI_MODEL_NAME, MRI_MODEL_FILE_NAME)
157
+ if LIST_MODEL_MODULES:
158
+ for ct_name, _ in ct_model.named_modules():
159
+ print(ct_name)
160
+
161
+ for mri_name, _ in mri_model.named_modules():
162
+ print(mri_name)
163
+
164
+ # Initialize Streamlit
165
+ st.title("Analyze")
166
+
167
+ # Use Streamlit's number_input to adjust WINDOW_CENTER and WINDOW_WIDTH
168
+ st.sidebar.header("Windowing Parameters for DICOM")
169
+ CT_WINDOW_CENTER = st.sidebar.number_input("CT Window Center", min_value=WINDOW_CENTER_MIN, max_value=WINDOW_CENTER_MAX, value=DEFAULT_CT_WINDOW_CENTER, step=1)
170
+ CT_WINDOW_WIDTH = st.sidebar.number_input("CT Window Width", min_value=WINDOW_WIDTH_MIN, max_value=WINDOW_WIDTH_MAX, value=DEFAULT_CT_WINDOW_WIDTH, step=1)
171
+ MRI_WINDOW_CENTER = st.sidebar.number_input("MRI Window Center", min_value=WINDOW_CENTER_MIN, max_value=WINDOW_CENTER_MAX, value=DEFAULT_MRI_WINDOW_CENTER, step=1)
172
+ MRI_WINDOW_WIDTH = st.sidebar.number_input("MRI Window Width", min_value=WINDOW_WIDTH_MIN, max_value=WINDOW_WIDTH_MAX, value=DEFAULT_MRI_WINDOW_WIDTH, step=1)
173
+
174
+ uploaded_ct_file = st.file_uploader("Upload a candidate CT DICOM", type=["dcm"])
175
+ if uploaded_ct_file is not None:
176
+ # Save the uploaded file to a temporary location
177
+ with tempfile.NamedTemporaryFile(delete=False, suffix=".dcm") as temp_file:
178
+ temp_file.write(uploaded_ct_file.getvalue())
179
+
180
+ # Apply evaluation transforms to the DICOM image for model prediction
181
+ image_tensor = eval_transforms(temp_file.name).unsqueeze(0).to(device)
182
+
183
+ # Predict
184
+ with torch.no_grad():
185
+ outputs = ct_model(image_tensor).sigmoid().to("cpu").numpy()
186
+ prob = outputs[0][0]
187
+ CLOTS_CLASSIFICATION = False
188
+ if(prob >= CT_INFERENCE_THRESHOLD):
189
+ CLOTS_CLASSIFICATION=True
190
+
191
+ st.header("CT Classification")
192
+ st.subheader(f"Ischaemic Stroke : {CLOTS_CLASSIFICATION}")
193
+ st.subheader(f"Confidence : {prob * 100:.1f}%")
194
+
195
+ # Load the original DICOM image for download
196
+ download_image_tensor = original_transforms(temp_file.name).unsqueeze(0).to(device)
197
+ download_image = download_image_tensor.squeeze()
198
+
199
+ # Transform the download image and apply windowing
200
+ transformed_download_image = DICOM_Utils.transform_image_for_display(download_image)
201
+ windowed_download_image = DICOM_Utils.apply_windowing(transformed_download_image, CT_WINDOW_CENTER, CT_WINDOW_WIDTH)
202
+
203
+ # Streamlit button to trigger image download
204
+ image_data = image_to_bytes(Image.fromarray(windowed_download_image))
205
+ st.download_button(
206
+ label="Download CT Image",
207
+ data=image_data,
208
+ file_name="downloaded_ct_image.png",
209
+ mime="image/png"
210
+ )
211
+
212
+ # Load the original DICOM image for display
213
+ display_image_tensor = cam_original_transforms(temp_file.name).unsqueeze(0).to(device)
214
+ display_image = display_image_tensor.squeeze()
215
+
216
+ # Transform the image and apply windowing
217
+ transformed_image = DICOM_Utils.transform_image_for_display(display_image)
218
+ windowed_image = DICOM_Utils.apply_windowing(transformed_image, CT_WINDOW_CENTER, CT_WINDOW_WIDTH)
219
+ st.image(Image.fromarray(windowed_image), caption="Original CT Visualization", use_column_width=True)
220
+
221
+ # Expand to three channels
222
+ windowed_image = np.expand_dims(windowed_image, axis=2)
223
+ windowed_image = np.tile(windowed_image, [1, 1, 3])
224
+
225
+ # Ensure both are of float32 type
226
+ windowed_image = windowed_image.astype(np.float32)
227
+
228
+ # Normalize to [0, 1] range
229
+ windowed_image = np.float32(windowed_image) / 255
230
+
231
+ # Build the CAM (Class Activation Map)
232
+ target_layers = [ct_model.model.norm]
233
+ cam = GradCAM(model=ct_model, target_layers=target_layers, reshape_transform=reshape_transform, use_cuda=True)
234
+ grayscale_cam = cam(input_tensor=image_tensor, targets=[ClassifierOutputTarget(CAM_CLASS_ID)])
235
+ grayscale_cam = grayscale_cam[0, :]
236
+
237
+ # Now you can safely call the show_cam_on_image function
238
+ visualization = show_cam_on_image(windowed_image, grayscale_cam, use_rgb=True)
239
+ st.image(Image.fromarray(visualization), caption="CAM CT Visualization", use_column_width=True)
240
+
241
+ uploaded_mri_file = st.file_uploader("Upload a candidate MRI DICOM", type=["dcm"])
242
+ if uploaded_mri_file is not None:
243
+ # Save the uploaded file to a temporary location
244
+ with tempfile.NamedTemporaryFile(delete=False, suffix=".dcm") as temp_file:
245
+ temp_file.write(uploaded_mri_file.getvalue())
246
+
247
+ # Apply evaluation transforms to the DICOM image for model prediction
248
+ image_tensor = eval_transforms(temp_file.name).unsqueeze(0).to(device)
249
+
250
+ # Predict
251
+ with torch.no_grad():
252
+ outputs = mri_model(image_tensor).sigmoid().to("cpu").numpy()
253
+ prob = outputs[0][0]
254
+ CLOTS_CLASSIFICATION = False
255
+ if(prob >= MRI_INFERENCE_THRESHOLD):
256
+ CLOTS_CLASSIFICATION=True
257
+
258
+ st.header("MRI Classification")
259
+ st.subheader(f"Ischaemic Stroke : {CLOTS_CLASSIFICATION}")
260
+ st.subheader(f"Confidence : {prob * 100:.1f}%")
261
+
262
+ # Load the original DICOM image for download
263
+ download_image_tensor = original_transforms(temp_file.name).unsqueeze(0).to(device)
264
+ download_image = download_image_tensor.squeeze()
265
+
266
+ # Transform the download image and apply windowing
267
+ transformed_download_image = DICOM_Utils.transform_image_for_display(download_image)
268
+ windowed_download_image = DICOM_Utils.apply_windowing(transformed_download_image, MRI_WINDOW_CENTER, MRI_WINDOW_WIDTH)
269
+
270
+ # Streamlit button to trigger image download
271
+ image_data = image_to_bytes(Image.fromarray(windowed_download_image))
272
+ st.download_button(
273
+ label="Download MRI Image",
274
+ data=image_data,
275
+ file_name="downloaded_mri_image.png",
276
+ mime="image/png"
277
+ )
278
+
279
+ # Load the original DICOM image for display
280
+ display_image_tensor = cam_original_transforms(temp_file.name).unsqueeze(0).to(device)
281
+ display_image = display_image_tensor.squeeze()
282
+
283
+ # Transform the image and apply windowing
284
+ transformed_image = DICOM_Utils.transform_image_for_display(display_image)
285
+ windowed_image = DICOM_Utils.apply_windowing(transformed_image, MRI_WINDOW_CENTER, MRI_WINDOW_WIDTH)
286
+ st.image(Image.fromarray(windowed_image), caption="Original MRI Visualization", use_column_width=True)
287
+
288
+ # Expand to three channels
289
+ windowed_image = np.expand_dims(windowed_image, axis=2)
290
+ windowed_image = np.tile(windowed_image, [1, 1, 3])
291
+
292
+ # Ensure both are of float32 type
293
+ windowed_image = windowed_image.astype(np.float32)
294
+
295
+ # Normalize to [0, 1] range
296
+ windowed_image = np.float32(windowed_image) / 255
297
+
298
+ # Build the CAM (Class Activation Map)
299
+ target_layers = [mri_model.model.norm]
300
+ cam = GradCAM(model=mri_model, target_layers=target_layers, reshape_transform=reshape_transform, use_cuda=True)
301
+ grayscale_cam = cam(input_tensor=image_tensor, targets=[ClassifierOutputTarget(CAM_CLASS_ID)])
302
+ grayscale_cam = grayscale_cam[0, :]
303
+
304
+ # Now you can safely call the show_cam_on_image function
305
+ visualization = show_cam_on_image(windowed_image, grayscale_cam, use_rgb=True)
306
+ st.image(Image.fromarray(visualization), caption="CAM MRI Visualization", use_column_width=True)
requirement.txt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ streamlit
2
+ gradio
3
+ monai