Skip to content

Instantly share code, notes, and snippets.

@koshian2
Created September 14, 2018 04:07
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save koshian2/2098e2261d673c818f6bdc51fa485e86 to your computer and use it in GitHub Desktop.
Save koshian2/2098e2261d673c818f6bdc51fa485e86 to your computer and use it in GitHub Desktop.
Vanilla Auto Encoder using Shake-Shake regulalization(failure)
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import torchvision
import torchvision.transforms as transforms
from torchvision.utils import save_image
from torchvision.datasets import MNIST, FashionMNIST, CIFAR10, STL10
from shake_shake_function import get_alpha_beta, shake_function
import os
import pickle
import zipfile
import datetime
import numpy as np
class EncoderModule(nn.Module):
def __init__(self, input_channels, output_channels, stride, kernel, pad):
super().__init__()
self.conv1 = nn.Conv2d(input_channels, output_channels, kernel_size=kernel, padding=pad, stride=stride)
self.bn1 = nn.BatchNorm2d(output_channels)
self.conv2 = nn.Conv2d(input_channels, output_channels, kernel_size=kernel, padding=pad, stride=stride)
self.bn2 = nn.BatchNorm2d(output_channels)
def forward(self, x):
path1 = F.relu(self.bn1(self.conv1(x)), inplace=True)
path2 = F.relu(self.bn2(self.conv2(x)), inplace=True)
# Shake-shake
if self.training:
shake_config = (True, True, True)
else:
shake_config = (False, False, False)
alpha, beta = get_alpha_beta(x.size(0), shake_config, x.is_cuda)
return shake_function(path1, path2, alpha, beta)
class Encoder(nn.Module):
def __init__(self, color_channels, pooling_kernels, n_neurons_in_middle_layer):
self.n_neurons_in_middle_layer = n_neurons_in_middle_layer
super().__init__()
self.bottle = EncoderModule(color_channels, 32, stride=1, kernel=1, pad=0)
self.m1 = EncoderModule(32, 64, stride=1, kernel=3, pad=1)
self.m2 = EncoderModule(64, 128, stride=pooling_kernels[0], kernel=3, pad=1)
self.m3 = EncoderModule(128, 256, stride=pooling_kernels[1], kernel=3, pad=1)
def forward(self, x):
out = self.m3(self.m2(self.m1(self.bottle(x))))
return out.view(-1, self.n_neurons_in_middle_layer)
class DecoderModule(nn.Module):
def __init__(self, input_channels, output_channels, stride, activation="relu"):
super().__init__()
self.convt1 = nn.ConvTranspose2d(input_channels, output_channels, kernel_size=stride, stride=stride)
self.bn1 = nn.BatchNorm2d(output_channels)
self.convt2 = nn.ConvTranspose2d(input_channels, output_channels, kernel_size=stride, stride=stride)
self.bn2 = nn.BatchNorm2d(output_channels)
if activation == "relu":
self.activation = nn.ReLU(inplace=True)
elif activation == "sigmoid":
self.activation = nn.Sigmoid()
def forward(self, x):
path1 = self.activation(self.bn1(self.convt1(x)))
path2 = self.activation(self.bn2(self.convt2(x)))
# Shake-shake
if self.training:
shake_config = (True, True, True)
else:
shake_config = (False, False, False)
alpha, beta = get_alpha_beta(x.size(0), shake_config, x.is_cuda)
return shake_function(path1, path2, alpha, beta)
class Decoder(nn.Module):
def __init__(self, color_channels, pooling_kernels, decoder_input_size):
self.decoder_input_size = decoder_input_size
super().__init__()
self.m1 = DecoderModule(256, 128, stride=1)
self.m2 = DecoderModule(128, 64, stride=pooling_kernels[1])
self.m3 = DecoderModule(64, 32, stride=pooling_kernels[0])
self.bottle = DecoderModule(32, color_channels, stride=1, activation="sigmoid")
def forward(self, x):
out = x.view(-1, 256, self.decoder_input_size, self.decoder_input_size)
out = self.m3(self.m2(self.m1(out)))
return self.bottle(out)
class VAE(nn.Module):
def __init__(self, dataset):
self.device = 'cuda' if torch.cuda.is_available() else 'cpu'
assert dataset in ["mnist" ,"fashion-mnist", "cifar", "stl"]
super().__init__()
# # latent features
self.n_latent_features = 64
# resolution
# mnist, fashion-mnist : 28 -> 14 -> 7
# cifar : 32 -> 8 -> 4
# stl : 96 -> 24 -> 6
if dataset in ["mnist", "fashion-mnist"]:
pooling_kernel = [2, 2]
encoder_output_size = 7
elif dataset == "cifar":
pooling_kernel = [4, 2]
encoder_output_size = 4
elif dataset == "stl":
pooling_kernel = [4, 4]
encoder_output_size = 6
# color channels
if dataset in ["mnist", "fashion-mnist"]:
color_channels = 1
else:
color_channels = 3
# # neurons int middle layer
n_neurons_middle_layer = 256 * encoder_output_size * encoder_output_size
# Encoder
self.encoder = Encoder(color_channels, pooling_kernel, n_neurons_middle_layer)
# Middle
self.fc1 = nn.Linear(n_neurons_middle_layer, self.n_latent_features)
self.fc2 = nn.Linear(self.n_latent_features, n_neurons_middle_layer)
# Decoder
self.decoder = Decoder(color_channels, pooling_kernel, encoder_output_size)
# Latent values cache
self.latent_cache = None
# data
self.train_loader, self.test_loader = self.load_data(dataset)
# history
self.history = {"loss":[], "val_loss":[]}
# model name
self.model_name = dataset
if not os.path.exists(self.model_name):
os.mkdir(self.model_name)
def bootstrap_sampling(self):
if self.latent_cache is None:
z = (torch.rand(64, self.n_latent_features)*2-1).to(self.device)
else:
# feature-wise bootstrap
if self.device == "cuda":
self.latent_cache = self.latent_cache.cpu()
cache = self.latent_cache.numpy()
z = np.zeros((64, self.n_latent_features)).astype("float32")
for i in range(self.n_latent_features):
index = np.random.choice(cache.shape[0], 64)
z[:,i] = cache[index, i]
z = torch.tensor(z).to(self.device)
z = F.relu(self.fc2(z))
# decode
return self.decoder(z)
def random_sampling(self):
z = (torch.rand(64, self.n_latent_features)*2-1).to(self.device)
z = F.relu(self.fc2(z))
return self.decoder(z)
def forward(self, x):
# Encoder
h = self.encoder(x)
# Bottle-neck
latent = F.tanh(self.fc1(h))
self.latent_cache = latent # latent layers cache
# decoder
z = F.relu(self.fc2(latent))
return self.decoder(z)
# Data
def load_data(self, dataset):
data_transform = transforms.Compose([
transforms.ToTensor()
])
if dataset == "mnist":
train = MNIST(root="./data", train=True, transform=data_transform, download=True)
test = MNIST(root="./data", train=False, transform=data_transform, download=True)
elif dataset == "fashion-mnist":
train = FashionMNIST(root="./data", train=True, transform=data_transform, download=True)
test = FashionMNIST(root="./data", train=False, transform=data_transform, download=True)
elif dataset == "cifar":
train = CIFAR10(root="./data", train=True, transform=data_transform, download=True)
test = CIFAR10(root="./data", train=False, transform=data_transform, download=True)
elif dataset == "stl":
train = STL10(root="./data", split="unlabeled", transform=data_transform, download=True)
test = STL10(root="./data", split="test", transform=data_transform, download=True)
train_loader = torch.utils.data.DataLoader(train, batch_size=128, shuffle=True, num_workers=0)
test_loader = torch.utils.data.DataLoader(test, batch_size=64, shuffle=True, num_workers=0)
return train_loader, test_loader
def mixup_data(self, x, alpha):
if alpha > 0:
lam = np.random.beta(alpha, alpha)
else:
lam = 1
batch_size = x.size()[0]
if self.device == "cuda":
index = torch.randperm(batch_size).cuda()
else:
index = torch.randperm(batch_size)
x2 = x[index, :]
mixed_x = lam * x + (1-lam) * x2
if lam >= 0.5:
return mixed_x, x
else:
return mixed_x, x2
# Model
def init_model(self):
self.optimizer = optim.Adam(self.parameters(), lr=1e-3)
self.criterion = nn.MSELoss()
self.scheduler = optim.lr_scheduler.CosineAnnealingLR(self.optimizer, 150, 5e-5)
if self.device == "cuda":
self = self.cuda()
torch.backends.cudnn.benchmark=True
self.to(self.device)
# Train
def fit_train(self, epoch):
self.train()
print(f"\nEpoch: {epoch+1:d} {datetime.datetime.now()}")
train_loss = 0
samples_cnt = 0
for batch_idx, (inputs, _) in enumerate(self.train_loader):
# mixup on train
mixup_x, true_x = self.mixup_data(inputs, 0.2)
mixup_x, true_x = mixup_x.to(self.device), true_x.to(self.device)
self.optimizer.zero_grad()
recon_batch = self(mixup_x)
loss = self.criterion(recon_batch, true_x)
loss.backward()
self.optimizer.step()
train_loss += loss.item()
samples_cnt += inputs.size(0)
if batch_idx%50 == 0:
print(batch_idx, len(self.train_loader), f"Loss: {train_loss/samples_cnt:f}")
self.history["loss"].append(train_loss/samples_cnt)
def test(self, epoch):
self.eval()
val_loss = 0
samples_cnt = 0
with torch.no_grad():
for batch_idx, (inputs, _) in enumerate(self.test_loader):
# no mixup on test time
inputs = inputs.to(self.device)
recon_batch = self(inputs)
val_loss += self.criterion(recon_batch, inputs).item()
samples_cnt += inputs.size(0)
if batch_idx == 0:
save_image(recon_batch, f"{self.model_name}/reconstruction_epoch_{str(epoch)}.png", nrow=8)
print(batch_idx, len(self.test_loader), f"ValLoss: {val_loss/samples_cnt:f}")
self.history["val_loss"].append(val_loss/samples_cnt)
# sampling
save_image(self.bootstrap_sampling(), f"{self.model_name}/sampling_bootstrap_epoch_{str(epoch)}.png", nrow=8)
save_image(self.random_sampling(), f"{self.model_name}/sampling_random_epoch_{str(epoch)}.png", nrow=8)
# save results
def save_history(self):
with open(f"{self.model_name}/{self.model_name}_history.dat", "wb") as fp:
pickle.dump(self.history, fp)
def save_to_zip(self):
with zipfile.ZipFile(f"{self.model_name}.zip", "w") as zip:
for file in os.listdir(self.model_name):
zip.write(f"{self.model_name}/{file}", file)
def save_to_googledrive(self, drive_service):
saving_filename = self.model_name+".zip"
file_metadata = {
'name': saving_filename,
'mimeType': 'application/octet-stream'
}
media = googleapiclient.http.MediaFileUpload(saving_filename,
mimetype='application/octet-stream',
resumable=True)
created = drive_service.files().create(body=file_metadata,
media_body=media,
fields='id').execute()
def google_drive_init():
google.colab.auth.authenticate_user()
return googleapiclient.discovery.build('drive', 'v3')
#from torch_summary import summary
def main():
googleclient = google_drive_init()
net = VAE("mnist")
#summary(net,(1,28,28))
#exit()
net.init_model()
for i in range(1):
net.scheduler.step()
net.fit_train(i)
net.test(i)
net.save_history()
net.save_to_zip()
net.save_to_googledrive(googleclient)
if __name__ == "__main__":
main()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment