Created
May 7, 2020 11:33
-
-
Save callmephilip/df9f09c98e12a4f09fea43eabfd18717 to your computer and use it in GitHub Desktop.
Helper for debugging FastAI image data bunches
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from fastai.vision import * | |
import pandas as pd | |
import matplotlib | |
from matplotlib.gridspec import GridSpec | |
def visualize_data_bunch(db, figsize=(16,16), layoutRows=3, layoutCols=5): | |
graphRows = 4 | |
layout = { "rows": layoutRows + graphRows, "cols": layoutCols } | |
graphCols = int(layoutCols * 0.75) | |
number_of_random_transforms = len(list(filter(lambda t: isinstance(t, RandTransform), db.train_ds.tfms))) | |
total_number_of_transforms = len(db.train_ds.tfms) | |
labels = db.classes | |
train_ds_counts = pd.value_counts(db.train_ds.y.items, sort=False) | |
valid_ds_counts = pd.value_counts(db.valid_ds.y.items, sort=False) | |
width = 0.35 | |
with plt.style.context('Solarize_Light2'): | |
fig = plt.figure(figsize=figsize, constrained_layout=True) | |
gs = GridSpec(layout["rows"], layout["cols"], figure=fig) | |
ax1 = plt.subplot(gs.new_subplotspec((0, 0), colspan=graphCols, rowspan=graphRows )) | |
ax2 = plt.subplot(gs.new_subplotspec((0, graphCols), colspan=layoutCols - graphCols, rowspan=graphRows)) | |
ax1.barh(labels, train_ds_counts.values, width, label='Training') | |
ax1.barh(labels, valid_ds_counts.values, width, label='Validation') | |
ax1.set_xlabel('Number of items') | |
ax1.set_ylabel('Classes') | |
ax1.set_title('Training and Validation datasets') | |
ax1.legend() | |
ax2.axis('off') | |
ax2.text(0, 0.9, "Training dataset: %d items" % (len(db.train_ds))) | |
ax2.text(0, 0.8, "Validation dataset: %d items" % (len(db.valid_ds))) | |
ax2.text(0, 0.7, "Batch size: %d items" % (db.batch_size)) | |
ax2.text(0, 0.6, "Image transformations: %d (%d random)" % (total_number_of_transforms, number_of_random_transforms)) | |
ds_item_index = random.randint(0, len(db.train_ds) - 1) | |
for imageRow in range(graphRows, layout["rows"]): | |
for imageCol in range(0, layout["cols"]): | |
ax = plt.subplot(gs.new_subplotspec((imageRow, imageCol), colspan=1)) | |
x,y = db.train_ds[ds_item_index] | |
x.show(ax, y=y) | |
plt.show() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment