|
from monai.transforms import Transform, Compose, LoadImage, EnsureChannelFirst |
|
import torch |
|
import skimage |
|
import torch |
|
import SimpleITK as sitk |
|
import numpy as np |
|
from PIL import Image |
|
from io import BytesIO |
|
import matplotlib.pyplot as plt |
|
import SimpleITK as sitk |
|
from matplotlib.colors import ListedColormap |
|
import base64 |
|
import numpy as np |
|
from cv2 import dilate |
|
from scipy.ndimage import label |
|
from Model_Seg import RgbaToGrayscale |
|
|
|
def image_to_base64(image_path): |
|
with open(image_path, "rb") as image_file: |
|
return base64.b64encode(image_file.read()).decode('utf-8') |
|
|
|
class CustomCLAHE(Transform): |
|
"""Implements Contrast-Limited Adaptive Histogram Equalization (CLAHE) as a custom transform, as described by Qiu et al. |
|
|
|
Attributes: |
|
p1 (float): Weighting factor, determines degree of of contour enhacement. Default is 0.6. |
|
p2 (None or int): Kernel size for adaptive histogram. Default is None. |
|
p3 (float): Clip limit for histogram equalization. Default is 0.01. |
|
|
|
""" |
|
|
|
def __init__(self, p1=0.6, p2=None, p3=0.01): |
|
self.p1 = p1 |
|
self.p2 = p2 |
|
self.p3 = p3 |
|
|
|
def __call__(self, data): |
|
"""Apply the CLAHE algorithm to input data. |
|
|
|
Args: |
|
data (Union[dict, np.ndarray]): Input data. Could be a dictionary containing the image or the image array itself. |
|
|
|
Returns: |
|
torch.Tensor: Transformed data. |
|
""" |
|
|
|
if isinstance(data, dict): |
|
im = data["image"] |
|
|
|
else: |
|
im = data |
|
im = im.numpy() |
|
|
|
|
|
|
|
im = im[0] |
|
im = im[None, :, :] |
|
|
|
im = skimage.exposure.rescale_intensity(im, in_range="image", out_range=(0, 1)) |
|
im_noi = skimage.filters.median(im) |
|
im_fil = im_noi - self.p1 * skimage.filters.gaussian(im_noi, sigma=1) |
|
im_fil = skimage.exposure.rescale_intensity(im_fil, in_range="image", out_range=(0, 1)) |
|
im_ce = skimage.exposure.equalize_adapthist(im_fil, kernel_size=self.p2, clip_limit=self.p3) |
|
if isinstance(data, dict): |
|
data["image"] = torch.Tensor(im_ce) |
|
else: |
|
data = torch.Tensor(im_ce) |
|
|
|
return data |
|
|
|
|
|
|
|
def custom_colormap(): |
|
|
|
cdict = [(0, 0, 0, 0), |
|
(0, 1, 0, 0.5), |
|
(1, 0, 0, 0.5), |
|
(1, 1, 0, 0.5)] |
|
cmap = ListedColormap(cdict) |
|
return cmap |
|
|
|
def read_image(image_path): |
|
read_transforms = Compose([ |
|
LoadImage(image_only=True), |
|
EnsureChannelFirst(), |
|
RgbaToGrayscale(), |
|
]) |
|
try: |
|
original_image = read_transforms(image_path) |
|
original_image_np = original_image.numpy().astype(np.uint8) |
|
return original_image_np.squeeze() |
|
|
|
except Exception as e: |
|
try : |
|
original_image = sitk.ReadImage(image_path) |
|
original_image_np = sitk.GetArrayFromImage(original_image) |
|
return original_image_np.squeeze() |
|
except Exception as e: |
|
print("Failed Loading the Image: ", e) |
|
return None |
|
|
|
def overlay_mask(image_path, image_mask): |
|
original_image_np = read_image(image_path).squeeze().astype(np.uint8) |
|
|
|
|
|
image_mask_disp = image_mask |
|
plt.figure(figsize=(10, 10)) |
|
plt.imshow(original_image_np, cmap='gray') |
|
|
|
plt.imshow(image_mask_disp, cmap=custom_colormap(), alpha=0.5) |
|
plt.axis('off') |
|
|
|
|
|
buffer = BytesIO() |
|
plt.savefig(buffer, format='png', bbox_inches='tight', pad_inches=0) |
|
buffer.seek(0) |
|
overlay_image_np = np.array(Image.open(buffer)) |
|
return overlay_image_np, original_image_np |
|
|
|
|
|
def bounding_box_mask(image, label): |
|
"""Generates a bounding box mask around a labeled region in an image |
|
|
|
Args: |
|
image (SimpleITK.Image): The input image. |
|
label (SimpleITK.Image): The labeled image containing the region of interest. |
|
|
|
Returns: |
|
SimpleITK.Image: An image containing the with the bounding box mask applied with the |
|
same spacing as the original image. |
|
|
|
Note: |
|
This function assumes that the input image and label are SimpleITK.Image objects. |
|
The returned bounding box mask is a binary image where pixels inside the bounding box |
|
are set to 1 and others are set to 0. |
|
""" |
|
|
|
original_spacing = image.GetSpacing() |
|
|
|
|
|
image_array = sitk.GetArrayFromImage(image) |
|
image_array = np.squeeze(image_array) |
|
label_array = sitk.GetArrayFromImage(label) |
|
label_array = np.squeeze(label_array) |
|
|
|
|
|
first_nonzero_row_index = np.nonzero(np.any(label_array != 0, axis=1))[0][0] |
|
last_nonzero_row_index = np.max(np.nonzero(np.any(label_array != 0, axis=1))) |
|
first_nonzero_column_index = np.nonzero(np.any(label_array != 0, axis=0))[0][0] |
|
last_nonzero_column_index = np.max(np.nonzero(np.any(label_array != 0, axis=0))) |
|
|
|
top_left_corner = (first_nonzero_row_index, first_nonzero_column_index) |
|
bottom_right_corner = (last_nonzero_row_index, last_nonzero_column_index) |
|
|
|
|
|
bounding_box_array = label_array.copy() |
|
bounding_box_array[ |
|
top_left_corner[0] : bottom_right_corner[0] + 1, |
|
top_left_corner[1] : bottom_right_corner[1] + 1, |
|
] = 1 |
|
|
|
|
|
bounding_box_array = bounding_box_array[None, ...].astype(np.uint8) |
|
|
|
|
|
bounding_box_image = sitk.GetImageFromArray(bounding_box_array) |
|
bounding_box_image.SetSpacing(original_spacing) |
|
return bounding_box_image |
|
|
|
|
|
def threshold_based_crop(image): |
|
""" |
|
Use Otsu's threshold estimator to separate background and foreground. In medical imaging the background is |
|
usually air. Then crop the image using the foreground's axis aligned bounding box. |
|
Args: |
|
image (SimpleITK image): An image where the anatomy and background intensities form a |
|
bi-modal distribution |
|
(the assumption underlying Otsu's method.) |
|
Return: |
|
Cropped image based on foreground's axis aligned bounding box. |
|
""" |
|
|
|
inside_value = 0 |
|
outside_value = 255 |
|
label_shape_filter = sitk.LabelShapeStatisticsImageFilter() |
|
|
|
|
|
label_shape_filter.Execute(sitk.OtsuThreshold(image, inside_value, outside_value)) |
|
bounding_box = label_shape_filter.GetBoundingBox(outside_value) |
|
return sitk.RegionOfInterest( |
|
image, |
|
bounding_box[int(len(bounding_box) / 2) :], |
|
bounding_box[0 : int(len(bounding_box) / 2)], |
|
) |
|
|
|
def creat_SIJ_mask(image, input_label): |
|
""" |
|
Create a mask for the sacroiliac joints (SIJ) from pelvis and sascrum segmentation mask |
|
|
|
Args: |
|
image (SimpleITK.Image): x-ray image. |
|
input_label (SimpleITK.Image): Segmentation mask containing labels for sacrum, left- and right pelvis |
|
|
|
Returns: |
|
SimpleITK.Image: Mask of the SIJ |
|
|
|
""" |
|
|
|
original_spacing = image.GetSpacing() |
|
|
|
|
|
mask_array = sitk.GetArrayFromImage(input_label).squeeze() |
|
|
|
sacrum_value = 1 |
|
left_pelvis_value = 2 |
|
right_pelvis_value = 3 |
|
background_value = 0 |
|
|
|
|
|
sacrum_mask = (mask_array == sacrum_value) |
|
|
|
first_nonzero_column_index = np.nonzero(np.any(sacrum_mask != 0, axis=0))[0][0] |
|
last_nonzero_column_index = np.max(np.nonzero(np.any(sacrum_mask != 0, axis=0))) |
|
box_width=last_nonzero_column_index-first_nonzero_column_index |
|
|
|
dilation_extent = int(np.round(0.05 * box_width)) |
|
|
|
dilated_sacrum_mask = dilate_mask(sacrum_mask, dilation_extent) |
|
|
|
intersection_left = (dilated_sacrum_mask & (mask_array == left_pelvis_value)) |
|
if np.all(intersection_left == 0): |
|
print("Warning: No left intersection") |
|
left_pelvis_mask = (mask_array == 2) |
|
intersection_left = create_median_height_array(left_pelvis_mask) |
|
|
|
intersection_left = keep_largest_component(intersection_left) |
|
|
|
intersection_right = (dilated_sacrum_mask & (mask_array == right_pelvis_value)) |
|
if np.all(intersection_right == 0): |
|
print("Warning: No right intersection") |
|
right_pelvis_mask = (mask_array == 3) |
|
intersection_right = create_median_height_array(right_pelvis_mask) |
|
intersection_right = keep_largest_component(intersection_right) |
|
|
|
intersection_mask = intersection_left +intersection_right |
|
intersection_mask = intersection_mask[None, ...] |
|
|
|
instersection_mask_im = sitk.GetImageFromArray(intersection_mask) |
|
instersection_mask_im.SetSpacing(original_spacing) |
|
return instersection_mask_im |
|
|
|
def dilate_mask(mask, extent): |
|
""" |
|
Keeps only the largest connected component in a binary segmentation mask. |
|
|
|
Args: |
|
mask (numpy.ndarray): A numpy array representing the binary segmentation mask, |
|
with 1s indicating the label and 0s indicating the background. |
|
|
|
Returns: |
|
numpy.ndarray: A modified version of the input mask, where only the largest |
|
connected component is retained, and other components are set to 0. |
|
|
|
""" |
|
mask_uint8 = mask.astype(np.uint8) |
|
|
|
kernel = np.ones((2*extent+1, 2*extent+1), np.uint8) |
|
dilated_mask = dilate(mask_uint8, kernel, iterations=1) |
|
return dilated_mask |
|
|
|
def mask_and_crop(image, input_label): |
|
""" |
|
Performs masking and cropping operations on an image and its label. |
|
|
|
Args: |
|
image (SimpleITK.Image): The image to be processed. |
|
label (SimpleITK.Image): The corresponding label image. |
|
|
|
Returns: |
|
tuple: A tuple containing two SimpleITK.Image objects. |
|
- cropped_boxed_image: The image after applying bounding box masking and cropping. |
|
- mask: The binary mask corresponding to the label after cropping. |
|
|
|
Note: |
|
This function relies on other functions: bounding_box_mask() and threshold_based_crop(). |
|
""" |
|
input_label = creat_SIJ_mask(image,input_label) |
|
box_mask = bounding_box_mask(image, input_label) |
|
|
|
boxed_image = sitk.Mask(image, box_mask, maskingValue=0, outsideValue=0) |
|
masked_image = sitk.Mask(image, input_label, maskingValue=0, outsideValue=0) |
|
|
|
cropped_boxed_image = threshold_based_crop(boxed_image) |
|
cropped_masked_image = threshold_based_crop(masked_image) |
|
|
|
mask = np.squeeze(sitk.GetArrayFromImage(cropped_masked_image)) |
|
mask = np.where(mask > 0, 1, 0) |
|
mask = sitk.GetImageFromArray(mask[None, ...]) |
|
return cropped_boxed_image, mask |
|
|
|
def create_median_height_array(mask): |
|
""" |
|
Creates an array based on the median height of non-zero elements in each column of the input mask. |
|
|
|
Args: |
|
mask (numpy.ndarray): A binary mask with 1s representing the label and 0s the background. |
|
|
|
Returns: |
|
numpy.ndarray: A new binary mask array with columns filled based on the median height, |
|
or None if the input mask has no non-zero columns. |
|
|
|
Note: |
|
This function is only used when there is no intersection between pelvis and sacrum, and creates an alternative |
|
SIJ mask, that serves as an approximate replacement. |
|
""" |
|
rows, cols = mask.shape |
|
column_details = [] |
|
|
|
for col in range(cols): |
|
column_data = mask[:, col] |
|
non_zero_indices = np.nonzero(column_data)[0] |
|
if non_zero_indices.size > 0: |
|
height = non_zero_indices[-1] - non_zero_indices[0] + 1 |
|
start_idx = non_zero_indices[0] |
|
column_details.append((height, start_idx, col)) |
|
|
|
if not column_details: |
|
return None |
|
median_height = round(np.median([h[0] for h in column_details])) |
|
median_cols = [(col, start_idx) for height, start_idx, col in column_details if height == median_height] |
|
new_array = np.zeros_like(mask, dtype=int) |
|
for col, start_idx in median_cols: |
|
start_col = max(0, col - 5) |
|
end_col = min(cols, col + 5) |
|
new_array[start_idx:start_idx + median_height, start_col:end_col] = 1 |
|
return new_array |
|
|
|
def keep_largest_component(mask): |
|
""" |
|
Identifies and retains the largest connected component in a binary segmentation mask. |
|
|
|
Args: |
|
mask (numpy.ndarray): A binary mask with 1s representing the label and 0s the background. |
|
|
|
Returns: |
|
numpy.ndarray: The modified mask with only the largest connected component. |
|
""" |
|
|
|
labeled_array, num_features = label(mask) |
|
|
|
|
|
if num_features <= 1: |
|
return mask |
|
|
|
|
|
largest_component = np.argmax(np.bincount(labeled_array.flat)[1:]) + 1 |
|
|
|
|
|
return (labeled_array == largest_component).astype(mask.dtype) |