Created
September 14, 2018 04:07
-
-
Save koshian2/2098e2261d673c818f6bdc51fa485e86 to your computer and use it in GitHub Desktop.
Vanilla Auto Encoder using Shake-Shake regulalization(failure)
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import torch | |
import torch.nn as nn | |
import torch.optim as optim | |
import torch.nn.functional as F | |
import torchvision | |
import torchvision.transforms as transforms | |
from torchvision.utils import save_image | |
from torchvision.datasets import MNIST, FashionMNIST, CIFAR10, STL10 | |
from shake_shake_function import get_alpha_beta, shake_function | |
import os | |
import pickle | |
import zipfile | |
import datetime | |
import numpy as np | |
class EncoderModule(nn.Module): | |
def __init__(self, input_channels, output_channels, stride, kernel, pad): | |
super().__init__() | |
self.conv1 = nn.Conv2d(input_channels, output_channels, kernel_size=kernel, padding=pad, stride=stride) | |
self.bn1 = nn.BatchNorm2d(output_channels) | |
self.conv2 = nn.Conv2d(input_channels, output_channels, kernel_size=kernel, padding=pad, stride=stride) | |
self.bn2 = nn.BatchNorm2d(output_channels) | |
def forward(self, x): | |
path1 = F.relu(self.bn1(self.conv1(x)), inplace=True) | |
path2 = F.relu(self.bn2(self.conv2(x)), inplace=True) | |
# Shake-shake | |
if self.training: | |
shake_config = (True, True, True) | |
else: | |
shake_config = (False, False, False) | |
alpha, beta = get_alpha_beta(x.size(0), shake_config, x.is_cuda) | |
return shake_function(path1, path2, alpha, beta) | |
class Encoder(nn.Module): | |
def __init__(self, color_channels, pooling_kernels, n_neurons_in_middle_layer): | |
self.n_neurons_in_middle_layer = n_neurons_in_middle_layer | |
super().__init__() | |
self.bottle = EncoderModule(color_channels, 32, stride=1, kernel=1, pad=0) | |
self.m1 = EncoderModule(32, 64, stride=1, kernel=3, pad=1) | |
self.m2 = EncoderModule(64, 128, stride=pooling_kernels[0], kernel=3, pad=1) | |
self.m3 = EncoderModule(128, 256, stride=pooling_kernels[1], kernel=3, pad=1) | |
def forward(self, x): | |
out = self.m3(self.m2(self.m1(self.bottle(x)))) | |
return out.view(-1, self.n_neurons_in_middle_layer) | |
class DecoderModule(nn.Module): | |
def __init__(self, input_channels, output_channels, stride, activation="relu"): | |
super().__init__() | |
self.convt1 = nn.ConvTranspose2d(input_channels, output_channels, kernel_size=stride, stride=stride) | |
self.bn1 = nn.BatchNorm2d(output_channels) | |
self.convt2 = nn.ConvTranspose2d(input_channels, output_channels, kernel_size=stride, stride=stride) | |
self.bn2 = nn.BatchNorm2d(output_channels) | |
if activation == "relu": | |
self.activation = nn.ReLU(inplace=True) | |
elif activation == "sigmoid": | |
self.activation = nn.Sigmoid() | |
def forward(self, x): | |
path1 = self.activation(self.bn1(self.convt1(x))) | |
path2 = self.activation(self.bn2(self.convt2(x))) | |
# Shake-shake | |
if self.training: | |
shake_config = (True, True, True) | |
else: | |
shake_config = (False, False, False) | |
alpha, beta = get_alpha_beta(x.size(0), shake_config, x.is_cuda) | |
return shake_function(path1, path2, alpha, beta) | |
class Decoder(nn.Module): | |
def __init__(self, color_channels, pooling_kernels, decoder_input_size): | |
self.decoder_input_size = decoder_input_size | |
super().__init__() | |
self.m1 = DecoderModule(256, 128, stride=1) | |
self.m2 = DecoderModule(128, 64, stride=pooling_kernels[1]) | |
self.m3 = DecoderModule(64, 32, stride=pooling_kernels[0]) | |
self.bottle = DecoderModule(32, color_channels, stride=1, activation="sigmoid") | |
def forward(self, x): | |
out = x.view(-1, 256, self.decoder_input_size, self.decoder_input_size) | |
out = self.m3(self.m2(self.m1(out))) | |
return self.bottle(out) | |
class VAE(nn.Module): | |
def __init__(self, dataset): | |
self.device = 'cuda' if torch.cuda.is_available() else 'cpu' | |
assert dataset in ["mnist" ,"fashion-mnist", "cifar", "stl"] | |
super().__init__() | |
# # latent features | |
self.n_latent_features = 64 | |
# resolution | |
# mnist, fashion-mnist : 28 -> 14 -> 7 | |
# cifar : 32 -> 8 -> 4 | |
# stl : 96 -> 24 -> 6 | |
if dataset in ["mnist", "fashion-mnist"]: | |
pooling_kernel = [2, 2] | |
encoder_output_size = 7 | |
elif dataset == "cifar": | |
pooling_kernel = [4, 2] | |
encoder_output_size = 4 | |
elif dataset == "stl": | |
pooling_kernel = [4, 4] | |
encoder_output_size = 6 | |
# color channels | |
if dataset in ["mnist", "fashion-mnist"]: | |
color_channels = 1 | |
else: | |
color_channels = 3 | |
# # neurons int middle layer | |
n_neurons_middle_layer = 256 * encoder_output_size * encoder_output_size | |
# Encoder | |
self.encoder = Encoder(color_channels, pooling_kernel, n_neurons_middle_layer) | |
# Middle | |
self.fc1 = nn.Linear(n_neurons_middle_layer, self.n_latent_features) | |
self.fc2 = nn.Linear(self.n_latent_features, n_neurons_middle_layer) | |
# Decoder | |
self.decoder = Decoder(color_channels, pooling_kernel, encoder_output_size) | |
# Latent values cache | |
self.latent_cache = None | |
# data | |
self.train_loader, self.test_loader = self.load_data(dataset) | |
# history | |
self.history = {"loss":[], "val_loss":[]} | |
# model name | |
self.model_name = dataset | |
if not os.path.exists(self.model_name): | |
os.mkdir(self.model_name) | |
def bootstrap_sampling(self): | |
if self.latent_cache is None: | |
z = (torch.rand(64, self.n_latent_features)*2-1).to(self.device) | |
else: | |
# feature-wise bootstrap | |
if self.device == "cuda": | |
self.latent_cache = self.latent_cache.cpu() | |
cache = self.latent_cache.numpy() | |
z = np.zeros((64, self.n_latent_features)).astype("float32") | |
for i in range(self.n_latent_features): | |
index = np.random.choice(cache.shape[0], 64) | |
z[:,i] = cache[index, i] | |
z = torch.tensor(z).to(self.device) | |
z = F.relu(self.fc2(z)) | |
# decode | |
return self.decoder(z) | |
def random_sampling(self): | |
z = (torch.rand(64, self.n_latent_features)*2-1).to(self.device) | |
z = F.relu(self.fc2(z)) | |
return self.decoder(z) | |
def forward(self, x): | |
# Encoder | |
h = self.encoder(x) | |
# Bottle-neck | |
latent = F.tanh(self.fc1(h)) | |
self.latent_cache = latent # latent layers cache | |
# decoder | |
z = F.relu(self.fc2(latent)) | |
return self.decoder(z) | |
# Data | |
def load_data(self, dataset): | |
data_transform = transforms.Compose([ | |
transforms.ToTensor() | |
]) | |
if dataset == "mnist": | |
train = MNIST(root="./data", train=True, transform=data_transform, download=True) | |
test = MNIST(root="./data", train=False, transform=data_transform, download=True) | |
elif dataset == "fashion-mnist": | |
train = FashionMNIST(root="./data", train=True, transform=data_transform, download=True) | |
test = FashionMNIST(root="./data", train=False, transform=data_transform, download=True) | |
elif dataset == "cifar": | |
train = CIFAR10(root="./data", train=True, transform=data_transform, download=True) | |
test = CIFAR10(root="./data", train=False, transform=data_transform, download=True) | |
elif dataset == "stl": | |
train = STL10(root="./data", split="unlabeled", transform=data_transform, download=True) | |
test = STL10(root="./data", split="test", transform=data_transform, download=True) | |
train_loader = torch.utils.data.DataLoader(train, batch_size=128, shuffle=True, num_workers=0) | |
test_loader = torch.utils.data.DataLoader(test, batch_size=64, shuffle=True, num_workers=0) | |
return train_loader, test_loader | |
def mixup_data(self, x, alpha): | |
if alpha > 0: | |
lam = np.random.beta(alpha, alpha) | |
else: | |
lam = 1 | |
batch_size = x.size()[0] | |
if self.device == "cuda": | |
index = torch.randperm(batch_size).cuda() | |
else: | |
index = torch.randperm(batch_size) | |
x2 = x[index, :] | |
mixed_x = lam * x + (1-lam) * x2 | |
if lam >= 0.5: | |
return mixed_x, x | |
else: | |
return mixed_x, x2 | |
# Model | |
def init_model(self): | |
self.optimizer = optim.Adam(self.parameters(), lr=1e-3) | |
self.criterion = nn.MSELoss() | |
self.scheduler = optim.lr_scheduler.CosineAnnealingLR(self.optimizer, 150, 5e-5) | |
if self.device == "cuda": | |
self = self.cuda() | |
torch.backends.cudnn.benchmark=True | |
self.to(self.device) | |
# Train | |
def fit_train(self, epoch): | |
self.train() | |
print(f"\nEpoch: {epoch+1:d} {datetime.datetime.now()}") | |
train_loss = 0 | |
samples_cnt = 0 | |
for batch_idx, (inputs, _) in enumerate(self.train_loader): | |
# mixup on train | |
mixup_x, true_x = self.mixup_data(inputs, 0.2) | |
mixup_x, true_x = mixup_x.to(self.device), true_x.to(self.device) | |
self.optimizer.zero_grad() | |
recon_batch = self(mixup_x) | |
loss = self.criterion(recon_batch, true_x) | |
loss.backward() | |
self.optimizer.step() | |
train_loss += loss.item() | |
samples_cnt += inputs.size(0) | |
if batch_idx%50 == 0: | |
print(batch_idx, len(self.train_loader), f"Loss: {train_loss/samples_cnt:f}") | |
self.history["loss"].append(train_loss/samples_cnt) | |
def test(self, epoch): | |
self.eval() | |
val_loss = 0 | |
samples_cnt = 0 | |
with torch.no_grad(): | |
for batch_idx, (inputs, _) in enumerate(self.test_loader): | |
# no mixup on test time | |
inputs = inputs.to(self.device) | |
recon_batch = self(inputs) | |
val_loss += self.criterion(recon_batch, inputs).item() | |
samples_cnt += inputs.size(0) | |
if batch_idx == 0: | |
save_image(recon_batch, f"{self.model_name}/reconstruction_epoch_{str(epoch)}.png", nrow=8) | |
print(batch_idx, len(self.test_loader), f"ValLoss: {val_loss/samples_cnt:f}") | |
self.history["val_loss"].append(val_loss/samples_cnt) | |
# sampling | |
save_image(self.bootstrap_sampling(), f"{self.model_name}/sampling_bootstrap_epoch_{str(epoch)}.png", nrow=8) | |
save_image(self.random_sampling(), f"{self.model_name}/sampling_random_epoch_{str(epoch)}.png", nrow=8) | |
# save results | |
def save_history(self): | |
with open(f"{self.model_name}/{self.model_name}_history.dat", "wb") as fp: | |
pickle.dump(self.history, fp) | |
def save_to_zip(self): | |
with zipfile.ZipFile(f"{self.model_name}.zip", "w") as zip: | |
for file in os.listdir(self.model_name): | |
zip.write(f"{self.model_name}/{file}", file) | |
def save_to_googledrive(self, drive_service): | |
saving_filename = self.model_name+".zip" | |
file_metadata = { | |
'name': saving_filename, | |
'mimeType': 'application/octet-stream' | |
} | |
media = googleapiclient.http.MediaFileUpload(saving_filename, | |
mimetype='application/octet-stream', | |
resumable=True) | |
created = drive_service.files().create(body=file_metadata, | |
media_body=media, | |
fields='id').execute() | |
def google_drive_init(): | |
google.colab.auth.authenticate_user() | |
return googleapiclient.discovery.build('drive', 'v3') | |
#from torch_summary import summary | |
def main(): | |
googleclient = google_drive_init() | |
net = VAE("mnist") | |
#summary(net,(1,28,28)) | |
#exit() | |
net.init_model() | |
for i in range(1): | |
net.scheduler.step() | |
net.fit_train(i) | |
net.test(i) | |
net.save_history() | |
net.save_to_zip() | |
net.save_to_googledrive(googleclient) | |
if __name__ == "__main__": | |
main() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment