Skip to content

Instantly share code, notes, and snippets.

Last active November 5, 2018 23:45
Show Gist options
  • Save Redchards/65f1a6f758a1a5c5efb56f83933c3f6e to your computer and use it in GitHub Desktop.
Save Redchards/65f1a6f758a1a5c5efb56f83933c3f6e to your computer and use it in GitHub Desktop.
pytorch implementation of highway networks
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
Created on Mon Nov 5 17:22:52 2018
@author: Vladslinger """
import torch
from torchvision import datasets, transforms
def generate_linear_layers(in_size, out_size, layer_count):
return [torch.nn.Linear(in_size, in_size) for _ in range(layer_count)]
class HighwayNetwork(torch.nn.Module):
def __init__(self, in_size, out_size, layer_count, nonlinear_function=torch.nn.Sigmoid(), activation=torch.nn.ReLU(), bias=-1.0):
super(HighwayNetwork, self).__init__()
self.carry_gate_list = torch.nn.ModuleList(generate_linear_layers(in_size, in_size, layer_count))
self.linear_term_list = torch.nn.ModuleList(generate_linear_layers(in_size, in_size, layer_count))
self.nonlinear_function = nonlinear_function
self.out_size = out_size
self.activation = activation
self.final_layer = torch.nn.Linear(in_size, out_size)
self.output_function = torch.nn.Softmax()
for carry_gate in self.carry_gate_list :
'''self.fc1 = torch.nn.Linear(in_size, 500)
self.relu = torch.nn.ReLU()
self.fc2 = torch.nn.Linear(500, 10)'''
'''def forward(self, x):
out = self.fc1(x)
out = self.relu(out)
out = self.fc2(out)
return out
def forward(self, x):
out = x
for carry_gate, linear_term in zip(self.carry_gate_list, self.linear_term_list):
gate = self.nonlinear_function(carry_gate(out))
out = gate * self.activation(linear_term(out)) + (1.0 - gate) * out
#out = self.activation(linear_term(out))
out = self.final_layer(out)
#out = self.output_function(out)
return out
if __name__ == '__main__':
batch_size = 64
nb_digits: int = 10
train_loader ='../data', train=True, download=True,
batch_size=batch_size, shuffle=True)
test_loader ='../data', train=False, download=True,
batch_size=batch_size, shuffle=True)
y_onehot = torch.FloatTensor(batch_size, nb_digits)
model = HighwayNetwork(28 * 28, 10, 25)
loss = torch.nn.CrossEntropyLoss()
learning_rate = 0.0001
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)
for i, (data, target) in enumerate(train_loader):
data = data.reshape(-1, 28 * 28)
#forward_pass = torch.nn.Softmax()(torch.nn.Linear(28 * 28, nb_digits)(model.forward(data)))
forward_pass = model.forward(data)
err = loss(forward_pass, target)
#for param in model.parameters():
# print(
acc = sum([1 if forward_pass[i].max(0)[1] == target[i] else 0 for i in range(forward_pass.shape[0])]) / forward_pass.shape[0]
print("Epoch {} : Loss {:.4f}".format(i, err.mean().item()))
print("Accuracy {}%".format(acc * 100))
for i, (data, target) in enumerate(train_loader):
data = data.reshape(-1, 28 * 28)
#forward_pass = torch.nn.Softmax()(torch.nn.Linear(28 * 28, nb_digits)(model.forward(data)))
forward_pass = model.forward(data)
#for param in model.parameters():
# print(
print(forward_pass[0].max(0)[1], target[0])
acc = sum([1 if forward_pass[i].max(0)[1] == target[i] else 0 for i in range(forward_pass.shape[0])]) / forward_pass.shape[0]
print("Epoch {} : Loss {:.4f}".format(i, err.mean().item()))
print("Accuracy {}".format(acc))
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment