Skip to content

Instantly share code, notes, and snippets.

@rjwebb
Last active December 18, 2018 23:08
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save rjwebb/8b6596af80e3a6add4e32941fab04db0 to your computer and use it in GitHub Desktop.
Save rjwebb/8b6596af80e3a6add4e32941fab04db0 to your computer and use it in GitHub Desktop.
Gradient descent implementation in Haskell
import Control.Monad (liftM)
import Data.List (find)
-- Given a list it returns the consecutive pairs of items in the list
pairs :: [a] -> [(a, a)]
pairs values = zip values (tail values)
-- Given a list, this function returns the first item whose difference
-- between it and the previous item is less than `threshold`
dropUntilStepSizeLessThan :: (Ord a, Num a) => a -> [a] -> Maybe a
dropUntilStepSizeLessThan threshold infl = result
where
result = liftM snd resultPair
resultPair = find stepSizeLessThan (pairs infl)
stepSizeLessThan (a,b) = abs (a - b) < threshold
-- This function returns the next step in the gradient descent algorithm
next_x :: Num a => (a -> a) -> a -> a -> a
next_x df gamma cur_x = cur_x - gamma * (df cur_x)
-- Given the derivative of a function (df), the step size multiplier (gamma),
-- the step size threshold (threshold) and a starting value,
-- solve the function f for 0 using gradient descent
grad_desc :: (Ord a, Num a) => (a -> a) -> a -> a -> a -> Maybe a
grad_desc df gamma threshold start_x = dropUntilStepSizeLessThan threshold values
where
values = iterate next start_x
next = next_x df gamma
-- Example derivative
df x = 4 * (x**3) - 9 * (x**2)
@rjwebb
Copy link
Author

rjwebb commented Dec 18, 2018

Replaced dropUntil with find, which returns a Maybe value so I had to use liftM snd instead of snd (to modify the value inside the monad).

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment