Skip to content

Instantly share code, notes, and snippets.

@winger winger/k_tree.java Secret
Created May 6, 2016

Embed
What would you like to do?
import java.io.*;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.HashMap;
public class k_tree {
final static long MOD = 1000000007;
private static long modPow(long x, long pow, long mod) {
long r = 1;
while (pow > 0) {
if (pow % 2 == 1) {
r = r * x % mod;
}
pow /= 2;
x = x * x % mod;
}
return r;
}
static class Graph {
int k;
int[][] es;
}
static class Partition {
int[] p;
long spanningTreesCache = 0;
public Partition(int[] p) {
this.p = p;
}
@Override
public boolean equals(Object o) {
if (this == o) {
return true;
}
if (o == null || getClass() != o.getClass()) {
return false;
}
return Arrays.equals(p, ((Partition) o).p);
}
@Override
public int hashCode() {
return Arrays.hashCode(p);
}
public long spanningTrees() {
if (spanningTreesCache == 0) {
int[] cnts = new int[p.length];
for (int i : p) {
cnts[i]++;
}
spanningTreesCache = 1;
for (int i = 0; i < cnts.length; ++i) {
if (cnts[i] > 1) {
spanningTreesCache = spanningTreesCache * modPow(cnts[i], cnts[i] - 2, MOD) % MOD;
}
}
}
return spanningTreesCache;
}
}
static class Partitions {
ArrayList<Partition> list = new ArrayList<>();
HashMap<Partition, Integer> index = new HashMap<>();
public Partitions(int k) {
genPartitions(0, 0, new int[k]);
}
private void genPartitions(int i, int m, int[] p) {
if (i == p.length) {
Partition par = new Partition(p.clone());
list.add(par);
index.put(par, index.size());
return;
}
for (p[i] = 0; p[i] <= m; ++p[i]) {
genPartitions(i + 1, Math.max(m, p[i] + 1), p);
}
}
}
static class Facet {
int[] vs;
long[] cs;
public Facet(int[] vs, Partitions pk) {
this.vs = vs;
cs = new long[pk.list.size()];
cs[cs.length - 1] = 1;
}
@Override
public boolean equals(Object o) {
if (this == o) {
return true;
}
if (o == null || getClass() != o.getClass()) {
return false;
}
Facet facet = (Facet) o;
return Arrays.equals(vs, facet.vs);
}
@Override
public int hashCode() {
return Arrays.hashCode(vs);
}
}
static class DSU {
int[] p, r;
int comps;
DSU(int n) {
p = new int[n];
r = new int[n];
comps = n;
for (int i = 0; i < n; ++i) {
p[i] = i;
}
}
int get(int i) {
if (p[i] != i) {
p[i] = get(p[i]);
}
return p[i];
}
void unite(int i, int j) {
i = get(i);
j = get(j);
if (i != j) {
if (r[i] < r[j]) {
p[i] = j;
} else {
p[j] = i;
}
if (r[i] == r[j]) {
r[i]++;
}
comps--;
}
}
}
static long solveFast(Graph g) {
Partitions pk = new Partitions(g.k), pk1 = new Partitions(g.k + 1);
int[][][] transitions = new int[pk1.list.size()][pk.list.size()][g.k + 1];
int[][] eTransitions = new int[pk1.list.size()][g.k];
int[] reductions = new int[pk1.list.size()];
for (int u = 0; u < pk1.list.size(); u++) {
Partition pu = pk1.list.get(u);
for (int v = 0; v < pk.list.size(); v++) {
Partition pv = pk.list.get(v);
loop: for (int e = 0; e <= g.k; e++) {
DSU dsu = new DSU(2 * g.k + 1);
for (int i = 0, j = 0; i < g.k + 1; ++i) {
if (i == e) {
continue;
}
if (dsu.get(pu.p[i]) == dsu.get(pv.p[j] + g.k + 1)) {
transitions[u][v][e] = -1;
continue loop;
}
dsu.unite(pu.p[i], pv.p[j] + g.k + 1);
j++;
}
int[] p = new int[g.k + 1];
int[] col = new int[2 * g.k + 1];
Arrays.fill(col, -1);
int cols = 0;
for (int i = 0; i < g.k + 1; ++i) {
int c = dsu.get(pu.p[i]);
if (col[c] == -1) {
col[c] = cols++;
}
p[i] = col[c];
}
transitions[u][v][e] = pk1.index.get(new Partition(p));
}
}
for (int e = 0; e < g.k; e++) {
DSU dsu = new DSU(g.k + 1);
if (dsu.get(pu.p[e]) == dsu.get(pu.p[g.k])) {
eTransitions[u][e] = -1;
continue;
}
dsu.unite(pu.p[e], pu.p[g.k]);
int[] p = new int[g.k + 1];
int[] col = new int[2 * g.k + 1];
Arrays.fill(col, -1);
int cols = 0;
for (int i = 0; i < g.k + 1; ++i) {
int c = dsu.get(pu.p[i]);
if (col[c] == -1) {
col[c] = cols++;
}
p[i] = col[c];
}
eTransitions[u][e] = pk1.index.get(new Partition(p));
}
int cLast = 0;
for (int i = 0; i < g.k + 1; ++i) {
if (pu.p[i] == pu.p[g.k]) {
cLast++;
}
}
if (cLast == 1) {
reductions[u] = -1;
} else {
reductions[u] = pk.index.get(new Partition(Arrays.copyOf(pu.p, g.k)));
}
}
HashMap<Facet, Facet> facetsSet = new HashMap<>();
int[] f0ar = new int[g.k];
for (int i = 0; i < g.k; ++i) {
f0ar[i] = i;
}
Facet f0 = new Facet(f0ar, pk);
Arrays.fill(f0.cs, 0);
for (int u = 0; u < pk.list.size(); u++) {
f0.cs[u] = pk.list.get(u).spanningTrees();
}
facetsSet.put(f0, f0);
Facet[][] facets = new Facet[g.es.length][g.k];
for (int i = g.k; i < g.k + g.es.length; ++i) {
for (int e = 0; e < g.k; ++e) {
int[] far = new int[g.k];
for (int t = 0, it = 0; t < g.k; ++t) {
if (t != e) {
far[it++] = g.es[i - g.k][t];
}
}
far[g.k - 1] = i;
facets[i - g.k][e] = new Facet(far, pk);
facetsSet.put(facets[i - g.k][e], facets[i - g.k][e]);
}
}
for (int i = g.k + g.es.length - 1; i >= g.k; --i) {
long[] d = new long[pk1.list.size()];
d[d.length - 1] = 1;
Facet fBase = facetsSet.get(new Facet(g.es[i - g.k], pk));
for (int e = 0; e <= g.k; ++e) {
Facet f = e < g.k ? facets[i - g.k][e] : fBase;
long[] d1 = new long[pk1.list.size()];
for (int u = 0; u < pk1.list.size(); ++u) {
for (int v = 0; v < pk.list.size(); v++) {
int w = transitions[u][v][e];
if (w != -1) {
d1[w] = (d1[w] + d[u] * f.cs[v]) % MOD;
}
}
}
d = d1;
}
for (int e = 0; e < g.k; ++e) {
long[] d1 = d.clone();
for (int u = 0; u < pk1.list.size(); ++u) {
int w = eTransitions[u][e];
if (w != -1) {
d1[w] = (d1[w] + d[u]) % MOD;
}
}
d = d1;
}
Arrays.fill(fBase.cs, 0);
for (int u = 0; u < pk1.list.size(); u++) {
int v = reductions[u];
if (v != -1) {
fBase.cs[v] = (fBase.cs[v] + d[u]) % MOD;
}
}
}
return f0.cs[0];
}
public static void solve(Input in, PrintWriter out) throws IOException {
int n = in.nextInt(), k = in.nextInt();
Graph g = new Graph();
g.k = k;
g.es = new int[n - k][k];
for (int i = k; i < n; ++i) {
for (int j = 0; j < k; ++j) {
g.es[i - k][j] = in.nextInt() - 1;
}
Arrays.sort(g.es[i - g.k]);
}
out.println(solveFast(g));
}
public static void main(String[] args) throws IOException {
PrintWriter out = new PrintWriter(System.out);
solve(new Input(new BufferedReader(new InputStreamReader(System.in))), out);
out.close();
}
static class Input {
BufferedReader in;
StringBuilder sb = new StringBuilder();
public Input(BufferedReader in) {
this.in = in;
}
public Input(String s) {
this.in = new BufferedReader(new StringReader(s));
}
public String next() throws IOException {
sb.setLength(0);
while (true) {
int c = in.read();
if (c == -1) {
return null;
}
if (" \n\r\t".indexOf(c) == -1) {
sb.append((char)c);
break;
}
}
while (true) {
int c = in.read();
if (c == -1 || " \n\r\t".indexOf(c) != -1) {
break;
}
sb.append((char)c);
}
return sb.toString();
}
public int nextInt() throws IOException {
return Integer.parseInt(next());
}
public long nextLong() throws IOException {
return Long.parseLong(next());
}
public double nextDouble() throws IOException {
return Double.parseDouble(next());
}
}
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
You can’t perform that action at this time.