Last active
January 27, 2021 14:47
-
-
Save Ed-Optalysys/3e312449f05d9fb68812e9717699d0cd to your computer and use it in GitHub Desktop.
Bayesian CNN Layer in PyTorch
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# Bayesian conv 2d implementation, adds dropout to kernel values | |
class BayesianConv2d(torch.nn.Module): | |
def __init__(self, in_ch, out_ch, kernel_size=(3, 3), bias=True, stride=1, padding=0, dilation=1, groups=1, drop_rate=0.2): | |
super(BayesianConv2d, self).__init__() | |
self.p = 1 - drop_rate | |
assert(0 <= self.p <= 1) | |
self.stride = stride | |
self.dilation = dilation | |
self.padding = padding | |
self.groups = groups | |
self.conv_op = F.conv2d | |
self.conv_layer = torch.nn.Conv2d(in_channels=in_ch, out_channels=out_ch, kernel_size=kernel_size, | |
bias, stride, padding, dilation, groups) | |
def forward(self, x, mask=None): | |
if mask is None: | |
if using_cuda: | |
mask = torch.ones(self.conv_layer.weight.shape).cuda() * self.p | |
else: | |
mask = torch.ones(self.conv_layer.weight.shape) * self.p | |
mask = torch.bernoulli(mask) | |
masked_kernels = self.conv_layer.weight * mask | |
return self.conv_op(x, masked_kernels, self.conv_layer.bias, self.stride, self.padding, self.dilation, self.groups) | |
# Example usage | |
bc = BayesianConv2d(in_ch=3, out_ch=10, drop_rate=0.333) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment