-
-
Save JosephDavidsonKSWH/f2365a613e46b2fa1612852e63649214 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import torch | |
import torch.nn as nn | |
import torch.nn.functional as F | |
import numpy as np | |
from torch.autograd import Variable | |
class Level(nn.Module): | |
def __init__(self): | |
super().__init__() | |
def set_params(self, new_params): | |
index = 0 | |
for layer in self.children(): | |
for key in layer._parameters.keys(): | |
if layer._parameters[key] is not None: | |
assert layer._parameters[key].size() == new_params[index].size() | |
layer._parameters[key] = new_params[index] | |
index += 1 | |
def apply_new_gradients(self, grads, lr=1): | |
new_vars = [] | |
for p, g in zip(self.parameters(), grads): | |
cl = p.clone() | |
cl = cl - g * lr | |
cl.retain_grad() | |
new_vars.append(cl) | |
self.set_params(new_vars) | |
def get_gradients(self): | |
grads = [] | |
for p in self.parameters(): | |
g = p.grad | |
g = g.detach() | |
grads.append(g) | |
g.volatile = False | |
return grads[0] | |
def detach_params(self): | |
print("Detaching") | |
for layer in self.children(): | |
for key in layer._parameters.keys(): | |
if layer._parameters[key] is not None: | |
layer._parameters[key] = layer._parameters[key].detach() | |
#layer._parameters[key].retain_grad() | |
layer._parameters[key].requires_grad = True | |
def forward(self, *input): | |
pass | |
class BaseLevel(Level): | |
def __init__(self): | |
super().__init__() | |
self.layer = nn.Linear(10, 10, bias=False) | |
def forward(self, data): | |
return self.layer(data) | |
class HigherLevel(Level): | |
def __init__(self, hidden_size, num_layers, param_count): | |
super().__init__() | |
self.hidden_size = hidden_size | |
self.num_layers = num_layers | |
self.lstm_model = nn.LSTM(1, self.hidden_size, num_layers=self.num_layers) | |
self.ff = nn.Linear(self.hidden_size, 1) | |
self.param_count = param_count | |
self.device = "cpu" | |
self.h_state = None | |
self.reset_hidden() | |
self.optim = torch.optim.Adam(self.parameters()) | |
def forward(self, inp): | |
self.lstm_model.flatten_parameters() | |
# Shape the tensor to have param_count batches | |
net_inp = inp.view(1, self.param_count, 1) | |
# Push all the values through the network and get new gradients for each input gradient | |
out, self.h_state = self.lstm_model(net_inp, self.h_state) | |
# Squeeze the adjustments between -1 and 1 | |
out = F.tanh(self.ff(out.view(self.param_count, -1))) | |
return out.view(-1) | |
def detach_hidden(self): | |
self.h_state = (self.h_state[0].detach(), self.h_state[1].detach()) | |
def reset_hidden(self): | |
#self.h_state = (torch.zeros(self.num_layers, self.param_count, self.hidden_size).to(self.device), | |
# torch.zeros(self.num_layers, self.param_count, self.hidden_size).to(self.device)) | |
self.h_state = (Variable(torch.zeros(self.num_layers, self.param_count, self.hidden_size)).cuda(), | |
Variable(torch.zeros(self.num_layers, self.param_count, self.hidden_size)).cuda()) | |
def to(self, device): | |
#super().to(device) | |
super().cuda() | |
self.device = device | |
self.reset_hidden() | |
if __name__ == "__main__": | |
cuda = True | |
device = "cuda" if cuda else "cpu" | |
input_data = Variable(torch.rand(10, 10).cuda()) | |
target_data = Variable(torch.rand(10, 10).cuda()) | |
#input_data = torch.rand(10, 10).to(device) | |
#target_data = torch.rand(10, 10).to(device) | |
base = BaseLevel() | |
paramcount = sum([int(np.prod(l.size())) for l in base.parameters()]) | |
higher = HigherLevel(5, 1, paramcount) | |
#base.to(device) | |
base.cuda() | |
higher.to(device) | |
loss_f = torch.nn.MSELoss() | |
for k in range(1000): | |
# FF through the base level | |
out = base(input_data) | |
loss = loss_f(out, target_data) | |
print(f"{k}: {loss}") | |
# Backprop | |
higher.optim.zero_grad() | |
loss.backward(retain_graph=True) | |
higher.optim.step() | |
# Get grads and push them through the higher level | |
grads = base.get_gradients() | |
adj_grads = higher(grads) | |
# Update the base level with the gradients | |
base.apply_new_gradients(adj_grads) | |
# detach at 500 epochs to | |
if (k+1) % 200 == 0: | |
higher.detach_hidden() | |
higher.detach_params() | |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment