Skip to content

Instantly share code, notes, and snippets.

@ammkrn
Last active January 31, 2022 13:10
Show Gist options
  • Save ammkrn/d7212abd94dfb86308647c888ca23ac0 to your computer and use it in GitHub Desktop.
Save ammkrn/d7212abd94dfb86308647c888ca23ac0 to your computer and use it in GitHub Desktop.
Lean 4 BTree first attempt
def B : Nat := 6
def MIDPOINT : Nat := B.pred
def CAPACITY : Nat := (2 * B) - 1
def MINIMUM : Nat := B - 1
def LHS_BACK : Nat := B - 2
def RHS_FRONT : Nat := B - 1
/-
BTree implementation where the nodes are dependent on `Height : Nat`. For nodes with
height h+1, their edges field is an array of nodes with height h. For nodes with
height 0, their edges field is a single PUnit.
Has to be separated into two types (`NodeCore` and `Node`), because recursive
arguments to a contructor can't depend be dependent in the type
being defined. In other words, the constructor for `Node K V (h+1)` cannot take
an argument `edges : Array (Node K V h)`, since the type of `edges` depends on `h`.
-/
structure NodeCore (K : Type u) (V : K → Type v) (Height : Nat) (R : Type w) [Ord K] where
elems : Array (Sigma V)
edges : R
@[reducible]
def Node (K : Type u) (V : K → Type v) [Ord K] : ∀ (height : Nat), Type (max u v)
| 0 => NodeCore K V 0 Unit
| Nat.succ h => NodeCore K V (h+1) (Array (Node K V h))
structure BTree (K : Type u) (V : K → Type v) [Ord K] where
height : Nat
root : Node K V height
abbrev RecTypeElem (K : Type u) (V : K → Type v) [Ord K] : ∀ (height : Nat), Type (max u v)
| 0 => PUnit
| h+1 => Node K V h
abbrev RecType (K : Type u) (V : K → Type v) [Ord K] : ∀ (height : Nat), Type (max u v)
| 0 => PUnit
| h+1 => Array (Node K V h)
section BTreeDefs
variable {K : Type u} {V : K → Type v} {n : Nat} [Ord K] [Inhabited K] [∀ k, Inhabited <| V k]
instance : Inhabited (Sigma V) := ⟨Inhabited.default, Inhabited.default⟩
instance (priority := high) : Inhabited (Node K V 0) := ⟨∅, Inhabited.default⟩
instance (priority := high) : EmptyCollection (Node K V 0) := ⟨Inhabited.default⟩
instance : EmptyCollection (BTree K V) := ⟨0, ∅⟩
instance (priority := low) : Inhabited (Node K V n) :=
⟨match n with
| 0 => Inhabited.default
| Nat.succ n => ⟨∅, Inhabited.default⟩⟩
@[reducible]
def Node.getElems : ∀ {n : Nat}, Node K V n → Array (Sigma V)
| 0, node => node.elems
| n+1, node => node.elems
section Util
def Array.insertAtOrPush {A : Type u} (as : Array A) (i : Nat) (a : A) : Array A :=
if i <= as.size
then as.insertAt i a
else as.push a
instance : ToString Ordering where
toString o :=
match o with
| Ordering.eq => "eq"
| Ordering.lt => "lt"
| Ordering.gt => "gt"
def blank (n : Nat) := Id.run do
let mut out := ""
for i in [0 : n] do
out := out.push ' '
out
end Util
@[reducible]
partial def Node.toString [ToString K] [∀ k, ToString (V k)] : ∀ {n} (node : Node K V n), String
| 0, node => Id.run do
let mut acc := ""
for elem in node.elems do
acc := acc ++ (ToString.toString elem) ++ "\n"
acc
| n+1, node => Id.run do
let mut acc := ""
let mut elemNum := 0
for edge in node.edges do
acc := acc ++ toString edge
match node.elems.get? elemNum with
| none => continue
| some elem => acc := acc ++ (blank (4 * n.succ)) ++ s!"{elem}\n"
elemNum := elemNum + 1
acc
instance [ToString K] [∀ k, ToString (V k)] : ToString (Node K V n) := ⟨Node.toString⟩
instance [ToString K] [∀ k, ToString (V k)] : ToString (BTree K V) := ⟨fun t => s!"{t.root}"⟩
/-
Linear search through the elements in an array of elements.
If the key is in this node, you'll get back `lt` or `eq`, and the returned
`Nat` will be the corresponding key position, or the proper edge position.
If the key is not in this node, all iterations will match `gt`, and the
last iteration will increment the counter, and the `Nat` returned will be
the last edge in the array of edges. There are no instances in which a return
value of `Ordering.gt, _` will be used to access a member of `elems`, so
this shouldn't cause confusion.
-/
def Node.findPos (node : Node K V n) (k : K) : (Ordering × Nat) := Id.run do
let mut idx := 0
for ⟨k', _⟩ in node.getElems do
match Ord.compare k k' with
| Ordering.lt => return ⟨Ordering.lt, idx⟩
| Ordering.eq => return ⟨Ordering.eq, idx⟩
| Ordering.gt =>
idx := idx + 1
continue
⟨Ordering.gt, node.getElems.size⟩
/- For some reason `Array.get` is only defined on `Type`. -/
@[reducible]
def Node.get? : ∀ {n : Nat} (node : Node K V n) (k : K), Option (Sigma V)
| 0 , node, k =>
match node.findPos k with
| ⟨Ordering.eq, pos⟩ => some <| node.elems.get! pos
| _ => none
| n+1, node, k =>
match node.findPos k with
| ⟨Ordering.lt, pos⟩ => (node.edges.get! pos).get? k
| ⟨Ordering.gt, pos⟩ => node.edges.back.get? k
| ⟨Ordering.eq, pos⟩ => some <| node.elems.get! pos
def BTree.get? (t : BTree K V) (k : K) : Option (Sigma V) := t.root.get? k
/- Used when we need to split an overflowing node. -/
def Node.splitElems (node : Node K V n) (inbound : Sigma V) (inPos : Nat) : (Array (Sigma V) × (Sigma V) × Array (Sigma V)) :=
let all := node.getElems.insertAtOrPush inPos inbound
let rhs := all.extract B all.size
let lhs := all.shrink B
assert! all.size = B * 2
assert! rhs.size = B
assert! lhs.size = B
(lhs.pop, lhs.back, rhs)
/- Used when we need to split an overflowing node. -/
def Node.splitEdges (node : Node K V n.succ) (inL inR : Node K V n) (inPos : Nat) : (Array (Node K V n) × Array (Node K V n)) :=
/- Since inPos is always <= the edges length, `set!; insertAt` is fine. -/
let all := (node.edges.set! inPos inR).insertAt inPos inL
let rhs := all.extract B all.size
let lhs := all.shrink B
assert! all.size = (2 * B).succ
assert! rhs.size = B.succ
assert! lhs.size = B
⟨lhs, rhs⟩
/- Split a leaf node that's too full to accommodate an additional element. -/
def Leaf.split (node : Node K V 0) (inbound : Sigma V) (keyPos : Nat) : (Node K V 0 × Sigma V × Node K V 0) :=
let ⟨elemsLeft, medianElem, elemsRight⟩ := node.splitElems inbound keyPos
⟨⟨elemsLeft, Inhabited.default⟩, medianElem, ⟨elemsRight, Inhabited.default⟩⟩
/-
Split an internal node that's too full to accommodate an additional element.
If the original findPos was lt or eq, then keyPos = edgePos.
If the original findPos was `gt`, then this is both the right edgePos,
AND for the newly inserted node, the right keyPos, since we're using `insertOrPush`.
-/
def Internal.split
(node : Node K V n.succ)
(elem : Sigma V)
(edgeL edgeR : Node K V n)
(inPos : Nat) : (Node K V n.succ × Sigma V × Node K V n.succ) :=
let ⟨elemsLeft, medianElem, elemsRight⟩ := node.splitElems elem inPos
let ⟨edgesLeft, edgesRight⟩ := node.splitEdges edgeL edgeR inPos
⟨⟨elemsLeft, edgesLeft⟩, medianElem, ⟨elemsRight, edgesRight⟩⟩
/--
The first element is `true` iff `k` was already in the map.
The second element is `Sum.inl _` if we didn't need to split, and
`Sum.inr ..` if we needed to split, which will be handled as we
go back up.
-/
def Node.insert : ∀ {n : Nat} (k : K), V k → Node K V n → (Bool × Sum (Node K V n) (Node K V n × Sigma V × Node K V n))
| 0, k, v, node =>
match node.findPos k with
| ⟨Ordering.eq, hitPos⟩ => (true, Sum.inl <| { node with elems := node.elems.set! hitPos ⟨k, v⟩ })
| ⟨_, pos⟩ =>
if node.elems.size >= CAPACITY
/- we're full; split this leaf -/
then (false, Sum.inr <| Leaf.split node ⟨k, v⟩ pos)
/- Not full, just insert the element and return. -/
else (false, Sum.inl <| { node with elems := node.elems.insertAtOrPush pos ⟨k, v⟩ })
| n+1, k, v, node =>
match node.findPos k with
| ⟨Ordering.eq, hitPos⟩ => (true, Sum.inl <| { node with elems := node.elems.set! hitPos ⟨k, v⟩ })
| ⟨_, pos⟩ =>
match (node.edges.get! pos).insert k v with
| ⟨b, Sum.inl edge'⟩ => (b, Sum.inl <| { node with edges := node.edges.set! pos edge' })
| ⟨b, Sum.inr ⟨l, kv', r⟩⟩ =>
if node.elems.size >= CAPACITY
/- Full; have to split. -/
then (b, Sum.inr <| Internal.split node kv' l r pos)
/- not full. -/
else (b, Sum.inl {
node with
elems := node.elems.insertAtOrPush pos kv',
edges := (node.edges.set! pos r).insertAt pos l })
def Node.leftStolenElems : ∀ {n : Nat}, Node K V n → (Sigma V × RecTypeElem K V n)
| 0, node => (node.elems.back, PUnit.unit)
| n+1, node => (node.elems.back, node.edges.back)
def Node.leftAfterStealing : ∀ {n : Nat}, Node K V n → Node K V n
| 0, node => { node with elems := node.elems.pop, edges := PUnit.unit }
| n+1, node => { node with elems := node.elems.pop, edges := node.edges.pop }
def Node.rightStolenElems : ∀ {n : Nat}, Node K V n → (Sigma V × RecTypeElem K V n)
| 0, node => (node.elems.get! 0, PUnit.unit)
| n+1, node => (node.elems.get! 0, node.edges.get! 0)
def Node.rightAfterStealing : ∀ {n : Nat}, Node K V n → Node K V n
| 0, node => { node with elems := node.elems.eraseIdx 0, edges := PUnit.unit }
| n+1, node => { node with elems := node.elems.eraseIdx 0, edges := node.edges.eraseIdx 0 }
def Node.addStolenToRight : ∀ {n : Nat}, (stolenEdge : RecTypeElem K V n) → (oldMiddle : Sigma V) → Node K V n → Node K V n
| 0, _, oldMiddle, rhs => { rhs with elems := rhs.elems.insertAtOrPush 0 oldMiddle, edges := PUnit.unit }
| n+1, stolenEdge, oldMiddle, rhs => { rhs with elems := rhs.elems.insertAtOrPush 0 oldMiddle, edges := rhs.edges.insertAt 0 stolenEdge }
def Node.addStolenToLeft : ∀ {n : Nat}, (oldMiddle : Sigma V) → (stolenEdge : RecTypeElem K V n) → Node K V n → Node K V n
| 0, oldMiddle, _, lhs => { lhs with elems := lhs.elems.push oldMiddle, edges := PUnit.unit }
| n+1, oldMiddle, stolenEdge, lhs => { lhs with elems := lhs.elems.push oldMiddle, edges := lhs.edges.push stolenEdge }
/-
Try to steal from the left sibling while rebalancing a node,
aftter we've deleted something from `edge`.
.. x y z ..
/ \
leftSibling edge
y is the element we want to switch out; it has idx `leftIdx`.
-/
@[reducible]
def Node.tryStealLeft (node : Node K V n.succ) (edgeIdx : Nat) (edge : Node K V n) : Option (Node K V n.succ) :=
match edgeIdx with
/- If this is the leftmost edge, no good. -/
| 0 => none
| leftIdx+1 =>
let leftSibling := node.edges.get! leftIdx
if leftSibling.getElems.size < MINIMUM
then panic! "tryStealLeft, impossible size < MINIMUM"
else if leftSibling.getElems.size = MINIMUM
then none
else
let oldMiddle := node.elems.get! leftIdx
let ⟨newMiddle, stolenEdge⟩ := leftSibling.leftStolenElems
let leftSibling' := leftSibling.leftAfterStealing
let edge' := edge.addStolenToRight stolenEdge oldMiddle
some {
node with
elems := node.elems.set! leftIdx newMiddle
edges := (node.edges.set! leftIdx leftSibling').set! edgeIdx edge'
}
/-
Try to steal from the right sibling while rebalancing this node after we've
deleted something from `edge`. `y` has index `edgeIdx`, rightSibling is
at `edgeIdx.succ`.
.. x y z ..
/ \
edge rightSibling
-/
@[reducible]
def Node.tryStealRight (node : Node K V n.succ) (edgeIdx : Nat) (edge : Node K V n) : Option (Node K V n.succ) :=
match node.edges.get? edgeIdx.succ with
/- If this is the rightmost edge, no good. -/
| none => none
| some rightSibling =>
if rightSibling.getElems.size < MINIMUM
then panic! s!"Impossible < MINIMUM steal right"
else if rightSibling.getElems.size = MINIMUM then none
else
let oldMiddle := node.elems.get! edgeIdx
let rightSibling' := rightSibling.rightAfterStealing
let ⟨newMiddle, stolenEdge⟩ := rightSibling.rightStolenElems
let edge' := edge.addStolenToLeft oldMiddle stolenEdge
some {
node with
elems := node.elems.set! edgeIdx newMiddle
edges := (node.edges.set! edgeIdx edge').set! edgeIdx.succ rightSibling'
}
/- Merge the elements of two nodes. For `elems`, add the separator
inbetween the merged arrays. -/
def Node.mergeAux : ∀ {n : Nat}
(lhs : Node K V n)
(midElem : Sigma V)
(rhs : Node K V n), Node K V n
| 0, lhs, midElem, rhs => ⟨(lhs.elems.push midElem).append rhs.elems, PUnit.unit⟩
| n+1, lhs, midElem, rhs => ⟨(lhs.elems.push midElem).append rhs.elems, lhs.edges.append rhs.edges⟩
/-
We're trying to rebalance `node` after we've deleted something from
`edge`, but we failed to steal from the left or right sibling of `edge`,
because the available siblings only had `MINIMUM` elements. So to rebalance,
we merge `edge` with one of its siblings, adding the separator element from
`node.elems`.
-/
@[reducible]
def Node.merge (node : Node K V n.succ) (edgeIdx : Nat) (edge : Node K V n) : Node K V n.succ :=
match edgeIdx with
/- This is the leftmost child; merge with the right sibling. -/
| 0 =>
let rightSibling := node.edges.get! edgeIdx.succ
let middleElem := node.elems.get! 0
let merged := Node.mergeAux edge middleElem rightSibling
{ node with
elems := node.elems.eraseIdx 0
edges := (node.edges.eraseIdx 0).set! 0 merged }
/- This is not the leftmost child; merge with the left sibling. -/
| leftSiblingIdx+1 =>
let leftSibling := node.edges.get! leftSiblingIdx
let middleElem := node.elems.get! leftSiblingIdx
let merged := Node.mergeAux leftSibling middleElem edge
{ node with
elems := node.elems.eraseIdx leftSiblingIdx
edges := (node.edges.eraseIdx leftSiblingIdx).set! leftSiblingIdx merged }
/-
IFF the edge from which we've deleted an element is underfull...
then try to steal from the left sibling,
else try to steal from the right sibling,
else merge, since failing to steal from either side means the reachable
sibling(s) of `edge` are small enough that we need to merge to rebalance.
-/
@[reducible]
def Node.rebalance (node : Node K V n.succ) (edgePos : Nat) : Node K V n.succ :=
let edge := node.edges.get! edgePos
if node.getElems.size >= MINIMUM
then node
else
match node.tryStealLeft edgePos edge with
| some x => x
| none =>
match node.tryStealRight edgePos edge with
| some x => x
| none => node.merge edgePos edge
/-
Helper for deleting from an internal node. We're trying to delete
something higher up, so we steal the next largest element, which is
necessarily the leftmost element in the leaf to the right of the elem
we want to delete. Return the removed leaf element so it can take the place
of the deleted inner node value. Also rebalance on the way back up.
-/
@[reducible]
def Node.deleteInternal : ∀ {n : Nat}, Node K V n → ((Sigma V) × Node K V n)
| 0, node => ⟨node.elems.get! 0, { node with elems := node.elems.eraseIdx 0 }⟩
| n+1, node =>
let ⟨newK, newEdge⟩ := (node.edges.get! 0).deleteInternal
⟨newK, rebalance { node with edges := node.edges.set! 0 newEdge } 0⟩
/-
If the element to be deleted is a leaf node, we just remove it and
rebalance on the way back up. If it's an internal node element, we need
call `deleteInternal`, and similarly rebalance on the way back up.
-/
@[reducible]
def Node.delete : ∀ {n : Nat}, K → Node K V n → (Option (Sigma V) × Node K V n)
| 0, k, node =>
match node.findPos k with
/- Found the element we're trying to delete. -/
| ⟨Ordering.eq, i⟩=> (node.elems.get! i, { node with elems := node.elems.eraseIdx i })
/- It wasn't actually present, and we're already in a leaf. -/
| _ => (none, node)
/- How to do this without having an almost identical | n+2 case? -/
| n+1, k, node =>
match node.findPos k with
/- The elem to be deleted is in this internal node. -/
| ⟨Ordering.eq, elemPos⟩ =>
/- Have to do the swap, then delete. -/
let edgePos := elemPos + 1
let ⟨newElem, newEdge⟩ := (node.edges.get! edgePos).deleteInternal
let node' : Node K V n.succ := {
node with
elems := node.elems.set! elemPos newElem
edges := node.edges.set! edgePos newEdge
}
⟨some <| node.elems.get! elemPos, node'.rebalance edgePos⟩
/- The elem to be deleted is either not present, or is in an edge. -/
| ⟨_, edgePos⟩ =>
match (node.edges.get! edgePos).delete k with
| ⟨none, _⟩ => ⟨none, node⟩
| ⟨some removedElem, edge'⟩ =>
/- Rebalance, then insert the new stuff -/
⟨some removedElem, rebalance { node with edges := node.edges.set! edgePos edge' } edgePos⟩
/-
Insert; if inserting the k/v pair causes the root to overflow and split,
we just make a new tree ⟨rootElem, #[l, r]⟩
-/
def BTree.insert (t : BTree K V) (k : K) (v : V k) : (Bool × BTree K V) :=
match t.root.insert k v with
| (b, Sum.inl root') => ⟨b, { t with root := root' }⟩
| (b, Sum.inr ⟨l, medianElem, r⟩) => ⟨b, BTree.mk t.height.succ (NodeCore.mk #[medianElem] #[l, r])⟩
/-
The only special case is if deletion causes the root to be empty due to a merge
that steals the only root element. In that case, the merged leaf just becomes
the new root since its the only thing left.
-/
def BTree.delete (t : BTree K V) (k : K) : (Option (Sigma V) × BTree K V) :=
match t.root.delete k with
| (none, root') => ⟨none, { t with root := root' }⟩
| (some x, root') =>
if root'.getElems.isEmpty
then
match hEq:t.height with
| 0 => ⟨some x, { t with root := root' }⟩
/- This doesn't work if the pattern is changed to `h+1` -/
| Nat.succ h => by
rw [hEq] at root'
simp [Node] at root'
exact ⟨some x, ⟨h, root'.edges.get! 0⟩⟩
else ⟨some x, { t with root := root' }⟩
end BTreeDefs
section Tests
def insertLoop : BTree Nat (fun _ => String) := Id.run do
let mut out : BTree Nat (fun _ => String) := ∅
for i in [0 : 20] do
out := (out.insert i s!"{i}").snd
for i in [40 : 60] do
out := (out.insert i s!"{i}").snd
for i in [80: 100] do
out := (out.insert i s!"{i}").snd
for i in [60: 80] do
out := (out.insert i s!"{i}").snd
for i in [20: 40] do
out := (out.insert i s!"{i}").snd
out
def deleteLoop (t : BTree Nat (fun _ => String)) : BTree Nat (fun _ => String) := Id.run do
let mut out := t
for i in [0 : 20] do
if i % 2 = 1 then out := (out.delete i).snd
for i in [40 : 60] do
if i % 2 = 1 then out := (out.delete i).snd
for i in [20: 40] do
if i % 2 = 1 then out := (out.delete i).snd
for i in [80: 100] do
if i % 2 = 1 then out := (out.delete i).snd
for i in [60: 80] do
if i % 2 = 1 then out := (out.delete i).snd
out
#eval insertLoop
#eval deleteLoop insertLoop
end Tests
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment