Skip to content

Instantly share code, notes, and snippets.

@khanhnamle1994
Created June 16, 2020 13:34
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save khanhnamle1994/d86d79ee5e32f4098829d12dbe383d53 to your computer and use it in GitHub Desktop.
Save khanhnamle1994/d86d79ee5e32f4098829d12dbe383d53 to your computer and use it in GitHub Desktop.
SVAE model architecture
class SVAE(nn.Module):
"""
Function to build the SVAE model
"""
def __init__(self, hyper_params):
super(Model, self).__init__()
self.hyper_params = hyper_params
self.encoder = Encoder(hyper_params)
self.decoder = Decoder(hyper_params)
self.item_embed = nn.Embedding(hyper_params['total_items'], hyper_params['item_embed_size'])
self.gru = nn.GRU(
hyper_params['item_embed_size'], hyper_params['rnn_size'],
batch_first=True, num_layers=1
)
self.linear1 = nn.Linear(hyper_params['hidden_size'], 2 * hyper_params['latent_size'])
nn.init.xavier_normal(self.linear1.weight)
self.tanh = nn.Tanh()
def sample_latent(self, h_enc):
"""
Return the latent normal sample z ~ N(mu, sigma^2)
"""
temp_out = self.linear1(h_enc)
mu = temp_out[:, :self.hyper_params['latent_size']]
log_sigma = temp_out[:, self.hyper_params['latent_size']:]
sigma = torch.exp(log_sigma)
std_z = torch.from_numpy(np.random.normal(0, 1, size=sigma.size())).float()
self.z_mean = mu
self.z_log_sigma = log_sigma
return mu + sigma * Variable(std_z, requires_grad=False) # Reparameterization trick
def forward(self, x):
"""
Function to do a forward pass
:param x: the input
"""
in_shape = x.shape # [bsz x seq_len] = [1 x seq_len]
x = x.view(-1) # [seq_len]
x = self.item_embed(x) # [seq_len x embed_size]
x = x.view(in_shape[0], in_shape[1], -1) # [1 x seq_len x embed_size]
rnn_out, _ = self.gru(x) # [1 x seq_len x rnn_size]
rnn_out = rnn_out.view(in_shape[0] * in_shape[1], -1) # [seq_len x rnn_size]
enc_out = self.encoder(rnn_out) # [seq_len x hidden_size]
sampled_z = self.sample_latent(enc_out) # [seq_len x latent_size]
dec_out = self.decoder(sampled_z) # [seq_len x total_items]
dec_out = dec_out.view(in_shape[0], in_shape[1], -1) # [1 x seq_len x total_items]
return dec_out, self.z_mean, self.z_log_sigma
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment