Last active
January 31, 2022 13:10
-
-
Save ammkrn/d7212abd94dfb86308647c888ca23ac0 to your computer and use it in GitHub Desktop.
Lean 4 BTree first attempt
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
def B : Nat := 6 | |
def MIDPOINT : Nat := B.pred | |
def CAPACITY : Nat := (2 * B) - 1 | |
def MINIMUM : Nat := B - 1 | |
def LHS_BACK : Nat := B - 2 | |
def RHS_FRONT : Nat := B - 1 | |
/- | |
BTree implementation where the nodes are dependent on `Height : Nat`. For nodes with | |
height h+1, their edges field is an array of nodes with height h. For nodes with | |
height 0, their edges field is a single PUnit. | |
Has to be separated into two types (`NodeCore` and `Node`), because recursive | |
arguments to a contructor can't depend be dependent in the type | |
being defined. In other words, the constructor for `Node K V (h+1)` cannot take | |
an argument `edges : Array (Node K V h)`, since the type of `edges` depends on `h`. | |
-/ | |
structure NodeCore (K : Type u) (V : K → Type v) (Height : Nat) (R : Type w) [Ord K] where | |
elems : Array (Sigma V) | |
edges : R | |
@[reducible] | |
def Node (K : Type u) (V : K → Type v) [Ord K] : ∀ (height : Nat), Type (max u v) | |
| 0 => NodeCore K V 0 Unit | |
| Nat.succ h => NodeCore K V (h+1) (Array (Node K V h)) | |
structure BTree (K : Type u) (V : K → Type v) [Ord K] where | |
height : Nat | |
root : Node K V height | |
abbrev RecTypeElem (K : Type u) (V : K → Type v) [Ord K] : ∀ (height : Nat), Type (max u v) | |
| 0 => PUnit | |
| h+1 => Node K V h | |
abbrev RecType (K : Type u) (V : K → Type v) [Ord K] : ∀ (height : Nat), Type (max u v) | |
| 0 => PUnit | |
| h+1 => Array (Node K V h) | |
section BTreeDefs | |
variable {K : Type u} {V : K → Type v} {n : Nat} [Ord K] [Inhabited K] [∀ k, Inhabited <| V k] | |
instance : Inhabited (Sigma V) := ⟨Inhabited.default, Inhabited.default⟩ | |
instance (priority := high) : Inhabited (Node K V 0) := ⟨∅, Inhabited.default⟩ | |
instance (priority := high) : EmptyCollection (Node K V 0) := ⟨Inhabited.default⟩ | |
instance : EmptyCollection (BTree K V) := ⟨0, ∅⟩ | |
instance (priority := low) : Inhabited (Node K V n) := | |
⟨match n with | |
| 0 => Inhabited.default | |
| Nat.succ n => ⟨∅, Inhabited.default⟩⟩ | |
@[reducible] | |
def Node.getElems : ∀ {n : Nat}, Node K V n → Array (Sigma V) | |
| 0, node => node.elems | |
| n+1, node => node.elems | |
section Util | |
def Array.insertAtOrPush {A : Type u} (as : Array A) (i : Nat) (a : A) : Array A := | |
if i <= as.size | |
then as.insertAt i a | |
else as.push a | |
instance : ToString Ordering where | |
toString o := | |
match o with | |
| Ordering.eq => "eq" | |
| Ordering.lt => "lt" | |
| Ordering.gt => "gt" | |
def blank (n : Nat) := Id.run do | |
let mut out := "" | |
for i in [0 : n] do | |
out := out.push ' ' | |
out | |
end Util | |
@[reducible] | |
partial def Node.toString [ToString K] [∀ k, ToString (V k)] : ∀ {n} (node : Node K V n), String | |
| 0, node => Id.run do | |
let mut acc := "" | |
for elem in node.elems do | |
acc := acc ++ (ToString.toString elem) ++ "\n" | |
acc | |
| n+1, node => Id.run do | |
let mut acc := "" | |
let mut elemNum := 0 | |
for edge in node.edges do | |
acc := acc ++ toString edge | |
match node.elems.get? elemNum with | |
| none => continue | |
| some elem => acc := acc ++ (blank (4 * n.succ)) ++ s!"{elem}\n" | |
elemNum := elemNum + 1 | |
acc | |
instance [ToString K] [∀ k, ToString (V k)] : ToString (Node K V n) := ⟨Node.toString⟩ | |
instance [ToString K] [∀ k, ToString (V k)] : ToString (BTree K V) := ⟨fun t => s!"{t.root}"⟩ | |
/- | |
Linear search through the elements in an array of elements. | |
If the key is in this node, you'll get back `lt` or `eq`, and the returned | |
`Nat` will be the corresponding key position, or the proper edge position. | |
If the key is not in this node, all iterations will match `gt`, and the | |
last iteration will increment the counter, and the `Nat` returned will be | |
the last edge in the array of edges. There are no instances in which a return | |
value of `Ordering.gt, _` will be used to access a member of `elems`, so | |
this shouldn't cause confusion. | |
-/ | |
def Node.findPos (node : Node K V n) (k : K) : (Ordering × Nat) := Id.run do | |
let mut idx := 0 | |
for ⟨k', _⟩ in node.getElems do | |
match Ord.compare k k' with | |
| Ordering.lt => return ⟨Ordering.lt, idx⟩ | |
| Ordering.eq => return ⟨Ordering.eq, idx⟩ | |
| Ordering.gt => | |
idx := idx + 1 | |
continue | |
⟨Ordering.gt, node.getElems.size⟩ | |
/- For some reason `Array.get` is only defined on `Type`. -/ | |
@[reducible] | |
def Node.get? : ∀ {n : Nat} (node : Node K V n) (k : K), Option (Sigma V) | |
| 0 , node, k => | |
match node.findPos k with | |
| ⟨Ordering.eq, pos⟩ => some <| node.elems.get! pos | |
| _ => none | |
| n+1, node, k => | |
match node.findPos k with | |
| ⟨Ordering.lt, pos⟩ => (node.edges.get! pos).get? k | |
| ⟨Ordering.gt, pos⟩ => node.edges.back.get? k | |
| ⟨Ordering.eq, pos⟩ => some <| node.elems.get! pos | |
def BTree.get? (t : BTree K V) (k : K) : Option (Sigma V) := t.root.get? k | |
/- Used when we need to split an overflowing node. -/ | |
def Node.splitElems (node : Node K V n) (inbound : Sigma V) (inPos : Nat) : (Array (Sigma V) × (Sigma V) × Array (Sigma V)) := | |
let all := node.getElems.insertAtOrPush inPos inbound | |
let rhs := all.extract B all.size | |
let lhs := all.shrink B | |
assert! all.size = B * 2 | |
assert! rhs.size = B | |
assert! lhs.size = B | |
(lhs.pop, lhs.back, rhs) | |
/- Used when we need to split an overflowing node. -/ | |
def Node.splitEdges (node : Node K V n.succ) (inL inR : Node K V n) (inPos : Nat) : (Array (Node K V n) × Array (Node K V n)) := | |
/- Since inPos is always <= the edges length, `set!; insertAt` is fine. -/ | |
let all := (node.edges.set! inPos inR).insertAt inPos inL | |
let rhs := all.extract B all.size | |
let lhs := all.shrink B | |
assert! all.size = (2 * B).succ | |
assert! rhs.size = B.succ | |
assert! lhs.size = B | |
⟨lhs, rhs⟩ | |
/- Split a leaf node that's too full to accommodate an additional element. -/ | |
def Leaf.split (node : Node K V 0) (inbound : Sigma V) (keyPos : Nat) : (Node K V 0 × Sigma V × Node K V 0) := | |
let ⟨elemsLeft, medianElem, elemsRight⟩ := node.splitElems inbound keyPos | |
⟨⟨elemsLeft, Inhabited.default⟩, medianElem, ⟨elemsRight, Inhabited.default⟩⟩ | |
/- | |
Split an internal node that's too full to accommodate an additional element. | |
If the original findPos was lt or eq, then keyPos = edgePos. | |
If the original findPos was `gt`, then this is both the right edgePos, | |
AND for the newly inserted node, the right keyPos, since we're using `insertOrPush`. | |
-/ | |
def Internal.split | |
(node : Node K V n.succ) | |
(elem : Sigma V) | |
(edgeL edgeR : Node K V n) | |
(inPos : Nat) : (Node K V n.succ × Sigma V × Node K V n.succ) := | |
let ⟨elemsLeft, medianElem, elemsRight⟩ := node.splitElems elem inPos | |
let ⟨edgesLeft, edgesRight⟩ := node.splitEdges edgeL edgeR inPos | |
⟨⟨elemsLeft, edgesLeft⟩, medianElem, ⟨elemsRight, edgesRight⟩⟩ | |
/-- | |
The first element is `true` iff `k` was already in the map. | |
The second element is `Sum.inl _` if we didn't need to split, and | |
`Sum.inr ..` if we needed to split, which will be handled as we | |
go back up. | |
-/ | |
def Node.insert : ∀ {n : Nat} (k : K), V k → Node K V n → (Bool × Sum (Node K V n) (Node K V n × Sigma V × Node K V n)) | |
| 0, k, v, node => | |
match node.findPos k with | |
| ⟨Ordering.eq, hitPos⟩ => (true, Sum.inl <| { node with elems := node.elems.set! hitPos ⟨k, v⟩ }) | |
| ⟨_, pos⟩ => | |
if node.elems.size >= CAPACITY | |
/- we're full; split this leaf -/ | |
then (false, Sum.inr <| Leaf.split node ⟨k, v⟩ pos) | |
/- Not full, just insert the element and return. -/ | |
else (false, Sum.inl <| { node with elems := node.elems.insertAtOrPush pos ⟨k, v⟩ }) | |
| n+1, k, v, node => | |
match node.findPos k with | |
| ⟨Ordering.eq, hitPos⟩ => (true, Sum.inl <| { node with elems := node.elems.set! hitPos ⟨k, v⟩ }) | |
| ⟨_, pos⟩ => | |
match (node.edges.get! pos).insert k v with | |
| ⟨b, Sum.inl edge'⟩ => (b, Sum.inl <| { node with edges := node.edges.set! pos edge' }) | |
| ⟨b, Sum.inr ⟨l, kv', r⟩⟩ => | |
if node.elems.size >= CAPACITY | |
/- Full; have to split. -/ | |
then (b, Sum.inr <| Internal.split node kv' l r pos) | |
/- not full. -/ | |
else (b, Sum.inl { | |
node with | |
elems := node.elems.insertAtOrPush pos kv', | |
edges := (node.edges.set! pos r).insertAt pos l }) | |
def Node.leftStolenElems : ∀ {n : Nat}, Node K V n → (Sigma V × RecTypeElem K V n) | |
| 0, node => (node.elems.back, PUnit.unit) | |
| n+1, node => (node.elems.back, node.edges.back) | |
def Node.leftAfterStealing : ∀ {n : Nat}, Node K V n → Node K V n | |
| 0, node => { node with elems := node.elems.pop, edges := PUnit.unit } | |
| n+1, node => { node with elems := node.elems.pop, edges := node.edges.pop } | |
def Node.rightStolenElems : ∀ {n : Nat}, Node K V n → (Sigma V × RecTypeElem K V n) | |
| 0, node => (node.elems.get! 0, PUnit.unit) | |
| n+1, node => (node.elems.get! 0, node.edges.get! 0) | |
def Node.rightAfterStealing : ∀ {n : Nat}, Node K V n → Node K V n | |
| 0, node => { node with elems := node.elems.eraseIdx 0, edges := PUnit.unit } | |
| n+1, node => { node with elems := node.elems.eraseIdx 0, edges := node.edges.eraseIdx 0 } | |
def Node.addStolenToRight : ∀ {n : Nat}, (stolenEdge : RecTypeElem K V n) → (oldMiddle : Sigma V) → Node K V n → Node K V n | |
| 0, _, oldMiddle, rhs => { rhs with elems := rhs.elems.insertAtOrPush 0 oldMiddle, edges := PUnit.unit } | |
| n+1, stolenEdge, oldMiddle, rhs => { rhs with elems := rhs.elems.insertAtOrPush 0 oldMiddle, edges := rhs.edges.insertAt 0 stolenEdge } | |
def Node.addStolenToLeft : ∀ {n : Nat}, (oldMiddle : Sigma V) → (stolenEdge : RecTypeElem K V n) → Node K V n → Node K V n | |
| 0, oldMiddle, _, lhs => { lhs with elems := lhs.elems.push oldMiddle, edges := PUnit.unit } | |
| n+1, oldMiddle, stolenEdge, lhs => { lhs with elems := lhs.elems.push oldMiddle, edges := lhs.edges.push stolenEdge } | |
/- | |
Try to steal from the left sibling while rebalancing a node, | |
aftter we've deleted something from `edge`. | |
.. x y z .. | |
/ \ | |
leftSibling edge | |
y is the element we want to switch out; it has idx `leftIdx`. | |
-/ | |
@[reducible] | |
def Node.tryStealLeft (node : Node K V n.succ) (edgeIdx : Nat) (edge : Node K V n) : Option (Node K V n.succ) := | |
match edgeIdx with | |
/- If this is the leftmost edge, no good. -/ | |
| 0 => none | |
| leftIdx+1 => | |
let leftSibling := node.edges.get! leftIdx | |
if leftSibling.getElems.size < MINIMUM | |
then panic! "tryStealLeft, impossible size < MINIMUM" | |
else if leftSibling.getElems.size = MINIMUM | |
then none | |
else | |
let oldMiddle := node.elems.get! leftIdx | |
let ⟨newMiddle, stolenEdge⟩ := leftSibling.leftStolenElems | |
let leftSibling' := leftSibling.leftAfterStealing | |
let edge' := edge.addStolenToRight stolenEdge oldMiddle | |
some { | |
node with | |
elems := node.elems.set! leftIdx newMiddle | |
edges := (node.edges.set! leftIdx leftSibling').set! edgeIdx edge' | |
} | |
/- | |
Try to steal from the right sibling while rebalancing this node after we've | |
deleted something from `edge`. `y` has index `edgeIdx`, rightSibling is | |
at `edgeIdx.succ`. | |
.. x y z .. | |
/ \ | |
edge rightSibling | |
-/ | |
@[reducible] | |
def Node.tryStealRight (node : Node K V n.succ) (edgeIdx : Nat) (edge : Node K V n) : Option (Node K V n.succ) := | |
match node.edges.get? edgeIdx.succ with | |
/- If this is the rightmost edge, no good. -/ | |
| none => none | |
| some rightSibling => | |
if rightSibling.getElems.size < MINIMUM | |
then panic! s!"Impossible < MINIMUM steal right" | |
else if rightSibling.getElems.size = MINIMUM then none | |
else | |
let oldMiddle := node.elems.get! edgeIdx | |
let rightSibling' := rightSibling.rightAfterStealing | |
let ⟨newMiddle, stolenEdge⟩ := rightSibling.rightStolenElems | |
let edge' := edge.addStolenToLeft oldMiddle stolenEdge | |
some { | |
node with | |
elems := node.elems.set! edgeIdx newMiddle | |
edges := (node.edges.set! edgeIdx edge').set! edgeIdx.succ rightSibling' | |
} | |
/- Merge the elements of two nodes. For `elems`, add the separator | |
inbetween the merged arrays. -/ | |
def Node.mergeAux : ∀ {n : Nat} | |
(lhs : Node K V n) | |
(midElem : Sigma V) | |
(rhs : Node K V n), Node K V n | |
| 0, lhs, midElem, rhs => ⟨(lhs.elems.push midElem).append rhs.elems, PUnit.unit⟩ | |
| n+1, lhs, midElem, rhs => ⟨(lhs.elems.push midElem).append rhs.elems, lhs.edges.append rhs.edges⟩ | |
/- | |
We're trying to rebalance `node` after we've deleted something from | |
`edge`, but we failed to steal from the left or right sibling of `edge`, | |
because the available siblings only had `MINIMUM` elements. So to rebalance, | |
we merge `edge` with one of its siblings, adding the separator element from | |
`node.elems`. | |
-/ | |
@[reducible] | |
def Node.merge (node : Node K V n.succ) (edgeIdx : Nat) (edge : Node K V n) : Node K V n.succ := | |
match edgeIdx with | |
/- This is the leftmost child; merge with the right sibling. -/ | |
| 0 => | |
let rightSibling := node.edges.get! edgeIdx.succ | |
let middleElem := node.elems.get! 0 | |
let merged := Node.mergeAux edge middleElem rightSibling | |
{ node with | |
elems := node.elems.eraseIdx 0 | |
edges := (node.edges.eraseIdx 0).set! 0 merged } | |
/- This is not the leftmost child; merge with the left sibling. -/ | |
| leftSiblingIdx+1 => | |
let leftSibling := node.edges.get! leftSiblingIdx | |
let middleElem := node.elems.get! leftSiblingIdx | |
let merged := Node.mergeAux leftSibling middleElem edge | |
{ node with | |
elems := node.elems.eraseIdx leftSiblingIdx | |
edges := (node.edges.eraseIdx leftSiblingIdx).set! leftSiblingIdx merged } | |
/- | |
IFF the edge from which we've deleted an element is underfull... | |
then try to steal from the left sibling, | |
else try to steal from the right sibling, | |
else merge, since failing to steal from either side means the reachable | |
sibling(s) of `edge` are small enough that we need to merge to rebalance. | |
-/ | |
@[reducible] | |
def Node.rebalance (node : Node K V n.succ) (edgePos : Nat) : Node K V n.succ := | |
let edge := node.edges.get! edgePos | |
if node.getElems.size >= MINIMUM | |
then node | |
else | |
match node.tryStealLeft edgePos edge with | |
| some x => x | |
| none => | |
match node.tryStealRight edgePos edge with | |
| some x => x | |
| none => node.merge edgePos edge | |
/- | |
Helper for deleting from an internal node. We're trying to delete | |
something higher up, so we steal the next largest element, which is | |
necessarily the leftmost element in the leaf to the right of the elem | |
we want to delete. Return the removed leaf element so it can take the place | |
of the deleted inner node value. Also rebalance on the way back up. | |
-/ | |
@[reducible] | |
def Node.deleteInternal : ∀ {n : Nat}, Node K V n → ((Sigma V) × Node K V n) | |
| 0, node => ⟨node.elems.get! 0, { node with elems := node.elems.eraseIdx 0 }⟩ | |
| n+1, node => | |
let ⟨newK, newEdge⟩ := (node.edges.get! 0).deleteInternal | |
⟨newK, rebalance { node with edges := node.edges.set! 0 newEdge } 0⟩ | |
/- | |
If the element to be deleted is a leaf node, we just remove it and | |
rebalance on the way back up. If it's an internal node element, we need | |
call `deleteInternal`, and similarly rebalance on the way back up. | |
-/ | |
@[reducible] | |
def Node.delete : ∀ {n : Nat}, K → Node K V n → (Option (Sigma V) × Node K V n) | |
| 0, k, node => | |
match node.findPos k with | |
/- Found the element we're trying to delete. -/ | |
| ⟨Ordering.eq, i⟩=> (node.elems.get! i, { node with elems := node.elems.eraseIdx i }) | |
/- It wasn't actually present, and we're already in a leaf. -/ | |
| _ => (none, node) | |
/- How to do this without having an almost identical | n+2 case? -/ | |
| n+1, k, node => | |
match node.findPos k with | |
/- The elem to be deleted is in this internal node. -/ | |
| ⟨Ordering.eq, elemPos⟩ => | |
/- Have to do the swap, then delete. -/ | |
let edgePos := elemPos + 1 | |
let ⟨newElem, newEdge⟩ := (node.edges.get! edgePos).deleteInternal | |
let node' : Node K V n.succ := { | |
node with | |
elems := node.elems.set! elemPos newElem | |
edges := node.edges.set! edgePos newEdge | |
} | |
⟨some <| node.elems.get! elemPos, node'.rebalance edgePos⟩ | |
/- The elem to be deleted is either not present, or is in an edge. -/ | |
| ⟨_, edgePos⟩ => | |
match (node.edges.get! edgePos).delete k with | |
| ⟨none, _⟩ => ⟨none, node⟩ | |
| ⟨some removedElem, edge'⟩ => | |
/- Rebalance, then insert the new stuff -/ | |
⟨some removedElem, rebalance { node with edges := node.edges.set! edgePos edge' } edgePos⟩ | |
/- | |
Insert; if inserting the k/v pair causes the root to overflow and split, | |
we just make a new tree ⟨rootElem, #[l, r]⟩ | |
-/ | |
def BTree.insert (t : BTree K V) (k : K) (v : V k) : (Bool × BTree K V) := | |
match t.root.insert k v with | |
| (b, Sum.inl root') => ⟨b, { t with root := root' }⟩ | |
| (b, Sum.inr ⟨l, medianElem, r⟩) => ⟨b, BTree.mk t.height.succ (NodeCore.mk #[medianElem] #[l, r])⟩ | |
/- | |
The only special case is if deletion causes the root to be empty due to a merge | |
that steals the only root element. In that case, the merged leaf just becomes | |
the new root since its the only thing left. | |
-/ | |
def BTree.delete (t : BTree K V) (k : K) : (Option (Sigma V) × BTree K V) := | |
match t.root.delete k with | |
| (none, root') => ⟨none, { t with root := root' }⟩ | |
| (some x, root') => | |
if root'.getElems.isEmpty | |
then | |
match hEq:t.height with | |
| 0 => ⟨some x, { t with root := root' }⟩ | |
/- This doesn't work if the pattern is changed to `h+1` -/ | |
| Nat.succ h => by | |
rw [hEq] at root' | |
simp [Node] at root' | |
exact ⟨some x, ⟨h, root'.edges.get! 0⟩⟩ | |
else ⟨some x, { t with root := root' }⟩ | |
end BTreeDefs | |
section Tests | |
def insertLoop : BTree Nat (fun _ => String) := Id.run do | |
let mut out : BTree Nat (fun _ => String) := ∅ | |
for i in [0 : 20] do | |
out := (out.insert i s!"{i}").snd | |
for i in [40 : 60] do | |
out := (out.insert i s!"{i}").snd | |
for i in [80: 100] do | |
out := (out.insert i s!"{i}").snd | |
for i in [60: 80] do | |
out := (out.insert i s!"{i}").snd | |
for i in [20: 40] do | |
out := (out.insert i s!"{i}").snd | |
out | |
def deleteLoop (t : BTree Nat (fun _ => String)) : BTree Nat (fun _ => String) := Id.run do | |
let mut out := t | |
for i in [0 : 20] do | |
if i % 2 = 1 then out := (out.delete i).snd | |
for i in [40 : 60] do | |
if i % 2 = 1 then out := (out.delete i).snd | |
for i in [20: 40] do | |
if i % 2 = 1 then out := (out.delete i).snd | |
for i in [80: 100] do | |
if i % 2 = 1 then out := (out.delete i).snd | |
for i in [60: 80] do | |
if i % 2 = 1 then out := (out.delete i).snd | |
out | |
#eval insertLoop | |
#eval deleteLoop insertLoop | |
end Tests |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment