gcvit-tf / gcvit /utils /gradcam.py
awsaf49's picture
alpha=0.6
094461a
import tensorflow as tf
import matplotlib.cm as cm
import numpy as np
try:
from tensorflow.keras.utils import array_to_img, img_to_array
except:
from tensorflow.keras.preprocessing.image import array_to_img, img_to_array
def process_image(img, size=(224, 224)):
img_array = tf.keras.applications.imagenet_utils.preprocess_input(img, mode='torch')
img_array = tf.image.resize(img_array, size,)[None,]
return img_array
def get_gradcam_model(model):
inp = tf.keras.Input(shape=(224, 224, 3))
feats = model.forward_features(inp)
preds = model.forward_head(feats)
return tf.keras.models.Model(inp, [preds, feats])
def get_gradcam_prediction(img, grad_model, process=True, decode=True, pred_index=None, cmap='jet', alpha=0.6):
"""Grad-CAM for a single image
Args:
img (np.ndarray): process or raw image without batch_shape e.g. (224, 224, 3)
grad_model (tf.keras.Model): model with feature map and prediction
process (bool, optional): imagenet pre-processing. Defaults to True.
pred_index (int, optional): for particular calss. Defaults to None.
cmap (str, optional): colormap. Defaults to 'jet'.
alpha (float, optional): opacity. Defaults to 0.4.
Returns:
preds_decode: top5 predictions
heatmap: gradcam heatmap
"""
# process image for inference
if process:
img_array = process_image(img)
else:
img_array = tf.convert_to_tensor(img)[None,]
if img.min()!=img.max():
img = (img - img.min())/(img.max() - img.min())
img = np.uint8(img*255.0)
# get prediction
with tf.GradientTape(persistent=True) as tape:
preds, feats = grad_model(img_array)
if pred_index is None:
pred_index = tf.argmax(preds[0])
class_channel = preds[:, pred_index]
# compute heatmap
grads = tape.gradient(class_channel, feats)
pooled_grads = tf.reduce_mean(grads, axis=(0, 1, 2))
feats = feats[0]
heatmap = feats @ pooled_grads[..., tf.newaxis]
heatmap = tf.squeeze(heatmap)
heatmap = tf.maximum(heatmap, 0) / tf.math.reduce_max(heatmap)
heatmap = heatmap.numpy()
heatmap = np.uint8(255 * heatmap)
# colorize heatmap
cmap = cm.get_cmap(cmap)
colors = cmap(np.arange(256))[:, :3]
heatmap = colors[heatmap]
heatmap = array_to_img(heatmap)
heatmap = heatmap.resize((img.shape[1], img.shape[0]))
heatmap = img_to_array(heatmap)
overlay = img + heatmap * alpha
overlay = array_to_img(overlay)
# decode prediction
preds_decode = tf.keras.applications.imagenet_utils.decode_predictions(preds.numpy())[0] if decode else preds
return preds_decode, overlay