Skip to content

Instantly share code, notes, and snippets.

@dte dte/Adversarial2.py
Created Mar 20, 2018

Embed
What would you like to do?
torch.manual_seed(10)
Q, P = Q_net() = Q_net(), P_net(0) # Encoder/Decoder
D_gauss = D_net_gauss() # Discriminator adversarial
if torch.cuda.is_available():
Q = Q.cuda()
P = P.cuda()
D_cat = D_gauss.cuda()
D_gauss = D_net_gauss().cuda()
# Set learning rates
gen_lr, reg_lr = 0.0006, 0.0008
# Set optimizators
P_decoder = optim.Adam(P.parameters(), lr=gen_lr)
Q_encoder = optim.Adam(Q.parameters(), lr=gen_lr)
Q_generator = optim.Adam(Q.parameters(), lr=reg_lr)
D_gauss_solver = optim.Adam(D_gauss.parameters(), lr=reg_lr)
z_sample = Q(X)
X_sample = P(z_sample)
recon_loss = F.binary_cross_entropy(X_sample + TINY,
X.resize(train_batch_size, X_dim) + TINY)
recon_loss.backward()
P_decoder.step()
Q_encoder.step()
Q.eval()
z_real_gauss = Variable(torch.randn(train_batch_size, z_dim) * 5) # Sample from N(0,5)
if torch.cuda.is_available():
z_real_gauss = z_real_gauss.cuda()
z_fake_gauss = Q(X)
# Compute discriminator outputs and loss
D_real_gauss, D_fake_gauss = D_gauss(z_real_gauss), D_gauss(z_fake_gauss)
D_loss_gauss = -torch.mean(torch.log(D_real_gauss + TINY) + torch.log(1 - D_fake_gauss + TINY))
D_loss.backward() # Backpropagate loss
D_gauss_solver.step() # Apply optimization step
# Generator
Q.train() # Back to use dropout
z_fake_gauss = Q(X)
D_fake_gauss = D_gauss(z_fake_gauss)
G_loss = -torch.mean(torch.log(D_fake_gauss + TINY))
G_loss.backward()
Q_generator.step()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
You can’t perform that action at this time.