Navigation Menu

Skip to content

Instantly share code, notes, and snippets.

@piscisaureus
Created April 23, 2018 16:59
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save piscisaureus/72d2a9536aeb55a3f7631b28c8b44fcd to your computer and use it in GitHub Desktop.
Save piscisaureus/72d2a9536aeb55a3f7631b28c8b44fcd to your computer and use it in GitHub Desktop.
// tslint:disable:typedef
// tslint:disable:comment-format
import * as tf from "@tensorflow/tfjs";
const DIM = 64; // Model dimensionality.
const BATCH_SIZE = 50;
const CRITIC_ITERS = 5; // How many critic iterations per generator iteration
const LAMBDA = 10; // Gradient penalty lambda hyperparameter
const ITERS = 200000; // How many generator iterations to train for
const OUTPUT_DIM = 784; // Number of pixels in MNIST (28*28)
const generator = tf.sequential({ layers: [
// Preprocess:
// nn.Linear(128, 4*4*4*DIM),
// nn.ReLU(True),
tf.layers.dense({inputShape: [128], units: 4*4*4 * DIM}),
tf.layers.activation({activation: "relu"}),
// Reshape
// output = output.view(-1, 4*DIM, 4, 4)
tf.layers.reshape({ targetShape: [4 * DIM, 4, 4] }),
// Block 1:
// nn.ConvTranspose2d(4*DIM, 2*DIM, 5),
// nn.ReLU(True),
tf.layers.conv2dTranspose({filters: 2*DIM, kernelSize: 5}),
tf.layers.activation({activation: "relu"}),
// Block 2:
// nn.ConvTranspose2d(2*DIM, DIM, 5),
// nn.ReLU(True),
tf.layers.conv2dTranspose({filters: DIM, kernelSize: 5}),
tf.layers.activation({activation: "relu"}),
// Deconv output:
// nn.ConvTranspose2d(DIM, 1, 8, stride=2)
tf.layers.conv2dTranspose({filters: 1, kernelSize: 8, strides: 2}),
// Finalize:
// output = self.sigmoid(output)
// output = output.view(-1, OUTPUT_DIM)
tf.layers.activation({ activation: "sigmoid" }),
tf.layers.reshape({ targetShape: [OUTPUT_DIM]} )
]});
const discriminator = tf.sequential({ layers: [
// Reshape input:
// input.view(-1, 1, 28, 28)
tf.layers.reshape({ targetShape: [1, 28, 28] }),
// Main block:
// nn.Conv2d(1, DIM, 5, stride=2, padding=2),
// nn.ReLU(True),
// nn.Conv2d(DIM, 2*DIM, 5, stride=2, padding=2),
// nn.ReLU(True),
// nn.Conv2d(2*DIM, 4*DIM, 5, stride=2, padding=2),
// nn.ReLU(True),
tf.layers.zeroPadding2d({ padding: 2 }),
tf.layers.conv2d({ filters: DIM, kernelSize: 5, strides: 2 }),
tf.layers.activation({activation: "relu"}),
tf.layers.zeroPadding2d({ padding: 2 }),
tf.layers.conv2d({ filters: 2 * DIM, kernelSize: 5, strides: 2 }),
tf.layers.activation({activation: "relu"}),
tf.layers.zeroPadding2d({ padding: 2 }),
tf.layers.conv2d({ filters: 4 * DIM, kernelSize: 5, strides: 2 }),
tf.layers.activation({activation: "relu"}),
// Reshape and fully connected layer
// output = output.view(-1, 4*4*4*DIM)
// nn.Linear(4*4*4*DIM, 1)
tf.layers.reshape({ targetShape: [-1, 4**3 * DIM]}),
tf.layers.dense({ units: 1 })
]});
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment