This model crops the foreground from the background in CT slices. It is a lightweight mobilenetv3_small_100
model trained on CT examinations from the public TotalSegmentator dataset, version.2.0.1.
The following function was used to generate masks for each CT:
import nibabel as nib
import numpy as np
from scipy.ndimage import binary_closing, binary_fill_holes, minimum_filter
from skimage.measure import label
def generate_mask(array):
mask = (array > 0).astype("uint8")
mask_label = label(mask)
labels, counts = np.unique(mask_label, return_counts=True)
labels, counts = labels[1:], counts[1:]
max_label = labels[np.argmax(counts)]
mask = mask_label == max_label
mask = np.stack([
binary_fill_holes(binary_closing(mask[:, :, i]))
for i in range(mask.shape[2])
], axis=2).astype("uint8")
mask = np.stack([
minimum_filter(mask[:, :, i], size=3)
for i in range(mask.shape[2])
], axis=2)
return mask
array = nib.load("ct.nii.gz").get_fdata()
# apply soft tissue window
array = apply_ct_window(array, window_level=50, window_width=400)
mask = generate_mask(array)
Bounding box coordinates were generated from the masks for individual slices.
The model was then trained to predict normalized (0-1) xwyh
coordinates, given an individual CT slice.
If the mask was empty, the coordinates were set to all zero.
Images were converted from Hounsfield units (HU) to 4 CT windows:
- Soft tissue (level=50, width=400)
- Brain (level=40, width=80)
- Lung (level=-600, width=1500)
- Bone (level=400, width=1800)
During training, random combinations of channels were selected. If more than 1 channel was selected, the images were averaged channel-wise to create a single-channel output. Strong data augmentation was also applied. Thus, this model should be robust to different CT windows and combinations thereof.
Example usage below:
import cv2
import torch
from transformers import AutoModel
device = "cuda" if torch.cuda.is_available() else "cpu"
cropper = AutoModel.from_pretrained("ianpan/ct-crop", trust_remote_code=True).eval().to(device)
# single image
img = cv2.imread("ct_slice.png", cv2.IMREAD_GRAYSCALE)
cropped_img = cropper.crop(img, mode="2d", device=device, raw_hu=False, add_buffer=None)
# expand all 4 sides by 2.5% each
cropped_img = cropper.crop(img, mode="2d", device=device, raw_hu=False, add_buffer=0.025)
# expand box height by 2.5% in each direction
# and box width by 5% in each direction
buffer = (0.05, 0.025)
cropped_img = cropper.crop(img, mode="2d", device=device, raw_hu=False, add_buffer=buffer)
# stack of images
img_list = ["ct_slice_1.png", "ct_slice_2.png", ...]
stack = np.stack([cv2.imread(img, cv2.IMREAD_GRAYSCALE) for img in img_list], axis=0)
cropped_stack = cropper.crop(img, mode="3d", device=device, raw_hu=False, add_buffer=None)
You can also get the coordinates directly and do the cropping yourself. You must separately preprocess the input. Example below:
# single image
img0 = cv2.imread("ct_slice.png", cv2.IMREAD_GRAYSCALE)
img_shapes = torch.tensor([_.shape[:2] for _ in [img0]]).to(device)
img = cropper.preprocess(img0, mode="2d")
# if multi-channel, need to convert from channels-last -> channels-first
img = torch.from_numpy(img).expand(1, 1, -1, -1).float().to(device)
with torch.inference_mode():
coords = cropper(img, img_shape=img_shapes, add_buffer=None)
# if you do not provide img_shapes, output will be normalized (0-1) coordinates
# otherwise will be scaled to img_shape
The model also contains methods to load DICOM images, if you have pydicom
installed:
img = cropper.load_image_from_dicom(path_to_dicom_file, windows=None)
# note: RescaleSlope and RescaleIntercept already applied in the method
# apply CT window
brain_window = (40, 80)
img = cropper.load_image_from_dicom(path_to_dicom_file, windows=brain_window)
# or multiple windows
soft_tissue_window = (50, 400)
img = cropper.load_image_from_dicom(path_to_dicom_file,
windows=[brain_window, soft_tissue_window])
# each window is a separate channel, img will be channels-last format
You can also load a stack of DICOM images from a folder:
dicom_folder = "/path/to/ct/head/images/"
# dicom_extension is used to filter files, default is ".dcm"
# can pass "" if you do not want to filter files
# default sort is by ImagePositionPatient using automatically determined
# orientation, can also sort by InstanceNumber
# can also apply CT windows, as above
stack = cropper.load_stack_from_dicom_folder(dicom_folder,
windows=None,
dicom_extension=".dcm",
sort_by_instance_number=False)
# can input raw Hounsfield units into cropper
cropped_stack = cropper.crop(stack, mode="3d", device=device, raw_hu=True)
By default, the cropper will not remove slices in a stack, even if they are predicted to be empty.
You can enable this by specifying remove_empty_slices=True
, which will also return
the indices in the original input of the removed empty slices.
cropped_stack, empty_slice_indices = cropper.crop(stack, mode="3d", remove_empty_slices=True)
- Downloads last month
- 30
Model tree for ianpan/ct-crop
Base model
timm/mobilenetv3_small_100.lamb_in1k