Skip to content

Instantly share code, notes, and snippets.

@siraben
Last active November 30, 2020 01:09
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save siraben/ccb604e5e77eb532c00c14a65b71b399 to your computer and use it in GitHub Desktop.
Save siraben/ccb604e5e77eb532c00c14a65b71b399 to your computer and use it in GitHub Desktop.
tree tilt
#! /usr/bin/env nix-shell
#! nix-shell --pure -i ghci -p "haskellPackages.ghcWithPackages (h: [ h.QuickCheck ])"
-- https://leetcode.com/problems/binary-tree-tilt/
import Test.QuickCheck
data Tree = Node Int Tree Tree | Leaf deriving (Show)
instance Arbitrary Tree where
arbitrary = do
s <- getSize
if s <= 0 then pure Leaf else oneof [pure Leaf, Node <$> arbitrary <*> resize (s `div` 2) arbitrary <*> resize (s `div` 2) arbitrary]
shrink Leaf = []
shrink (Node x l r) = [Leaf] ++ [l, r] ++ [Node x' l' r' | (x', l', r') <- shrink (x, l, r)]
-- Run `quickCheck derivedCorrect` to test correctness of derived solution
derivedCorrect :: Tree -> Bool
derivedCorrect t = solve t == naiveSolve t
-- Sum of a tree
treeSum Leaf = 0
treeSum (Node i t2 t3) = i + treeSum t2 + treeSum t3
-- Generate a tree of tilts
tiltTree Leaf = Leaf
tiltTree (Node _ l r) = Node (abs (treeSum l - treeSum r)) (tiltTree l) (tiltTree r)
-- Naive solution
naiveSolve = treeSum . tiltTree
-- Derived solution using accumulating parameter
solve = snd . solve'
solve' Leaf = (0, 0)
solve' (Node n l r) = (n + l_sum + r_sum, abs (l_sum - r_sum) + l_sol + r_sol)
where
(l_sum, l_sol) = solve' l
(r_sum, r_sol) = solve' r
-- Test cases
ex1 = Node 1 (Node 2 Leaf Leaf) (Node 3 Leaf Leaf) -- should be 1
ex2 = Node 4 (Node 2 (Node 3 Leaf Leaf) (Node 5 Leaf Leaf)) (Node 9 Leaf (Node 7 Leaf Leaf)) -- should be 15
ex3 = Node 21 (Node 7 (Node 1 (Node 3 Leaf Leaf) (Node 3 Leaf Leaf)) (Node 1 Leaf Leaf)) (Node 14 (Node 2 Leaf Leaf) (Node 2 Leaf Leaf)) -- should be 9
{-
Spec:
solve' t = (treeSum t, naiveSolve t)
Base case:
solve' Leaf
= { unfold solve' }
(treeSum Leaf, treeSum (tiltTree Leaf))
= { unfold treeSum, tiltTree }
(0, treeSum Leaf)
= { unfold treeSum }
(0, 0)
Inductive case:
solve' (Node n l r)
= { unfold solve' }
(treeSum (Node n l r), treeSum (Node (abs (treeSum l - treeSum r)) (tiltTree l) (tiltTree r)))
= { unfold treeSum }
(n + treeSum l + treeSum r, treeSum (Node (abs (treeSum l - treeSum r)) (tiltTree l) (tiltTree r)))
= { unfold treeSum }
(n + treeSum l + treeSum r, abs (treeSum l - treeSum r) + treeSum (tiltTree l) + treeSum (tiltTree r))
= { fold naiveSolve }
(n + treeSum l + treeSum r, abs (treeSum l - treeSum r) + naiveSolve l + naiveSolve r)
= { introduce let }
let (l_sum, l_sol) = solve' l
(r_sum, r_sol) = solve' r
in
(n + l_sum + r_sum, abs (l_sum - r_sum) + l_sol + r_sol)
QED
Now define:
solve t = snd (solve' t)
Correctness of solve:
solve t = snd (solve' t) = snd (treeSum t, naiveSolve t) = naiveSolve t
-}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment