Last active
August 12, 2017 20:23
-
-
Save MulticolorWorld/48dc7e51d46462c736af66638303bc68 to your computer and use it in GitHub Desktop.
LeNetExample
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
fun main(args: Array<String>) { | |
val seed = 12345L | |
val height = /*image height, all same value*/ | |
val width = /*image width, all same value*/ | |
val channels = 3 | |
val batchSize = 64 | |
val images = /* collection of image file uri. two parent folder(two labels)*/ | |
val dir = CollectionInputSplit(images) | |
val labelMaker = ParentPathLabelGenerator() | |
val pathFilter = BalancedPathFilter(randomNumGen,labelMaker,200) | |
val filesInDirSplit = dir.sample(pathFilter, 80.0, 20.0) | |
val trainData = filesInDirSplit[0] | |
val testData = filesInDirSplit[1] | |
val trainIter = RecordReaderDataSetIterator( | |
ImageRecordReader(height, width, channels, labelMaker).apply { | |
initialize(trainData) | |
}, | |
batchSize, | |
1, | |
2 | |
) | |
val testIter = RecordReaderDataSetIterator( | |
ImageRecordReader(height, width, channels, labelMaker).apply { | |
initialize(testData) | |
}, | |
batchSize, | |
1, | |
2 | |
) | |
val conf = NeuralNetConfiguration.Builder().apply { | |
this.seed = seed | |
iterations(1) | |
regularization(true).l2(0.0005) | |
learningRate = 0.01 | |
weightInit = WeightInit.XAVIER | |
optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT) | |
updater(Updater.NESTEROVS).momentum(0.9) | |
list() | |
}.list().apply { | |
layer(0, ConvolutionLayer.Builder(5, 5) | |
.nIn(channels) | |
.stride(1, 1) | |
.nOut(20) | |
.activation(Activation.IDENTITY) | |
.build()) | |
layer(1, SubsamplingLayer.Builder(SubsamplingLayer.PoolingType.MAX) | |
.kernelSize(2, 2) | |
.stride(2, 2) | |
.build()) | |
layer(2, ConvolutionLayer.Builder(5, 5) | |
.stride(1, 1) | |
.nOut(50) | |
.activation(Activation.IDENTITY) | |
.build()) | |
layer(3, SubsamplingLayer.Builder(SubsamplingLayer.PoolingType.MAX) | |
.kernelSize(2, 2) | |
.stride(2, 2) | |
.build()) | |
layer(4, DenseLayer.Builder() | |
.activation(Activation.RELU) | |
.nOut(500) | |
.build()) | |
layer(5, OutputLayer.Builder(LossFunctions.LossFunction.NEGATIVELOGLIKELIHOOD) | |
.nOut(2) | |
.activation(Activation.SOFTMAX) | |
.build()) | |
inputType = InputType.convolutional(height, width, channels) | |
backprop(true) | |
pretrain(false) | |
}.build() | |
val model = MultiLayerNetwork(conf).apply { | |
init() | |
setListeners(ScoreIterationListener(1)) | |
} | |
model.fit(trainIter) | |
val evaluation = Evaluation(2) | |
for (ds in testIter) { | |
evaluation.eval(ds.labels, model.output(ds.featureMatrix, false)) | |
} | |
println(evaluation.stats()) | |
} | |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment