Skip to content

Instantly share code, notes, and snippets.

@sherjilozair
Last active August 29, 2015 14:01
Show Gist options
  • Save sherjilozair/d1cd2684b8dc94deb55c to your computer and use it in GitHub Desktop.
Save sherjilozair/d1cd2684b8dc94deb55c to your computer and use it in GitHub Desktop.
reboot.yaml
!obj:pylearn2.train.Train {
dataset: &train !obj:pylearn2.datasets.mnist.MNIST {
which_set: 'train',
start: 0,
stop: 50000
},
model: !obj:galatea.adversarial.AdversaryPair {
generator: !obj:galatea.adversarial.Generator {
mlp: !obj:pylearn2.models.mlp.MLP {
layers: [
!obj:pylearn2.models.mlp.RectifiedLinear {
layer_name: 'h0',
dim: 1200,
irange: .05,
max_col_norm: 1.9365,
},
!obj:pylearn2.models.mlp.RectifiedLinear {
layer_name: 'h1',
dim: 1200,
irange: .05,
max_col_norm: 1.9365,
},
!obj:pylearn2.models.mlp.Sigmoid {
max_col_norm: 1.9365,
init_bias: !obj:pylearn2.models.dbm.init_sigmoid_bias_from_marginals { dataset: *train},
layer_name: 'y',
irange: .05,
output_space: !obj:pylearn2.space.Conv2DSpace {
shape: [28, 28],
num_channels: 1
}
}
],
}},
discriminator:
!obj:pylearn2.models.mlp.MLP {
input_space: !obj:pylearn2.space.Conv2DSpace {
shape: [28, 28],
num_channels: 1
},
layers: [
!obj:pylearn2.models.mlp.ConvRectifiedLinear {
#W_lr_scale: .1,
#b_lr_scale: .1,
layer_name: 'h0',
kernel_shape: [100, 100],
output_channels: 1,
pool_shape: [50, 50],
pool_stride: [20, 20],
max_kernel_norm: 1.9365,
irange: .005,
#max_col_norm: 1.9365,
},
!obj:pylearn2.models.mlp.ConvRectifiedLinear {
#W_lr_scale: .1,
#b_lr_scale: .1,
layer_name: 'h1',
output_channels: 1,
kernel_shape: [100, 100],
pool_shape: [50, 50],
pool_stride: [20, 20],
max_kernel_norm: 1.9365,
irange: .005,
},
!obj:pylearn2.models.mlp.Sigmoid {
#W_lr_scale: .1,
#b_lr_scale: .1,
max_col_norm: 1.9365,
layer_name: 'y',
dim: 1,
irange: .005
}
],
},
},
algorithm: !obj:pylearn2.training_algorithms.sgd.SGD {
batch_size: 100,
learning_rate: .1,
learning_rule: !obj:pylearn2.training_algorithms.learning_rule.Momentum {
init_momentum: .5,
},
monitoring_dataset:
{
'train' : *train,
'valid' : !obj:pylearn2.datasets.mnist.MNIST {
which_set: 'train',
start: 50000,
stop: 60000
},
'test' : !obj:pylearn2.datasets.mnist.MNIST {
which_set: 'test',
}
},
cost: !obj:galatea.adversarial.AdversaryCost2 {
scale_grads: 0,
#target_scale: 1.,
discriminator_default_input_include_prob: .5,
discriminator_input_include_probs: {
'h0': .8
},
discriminator_default_input_scale: 2.,
discriminator_input_scales: {
'h0': 1.25
}
},
#!obj:pylearn2.costs.mlp.dropout.Dropout {
# input_include_probs: { 'h0' : .8 },
# input_scales: { 'h0': 1. }
#},
#termination_criterion: !obj:pylearn2.termination_criteria.MonitorBased {
# channel_name: "valid_y_misclass",
# prop_decrease: 0.,
# N: 100
#},
update_callbacks: !obj:pylearn2.training_algorithms.sgd.ExponentialDecay {
decay_factor: 1.000004,
min_lr: .000001
}
},
extensions: [
#!obj:pylearn2.train_extensions.best_params.MonitorBasedSaveBest {
# channel_name: 'valid_y_misclass',
# save_path: "${PYLEARN2_TRAIN_FILE_FULL_STEM}_best.pkl"
#},
!obj:pylearn2.training_algorithms.learning_rule.MomentumAdjustor {
start: 1,
saturate: 250,
final_momentum: .7
}
],
save_path: "${PYLEARN2_TRAIN_FILE_FULL_STEM}.pkl",
save_freq: 1
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment