Created
May 8, 2016 20:09
-
-
Save gvergnaud/e171cf8d00609d759f9f5ecd30c49f57 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
/* ----------------------------------------- * | |
Learning Machin Learning | |
* ----------------------------------------- */ | |
// gradientDescent of the square error function for a Theta of any number of | |
// parameters | |
// multiplyVectors :: List number -> List number -> number | |
const multiplyVectors = (xs, ys) => | |
xs.reduce((acc, x, i) => acc + x * ys[i], 0) | |
// sum :: (a -> Int) -> List a -> Int | |
const sum = (transformer, list) => | |
list.reduce((acc, item) => acc + transformer(item), 0) | |
// getNextTheta :: List Number -> List Data -> Theta -> Theta | |
const getNextTheta = (step, dataSet) => theta => | |
theta.map((currentTheta, i) => | |
currentTheta - step[i] / dataSet.length * sum(([xs, y]) => (multiplyVectors(xs, theta) - y) * xs[i], dataSet) | |
) | |
// isConverging :: Theta -> Theta -> Bool | |
const isConverging = (xs, ys) => | |
xs.reduce((acc, x, i) => acc && Math.abs(x - ys[i]) < 0.001, true) | |
// repeatUntilConverge :: (Theta -> Theta) -> Theta -> Theta | |
const repeatUntilConverge = (f, theta) => { | |
const nextTheta = f(theta) | |
return isConverging(nextTheta, theta) | |
? nextTheta | |
: repeatUntilConverge(f, nextTheta) | |
} | |
// gradientDescent :: { step :: List Number, dataSet :: List Data } -> Theta | |
const gradientDescent = ({ step, dataSet }) => | |
repeatUntilConverge(getNextTheta(step, dataSet), dataSet[0][0].map(_ => 0)) | |
// linearRegression :: List Data -> List Number -> (Data -> Value) | |
function linearRegression(dataSet, step) { | |
const theta = gradientDescent({ step, dataSet }) | |
return data => multiplyVectors(data, theta) | |
} | |
/* ----------------------------------------- * | |
let's use it | |
* ----------------------------------------- */ | |
const trainingSet = [ | |
{ | |
size: 2104, | |
bedroomsCount: 3, | |
floorsCount: 1, | |
age: 8, | |
price: 400 | |
}, | |
{ | |
size: 1416, | |
bedroomsCount: 2, | |
floorsCount: 1, | |
age: 15, | |
price: 232 | |
}, | |
{ | |
size: 1534, | |
bedroomsCount: 3, | |
floorsCount: 1, | |
age: 2, | |
price: 315 | |
}, | |
{ | |
size: 2500, | |
bedroomsCount: 4, | |
floorsCount: 2, | |
age: 7, | |
price: 460 | |
} | |
] | |
const formatData = ({ size, bedroomsCount, floorsCount, age }) => | |
[ 1, size, bedroomsCount, floorsCount, age ] | |
const formatTrainingSet = ({ price, ...rest }) => [ | |
formatData(rest), | |
price | |
] | |
const h = linearRegression(trainingSet.map(formatTrainingSet), [.01, .0000001, .01, .01, .01]) | |
console.log(`size: ${1503}, estimated price ${h(formatData({ | |
size: 1503, | |
bedroomsCount: 1, | |
floorsCount: 1, | |
age: 4, | |
}))}`) | |
console.log(`size: ${2000}, estimated price ${h(formatData({ | |
size: 2000, | |
bedroomsCount: 2, | |
floorsCount: 2, | |
age: 2, | |
}))}`) | |
console.log(`size: ${1503}, estimated price ${h(formatData({ | |
size: 1000, | |
bedroomsCount: 3, | |
floorsCount: 1, | |
age: 18, | |
}))}`) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment