Created
February 6, 2019 12:51
-
-
Save JayanthRR/07aa7a8d45a027954a8d0d2bb1fbaec7 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import types | |
from six import add_metaclass | |
import torch | |
from torch import nn | |
import torch.nn.functional as F | |
class PatchModules(type): | |
def __call__(cls, *args, **kwargs): | |
"""Called when you call ReparamModule(...) """ | |
obj = type.__call__(cls, *args, **kwargs) | |
for module in obj.modules(): | |
print(module, type(module)) | |
if not isinstance(module, ReparamModule): | |
# pesky hacks | |
module._get_weights = types.MethodType(ReparamModule._get_weights, module) | |
module._get_module_names = types.MethodType(ReparamModule._get_module_names, module) | |
obj._module_names = obj._get_module_names() | |
return obj | |
@add_metaclass(PatchModules) | |
class ReparamModule(nn.Module): | |
def __init__(self, D_IN, D_OUT): | |
super(ReparamModule, self).__init__() | |
self.fc1 = nn.Linear(D_IN, D_OUT) | |
def forward(self, x): | |
return F.relu(self.fc1(x)) | |
def get_weights(self, clone=False): | |
weights = self._get_weights() | |
if clone: | |
weights = [w.clone().detach().requires_grad_() for w in weights] | |
return weights | |
def _get_module_names(self, mn_list=None): | |
if mn_list is None: | |
mn_list = [] | |
for name, param in self._parameters.items(): | |
if param is not None: | |
mn_list.append((self, name)) | |
for name, module in self._modules.items(): | |
if module is not None: | |
module._get_module_names(mn_list=mn_list) | |
return tuple(mn_list) | |
def _get_weights(self, w_list=None): | |
if w_list is None: | |
w_list = [] | |
for name, param in self._parameters.items(): | |
if param is not None: | |
w_list.append(param) | |
for name, module in self._modules.items(): | |
if module is not None: | |
module._get_weights(w_list=w_list) | |
return tuple(w_list) | |
def forward_with_weights(self, input, *new_ws): | |
old_ws = self._get_weights() | |
for (m, n), w in zip(self._module_names, new_ws): | |
super(nn.Module, m).__setattr__(n, w) | |
output = self.forward(input) | |
return output | |
def test_assignment(): | |
batch_size = 20 | |
D_IN = 100 | |
D_OUT = 2 | |
rand_input = torch.randn(batch_size, D_IN) | |
rand_output = torch.randn(batch_size, D_OUT) | |
loss_fn = torch.nn.MSELoss(reduction='sum') | |
queue = [] | |
parameter_gradients = [] | |
rpm = ReparamModule(D_IN, D_OUT) | |
# Compute output with initial parameters | |
output_1 = rpm.forward(rand_input) | |
rpm.zero_grad() | |
loss_1 = loss_fn(rand_output, output_1) | |
print("loss_1 :", loss_1.item()) | |
loss_1.backward() | |
queue.append(rpm.get_weights(clone=True)) | |
gradients = [] | |
gradients.append(rpm.fc1.weight.grad.clone().detach()) | |
gradients.append(rpm.fc1.bias.grad.clone().detach()) | |
# Weight update | |
for f in rpm.parameters(): | |
f.data.sub_(f.grad.data * 0.1) | |
parameter_gradients.append(gradients) | |
queue.append(rpm.get_weights(clone=True)) | |
# Compute output with updated parameters | |
output_2 = rpm.forward(rand_input) | |
rpm.zero_grad() | |
loss_2 = loss_fn(rand_output, output_2) | |
print("loss_2 :", loss_2.item()) | |
loss_2.backward() | |
gradients = [] | |
gradients.append(rpm.fc1.weight.grad.clone().detach()) | |
gradients.append(rpm.fc1.bias.grad.clone().detach()) | |
# Weight update | |
for f in rpm.parameters(): | |
f.data.sub_(f.grad.data * 0.1) | |
parameter_gradients.append(gradients) | |
queue.append(rpm.get_weights(clone=True)) | |
# Compute with old parameters | |
old_params = queue[0] | |
output_3 = rpm.forward_with_weights(rand_input, *old_params) | |
rpm.zero_grad() | |
loss_3 = loss_fn(rand_output, output_3) | |
print("loss_3 :", loss_3.item()) | |
loss_3.backward() | |
for f in rpm.parameters(): | |
print(f.grad.data) #This is 0. Because the parameters were not used for forward. | |
gradients = [] | |
gradients.append(rpm.fc1.weight.grad.clone().detach()) | |
gradients.append(rpm.fc1.bias.grad.clone().detach()) | |
parameter_gradients.append(gradients) | |
queue.append(rpm.get_weights(clone=True)) | |
#rpm.get_weights does not return the new weights assigned in forward_with_weights | |
print("Should be True,", torch.all(torch.eq(output_1, output_3))) | |
print("Should be False,", torch.all(torch.eq(output_1, output_2))) | |
grad1_w = parameter_gradients[0][0] | |
grad2_w = parameter_gradients[2][0] | |
grad1_b = parameter_gradients[0][1] | |
grad2_b = parameter_gradients[2][1] | |
print("Should be True,", torch.all(torch.eq(grad1_w, grad2_w))) | |
print("Should be True,", torch.all(torch.eq(grad1_b, grad2_b))) | |
test_assignment() | |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment