Skip to content

Instantly share code, notes, and snippets.

/mnist_gan.m

Created Apr 5, 2017
Embed
What would you like to do?
trainingData = <|
"random" -> RandomInteger[{1, 10}, Length[mnistDigits]],
"Input" ->
Map[RandomReal[{-0.05, 0.05}, {1, 28, 28}] +
ArrayReshape[ImageData[#], {1, 28, 28}] &, mnistDigits]|>;
generator =
NetChain[{EmbeddingLayer[8*6*6, 10], ReshapeLayer[{8, 6, 6}],
DeconvolutionLayer[8, 4, "Stride" -> 2], Ramp,
DeconvolutionLayer[1, 4, "Stride" -> 2, "PaddingSize" -> 1],
LogisticSigmoid}];
discriminator =
NetChain[{ConvolutionLayer[4, 4], Tanh, PoolingLayer[3, 1], 16,
Ramp, 1}, "Input" -> {1, 28, 28}];
wganNet =
NetGraph[<|"gen" -> generator,
"discrimop" -> NetMapOperator[discriminator],
"cat" -> CatenateLayer[],
"reshape" -> ReshapeLayer[{2, 1, 28, 28}],
"flat" -> FlattenLayer[], "total" -> SummationLayer[],
"scale" ->
ConstantTimesLayer["Scaling" -> {-1, 1}]|>, {NetPort["random"] ->
"gen" -> "cat", NetPort["Input"] -> "cat",
"cat" ->
"reshape" -> "discrimop" -> "flat" -> "scale" -> "total"},
"Input" -> {1, 28, 28}];
NetTrain[wganNet, trainingData, "Output",
Method -> {"ADAM", "Beta1" -> 0.5, "LearningRate" -> 0.01,
"WeightClipping" -> {{"discrimop", 1} -> 1,
"discrimop" -> 0.001}},
TrainingProgressReporting -> {progressFuncCreator[Range[10]],
"Interval" -> Quantity[0.3, "Seconds"]},
LearningRateMultipliers -> {"scale" -> 0, "gen" -> -0.05},
TargetDevice -> "GPU", BatchSize -> 32, MaxTrainingRounds -> 5000];
@joohaeng

This comment has been minimized.

Copy link

@joohaeng joohaeng commented May 9, 2019

Thank you for sharing! In v12, there should be a change to make this code run: "WeightClipping" -> {"discrimop" -> 0.001}.

@titovie70

This comment has been minimized.

Copy link

@titovie70 titovie70 commented Oct 8, 2019

Thank you for sharing!
The following error appears in Mathematica 11.3.0.0: "FlattenLayer::nettypeinc: Type inconsistency in FlattenLayer: dynamic dimensions cannot be flattened with other levels."
I would be very grateful if someone could help me.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment