Skip to content

Instantly share code, notes, and snippets.

@waynejo
Created October 28, 2016 13:24
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save waynejo/8b7200035232dc89d31874ec8891975d to your computer and use it in GitHub Desktop.
Save waynejo/8b7200035232dc89d31874ec8891975d to your computer and use it in GitHub Desktop.
package io.github.robhinds.nn
import io.github.robhinds.nn.builder.NNBuilder
import scala.io.Source
object Main {
def main(args: Array[String]): Unit = {
val data = Source.fromURL(getClass.getResource("/train.csv"))
val lines: Iterator[String] = data.getLines()
println(lines.next())
val iter = lines.map(_.split(",")).toList
(1 to 1000000 by 50).foreach(v => {
val network = NNBuilder
.inputNeurons(3)
.hiddenNeurons(20)
.outputNeurons(1)
.learningRate(0.05)
.iterations(v)
.build()
val inputData = iter.map(columns => List(if ("male" == columns(5)) 1.0 else 0.0, (if (columns(6).isEmpty) 35 else columns(6).toDouble) / 100.0, (columns(2).toDouble - 1) / 2.0))
val outputData = iter.map(columns => List(columns(1).toDouble))
print(s"$v : ")
network.train(inputData, outputData, inputData.size * 20 / 100)
})
}
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment