Created
October 28, 2016 13:24
-
-
Save waynejo/8b7200035232dc89d31874ec8891975d to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
package io.github.robhinds.nn | |
import io.github.robhinds.nn.builder.NNBuilder | |
import scala.io.Source | |
object Main { | |
def main(args: Array[String]): Unit = { | |
val data = Source.fromURL(getClass.getResource("/train.csv")) | |
val lines: Iterator[String] = data.getLines() | |
println(lines.next()) | |
val iter = lines.map(_.split(",")).toList | |
(1 to 1000000 by 50).foreach(v => { | |
val network = NNBuilder | |
.inputNeurons(3) | |
.hiddenNeurons(20) | |
.outputNeurons(1) | |
.learningRate(0.05) | |
.iterations(v) | |
.build() | |
val inputData = iter.map(columns => List(if ("male" == columns(5)) 1.0 else 0.0, (if (columns(6).isEmpty) 35 else columns(6).toDouble) / 100.0, (columns(2).toDouble - 1) / 2.0)) | |
val outputData = iter.map(columns => List(columns(1).toDouble)) | |
print(s"$v : ") | |
network.train(inputData, outputData, inputData.size * 20 / 100) | |
}) | |
} | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment