Skip to content

Instantly share code, notes, and snippets.

@ftiasch
Last active September 9, 2015 14:16
Show Gist options
  • Save ftiasch/2ae976347be07da2ab5f to your computer and use it in GitHub Desktop.
Save ftiasch/2ae976347be07da2ab5f to your computer and use it in GitHub Desktop.
2014 Beijing Regional Onsite Contest Problem J Just a Challenge
/*
* 假设l = aa..a, r = bb..b,则题目等价于计算树上距离恰好是m的点对数量。不难发现本题需要树分治,
*
* 分治点u时,假设v到u对应的字符串是S = s1 s2 ... sk,需要比较S和l, r长度为k的前后缀的字典序大小。
*
* (作减法之后,只讨论上界r)
* 1. 对reverse(r)建立后缀自动机。
* 从u开始dfs树,在自动机上对应转移。
* 因为所在节点的right集合不一定包含前缀0,
* 所以需要预处理节点(沿着parent)最近的包含前缀0的祖先,
* 该祖先的length即S和r的最长公共前缀。只需比较一次即可确定大小关系。
*
* 2. 对r建立后缀树。从u开始dfs树,并在后缀树上对应转移。
* 注意可能会走到后缀树上不存在的节点,要记录下分叉的节点。
* 预处理后缀树的一个dfs编号,就很容易比较出两个串的字典序。
*
* 时间复杂度O(n log n)
*/
#include <algorithm>
#include <iostream>
#include <cstdio>
#include <cstring>
#include <cassert>
#include <vector>
#define ALL(v) (v).begin(), (v).end()
const int N = 100000;
int n, m, edge_count, first_edge[N], to[N << 1], character[N << 1], next_edge[N << 1], lower_bound[N], upper_bound[N];
void add_edge(int u, int v, int w)
{
to[edge_count] = v;
character[edge_count] = w;
next_edge[edge_count] = first_edge[u];
first_edge[u] = edge_count ++;
}
void read_string(int* bound)
{
static char buffer[N];
scanf("%s", buffer);
for (int i = 0; buffer[i]; ++ i) {
bound[i] = buffer[i] - 'a';
}
}
struct State;
std::vector <State*> states;
struct State {
State(int length) : length(length), right(-1), has_suffix(false), parent(NULL), quick_parent(NULL) {
memset(go, 0, sizeof(go));
memset(children, 0, sizeof(children));
states.push_back(this);
}
State* extend(State* start, int token) {
State *p = this;
State *np = new State(length + 1);
while (p && !p->go[token]) {
p->go[token] = np;
p = p->parent;
}
if (!p) {
np->parent = start;
} else {
State *q = p->go[token];
if (p->length + 1 == q->length) {
np->parent = q;
} else {
State *nq = new State(p->length + 1);
memcpy(nq->go, q->go, sizeof(q->go));
nq->parent = q->parent;
np->parent = q->parent = nq;
while (p && p->go[token] == q) {
p->go[token] = nq;
p = p->parent;
}
}
}
return np;
}
int length, right, begin, end;
bool has_suffix;
State* parent;
State* quick_parent;
State* go[2];
int min_length[2];
State* children[2];
};
bool compare(State *a, State *b)
{
return a->length > b->length;
}
int label(State* p, int count)
{
if (p == NULL) {
return count;
}
p->begin = count ++;
for (int i = 0; i < 2; ++ i) {
count = label(p->children[i], count);
}
p->end = count ++;
return count;
}
State* start;
State* location[N];
bool visited[N];
int size[N], balance[N];
std::vector <int> nodes;
void prepare(int p, int u)
{
size[u] = 1;
balance[u] = 0;
nodes.push_back(u);
for (int iterator = first_edge[u]; ~iterator; iterator = next_edge[iterator]) {
int v = to[iterator];
if (v != p && !visited[v]) {
prepare(u, v);
size[u] += size[v];
balance[u] = std::max(balance[u], size[v]);
}
}
}
int* bound;
int prefix[N][3], suffix[N][3], max_depth, prefix_delta[N][3], suffix_delta[N][3];
std::vector <int> string;
int signum(int x)
{
return x < 0 ? -1 : x > 0;
}
void search(int p, int u, int w, State* s, int lcp, State* t, int diverge)
{
string.push_back(w);
int depth = (int)string.size();
if (depth <= m) {
if (depth > max_depth) {
max_depth = depth;
for (int j = 0; j < 3; ++ j) {
prefix_delta[depth][j] = suffix_delta[depth][j] = 0;
}
}
lcp = std::min(lcp, s->min_length[w]) + 1;
s = s->go[w];
lcp = std::min(lcp, s->length);
int l = std::min(lcp, s->quick_parent->length);
if (l == depth) {
prefix_delta[depth][1] ++;
} else {
prefix_delta[depth][1 + string[depth - 1 - l] - bound[l]] ++;
}
if (diverge) {
diverge += 2;
} else if (depth > t->length) {
if (t->children[w]) {
t = t->children[w];
} else {
diverge = 2 + w;
}
} else if (bound[t->right + depth - 1] != w) {
diverge = 2 + w;
}
State* q = location[m - depth];
if (t == q && !diverge) {
suffix_delta[depth][1] ++;
} else if (t->begin <= q->begin && q->end <= t->end) {
assert(diverge);
suffix_delta[depth][(diverge & 1) << 1] ++;
} else {
suffix_delta[depth][(q->begin < t->begin) << 1] ++;
}
for (int iterator = first_edge[u]; ~iterator; iterator = next_edge[iterator]) {
int v = to[iterator];
if (v != p && !visited[v]) {
int w = character[iterator];
search(u, v, w, s, lcp, t, diverge);
}
}
}
string.pop_back();
}
long long merge(bool equal, int prefix[N][3], int suffix[N][3], int l, int r)
{
long long result = 0;
for (int i = l; i <= r; ++ i) {
int j = m - i;
result += (long long)prefix[i][0] * (suffix[j][0] + suffix[j][1] + suffix[j][2]);
result += (long long)prefix[i][1] * (suffix[j][0] + (equal ? suffix[j][1] : 0));
}
return result;
}
long long divide(int root, bool equal)
{
nodes.clear();
prepare(-1, root);
{
int s = size[root];
if (s <= m) {
return 0;
}
for (int i = 0; i < (int)nodes.size(); ++ i) {
int u = nodes[i];
balance[u] = std::max(balance[u], s - size[u]);
if (balance[u] < balance[root]) {
root = u;
}
}
}
long long result = 0;
for (int i = 0; i <= m; ++ i) {
for (int j = 0; j < 3; ++ j) {
prefix[i][j] = suffix[i][j] = 0;
}
}
prefix[0][1] = suffix[0][1] = 1;
for (int iterator = first_edge[root]; ~iterator; iterator = next_edge[iterator]) {
int v = to[iterator];
if (!visited[v]) {
int w = character[iterator];
max_depth = 0;
search(root, v, w, start, 0, start, 0);
result += merge(equal, prefix_delta, suffix, 1, max_depth);
result += merge(equal, prefix, suffix_delta, m - max_depth, m - 1);
for (int i = 0; i <= max_depth; ++ i) {
for (int j = 0; j < 3; ++ j) {
prefix[i][j] += prefix_delta[i][j];
suffix[i][j] += suffix_delta[i][j];
}
}
}
}
visited[root] = true;
for (int iterator = first_edge[root]; ~iterator; iterator = next_edge[iterator]) {
int v = to[iterator];
if (!visited[v]) {
result += divide(v, equal);
}
}
return result;
}
long long solve(int* bound, bool equal)
{
start = new State(0);
{
State* p = start;
for (int i = m - 1; i >= 0; -- i) {
p = p->extend(start, bound[i]);
location[i] = p;
p->right = i;
}
p->has_suffix = true;
}
std::sort(ALL(states), compare);
for (int i = 0; i < (int)states.size(); ++ i) {
State* p = states[i];
State* q = p->parent;
if (q) {
q->right = p->right;
q->has_suffix |= p->has_suffix;
q->children[bound[q->right + q->length]] = p;
}
}
for (int i = (int)states.size() - 1; i >= 0; -- i) {
State* p = states[i];
if (p->has_suffix) {
p->quick_parent = p;
} else {
p->quick_parent = p->parent->quick_parent;
}
for (int j = 0; j < 2; ++ j) {
if (p->go[j]) {
p->min_length[j] = p->length;
} else {
State* q = p->parent;
if (q) {
p->go[j] = q->go[j];
p->min_length[j] = q->min_length[j];
} else {
p->go[j] = start;
p->min_length[j] = 0;
}
}
}
}
label(start, 0);
std::fill(visited, visited + n, false);
::bound = bound;
long long result = divide(0, equal);
for (int i = 0; i < (int)states.size(); ++ i) {
delete states[i];
}
states.clear();
return result;
}
int main()
{
int tests;
scanf("%d", &tests);
while (tests --) {
scanf("%d%d", &n, &m);
edge_count = 0;
std::fill(first_edge, first_edge + n, -1);
for (int i = 0; i < n - 1; ++ i) {
int a, b;
char buffer[2];
scanf("%d%d%s", &a, &b, buffer);
a --;
b --;
int c = *buffer - 'a';
add_edge(a, b, c);
add_edge(b, a, c);
}
read_string(lower_bound);
read_string(upper_bound);
long long result = 0;
result += solve(upper_bound, true);
result -= solve(lower_bound, false);
std::cout << result << std::endl;
}
return 0;
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment