|
--- |
|
license: mit |
|
--- |
|
This is the trained model for the controlnet-stablediffusion for the Synthetic CT/MRI generaion from Segmentation Map |
|
We have to customize the pipeline for controlnet-stablediffusion |
|
|
|
This Model is trained on the JHU dataset, containing, 5312 CT volumes with corrosponding Segmentation mask, |
|
|
|
We make the 2D slices of CT volumes ~ 1.3M 2D slices |
|
|
|
Here is the training and inference code for [Diff_Synth_CT](https://github.com/Onkarsus13/DiffCTSeg) |
|
|
|
Training details |
|
|
|
Hardware: 8x Nvidia-A6000 |
|
|
|
Batch size: 8 x 4 x 32 |
|
|
|
For direct inference |
|
|
|
|
|
step 1: Clone the GitHub repo to get the customized ControlNet-StableDiffusion Pipeline Implementation |
|
``` |
|
git clone https://github.com/Onkarsus13/DiffCTSeg |
|
``` |
|
|
|
Step2: Go into the repository and install repository, dependency |
|
``` |
|
cd DiffCTSeg |
|
pip install -e ".[torch]" |
|
pip install -e .[all,dev,notebooks] |
|
``` |
|
|
|
Step3: Run `python test_eraser.py` OR You can run the code given below |
|
|
|
```python |
|
from diffusers import StableDiffusionControlNetInpaintPipeline, ControlNetModel, UniPCMultistepScheduler, PNDMScheduler, DDIMScheduler, DPMSolverMultistepScheduler |
|
import torch |
|
from PIL import Image |
|
import numpy as np |
|
import glob |
|
|
|
|
|
class_dict_BTCV = { |
|
0:(0, 0, 0), |
|
1:(255, 60, 0), |
|
2:(255, 60, 232), |
|
3:(134, 79, 117), |
|
4:(125, 0, 190), |
|
5:(117, 200, 191), |
|
6:(230, 91, 101), |
|
7:(255, 0, 155), |
|
8:(75, 205, 155), |
|
9:(100, 37, 200) |
|
} |
|
|
|
class_dict = { |
|
0:"background", |
|
1:"aorta", |
|
2:"kidney_left", |
|
3:"liver", |
|
4:"postcava", |
|
5:"stomach", |
|
6:"gall_bladder", |
|
7:"kidney_right", |
|
8:"pancreas", |
|
9:"spleen" |
|
} |
|
|
|
def rgb_to_onehot(rgb_arr, color_dict=class_dict_BTCV): |
|
num_classes = len(color_dict) |
|
shape = rgb_arr.shape[:2]+(num_classes,) |
|
arr = np.zeros( shape, dtype=np.int8 ) |
|
for i, cls in enumerate(color_dict): |
|
arr[:,:,i] = np.all(rgb_arr.reshape( (-1,3) ) == color_dict[i], axis=1).reshape(shape[:2]) |
|
return arr |
|
|
|
|
|
|
|
pipe = StableDiffusionControlNetInpaintPipeline.from_pretrained( |
|
"onkarsus13/PixArt_Dual_Tone_CT_SEG_V0.1", torch_dtype=torch.float16, safety_checker=None, |
|
feature_extractor=None, |
|
) |
|
pipe.scheduler = UniPCMultistepScheduler.from_pretrained('onkarsus13/PixArt_Dual_Tone_CT_SEG_V0.1', subfolder="scheduler") |
|
pipe.to('cuda:0') |
|
pipe.enable_model_cpu_offload() |
|
|
|
|
|
generator = torch.Generator(device="cpu").manual_seed(1) |
|
images = Image.open("<Give Segmentation Mask>") |
|
npi = np.asarray(images.convert("RGB")) |
|
npi = rgb_to_onehot(npi, ).argmax(-1) |
|
unique_ids = np.unique(npi) |
|
|
|
print('CT image containg '+" ".join([class_dict[i] for i in unique_ids])) |
|
image = pipe( |
|
'CT image containg '+" ".join([class_dict[i] for i in unique_ids]), |
|
images, |
|
[images], |
|
num_inference_steps=30, |
|
generator=generator, |
|
controlnet_conditioning_scale=1.0, |
|
).images[0] |
|
|
|
image.save('./result.png') |
|
|
|
|
|
|
|
``` |