Skip to content

Instantly share code, notes, and snippets.

@udithhaputhanthri
Created March 26, 2021 08:25
Show Gist options
  • Save udithhaputhanthri/fd6af184ff46e59c6cb9012c21e6303f to your computer and use it in GitHub Desktop.
Save udithhaputhanthri/fd6af184ff46e59c6cb9012c21e6303f to your computer and use it in GitHub Desktop.
WGAN_scripts
C=Critic(img_channels,hidden_C).to(device)
G=Generator(noise_channels,img_channels,hidden_G).to(device)
#C=C.apply(init_weights)
#G=G.apply(init_weights)
wandb.watch(G, log='all', log_freq=10)
wandb.watch(C, log='all', log_freq=10)
opt_C=torch.optim.Adam(C.parameters(),lr=lr, betas=(0.5,0.999))
opt_G=torch.optim.Adam(G.parameters(),lr=lr, betas=(0.5,0.999))
gen_repeats=1
crit_repeats=3
noise_for_generate=torch.randn(batch_size,noise_channels,1,1).to(device)
losses_C=[]
losses_G=[]
for epoch in range(1,epochs+1):
loss_C_epoch=[]
loss_G_epoch=[]
for idx,(x,_) in enumerate(data_loader):
C.train()
G.train()
x=x.to(device)
x_len=x.shape[0]
### Train C
loss_C_iter=0
for _ in range(crit_repeats):
opt_C.zero_grad()
z=torch.randn(x_len,noise_channels,1,1).to(device)
real_imgs=x
fake_imgs=G(z).detach()
real_C_out=C(real_imgs)
fake_C_out=C(fake_imgs)
epsilon= torch.rand(len(x),1,1,1, device= device, requires_grad=True)
gradient= get_gradient(C, real_imgs, fake_imgs.detach(), epsilon)
gp= gradient_penalty(gradient)
loss_C= get_crit_loss(fake_C_out, real_C_out, gp, c_lambda=10)
loss_C.backward()
opt_C.step()
loss_C_iter+=loss_C.item()/crit_repeats
### Train G
loss_G_iter=0
for _ in range(gen_repeats):
opt_G.zero_grad()
z=torch.randn(x_len,noise_channels,1,1).to(device)
fake_C_out = C(G(z))
loss_G= get_gen_loss(fake_C_out)
loss_G.backward()
opt_G.step()
loss_G_iter+=loss_G.item()/gen_repeats
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment