This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
'''Trains a simple convnet on the Zalando MNIST dataset. | |
Gets to 81.03% test accuracy after 30 epochs | |
(there is still a lot of margin for parameter tuning). | |
3 seconds per epoch on a GeForce GTX 980 GPU with CuDNN 5. | |
''' | |
from __future__ import print_function | |
import numpy as np | |
from mnist import MNIST |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import matplotlib.pyplot as plt | |
def draw_neural_net(ax, left, right, bottom, top, layer_sizes): | |
''' | |
Draw a neural network cartoon using matplotilb. | |
:usage: | |
>>> fig = plt.figure(figsize=(12, 12)) | |
>>> draw_neural_net(fig.gca(), .1, .9, .1, .9, [4, 7, 2]) | |