Skip to content

Instantly share code, notes, and snippets.

@mike-neck
Last active Jun 5, 2016
Embed
What would you like to do?
機械学習でFizzBuzz
/*
* Copyright 2016 Shinya Mochida
*
* Licensed under the Apache License,Version2.0(the"License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,software
* Distributed under the License is distributed on an"AS IS"BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND,either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package com.sample.fizzbuzz
import org.encog.Encog
import org.encog.engine.network.activation.ActivationSigmoid
import org.encog.ml.data.MLData
import org.encog.ml.data.MLDataPair
import org.encog.ml.data.MLDataSet
import org.encog.ml.data.basic.BasicMLData
import org.encog.ml.data.basic.BasicMLDataPair
import org.encog.ml.data.basic.BasicMLDataSet
import org.encog.neural.networks.BasicNetwork
import org.encog.neural.networks.layers.BasicLayer
import org.encog.neural.networks.training.propagation.resilient.ResilientPropagation
object StudyRange {
val STUDY_START = 201
val STUDY_END = 1000
}
fun main(args: Array<String>) {
val inputList: List<InputNumber> =
StudyRange.STUDY_START.rangeTo(StudyRange.STUDY_END).map(toInput)
val studyDataList: List<StudyData> = fizzBuzzList().map { StudyData(inputList, it) }
val filters: List<Filter> = studyDataList.map { Filter(it.dataSet(), it.fizzBuzz) }
val testData: List<InputNumber> = (1..40).map(toInput)
testData.map {
"${it.num} -> ${it.fizzBuzz(filters)} (correct: ${it.correct()})"
}.forEach{
println(it)
}
Encog.getInstance().shutdown()
}
/*------------------------*/
fun Boolean.toInputData(): Double = if (this) 1.0 else 0.0
fun BooleanArray.toDoubleArray(): DoubleArray =
this.map { it.toInputData() }.toDoubleArray()
fun Double.round(): Int = Math.round(this).toInt()
/*------------------------*/
data class FizzBuzzStudy(
val input: InputNumber,
val fizzBuzz: FizzBuzz)
fun FizzBuzzStudy.inputArray(): DoubleArray =
booleanArrayOf(input.fizz, input.buzz).toDoubleArray()
fun FizzBuzzStudy.inputData(): MLData = BasicMLData(inputArray())
fun FizzBuzzStudy.idealData(): MLData = BasicMLData(input.toFizzBuzz().ideal(fizzBuzz))
fun FizzBuzzStudy.dataPair(): MLDataPair =
BasicMLDataPair(inputData(), idealData())
/*------------------------*/
class StudyData(val dataList: List<InputNumber>, val fizzBuzz: FizzBuzz) {
private fun studyList(): List<FizzBuzzStudy> = dataList.map { it.toStudyData(fizzBuzz) }
fun dataPairList(): List<MLDataPair> = studyList().map { it.dataPair() }
}
fun StudyData.dataSet(): MLDataSet = BasicMLDataSet(this.dataPairList())
/*------------------------*/
val toInput: (Int) -> InputNumber = {
val fizz = it % 3 == 0
val buzz = it % 5 == 0
InputNumber(it, fizz, buzz)
}
data class InputNumber(
val num: Int, val fizz: Boolean, val buzz: Boolean)
fun InputNumber.toFizzBuzz(): FizzBuzz {
return when(Pair(fizz, buzz)) {
Pair(true, true) -> FizzBuzz.FIZZ_BUZZ
Pair(false, true) -> FizzBuzz.BUZZ
Pair(true, false) -> FizzBuzz.FIZZ
else -> FizzBuzz.NONE
}
}
fun InputNumber.correct(): String {
return when(Pair(fizz, buzz)) {
Pair(true, true) -> FizzBuzz.FIZZ_BUZZ.name
Pair(false, true) -> FizzBuzz.BUZZ.name
Pair(true, false) -> FizzBuzz.FIZZ.name
else -> "$num"
}
}
fun InputNumber.fizzBuzz(networks: List<Network>): String {
val found = networks.find { it.match(this) }
return when(found?.fizzBuzz) {
FizzBuzz.FIZZ_BUZZ -> FizzBuzz.FIZZ_BUZZ.name
FizzBuzz.BUZZ -> FizzBuzz.BUZZ.name
FizzBuzz.FIZZ -> FizzBuzz.FIZZ.name
else -> "$num"
}
}
fun InputNumber.toStudyData(fizzBuzz: FizzBuzz): FizzBuzzStudy {
return FizzBuzzStudy(this, fizzBuzz)
}
fun InputNumber.toDataArray(): DoubleArray =
booleanArrayOf(fizz, buzz).toDoubleArray()
/*------------------------*/
enum class FizzBuzz {
FIZZ,
BUZZ,
FIZZ_BUZZ,
NONE
}
fun fizzBuzzList(): List<FizzBuzz> =
listOf(FizzBuzz.FIZZ_BUZZ, FizzBuzz.BUZZ,
FizzBuzz.FIZZ, FizzBuzz.NONE)
fun FizzBuzz.ideal(fizzBuzz: FizzBuzz): DoubleArray =
if (this == fizzBuzz) doubleArrayOf(1.0) else doubleArrayOf(0.0)
/*------------------------*/
class Network(val data: MLDataSet, val fizzBuzz: FizzBuzz) {
val network = BasicNetwork()
init {
println("[$fizzBuzz] ${data.size()}")
network.addLayer(BasicLayer(null, true, 2))
network.addLayer(BasicLayer(ActivationSigmoid(), true, 4))
network.addLayer(BasicLayer(ActivationSigmoid(), false, 1))
network.structure.finalizeStructure()
network.reset()
val training = ResilientPropagation(network, data)
for (t in 1..30) {
training.iteration()
println("[$fizzBuzz]Iteration #$t Error: [${training.error}]")
if (training.error < 0.01) break
}
training.finishTraining()
}
private fun compute(input: MLData): MLData = network.compute(input)
fun match(input: InputNumber): Boolean =
compute(BasicMLData(input.toDataArray())).getData(0).round() == 1
}
[FIZZ_BUZZ] 800
[FIZZ_BUZZ]Iteration #1 Error: [0.12226975500233841]
[FIZZ_BUZZ]Iteration #2 Error: [0.09330275911057367]
[FIZZ_BUZZ]Iteration #3 Error: [0.07428148506734375]
[FIZZ_BUZZ]Iteration #4 Error: [0.0641429130682073]
[FIZZ_BUZZ]Iteration #5 Error: [0.06124317773136271]
[FIZZ_BUZZ]Iteration #6 Error: [0.06065131132730094]
[FIZZ_BUZZ]Iteration #7 Error: [0.058779038071560166]
[FIZZ_BUZZ]Iteration #8 Error: [0.05727240367712047]
[FIZZ_BUZZ]Iteration #9 Error: [0.054328420864253144]
[FIZZ_BUZZ]Iteration #10 Error: [0.05105871759203601]
[FIZZ_BUZZ]Iteration #11 Error: [0.04768288172604686]
[FIZZ_BUZZ]Iteration #12 Error: [0.04513699047882237]
[FIZZ_BUZZ]Iteration #13 Error: [0.04258875332991354]
[FIZZ_BUZZ]Iteration #14 Error: [0.040274350258516334]
[FIZZ_BUZZ]Iteration #15 Error: [0.036694693557748057]
[FIZZ_BUZZ]Iteration #16 Error: [0.03308502824167907]
[FIZZ_BUZZ]Iteration #17 Error: [0.025876701138245077]
[FIZZ_BUZZ]Iteration #18 Error: [0.019051917410186872]
[FIZZ_BUZZ]Iteration #19 Error: [0.01086335266804373]
[FIZZ_BUZZ]Iteration #20 Error: [0.0051298132102303405]
[BUZZ] 800
[BUZZ]Iteration #1 Error: [0.4095437300251804]
[BUZZ]Iteration #2 Error: [0.335520645907876]
[BUZZ]Iteration #3 Error: [0.2620690114678582]
[BUZZ]Iteration #4 Error: [0.19611320691900438]
[BUZZ]Iteration #5 Error: [0.14839898668401375]
[BUZZ]Iteration #6 Error: [0.12167051800750166]
[BUZZ]Iteration #7 Error: [0.11103496605170611]
[BUZZ]Iteration #8 Error: [0.1042954774556595]
[BUZZ]Iteration #9 Error: [0.09409627993023403]
[BUZZ]Iteration #10 Error: [0.0842071166408978]
[BUZZ]Iteration #11 Error: [0.0665664138325176]
[BUZZ]Iteration #12 Error: [0.046526843403418255]
[BUZZ]Iteration #13 Error: [0.02868040403273247]
[BUZZ]Iteration #14 Error: [0.01767991692674723]
[BUZZ]Iteration #15 Error: [0.005869018280375759]
[FIZZ] 800
[FIZZ]Iteration #1 Error: [0.2064513978391209]
[FIZZ]Iteration #2 Error: [0.19932015564362507]
[FIZZ]Iteration #3 Error: [0.19241633088418336]
[FIZZ]Iteration #4 Error: [0.188115462435641]
[FIZZ]Iteration #5 Error: [0.17692070426824663]
[FIZZ]Iteration #6 Error: [0.16503889949583014]
[FIZZ]Iteration #7 Error: [0.15042310250842145]
[FIZZ]Iteration #8 Error: [0.13070643271522736]
[FIZZ]Iteration #9 Error: [0.10772843916608814]
[FIZZ]Iteration #10 Error: [0.08286916330891221]
[FIZZ]Iteration #11 Error: [0.05901658645569269]
[FIZZ]Iteration #12 Error: [0.03865155706791457]
[FIZZ]Iteration #13 Error: [0.022807244591016386]
[FIZZ]Iteration #14 Error: [0.01121368412581513]
[FIZZ]Iteration #15 Error: [0.004453462179809962]
[NONE] 800
[NONE]Iteration #1 Error: [0.23645302462711473]
[NONE]Iteration #2 Error: [0.2263712608586554]
[NONE]Iteration #3 Error: [0.21419234991340613]
[NONE]Iteration #4 Error: [0.19831781594244577]
[NONE]Iteration #5 Error: [0.18152056723441828]
[NONE]Iteration #6 Error: [0.16717549435315046]
[NONE]Iteration #7 Error: [0.14490907935465216]
[NONE]Iteration #8 Error: [0.12126651228005893]
[NONE]Iteration #9 Error: [0.09689166018982734]
[NONE]Iteration #10 Error: [0.07365214098467762]
[NONE]Iteration #11 Error: [0.0534529928888325]
[NONE]Iteration #12 Error: [0.035090588685562064]
[NONE]Iteration #13 Error: [0.02194091557381065]
[NONE]Iteration #14 Error: [0.012960881942904694]
[NONE]Iteration #15 Error: [0.006164686125590212]
1 -> 1 (correct: 1)
2 -> 2 (correct: 2)
3 -> FIZZ (correct: FIZZ)
4 -> 4 (correct: 4)
5 -> BUZZ (correct: BUZZ)
6 -> FIZZ (correct: FIZZ)
7 -> 7 (correct: 7)
8 -> 8 (correct: 8)
9 -> FIZZ (correct: FIZZ)
10 -> BUZZ (correct: BUZZ)
11 -> 11 (correct: 11)
12 -> FIZZ (correct: FIZZ)
13 -> 13 (correct: 13)
14 -> 14 (correct: 14)
15 -> FIZZ_BUZZ (correct: FIZZ_BUZZ)
16 -> 16 (correct: 16)
17 -> 17 (correct: 17)
18 -> FIZZ (correct: FIZZ)
19 -> 19 (correct: 19)
20 -> BUZZ (correct: BUZZ)
21 -> FIZZ (correct: FIZZ)
22 -> 22 (correct: 22)
23 -> 23 (correct: 23)
24 -> FIZZ (correct: FIZZ)
25 -> BUZZ (correct: BUZZ)
26 -> 26 (correct: 26)
27 -> FIZZ (correct: FIZZ)
28 -> 28 (correct: 28)
29 -> 29 (correct: 29)
30 -> FIZZ_BUZZ (correct: FIZZ_BUZZ)
31 -> 31 (correct: 31)
32 -> 32 (correct: 32)
33 -> FIZZ (correct: FIZZ)
34 -> 34 (correct: 34)
35 -> BUZZ (correct: BUZZ)
36 -> FIZZ (correct: FIZZ)
37 -> 37 (correct: 37)
38 -> 38 (correct: 38)
39 -> FIZZ (correct: FIZZ)
40 -> BUZZ (correct: BUZZ)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment