Skip to content

Instantly share code, notes, and snippets.

@udithhaputhanthri
Last active May 16, 2020 03:50
Show Gist options
  • Save udithhaputhanthri/38d8698391c8392657149fbd7f3d44e1 to your computer and use it in GitHub Desktop.
Save udithhaputhanthri/38d8698391c8392657149fbd7f3d44e1 to your computer and use it in GitHub Desktop.
Introduction to DCGAN using PyTorch
D.train()
G.train()
noise_for_generate=torch.randn(batch_size,noise_channels,1,1).to(device)
for epoch in range(epochs):
for idx,(x,_) in enumerate(data_loader):
x=x.to(device)
x_len=x.shape[0]
### Train D
D.zero_grad()
z=torch.randn(x_len,noise_channels,1,1).to(device)
real,label_real_D=D(x).reshape(-1),(torch.ones(x_len)*0.9).to(device)
fake,label_fake_D=D(G(z).detach()).reshape(-1),(torch.ones(x_len)*0.1).to(device)
loss_D=criterion(real,label_real_D)+criterion(fake,label_fake_D)
loss_D.backward()
opt_D.step()
### Train G
G.zero_grad()
label_real_G=torch.ones(x_len).to(device)
loss_G=criterion(D(G(z)).reshape(-1),label_real_G)
loss_G.backward()
opt_G.step()
### Return current state
if idx%50==0:
print(f'epoch:{epoch}/{epochs} iteration:{idx}/{len(dataset)} Loss D :{loss_D} -- Loss G :{loss_G}')
torch.save({'state_dict': G.state_dict()}, 'latest_model/checkpoint_G.pth.tar')
torch.save({'state_dict': D.state_dict()}, 'latest_model/checkpoint_D.pth.tar')
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment