--- library_name: transformers tags: - chest_x_ray - x_ray - medical_imaging - radiology - segmentation - classification - lungs - heart base_model: - timm/tf_efficientnetv2_s.in21k_ft_in1k pipeline_tag: image-segmentation --- This model performs both segmentation and classification on chest radiographs (X-rays). The model uses a `tf_efficientnetv2_s` backbone with a U-Net decoder for segmentation and linear layer for classification. For frontal radiographs, the model segments the: 1) right lung, 2) left lung, and 3) heart. The model also predicts the chest X-ray view (AP, PA, lateral), patient age, and patient sex. The [CheXpert](https://stanfordmlgroup.github.io/competitions/chexpert/) (small version) and [NIH Chest X-ray](https://nihcc.app.box.com/v/ChestXray-NIHCC) datasets were used to train the model. Segmentation masks were obtained from the CheXmask [dataset](https://physionet.org/content/chexmask-cxr-segmentation-data/0.4/) ([paper](https://www.nature.com/articles/s41597-024-03358-1)). The final dataset comprised 335,516 images from 96,385 patients and was split into 80% training/20% validation. A holdout test set was not used since minimal tuning was performed. The view classifier was trained only on CheXpert images (NIH images excluded from loss function), given that lateral radiographs are only present in CheXpert. This is to avoid unwanted bias in the model, which can occur if one class originates only from a single dataset. Validation performance as follows: ``` Segmentation (Dice similarity coefficient): Right Lung: 0.957 Left Lung: 0.948 Heart: 0.943 Age Prediction: Mean Absolute Error: 5.25 years Classification: View (AP, PA, lateral): 99.42% accuracy Female: 0.999 AUC ``` To use the model: ``` import cv2 import torch from transformers import AutoModel device = "cuda" if torch.cuda.is_available() else "cpu" model = AutoModel.from_pretrained("ianpan/chest-x-ray-basic", trust_remote_code=True) model = model.eval().to(device) img = cv2.imread(..., 0) x = model.preprocess(img) # only takes single image as input x = torch.from_numpy(x).unsqueeze(0).unsqueeze(0) # add channel, batch dims x = x.float() with torch.inference_mode(): out = model(x.to(device)) ``` The output is a dictionary which contains 4 keys: * `mask` has 3 channels containing the segmentation masks. Take the argmax over the channel dimension to create a single image mask (i.e., `out["mask"].argmax(1)`): 1 = right lung, 2 = left lung, 3 = heart. * `age`, in years. * `view`, with 3 classes for each possible view. Take the argmax to select the predicted view (i.e., `out["view"].argmax(1)`): 0 = AP, 1 = PA, 2 = lateral. * `female`, binarize with `out["female"] >= 0.5`. You can use the segmentation mask to crop the region containing the lungs from the rest of the X-ray. You can also calculate the [cardiothoracic ratio (CTR)](https://radiopaedia.org/articles/cardiothoracic-ratio?lang=us) using this function: ``` import numpy as np def calculate_ctr(mask): # single mask with dims (height, width) lungs = np.zeros_like(mask) lungs[mask == 1] = 1 lungs[mask == 2] = 1 heart = (mask == 3).astype("int") y, x = np.stack(np.where(lungs == 1)) lung_min = x.min() lung_max = x.max() y, x = np.stack(np.where(heart == 1)) heart_min = x.min() heart_max = x.max() lung_range = lung_max - lung_min heart_range = heart_max - heart_min return heart_range / lung_range ``` If you have `pydicom` installed, you can also load a DICOM image directly: ``` img = model.load_image_from_dicom(path_to_dicom) ``` This model is for demonstration and research purposes only and has NOT been approved by any regulatory agency for clinical use. The user assumes any and all responsibility regarding their own use of this model and its outputs.