Skip to content

Instantly share code, notes, and snippets.

@philip-bl
Created April 26, 2019 23:27
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save philip-bl/04b035b81696efee1e15450225b335a5 to your computer and use it in GitHub Desktop.
Save philip-bl/04b035b81696efee1e15450225b335a5 to your computer and use it in GitHub Desktop.
Linear layer for tensors of any shape in pytorch
class ShapyLinear(nn.Module):
"""Can model any affine function from the set of tensors of any (fixed) shape `in_shape` to
the set of tensors of any (fixed) shape `out_shape`.
In forward method the first modes of `inputs` are interpreted as indices of samples,
then come the modes corresponding to `in_shape`. The affine function is applied to each sample."""
def __init__(self, in_shape, out_shape):
""":param in_shape: shape of one input sample
:param out_shape: shape of one output sample"""
super().__init__()
in_shape = tuple(in_shape)
out_shape = tuple(out_shape)
self.input_num_modes = len(in_shape)
weight_data = torch.zeros(*in_shape, *out_shape, requires_grad=True)
xavier_normal_(weight_data)
self.weight = nn.Parameter(weight_data)
self.bias = nn.Parameter(torch.zeros(*out_shape, requires_grad=True))
@property
def weight_contraction_modes(self):
"""Indices of modes of `self.weight` over which tensor contraction with input is performed."""
return tuple(range(self.input_num_modes))
def forward(self, inputs):
# calculate how many modes of `inputs` represent indices of samples
num_sample_modes = inputs.ndimension() - self.input_num_modes
# calculate over what modes of `inputs` we perform tensor contraction
inputs_contraction_modes = tuple(range(num_sample_modes, inputs.ndimension()))
foo = torch.tensordot(inputs, self.weight, dims=(inputs_contraction_modes, self.weight_contraction_modes))
return foo + self.bias
def __repr__(self):
return f"ShapyLinear(input_num_modes={self.input_num_modes}, weight.shape={tuple(self.weight.shape)})"
def test_shapy_linear():
"""Check that it calculates exactly the same thing as `torch.nn.Linear`, except reshaped."""
shapy = ShapyLinear((2, 3), (4, 5))
lin = nn.Linear(2*3, 4*5)
lin.weight.data = shapy.weight.permute(2, 3, 0, 1).reshape(4*5, 2*3)
lin.bias.data = shapy.bias.reshape(4*5)
X = torch.randn(6, 7, 2, 3)
result_lin = lin(X.reshape(6*7, 2*3)).reshape(6, 7, 4, 5)
result_shapy = shapy(X)
assert (result_shapy - result_lin).abs().max().item() < 1e-10
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment