Skip to content

Instantly share code, notes, and snippets.

@ssnl
Last active May 16, 2018 05:52
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save ssnl/f2f56534aefb22d8612dbd7a5da28ed8 to your computer and use it in GitHub Desktop.
Save ssnl/f2f56534aefb22d8612dbd7a5da28ed8 to your computer and use it in GitHub Desktop.
import types
from six import add_metaclass
class PatchModules(type):
def __call__(cls, *args, **kwargs):
"""Called when you call ReparamModule(...) """
obj = type.__call__(cls, *args, **kwargs)
for module in obj.modules():
if not isinstance(module, ReparamModule):
# pesky hacks
module._get_weights = types.MethodType(ReparamModule._get_weights, module)
module._get_module_names = types.MethodType(ReparamModule._get_module_names, module)
obj._module_names = obj._get_module_names()
return obj
@add_metaclass(PatchModules)
class ReparamModule(nn.Module):
def get_weights(self, clone=False):
weights = self._get_weights()
if clone:
weights = [w.clone().detach().requires_grad_() for w in weights]
return weights
def _get_module_names(self, mn_list=None):
if mn_list is None:
mn_list = []
for name, param in self._parameters.items():
if param is not None:
mn_list.append((self, name))
for name, module in self._modules.items():
if module is not None:
module._get_module_names(mn_list=mn_list)
return tuple(mn_list)
def _get_weights(self, w_list=None):
if w_list is None:
w_list = []
for name, param in self._parameters.items():
if param is not None:
w_list.append(param)
for name, module in self._modules.items():
if module is not None:
module._get_weights(w_list=w_list)
return tuple(w_list)
def forward_with_weights(self, input, *new_ws):
old_ws = self._get_weights()
for (m, n), w in zip(self._module_names, new_ws):
super(nn.Module, m).__setattr__(n, w)
output = self.forward(input)
for (m, n), w in zip(self._module_names, old_ws):
super(nn.Module, m).__setattr__(n, w)
return output
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment