Skip to content

Instantly share code, notes, and snippets.

What would you like to do?
Helper for debugging FastAI image data bunches
from import *
import pandas as pd
import matplotlib
from matplotlib.gridspec import GridSpec
def visualize_data_bunch(db, figsize=(16,16), layoutRows=3, layoutCols=5):
graphRows = 4
layout = { "rows": layoutRows + graphRows, "cols": layoutCols }
graphCols = int(layoutCols * 0.75)
number_of_random_transforms = len(list(filter(lambda t: isinstance(t, RandTransform), db.train_ds.tfms)))
total_number_of_transforms = len(db.train_ds.tfms)
labels = db.classes
train_ds_counts = pd.value_counts(db.train_ds.y.items, sort=False)
valid_ds_counts = pd.value_counts(db.valid_ds.y.items, sort=False)
width = 0.35
fig = plt.figure(figsize=figsize, constrained_layout=True)
gs = GridSpec(layout["rows"], layout["cols"], figure=fig)
ax1 = plt.subplot(gs.new_subplotspec((0, 0), colspan=graphCols, rowspan=graphRows ))
ax2 = plt.subplot(gs.new_subplotspec((0, graphCols), colspan=layoutCols - graphCols, rowspan=graphRows))
ax1.barh(labels, train_ds_counts.values, width, label='Training')
ax1.barh(labels, valid_ds_counts.values, width, label='Validation')
ax1.set_xlabel('Number of items')
ax1.set_title('Training and Validation datasets')
ax2.text(0, 0.9, "Training dataset: %d items" % (len(db.train_ds)))
ax2.text(0, 0.8, "Validation dataset: %d items" % (len(db.valid_ds)))
ax2.text(0, 0.7, "Batch size: %d items" % (db.batch_size))
ax2.text(0, 0.6, "Image transformations: %d (%d random)" % (total_number_of_transforms, number_of_random_transforms))
ds_item_index = random.randint(0, len(db.train_ds) - 1)
for imageRow in range(graphRows, layout["rows"]):
for imageCol in range(0, layout["cols"]):
ax = plt.subplot(gs.new_subplotspec((imageRow, imageCol), colspan=1))
x,y = db.train_ds[ds_item_index], y=y)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment