Last active
December 18, 2018 23:08
-
-
Save rjwebb/8b6596af80e3a6add4e32941fab04db0 to your computer and use it in GitHub Desktop.
Gradient descent implementation in Haskell
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import Control.Monad (liftM) | |
import Data.List (find) | |
-- Given a list it returns the consecutive pairs of items in the list | |
pairs :: [a] -> [(a, a)] | |
pairs values = zip values (tail values) | |
-- Given a list, this function returns the first item whose difference | |
-- between it and the previous item is less than `threshold` | |
dropUntilStepSizeLessThan :: (Ord a, Num a) => a -> [a] -> Maybe a | |
dropUntilStepSizeLessThan threshold infl = result | |
where | |
result = liftM snd resultPair | |
resultPair = find stepSizeLessThan (pairs infl) | |
stepSizeLessThan (a,b) = abs (a - b) < threshold | |
-- This function returns the next step in the gradient descent algorithm | |
next_x :: Num a => (a -> a) -> a -> a -> a | |
next_x df gamma cur_x = cur_x - gamma * (df cur_x) | |
-- Given the derivative of a function (df), the step size multiplier (gamma), | |
-- the step size threshold (threshold) and a starting value, | |
-- solve the function f for 0 using gradient descent | |
grad_desc :: (Ord a, Num a) => (a -> a) -> a -> a -> a -> Maybe a | |
grad_desc df gamma threshold start_x = dropUntilStepSizeLessThan threshold values | |
where | |
values = iterate next start_x | |
next = next_x df gamma | |
-- Example derivative | |
df x = 4 * (x**3) - 9 * (x**2) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Replaced
dropUntil
withfind
, which returns aMaybe
value so I had to useliftM snd
instead ofsnd
(to modify the value inside the monad).