Skip to content

Instantly share code, notes, and snippets.

@vishnubob
Created March 31, 2018 20:44
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save vishnubob/469147549f71911cf51f69d896bd1bc9 to your computer and use it in GitHub Desktop.
Save vishnubob/469147549f71911cf51f69d896bd1bc9 to your computer and use it in GitHub Desktop.
Simple Autoencoder with useful PIL based image tiler
# MXNet Autoencoder
# based on example from SherlockLiao
# https://github.com/SherlockLiao/mxnet-gluon-tutorial/blob/master/08-AutoEncoder/simple_autoencoder.py
import operator
import bisect
import os
import numpy as np
import mxnet as mx
from mxnet import gluon
from PIL import Image
def gcd(a,b):
while b > 0:
a, b = b, a % b
return a
def lcm(a, b):
return a * b / gcd(a, b)
def find_res(cnt, ratio=.5):
# given a count of images, find an arrangement of width and height
# that is as close to .5 in ratio as possible
vals = set([gcd(cnt, x) for x in range(1, cnt)])
rats = []
for width in vals:
height = cnt / width
rat = width / height
rats.append((rat, width, height))
rats = sorted(rats, key=operator.itemgetter(0))
keys = [rat[0] for rat in rats]
idx = bisect.bisect_left(keys, ratio)
idx = min(len(rats) - 1, idx)
width = max(rats[idx][1:])
height = min(rats[idx][1:])
return list(map(int, (width, height)))
def norm_ip(img, min, max):
img = np.clip(img, min, max)
img = (img - min) / (max - min + 1e-5)
return img
def norm_range(t, range=None):
if range is not None:
return norm_ip(t, range[0], range[1])
else:
return norm_ip(t, t.min(), t.max())
def tile_image(imgs):
# based on code found at
# https://stackoverflow.com/questions/26521365/cleanly-tile-numpy-array-of-images-stored-in-a-flattened-1d-format
imgcnt = imgs.shape[0]
(prow, pcol) = find_res(imgcnt)
sqr = int(round(imgs.shape[1] ** .5))
pad = prow - imgs.shape[0] % pcol
imgs = imgs.reshape(imgs.shape[0], sqr, sqr)
tiled = []
total = prow * pcol
for i in range(0, total, prow):
tiled.append(np.hstack(imgs[i:i+prow,:,:]))
return np.vstack(tiled)
if not os.path.exists('./mlp_img'):
os.mkdir('./mlp_img')
def to_img(x):
x = 0.5 * (x + 1)
x = x.clamp(0, 1)
x = x.view(x.size(0), 1, 28, 28)
return x
num_epochs = 100
batch_size = 128
learning_rate = 1e-3
ctx = mx.gpu()
def transform(data, label):
return (data.astype('float32') / 255 - 0.5) / 0.5, label.astype('float32')
mnist_train = gluon.data.vision.FashionMNIST(train=True, transform=transform)
#mnist_train = gluon.data.vision.MNIST(train=True, transform=transform)
dataloader = gluon.data.DataLoader(mnist_train, batch_size=batch_size, shuffle=True)
class autoencoder(gluon.Block):
def __init__(self):
super(autoencoder, self).__init__()
with self.name_scope():
self.encoder = gluon.nn.Sequential('encoder_')
with self.encoder.name_scope():
self.encoder.add(gluon.nn.Dense(128, activation='relu'))
self.encoder.add(gluon.nn.Dense(64, activation='relu'))
self.encoder.add(gluon.nn.Dense(12, activation='relu'))
self.encoder.add(gluon.nn.Dense(3))
self.decoder = gluon.nn.Sequential('decoder_')
with self.decoder.name_scope():
self.decoder.add(gluon.nn.Dense(12, activation='relu'))
self.decoder.add(gluon.nn.Dense(64, activation='relu'))
self.decoder.add(gluon.nn.Dense(128, activation='relu'))
self.decoder.add(gluon.nn.Dense(28 * 28, activation='tanh'))
def forward(self, x):
x = self.encoder(x)
x = self.decoder(x)
return x
def autoencode():
model = autoencoder()
model.collect_params().initialize(mx.init.Xavier(), ctx=ctx)
criterion = gluon.loss.L2Loss()
optimizer = gluon.Trainer(model.collect_params(), 'adam',
{'learning_rate': learning_rate,
'wd': 1e-5})
for epoch in range(num_epochs):
running_loss = 0.0
n_total = 0.0
for data in dataloader:
img, _ = data
img = img.reshape((img.shape[0], -1)).as_in_context(ctx)
with mx.autograd.record():
output = model(img)
loss = criterion(output, img)
loss.backward()
optimizer.step(img.shape[0])
running_loss += mx.nd.sum(loss).asscalar()
n_total += img.shape[0]
# ===================log========================
print('epoch [{}/{}], loss:{:.4f}'
.format(epoch + 1, num_epochs, running_loss / n_total))
if epoch % 10 == 0:
sqr = int(round(output.shape[1] ** .5))
fn = './mlp_img/{}_autoencoder.png'.format(epoch)
im = output.asnumpy()
(minv, maxv) = (np.min(im), np.max(im))
im = (((im + minv) / 2) * 255.0).astype(np.uint8)
im = TileImage(im)
sz = tuple(np.array([im.shape[1], im.shape[0]]) * 2)
im = Image.fromarray(im, mode='L').resize(sz)
im.save(fn)
model.save_params('./handwritten-digits-autoencoder.params')
if __name__ == "__main__":
autoencode()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment