-
-
Save ssnl/f2f56534aefb22d8612dbd7a5da28ed8 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import types | |
from six import add_metaclass | |
class PatchModules(type): | |
def __call__(cls, *args, **kwargs): | |
"""Called when you call ReparamModule(...) """ | |
obj = type.__call__(cls, *args, **kwargs) | |
for module in obj.modules(): | |
if not isinstance(module, ReparamModule): | |
# pesky hacks | |
module._get_weights = types.MethodType(ReparamModule._get_weights, module) | |
module._get_module_names = types.MethodType(ReparamModule._get_module_names, module) | |
obj._module_names = obj._get_module_names() | |
return obj | |
@add_metaclass(PatchModules) | |
class ReparamModule(nn.Module): | |
def get_weights(self, clone=False): | |
weights = self._get_weights() | |
if clone: | |
weights = [w.clone().detach().requires_grad_() for w in weights] | |
return weights | |
def _get_module_names(self, mn_list=None): | |
if mn_list is None: | |
mn_list = [] | |
for name, param in self._parameters.items(): | |
if param is not None: | |
mn_list.append((self, name)) | |
for name, module in self._modules.items(): | |
if module is not None: | |
module._get_module_names(mn_list=mn_list) | |
return tuple(mn_list) | |
def _get_weights(self, w_list=None): | |
if w_list is None: | |
w_list = [] | |
for name, param in self._parameters.items(): | |
if param is not None: | |
w_list.append(param) | |
for name, module in self._modules.items(): | |
if module is not None: | |
module._get_weights(w_list=w_list) | |
return tuple(w_list) | |
def forward_with_weights(self, input, *new_ws): | |
old_ws = self._get_weights() | |
for (m, n), w in zip(self._module_names, new_ws): | |
super(nn.Module, m).__setattr__(n, w) | |
output = self.forward(input) | |
for (m, n), w in zip(self._module_names, old_ws): | |
super(nn.Module, m).__setattr__(n, w) | |
return output |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment