Created
September 6, 2014 13:19
-
-
Save tkroman/d1d1697f2e3dcd2af5ce to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import java.util | |
import java.util.Collections | |
import scala.util.Random | |
object Perceptron { | |
def main (args: Array[String]) { | |
val Variant = 13 | |
val CsvRegex = """(.+)\;(.+)\;(.+)""".r | |
def approxNextW(w: (Double, Double, Double))(eta: Double, d: Double, y: Double, x: (Double, Double, Double)) = { | |
val interm = eta * (d - y) | |
(w._1 + interm * (x._1 * 0.5), w._2 + interm * (x._2 * 0.5), w._3 + interm * (x._3 * 0.5)) | |
} | |
def makeGuess(w: (Double, Double, Double))(x: (Double, Double, Double)) = | |
(w._1 * x._1 + w._2 * x._2 + w._3 * x._3).signum | |
val csvLines = io.Source.fromFile(getClass.getResource(s"data${Variant}.csv").getPath).getLines().toSeq | |
val csvLinesLen = csvLines.length | |
def learningSet = Random.shuffle(csvLines).take((0.8 * csvLinesLen).toInt) | |
def testSet = Random.shuffle(csvLines).take((0.2 * csvLinesLen).toInt) | |
var w = (0.5d, 0.5d, 0.5d) | |
var retryCount = 0 | |
var errCount = 0 | |
do { | |
errCount = 0 | |
for (line <- learningSet) { | |
val CsvRegex(x1, x2, expected) = line | |
val perceptronGuess = makeGuess(w)((x1.toDouble, x2.toDouble, -1)) | |
if (expected.toInt != perceptronGuess) { | |
errCount += 1 | |
w = approxNextW(w)(0.5, expected.toInt, perceptronGuess, (x1.toDouble, x2.toDouble, -1)) | |
} | |
} | |
retryCount += 1 | |
} while (errCount > 0 && retryCount < 5000) | |
println((for (line <- testSet) yield { | |
val CsvRegex(x1, x2, expected) = line | |
makeGuess(w)(x1.toDouble, x2.toDouble, -1) == expected.toInt | |
}) forall (true ==)) | |
println(retryCount) | |
println(w) | |
} | |
} | |
// here's the sample file (data13.csv): | |
// 0.181;0.985;-1 | |
// 0.153;0.074;-1 | |
// 0.917;0.794;1 | |
// 0.642;0.453;1 | |
// 0.964;0.453;1 | |
// 0.739;0.04;1 | |
// 0.79;0.055;1 | |
// 0.661;0.68;1 | |
// 0.364;0.194;-1 | |
// 0.129;0.716;-1 | |
// 0.155;0.963;-1 | |
// 0.36;0.993;-1 | |
// 0.132;0.958;-1 | |
// 0.074;0.854;-1 | |
// 0.762;0.271;1 | |
// 0.489;0.142;-1 | |
// 0.016;0.496;-1 | |
// 0.324;0.202;-1 | |
// 0.362;0.727;-1 | |
// 0.797;0.877;1 | |
// 0.768;0.557;1 | |
// 0.628;0.512;-1 | |
// 0.703;0.081;1 | |
// 0.035;0.949;-1 | |
// 0.137;0.892;-1 | |
// 0.738;0.532;1 | |
// 0.537;0.084;-1 | |
// 0.953;0.105;1 | |
// 0.873;0.25;1 | |
// 0.401;0.013;-1 | |
// 0.613;0.485;-1 | |
// 0.301;0.098;-1 | |
// 0.726;0.382;1 | |
// 0.37;0.705;-1 | |
// 0.681;0.034;-1 | |
// 0.828;0.677;1 | |
// 0.447;0.037;-1 | |
// 0.393;0.358;-1 | |
// 0.491;0.482;-1 | |
// 0.285;0.592;-1 | |
// 0.784;0.965;1 | |
// 0.221;0.722;-1 | |
// 0.794;0.579;1 | |
// 0.421;0.593;-1 | |
// 0.489;0.917;-1 | |
// 0.604;0.547;-1 | |
// 0.339;0.074;-1 | |
// 0.987;0.429;1 | |
// 0.311;0.642;-1 | |
// 0.291;0.583;-1 | |
// 0.875;0.717;1 | |
// 0.733;0.852;1 | |
// 0;0.891;-1 | |
// 0.033;0.58;-1 | |
// 0.154;0.259;-1 | |
// 0.409;0.611;-1 | |
// 0.307;0.413;-1 | |
// 0.579;0.288;-1 | |
// 0.926;0.449;1 | |
// 0.764;0.455;1 | |
// 0.98;0.959;1 | |
// 0.681;0.456;1 | |
// 0.939;0.037;1 | |
// 0.549;0.136;-1 | |
// 0.771;0.278;1 | |
// 0.087;0.836;-1 | |
// 0.945;0.774;1 | |
// 0.513;0.402;-1 | |
// 0.429;0.121;-1 | |
// 0.547;0.177;-1 | |
// 0.693;0.41;1 | |
// 0.408;0.004;-1 | |
// 0.704;0.237;1 | |
// 0.813;0.228;1 | |
// 0.933;0.694;1 | |
// 0.982;0.692;1 | |
// 0.346;0.376;-1 | |
// 0.633;0.741;1 | |
// 0.415;0.376;-1 | |
// 0.985;0.217;1 | |
// 0.106;0.289;-1 | |
// 0.758;0.464;1 | |
// 0.889;0.207;1 | |
// 0.169;0.327;-1 | |
// 0.733;0.725;1 | |
// 0.937;0.148;1 | |
// 0.106;0.936;-1 | |
// 0.282;0.773;-1 | |
// 0.977;0.435;1 | |
// 0.921;0.544;1 | |
// 0.814;0.314;1 | |
// 0.074;0.305;-1 | |
// 0.588;0.168;-1 | |
// 0.46;0.4;-1 | |
// 0.765;0.09;1 | |
// 0.297;0.673;-1 | |
// 0.316;0.675;-1 | |
// 0.303;0.467;-1 | |
// 0.708;0.015;1 | |
// 0.699;0.626;1 |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment