Skip to content

Instantly share code, notes, and snippets.

@abesmon
Last active September 21, 2021 13:33
Show Gist options
  • Save abesmon/3ede509bbcf503170f5982aa2d2f5c21 to your computer and use it in GitHub Desktop.
Save abesmon/3ede509bbcf503170f5982aa2d2f5c21 to your computer and use it in GitHub Desktop.
for epoch in range(num_epochs):
for n, (real_samples, _) in enumerate(train_loader):
# Обучение дискриминатора
optimizer_discriminator.zero_grad()
# Данные для тренировки дискриминатора
real_samples = real_samples.to(device=device)
real_samples_labels = torch.ones((batch_size, 1)).to(device=device)
real_outp = discriminator(real_samples)
real_loss = loss_function(real_outp, real_samples_labels)
latent_space_samples = torch.randn((batch_size, seed_size, 1, 1)).to(
device=device)
generated_samples = generator(latent_space_samples)
generated_samples_labels = torch.zeros((batch_size, 1)).to(
device=device)
output_discriminator = discriminator(generated_samples)
gen_loss = loss_function(output_discriminator, generated_samples_labels)
total_loss = real_loss + gen_loss
total_loss.backward()
optimizer_discriminator.step()
# Обучение генератора
generator.zero_grad()
# Данные для обучения генератора
latent_space_samples = torch.randn((batch_size, seed_size, 1, 1)).to(
device=device)
generated_samples = generator(latent_space_samples)
output_discriminator_generated = discriminator(generated_samples)
loss_generator = loss_function(
output_discriminator_generated, real_samples_labels)
loss_generator.backward()
optimizer_generator.step()
# Показываем loss
if n == batch_size - 1:
clear_output()
print(f"Epoch: {epoch} Loss D.: {total_loss} | Loss G.:{loss_generator}")
# plt.imshow(generated_samples[0].cpu().detach().reshape(64, 64), cmap="gray_r")
fig, ax = plt.subplots(figsize=(8, 8))
ax.set_xticks([]); ax.set_yticks([])
ax.imshow(make_grid(generated_samples.cpu().detach(), nrow=8).permute(1, 2, 0))
# plt.imshow(generated_samples[0].cpu().detach().permute(1, 2, 0))
# plt.imshow(train_loader.real_samples.permute(1, 2, 0))
plt.show()
for epoch in range(num_epochs):
for n, (real_samples, mnist_labels) in enumerate(train_loader):
# Данные для тренировки дискриминатора
real_samples = real_samples.to(device=device)
real_samples_labels = torch.ones((batch_size, 1)).to(
device=device)
latent_space_samples = torch.randn((batch_size, 100)).to(
device=device)
generated_samples = generator(latent_space_samples)
generated_samples_labels = torch.zeros((batch_size, 1)).to(
device=device)
all_samples = torch.cat((real_samples, generated_samples))
all_samples_labels = torch.cat(
(real_samples_labels, generated_samples_labels))
# Обучение дискриминатора
discriminator.zero_grad()
output_discriminator = discriminator(all_samples)
loss_discriminator = loss_function(
output_discriminator, all_samples_labels)
loss_discriminator.backward()
optimizer_discriminator.step()
# Данные для обучения генератора
latent_space_samples = torch.randn((batch_size, 100)).to(
device=device)
# Обучение генератора
generator.zero_grad()
generated_samples = generator(latent_space_samples)
output_discriminator_generated = discriminator(generated_samples)
loss_generator = loss_function(
output_discriminator_generated, real_samples_labels)
loss_generator.backward()
optimizer_generator.step()
# Показываем loss
if n == batch_size - 1:
print(f"Epoch: {epoch} Loss D.: {loss_discriminator} | Loss G.:{loss_generator}")
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment