This model crops the foreground from the background in CT slices. It is a lightweight mobilenetv3_small_100 model trained on CT examinations from the public TotalSegmentator dataset, version.2.0.1.

The following function was used to generate masks for each CT:

import nibabel as nib
import numpy as np
from scipy.ndimage import binary_closing, binary_fill_holes, minimum_filter
from skimage.measure import label

def generate_mask(array):
    mask = (array > 0).astype("uint8")
    mask_label = label(mask)
    labels, counts = np.unique(mask_label, return_counts=True)
    labels, counts = labels[1:], counts[1:]
    max_label = labels[np.argmax(counts)]
    mask = mask_label == max_label
    mask = np.stack([
        binary_fill_holes(binary_closing(mask[:, :, i]))
        for i in range(mask.shape[2])
    ], axis=2).astype("uint8")
    mask = np.stack([
        minimum_filter(mask[:, :, i], size=3)
        for i in range(mask.shape[2])
    ], axis=2)
    return mask

array = nib.load("ct.nii.gz").get_fdata()
# apply soft tissue window
array = apply_ct_window(array, window_level=50, window_width=400)
mask = generate_mask(array)

Bounding box coordinates were generated from the masks for individual slices. The model was then trained to predict normalized (0-1) xwyh coordinates, given an individual CT slice. If the mask was empty, the coordinates were set to all zero. Images were converted from Hounsfield units (HU) to 4 CT windows:

  1. Soft tissue (level=50, width=400)
  2. Brain (level=40, width=80)
  3. Lung (level=-600, width=1500)
  4. Bone (level=400, width=1800)

During training, random combinations of channels were selected. If more than 1 channel was selected, the images were averaged channel-wise to create a single-channel output. Strong data augmentation was also applied. Thus, this model should be robust to different CT windows and combinations thereof.

Example usage below:

import cv2
import torch
from transformers import AutoModel

device = "cuda" if torch.cuda.is_available() else "cpu"
cropper = AutoModel.from_pretrained("ianpan/ct-crop", trust_remote_code=True).eval().to(device)

# single image
img = cv2.imread("ct_slice.png", cv2.IMREAD_GRAYSCALE)
cropped_img = cropper.crop(img, mode="2d", device=device, raw_hu=False, add_buffer=None)

# expand all 4 sides by 2.5% each
cropped_img = cropper.crop(img, mode="2d", device=device, raw_hu=False, add_buffer=0.025)

# expand box height by 2.5% in each direction
# and box width by 5% in each direction
buffer = (0.05, 0.025)
cropped_img = cropper.crop(img, mode="2d", device=device, raw_hu=False, add_buffer=buffer)

# stack of images
img_list = ["ct_slice_1.png", "ct_slice_2.png", ...]
stack = np.stack([cv2.imread(img, cv2.IMREAD_GRAYSCALE) for img in img_list], axis=0)
cropped_stack = cropper.crop(img, mode="3d", device=device, raw_hu=False, add_buffer=None)

You can also get the coordinates directly and do the cropping yourself. You must separately preprocess the input. Example below:

# single image
img0 = cv2.imread("ct_slice.png", cv2.IMREAD_GRAYSCALE)
img_shapes = torch.tensor([_.shape[:2] for _ in [img0]]).to(device)
img = cropper.preprocess(img0, mode="2d")
# if multi-channel, need to convert from channels-last -> channels-first
img = torch.from_numpy(img).expand(1, 1, -1, -1).float().to(device)
with torch.inference_mode():
    coords = cropper(img, img_shape=img_shapes, add_buffer=None)

# if you do not provide img_shapes, output will be normalized (0-1) coordinates
# otherwise will be scaled to img_shape

The model also contains methods to load DICOM images, if you have pydicom installed:

img = cropper.load_image_from_dicom(path_to_dicom_file, windows=None)

# note: RescaleSlope and RescaleIntercept already applied in the method
# apply CT window
brain_window = (40, 80)
img = cropper.load_image_from_dicom(path_to_dicom_file, windows=brain_window)

# or multiple windows
soft_tissue_window = (50, 400)
img = cropper.load_image_from_dicom(path_to_dicom_file,
                                    windows=[brain_window, soft_tissue_window])
# each window is a separate channel, img will be channels-last format

You can also load a stack of DICOM images from a folder:

dicom_folder = "/path/to/ct/head/images/"

# dicom_extension is used to filter files, default is ".dcm"
# can pass "" if you do not want to filter files
# default sort is by ImagePositionPatient using automatically determined
# orientation, can also sort by InstanceNumber
# can also apply CT windows, as above
stack = cropper.load_stack_from_dicom_folder(dicom_folder,
                                             windows=None,
                                             dicom_extension=".dcm",
                                             sort_by_instance_number=False)

# can input raw Hounsfield units into cropper
cropped_stack = cropper.crop(stack, mode="3d", device=device, raw_hu=True)

By default, the cropper will not remove slices in a stack, even if they are predicted to be empty. You can enable this by specifying remove_empty_slices=True, which will also return the indices in the original input of the removed empty slices.

cropped_stack, empty_slice_indices = cropper.crop(stack, mode="3d", remove_empty_slices=True)
Downloads last month
30
Safetensors
Model size
1.53M params
Tensor type
F32
·
Inference Providers NEW
This model is not currently available via any of the supported third-party Inference Providers, and HF Inference API does not yet support model repos that contain custom code.

Model tree for ianpan/ct-crop

Finetuned
(3)
this model