Skip to content

Instantly share code, notes, and snippets.

@Pibborn
Created June 1, 2018 16:12
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save Pibborn/99c367c891f06b09edabb9184af9be72 to your computer and use it in GitHub Desktop.
Save Pibborn/99c367c891f06b09edabb9184af9be72 to your computer and use it in GitHub Desktop.
MLP without nn.Module?
import torch.nn as nn
import torch.nn.functional as F
import math
batch_size = 8
## data loading
from torchvision import datasets, transforms
train_loader = torch.utils.data.DataLoader(
datasets.MNIST('../data', train=True, download=True,
transform=transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.1307,), (0.3081,))
])),
batch_size=batch_size, shuffle=True)
test_loader = torch.utils.data.DataLoader(
datasets.MNIST('../data', train=False, transform=transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.1307,), (0.3081,))
])),
batch_size=batch_size, shuffle=True)
(x, y) = next(iter(train_loader))
y_ = nn.Embedding(10, batch_size)
y_.weight.data = torch.eye(10)
y = y_(y).view(10, -1)
## model definition
w1 = torch.rand(28*28, 100)
b1 = torch.ones(100)
print('x reshaped: {}'.format(x.view(batch_size, 1, -1).size()))
print('w1: {}'.format(w1.size()))
o1 = F.sigmoid(torch.matmul(x.view(batch_size, 1, -1), w1) + b1)
print('o1: {}'.format(o1.size()))
print('w2: {}'.format(w2.size()))
w2 = torch.rand(100, 10)
b2 = torch.ones(10)
o2 = F.log_softmax(F.sigmoid(torch.matmul(o1, w2) + b2))
print('o2: {}'.format(o2.size()))
print('y: {}'.format(y.size()))
print('matmul(o2, y): {}'.format(torch.matmul(o2, y).size()))
print(o2[0])
loss = torch.mean(-sum(torch.matmul(o2, y)))
loss.backward()
print(loss)
## training
for i, (data, target) in enumerate(train_loader):
loss(data)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment