Create a gist now

Instantly share code, notes, and snippets.

What would you like to do?
Range in range (solution 3)
import java.util.*;
public class BitVector {
public int n;
public int m;
public long[] bits;
public BitVector(int length) {
n = length;
bits = new long[(n+63)>>>6];
m = bits.length;
}
public void set(int at) {
bits[at>>>6] |= 1L<<(at&63);
}
public void set(int at, boolean s) {
if (s) {
bits[at>>>6] |= 1L<<(at&63);
} else {
bits[at>>>6] &= ~(1L<<(at&63));
}
}
public boolean get(int at) {
int big = at >>> 6 ;
if (big >= bits.length) {
return false;
}
return ((bits[big] >>> (at&63)) & 1) == 1;
}
public BitVector shiftLeft(int l) {
BitVector ret = new BitVector(n+l);
int big = l >>> 6;
int small = l & 63;
for (int i = 0; i < m ; i++) {
ret.bits[i+big] |= bits[i] << small;
}
if (small >= 1) {
for (int i = 0; i+big+1 < ret.m; i++) {
ret.bits[i+big+1] |= (bits[i] >>> (64-small));
}
}
return ret;
}
public BitVector or(BitVector o) {
BitVector ans = new BitVector(Math.max(n, o.n));
for (int i = 0; i < ans.m ; i++) {
if (i < m) {
ans.bits[i] = bits[i];
}
if (i < o.m) {
ans.bits[i] |= o.bits[i];
}
}
return ans;
}
public static boolean canMake(int[] x, int n) {
Map<Integer,Integer> cnt = new HashMap<>();
for (int i = 0; i < x.length ; i++) {
cnt.put(x[i], cnt.getOrDefault(x[i],0)+1);
}
List<Integer> t = new ArrayList<>();
for (int k : cnt.keySet()) {
int v = cnt.get(k);
int l = 1;
while (l <= v) {
v -= l;
t.add(k * l);
l *= 2;
}
if (v >= 1) {
t.add(k * v);
}
}
Collections.sort(t);
BitVector v = new BitVector(1);
v.set(0);
for (int ti : t) {
v = v.or(v.shiftLeft(ti));
}
return v.get(n);
}
}
import data_structure.WaveletTree;
import java.util.Arrays;
import java.util.TreeMap;
public class CountRangeCoveringRangeOnlineWavelet {
TreeMap<Integer,Integer> yMap;
TreeMap<Integer,Integer> xMap;
WaveletTree waveletTree;
/**
* Prepare the online query.
*
* @cost O(nlogn) where n = ranges.length
*
* @warning Destroys the given target ranges array.
*
* @param ranges target ranges
*/
public CountRangeCoveringRangeOnlineWavelet(int[][] ranges) {
this.yMap = compressY(ranges);
this.xMap = compressUniqueX(ranges);
int n = ranges.length;
int[] ys = new int[n];
for (int i = 0; i < n ; i++) {
ys[i] = ranges[i][1];
}
this.waveletTree = new WaveletTree(ys);
}
private static TreeMap<Integer,Integer> compressY(int[][] ranges) {
Arrays.sort(ranges, (a, b) -> a[1] - b[1]);
int n = ranges.length;
TreeMap<Integer,Integer> yMap = new TreeMap<>();
int yidx = 0;
for (int i = 0 ; i < n ; ) {
int base = ranges[i][1];
yMap.put(base, yidx);
int j = i;
while (j < n && ranges[j][1] == base) {
ranges[j][1] = yidx;
j++;
}
yidx++;
i = j;
}
yMap.put(Integer.MAX_VALUE, yidx);
return yMap;
}
private static TreeMap<Integer,Integer> compressUniqueX(int[][] ranges) {
Arrays.sort(ranges, (a, b) -> a[0] - b[0]);
int n = ranges.length;
TreeMap<Integer,Integer> xMap = new TreeMap<>();
for (int i = 0 ; i < n ; ) {
int base = ranges[i][0];
xMap.put(base, i);
int j = i;
while (j < n && ranges[j][0] == base) {
j++;
}
i = j;
}
xMap.put(Integer.MAX_VALUE, n);
return xMap;
}
/**
* Counts how many ranges in [l,r).
*
* @cost O(logn) where n = ranges.length
*
* @param l start of the query range(inclusive)
* @param r end of the query range(exclusive)
* @return number of target ranges covered by [l,r)
*/
public int query(int l, int r) {
int xFrom = xMap.ceilingEntry(l).getValue();
int xTo = xMap.higherEntry(r).getValue();
int yFrom = yMap.ceilingEntry(l).getValue();
int yTo = yMap.higherEntry(r).getValue();
return waveletTree.rangeCount(xFrom, xTo, yFrom, yTo);
}
}
import utils.BitVector;
import java.util.*;
public class WaveletTree {
BitVector[] which;
public int z;
public int level;
public int[] from;
public int[] to;
public int[][][] rank;
public int[] blen;
public long[] blenMask;
public long[] mask;
public WaveletTree(int[] values) {
int max = 0;
for (int i = 0; i < values.length ; i++) {
max = Math.max(max, values[i]);
}
int n = values.length;
z = Integer.highestOneBit(max);
int zz = z;
level = 0;
while (zz >= 1) {
zz >>= 1;
level++;
}
which = new BitVector[level];
for (int i = 0; i < level; i++) {
which[i] = new BitVector(n);
}
blen = new int[which[0].bits.length];
blenMask = new long[blen.length];
for (int i = 0; i < blen.length; i++) {
blen[i] = Math.min(64, n-64*i);
blenMask[i] = (1L<<blen[i])-1;
}
mask = new long[64];
for (int i = 0; i < 64; i++) {
mask[i] = (1L<<i)-1;
}
from = new int[4*z];
to = new int[4*z];
to[1] = n;
int[][] buffer = new int[3][n];
buffer[0] = values.clone();
int bi = 0;
int lv = 0;
for (int i = 1 ; i <= 2*z-1 ; i++) {
int mask = (z>>lv);
int f = lv%2;
int t = 1-f;
int tmpbi = 0;
from[i*2] = bi;
for (int k = from[i] ; k < to[i] ; k++) {
if ((buffer[f][k] & mask) == 0) {
buffer[t][bi++] = buffer[f][k];
} else {
which[lv].set(k);
buffer[2][tmpbi++] = buffer[f][k];
}
}
to[i*2] = bi;
from[i*2+1] = bi;
to[i*2+1] = bi+tmpbi;
System.arraycopy(buffer[2], 0, buffer[t], bi, tmpbi);
bi += tmpbi;
if (i+1 == Math.pow(2, lv+1)) {
lv++;
bi = 0;
}
}
// prec rank
int bl = which[0].bits.length;
rank = new int[level][bl+1][2];
for (int i = 0 ; i < level ; i++) {
for (int j = 0; j < bl; j++) {
int cnt = Long.bitCount(which[i].bits[j]);
rank[i][j+1][1] += rank[i][j][1] + cnt;
rank[i][j+1][0] += rank[i][j][0] + (blen[j] - cnt);
}
}
}
private int _rank(int idx, int fr, int to, int bit) {
return _rank(idx, to, bit) - _rank(idx, fr, bit);
}
private int _rank(int idx, int pos, int bit) {
int px = pos>>>6;
int cnt = rank[idx][px][bit];
int cnt2 = ((pos&63) == 0) ? 0 : Long.bitCount(which[idx].bits[px] & mask[pos&63]);
return cnt + (bit == 0 ? (pos&63) - cnt2 : cnt2);
}
private int _select(int idx, int ith, int bit) {
int fr = 0;
int to = rank[idx].length;
while (to - fr > 1) {
int med = (to + fr) / 2;
if (rank[idx][med][bit] < ith) {
fr = med;
} else {
to = med;
}
}
int left = ith - rank[idx][fr][bit];
int ct = fr * 64;
long L = which[idx].bits[fr+1];
for (int x = 0 ; x < 64 ; x++) {
int f = (int)((L >> x) & 1) ^ bit;
left -= f;
if (left == 0) {
return ct + x;
}
}
throw new RuntimeException("whoaaaa");
}
public int rank(int pos, int value) {
int node = 1;
for (int lv = 0 ; lv < level ; lv++) {
int bit = (value >> (level-lv-1)) & 1;
pos = _rank(lv, from[node], from[node] + pos, bit);
if (bit == 0) {
node = node * 2;
} else {
node = node * 2 + 1;
}
}
return pos;
}
public int rangeKth(int ql, int qr, int k) {
int node = 1;
for (int lv = 0 ; lv < level ; lv++) {
int r0 = _rank(lv, from[node], from[node]+ql, 0);
int r1 = _rank(lv, from[node]+ql, from[node]+qr, 0);
if (k < r1) {
node = node * 2;
ql = r0;
qr = ql + r1;
} else {
node = node * 2 + 1;
k -= r1;
int L =qr-ql;
ql = ql-r0;
qr = ql+(L-r1);
}
}
return node - z * 2;
}
public int rangeCount(int ql, int qr, int min, int max) {
return rangeCount(ql, qr, min, max, 0, 1, 0, z*2);
}
public int rangeCount(int ql, int qr, int min, int max, int lv, int node, int nodeMin, int nodeMax) {
if (min <= nodeMin && nodeMax <= max) {
return qr-ql;
}
if (nodeMax <= min || max <= nodeMin) {
return 0;
}
int r0 = _rank(lv, from[node], from[node]+ql, 0);
int r1 = _rank(lv, from[node]+ql, from[node]+qr, 0);
int nodeMed = (nodeMin + nodeMax) / 2;
int L = qr - ql;
int left = rangeCount(r0, r0+r1, min, max, lv+1, node*2, nodeMin, nodeMed);
int right = rangeCount(ql-r0, ql-r0+L-r1, min, max, lv+1, node*2+1, nodeMed, nodeMax);
return left + right;
}
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment