Skip to content

Instantly share code, notes, and snippets.

@viveksb007
Created February 20, 2021 09:55
Show Gist options
  • Save viveksb007/a6d56c4a564842c4f27c8cfc5c47343e to your computer and use it in GitHub Desktop.
Save viveksb007/a6d56c4a564842c4f27c8cfc5c47343e to your computer and use it in GitHub Desktop.
import java.util.Arrays;
public class SegmentTree {
private final MergeFunction<Integer, Integer, Integer> mergeFunction;
private final Integer[] tree;
public SegmentTree(Integer[] arr, MergeFunction<Integer, Integer, Integer> mergeFunction) {
tree = new Integer[findArraySizeToRepresentTree(arr.length)];
this.mergeFunction = mergeFunction;
Arrays.fill(tree, null);
createTree(arr);
}
public void updateElement(int index, int updatedValue, Integer[] arr) {
updateElementUtil(0, arr.length - 1, index, updatedValue, 0, arr);
}
private int updateElementUtil(int start, int end, int updatedIndex, int updatedValue, int index, Integer[] arr) {
if (updatedIndex < start || updatedIndex > end) return tree[index];
if (start == end) {
arr[updatedIndex] = updatedValue;
tree[index] = updatedValue;
return updatedValue;
}
int mid = (start + end) / 2;
int l = updateElementUtil(start, mid, updatedIndex, updatedValue, 2 * index + 1, arr);
int r = updateElementUtil(mid + 1, end, updatedIndex, updatedValue, 2 * index + 2, arr);
tree[index] = mergeFunction.merge(l, r);
return tree[index];
}
public Integer findQuery(int left, int right, Integer[] arr) {
return findQueryUtil(0, 0, arr.length - 1, left, right);
}
private Integer findQueryUtil(int index, int start, int end, int left, int right) {
if (start >= left && end <= right) return tree[index];
if (end < left || start > right) return null;
int mid = (start + end) / 2;
Integer l = findQueryUtil(2 * index + 1, start, mid, left, right);
Integer r = findQueryUtil(2 * index + 2, mid + 1, end, left, right);
return mergeFunction.merge(l, r);
}
private void createTree(Integer[] arr) {
createTreeUtil(0, 0, arr.length - 1, arr);
}
private int createTreeUtil(int index, int start, int end, Integer[] arr) {
if (start == end) {
tree[index] = arr[start];
return tree[index];
}
int mid = (start + end) / 2;
int l = createTreeUtil(2 * index + 1, start, mid, arr);
int r = createTreeUtil(2 * index + 2, mid + 1, end, arr);
tree[index] = mergeFunction.merge(l, r);
return tree[index];
}
public static int findArraySizeToRepresentTree(int n) {
double power = Math.log(n) / Math.log(2);
if (((int) power) == power) return 2 * n - 1;
int reqPower = (int) power + 1;
return (int) Math.pow(2, reqPower) * 2 - 1;
}
}
import org.testng.annotations.Test;
import static org.testng.Assert.assertEquals;
public class SegmentTreeTest {
@Test
public void testSegmentTreeForRangeSumQueryProblems() {
Integer[] arr = new Integer[]{1, 3, 5, 7, 9, 11};
SegmentTree segmentTree = new SegmentTree(arr, new SumFunction());
int result = segmentTree.findQuery(2, 4, arr);
assertEquals(21, result);
segmentTree.updateElement(3, 8, arr);
result = segmentTree.findQuery(2, 4, arr);
assertEquals(22, result);
segmentTree.updateElement(0, 3, arr);
result = segmentTree.findQuery(0, 5, arr);
assertEquals(39, result);
}
@Test
public void testSegmentTreeForMinRangeQueryProblems() {
Integer[] arr = new Integer[]{1, 3, 2, 7, 9, 11};
SegmentTree segmentTree = new SegmentTree(arr, new MinFunction());
int result = segmentTree.findQuery(1, 5, arr);
assertEquals(2, result);
segmentTree.updateElement(2, -1, arr);
result = segmentTree.findQuery(1, 5, arr);
assertEquals(-1, result);
segmentTree.updateElement(5, 8, arr);
result = segmentTree.findQuery(4, 5, arr);
assertEquals(8, result);
}
@Test
public void testArraySizeForTreeLogic() {
assertEquals(15, SegmentTree.findArraySizeToRepresentTree(8));
assertEquals(15, SegmentTree.findArraySizeToRepresentTree(6));
assertEquals(31, SegmentTree.findArraySizeToRepresentTree(10));
}
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment