Created
November 28, 2019 18:54
-
-
Save dhuynh95/d639ac3506a73dc0c6af48ebfc7dfaf1 to your computer and use it in GitHub Desktop.
Helper functions to turn regular NNs to BNNs
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
class CustomDropout(nn.Module): | |
"""Custom Dropout module to be used as a baseline for MC Dropout""" | |
def __init__(self, p:float, activate=True): | |
super().__init__() | |
self.activate = activate | |
self.p = p | |
def forward(self, x): | |
return nn.functional.dropout(x, self.p, training=self.training or self.activate) | |
def extra_repr(self): | |
return f"p={self.p}, activate={self.activate}" | |
def switch_custom_dropout(m, activate:bool=True, verbose:bool=False): | |
"""Turn all Custom Dropouts training mode to true or false according to the variable activate""" | |
for c in m.children(): | |
if isinstance(c, CustomDropout): | |
print(f"Current active : {c.activate}") | |
print(f"Switching to : {activate}") | |
c.activate = activate | |
else: | |
switch_custom_dropout(c, activate=activate) | |
def convert_layers(model:nn.Module, original:nn.Module, replacement:nn.Module, get_args:Callable=None, | |
additional_args:dict={}): | |
"""Convert modules of type "original" to "replacement" inside the model | |
get_args : a function to use on the original module to eventually get its arguements to pass to the new module | |
additional_args : a dictionary to add more args to the new module | |
""" | |
for child_name, child in model.named_children(): | |
if isinstance(child, original): | |
# First we grab args from the child | |
if get_args: | |
original_args = get_args(child) | |
else: | |
original_args = {} | |
# If we want to provide additional args | |
if additional_args: | |
args = {**original_args, **additional_args} | |
else: | |
args = original_args | |
new_layer = replacement(**args) | |
setattr(model, child_name, new_layer) | |
else: | |
convert_layers(child, original, replacement, | |
get_args, additional_args) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment