Skip to content

Instantly share code, notes, and snippets.

@SannaPersson
Created September 13, 2022 14:03
Show Gist options
  • Save SannaPersson/4779d7c3abd7c192b3649e4bf1c9fd62 to your computer and use it in GitHub Desktop.
Save SannaPersson/4779d7c3abd7c192b3649e4bf1c9fd62 to your computer and use it in GitHub Desktop.
variational_autoencoder3
# Define train function
def train(num_epochs, model, optimizer, loss_fn):
# Start training
for epoch in range(num_epochs):
loop = tqdm(enumerate(train_loader))
for i, (x, y) in loop:
# Forward pass
x = x.to(device).view(-1, INPUT_DIM)
x_reconst, mu, sigma = model(x)
# loss, formulas from https://www.youtube.com/watch?v=igP03FXZqgo&t=2182s
reconst_loss = loss_fn(x_reconst, x)
kl_div = - torch.sum(1 + torch.log(sigma.pow(2)) - mu.pow(2) - sigma.pow(2))
# Backprop and optimize
loss = reconst_loss + kl_div
optimizer.zero_grad()
loss.backward()
optimizer.step()
loop.set_postfix(loss=loss.item())
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment