Skip to content

Instantly share code, notes, and snippets.

@gchhablani
Last active May 17, 2020 07:40
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save gchhablani/b7b35fadf406dddb2e71d87fabed560b to your computer and use it in GitHub Desktop.
Save gchhablani/b7b35fadf406dddb2e71d87fabed560b to your computer and use it in GitHub Desktop.
class VAE(nn.Module):
def __init__(self):
super(VAE,self).__init__()
self.encoder = nn.Sequential(nn.Linear(784,128),nn.ReLU(),nn.Linear(128,64),nn.ReLU())
self.decoder = nn.Sequential(nn.Linear(64,128),nn.ReLU(),nn.Linear(128,784))
self._mu = nn.Linear(64,64)
self._log_sigma = nn.Linear(64,64)
def sampler(self,encoding):
mu = self._mu(encoding)
sigma = torch.exp(0.5*self._log_sigma(encoding))
z = torch.from_numpy(np.random.normal(0,1,size=sigma.size())).float()
self.z_mean = mu
self.z_sigma = sigma
return mu+sigma*Variable(z,requires_grad=False).to(device)
def forward(self,inp):
return self.decoder(self.sampler(self.encoder(inp)))
def kld_loss(z_mean,z_sigma):
return 0.5*torch.mean(z_mean*z_mean+z_sigma*z_sigma - 2*torch.log(z_sigma)-1)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment