Skip to content

Instantly share code, notes, and snippets.

@callmephilip
Created May 7, 2020 11:33
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save callmephilip/df9f09c98e12a4f09fea43eabfd18717 to your computer and use it in GitHub Desktop.
Save callmephilip/df9f09c98e12a4f09fea43eabfd18717 to your computer and use it in GitHub Desktop.
Helper for debugging FastAI image data bunches
from fastai.vision import *
import pandas as pd
import matplotlib
from matplotlib.gridspec import GridSpec
def visualize_data_bunch(db, figsize=(16,16), layoutRows=3, layoutCols=5):
graphRows = 4
layout = { "rows": layoutRows + graphRows, "cols": layoutCols }
graphCols = int(layoutCols * 0.75)
number_of_random_transforms = len(list(filter(lambda t: isinstance(t, RandTransform), db.train_ds.tfms)))
total_number_of_transforms = len(db.train_ds.tfms)
labels = db.classes
train_ds_counts = pd.value_counts(db.train_ds.y.items, sort=False)
valid_ds_counts = pd.value_counts(db.valid_ds.y.items, sort=False)
width = 0.35
with plt.style.context('Solarize_Light2'):
fig = plt.figure(figsize=figsize, constrained_layout=True)
gs = GridSpec(layout["rows"], layout["cols"], figure=fig)
ax1 = plt.subplot(gs.new_subplotspec((0, 0), colspan=graphCols, rowspan=graphRows ))
ax2 = plt.subplot(gs.new_subplotspec((0, graphCols), colspan=layoutCols - graphCols, rowspan=graphRows))
ax1.barh(labels, train_ds_counts.values, width, label='Training')
ax1.barh(labels, valid_ds_counts.values, width, label='Validation')
ax1.set_xlabel('Number of items')
ax1.set_ylabel('Classes')
ax1.set_title('Training and Validation datasets')
ax1.legend()
ax2.axis('off')
ax2.text(0, 0.9, "Training dataset: %d items" % (len(db.train_ds)))
ax2.text(0, 0.8, "Validation dataset: %d items" % (len(db.valid_ds)))
ax2.text(0, 0.7, "Batch size: %d items" % (db.batch_size))
ax2.text(0, 0.6, "Image transformations: %d (%d random)" % (total_number_of_transforms, number_of_random_transforms))
ds_item_index = random.randint(0, len(db.train_ds) - 1)
for imageRow in range(graphRows, layout["rows"]):
for imageCol in range(0, layout["cols"]):
ax = plt.subplot(gs.new_subplotspec((imageRow, imageCol), colspan=1))
x,y = db.train_ds[ds_item_index]
x.show(ax, y=y)
plt.show()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment