Skip to content

Instantly share code, notes, and snippets.

@JayanthRR
Created February 6, 2019 12:51
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save JayanthRR/07aa7a8d45a027954a8d0d2bb1fbaec7 to your computer and use it in GitHub Desktop.
Save JayanthRR/07aa7a8d45a027954a8d0d2bb1fbaec7 to your computer and use it in GitHub Desktop.
import types
from six import add_metaclass
import torch
from torch import nn
import torch.nn.functional as F
class PatchModules(type):
def __call__(cls, *args, **kwargs):
"""Called when you call ReparamModule(...) """
obj = type.__call__(cls, *args, **kwargs)
for module in obj.modules():
print(module, type(module))
if not isinstance(module, ReparamModule):
# pesky hacks
module._get_weights = types.MethodType(ReparamModule._get_weights, module)
module._get_module_names = types.MethodType(ReparamModule._get_module_names, module)
obj._module_names = obj._get_module_names()
return obj
@add_metaclass(PatchModules)
class ReparamModule(nn.Module):
def __init__(self, D_IN, D_OUT):
super(ReparamModule, self).__init__()
self.fc1 = nn.Linear(D_IN, D_OUT)
def forward(self, x):
return F.relu(self.fc1(x))
def get_weights(self, clone=False):
weights = self._get_weights()
if clone:
weights = [w.clone().detach().requires_grad_() for w in weights]
return weights
def _get_module_names(self, mn_list=None):
if mn_list is None:
mn_list = []
for name, param in self._parameters.items():
if param is not None:
mn_list.append((self, name))
for name, module in self._modules.items():
if module is not None:
module._get_module_names(mn_list=mn_list)
return tuple(mn_list)
def _get_weights(self, w_list=None):
if w_list is None:
w_list = []
for name, param in self._parameters.items():
if param is not None:
w_list.append(param)
for name, module in self._modules.items():
if module is not None:
module._get_weights(w_list=w_list)
return tuple(w_list)
def forward_with_weights(self, input, *new_ws):
old_ws = self._get_weights()
for (m, n), w in zip(self._module_names, new_ws):
super(nn.Module, m).__setattr__(n, w)
output = self.forward(input)
return output
def test_assignment():
batch_size = 20
D_IN = 100
D_OUT = 2
rand_input = torch.randn(batch_size, D_IN)
rand_output = torch.randn(batch_size, D_OUT)
loss_fn = torch.nn.MSELoss(reduction='sum')
queue = []
parameter_gradients = []
rpm = ReparamModule(D_IN, D_OUT)
# Compute output with initial parameters
output_1 = rpm.forward(rand_input)
rpm.zero_grad()
loss_1 = loss_fn(rand_output, output_1)
print("loss_1 :", loss_1.item())
loss_1.backward()
queue.append(rpm.get_weights(clone=True))
gradients = []
gradients.append(rpm.fc1.weight.grad.clone().detach())
gradients.append(rpm.fc1.bias.grad.clone().detach())
# Weight update
for f in rpm.parameters():
f.data.sub_(f.grad.data * 0.1)
parameter_gradients.append(gradients)
queue.append(rpm.get_weights(clone=True))
# Compute output with updated parameters
output_2 = rpm.forward(rand_input)
rpm.zero_grad()
loss_2 = loss_fn(rand_output, output_2)
print("loss_2 :", loss_2.item())
loss_2.backward()
gradients = []
gradients.append(rpm.fc1.weight.grad.clone().detach())
gradients.append(rpm.fc1.bias.grad.clone().detach())
# Weight update
for f in rpm.parameters():
f.data.sub_(f.grad.data * 0.1)
parameter_gradients.append(gradients)
queue.append(rpm.get_weights(clone=True))
# Compute with old parameters
old_params = queue[0]
output_3 = rpm.forward_with_weights(rand_input, *old_params)
rpm.zero_grad()
loss_3 = loss_fn(rand_output, output_3)
print("loss_3 :", loss_3.item())
loss_3.backward()
for f in rpm.parameters():
print(f.grad.data) #This is 0. Because the parameters were not used for forward.
gradients = []
gradients.append(rpm.fc1.weight.grad.clone().detach())
gradients.append(rpm.fc1.bias.grad.clone().detach())
parameter_gradients.append(gradients)
queue.append(rpm.get_weights(clone=True))
#rpm.get_weights does not return the new weights assigned in forward_with_weights
print("Should be True,", torch.all(torch.eq(output_1, output_3)))
print("Should be False,", torch.all(torch.eq(output_1, output_2)))
grad1_w = parameter_gradients[0][0]
grad2_w = parameter_gradients[2][0]
grad1_b = parameter_gradients[0][1]
grad2_b = parameter_gradients[2][1]
print("Should be True,", torch.all(torch.eq(grad1_w, grad2_w)))
print("Should be True,", torch.all(torch.eq(grad1_b, grad2_b)))
test_assignment()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment