Skip to content

Instantly share code, notes, and snippets.

@tkroman
Created September 6, 2014 13:19
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save tkroman/d1d1697f2e3dcd2af5ce to your computer and use it in GitHub Desktop.
Save tkroman/d1d1697f2e3dcd2af5ce to your computer and use it in GitHub Desktop.
import java.util
import java.util.Collections
import scala.util.Random
object Perceptron {
def main (args: Array[String]) {
val Variant = 13
val CsvRegex = """(.+)\;(.+)\;(.+)""".r
def approxNextW(w: (Double, Double, Double))(eta: Double, d: Double, y: Double, x: (Double, Double, Double)) = {
val interm = eta * (d - y)
(w._1 + interm * (x._1 * 0.5), w._2 + interm * (x._2 * 0.5), w._3 + interm * (x._3 * 0.5))
}
def makeGuess(w: (Double, Double, Double))(x: (Double, Double, Double)) =
(w._1 * x._1 + w._2 * x._2 + w._3 * x._3).signum
val csvLines = io.Source.fromFile(getClass.getResource(s"data${Variant}.csv").getPath).getLines().toSeq
val csvLinesLen = csvLines.length
def learningSet = Random.shuffle(csvLines).take((0.8 * csvLinesLen).toInt)
def testSet = Random.shuffle(csvLines).take((0.2 * csvLinesLen).toInt)
var w = (0.5d, 0.5d, 0.5d)
var retryCount = 0
var errCount = 0
do {
errCount = 0
for (line <- learningSet) {
val CsvRegex(x1, x2, expected) = line
val perceptronGuess = makeGuess(w)((x1.toDouble, x2.toDouble, -1))
if (expected.toInt != perceptronGuess) {
errCount += 1
w = approxNextW(w)(0.5, expected.toInt, perceptronGuess, (x1.toDouble, x2.toDouble, -1))
}
}
retryCount += 1
} while (errCount > 0 && retryCount < 5000)
println((for (line <- testSet) yield {
val CsvRegex(x1, x2, expected) = line
makeGuess(w)(x1.toDouble, x2.toDouble, -1) == expected.toInt
}) forall (true ==))
println(retryCount)
println(w)
}
}
// here's the sample file (data13.csv):
// 0.181;0.985;-1
// 0.153;0.074;-1
// 0.917;0.794;1
// 0.642;0.453;1
// 0.964;0.453;1
// 0.739;0.04;1
// 0.79;0.055;1
// 0.661;0.68;1
// 0.364;0.194;-1
// 0.129;0.716;-1
// 0.155;0.963;-1
// 0.36;0.993;-1
// 0.132;0.958;-1
// 0.074;0.854;-1
// 0.762;0.271;1
// 0.489;0.142;-1
// 0.016;0.496;-1
// 0.324;0.202;-1
// 0.362;0.727;-1
// 0.797;0.877;1
// 0.768;0.557;1
// 0.628;0.512;-1
// 0.703;0.081;1
// 0.035;0.949;-1
// 0.137;0.892;-1
// 0.738;0.532;1
// 0.537;0.084;-1
// 0.953;0.105;1
// 0.873;0.25;1
// 0.401;0.013;-1
// 0.613;0.485;-1
// 0.301;0.098;-1
// 0.726;0.382;1
// 0.37;0.705;-1
// 0.681;0.034;-1
// 0.828;0.677;1
// 0.447;0.037;-1
// 0.393;0.358;-1
// 0.491;0.482;-1
// 0.285;0.592;-1
// 0.784;0.965;1
// 0.221;0.722;-1
// 0.794;0.579;1
// 0.421;0.593;-1
// 0.489;0.917;-1
// 0.604;0.547;-1
// 0.339;0.074;-1
// 0.987;0.429;1
// 0.311;0.642;-1
// 0.291;0.583;-1
// 0.875;0.717;1
// 0.733;0.852;1
// 0;0.891;-1
// 0.033;0.58;-1
// 0.154;0.259;-1
// 0.409;0.611;-1
// 0.307;0.413;-1
// 0.579;0.288;-1
// 0.926;0.449;1
// 0.764;0.455;1
// 0.98;0.959;1
// 0.681;0.456;1
// 0.939;0.037;1
// 0.549;0.136;-1
// 0.771;0.278;1
// 0.087;0.836;-1
// 0.945;0.774;1
// 0.513;0.402;-1
// 0.429;0.121;-1
// 0.547;0.177;-1
// 0.693;0.41;1
// 0.408;0.004;-1
// 0.704;0.237;1
// 0.813;0.228;1
// 0.933;0.694;1
// 0.982;0.692;1
// 0.346;0.376;-1
// 0.633;0.741;1
// 0.415;0.376;-1
// 0.985;0.217;1
// 0.106;0.289;-1
// 0.758;0.464;1
// 0.889;0.207;1
// 0.169;0.327;-1
// 0.733;0.725;1
// 0.937;0.148;1
// 0.106;0.936;-1
// 0.282;0.773;-1
// 0.977;0.435;1
// 0.921;0.544;1
// 0.814;0.314;1
// 0.074;0.305;-1
// 0.588;0.168;-1
// 0.46;0.4;-1
// 0.765;0.09;1
// 0.297;0.673;-1
// 0.316;0.675;-1
// 0.303;0.467;-1
// 0.708;0.015;1
// 0.699;0.626;1
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment