Skip to content

Instantly share code, notes, and snippets.

@InnovArul
Last active April 11, 2024 11:01
Show Gist options
  • Save InnovArul/500e0c57e88300651f8005f9bd0d12bc to your computer and use it in GitHub Desktop.
Save InnovArul/500e0c57e88300651f8005f9bd0d12bc to your computer and use it in GitHub Desktop.
tied linear layer experiment
import torch, torch.nn as nn, torch.nn.functional as F
import numpy as np
import torch.optim as optim
# tied autoencoder using off the shelf nn modules
class TiedAutoEncoderOffTheShelf(nn.Module):
def __init__(self, inp, out, weight):
super().__init__()
self.encoder = nn.Linear(inp, out, bias=False)
self.decoder = nn.Linear(out, inp, bias=False)
# tie the weights
#print(type(self.encoder.weight))
self.encoder.weight = nn.Parameter(weight)
self.decoder.weight = nn.Parameter(weight.transpose(0,1))
def forward(self, input):
encoded_feats = self.encoder(input)
reconstructed_output = self.decoder(encoded_feats)
return encoded_feats, reconstructed_output
# tied auto encoder using functional calls
class TiedAutoEncoderFunctional(nn.Module):
def __init__(self, inp, out):
super().__init__()
self.param = nn.Parameter(torch.randn(out, inp))
def forward(self, input):
encoded_feats = F.linear(input, self.param)
reconstructed_output = F.linear(encoded_feats, self.param.t())
return encoded_feats, reconstructed_output
# mixed approach
class MixedAppraochTiedAutoEncoder(nn.Module):
def __init__(self, inp, out, weight):
super().__init__()
self.encoder = nn.Linear(inp, out, bias=False)
self.encoder.weight = nn.Parameter(weight)
def forward(self, input):
encoded_feats = self.encoder(input)
reconstructed_output = F.linear(encoded_feats, self.encoder.weight.t())
return encoded_feats, reconstructed_output
if __name__ == '__main__':
tied_module_F = TiedAutoEncoderFunctional(5, 6)
# instantiate off-the-shelf auto-encoder
offshelf_weight = tied_module_F.param.data.clone()
tied_module_offshelf = TiedAutoEncoderOffTheShelf(5, 6, offshelf_weight)
# instantiate mixed type auto-encoder
mixed_weight = tied_module_F.param.data.clone()
tied_module_mixed = MixedAppraochTiedAutoEncoder(5, 6, mixed_weight)
assert torch.equal(tied_module_offshelf.encoder.weight.data, tied_module_F.param.data), 'F vs offshelf: param not equal'
assert torch.equal(tied_module_mixed.encoder.weight.data, tied_module_F.param.data), 'F vs mixed: param not equal'
optim_F = optim.SGD(tied_module_F.parameters(), lr=1)
optim_offshelf = optim.SGD(tied_module_offshelf.parameters(), lr=1)
optim_mixed = optim.SGD(tied_module_mixed.parameters(), lr=1)
# common input
input = torch.rand(5, 5)
# zero the gradients
optim_F.zero_grad()
optim_offshelf.zero_grad()
optim_mixed.zero_grad()
# get output from both modules
reconstruction_F = tied_module_F(input)
reconstruction_offshelf = tied_module_offshelf(input)
reconstruction_mixed = tied_module_mixed(input)
# back propagation
reconstruction_F[1].sum().backward()
reconstruction_offshelf[1].sum().backward()
reconstruction_mixed[1].sum().backward()
# step
optim_F.step()
optim_offshelf.step()
optim_mixed.step()
# check the equality of output and parameters
assert torch.equal(reconstruction_offshelf[0], reconstruction_F[0]), 'F vs offshelf: bottleneck not equal'
assert torch.equal(reconstruction_offshelf[1], reconstruction_F[1]), 'F vs offshelf: output not equal'
assert (tied_module_offshelf.encoder.weight.data - tied_module_F.param.data).pow(2).sum() < 1e-10, 'F vs offshelf: param after step not equal'
assert (tied_module_offshelf.encoder.weight.data - offshelf_weight).pow(2).sum() < 1e-10, 'F vs mixed: source weight tensor not equal'
assert torch.equal(reconstruction_mixed[0], reconstruction_F[0]), 'F vs mixed: bottleneck not equal'
assert torch.equal(reconstruction_mixed[1], reconstruction_F[1]), 'F vs mixed: output not equal'
assert (tied_module_mixed.encoder.weight.data - tied_module_F.param.data).pow(2).sum() < 1e-10, 'F vs mixed: param after step not equal'
assert (tied_module_mixed.encoder.weight.data - mixed_weight).pow(2).sum() < 1e-10, 'F vs mixed: param after step not equal'
print('success!')
@InnovArul
Copy link
Author

InnovArul commented Dec 12, 2020

It seems L14 has to be changed as the following for for PyTorch 1.7.0 to be working:

self.decoder.weight = nn.Parameter(self.encoder.weight.transpose(0,1))

Thanks for letting me know. I have updated the gist to reflect it.

Note for myself: nn.Parameter shares the underlying memory with the given torch.Tensor

@AFAgarap
Copy link

AFAgarap commented Mar 9, 2021

Hello, @InnovArul Thank you for this nice work. But may I ask why can't we just use this approach?

@InnovArul
Copy link
Author

InnovArul commented Mar 11, 2021

Hello, @InnovArul Thank you for this nice work. But may I ask why can't we just use this approach?

Ideally, the decoder.weight is the transpose of encoder.weight.
So, we need decoder.weight = encoder.weight.t(). However, this will throw the following error: cannot assign 'torch.FloatTensor' as parameter 'weight' (torch.nn.Parameter or None expected).

To overcome that, we need to wrap it with nn.Parameter().
Does this answer your question?

@AFAgarap
Copy link

Yes. Thank you. Actually, that's what I did. But maybe my question should have been more on are there any merits to using the other approaches you have enlisted here?

@InnovArul
Copy link
Author

To me, tied auto-encoder with functional calls looks clean without involving nn.Parameter(another_layer.weight) .
Apart from that, I do not see any particular merits in other approaches.

@azimnoralfian
Copy link

Hello, @InnovArul Thank you for this nice work! I am currently building an Autoencoder for dimensionality reduction with beginner level of knowledge in PyTorch. Sorry if my question is very trivial, but is the same concept can be applied to a non-linear model? I was thinking of putting gradient=False in the decoder layer so that the model only train the weights for encoder only. Is this a correct approach?

@InnovArul
Copy link
Author

Hi, Sorry that I missed your message. I hope you already found the answer.

Just to answer your question, yes, in my understanding, setting decoder.requires_grad_(False) would not add the gradient from decoder to the weights. and it will let the weights to only receive gradients from encoder.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment