Last active
September 9, 2015 14:16
-
-
Save ftiasch/2ae976347be07da2ab5f to your computer and use it in GitHub Desktop.
2014 Beijing Regional Onsite Contest Problem J Just a Challenge
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
/* | |
* 假设l = aa..a, r = bb..b,则题目等价于计算树上距离恰好是m的点对数量。不难发现本题需要树分治, | |
* | |
* 分治点u时,假设v到u对应的字符串是S = s1 s2 ... sk,需要比较S和l, r长度为k的前后缀的字典序大小。 | |
* | |
* (作减法之后,只讨论上界r) | |
* 1. 对reverse(r)建立后缀自动机。 | |
* 从u开始dfs树,在自动机上对应转移。 | |
* 因为所在节点的right集合不一定包含前缀0, | |
* 所以需要预处理节点(沿着parent)最近的包含前缀0的祖先, | |
* 该祖先的length即S和r的最长公共前缀。只需比较一次即可确定大小关系。 | |
* | |
* 2. 对r建立后缀树。从u开始dfs树,并在后缀树上对应转移。 | |
* 注意可能会走到后缀树上不存在的节点,要记录下分叉的节点。 | |
* 预处理后缀树的一个dfs编号,就很容易比较出两个串的字典序。 | |
* | |
* 时间复杂度O(n log n) | |
*/ | |
#include <algorithm> | |
#include <iostream> | |
#include <cstdio> | |
#include <cstring> | |
#include <cassert> | |
#include <vector> | |
#define ALL(v) (v).begin(), (v).end() | |
const int N = 100000; | |
int n, m, edge_count, first_edge[N], to[N << 1], character[N << 1], next_edge[N << 1], lower_bound[N], upper_bound[N]; | |
void add_edge(int u, int v, int w) | |
{ | |
to[edge_count] = v; | |
character[edge_count] = w; | |
next_edge[edge_count] = first_edge[u]; | |
first_edge[u] = edge_count ++; | |
} | |
void read_string(int* bound) | |
{ | |
static char buffer[N]; | |
scanf("%s", buffer); | |
for (int i = 0; buffer[i]; ++ i) { | |
bound[i] = buffer[i] - 'a'; | |
} | |
} | |
struct State; | |
std::vector <State*> states; | |
struct State { | |
State(int length) : length(length), right(-1), has_suffix(false), parent(NULL), quick_parent(NULL) { | |
memset(go, 0, sizeof(go)); | |
memset(children, 0, sizeof(children)); | |
states.push_back(this); | |
} | |
State* extend(State* start, int token) { | |
State *p = this; | |
State *np = new State(length + 1); | |
while (p && !p->go[token]) { | |
p->go[token] = np; | |
p = p->parent; | |
} | |
if (!p) { | |
np->parent = start; | |
} else { | |
State *q = p->go[token]; | |
if (p->length + 1 == q->length) { | |
np->parent = q; | |
} else { | |
State *nq = new State(p->length + 1); | |
memcpy(nq->go, q->go, sizeof(q->go)); | |
nq->parent = q->parent; | |
np->parent = q->parent = nq; | |
while (p && p->go[token] == q) { | |
p->go[token] = nq; | |
p = p->parent; | |
} | |
} | |
} | |
return np; | |
} | |
int length, right, begin, end; | |
bool has_suffix; | |
State* parent; | |
State* quick_parent; | |
State* go[2]; | |
int min_length[2]; | |
State* children[2]; | |
}; | |
bool compare(State *a, State *b) | |
{ | |
return a->length > b->length; | |
} | |
int label(State* p, int count) | |
{ | |
if (p == NULL) { | |
return count; | |
} | |
p->begin = count ++; | |
for (int i = 0; i < 2; ++ i) { | |
count = label(p->children[i], count); | |
} | |
p->end = count ++; | |
return count; | |
} | |
State* start; | |
State* location[N]; | |
bool visited[N]; | |
int size[N], balance[N]; | |
std::vector <int> nodes; | |
void prepare(int p, int u) | |
{ | |
size[u] = 1; | |
balance[u] = 0; | |
nodes.push_back(u); | |
for (int iterator = first_edge[u]; ~iterator; iterator = next_edge[iterator]) { | |
int v = to[iterator]; | |
if (v != p && !visited[v]) { | |
prepare(u, v); | |
size[u] += size[v]; | |
balance[u] = std::max(balance[u], size[v]); | |
} | |
} | |
} | |
int* bound; | |
int prefix[N][3], suffix[N][3], max_depth, prefix_delta[N][3], suffix_delta[N][3]; | |
std::vector <int> string; | |
int signum(int x) | |
{ | |
return x < 0 ? -1 : x > 0; | |
} | |
void search(int p, int u, int w, State* s, int lcp, State* t, int diverge) | |
{ | |
string.push_back(w); | |
int depth = (int)string.size(); | |
if (depth <= m) { | |
if (depth > max_depth) { | |
max_depth = depth; | |
for (int j = 0; j < 3; ++ j) { | |
prefix_delta[depth][j] = suffix_delta[depth][j] = 0; | |
} | |
} | |
lcp = std::min(lcp, s->min_length[w]) + 1; | |
s = s->go[w]; | |
lcp = std::min(lcp, s->length); | |
int l = std::min(lcp, s->quick_parent->length); | |
if (l == depth) { | |
prefix_delta[depth][1] ++; | |
} else { | |
prefix_delta[depth][1 + string[depth - 1 - l] - bound[l]] ++; | |
} | |
if (diverge) { | |
diverge += 2; | |
} else if (depth > t->length) { | |
if (t->children[w]) { | |
t = t->children[w]; | |
} else { | |
diverge = 2 + w; | |
} | |
} else if (bound[t->right + depth - 1] != w) { | |
diverge = 2 + w; | |
} | |
State* q = location[m - depth]; | |
if (t == q && !diverge) { | |
suffix_delta[depth][1] ++; | |
} else if (t->begin <= q->begin && q->end <= t->end) { | |
assert(diverge); | |
suffix_delta[depth][(diverge & 1) << 1] ++; | |
} else { | |
suffix_delta[depth][(q->begin < t->begin) << 1] ++; | |
} | |
for (int iterator = first_edge[u]; ~iterator; iterator = next_edge[iterator]) { | |
int v = to[iterator]; | |
if (v != p && !visited[v]) { | |
int w = character[iterator]; | |
search(u, v, w, s, lcp, t, diverge); | |
} | |
} | |
} | |
string.pop_back(); | |
} | |
long long merge(bool equal, int prefix[N][3], int suffix[N][3], int l, int r) | |
{ | |
long long result = 0; | |
for (int i = l; i <= r; ++ i) { | |
int j = m - i; | |
result += (long long)prefix[i][0] * (suffix[j][0] + suffix[j][1] + suffix[j][2]); | |
result += (long long)prefix[i][1] * (suffix[j][0] + (equal ? suffix[j][1] : 0)); | |
} | |
return result; | |
} | |
long long divide(int root, bool equal) | |
{ | |
nodes.clear(); | |
prepare(-1, root); | |
{ | |
int s = size[root]; | |
if (s <= m) { | |
return 0; | |
} | |
for (int i = 0; i < (int)nodes.size(); ++ i) { | |
int u = nodes[i]; | |
balance[u] = std::max(balance[u], s - size[u]); | |
if (balance[u] < balance[root]) { | |
root = u; | |
} | |
} | |
} | |
long long result = 0; | |
for (int i = 0; i <= m; ++ i) { | |
for (int j = 0; j < 3; ++ j) { | |
prefix[i][j] = suffix[i][j] = 0; | |
} | |
} | |
prefix[0][1] = suffix[0][1] = 1; | |
for (int iterator = first_edge[root]; ~iterator; iterator = next_edge[iterator]) { | |
int v = to[iterator]; | |
if (!visited[v]) { | |
int w = character[iterator]; | |
max_depth = 0; | |
search(root, v, w, start, 0, start, 0); | |
result += merge(equal, prefix_delta, suffix, 1, max_depth); | |
result += merge(equal, prefix, suffix_delta, m - max_depth, m - 1); | |
for (int i = 0; i <= max_depth; ++ i) { | |
for (int j = 0; j < 3; ++ j) { | |
prefix[i][j] += prefix_delta[i][j]; | |
suffix[i][j] += suffix_delta[i][j]; | |
} | |
} | |
} | |
} | |
visited[root] = true; | |
for (int iterator = first_edge[root]; ~iterator; iterator = next_edge[iterator]) { | |
int v = to[iterator]; | |
if (!visited[v]) { | |
result += divide(v, equal); | |
} | |
} | |
return result; | |
} | |
long long solve(int* bound, bool equal) | |
{ | |
start = new State(0); | |
{ | |
State* p = start; | |
for (int i = m - 1; i >= 0; -- i) { | |
p = p->extend(start, bound[i]); | |
location[i] = p; | |
p->right = i; | |
} | |
p->has_suffix = true; | |
} | |
std::sort(ALL(states), compare); | |
for (int i = 0; i < (int)states.size(); ++ i) { | |
State* p = states[i]; | |
State* q = p->parent; | |
if (q) { | |
q->right = p->right; | |
q->has_suffix |= p->has_suffix; | |
q->children[bound[q->right + q->length]] = p; | |
} | |
} | |
for (int i = (int)states.size() - 1; i >= 0; -- i) { | |
State* p = states[i]; | |
if (p->has_suffix) { | |
p->quick_parent = p; | |
} else { | |
p->quick_parent = p->parent->quick_parent; | |
} | |
for (int j = 0; j < 2; ++ j) { | |
if (p->go[j]) { | |
p->min_length[j] = p->length; | |
} else { | |
State* q = p->parent; | |
if (q) { | |
p->go[j] = q->go[j]; | |
p->min_length[j] = q->min_length[j]; | |
} else { | |
p->go[j] = start; | |
p->min_length[j] = 0; | |
} | |
} | |
} | |
} | |
label(start, 0); | |
std::fill(visited, visited + n, false); | |
::bound = bound; | |
long long result = divide(0, equal); | |
for (int i = 0; i < (int)states.size(); ++ i) { | |
delete states[i]; | |
} | |
states.clear(); | |
return result; | |
} | |
int main() | |
{ | |
int tests; | |
scanf("%d", &tests); | |
while (tests --) { | |
scanf("%d%d", &n, &m); | |
edge_count = 0; | |
std::fill(first_edge, first_edge + n, -1); | |
for (int i = 0; i < n - 1; ++ i) { | |
int a, b; | |
char buffer[2]; | |
scanf("%d%d%s", &a, &b, buffer); | |
a --; | |
b --; | |
int c = *buffer - 'a'; | |
add_edge(a, b, c); | |
add_edge(b, a, c); | |
} | |
read_string(lower_bound); | |
read_string(upper_bound); | |
long long result = 0; | |
result += solve(upper_bound, true); | |
result -= solve(lower_bound, false); | |
std::cout << result << std::endl; | |
} | |
return 0; | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment