Skip to content

Instantly share code, notes, and snippets.

Show Gist options
  • Save vedraiyani/d7aa2efd21fd138239e961e024ef5c13 to your computer and use it in GitHub Desktop.
Save vedraiyani/d7aa2efd21fd138239e961e024ef5c13 to your computer and use it in GitHub Desktop.
def plot_eli5_top_explanations(
model: Model,
image: np.array,
class_names_mapping: Dict[int, str],
top_preds_count: int = 3,
fig_name: Optional[str] = None
) -> None:
image_columns = 3
image_rows = math.ceil(top_preds_count / image_columns)
preds = model.predict(image)
top_preds_indexes = np.flip(np.argsort(preds))[0,:top_preds_count]
top_preds_values = preds.take(top_preds_indexes)
top_preds_names = np.vectorize(lambda x: class_names[x])(top_preds_indexes)
plt.style.use('dark_background')
fig, axes = plt.subplots(image_rows, image_columns, figsize=(image_columns * 5, image_rows * 5))
[ax.set_axis_off() for ax in axes.flat]
for i, (index, value, name, ax) in \
enumerate(zip(top_preds_indexes, top_preds_values, top_preds_names, axes.flat)):
class_grad_cam = eli5.show_prediction(model, image, targets=[int(index)])
subplot_title = "{}. class: {} pred: {:.3f}".format(i + 1, name, value)
ax.imshow(class_grad_cam)
ax.set_title(subplot_title, pad=20)
if fig_name:
plt.savefig(fig_name)
plt.show()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment