Skip to content

Instantly share code, notes, and snippets.

@Sam-Belliveau
Last active March 7, 2023 15:48
Show Gist options
  • Save Sam-Belliveau/834ab56b1a5f6f3a0f09d07672bf6817 to your computer and use it in GitHub Desktop.
Save Sam-Belliveau/834ab56b1a5f6f3a0f09d07672bf6817 to your computer and use it in GitHub Desktop.
/**
* Copyright 2023 Sam Belliveau
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the “Software”), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED “AS IS”, WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
* SOFTWARE.
*/
import java.util.Iterator;
import java.util.Optional;
/**
* A simple Radix Tree implemented in java that effectively acts as a map with a
* Long key.
*
* All functions are implemented using Optional to guarantee memory safety, an
* exception will never be thrown.
*
* For time complexities, n is the number of bits in the key, limited to 64.
*
* get(), set(), and remove() are all log(n).
*
* min(), next(), max(), and prev() are also all log(n) as empty trees are
* purged immediately when made empty, this allows us to efficiently scan
* through the tree for the next valid node in log(n) time.
*
* This means that iterating over the entire sorted tree is an k * log(n)
* where k is the number of elements in the array.
*
* Iterators may not function if removing elements that have not been iterated
* over yet.
*
* @author Sam B. (sam.belliveau@gmail.com)
*/
public class RadixTree<T> implements Iterable<RadixTree.Pair<T>> {
/*** PAIR ***/
public final static class Pair<V> {
private static <T> Optional<Pair<T>> of(final long key,
final Optional<T> value) {
return value.map(v -> new Pair<>(key, v));
}
public final long key;
public final V value;
private Pair(final long key, final V value) {
this.key = key;
this.value = value;
}
}
/*** CONFIGURATION CONSTANTS ***/
public static final int kBits = Long.BYTES * 8;
public static final long kMask = (-1L) ^ ((-1L) >>> 1);
/*** MEMBER VARIABLES ***/
private Optional<T> mVal;
private Optional<RadixTree<T>> mL;
private Optional<RadixTree<T>> mH;
private int mSize;
/*** CONSTRUCTORS ***/
private RadixTree() {
mVal = Optional.empty();
mL = Optional.empty();
mH = Optional.empty();
mSize = 0;
}
/*** SIZE ***/
// O(1)
public int size() { return mSize; }
// O(1)
public boolean isEmpty() { return size() == 0; }
/*** GETTERS / SETTERS ***/
// O(1)
private Optional<RadixTree<T>> getChild(final long index) {
return ((index & kMask) == 0) ? mL : mH;
}
// O(1)
private Optional<RadixTree<T>> initChild(final long index) {
return ((index & kMask) == 0)
? (mL = mL.or(() -> Optional.of(new RadixTree<>())))
: (mH = mH.or(() -> Optional.of(new RadixTree<>())));
}
// O(log(n))
public Optional<T> get(final long index) {
return mVal.filter(p -> index == 0)
.or(() -> getChild(index).flatMap(c -> c.get(index << 1)));
}
// O(log(n))
public Optional<T> set(final long index, final T x) {
if (index == 0) {
final Optional<T> previous = mVal;
mVal = Optional.of(x);
return previous;
}
final var result = initChild(index).flatMap(c -> c.set(index << 1, x));
if (result.isEmpty())
++mSize;
return result;
}
// O(log(n))
public Optional<T> remove(final long index) {
if (index == 0) {
final Optional<T> previous = mVal;
mVal = Optional.empty();
return previous;
}
final var result = getChild(index).flatMap(c -> c.remove(index << 1));
if (result.isPresent()) {
--mSize;
mL = mL.filter(t -> !t.isEmpty());
mH = mH.filter(t -> !t.isEmpty());
}
return result;
}
/*** SEARCH FORWARDS ***/
// O(1)
private static <U> Pair<U> isL(final Pair<U> p) {
return new Pair<>(p.key >>> 1, p.value);
}
// O(1)
private static <U> Pair<U> isH(final Pair<U> p) {
return new Pair<>(p.key >>> 1 | kMask, p.value);
}
// O(log(n))
public Optional<Pair<T>> min() {
return (Pair.of(0, mVal)
.or(() -> mL.flatMap(t -> t.min().map(p -> isL(p))))
.or(() -> mH.flatMap(t -> t.min().map(p -> isH(p)))));
}
// O(log(n))
public Optional<Pair<T>> next(final long index) {
return Pair.of(index, mVal).filter(p -> p.key == 0).or(() -> {
return ((index & kMask) != 0)
? mH.flatMap(t -> t.next(index << 1).map(p -> isH(p)))
: mL.flatMap(t -> t.next(index << 1).map(p -> isL(p)))
.or(() -> mH.flatMap(t -> t.min().map(p -> isH(p))));
});
}
/*** SEARCH BACKWARDS ***/
// O(log(n))
public Optional<Pair<T>> max() {
return (mH.flatMap(t -> t.min().map(p -> isH(p)))
.or(() -> mL.flatMap(t -> t.min()).map(p -> isL(p)))
.or(() -> Pair.of(0, mVal)));
}
// O(log(n))
public Optional<Pair<T>> prev(final long index) {
return Pair.of(index, mVal).filter(p -> p.key != 0).or(() -> {
return ((index & kMask) == 0)
? mL.flatMap(t -> t.prev(index << 1).map(p -> isL(p)))
: mH.flatMap(t -> t.prev(index << 1).map(p -> isH(p)))
.or(() -> mL.flatMap(t -> t.max()).map(p -> isL(p)));
});
}
/*** ITERATORS ***/
public Iterator<Pair<T>> iterator() {
final var parent = this;
return new Iterator<Pair<T>>() {
Optional<Pair<T>> next = min();
public void remove() { next.ifPresent(p -> parent.remove(p.key)); }
public boolean hasNext() { return next.isPresent(); }
public Pair<T> next() {
final Pair<T> r = next.get();
next = parent.next(r.key + 1).filter(p -> r.key != -1);
return r;
}
};
}
public Iterator<Pair<T>> reverseIterator() {
final var parent = this;
return new Iterator<Pair<T>>() {
Optional<Pair<T>> next = max();
public void remove() { next.ifPresent(p -> parent.remove(p.key)); }
public boolean hasNext() { return next.isPresent(); }
public Pair<T> next() {
final Pair<T> r = next.get();
next = parent.prev(r.key - 1).filter(p -> r.key != 0);
return r;
}
};
}
/*** TESTS ***/
public static void main(String... args) {
RadixTree<Integer> table = new RadixTree<>();
table.set(6941, 100);
for (int a = 0; a < 10000; ++a) {
table.set(2 * (long)(Math.random() * Integer.MAX_VALUE), a);
}
for (Pair<Integer> l : table) {
System.out.println(l.key + ": " + l.value +
" | size: " + table.size());
if (l.key % 2 == 0) {
table.remove(l.key);
}
}
System.out.println();
for (Pair<Integer> l : table) {
System.out.println(l.key + ": " + l.value);
}
}
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment