Last active
August 12, 2021 06:09
-
-
Save thomasnield/6f7d1cb8ab0faa839299cfdbbb70d860 to your computer and use it in GitHub Desktop.
simple_logistic_regression
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import org.apache.commons.math3.distribution.NormalDistribution | |
import org.nield.kotlinstatistics.randomFirst | |
import kotlin.math.exp | |
import kotlin.math.ln | |
// See graph | |
// https://www.desmos.com/calculator/6cb10atg3l | |
// Helpful Resources: | |
// StatsQuest on YouTube: https://www.youtube.com/watch?v=yIYKR4sgzI8&list=PLblh5JKOoLUKxzEP5HA2d-Li7IJkHfXSe | |
// Brandon Foltz on YouTube: https://www.youtube.com/playlist?list=PLIeGtxpvyG-JmBQ9XoFD4rs-b3hkcX7Uu | |
data class Observation(val independent: Double, val dependent: Boolean) | |
val allObservations = sequenceOf( | |
1.0 to false, | |
1.5 to false, | |
2.1 to false, | |
2.4 to false, | |
2.5 to true, | |
3.1 to false, | |
4.2 to false, | |
4.4 to true, | |
4.6 to true, | |
4.9 to false, | |
5.2 to true, | |
5.6 to false, | |
6.1 to true, | |
6.4 to true, | |
6.6 to true, | |
7.0 to false, | |
7.6 to true, | |
7.8 to true, | |
8.4 to true, | |
8.8 to true, | |
9.2 to true | |
).map { (independent,dependent) -> Observation(independent, dependent) } | |
.toList() | |
fun main() { | |
/*allObservations.forEach { | |
println("${it.independent},${if (it.dependent) 1 else 0}") | |
}*/ | |
var bestLikelihood = -10_000_000.0 | |
// use hill climbing for optimization | |
val normalDistribution = NormalDistribution(0.0, 1.0) | |
var b0 = .01 | |
var b1 = .01 | |
fun predictProbability(independent: Double) = 1.0 / (1 + exp(-(b0 + b1*independent))) | |
repeat(10000) { | |
val selectedBeta = (0..1).asSequence().randomFirst() | |
val adjust = normalDistribution.sample() | |
// make random adjustment to two of the colors | |
when { | |
selectedBeta == 0 -> b0 += adjust | |
selectedBeta == 1 -> b1 += adjust | |
} | |
// calculate maximum likelihood | |
val trueEstimates = allObservations.asSequence() | |
.filter { it.dependent == true } | |
.map { ln(predictProbability(it.independent)) } | |
.sum() | |
val falseEstimates = allObservations.asSequence() | |
.filter { it.dependent == false } | |
.map { ln(1 - predictProbability(it.independent)) } | |
.sum() | |
val likelihood = trueEstimates + falseEstimates | |
if (bestLikelihood < likelihood) { | |
bestLikelihood = likelihood | |
} else { | |
// revert if no improvement happens | |
when { | |
selectedBeta == 0 -> b0 -= adjust | |
selectedBeta == 1 -> b1 -= adjust | |
} | |
} | |
} | |
println("1.0 / (1 + exp(-($b0 + $b1*x))") | |
println("BEST LIKELIHOOD: $bestLikelihood") | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment