Create a gist now

Instantly share code, notes, and snippets.

What would you like to do?
Single-script Blocks reproduction of the Goodfellow et al (2014) MNIST GAN.
from collections import OrderedDict
from blocks.algorithms import GradientDescent, Momentum
from blocks.bricks import (
MLP, Rectifier, Logistic, Linear, LinearMaxout, Sequence, Random
)
from blocks.extensions import Printing, Timing
from blocks.extensions.training import SharedVariableModifier
from blocks.extensions.monitoring import (DataStreamMonitoring,
TrainingDataMonitoring)
from blocks.filter import VariableFilter
from blocks.graph import ComputationGraph, apply_dropout
from blocks.main_loop import MainLoop
from blocks.initialization import Uniform, Constant
from blocks.roles import INPUT
from blocks.select import Selector
from fuel.datasets import MNIST
from fuel.transformers import Flatten
from fuel.streams import DataStream
from fuel.schemes import SequentialScheme
import numpy
import theano
from theano import tensor
from theano.tensor.nnet import binary_crossentropy
if __name__ == "__main__":
train = MNIST(which_sets=('train',), sources=('features',),
subset=range(0, 50000))
valid = MNIST(which_sets=('train',), sources=('features',),
subset=range(50000, 60000))
noise_dim = 100
generator = MLP([Rectifier(), Rectifier(), Logistic()],
[noise_dim, 1200, 1200, 784],
weights_init=Uniform(width=0.1),
biases_init=Constant(0.0),
name='generator_mlp')
raw_train_data, = train.get_data(request=range(train.num_examples))
float_train_data = (raw_train_data / 255.)
marginals = numpy.clip(float_train_data.reshape(
float_train_data.shape[0], -1).mean(axis=0), 1e-7, 1 - 1e-7)
generator.push_initialization_config()
generator_last_linear = generator.linear_transformations[-1]
generator_last_linear.weights_init = Constant(0)
generator.initialize()
log_odds = (numpy.log(marginals / (1 - marginals))
.astype(theano.config.floatX))
generator_last_linear.b.set_value(log_odds)
# Build & initialize the discriminator.
discriminator = Sequence([LinearMaxout(784, 240, 5, name='d0',
weights_init=Uniform(width=0.01),
biases_init=Constant(0.0)).apply,
LinearMaxout(240, 240, 5, name='d1',
weights_init=Uniform(width=0.01),
biases_init=Constant(0.0)).apply,
Linear(240, 1, weights_init=Constant(0.01),
name='d2linear',
biases_init=Constant(0)).apply,
Logistic(name='d2logistic').apply],
name='discriminator')
discriminator.initialize()
# Build the graph.
features = tensor.matrix('features')
disc_pred = discriminator.apply(features)
SQRT3 = numpy.cast[theano.config.floatX](numpy.sqrt(3))
noise = Random().theano_rng.uniform(low=-SQRT3 - 1e-7, high=SQRT3 + 1e-7,
size=(features.shape[0], 100))
sample = generator.apply(noise)
gen_pred = discriminator.apply(sample)
disc_cost = 0.5 * (binary_crossentropy(gen_pred, 0).mean() +
binary_crossentropy(disc_pred, 1).mean())
gen_cost = binary_crossentropy(gen_pred, 1).mean()
d_upper_linear = [discriminator.children[1].linear,
discriminator.children[2]]
cg = ComputationGraph([gen_cost, disc_cost])
# Use the VariableFilter to get the inputs to these bricks.
upper_disc_inputs = VariableFilter(bricks=d_upper_linear,
roles=[INPUT])(cg)
# Apply 0.5 dropout to those variables.
dropped_cg = apply_dropout(cg, upper_disc_inputs, 0.5)
# Filter for all the inputs to the discriminator brick itself.
discriminator_inputs = VariableFilter(bricks=[discriminator],
roles=[INPUT])(dropped_cg)
# Apply 0.2 dropout on those inputs.
final_cg = apply_dropout(dropped_cg, discriminator_inputs, 0.2)
drop_gen_cost, drop_disc_cost = final_cg.outputs
drop_gen_cost.name = 'gen_cost_with_dropout'
drop_disc_cost.name = 'disc_cost_with_dropout'
gen_params = list(Selector(generator).get_parameters().values())
disc_params = list(Selector(discriminator).get_parameters().values())
gen_grads = tensor.grad(drop_gen_cost, gen_params)
disc_grads = tensor.grad(drop_disc_cost, disc_params)
gradients = OrderedDict(zip(gen_params + disc_params,
gen_grads + disc_grads))
rule = Momentum(0.1, 0.5)
algorithm = GradientDescent(cost=gen_cost + disc_cost, # not used
step_rule=rule,
gradients=gradients) # these are used
# Build the data stream for training and monitoring.
train_stream = Flatten(DataStream.default_stream(
train,
iteration_scheme=SequentialScheme(batch_size=100,
examples=train.num_examples)))
valid_stream = Flatten(DataStream.default_stream(
train,
iteration_scheme=SequentialScheme(batch_size=100,
examples=train.num_examples)))
# Set up the main loop and appropriate extensions.
main_loop = MainLoop(algorithm=algorithm, data_stream=train_stream,
extensions=[TrainingDataMonitoring([drop_gen_cost,
drop_disc_cost],
prefix='train',
after_epoch=True),
DataStreamMonitoring(
data_stream=valid_stream,
variables=[drop_gen_cost,
drop_disc_cost],
prefix='valid',
after_epoch=True),
SharedVariableModifier(
rule.momentum,
lambda x: min(0.7, 0.5 + x * 0.2 /
125000),
after_batch=False,
after_epoch=True),
SharedVariableModifier(
rule.learning_rate,
lambda x: max(1e-6, 0.1 /
(1.000004 ** x)),
after_batch=True),
Timing(), Printing()])
main_loop.run()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment