Skip to content

Instantly share code, notes, and snippets.

@saeedizadi
Created November 19, 2018 23:48
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save saeedizadi/12af1ed03f6b165b932d6e35d037e18a to your computer and use it in GitHub Desktop.
Save saeedizadi/12af1ed03f6b165b932d6e35d037e18a to your computer and use it in GitHub Desktop.
model = dict()
optimizer = dict()
criterion = dict()
model['gen'] = UncertaintyNet(ndense=12, nconvs=8, growth_rate=8, scale=scale).cuda()
optimizer['gen'] = optim.Adam(model['gen'].parameters(), lr=lr, weight_decay=weight_decay)
criterion['gen'] = nn.L1Loss().cuda()
model['disc'] = Discriminator().cuda()
optimizer['disc'] = optim.Adam(model['disc'].parameters(), lr=lr, weight_decay=weight_decay)
criterion['disc'] = nn.BCELoss().cuda()
def gan_train(model, weightdir, epochs, dataloader, optimizer, criterion, scale, log_step=10):
lr_size = map(lambda x: x / scale, IMAGE_SIZE)
updown = transforms.Compose([transforms.ToPILImage(),
transforms.Resize(size=lr_size, interpolation=Image.BICUBIC),
transforms.Grayscale(),
transforms.ToTensor()])
with tqdm(total=epochs, leave=False, dynamic_ncols=True, disable=True) as pbar:
for epoch in range(1, epochs + 1):
g_loss_avg = 0.
d_loss_avg = 0.
gd_loss_avg = 0.
for step, highres in enumerate(dataloader):
lowres = torch.FloatTensor(highres.size()[0], 1, lr_size[0], lr_size[1])
for j in range(highres.size()[0]):
lowres[j] = updown(highres[j])
highres = Variable(highres).cuda()
lowres = Variable(lowres).cuda()
# --- Train Discriminator Real ---
for p in model['disc'].parameters():
p.requires_grad = True
model['disc'].zero_grad()
model['gen'].eval()
model['disc'].train()
real_targets = Variable(torch.ones(highres.size()[0]).cuda())
real_output = model['disc'](highres)
real_loss = criterion['disc'](real_output.squeeze(), real_targets.squeeze())
# --- Train Discriminator for Fake ---
gen_outputs = model['gen'](lowres)
gen_outputs = gen_outputs.detach()
fake_targets = Variable(torch.zeros(highres.size()[0]).cuda())
fake_output = model['disc'](gen_outputs) # detach for speed concers
fake_loss = criterion['disc'](fake_output.squeeze(), fake_targets.squeeze())
disc_loss = (real_loss + fake_loss)
disc_loss.backward(disc_loss)
d_loss_avg += disc_loss.data.cpu().numpy() * highres.size()[0]
optimizer['disc'].step()
# --- Train Generator ---
# --- Do not update Discriminator ---
for p in model['disc'].parameters():
p.requires_grad = False
model['gen'].zero_grad()
model['gen'].train()
model['disc'].eval()
alpha = 10
gen_outputs = model['gen'](lowres)
loss = criterion['gen'](gen_outputs, highres)
real_targets = Variable(torch.ones(highres.size()[0], 1).cuda())
output = model['disc'](gen_outputs)
gen_loss = criterion['disc'](output, real_targets)
gd_loss_avg += gen_loss.data.cpu().numpy() * highres.size()[0]
gen_loss = gen_loss + alpha * loss
gen_loss.backward()
optimizer['gen'].step()
g_loss_avg += gen_loss.data.cpu().numpy() * highres.size()[0]
pbar.update()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment