|
import tensorflow as tf |
|
import matplotlib.cm as cm |
|
import numpy as np |
|
try: |
|
from tensorflow.keras.utils import array_to_img, img_to_array |
|
except: |
|
from tensorflow.keras.preprocessing.image import array_to_img, img_to_array |
|
|
|
def process_image(img, size=(224, 224)): |
|
img_array = tf.keras.applications.imagenet_utils.preprocess_input(img, mode='torch') |
|
img_array = tf.image.resize(img_array, size,)[None,] |
|
return img_array |
|
|
|
def get_gradcam_model(model): |
|
inp = tf.keras.Input(shape=(224, 224, 3)) |
|
feats = model.forward_features(inp) |
|
preds = model.forward_head(feats) |
|
return tf.keras.models.Model(inp, [preds, feats]) |
|
|
|
def get_gradcam_prediction(img, grad_model, process=True, decode=True, pred_index=None, cmap='jet', alpha=0.6): |
|
"""Grad-CAM for a single image |
|
|
|
Args: |
|
img (np.ndarray): process or raw image without batch_shape e.g. (224, 224, 3) |
|
grad_model (tf.keras.Model): model with feature map and prediction |
|
process (bool, optional): imagenet pre-processing. Defaults to True. |
|
pred_index (int, optional): for particular calss. Defaults to None. |
|
cmap (str, optional): colormap. Defaults to 'jet'. |
|
alpha (float, optional): opacity. Defaults to 0.4. |
|
|
|
Returns: |
|
preds_decode: top5 predictions |
|
heatmap: gradcam heatmap |
|
""" |
|
|
|
if process: |
|
img_array = process_image(img) |
|
else: |
|
img_array = tf.convert_to_tensor(img)[None,] |
|
if img.min()!=img.max(): |
|
img = (img - img.min())/(img.max() - img.min()) |
|
img = np.uint8(img*255.0) |
|
|
|
with tf.GradientTape(persistent=True) as tape: |
|
preds, feats = grad_model(img_array) |
|
if pred_index is None: |
|
pred_index = tf.argmax(preds[0]) |
|
class_channel = preds[:, pred_index] |
|
|
|
grads = tape.gradient(class_channel, feats) |
|
pooled_grads = tf.reduce_mean(grads, axis=(0, 1, 2)) |
|
feats = feats[0] |
|
heatmap = feats @ pooled_grads[..., tf.newaxis] |
|
heatmap = tf.squeeze(heatmap) |
|
heatmap = tf.maximum(heatmap, 0) / tf.math.reduce_max(heatmap) |
|
heatmap = heatmap.numpy() |
|
heatmap = np.uint8(255 * heatmap) |
|
|
|
cmap = cm.get_cmap(cmap) |
|
colors = cmap(np.arange(256))[:, :3] |
|
heatmap = colors[heatmap] |
|
heatmap = array_to_img(heatmap) |
|
heatmap = heatmap.resize((img.shape[1], img.shape[0])) |
|
heatmap = img_to_array(heatmap) |
|
overlay = img + heatmap * alpha |
|
overlay = array_to_img(overlay) |
|
|
|
preds_decode = tf.keras.applications.imagenet_utils.decode_predictions(preds.numpy())[0] if decode else preds |
|
return preds_decode, overlay |