Skip to content

Instantly share code, notes, and snippets.

@wayofnumbers
Last active October 31, 2019 20:47
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save wayofnumbers/4cef7bbabb2e53a2ff8f20e9016a5aff to your computer and use it in GitHub Desktop.
Save wayofnumbers/4cef7bbabb2e53a2ff8f20e9016a5aff to your computer and use it in GitHub Desktop.
Simple PyTorch model to showcase how to build nn.Linear of your own
# We'll use fast.ai to showcase how to build your own 'nn.Linear' module
%matplotlib inline
from fastai.basics import *
import sys
# create and download/prepare our MNIST dataset
path = Config().data_path()/'mnist'
path.mkdir(parents=True)
!wget http://deeplearning.net/data/mnist/mnist.pkl.gz -P {path}
# Get the images downloaded into data set
with gzip.open(path/'mnist.pkl.gz', 'rb') as f:
((x_train, y_train), (x_valid, y_valid), _) = pickle.load(f, encoding='latin-1')
# Have a look at the images and shape
plt.imshow(x_train[0].reshape((28,28)), cmap="gray")
x_train.shape
# convert numpy into PyTorch tensor
x_train,y_train,x_valid,y_valid = map(torch.tensor, (x_train,y_train,x_valid,y_valid))
n,c = x_train.shape
x_train.shape, y_train.min(), y_train.max()
# prepare dataset and create fast.ai DataBunch for training
bs=64
train_ds = TensorDataset(x_train, y_train)
valid_ds = TensorDataset(x_valid, y_valid)
data = DataBunch.create(train_ds, valid_ds, bs=bs)
# create a simple MNIST logistic model with only one Linear layer
class Mnist_Logistic(nn.Module):
def __init__(self):
super().__init__()
self.lin = nn.Linear(784, 10, bias=True)
def forward(self, xb): return self.lin(xb)
model =Mnist_Logistic()
lr=2e-2
loss_func = nn.CrossEntropyLoss()
# define update function with weight decay
def update(x,y,lr):
wd = 1e-5
y_hat = model(x)
# weight decay
w2 = 0.
for p in model.parameters(): w2 += (p**2).sum()
# add to regular loss
loss = loss_func(y_hat, y) + w2*wd
loss.requres_grad = True
loss.backward()
with torch.no_grad():
for p in model.parameters():
p.sub_(lr * p.grad)
p.grad.zero_()
return loss.item()
# iterate through one epoch and plot losses
losses = [update(x,y,lr) for x,y in data.train_dl]
plt.plot(losses);
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment