Skip to content

Instantly share code, notes, and snippets.

@khanhnamle1994
Created June 16, 2020 13:41
Show Gist options
  • Star 1 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save khanhnamle1994/2b575c89050b571b5a78d55bdd5b865f to your computer and use it in GitHub Desktop.
Save khanhnamle1994/2b575c89050b571b5a78d55bdd5b865f to your computer and use it in GitHub Desktop.
MultVAE model architecture
class MultVAE(BaseModel):
"""
Variational Autoencoder with Multninomial Likelihood model class
"""
def __init__(self, model_conf, num_users, num_items, device):
"""
:param model_conf: model configuration
:param num_users: number of users
:param num_items: number of items
:param device: choice of device
"""
super(MultVAE, self).__init__()
self.num_users = num_users
self.num_items = num_items
self.enc_dims = [self.num_items] + model_conf.enc_dims
self.dec_dims = self.enc_dims[::-1]
self.dims = self.enc_dims + self.dec_dims[1:]
self.total_anneal_steps = model_conf.total_anneal_steps
self.anneal_cap = model_conf.anneal_cap
self.dropout = model_conf.dropout
self.eps = 1e-6
self.anneal = 0.
self.update_count = 0
self.device = device
self.encoder = nn.ModuleList()
for i, (d_in, d_out) in enumerate(zip(self.enc_dims[:-1], self.enc_dims[1:])):
if i == len(self.enc_dims[:-1]) - 1:
d_out *= 2
self.encoder.append(nn.Linear(d_in, d_out))
if i != len(self.enc_dims[:-1]) - 1:
self.encoder.append(nn.Tanh())
self.decoder = nn.ModuleList()
for i, (d_in, d_out) in enumerate(zip(self.dec_dims[:-1], self.dec_dims[1:])):
self.decoder.append(nn.Linear(d_in, d_out))
if i != len(self.dec_dims[:-1]) - 1:
self.decoder.append(nn.Tanh())
self.to(self.device)
def forward(self, rating_matrix):
"""
Forward pass
:param rating_matrix: rating matrix
"""
# encoder
h = F.dropout(F.normalize(rating_matrix), p=self.dropout, training=self.training)
for layer in self.encoder:
h = layer(h)
# sample
mu_q = h[:, :self.enc_dims[-1]]
logvar_q = h[:, self.enc_dims[-1]:] # log sigmod^2 batch x 200
std_q = torch.exp(0.5 * logvar_q) # sigmod batch x 200
# reparametrization trick
epsilon = torch.zeros_like(std_q).normal_(mean=0, std=0.01)
sampled_z = mu_q + self.training * epsilon * std_q
# decoder
output = sampled_z
for layer in self.decoder:
output = layer(output)
if self.training:
kl_loss = ((0.5 * (-logvar_q + torch.exp(logvar_q) + torch.pow(mu_q, 2) - 1)).sum(1)).mean()
return output, kl_loss
else:
return output
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment