Spaces:
Running
on
Zero
Running
on
Zero
import torch | |
import numpy as np | |
from monai.transforms import Compose, LoadImage, EnsureChannelFirst, Lambda, Resize, NormalizeIntensity, GaussianSmooth, ScaleIntensity, AsDiscrete, KeepLargestConnectedComponent, Invert, Rotate90, SaveImage, Transform | |
from monai.inferers import SlidingWindowInferer | |
from monai.networks.nets import UNet | |
class RgbaToGrayscale(Transform): | |
def __call__(self, x): | |
# squeeze last dimension, to ensure C, H, W format | |
x = x.squeeze(-1) | |
# Ensure the tensor is 3D (channels, height, width) | |
if x.ndim != 3: | |
raise ValueError(f"Input tensor must be 3D. Shape: {x.shape}") | |
# Check the number of channels | |
if x.shape[0] == 4: # Assuming RGBA | |
rgb_weights = torch.tensor([0.2989, 0.5870, 0.1140], device=x.device) | |
# Apply weights to RGB channels, output should retain one channel dimension | |
grayscale = torch.einsum('cwh,c->wh', x[:3, :, :], rgb_weights).unsqueeze(0) | |
elif x.shape[0] == 3: # Assuming RGB | |
rgb_weights = torch.tensor([0.2989, 0.5870, 0.1140], device=x.device) | |
grayscale = torch.einsum('cwh,c->wh', x, rgb_weights).unsqueeze(0) | |
elif x.shape[0] == 1: # Already grayscale | |
grayscale = x | |
else: | |
raise ValueError(f"Unsupported channel number: {x.shape[0]}") | |
return grayscale | |
def inverse(self, x): | |
# Simply return the input as the output | |
return x | |
model = UNet( | |
spatial_dims=2, | |
in_channels=1, | |
out_channels=4, | |
channels=[64, 128, 256, 512], | |
strides=[2, 2, 2], | |
num_res_units=3 | |
) | |
checkpoint_path = 'segmentation_model.pt' | |
checkpoint = torch.load(checkpoint_path, map_location='cpu') | |
assert model.state_dict().keys() == checkpoint['network'].keys(), "Model and checkpoint keys do not match" | |
model.load_state_dict(checkpoint['network']) | |
model.eval() | |
# Define transforms for preprocessing | |
pre_transforms = Compose([ | |
LoadImage(image_only=True), | |
EnsureChannelFirst(), | |
RgbaToGrayscale(), # Convert RGBA to grayscale | |
Resize(spatial_size=(768, 768)), | |
Lambda(func=lambda x: x.squeeze(-1)), # Adjust if the input image has an extra unwanted dimension | |
NormalizeIntensity(), | |
GaussianSmooth(sigma=0.1), | |
ScaleIntensity(minv=-1, maxv=1) | |
]) | |
# Define transforms for postprocessing | |
post_transforms = Compose([ | |
AsDiscrete(argmax=True, to_onehot=4), | |
KeepLargestConnectedComponent(), | |
AsDiscrete(argmax=True), | |
Invert(pre_transforms), | |
#SaveImage(output_dir='./', output_postfix='seg', output_ext='.nii', resample=False) | |
]) | |
def load_and_segment_image(input_image_path, device): | |
image_tensor = pre_transforms(input_image_path) | |
image_tensor = image_tensor.unsqueeze(0).to(device) | |
# Inference using SlidingWindowInferer | |
inferer = SlidingWindowInferer(roi_size=(512, 512), sw_batch_size=16, overlap=0.75) | |
with torch.no_grad(): | |
outputs = inferer(image_tensor, model.to(device)) | |
outputs = outputs.squeeze(0) | |
processed_outputs = post_transforms(outputs) | |
# rotate | |
rotate = Rotate90(spatial_axes=(0, 1), k=3) | |
processed_outputs = rotate(processed_outputs).to('cpu') | |
output_array = processed_outputs.squeeze().detach().numpy().astype(np.uint8) | |
return output_array |