Range in range (solution 3)
import java.util.*; | |
public class BitVector { | |
public int n; | |
public int m; | |
public long[] bits; | |
public BitVector(int length) { | |
n = length; | |
bits = new long[(n+63)>>>6]; | |
m = bits.length; | |
} | |
public void set(int at) { | |
bits[at>>>6] |= 1L<<(at&63); | |
} | |
public void set(int at, boolean s) { | |
if (s) { | |
bits[at>>>6] |= 1L<<(at&63); | |
} else { | |
bits[at>>>6] &= ~(1L<<(at&63)); | |
} | |
} | |
public boolean get(int at) { | |
int big = at >>> 6 ; | |
if (big >= bits.length) { | |
return false; | |
} | |
return ((bits[big] >>> (at&63)) & 1) == 1; | |
} | |
public BitVector shiftLeft(int l) { | |
BitVector ret = new BitVector(n+l); | |
int big = l >>> 6; | |
int small = l & 63; | |
for (int i = 0; i < m ; i++) { | |
ret.bits[i+big] |= bits[i] << small; | |
} | |
if (small >= 1) { | |
for (int i = 0; i+big+1 < ret.m; i++) { | |
ret.bits[i+big+1] |= (bits[i] >>> (64-small)); | |
} | |
} | |
return ret; | |
} | |
public BitVector or(BitVector o) { | |
BitVector ans = new BitVector(Math.max(n, o.n)); | |
for (int i = 0; i < ans.m ; i++) { | |
if (i < m) { | |
ans.bits[i] = bits[i]; | |
} | |
if (i < o.m) { | |
ans.bits[i] |= o.bits[i]; | |
} | |
} | |
return ans; | |
} | |
public static boolean canMake(int[] x, int n) { | |
Map<Integer,Integer> cnt = new HashMap<>(); | |
for (int i = 0; i < x.length ; i++) { | |
cnt.put(x[i], cnt.getOrDefault(x[i],0)+1); | |
} | |
List<Integer> t = new ArrayList<>(); | |
for (int k : cnt.keySet()) { | |
int v = cnt.get(k); | |
int l = 1; | |
while (l <= v) { | |
v -= l; | |
t.add(k * l); | |
l *= 2; | |
} | |
if (v >= 1) { | |
t.add(k * v); | |
} | |
} | |
Collections.sort(t); | |
BitVector v = new BitVector(1); | |
v.set(0); | |
for (int ti : t) { | |
v = v.or(v.shiftLeft(ti)); | |
} | |
return v.get(n); | |
} | |
} |
import data_structure.WaveletTree; | |
import java.util.Arrays; | |
import java.util.TreeMap; | |
public class CountRangeCoveringRangeOnlineWavelet { | |
TreeMap<Integer,Integer> yMap; | |
TreeMap<Integer,Integer> xMap; | |
WaveletTree waveletTree; | |
/** | |
* Prepare the online query. | |
* | |
* @cost O(nlogn) where n = ranges.length | |
* | |
* @warning Destroys the given target ranges array. | |
* | |
* @param ranges target ranges | |
*/ | |
public CountRangeCoveringRangeOnlineWavelet(int[][] ranges) { | |
this.yMap = compressY(ranges); | |
this.xMap = compressUniqueX(ranges); | |
int n = ranges.length; | |
int[] ys = new int[n]; | |
for (int i = 0; i < n ; i++) { | |
ys[i] = ranges[i][1]; | |
} | |
this.waveletTree = new WaveletTree(ys); | |
} | |
private static TreeMap<Integer,Integer> compressY(int[][] ranges) { | |
Arrays.sort(ranges, (a, b) -> a[1] - b[1]); | |
int n = ranges.length; | |
TreeMap<Integer,Integer> yMap = new TreeMap<>(); | |
int yidx = 0; | |
for (int i = 0 ; i < n ; ) { | |
int base = ranges[i][1]; | |
yMap.put(base, yidx); | |
int j = i; | |
while (j < n && ranges[j][1] == base) { | |
ranges[j][1] = yidx; | |
j++; | |
} | |
yidx++; | |
i = j; | |
} | |
yMap.put(Integer.MAX_VALUE, yidx); | |
return yMap; | |
} | |
private static TreeMap<Integer,Integer> compressUniqueX(int[][] ranges) { | |
Arrays.sort(ranges, (a, b) -> a[0] - b[0]); | |
int n = ranges.length; | |
TreeMap<Integer,Integer> xMap = new TreeMap<>(); | |
for (int i = 0 ; i < n ; ) { | |
int base = ranges[i][0]; | |
xMap.put(base, i); | |
int j = i; | |
while (j < n && ranges[j][0] == base) { | |
j++; | |
} | |
i = j; | |
} | |
xMap.put(Integer.MAX_VALUE, n); | |
return xMap; | |
} | |
/** | |
* Counts how many ranges in [l,r). | |
* | |
* @cost O(logn) where n = ranges.length | |
* | |
* @param l start of the query range(inclusive) | |
* @param r end of the query range(exclusive) | |
* @return number of target ranges covered by [l,r) | |
*/ | |
public int query(int l, int r) { | |
int xFrom = xMap.ceilingEntry(l).getValue(); | |
int xTo = xMap.higherEntry(r).getValue(); | |
int yFrom = yMap.ceilingEntry(l).getValue(); | |
int yTo = yMap.higherEntry(r).getValue(); | |
return waveletTree.rangeCount(xFrom, xTo, yFrom, yTo); | |
} | |
} |
import utils.BitVector; | |
import java.util.*; | |
public class WaveletTree { | |
BitVector[] which; | |
public int z; | |
public int level; | |
public int[] from; | |
public int[] to; | |
public int[][][] rank; | |
public int[] blen; | |
public long[] blenMask; | |
public long[] mask; | |
public WaveletTree(int[] values) { | |
int max = 0; | |
for (int i = 0; i < values.length ; i++) { | |
max = Math.max(max, values[i]); | |
} | |
int n = values.length; | |
z = Integer.highestOneBit(max); | |
int zz = z; | |
level = 0; | |
while (zz >= 1) { | |
zz >>= 1; | |
level++; | |
} | |
which = new BitVector[level]; | |
for (int i = 0; i < level; i++) { | |
which[i] = new BitVector(n); | |
} | |
blen = new int[which[0].bits.length]; | |
blenMask = new long[blen.length]; | |
for (int i = 0; i < blen.length; i++) { | |
blen[i] = Math.min(64, n-64*i); | |
blenMask[i] = (1L<<blen[i])-1; | |
} | |
mask = new long[64]; | |
for (int i = 0; i < 64; i++) { | |
mask[i] = (1L<<i)-1; | |
} | |
from = new int[4*z]; | |
to = new int[4*z]; | |
to[1] = n; | |
int[][] buffer = new int[3][n]; | |
buffer[0] = values.clone(); | |
int bi = 0; | |
int lv = 0; | |
for (int i = 1 ; i <= 2*z-1 ; i++) { | |
int mask = (z>>lv); | |
int f = lv%2; | |
int t = 1-f; | |
int tmpbi = 0; | |
from[i*2] = bi; | |
for (int k = from[i] ; k < to[i] ; k++) { | |
if ((buffer[f][k] & mask) == 0) { | |
buffer[t][bi++] = buffer[f][k]; | |
} else { | |
which[lv].set(k); | |
buffer[2][tmpbi++] = buffer[f][k]; | |
} | |
} | |
to[i*2] = bi; | |
from[i*2+1] = bi; | |
to[i*2+1] = bi+tmpbi; | |
System.arraycopy(buffer[2], 0, buffer[t], bi, tmpbi); | |
bi += tmpbi; | |
if (i+1 == Math.pow(2, lv+1)) { | |
lv++; | |
bi = 0; | |
} | |
} | |
// prec rank | |
int bl = which[0].bits.length; | |
rank = new int[level][bl+1][2]; | |
for (int i = 0 ; i < level ; i++) { | |
for (int j = 0; j < bl; j++) { | |
int cnt = Long.bitCount(which[i].bits[j]); | |
rank[i][j+1][1] += rank[i][j][1] + cnt; | |
rank[i][j+1][0] += rank[i][j][0] + (blen[j] - cnt); | |
} | |
} | |
} | |
private int _rank(int idx, int fr, int to, int bit) { | |
return _rank(idx, to, bit) - _rank(idx, fr, bit); | |
} | |
private int _rank(int idx, int pos, int bit) { | |
int px = pos>>>6; | |
int cnt = rank[idx][px][bit]; | |
int cnt2 = ((pos&63) == 0) ? 0 : Long.bitCount(which[idx].bits[px] & mask[pos&63]); | |
return cnt + (bit == 0 ? (pos&63) - cnt2 : cnt2); | |
} | |
private int _select(int idx, int ith, int bit) { | |
int fr = 0; | |
int to = rank[idx].length; | |
while (to - fr > 1) { | |
int med = (to + fr) / 2; | |
if (rank[idx][med][bit] < ith) { | |
fr = med; | |
} else { | |
to = med; | |
} | |
} | |
int left = ith - rank[idx][fr][bit]; | |
int ct = fr * 64; | |
long L = which[idx].bits[fr+1]; | |
for (int x = 0 ; x < 64 ; x++) { | |
int f = (int)((L >> x) & 1) ^ bit; | |
left -= f; | |
if (left == 0) { | |
return ct + x; | |
} | |
} | |
throw new RuntimeException("whoaaaa"); | |
} | |
public int rank(int pos, int value) { | |
int node = 1; | |
for (int lv = 0 ; lv < level ; lv++) { | |
int bit = (value >> (level-lv-1)) & 1; | |
pos = _rank(lv, from[node], from[node] + pos, bit); | |
if (bit == 0) { | |
node = node * 2; | |
} else { | |
node = node * 2 + 1; | |
} | |
} | |
return pos; | |
} | |
public int rangeKth(int ql, int qr, int k) { | |
int node = 1; | |
for (int lv = 0 ; lv < level ; lv++) { | |
int r0 = _rank(lv, from[node], from[node]+ql, 0); | |
int r1 = _rank(lv, from[node]+ql, from[node]+qr, 0); | |
if (k < r1) { | |
node = node * 2; | |
ql = r0; | |
qr = ql + r1; | |
} else { | |
node = node * 2 + 1; | |
k -= r1; | |
int L =qr-ql; | |
ql = ql-r0; | |
qr = ql+(L-r1); | |
} | |
} | |
return node - z * 2; | |
} | |
public int rangeCount(int ql, int qr, int min, int max) { | |
return rangeCount(ql, qr, min, max, 0, 1, 0, z*2); | |
} | |
public int rangeCount(int ql, int qr, int min, int max, int lv, int node, int nodeMin, int nodeMax) { | |
if (min <= nodeMin && nodeMax <= max) { | |
return qr-ql; | |
} | |
if (nodeMax <= min || max <= nodeMin) { | |
return 0; | |
} | |
int r0 = _rank(lv, from[node], from[node]+ql, 0); | |
int r1 = _rank(lv, from[node]+ql, from[node]+qr, 0); | |
int nodeMed = (nodeMin + nodeMax) / 2; | |
int L = qr - ql; | |
int left = rangeCount(r0, r0+r1, min, max, lv+1, node*2, nodeMin, nodeMed); | |
int right = rangeCount(ql-r0, ql-r0+L-r1, min, max, lv+1, node*2+1, nodeMed, nodeMax); | |
return left + right; | |
} | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment