Skip to content

Instantly share code, notes, and snippets.

@kris-singh
Created January 30, 2019 09:35
Show Gist options
  • Star 1 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save kris-singh/51b71df87da0d21eac14080c7132daa2 to your computer and use it in GitHub Desktop.
Save kris-singh/51b71df87da0d21eac14080c7132daa2 to your computer and use it in GitHub Desktop.
CrudeImplementation.
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision import datasets, transforms
def backward_hook_function(grad_out):
print(grad_out.shape)
print(grad_out.norm())
# print("grad_norm", grad_in.norm())
class DummyModel(nn.Module):
def __init__(self, input_size, hidden_size1, hidden_size2, num_classes):
super(DummyModel, self).__init__()
self.layer1 = nn.Linear(input_size, hidden_size1)
self.layer2 = nn.Linear(hidden_size1, hidden_size2)
self.out_layer = nn.Linear(hidden_size2, num_classes)
def forward(self, x):
x = F.relu(self.layer1(x))
x = F.relu(self.layer2(x))
out = F.log_softmax(self.out_layer(x), dim=1)
return out
def train(model, dataset, device, criterion):
for module in model.children():
if isinstance(module, (nn.Linear, nn.Conv2d)):
for params in module.parameters():
params.register_hook(backward_hook_function)
for key, (data, target) in enumerate(dataset):
data = data.to(device)
labels = target.to(device)
data = data.view(1, -1)
output = model(data)
loss = criterion(output, labels)
loss.backward()
break
def main():
use_cuda = torch.cuda.is_available()
device = torch.device("cuda" if use_cuda else "cpu")
kwargs = {'num_workers': 1, 'pin_memory': True} if use_cuda else {}
train_loader = torch.utils.data.DataLoader(
datasets.MNIST('public_data', train=True, download=True,
transform=transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.1307,), (0.3081,))
])),
batch_size=1, shuffle=True, **kwargs)
model = DummyModel(784, 512, 128, 10).to(device)
criterion = F.nll_loss
train(model, train_loader, device, criterion)
if __name__ == "__main__":
main()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment