Skip to content

Instantly share code, notes, and snippets.

@chowells79
Last active February 16, 2023 17:17
Show Gist options
  • Save chowells79/ae906eac3d8804e06497ae56d4496a9c to your computer and use it in GitHub Desktop.
Save chowells79/ae906eac3d8804e06497ae56d4496a9c to your computer and use it in GitHub Desktop.
-- Finds the largest Integer x such that 'f x <= target'
--
-- preconditions:
-- the provided function is monotonically non-decreasing
-- there exist an Integer y such that 'f y > target'
-- there exists an Integer z such that 'f z <= target'
--
-- violation of the first precondition may result in incorrect output
-- violation of the second or third preconditions results in non-termination
unboundedSearch :: Ord a => (Integer -> a) -> a -> Integer
unboundedSearch f target
| f 0 <= target = up 0 16
| otherwise = down (-16) 0
where
-- search for an upper bound after which to switch to binary search
--
-- invariant:
-- f lower <= target
up lower upper
| f upper <= target = up upper (upper * upper)
| otherwise = binarySearch f target lower upper
-- search for a lower bound after which to switch to binary search
--
-- invariant:
-- target < f upper
down lower upper
| target < f lower = down (negate lower * lower) lower
| otherwise = binarySearch f target lower upper
-- Finds the largest Integer x between lower (inclusive) and upper
-- (exclusive) such that 'f x <= target'
--
-- preconditions:
-- the provided function is monotonically non-decreasing
-- f lower <= target
-- target < f upper
--
-- violation of the preconditions may result in incorrect output or non-termination
binarySearch :: Ord a => (Integer -> a) -> a -> Integer -> Integer -> Integer
binarySearch f target = go
where
-- invariants:
-- f lower <= target
-- target < f upper
go lower upper
| lower == midpoint = lower
| f midpoint <= target = go midpoint upper
| otherwise = go lower midpoint
where
midpoint = halfLower + halfUpper + parity
where
-- use div instead of quot so that the midpoint always rounds towards lower
(halfLower, parityLower) = lower `divMod` 2
(halfUpper, parityUpper) = upper `divMod` 2
parity = (parityLower + parityUpper) `div` 2
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment