Range in range (solution 3)
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import java.util.*; | |
public class BitVector { | |
public int n; | |
public int m; | |
public long[] bits; | |
public BitVector(int length) { | |
n = length; | |
bits = new long[(n+63)>>>6]; | |
m = bits.length; | |
} | |
public void set(int at) { | |
bits[at>>>6] |= 1L<<(at&63); | |
} | |
public void set(int at, boolean s) { | |
if (s) { | |
bits[at>>>6] |= 1L<<(at&63); | |
} else { | |
bits[at>>>6] &= ~(1L<<(at&63)); | |
} | |
} | |
public boolean get(int at) { | |
int big = at >>> 6 ; | |
if (big >= bits.length) { | |
return false; | |
} | |
return ((bits[big] >>> (at&63)) & 1) == 1; | |
} | |
public BitVector shiftLeft(int l) { | |
BitVector ret = new BitVector(n+l); | |
int big = l >>> 6; | |
int small = l & 63; | |
for (int i = 0; i < m ; i++) { | |
ret.bits[i+big] |= bits[i] << small; | |
} | |
if (small >= 1) { | |
for (int i = 0; i+big+1 < ret.m; i++) { | |
ret.bits[i+big+1] |= (bits[i] >>> (64-small)); | |
} | |
} | |
return ret; | |
} | |
public BitVector or(BitVector o) { | |
BitVector ans = new BitVector(Math.max(n, o.n)); | |
for (int i = 0; i < ans.m ; i++) { | |
if (i < m) { | |
ans.bits[i] = bits[i]; | |
} | |
if (i < o.m) { | |
ans.bits[i] |= o.bits[i]; | |
} | |
} | |
return ans; | |
} | |
public static boolean canMake(int[] x, int n) { | |
Map<Integer,Integer> cnt = new HashMap<>(); | |
for (int i = 0; i < x.length ; i++) { | |
cnt.put(x[i], cnt.getOrDefault(x[i],0)+1); | |
} | |
List<Integer> t = new ArrayList<>(); | |
for (int k : cnt.keySet()) { | |
int v = cnt.get(k); | |
int l = 1; | |
while (l <= v) { | |
v -= l; | |
t.add(k * l); | |
l *= 2; | |
} | |
if (v >= 1) { | |
t.add(k * v); | |
} | |
} | |
Collections.sort(t); | |
BitVector v = new BitVector(1); | |
v.set(0); | |
for (int ti : t) { | |
v = v.or(v.shiftLeft(ti)); | |
} | |
return v.get(n); | |
} | |
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import data_structure.WaveletTree; | |
import java.util.Arrays; | |
import java.util.TreeMap; | |
public class CountRangeCoveringRangeOnlineWavelet { | |
TreeMap<Integer,Integer> yMap; | |
TreeMap<Integer,Integer> xMap; | |
WaveletTree waveletTree; | |
/** | |
* Prepare the online query. | |
* | |
* @cost O(nlogn) where n = ranges.length | |
* | |
* @warning Destroys the given target ranges array. | |
* | |
* @param ranges target ranges | |
*/ | |
public CountRangeCoveringRangeOnlineWavelet(int[][] ranges) { | |
this.yMap = compressY(ranges); | |
this.xMap = compressUniqueX(ranges); | |
int n = ranges.length; | |
int[] ys = new int[n]; | |
for (int i = 0; i < n ; i++) { | |
ys[i] = ranges[i][1]; | |
} | |
this.waveletTree = new WaveletTree(ys); | |
} | |
private static TreeMap<Integer,Integer> compressY(int[][] ranges) { | |
Arrays.sort(ranges, (a, b) -> a[1] - b[1]); | |
int n = ranges.length; | |
TreeMap<Integer,Integer> yMap = new TreeMap<>(); | |
int yidx = 0; | |
for (int i = 0 ; i < n ; ) { | |
int base = ranges[i][1]; | |
yMap.put(base, yidx); | |
int j = i; | |
while (j < n && ranges[j][1] == base) { | |
ranges[j][1] = yidx; | |
j++; | |
} | |
yidx++; | |
i = j; | |
} | |
yMap.put(Integer.MAX_VALUE, yidx); | |
return yMap; | |
} | |
private static TreeMap<Integer,Integer> compressUniqueX(int[][] ranges) { | |
Arrays.sort(ranges, (a, b) -> a[0] - b[0]); | |
int n = ranges.length; | |
TreeMap<Integer,Integer> xMap = new TreeMap<>(); | |
for (int i = 0 ; i < n ; ) { | |
int base = ranges[i][0]; | |
xMap.put(base, i); | |
int j = i; | |
while (j < n && ranges[j][0] == base) { | |
j++; | |
} | |
i = j; | |
} | |
xMap.put(Integer.MAX_VALUE, n); | |
return xMap; | |
} | |
/** | |
* Counts how many ranges in [l,r). | |
* | |
* @cost O(logn) where n = ranges.length | |
* | |
* @param l start of the query range(inclusive) | |
* @param r end of the query range(exclusive) | |
* @return number of target ranges covered by [l,r) | |
*/ | |
public int query(int l, int r) { | |
int xFrom = xMap.ceilingEntry(l).getValue(); | |
int xTo = xMap.higherEntry(r).getValue(); | |
int yFrom = yMap.ceilingEntry(l).getValue(); | |
int yTo = yMap.higherEntry(r).getValue(); | |
return waveletTree.rangeCount(xFrom, xTo, yFrom, yTo); | |
} | |
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import utils.BitVector; | |
import java.util.*; | |
public class WaveletTree { | |
BitVector[] which; | |
public int z; | |
public int level; | |
public int[] from; | |
public int[] to; | |
public int[][][] rank; | |
public int[] blen; | |
public long[] blenMask; | |
public long[] mask; | |
public WaveletTree(int[] values) { | |
int max = 0; | |
for (int i = 0; i < values.length ; i++) { | |
max = Math.max(max, values[i]); | |
} | |
int n = values.length; | |
z = Integer.highestOneBit(max); | |
int zz = z; | |
level = 0; | |
while (zz >= 1) { | |
zz >>= 1; | |
level++; | |
} | |
which = new BitVector[level]; | |
for (int i = 0; i < level; i++) { | |
which[i] = new BitVector(n); | |
} | |
blen = new int[which[0].bits.length]; | |
blenMask = new long[blen.length]; | |
for (int i = 0; i < blen.length; i++) { | |
blen[i] = Math.min(64, n-64*i); | |
blenMask[i] = (1L<<blen[i])-1; | |
} | |
mask = new long[64]; | |
for (int i = 0; i < 64; i++) { | |
mask[i] = (1L<<i)-1; | |
} | |
from = new int[4*z]; | |
to = new int[4*z]; | |
to[1] = n; | |
int[][] buffer = new int[3][n]; | |
buffer[0] = values.clone(); | |
int bi = 0; | |
int lv = 0; | |
for (int i = 1 ; i <= 2*z-1 ; i++) { | |
int mask = (z>>lv); | |
int f = lv%2; | |
int t = 1-f; | |
int tmpbi = 0; | |
from[i*2] = bi; | |
for (int k = from[i] ; k < to[i] ; k++) { | |
if ((buffer[f][k] & mask) == 0) { | |
buffer[t][bi++] = buffer[f][k]; | |
} else { | |
which[lv].set(k); | |
buffer[2][tmpbi++] = buffer[f][k]; | |
} | |
} | |
to[i*2] = bi; | |
from[i*2+1] = bi; | |
to[i*2+1] = bi+tmpbi; | |
System.arraycopy(buffer[2], 0, buffer[t], bi, tmpbi); | |
bi += tmpbi; | |
if (i+1 == Math.pow(2, lv+1)) { | |
lv++; | |
bi = 0; | |
} | |
} | |
// prec rank | |
int bl = which[0].bits.length; | |
rank = new int[level][bl+1][2]; | |
for (int i = 0 ; i < level ; i++) { | |
for (int j = 0; j < bl; j++) { | |
int cnt = Long.bitCount(which[i].bits[j]); | |
rank[i][j+1][1] += rank[i][j][1] + cnt; | |
rank[i][j+1][0] += rank[i][j][0] + (blen[j] - cnt); | |
} | |
} | |
} | |
private int _rank(int idx, int fr, int to, int bit) { | |
return _rank(idx, to, bit) - _rank(idx, fr, bit); | |
} | |
private int _rank(int idx, int pos, int bit) { | |
int px = pos>>>6; | |
int cnt = rank[idx][px][bit]; | |
int cnt2 = ((pos&63) == 0) ? 0 : Long.bitCount(which[idx].bits[px] & mask[pos&63]); | |
return cnt + (bit == 0 ? (pos&63) - cnt2 : cnt2); | |
} | |
private int _select(int idx, int ith, int bit) { | |
int fr = 0; | |
int to = rank[idx].length; | |
while (to - fr > 1) { | |
int med = (to + fr) / 2; | |
if (rank[idx][med][bit] < ith) { | |
fr = med; | |
} else { | |
to = med; | |
} | |
} | |
int left = ith - rank[idx][fr][bit]; | |
int ct = fr * 64; | |
long L = which[idx].bits[fr+1]; | |
for (int x = 0 ; x < 64 ; x++) { | |
int f = (int)((L >> x) & 1) ^ bit; | |
left -= f; | |
if (left == 0) { | |
return ct + x; | |
} | |
} | |
throw new RuntimeException("whoaaaa"); | |
} | |
public int rank(int pos, int value) { | |
int node = 1; | |
for (int lv = 0 ; lv < level ; lv++) { | |
int bit = (value >> (level-lv-1)) & 1; | |
pos = _rank(lv, from[node], from[node] + pos, bit); | |
if (bit == 0) { | |
node = node * 2; | |
} else { | |
node = node * 2 + 1; | |
} | |
} | |
return pos; | |
} | |
public int rangeKth(int ql, int qr, int k) { | |
int node = 1; | |
for (int lv = 0 ; lv < level ; lv++) { | |
int r0 = _rank(lv, from[node], from[node]+ql, 0); | |
int r1 = _rank(lv, from[node]+ql, from[node]+qr, 0); | |
if (k < r1) { | |
node = node * 2; | |
ql = r0; | |
qr = ql + r1; | |
} else { | |
node = node * 2 + 1; | |
k -= r1; | |
int L =qr-ql; | |
ql = ql-r0; | |
qr = ql+(L-r1); | |
} | |
} | |
return node - z * 2; | |
} | |
public int rangeCount(int ql, int qr, int min, int max) { | |
return rangeCount(ql, qr, min, max, 0, 1, 0, z*2); | |
} | |
public int rangeCount(int ql, int qr, int min, int max, int lv, int node, int nodeMin, int nodeMax) { | |
if (min <= nodeMin && nodeMax <= max) { | |
return qr-ql; | |
} | |
if (nodeMax <= min || max <= nodeMin) { | |
return 0; | |
} | |
int r0 = _rank(lv, from[node], from[node]+ql, 0); | |
int r1 = _rank(lv, from[node]+ql, from[node]+qr, 0); | |
int nodeMed = (nodeMin + nodeMax) / 2; | |
int L = qr - ql; | |
int left = rangeCount(r0, r0+r1, min, max, lv+1, node*2, nodeMin, nodeMed); | |
int right = rangeCount(ql-r0, ql-r0+L-r1, min, max, lv+1, node*2+1, nodeMed, nodeMax); | |
return left + right; | |
} | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment