Skip to content

Instantly share code, notes, and snippets.

@lakshith-403
Created April 6, 2022 06:36
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save lakshith-403/355e26f603b950db49ba1977dd5e6911 to your computer and use it in GitHub Desktop.
Save lakshith-403/355e26f603b950db49ba1977dd5e6911 to your computer and use it in GitHub Desktop.
import torch
import torchvision
from torch import nn
from torch import optim
import pandas as pd
import numpy as np
from torch.utils.data import Dataset
from sklearn.preprocessing import maxabs_scale
from torch.utils.tensorboard import SummaryWriter
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Using {device} device")
print(torch.version.cuda)
class FeatureDataset(Dataset):
def __init__(self, file_name):
data_csv = pd.read_csv(file_name, header=0)
dataset = np.array(data_csv, dtype=float)
x = dataset[:, 1:785]
x = maxabs_scale(x, axis=1)
x = torch.tensor(x, dtype=float, device=device)
y = torch.ones((x.shape[0], 1), dtype=float, device=device)
self.x_train = x
self.y_train = y
def __len__(self):
return len(self.y_train)
def __getitem__(self, idx):
return self.x_train[idx], self.y_train[idx]
batch_size = 64
epochs = 50
learning_rate = 3e-4
feature_set = FeatureDataset('data/mnist_train.csv')
data_loader = torch.utils.data.DataLoader(feature_set, batch_size=batch_size, shuffle=True, drop_last=True)
class Generator(nn.Module):
def __init__(self):
super(Generator, self).__init__()
self.linear_stack = nn.Sequential(
nn.Linear(64, 512),
nn.ReLU(),
nn.Linear(512, 1024),
nn.ReLU(),
nn.Linear(1024, 28 * 28),
nn.Tanh()
)
def forward(self, x):
logits = self.linear_stack(x)
return logits
class Discriminator(nn.Module):
def __init__(self):
super(Discriminator, self).__init__()
self.linear_stack = nn.Sequential(
nn.Linear(784, 512),
nn.ReLU(),
nn.Linear(512, 1024),
nn.ReLU(),
nn.Linear(1024, 1),
nn.Sigmoid()
)
def forward(self, x):
logits = self.linear_stack(x.float())
return logits
disc = Discriminator().to(device)
gen = Generator().to(device)
opt_disc = optim.Adam(disc.parameters(), lr=learning_rate)
opt_gen = optim.Adam(gen.parameters(), lr=learning_rate)
fixed_noise = torch.randn((batch_size, 64), device=device)
criterian = nn.BCELoss()
scaler_writer = SummaryWriter(f"logs/loss")
writer_fake = SummaryWriter(f"logs/fake")
writer_real = SummaryWriter(f"logs/real")
step = 0
i = 0
for epoch in range(epochs):
print(epoch)
for batch_idx, (real, _) in enumerate(data_loader):
i = i + 1
noise = torch.randn((batch_size, 64), device=device)
fake = gen(noise)
disc_real = disc(real)
lossD_real = criterian(disc_real, torch.ones_like(disc_real, device=device))
disc_fake = disc(fake)
lossD_fake = criterian(disc_fake, torch.zeros_like(disc_fake, device=device))
lossD = (lossD_real + lossD_fake) / 2
disc.zero_grad()
lossD.backward(retain_graph=True)
opt_disc.step()
output = disc(fake)
lossG = criterian(output, torch.ones_like(output))
gen.zero_grad()
lossG.backward()
opt_gen.step()
if batch_idx == 0:
with torch.no_grad():
fake = gen(fixed_noise).reshape(-1, 1, 28, 28)
data = real.reshape(-1, 1, 28, 28)
img_grid_fake = torchvision.utils.make_grid(fake, normalize=True)
img_grid_real = torchvision.utils.make_grid(data, normalize=True)
writer_fake.add_image(
"Fake", img_grid_fake, global_step=step
)
writer_real.add_image(
"Real", img_grid_real, global_step=step
)
if batch_idx % 100 == 0:
scaler_writer.add_scalar("Discriminator", lossD, global_step=step)
scaler_writer.add_scalar("Generator", lossG, global_step=step)
step = step + 1
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment