-
-
Save InnovArul/500e0c57e88300651f8005f9bd0d12bc to your computer and use it in GitHub Desktop.
| import torch, torch.nn as nn, torch.nn.functional as F | |
| import numpy as np | |
| import torch.optim as optim | |
| # tied autoencoder using off the shelf nn modules | |
| class TiedAutoEncoderOffTheShelf(nn.Module): | |
| def __init__(self, inp, out, weight): | |
| super().__init__() | |
| self.encoder = nn.Linear(inp, out, bias=False) | |
| self.decoder = nn.Linear(out, inp, bias=False) | |
| # tie the weights | |
| #print(type(self.encoder.weight)) | |
| self.encoder.weight = nn.Parameter(weight) | |
| self.decoder.weight = nn.Parameter(weight.transpose(0,1)) | |
| def forward(self, input): | |
| encoded_feats = self.encoder(input) | |
| reconstructed_output = self.decoder(encoded_feats) | |
| return encoded_feats, reconstructed_output | |
| # tied auto encoder using functional calls | |
| class TiedAutoEncoderFunctional(nn.Module): | |
| def __init__(self, inp, out): | |
| super().__init__() | |
| self.param = nn.Parameter(torch.randn(out, inp)) | |
| def forward(self, input): | |
| encoded_feats = F.linear(input, self.param) | |
| reconstructed_output = F.linear(encoded_feats, self.param.t()) | |
| return encoded_feats, reconstructed_output | |
| # mixed approach | |
| class MixedAppraochTiedAutoEncoder(nn.Module): | |
| def __init__(self, inp, out, weight): | |
| super().__init__() | |
| self.encoder = nn.Linear(inp, out, bias=False) | |
| self.encoder.weight = nn.Parameter(weight) | |
| def forward(self, input): | |
| encoded_feats = self.encoder(input) | |
| reconstructed_output = F.linear(encoded_feats, self.encoder.weight.t()) | |
| return encoded_feats, reconstructed_output | |
| if __name__ == '__main__': | |
| tied_module_F = TiedAutoEncoderFunctional(5, 6) | |
| # instantiate off-the-shelf auto-encoder | |
| offshelf_weight = tied_module_F.param.data.clone() | |
| tied_module_offshelf = TiedAutoEncoderOffTheShelf(5, 6, offshelf_weight) | |
| # instantiate mixed type auto-encoder | |
| mixed_weight = tied_module_F.param.data.clone() | |
| tied_module_mixed = MixedAppraochTiedAutoEncoder(5, 6, mixed_weight) | |
| assert torch.equal(tied_module_offshelf.encoder.weight.data, tied_module_F.param.data), 'F vs offshelf: param not equal' | |
| assert torch.equal(tied_module_mixed.encoder.weight.data, tied_module_F.param.data), 'F vs mixed: param not equal' | |
| optim_F = optim.SGD(tied_module_F.parameters(), lr=1) | |
| optim_offshelf = optim.SGD(tied_module_offshelf.parameters(), lr=1) | |
| optim_mixed = optim.SGD(tied_module_mixed.parameters(), lr=1) | |
| # common input | |
| input = torch.rand(5, 5) | |
| # zero the gradients | |
| optim_F.zero_grad() | |
| optim_offshelf.zero_grad() | |
| optim_mixed.zero_grad() | |
| # get output from both modules | |
| reconstruction_F = tied_module_F(input) | |
| reconstruction_offshelf = tied_module_offshelf(input) | |
| reconstruction_mixed = tied_module_mixed(input) | |
| # back propagation | |
| reconstruction_F[1].sum().backward() | |
| reconstruction_offshelf[1].sum().backward() | |
| reconstruction_mixed[1].sum().backward() | |
| # step | |
| optim_F.step() | |
| optim_offshelf.step() | |
| optim_mixed.step() | |
| # check the equality of output and parameters | |
| assert torch.equal(reconstruction_offshelf[0], reconstruction_F[0]), 'F vs offshelf: bottleneck not equal' | |
| assert torch.equal(reconstruction_offshelf[1], reconstruction_F[1]), 'F vs offshelf: output not equal' | |
| assert (tied_module_offshelf.encoder.weight.data - tied_module_F.param.data).pow(2).sum() < 1e-10, 'F vs offshelf: param after step not equal' | |
| assert (tied_module_offshelf.encoder.weight.data - offshelf_weight).pow(2).sum() < 1e-10, 'F vs mixed: source weight tensor not equal' | |
| assert torch.equal(reconstruction_mixed[0], reconstruction_F[0]), 'F vs mixed: bottleneck not equal' | |
| assert torch.equal(reconstruction_mixed[1], reconstruction_F[1]), 'F vs mixed: output not equal' | |
| assert (tied_module_mixed.encoder.weight.data - tied_module_F.param.data).pow(2).sum() < 1e-10, 'F vs mixed: param after step not equal' | |
| assert (tied_module_mixed.encoder.weight.data - mixed_weight).pow(2).sum() < 1e-10, 'F vs mixed: param after step not equal' | |
| print('success!') |
To me, tied auto-encoder with functional calls looks clean without involving nn.Parameter(another_layer.weight) .
Apart from that, I do not see any particular merits in other approaches.
Hello, @InnovArul Thank you for this nice work! I am currently building an Autoencoder for dimensionality reduction with beginner level of knowledge in PyTorch. Sorry if my question is very trivial, but is the same concept can be applied to a non-linear model? I was thinking of putting gradient=False in the decoder layer so that the model only train the weights for encoder only. Is this a correct approach?
Hi, Sorry that I missed your message. I hope you already found the answer.
Just to answer your question, yes, in my understanding, setting decoder.requires_grad_(False) would not add the gradient from decoder to the weights. and it will let the weights to only receive gradients from encoder.
Yes. Thank you. Actually, that's what I did. But maybe my question should have been more on are there any merits to using the other approaches you have enlisted here?