Navigation Menu

Skip to content

Instantly share code, notes, and snippets.

@t-ae
Last active July 23, 2019 07:39
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save t-ae/3cd33e4f0535b98c2df9bfeef49645e5 to your computer and use it in GitHub Desktop.
Save t-ae/3cd33e4f0535b98c2df9bfeef49645e5 to your computer and use it in GitHub Desktop.
VAE on Swift for TensorFlow
// VAE by modifying official autoencoder code
// https://github.com/tensorflow/swift-models/blob/2fa11ba1d28ef09454af9da77e22b585cf3e5b7b/Autoencoder/main.swift
// Copyright 2019 The TensorFlow Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
import Foundation
import TensorFlow
import Python
// Import Python modules
let matplotlib = Python.import("matplotlib")
let np = Python.import("numpy")
let plt = Python.import("matplotlib.pyplot")
// Turn off using display on server / linux
matplotlib.use("Agg")
// Some globals
let epochCount = 50
let batchSize = 128
let outputFolder = "./output/"
let imageHeight = 28, imageWidth = 28
func plot(image: [Float], name: String) {
// Create figure
let ax = plt.gca()
let array = np.array([image])
let pixels = array.reshape([imageHeight, imageWidth])
if !FileManager.default.fileExists(atPath: outputFolder) {
try! FileManager.default.createDirectory(atPath: outputFolder,
withIntermediateDirectories: false,
attributes: nil)
}
ax.imshow(pixels, cmap: "gray")
plt.savefig("\(outputFolder)\(name).png", dpi: 300)
plt.close()
}
/// Reads a file into an array of bytes.
func readFile(_ filename: String) -> [UInt8] {
let possibleFolders = [".", "Resources", "Autoencoder/Resources"]
for folder in possibleFolders {
let parent = URL(fileURLWithPath: folder)
let filePath = parent.appendingPathComponent(filename).path
guard FileManager.default.fileExists(atPath: filePath) else {
continue
}
let d = Python.open(filePath, "rb").read()
return Array(numpy: np.frombuffer(d, dtype: np.uint8))!
}
print("Failed to find file with name \(filename) in the following folders: \(possibleFolders).")
exit(-1)
}
/// Reads MNIST images and labels from specified file paths.
func readMNIST(imagesFile: String, labelsFile: String) -> (images: Tensor<Float>,
labels: Tensor<Int32>) {
print("Reading data.")
let images = readFile(imagesFile).dropFirst(16).map { Float($0) }
let labels = readFile(labelsFile).dropFirst(8).map { Int32($0) }
let rowCount = labels.count
print("Constructing data tensors.")
return (
images: Tensor(shape: [rowCount, imageHeight * imageWidth], scalars: images) / 255.0,
labels: Tensor(labels)
)
}
struct Encoder: Layer {
typealias Input = Tensor<Float>
typealias Output = Encoded
var encoder1 = Dense<Float>(inputSize: imageHeight * imageWidth, outputSize: 128,
activation: relu)
var encoder2 = Dense<Float>(inputSize: 128, outputSize: 64, activation: relu)
var encoder3 = Dense<Float>(inputSize: 64, outputSize: 12, activation: relu)
var encoderMean = Dense<Float>(inputSize: 12, outputSize: 4, activation: identity)
var encoderLogVar = Dense<Float>(inputSize: 12, outputSize: 4, activation: identity)
@differentiable
func callAsFunction(_ input: Input) -> Output {
let intermediate = input.sequenced(through: encoder1, encoder2, encoder3)
let mean = encoderMean(intermediate)
let logVar = encoderLogVar(intermediate)
return Encoded(mean: mean, logVar: logVar)
}
}
struct Encoded: Differentiable {
var mean: Tensor<Float>
var logVar: Tensor<Float>
}
struct Decoder: Layer {
typealias Input = Tensor<Float>
typealias Output = Tensor<Float>
var decoder1 = Dense<Float>(inputSize: 4, outputSize: 12, activation: relu)
var decoder2 = Dense<Float>(inputSize: 12, outputSize: 64, activation: relu)
var decoder3 = Dense<Float>(inputSize: 64, outputSize: 128, activation: relu)
var decoder4 = Dense<Float>(inputSize: 128, outputSize: imageHeight * imageWidth,
activation: tanh)
@differentiable
func callAsFunction(_ input: Input) -> Output {
return input.sequenced(through: decoder1, decoder2, decoder3, decoder4)
}
}
struct VAE: Layer {
typealias Input = Tensor<Float>
typealias Output = VAEResult
var encoder = Encoder()
var decoder = Decoder()
@differentiable
func callAsFunction(_ input: Input) -> Output {
let encoded = encoder(input)
let mean = encoded.mean
let logVar = encoded.logVar
let gaussian = Tensor<Float>(randomNormal: mean.shape)
let std = exp(logVar/2)
let images = decoder(gaussian * std + mean)
return VAEResult(image: images, mean: mean, logVar: logVar)
}
}
struct VAEResult: Differentiable {
var image: Tensor<Float>
var mean: Tensor<Float>
var logVar: Tensor<Float>
}
@differentiable
func loss(result: VAEResult, original: Tensor<Float>) -> Tensor<Float> {
let reconstrcutionLoss = (result.image - original).squared().sum(alongAxes: 1)
let klLoss = (1 + result.logVar - result.mean.squared() - exp(result.logVar))
.sum(alongAxes: 1) * -0.5
return (reconstrcutionLoss + klLoss).mean()
}
// MNIST data logic
func minibatch<Scalar>(in x: Tensor<Scalar>, at index: Int) -> Tensor<Scalar> {
let start = index * batchSize
return x[start..<start+batchSize]
}
let (images, numericLabels) = readMNIST(imagesFile: "train-images-idx3-ubyte",
labelsFile: "train-labels-idx1-ubyte")
let labels = Tensor<Float>(oneHotAtIndices: numericLabels, depth: 10)
var vae = VAE()
let optimizer = Adam(for: vae)
// Training loop
for epoch in 1...epochCount {
let sampleImage = Tensor(shape: [1, imageHeight * imageWidth], scalars: images[epoch].scalars)
let testResult = vae(sampleImage)
plot(image: sampleImage.scalars, name: "epoch-\(epoch)-input")
plot(image: testResult.image.scalars, name: "epoch-\(epoch)-output")
let sampleLoss = loss(result: testResult, original: sampleImage)
print("[Epoch: \(epoch)] Loss: \(sampleLoss)")
for i in 0 ..< Int(labels.shape[0]) / batchSize {
let x = minibatch(in: images, at: i)
let 𝛁model = vae.gradient { vae -> Tensor<Float> in
let result = vae(x)
return loss(result: result, original: x)
}
optimizer.update(&vae.allDifferentiableVariables, along: 𝛁model)
}
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment