File size: 2,718 Bytes
3126b1e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
094461a
3126b1e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
import tensorflow as tf
import matplotlib.cm as cm
import numpy as np
try:
    from tensorflow.keras.utils import array_to_img, img_to_array
except:
    from tensorflow.keras.preprocessing.image import array_to_img, img_to_array

def process_image(img, size=(224, 224)):
    img_array = tf.keras.applications.imagenet_utils.preprocess_input(img, mode='torch')
    img_array = tf.image.resize(img_array, size,)[None,]
    return img_array

def get_gradcam_model(model):
    inp = tf.keras.Input(shape=(224, 224, 3))
    feats = model.forward_features(inp)
    preds = model.forward_head(feats)
    return tf.keras.models.Model(inp, [preds, feats])

def get_gradcam_prediction(img, grad_model, process=True, decode=True, pred_index=None, cmap='jet', alpha=0.6):
    """Grad-CAM for a single image

    Args:
        img (np.ndarray): process or raw image without batch_shape e.g. (224, 224, 3)
        grad_model (tf.keras.Model): model with feature map and prediction
        process (bool, optional): imagenet pre-processing. Defaults to True.
        pred_index (int, optional): for particular calss. Defaults to None.
        cmap (str, optional): colormap. Defaults to 'jet'.
        alpha (float, optional): opacity. Defaults to 0.4.

    Returns:
        preds_decode: top5 predictions
        heatmap: gradcam heatmap
    """
    # process image for inference
    if process:
        img_array = process_image(img)
    else:
        img_array = tf.convert_to_tensor(img)[None,]
        if img.min()!=img.max():
            img = (img - img.min())/(img.max() - img.min())
            img = np.uint8(img*255.0)
    # get prediction
    with tf.GradientTape(persistent=True) as tape:
        preds, feats = grad_model(img_array)
        if pred_index is None:
            pred_index = tf.argmax(preds[0])
        class_channel = preds[:, pred_index]
    # compute heatmap
    grads = tape.gradient(class_channel, feats)
    pooled_grads = tf.reduce_mean(grads, axis=(0, 1, 2))
    feats = feats[0]
    heatmap = feats @ pooled_grads[..., tf.newaxis]
    heatmap = tf.squeeze(heatmap)
    heatmap = tf.maximum(heatmap, 0) / tf.math.reduce_max(heatmap)
    heatmap = heatmap.numpy()
    heatmap = np.uint8(255 * heatmap)
    # colorize heatmap
    cmap = cm.get_cmap(cmap)
    colors = cmap(np.arange(256))[:, :3]
    heatmap = colors[heatmap]
    heatmap = array_to_img(heatmap)
    heatmap = heatmap.resize((img.shape[1], img.shape[0]))
    heatmap = img_to_array(heatmap)
    overlay = img + heatmap * alpha
    overlay = array_to_img(overlay)
    # decode prediction
    preds_decode = tf.keras.applications.imagenet_utils.decode_predictions(preds.numpy())[0] if decode else preds
    return preds_decode, overlay