Last active
October 31, 2019 20:47
-
-
Save wayofnumbers/4cef7bbabb2e53a2ff8f20e9016a5aff to your computer and use it in GitHub Desktop.
Simple PyTorch model to showcase how to build nn.Linear of your own
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# We'll use fast.ai to showcase how to build your own 'nn.Linear' module | |
%matplotlib inline | |
from fastai.basics import * | |
import sys | |
# create and download/prepare our MNIST dataset | |
path = Config().data_path()/'mnist' | |
path.mkdir(parents=True) | |
!wget http://deeplearning.net/data/mnist/mnist.pkl.gz -P {path} | |
# Get the images downloaded into data set | |
with gzip.open(path/'mnist.pkl.gz', 'rb') as f: | |
((x_train, y_train), (x_valid, y_valid), _) = pickle.load(f, encoding='latin-1') | |
# Have a look at the images and shape | |
plt.imshow(x_train[0].reshape((28,28)), cmap="gray") | |
x_train.shape | |
# convert numpy into PyTorch tensor | |
x_train,y_train,x_valid,y_valid = map(torch.tensor, (x_train,y_train,x_valid,y_valid)) | |
n,c = x_train.shape | |
x_train.shape, y_train.min(), y_train.max() | |
# prepare dataset and create fast.ai DataBunch for training | |
bs=64 | |
train_ds = TensorDataset(x_train, y_train) | |
valid_ds = TensorDataset(x_valid, y_valid) | |
data = DataBunch.create(train_ds, valid_ds, bs=bs) | |
# create a simple MNIST logistic model with only one Linear layer | |
class Mnist_Logistic(nn.Module): | |
def __init__(self): | |
super().__init__() | |
self.lin = nn.Linear(784, 10, bias=True) | |
def forward(self, xb): return self.lin(xb) | |
model =Mnist_Logistic() | |
lr=2e-2 | |
loss_func = nn.CrossEntropyLoss() | |
# define update function with weight decay | |
def update(x,y,lr): | |
wd = 1e-5 | |
y_hat = model(x) | |
# weight decay | |
w2 = 0. | |
for p in model.parameters(): w2 += (p**2).sum() | |
# add to regular loss | |
loss = loss_func(y_hat, y) + w2*wd | |
loss.requres_grad = True | |
loss.backward() | |
with torch.no_grad(): | |
for p in model.parameters(): | |
p.sub_(lr * p.grad) | |
p.grad.zero_() | |
return loss.item() | |
# iterate through one epoch and plot losses | |
losses = [update(x,y,lr) for x,y in data.train_dl] | |
plt.plot(losses); | |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment