Created
April 5, 2017 19:55
-
-
Save anonymous/4cccd394f4f7954d577ceb2f75971094 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
trainingData = <| | |
"random" -> RandomInteger[{1, 10}, Length[mnistDigits]], | |
"Input" -> | |
Map[RandomReal[{-0.05, 0.05}, {1, 28, 28}] + | |
ArrayReshape[ImageData[#], {1, 28, 28}] &, mnistDigits]|>; | |
generator = | |
NetChain[{EmbeddingLayer[8*6*6, 10], ReshapeLayer[{8, 6, 6}], | |
DeconvolutionLayer[8, 4, "Stride" -> 2], Ramp, | |
DeconvolutionLayer[1, 4, "Stride" -> 2, "PaddingSize" -> 1], | |
LogisticSigmoid}]; | |
discriminator = | |
NetChain[{ConvolutionLayer[4, 4], Tanh, PoolingLayer[3, 1], 16, | |
Ramp, 1}, "Input" -> {1, 28, 28}]; | |
wganNet = | |
NetGraph[<|"gen" -> generator, | |
"discrimop" -> NetMapOperator[discriminator], | |
"cat" -> CatenateLayer[], | |
"reshape" -> ReshapeLayer[{2, 1, 28, 28}], | |
"flat" -> FlattenLayer[], "total" -> SummationLayer[], | |
"scale" -> | |
ConstantTimesLayer["Scaling" -> {-1, 1}]|>, {NetPort["random"] -> | |
"gen" -> "cat", NetPort["Input"] -> "cat", | |
"cat" -> | |
"reshape" -> "discrimop" -> "flat" -> "scale" -> "total"}, | |
"Input" -> {1, 28, 28}]; | |
NetTrain[wganNet, trainingData, "Output", | |
Method -> {"ADAM", "Beta1" -> 0.5, "LearningRate" -> 0.01, | |
"WeightClipping" -> {{"discrimop", 1} -> 1, | |
"discrimop" -> 0.001}}, | |
TrainingProgressReporting -> {progressFuncCreator[Range[10]], | |
"Interval" -> Quantity[0.3, "Seconds"]}, | |
LearningRateMultipliers -> {"scale" -> 0, "gen" -> -0.05}, | |
TargetDevice -> "GPU", BatchSize -> 32, MaxTrainingRounds -> 5000]; |
Thank you for sharing!
The following error appears in Mathematica 11.3.0.0: "FlattenLayer::nettypeinc: Type inconsistency in FlattenLayer: dynamic dimensions cannot be flattened with other levels."
I would be very grateful if someone could help me.
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Thank you for sharing! In v12, there should be a change to make this code run: "WeightClipping" -> {"discrimop" -> 0.001}.