Last active
December 20, 2020 02:38
-
-
Save rosuH/d2f65783fe16fe7653fe9f0519fec442 to your computer and use it in GitHub Desktop.
SegmentTree
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
public interface Merger<E> { | |
E merge(E a, E b); | |
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import java.util.Arrays; | |
public class SegmentTree<E> { | |
private E[] data; | |
private E[] tree; | |
private Merger merger; | |
public SegmentTree(E[] arr, Merger<E> merger) { | |
this.merger = merger; | |
data = (E[]) new Object[arr.length]; | |
for (int i = 0; i < arr.length; i++) { | |
data[i] = arr[i]; | |
} | |
tree = (E[]) new Object[arr.length * 4]; | |
buildSegmentTree(0, 0, data.length - 1); | |
} | |
// 在 treeIndex 的位置创建表示区间[l..r]的线段树 | |
private void buildSegmentTree(int treeIndex, int l, int r) { | |
// 先考虑递归到底 | |
if (l == r) { | |
tree[treeIndex] = data[l]; | |
return; | |
} | |
int leftTreeIndex = leftChild(treeIndex); | |
int rightTreeIndex = rightChild(treeIndex); | |
int mid = l + (r - l) / 2; | |
// 创建左右子树 | |
buildSegmentTree(leftTreeIndex, l, mid); | |
buildSegmentTree(rightTreeIndex, mid + 1, r); | |
tree[treeIndex] = (E) merger.merge(tree[leftTreeIndex], tree[rightTreeIndex]); | |
} | |
public int getSize() { | |
return data.length; | |
} | |
private E get(int index) { | |
if (index < 0 || index >= data.length) { | |
throw new IllegalArgumentException("Index is illegal."); | |
} | |
return data[index]; | |
} | |
private int leftChild(int index) { | |
return 2 * index + 1; | |
} | |
private int rightChild(int index) { | |
return 2 * index + 2; | |
} | |
public E query(int queryL, int queryR) { | |
if (queryL < 0 || queryL >= data.length || queryR < 0 || queryR >= data.length) { | |
throw new IllegalArgumentException("Index is illegal"); | |
} | |
return query(0, 0, data.length + 1, queryL, queryR); | |
} | |
// 在以 treeIndex 为根的线段树中[l..r]的返回内,搜区间[queryL,queryR]的值 | |
private E query(int treeIndex, int l, int r, int queryL, int queryR) { | |
// 先考虑查询到底的情况 | |
if (l == queryL && r == queryR) { | |
System.out.println("Query " + treeIndex + " == " + tree[treeIndex]); | |
return tree[treeIndex]; | |
} | |
int mid = l + (r - l) / 2; | |
int leftTreeIndex = leftChild(treeIndex); | |
int rightTreeIndex = rightChild(treeIndex); | |
if (queryL >= mid + 1) { | |
return query(rightTreeIndex, mid + 1, r, queryL, queryR); | |
} else if (queryR <= mid) { | |
return query(leftTreeIndex, l, mid, queryL, queryR); | |
} else { | |
// 一部分落在左孩子,一部分落在右孩子 | |
E leftRect = query(leftTreeIndex, l, mid, queryL, mid); | |
E rightRect = query(rightTreeIndex, mid + 1, r, mid + 1, queryR); | |
return (E) merger.merge(leftRect, rightRect); | |
} | |
} | |
public void set(int index, E e) { | |
if (index < 0 || index >= data.length) { | |
throw new IllegalArgumentException("Index is illegal."); | |
} | |
data[index] = e; | |
set(0, 0, data.length - 1, index, e); | |
} | |
private void set(int treeIndex, int l, int r, int index, E e) { | |
if (l == r) { | |
tree[treeIndex] = e; | |
} | |
int mid = l + (r - l) / 2; | |
int leftChildIndex = leftChild(treeIndex); | |
int rightChildIndex = rightChild(treeIndex); | |
if (index >= mid + 1) { | |
set(rightChildIndex, mid + 1, r, index, e); | |
} else { | |
set(leftChildIndex, l, mid, index, e); | |
} | |
tree[treeIndex] = (E) merger.merge(tree[leftChildIndex], tree[rightChildIndex]); | |
} | |
@Override | |
public String toString() { | |
return "SegmentTree{" + | |
"data=" + Arrays.toString(data) + | |
", tree=" + Arrays.toString(tree) + | |
", merger=" + merger + | |
'}'; | |
} | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment