Skip to content

Instantly share code, notes, and snippets.

@jamesthompson
Created September 19, 2012 21:50
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save jamesthompson/3752532 to your computer and use it in GitHub Desktop.
Save jamesthompson/3752532 to your computer and use it in GitHub Desktop.
Linear Regression in Scala
object LinearRegression {
def fit(data: List[(Double,Double)]) = {
def sqr(in:Double) : Double = in * in
val sumx = data.map(_._1).sum
val sumx2 = data.map((d:(Double,Double)) => sqr(d._1)).sum
val sumy = data.map(_._2).sum
val xbar = sumx / data.length
val ybar = sumy / data.length
val xxbar = data.map(d => sqr(d._1 - xbar)).sum
val yybar = data.map(d => sqr(d._2 - ybar)).sum
val xybar = data.map(d => (d._1 - xbar) * (d._2 - ybar)).sum
val m = xybar / xxbar
val c = ybar - (m * xbar)
val out = for(d <- data) yield {
val fit = m * d._1 + c
(sqr(fit - d._2), sqr(fit - ybar))
}
val rss = out.map(_._1).sum
val ssr = out.map(_._2).sum
val rsquared = ssr / yybar
val svar = rss / (data.length - 2).toDouble
val stErrGradient = svar / xxbar
val stErrIntercept = (svar / data.length) + (sqr(xbar) * stErrGradient)
println("y = " + m.toString + " x + " + c.toString)
println("R^2 Correlation Coefficient = " + rsquared.toString);
println("Std. error of m = " + math.sqrt(stErrGradient).toString);
println("Std. error of c = " + math.sqrt(stErrIntercept).toString);
}
}
@jamesthompson
Copy link
Author

m = gradient
c = y-axis intercept

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment