Skip to content

Instantly share code, notes, and snippets.

@lolz0r
Created January 22, 2019 17:06
Show Gist options
  • Save lolz0r/c3ab23c6763667edf2d8b01ee154aa13 to your computer and use it in GitHub Desktop.
Save lolz0r/c3ab23c6763667edf2d8b01ee154aa13 to your computer and use it in GitHub Desktop.
Learned basis function, pytorch
class ConvSeluSVD(nn.Module):
def __init__(self, inputSize, outputSize, stride=1, maxpool=False, ownBasis=False):
super(ConvSeluSVD, self).__init__()
self.inputSize = inputSize
self.outputSize = outputSize
self.stride = stride
self.params = Parameter( torch.Tensor(outputSize * inputSize, 1,3).normal_(0, .02))
self.selu = nn.SELU(True)
self.bias = Parameter( torch.zeros(outputSize))
self.maxpool = maxpool
if ownBasis == True:
self.basisWeights = Parameter( torch.Tensor(
[[-0.21535662, -0.30022025, -0.26041868, -0.314888, -0.45471892, -0.3971264,
-0.26603645, -0.3896653, -0.33079177],
[ 0.34970352, 0.50572443, 0.36894855, 0.07661748, 0.08152138, 0.02740295,
-0.28591475, -0.49375448, -0.38343033],
[-0.3019736, -0.02775075, 0.29349312, -0.50207216, -0.05312577, 0.5471206,
-0.39858055, -0.09402011, 0.31616086]] ))
def forward(self, input, basis_=None):
if basis_ is None:
basis_ = self.basisWeights
basis = basis_.unsqueeze(0)
basis = basis.expand(self.params.size(0), basis.size(1), basis.size(2) )
weights = torch.bmm(self.params, basis )
weights = weights.squeeze()
weights = weights.view(self.outputSize, self.inputSize, 3,3)
x = torch.nn.functional.conv2d(input,
weights,
bias=self.bias,
stride=self.stride,
padding=1,
dilation=1,
groups=1)
x = self.selu(x)
if self.maxpool:
x = F.max_pool2d(x, 2)
return x
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment