Skip to content

Instantly share code, notes, and snippets.

@dwf
Created October 4, 2013 00:02
Show Gist options
  • Save dwf/6819012 to your computer and use it in GitHub Desktop.
Save dwf/6819012 to your computer and use it in GitHub Desktop.
import numpy as np
import matplotlib.pyplot as plt
def plot_filters(filters, shape, num_rows, num_cols):
"""
Assumes filters are rows (so transpose it that way if necessary).
matplotlib isn't the speediest horse in town, I wouldn't draw
more than a few dozen this way if you need it to frequently update.
An alternative would be to pack stuff into one big array rather
than using subplots (this is what pylearn2's show_filters script
does, but in a somewhat inflexible way).
"""
# Speed things up a bit by drawing only when we're done.
interactive = plt.isinteractive()
try:
if interactive:
plt.ioff()
filters = np.asarray(filters)
assert filters.ndim == 2
num_filters_displayed = min(num_rows * num_cols, filters.shape[0])
min_value, max_value = filters.min(), filters.max()
for i in range(num_filters_displayed):
plt.subplot(num_rows, num_cols, i + 1)
plt.imshow(filters[i].reshape(shape),
cmap=plt.cm.gray, interpolation='nearest')
plt.xticks([])
plt.yticks([])
plt.clim(min_value, max_value)
finally:
# Restore interactive state, if necessary.
if interactive:
plt.show() # Do queued up plotting commands.
plt.ion()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment