Created
December 9, 2017 00:17
-
-
Save bigsnarfdude/dde651f6e06f266b48bc3750ac730f80 to your computer and use it in GitHub Desktop.
denoising autoencoder pytorch cuda
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import os | |
import torch | |
from torch import nn | |
from torch.autograd import Variable | |
from torch.utils.data import DataLoader | |
from torchvision import transforms | |
from torchvision.datasets import MNIST | |
from torchvision.utils import save_image | |
if not os.path.exists('./mlp_img'): | |
os.mkdir('./mlp_img') | |
def to_img(x): | |
x = x.view(x.size(0), 1, 28, 28) | |
return x | |
num_epochs = 20 | |
batch_size = 128 | |
learning_rate = 1e-3 | |
def add_noise(img): | |
noise = torch.randn(img.size()) * 0.2 | |
noisy_img = img + noise | |
return noisy_img | |
def plot_sample_img(img, name): | |
img = img.view(1, 28, 28) | |
save_image(img, './sample_{}.png'.format(name)) | |
def min_max_normalization(tensor, min_value, max_value): | |
min_tensor = tensor.min() | |
tensor = (tensor - min_tensor) | |
max_tensor = tensor.max() | |
tensor = tensor / max_tensor | |
tensor = tensor * (max_value - min_value) + min_value | |
return tensor | |
def tensor_round(tensor): | |
return torch.round(tensor) | |
img_transform = transforms.Compose([ | |
transforms.ToTensor(), | |
transforms.Lambda(lambda tensor:min_max_normalization(tensor, 0, 1)), | |
transforms.Lambda(lambda tensor:tensor_round(tensor)) | |
]) | |
dataset = MNIST('./data', transform=img_transform, download=True) | |
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True) | |
class autoencoder(nn.Module): | |
def __init__(self): | |
super(autoencoder, self).__init__() | |
self.encoder = nn.Sequential( | |
nn.Linear(28 * 28, 256), | |
nn.ReLU(True), | |
nn.Linear(256, 64), | |
nn.ReLU(True)) | |
self.decoder = nn.Sequential( | |
nn.Linear(64, 256), | |
nn.ReLU(True), | |
nn.Linear(256, 28 * 28), | |
nn.Sigmoid()) | |
def forward(self, x): | |
x = self.encoder(x) | |
x = self.decoder(x) | |
return x | |
model = autoencoder().cuda() | |
criterion = nn.BCELoss() | |
optimizer = torch.optim.Adam( | |
model.parameters(), lr=learning_rate, weight_decay=1e-5) | |
for epoch in range(num_epochs): | |
for data in dataloader: | |
img, _ = data | |
img = img.view(img.size(0), -1) | |
noisy_img = add_noise(img) | |
noisy_img = Variable(noisy_img).cuda() | |
img = Variable(img).cuda() | |
# ===================forward===================== | |
output = model(noisy_img) | |
loss = criterion(output, img) | |
MSE_loss = nn.MSELoss()(output, img) | |
# ===================backward==================== | |
optimizer.zero_grad() | |
loss.backward() | |
optimizer.step() | |
# ===================log======================== | |
print('epoch [{}/{}], loss:{:.4f}, MSE_loss:{:.4f}' | |
.format(epoch + 1, num_epochs, loss.data[0], MSE_loss.data[0])) | |
if epoch % 10 == 0: | |
x = to_img(img.cpu().data) | |
x_hat = to_img(output.cpu().data) | |
x_noisy = to_img(noisy_img.cpu().data) | |
weights = to_img(model.encoder[0].weight.cpu().data) | |
save_image(x, './mlp_img/x_{}.png'.format(epoch)) | |
save_image(x_hat, './mlp_img/x_hat_{}.png'.format(epoch)) | |
save_image(x_noisy, './mlp_img/x_noisy_{}.png'.format(epoch)) | |
save_image(weights, './filters/epoch_{}.png'.format(epoch)) | |
torch.save(model.state_dict(), './sim_autoencoder.pth') |
Thanks for the code, it works really nicely
I just want to say toTensor already normalizes the image between a range of 0 and 1 so the lambda is not needed.
The documentation is below unless I am thinking of something else.
Converts a PIL Image or numpy.ndarray (H x W x C) in the range [0, 255] to a torch.FloatTensor of shape (C x H x W) in the range [0.0, 1.0].
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Hi, man. Your code is fine. However, do you know how to share the transpose of encoder's weight matrix to decoder? Thanks in advance~