# Import required libraries import os import io import torch # import shutil import numpy as np import streamlit as st # Import utility and custom functions from PIL import Image from Util.DICOM import DICOM_Utils from Util.Custom_Model import Build_Custom_Model, reshape_transform # Import additional MONAI and PyTorch Grad-CAM utilities from monai.config import print_config from monai.utils import set_determinism from monai.networks.nets import SEResNet50 from monai.transforms import ( Activations, EnsureChannelFirst, AsDiscrete, Compose, LoadImage, RandFlip, RandRotate, RandZoom, ScaleIntensity, AsChannelFirst, AddChannel, RandSpatialCrop, ScaleIntensityRangePercentiles, Resize, ) from pytorch_grad_cam import GradCAM from pytorch_grad_cam.utils.image import show_cam_on_image from pytorch_grad_cam.utils.model_targets import ClassifierOutputTarget # (Int) Random seed SEED = 0 # (Int) Model parameters NUM_CLASSES = 1 # (String) CT Model directory CT_MODEL_DIRECTORY = "models/CLOTS/CT" # (String) MRI Model directory MRI_MODEL_DIRECTORY = "models/CLOTS/MRI" # (Boolean) Use custom model CUSTOM_MODEL_FLAG = True # (List[int]) Image size SPATIAL_SIZE = [224, 224] # (String) CT Model file name CT_MODEL_FILE_NAME = "best_metric_model.pth" # (String) MRI Model file name MRI_MODEL_FILE_NAME = "best_metric_model.pth" # (Boolean) List model modules LIST_MODEL_MODULES = False # (String) Model name CT_MODEL_NAME = "swin_base_patch4_window7_224" # (String) Model name MRI_MODEL_NAME = "swin_base_patch4_window7_224" # (Float) Model inference threshold CT_INFERENCE_THRESHOLD = 0.5 # (Float) Model inference threshold MRI_INFERENCE_THRESHOLD = 0.5 # (Int) Display CAM Class ID CAM_CLASS_ID = 0 # (Int) Window Center for image display DEFAULT_CT_WINDOW_CENTER = 40 # (Int) Window Width for image display DEFAULT_CT_WINDOW_WIDTH = 100 # (Int) Window Center for image display DEFAULT_MRI_WINDOW_CENTER = 400 # (Int) Window Width for image display DEFAULT_MRI_WINDOW_WIDTH = 1000 # (Int) Minimum value for Window Center WINDOW_CENTER_MIN = -600 # (Int) Maximum value for Window Center WINDOW_CENTER_MAX = 1000 # (Int) Minimum value for Window Width WINDOW_WIDTH_MIN = 1 # (Int) Maximum value for Window Width WINDOW_WIDTH_MAX = 3000 # Evaluation Transforms eval_transforms = Compose( [ # LoadImage(image_only=True), AsChannelFirst(), ScaleIntensityRangePercentiles(lower=20, upper=80, b_min=0.0, b_max=1.0, clip=False, relative=True), Resize(spatial_size=SPATIAL_SIZE) ] ) # CAM Transforms cam_transforms = Compose( [ # LoadImage(image_only=True), AsChannelFirst(), Resize(spatial_size=SPATIAL_SIZE) ] ) # Original Transforms original_transforms = Compose( [ # LoadImage(image_only=True), AsChannelFirst() ] ) # Function to convert PIL Image to byte stream in PNG format for downloading def image_to_bytes(image): byte_stream = io.BytesIO() image.save(byte_stream, format='PNG') return byte_stream.getvalue() # if os.path.exists("tempDir"): # shutil.rmtree(os.path.join("tempDir")) # def create_dir(dirname: str): # if not os.path.exists(dirname): # os.makedirs(dirname, exist_ok=True) # create_dir("CT_tempDir") # create_dir("MRI_tempDir") # # Get the current working directory # current_directory = os.getcwd() set_determinism(seed=SEED) torch.manual_seed(SEED) # Parameters device = torch.device("cuda" if torch.cuda.is_available() else "cpu") def load_model(root_dir, model_name, model_file_name): if CUSTOM_MODEL_FLAG: model = Build_Custom_Model(model_name, NUM_CLASSES, pretrained=False).to(device) else: model = SEResNet50(spatial_dims=2, in_channels=1, num_classes=NUM_CLASSES).to(device) model.load_state_dict(torch.load(os.path.join(root_dir, model_file_name), map_location=device)) model.eval() return model ct_model = load_model(CT_MODEL_DIRECTORY, CT_MODEL_NAME, CT_MODEL_FILE_NAME) mri_model = load_model(MRI_MODEL_DIRECTORY, MRI_MODEL_NAME, MRI_MODEL_FILE_NAME) if LIST_MODEL_MODULES: for ct_name, _ in ct_model.named_modules(): print(ct_name) for mri_name, _ in mri_model.named_modules(): print(mri_name) # Initialize Streamlit st.title("Analyze") # Use Streamlit's number_input to adjust WINDOW_CENTER and WINDOW_WIDTH st.sidebar.header("Windowing Parameters for DICOM") MRI_WINDOW_CENTER = st.sidebar.number_input("MRI Window Center", min_value=WINDOW_CENTER_MIN, max_value=WINDOW_CENTER_MAX, value=DEFAULT_MRI_WINDOW_CENTER, step=1) MRI_WINDOW_WIDTH = st.sidebar.number_input("MRI Window Width", min_value=WINDOW_WIDTH_MIN, max_value=WINDOW_WIDTH_MAX, value=DEFAULT_MRI_WINDOW_WIDTH, step=1) CT_WINDOW_CENTER = st.sidebar.number_input("CT Window Center", min_value=WINDOW_CENTER_MIN, max_value=WINDOW_CENTER_MAX, value=DEFAULT_CT_WINDOW_CENTER, step=1) CT_WINDOW_WIDTH = st.sidebar.number_input("CT Window Width", min_value=WINDOW_WIDTH_MIN, max_value=WINDOW_WIDTH_MAX, value=DEFAULT_CT_WINDOW_WIDTH, step=1) uploaded_mri_file = st.file_uploader("Upload a candidate MRI DICOM", type=["dcm"]) if uploaded_mri_file is not None: # To check file details file_details = {"FileName": uploaded_mri_file.name, "FileType": uploaded_mri_file.type, "FileSize": uploaded_mri_file.size} st.write(file_details) import pydicom # Read DICOM file into NumPy array dicom_data = pydicom.dcmread(uploaded_mri_file) dicom_array = dicom_data.pixel_array # Convert the data type to float32 dicom_array = dicom_array.astype(np.float32) # Then add a channel dimension dicom_array = dicom_array[:, :, np.newaxis] # Check the shape and dtype of dicom_array st.write(f"Shape of dicom_array: {dicom_array.shape}") st.write(f"Data type of dicom_array: {dicom_array.dtype}") transformed_array = eval_transforms(dicom_array) # Convert to PyTorch tensor and move to device device = torch.device("cuda" if torch.cuda.is_available() else "cpu") image_tensor = transformed_array.clone().detach().unsqueeze(0).to(device) # Predict with torch.no_grad(): outputs = mri_model(image_tensor).sigmoid().to("cpu").numpy() prob = outputs[0][0] CLOTS_CLASSIFICATION = False if(prob >= MRI_INFERENCE_THRESHOLD): CLOTS_CLASSIFICATION=True st.header("MRI Classification") st.subheader(f"Ischaemic Stroke : {CLOTS_CLASSIFICATION}") st.subheader(f"Confidence : {prob * 100:.1f}%") # Load the original DICOM image for download download_image_tensor = original_transforms(dicom_array).unsqueeze(0).to(device) download_image = download_image_tensor.squeeze() # Transform the download image and apply windowing transformed_download_image = DICOM_Utils.transform_image_for_display(download_image) windowed_download_image = DICOM_Utils.apply_windowing(transformed_download_image, MRI_WINDOW_CENTER, MRI_WINDOW_WIDTH) # Streamlit button to trigger image download image_data = image_to_bytes(Image.fromarray(windowed_download_image)) st.download_button( label="Download MRI Image", data=image_data, file_name="downloaded_mri_image.png", mime="image/png" ) # Load the original DICOM image for display display_image_tensor = cam_transforms(dicom_array).unsqueeze(0).to(device) display_image = display_image_tensor.squeeze() # Transform the image and apply windowing transformed_image = DICOM_Utils.transform_image_for_display(display_image) windowed_image = DICOM_Utils.apply_windowing(transformed_image, MRI_WINDOW_CENTER, MRI_WINDOW_WIDTH) st.image(Image.fromarray(windowed_image), caption="Original MRI Visualization", use_column_width=True) # Expand to three channels windowed_image = np.expand_dims(windowed_image, axis=2) windowed_image = np.tile(windowed_image, [1, 1, 3]) # Ensure both are of float32 type windowed_image = windowed_image.astype(np.float32) # Normalize to [0, 1] range windowed_image = np.float32(windowed_image) / 255 # Build the CAM (Class Activation Map) target_layers = [mri_model.model.norm] cam = GradCAM(model=mri_model, target_layers=target_layers, reshape_transform=reshape_transform, use_cuda=True) grayscale_cam = cam(input_tensor=image_tensor, targets=[ClassifierOutputTarget(CAM_CLASS_ID)]) grayscale_cam = grayscale_cam[0, :] # Now you can safely call the show_cam_on_image function visualization = show_cam_on_image(windowed_image, grayscale_cam, use_rgb=True) st.image(Image.fromarray(visualization), caption="CAM MRI Visualization", use_column_width=True) # uploaded_ct_file = st.file_uploader("Upload a candidate CT DICOM", type=["dcm"]) # if uploaded_ct_file is not None: # # Save the uploaded file to a temporary location # ct_temp_path = os.path.join("CT_tempDir", uploaded_ct_file.name) # with open(ct_temp_path, "wb") as f: # f.write(uploaded_ct_file.getbuffer()) # full_ct_temp_path = current_directory +"\\"+ ct_temp_path # # Apply evaluation transforms to the DICOM image for model prediction # image_tensor = eval_transforms(full_ct_temp_path).unsqueeze(0).to(device) # # Predict # with torch.no_grad(): # outputs = ct_model(image_tensor).sigmoid().to("cpu").numpy() # prob = outputs[0][0] # CLOTS_CLASSIFICATION = False # if(prob >= CT_INFERENCE_THRESHOLD): # CLOTS_CLASSIFICATION=True # st.header("CT Classification") # st.subheader(f"Ischaemic Stroke : {CLOTS_CLASSIFICATION}") # st.subheader(f"Confidence : {prob * 100:.1f}%") # # Load the original DICOM image for download # download_image_tensor = original_transforms(full_ct_temp_path).unsqueeze(0).to(device) # download_image = download_image_tensor.squeeze() # # Transform the download image and apply windowing # transformed_download_image = DICOM_Utils.transform_image_for_display(download_image) # windowed_download_image = DICOM_Utils.apply_windowing(transformed_download_image, CT_WINDOW_CENTER, CT_WINDOW_WIDTH) # # Streamlit button to trigger image download # image_data = image_to_bytes(Image.fromarray(windowed_download_image)) # st.download_button( # label="Download CT Image", # data=image_data, # file_name="downloaded_ct_image.png", # mime="image/png" # ) # # Load the original DICOM image for display # display_image_tensor = cam_transforms(full_ct_temp_path).unsqueeze(0).to(device) # display_image = display_image_tensor.squeeze() # # Transform the image and apply windowing # transformed_image = DICOM_Utils.transform_image_for_display(display_image) # windowed_image = DICOM_Utils.apply_windowing(transformed_image, CT_WINDOW_CENTER, CT_WINDOW_WIDTH) # st.image(Image.fromarray(windowed_image), caption="Original CT Visualization", use_column_width=True) # # Expand to three channels # windowed_image = np.expand_dims(windowed_image, axis=2) # windowed_image = np.tile(windowed_image, [1, 1, 3]) # # Ensure both are of float32 type # windowed_image = windowed_image.astype(np.float32) # # Normalize to [0, 1] range # windowed_image = np.float32(windowed_image) / 255 # # Build the CAM (Class Activation Map) # target_layers = [ct_model.model.norm] # cam = GradCAM(model=ct_model, target_layers=target_layers, reshape_transform=reshape_transform, use_cuda=True) # grayscale_cam = cam(input_tensor=image_tensor, targets=[ClassifierOutputTarget(CAM_CLASS_ID)]) # grayscale_cam = grayscale_cam[0, :] # # Now you can safely call the show_cam_on_image function # visualization = show_cam_on_image(windowed_image, grayscale_cam, use_rgb=True) # st.image(Image.fromarray(visualization), caption="CAM CT Visualization", use_column_width=True)