File size: 6,434 Bytes
630e8d8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4475117
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
# example of using saved cycleGAN models for image translation
#based on https://machinelearningmastery.com/cyclegan-tutorial-with-keras/
from keras.models import load_model
import numpy as np
import tensorflow_addons as tfa
from scipy.ndimage import zoom
from tqdm import tqdm
import warnings
warnings.filterwarnings("ignore")
from huggingface_hub import hf_hub_download
from skimage.morphology import binary_erosion, binary_dilation
from skimage import draw


def predict_mask(image, dim_x, dim_y, dim_z, _resize=True, norm_=True, mode_='test', patch_size=(64,128,128,1), _step=64, _step_z=32, _patch_size_z=64):

    cust={'InstanceNormalization': tfa.layers.InstanceNormalization}
    #load the model
    # Download the model from Hugging Face Model Hub
    model_dir = hf_hub_download(repo_id="Hemaxi/3DCycleGAN", filename="CycleGANVesselSegmentation.h5")
    model_BtoA = load_model(model_dir, cust)

    print('Mode: {}'.format(mode_))

    _patch_size = patch_size[1]
    _nbslices = patch_size[0]

    perceqmin = 1
    perceqmax = 99

    image = ((image/(np.max(image)))*255).astype('uint8')

    print('Image Shape: {}'.format(image.shape))
    print('----------------------------------------')

    initial_image_x = np.shape(image)[0]
    initial_image_y = np.shape(image)[1]
    initial_image_z = np.shape(image)[2]

    #percentile equalization
    if norm_:
        minval = np.percentile(image, perceqmin) 
        maxval = np.percentile(image, perceqmax)
        image = np.clip(image, minval, maxval)
        image = (((image - minval) / (maxval - minval)) * 255).astype('uint8')

    if _resize:
        image = zoom(image, (dim_x/0.333, dim_y/0.333, dim_z/0.5), order=3, mode='nearest')
        image = ((image/np.max(image))*255.0).astype('uint8')


    #image size
    size_y = np.shape(image)[0]
    size_x = np.shape(image)[1]
    size_depth = np.shape(image)[2]
    aux_sizes_or = [size_y, size_x, size_depth]
    

    #patch size
    new_size_y = int((size_y/_patch_size) + 1) * _patch_size
    new_size_x = int((size_x/_patch_size) + 1) * _patch_size
    new_size_z = int((size_depth/_patch_size_z) + 1) * _patch_size_z
    aux_sizes = [new_size_y, new_size_x, new_size_z]
    
    ## zero padding
    aux_img = np.random.randint(1,50,(aux_sizes[0], aux_sizes[1], aux_sizes[2]))
    aux_img[0:aux_sizes_or[0], 0:aux_sizes_or[1],0:aux_sizes_or[2]] = image
    image = aux_img
    del aux_img
        
    final_mask_foreground = np.zeros((np.shape(image)[0], np.shape(image)[1], np.shape(image)[2]))
    final_mask_background = np.zeros((np.shape(image)[0], np.shape(image)[1], np.shape(image)[2]))
    final_mask_background = final_mask_background.astype('uint8')
    final_mask_foreground = final_mask_foreground.astype('uint8')
    

    total_iterations = int(image.shape[0]/_patch_size)

    with tqdm(total=total_iterations) as pbar:
        i=0
        while i+_patch_size<=image.shape[0]:
            j=0
            while j+_patch_size<=image.shape[1]:
                k=0
                while k+_patch_size_z<=image.shape[2]:
                
                    B_real = np.zeros((1,_nbslices,_patch_size,_patch_size,1),dtype='float32')
                    _slice = image[i:i+_patch_size, j:j+_patch_size, k:k+_patch_size_z]
                    
                    _slice = _slice.transpose(2,0,1)
                    _slice = np.expand_dims(_slice, axis=-1)
        
                    B_real[0,:]=(_slice-127.5) /127.5   
        
                    A_generated  = model_BtoA.predict(B_real)
        
                    A_generated = (A_generated + 1)/2 #from [-1,1] to [0,1]
        
                    A_generated = A_generated[0,:,:,:,0]
                    A_generated = A_generated.transpose(1,2,0)
        
                    #print(np.unique(A_generated))
                    A_generated = (A_generated>0.5)*1
        
                    A_generated = A_generated.astype('uint8')
        
                    final_mask_foreground[i:i+_patch_size, j:j+_patch_size, k:k+_patch_size_z] = final_mask_foreground[i:i+_patch_size, j:j+_patch_size, k:k+_patch_size_z] + A_generated
                    final_mask_background[i:i+_patch_size, j:j+_patch_size, k:k+_patch_size_z] = final_mask_background[i:i+_patch_size, j:j+_patch_size, k:k+_patch_size_z] + (1-A_generated)
                    
                    k=k+_step_z
                j=j+_step
            i=i+_step
            pbar.update(1)


    del _slice
    del A_generated
    del B_real

    final_mask = (final_mask_foreground>=final_mask_background)*1

    image = image[0:aux_sizes_or[0], 0:aux_sizes_or[1],0:size_depth]
    print('Image Shape: {}'.format(image.shape))
    print('----------------------------------------')

    final_mask = final_mask[0:aux_sizes_or[0], 0:aux_sizes_or[1],0:aux_sizes_or[2]]


    if _resize:
        final_mask = zoom(final_mask, (0.333/dim_x, 0.333/dim_y, 0.5/dim_z), order=3, mode='nearest')
        final_mask = (final_mask*255.0).astype('uint8')

        final_size_x = np.shape(final_mask)[0]
        final_size_y = np.shape(final_mask)[1]
        final_size_z = np.shape(final_mask)[2]

        aux_mask = np.zeros((initial_image_x, initial_image_y, initial_image_z)).astype('uint8')
        aux_mask[0:min(initial_image_x, final_size_x),0:min(initial_image_y, final_size_y),0:min(initial_image_z, final_size_z)] = final_mask[0:min(initial_image_x, final_size_x),0:min(initial_image_y, final_size_y),0:min(initial_image_z, final_size_z)]

        final_mask = aux_mask.copy()


    print('Mask Shape: {}'.format(final_mask.shape))
    print('----------------------------------------')
    final_mask = final_mask/np.max(final_mask)
    final_mask = final_mask*255.0
    final_mask = final_mask.astype('uint8')


    #closing operation to fill small holes
    mask = final_mask
    mask[mask!=0] = 1
    mask = mask.astype('uint8')
    
    ellipsoid = draw.ellipsoid(9,9,3, spacing=(1,1,1), levelset=False)
    ellipsoid = ellipsoid.astype('uint8')
    ellipsoid = ellipsoid[1:-1,1:-1,1:-1]
    
    #perform closing operation on the mask
    dil = binary_dilation(mask, ellipsoid)
    closed_mask = binary_erosion(dil, ellipsoid)
    closed_mask = (closed_mask*255.0).astype('uint8')
    
    return closed_mask