Skip to content

Instantly share code, notes, and snippets.

@tjkendev
Created April 22, 2021 16:05
Show Gist options
  • Save tjkendev/0ef98fd8afef118075be4a22daf485a0 to your computer and use it in GitHub Desktop.
Save tjkendev/0ef98fd8afef118075be4a22daf485a0 to your computer and use it in GitHub Desktop.
DSU on Tree
# DSU on Tree
# Ref: https://codeforces.com/blog/entry/44351
import random
import time
import sys
random.seed()
sys.setrecursionlimit(10**6)
N = 10000
C = 10
cols = [random.randint(0, C-1) for i in range(N)]
G = [[] for i in range(N)]
*rg, = range(N)
random.shuffle(rg)
for i in range(1, N):
j = random.randint(0, i-1)
G[rg[i]].append(rg[j])
G[rg[j]].append(rg[i])
# calc_sz: O(N)
def calc_sz():
sz = [0]*N
def dfs_sz(v, p):
c = 0
for w in G[v]:
if w == p:
continue
c += dfs_sz(w, v)
sz[v] = c
return c
dfs_sz(0, -1)
return sz
sz = calc_sz()
# calc_ett: O(N)
def calc_ett():
st = [0]*N; ft = [0]*N
ver = [0]*N
idx = 0
def dfs_ett(v, p):
nonlocal idx
st[v] = idx
ver[idx] = v
idx += 1
for w in G[v]:
if w == p:
continue
dfs_ett(w, v)
ft[v] = idx
dfs_ett(0, -1)
return st, ft, ver
st, ft, ver = calc_ett()
# naive: O(N^2 + NC)
def solver_naive(N, G, C, cols, sz):
R = [[0]*C for i in range(N)]
cnts = [0]*C
def dfs_add(v, p, x):
cnts[cols[v]] += x
for w in G[v]:
if w != p:
dfs_add(w, v, x)
def dfs(v, p):
dfs_add(v, p, 1)
Rv = R[v]
for i in range(C):
Rv[i] = cnts[i]
dfs_add(v, p, -1)
for w in G[v]:
if w != p:
dfs(w, v)
dfs(0, -1)
return R
# impl1: O(N log^2 N + NC)
def solver_impl1(N, G, C, cols, sz):
R = [[0]*C for i in range(N)]
cnts = [None for i in range(N)]
sz = [0]*N
def dfs(v, p):
mx = heavy = -1
for w in G[v]:
if w == p:
continue
dfs(w, v)
if mx < sz[w]:
mx = sz[w]; heavy = w
if heavy != -1:
cnts[v] = cnts[heavy]
else:
cnts[v] = {}
cv = cnts[v]
cv[cols[v]] = cv.get(cols[v], 0) + 1
for w in G[v]:
if w == p or w == heavy:
continue
for c, val in cnts[w].items():
cv[c] = cv.get(c, 0) + val
Rv = R[v]
for i in range(C):
Rv[i] = cv.get(i, 0)
dfs(0, -1)
return R
# impl2: O(N log N + NC)
def solver_impl2(N, G, C, cols, sz):
R = [[0]*C for i in range(N)]
vec = [None]*N
cnts = [0]*C
def dfs(v, p, keep):
mx = heavy = -1
for w in G[v]:
if w == p:
continue
if mx < sz[w]:
mx = sz[w]
heavy = w
for w in G[v]:
if w == p or w == heavy:
continue
dfs(w, v, 0)
if heavy != -1:
dfs(heavy, v, 1)
vec[v] = vec[heavy]
else:
vec[v] = []
vec[v].append(v)
cnts[cols[v]] += 1
for w in G[v]:
if w == p or w == heavy:
continue
for x in vec[w]:
cnts[cols[x]] += 1
vec[v].append(x)
Rv = R[v]
for i in range(C):
Rv[i] = cnts[i]
if keep == 0:
for x in vec[v]:
cnts[cols[x]] -= 1
dfs(0, -1, 1)
return R
# impl3: O(N log N + NC)
def solver_impl3(N, G, C, cols, sz):
R = [[0]*C for i in range(N)]
cnts = [0]*C
hs = [0]*N
def dfs_add(v, p, x):
cnts[cols[v]] += x
for w in G[v]:
if w == p or hs[w]:
continue
dfs_add(w, v, x)
def dfs(v, p, keep):
mx = heavy = -1
for w in G[v]:
if w == p:
continue
if mx < sz[w]:
mx = sz[w]
heavy = w
for w in G[v]:
if w != p and w != heavy:
dfs(w, v, 0)
if heavy != -1:
dfs(heavy, v, 1)
hs[heavy] = 1
dfs_add(v, p, 1)
Rv = R[v]
for i in range(C):
Rv[i] = cnts[i]
if heavy != -1:
hs[heavy] = 0
if keep == 0:
dfs_add(v, p, -1)
dfs(0, -1, 1)
return R
# impl4: O(N log N + NC)
def solver_impl4(N, G, C, cols, sz, st, ft, ver):
R = [[0]*C for i in range(N)]
cnts = [0]*C
def dfs(v, p, keep):
mx = heavy = -1
for w in G[v]:
if w == p:
continue
if mx < sz[w]:
mx = sz[w]
heavy = w
for w in G[v]:
if w == p or w == heavy:
continue
dfs(w, v, 0)
if heavy != -1:
dfs(heavy, v, 1)
#for w in G[v]:
# if w == p or w == heavy:
# continue
# for p in range(st[w], ft[w]):
# cnts[cols[ver[p]]] += 1
if heavy != -1:
for p in range(st[v]+1, st[heavy]):
cnts[cols[ver[p]]] += 1
for p in range(ft[heavy], ft[v]):
cnts[cols[ver[p]]] += 1
cnts[cols[v]] += 1
Rv = R[v]
for i in range(C):
Rv[i] = cnts[i]
if keep == 0:
for p in range(st[v], ft[v]):
cnts[cols[ver[p]]] -= 1
dfs(0, -1, 0)
return R
t0 = time.time()
R0 = solver_naive(N, G, C, cols, sz)
t1 = time.time()
R1 = solver_impl1(N, G, C, cols, sz)
t2 = time.time()
R2 = solver_impl2(N, G, C, cols, sz)
t3 = time.time()
R3 = solver_impl3(N, G, C, cols, sz)
t4 = time.time()
R4 = solver_impl4(N, G, C, cols, sz, st, ft, ver)
t5 = time.time()
print(R0 == R1 == R2 == R3 == R4)
print("solver_naive:", t1 - t0)
print("solver_impl1:", t2 - t1)
print("solver_impl2:", t3 - t2)
print("solver_impl3:", t4 - t3)
print("solver_impl4:", t5 - t4)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment