Skip to content

Instantly share code, notes, and snippets.

@shayaf84
Last active November 20, 2021 16:09
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save shayaf84/5e3edab2593c7336b385194c36f90ff2 to your computer and use it in GitHub Desktop.
Save shayaf84/5e3edab2593c7336b385194c36f90ff2 to your computer and use it in GitHub Desktop.
def plotAttributions(baseline,attributions,
image,
cmap=None,
overlay_alpha=0.4):
#Sum of the abs value of the attriutions determining feature importance
mask = tf.reduce_sum(tf.math.abs(attributions), axis=-1)
fig, ax = plt.subplots(nrows=2, ncols=2, squeeze=False, figsize=(8, 8))
#Original image
ax[0, 1].set_title('Original')
ax[0, 1].imshow(image)
ax[0, 1].axis('off')
#Baseline image
ax[0, 0].set_title('Baseline')
ax[0, 0].imshow(baseline)
ax[0, 0].axis('off')
#Just the feature attributions
ax[1, 0].set_title('Attribution mask')
ax[1, 0].imshow(mask, cmap=cmap)
ax[1, 0].axis('off')
#Attributaions overlayed over the image - overlay_alpha is interpolation used
ax[1, 1].set_title('Overlay')
ax[1, 1].imshow(mask, cmap=cmap)
ax[1, 1].imshow(image, alpha=overlay_alpha)
ax[1, 1].axis('off')
plt.tight_layout()
return fig
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment