Spaces:
Running
on
Zero
Running
on
Zero
from monai.transforms import Transform | |
import torch | |
import skimage | |
import torch | |
import SimpleITK as sitk | |
import numpy as np | |
from PIL import Image | |
from io import BytesIO | |
import matplotlib.pyplot as plt | |
import SimpleITK as sitk | |
from matplotlib.colors import ListedColormap | |
import base64 | |
import numpy as np | |
from cv2 import dilate | |
from scipy.ndimage import label | |
def image_to_base64(image_path): | |
with open(image_path, "rb") as image_file: | |
return base64.b64encode(image_file.read()).decode('utf-8') | |
class CustomCLAHE(Transform): | |
"""Implements Contrast-Limited Adaptive Histogram Equalization (CLAHE) as a custom transform, as described by Qiu et al. | |
Attributes: | |
p1 (float): Weighting factor, determines degree of of contour enhacement. Default is 0.6. | |
p2 (None or int): Kernel size for adaptive histogram. Default is None. | |
p3 (float): Clip limit for histogram equalization. Default is 0.01. | |
""" | |
def __init__(self, p1=0.6, p2=None, p3=0.01): | |
self.p1 = p1 | |
self.p2 = p2 | |
self.p3 = p3 | |
def __call__(self, data): | |
"""Apply the CLAHE algorithm to input data. | |
Args: | |
data (Union[dict, np.ndarray]): Input data. Could be a dictionary containing the image or the image array itself. | |
Returns: | |
torch.Tensor: Transformed data. | |
""" | |
if isinstance(data, dict): | |
im = data["image"] | |
else: | |
im = data | |
im = im.numpy() | |
# remove the first dimension | |
im = im[0] | |
im = im[None, :, :] | |
#im = np.expand_dims(im, axis=0) | |
im = skimage.exposure.rescale_intensity(im, in_range="image", out_range=(0, 1)) | |
im_noi = skimage.filters.median(im) | |
im_fil = im_noi - self.p1 * skimage.filters.gaussian(im_noi, sigma=1) | |
im_fil = skimage.exposure.rescale_intensity(im_fil, in_range="image", out_range=(0, 1)) | |
im_ce = skimage.exposure.equalize_adapthist(im_fil, kernel_size=self.p2, clip_limit=self.p3) | |
if isinstance(data, dict): | |
data["image"] = torch.Tensor(im_ce) | |
else: | |
data = torch.Tensor(im_ce) | |
return data | |
def custom_colormap(): | |
cdict = [(0, 0, 0, 0), # Class 0 - fully transparent (background) | |
(0, 1, 0, 0.5), # Class 1 - Green with 50% transparency | |
(1, 0, 0, 0.5), # Class 2 - Red with 50% transparency | |
(1, 1, 0, 0.5)] # Class 3 - Yellow with 50% transparency | |
cmap = ListedColormap(cdict) | |
return cmap | |
def read_image(image_path): | |
try: | |
original_image = Image.open(image_path).convert('L') | |
original_image_np = np.array(original_image) | |
return original_image_np.squeeze() | |
except Exception as e: | |
try : | |
original_image = sitk.ReadImage(image_path) | |
original_image_np = sitk.GetArrayFromImage(original_image) | |
return original_image_np.squeeze() | |
except Exception as e: | |
print("Failed Loading the Image: ", e) | |
return None | |
def overlay_mask(image_path, image_mask): | |
original_image_np = read_image(image_path).squeeze().astype(np.uint8) | |
#adjust mask intensities for display | |
image_mask_disp = image_mask | |
plt.figure(figsize=(10, 10)) | |
plt.imshow(original_image_np, cmap='gray') | |
plt.imshow(image_mask_disp, cmap=custom_colormap(), alpha=0.5) | |
plt.axis('off') | |
# Save the overlay to a buffer | |
buffer = BytesIO() | |
plt.savefig(buffer, format='png', bbox_inches='tight', pad_inches=0) | |
buffer.seek(0) | |
overlay_image_np = np.array(Image.open(buffer)) | |
return overlay_image_np, original_image_np | |
def bounding_box_mask(image, label): | |
"""Generates a bounding box mask around a labeled region in an image | |
Args: | |
image (SimpleITK.Image): The input image. | |
label (SimpleITK.Image): The labeled image containing the region of interest. | |
Returns: | |
SimpleITK.Image: An image containing the with the bounding box mask applied with the | |
same spacing as the original image. | |
Note: | |
This function assumes that the input image and label are SimpleITK.Image objects. | |
The returned bounding box mask is a binary image where pixels inside the bounding box | |
are set to 1 and others are set to 0. | |
""" | |
# get original spacing | |
original_spacing = image.GetSpacing() | |
# convert image and label to arrays | |
image_array = sitk.GetArrayFromImage(image) | |
image_array = np.squeeze(image_array) | |
label_array = sitk.GetArrayFromImage(label) | |
label_array = np.squeeze(label_array) | |
# determine corners of the bounding box | |
first_nonzero_row_index = np.nonzero(np.any(label_array != 0, axis=1))[0][0] | |
last_nonzero_row_index = np.max(np.nonzero(np.any(label_array != 0, axis=1))) | |
first_nonzero_column_index = np.nonzero(np.any(label_array != 0, axis=0))[0][0] | |
last_nonzero_column_index = np.max(np.nonzero(np.any(label_array != 0, axis=0))) | |
top_left_corner = (first_nonzero_row_index, first_nonzero_column_index) | |
bottom_right_corner = (last_nonzero_row_index, last_nonzero_column_index) | |
# define the bounding box as an array mask | |
bounding_box_array = label_array.copy() | |
bounding_box_array[ | |
top_left_corner[0] : bottom_right_corner[0] + 1, | |
top_left_corner[1] : bottom_right_corner[1] + 1, | |
] = 1 | |
# add channel dimension | |
bounding_box_array = bounding_box_array[None, ...].astype(np.uint8) | |
# get Image from Array Mask and apply original spacing | |
bounding_box_image = sitk.GetImageFromArray(bounding_box_array) | |
bounding_box_image.SetSpacing(original_spacing) | |
return bounding_box_image | |
def threshold_based_crop(image): | |
""" | |
Use Otsu's threshold estimator to separate background and foreground. In medical imaging the background is | |
usually air. Then crop the image using the foreground's axis aligned bounding box. | |
Args: | |
image (SimpleITK image): An image where the anatomy and background intensities form a | |
bi-modal distribution | |
(the assumption underlying Otsu's method.) | |
Return: | |
Cropped image based on foreground's axis aligned bounding box. | |
""" | |
inside_value = 0 | |
outside_value = 255 | |
label_shape_filter = sitk.LabelShapeStatisticsImageFilter() | |
# uncomment for debugging | |
#sitk.WriteImage(image, "./image.png") | |
label_shape_filter.Execute(sitk.OtsuThreshold(image, inside_value, outside_value)) | |
bounding_box = label_shape_filter.GetBoundingBox(outside_value) | |
return sitk.RegionOfInterest( | |
image, | |
bounding_box[int(len(bounding_box) / 2) :], | |
bounding_box[0 : int(len(bounding_box) / 2)], | |
) | |
def creat_SIJ_mask(image, input_label): | |
""" | |
Create a mask for the sacroiliac joints (SIJ) from pelvis and sascrum segmentation mask | |
Args: | |
image (SimpleITK.Image): x-ray image. | |
input_label (SimpleITK.Image): Segmentation mask containing labels for sacrum, left- and right pelvis | |
Returns: | |
SimpleITK.Image: Mask of the SIJ | |
""" | |
original_spacing = image.GetSpacing() | |
# uncomment for debugging | |
#sitk.WriteImage(input_label, "./input_label.png") | |
mask_array = sitk.GetArrayFromImage(input_label).squeeze() | |
sacrum_value = 1 | |
left_pelvis_value = 2 | |
right_pelvis_value = 3 | |
background_value = 0 | |
sacrum_mask = (mask_array == sacrum_value) | |
first_nonzero_column_index = np.nonzero(np.any(sacrum_mask != 0, axis=0))[0][0] | |
last_nonzero_column_index = np.max(np.nonzero(np.any(sacrum_mask != 0, axis=0))) | |
box_width=last_nonzero_column_index-first_nonzero_column_index | |
dilation_extent = int(np.round(0.05 * box_width)) | |
dilated_sacrum_mask = dilate_mask(sacrum_mask, dilation_extent) | |
intersection_left = (dilated_sacrum_mask & (mask_array == left_pelvis_value)) | |
if np.all(intersection_left == 0): | |
print("Warning: No left intersection") | |
left_pelvis_mask = (mask_array == 2) | |
intersection_left = create_median_height_array(left_pelvis_mask) | |
intersection_left = keep_largest_component(intersection_left) | |
intersection_right = (dilated_sacrum_mask & (mask_array == right_pelvis_value)) | |
if np.all(intersection_right == 0): | |
print("Warning: No right intersection") | |
right_pelvis_mask = (mask_array == 3) | |
intersection_right = create_median_height_array(right_pelvis_mask) | |
intersection_right = keep_largest_component(intersection_right) | |
intersection_mask = intersection_left +intersection_right | |
intersection_mask = intersection_mask[None, ...] | |
instersection_mask_im = sitk.GetImageFromArray(intersection_mask) | |
instersection_mask_im.SetSpacing(original_spacing) | |
return instersection_mask_im | |
def dilate_mask(mask, extent): | |
""" | |
Keeps only the largest connected component in a binary segmentation mask. | |
Args: | |
mask (numpy.ndarray): A numpy array representing the binary segmentation mask, | |
with 1s indicating the label and 0s indicating the background. | |
Returns: | |
numpy.ndarray: A modified version of the input mask, where only the largest | |
connected component is retained, and other components are set to 0. | |
""" | |
mask_uint8 = mask.astype(np.uint8) | |
kernel = np.ones((2*extent+1, 2*extent+1), np.uint8) | |
dilated_mask = dilate(mask_uint8, kernel, iterations=1) | |
return dilated_mask | |
def mask_and_crop(image, input_label): | |
""" | |
Performs masking and cropping operations on an image and its label. | |
Args: | |
image (SimpleITK.Image): The image to be processed. | |
label (SimpleITK.Image): The corresponding label image. | |
Returns: | |
tuple: A tuple containing two SimpleITK.Image objects. | |
- cropped_boxed_image: The image after applying bounding box masking and cropping. | |
- mask: The binary mask corresponding to the label after cropping. | |
Note: | |
This function relies on other functions: bounding_box_mask() and threshold_based_crop(). | |
""" | |
input_label = creat_SIJ_mask(image,input_label) | |
box_mask = bounding_box_mask(image, input_label) | |
boxed_image = sitk.Mask(image, box_mask, maskingValue=0, outsideValue=0) | |
masked_image = sitk.Mask(image, input_label, maskingValue=0, outsideValue=0) | |
cropped_boxed_image = threshold_based_crop(boxed_image) | |
cropped_masked_image = threshold_based_crop(masked_image) | |
mask = np.squeeze(sitk.GetArrayFromImage(cropped_masked_image)) | |
mask = np.where(mask > 0, 1, 0) | |
mask = sitk.GetImageFromArray(mask[None, ...]) | |
return cropped_boxed_image, mask | |
def create_median_height_array(mask): | |
""" | |
Creates an array based on the median height of non-zero elements in each column of the input mask. | |
Args: | |
mask (numpy.ndarray): A binary mask with 1s representing the label and 0s the background. | |
Returns: | |
numpy.ndarray: A new binary mask array with columns filled based on the median height, | |
or None if the input mask has no non-zero columns. | |
Note: | |
This function is only used when there is no intersection between pelvis and sacrum, and creates an alternative | |
SIJ mask, that serves as an approximate replacement. | |
""" | |
rows, cols = mask.shape | |
column_details = [] | |
for col in range(cols): | |
column_data = mask[:, col] | |
non_zero_indices = np.nonzero(column_data)[0] | |
if non_zero_indices.size > 0: | |
height = non_zero_indices[-1] - non_zero_indices[0] + 1 | |
start_idx = non_zero_indices[0] | |
column_details.append((height, start_idx, col)) | |
if not column_details: | |
return None | |
median_height = round(np.median([h[0] for h in column_details])) | |
median_cols = [(col, start_idx) for height, start_idx, col in column_details if height == median_height] | |
new_array = np.zeros_like(mask, dtype=int) | |
for col, start_idx in median_cols: | |
start_col = max(0, col - 5) | |
end_col = min(cols, col + 5) | |
new_array[start_idx:start_idx + median_height, start_col:end_col] = 1 | |
return new_array | |
def keep_largest_component(mask): | |
""" | |
Identifies and retains the largest connected component in a binary segmentation mask. | |
Args: | |
mask (numpy.ndarray): A binary mask with 1s representing the label and 0s the background. | |
Returns: | |
numpy.ndarray: The modified mask with only the largest connected component. | |
""" | |
# Label the connected components | |
labeled_array, num_features = label(mask) | |
# If no features are found, return the original mask | |
if num_features <= 1: | |
return mask | |
# Find the largest connected component | |
largest_component = np.argmax(np.bincount(labeled_array.flat)[1:]) + 1 | |
# Generate the mask for the largest component | |
return (labeled_array == largest_component).astype(mask.dtype) |