Skip to content

Instantly share code, notes, and snippets.

@JosephDavidsonKSWH
Created May 14, 2018 09:49
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save JosephDavidsonKSWH/f2365a613e46b2fa1612852e63649214 to your computer and use it in GitHub Desktop.
Save JosephDavidsonKSWH/f2365a613e46b2fa1612852e63649214 to your computer and use it in GitHub Desktop.
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
from torch.autograd import Variable
class Level(nn.Module):
def __init__(self):
super().__init__()
def set_params(self, new_params):
index = 0
for layer in self.children():
for key in layer._parameters.keys():
if layer._parameters[key] is not None:
assert layer._parameters[key].size() == new_params[index].size()
layer._parameters[key] = new_params[index]
index += 1
def apply_new_gradients(self, grads, lr=1):
new_vars = []
for p, g in zip(self.parameters(), grads):
cl = p.clone()
cl = cl - g * lr
cl.retain_grad()
new_vars.append(cl)
self.set_params(new_vars)
def get_gradients(self):
grads = []
for p in self.parameters():
g = p.grad
g = g.detach()
grads.append(g)
g.volatile = False
return grads[0]
def detach_params(self):
print("Detaching")
for layer in self.children():
for key in layer._parameters.keys():
if layer._parameters[key] is not None:
layer._parameters[key] = layer._parameters[key].detach()
#layer._parameters[key].retain_grad()
layer._parameters[key].requires_grad = True
def forward(self, *input):
pass
class BaseLevel(Level):
def __init__(self):
super().__init__()
self.layer = nn.Linear(10, 10, bias=False)
def forward(self, data):
return self.layer(data)
class HigherLevel(Level):
def __init__(self, hidden_size, num_layers, param_count):
super().__init__()
self.hidden_size = hidden_size
self.num_layers = num_layers
self.lstm_model = nn.LSTM(1, self.hidden_size, num_layers=self.num_layers)
self.ff = nn.Linear(self.hidden_size, 1)
self.param_count = param_count
self.device = "cpu"
self.h_state = None
self.reset_hidden()
self.optim = torch.optim.Adam(self.parameters())
def forward(self, inp):
self.lstm_model.flatten_parameters()
# Shape the tensor to have param_count batches
net_inp = inp.view(1, self.param_count, 1)
# Push all the values through the network and get new gradients for each input gradient
out, self.h_state = self.lstm_model(net_inp, self.h_state)
# Squeeze the adjustments between -1 and 1
out = F.tanh(self.ff(out.view(self.param_count, -1)))
return out.view(-1)
def detach_hidden(self):
self.h_state = (self.h_state[0].detach(), self.h_state[1].detach())
def reset_hidden(self):
#self.h_state = (torch.zeros(self.num_layers, self.param_count, self.hidden_size).to(self.device),
# torch.zeros(self.num_layers, self.param_count, self.hidden_size).to(self.device))
self.h_state = (Variable(torch.zeros(self.num_layers, self.param_count, self.hidden_size)).cuda(),
Variable(torch.zeros(self.num_layers, self.param_count, self.hidden_size)).cuda())
def to(self, device):
#super().to(device)
super().cuda()
self.device = device
self.reset_hidden()
if __name__ == "__main__":
cuda = True
device = "cuda" if cuda else "cpu"
input_data = Variable(torch.rand(10, 10).cuda())
target_data = Variable(torch.rand(10, 10).cuda())
#input_data = torch.rand(10, 10).to(device)
#target_data = torch.rand(10, 10).to(device)
base = BaseLevel()
paramcount = sum([int(np.prod(l.size())) for l in base.parameters()])
higher = HigherLevel(5, 1, paramcount)
#base.to(device)
base.cuda()
higher.to(device)
loss_f = torch.nn.MSELoss()
for k in range(1000):
# FF through the base level
out = base(input_data)
loss = loss_f(out, target_data)
print(f"{k}: {loss}")
# Backprop
higher.optim.zero_grad()
loss.backward(retain_graph=True)
higher.optim.step()
# Get grads and push them through the higher level
grads = base.get_gradients()
adj_grads = higher(grads)
# Update the base level with the gradients
base.apply_new_gradients(adj_grads)
# detach at 500 epochs to
if (k+1) % 200 == 0:
higher.detach_hidden()
higher.detach_params()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment