Created
June 16, 2020 13:41
-
-
Save khanhnamle1994/2b575c89050b571b5a78d55bdd5b865f to your computer and use it in GitHub Desktop.
MultVAE model architecture
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
class MultVAE(BaseModel): | |
""" | |
Variational Autoencoder with Multninomial Likelihood model class | |
""" | |
def __init__(self, model_conf, num_users, num_items, device): | |
""" | |
:param model_conf: model configuration | |
:param num_users: number of users | |
:param num_items: number of items | |
:param device: choice of device | |
""" | |
super(MultVAE, self).__init__() | |
self.num_users = num_users | |
self.num_items = num_items | |
self.enc_dims = [self.num_items] + model_conf.enc_dims | |
self.dec_dims = self.enc_dims[::-1] | |
self.dims = self.enc_dims + self.dec_dims[1:] | |
self.total_anneal_steps = model_conf.total_anneal_steps | |
self.anneal_cap = model_conf.anneal_cap | |
self.dropout = model_conf.dropout | |
self.eps = 1e-6 | |
self.anneal = 0. | |
self.update_count = 0 | |
self.device = device | |
self.encoder = nn.ModuleList() | |
for i, (d_in, d_out) in enumerate(zip(self.enc_dims[:-1], self.enc_dims[1:])): | |
if i == len(self.enc_dims[:-1]) - 1: | |
d_out *= 2 | |
self.encoder.append(nn.Linear(d_in, d_out)) | |
if i != len(self.enc_dims[:-1]) - 1: | |
self.encoder.append(nn.Tanh()) | |
self.decoder = nn.ModuleList() | |
for i, (d_in, d_out) in enumerate(zip(self.dec_dims[:-1], self.dec_dims[1:])): | |
self.decoder.append(nn.Linear(d_in, d_out)) | |
if i != len(self.dec_dims[:-1]) - 1: | |
self.decoder.append(nn.Tanh()) | |
self.to(self.device) | |
def forward(self, rating_matrix): | |
""" | |
Forward pass | |
:param rating_matrix: rating matrix | |
""" | |
# encoder | |
h = F.dropout(F.normalize(rating_matrix), p=self.dropout, training=self.training) | |
for layer in self.encoder: | |
h = layer(h) | |
# sample | |
mu_q = h[:, :self.enc_dims[-1]] | |
logvar_q = h[:, self.enc_dims[-1]:] # log sigmod^2 batch x 200 | |
std_q = torch.exp(0.5 * logvar_q) # sigmod batch x 200 | |
# reparametrization trick | |
epsilon = torch.zeros_like(std_q).normal_(mean=0, std=0.01) | |
sampled_z = mu_q + self.training * epsilon * std_q | |
# decoder | |
output = sampled_z | |
for layer in self.decoder: | |
output = layer(output) | |
if self.training: | |
kl_loss = ((0.5 * (-logvar_q + torch.exp(logvar_q) + torch.pow(mu_q, 2) - 1)).sum(1)).mean() | |
return output, kl_loss | |
else: | |
return output |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment