相关链接
题目传送门:http://www.lydsy.com/JudgeOnline/problem.php?id=4231
数据生成器:http://paste.ubuntu.com/24366714/
神犇题解:http://www.cnblogs.com/clrs97/p/5467637.html
解题报告
首先我们如果最终一个串出现的位置会越过$LCA$
那么我们可以把这一部分的情况单独拿出来,暴力跑$KMP$
剩下就是单纯地从根节点向下,或者向上的路径中出现了多少次
这不难让我们想到广义后缀自动机,但似乎这题并不能用
考虑另一个方法,把所有模式串建成AC自动机
然后在原树上$DFS$,进入一个点时将其在AC自动机对应的结点权值$+1$
退出来的时候将其$-1$,那么我们在需要询问的时候统计一下子树的权值和就可以了
总时间复杂度:$O(n \log n + \sum |S|)$
Code
#include<bits/stdc++.h> #define LL long long using namespace std; const int N = 100009; const int M = 600009; int n,m,head[N],nxt[M],to[M],cost[M],U[N]; int pp[N][2],ans[N],dep[N],fa[N][20],C[N]; vector<pair<int,int> > qry[N]; class AC_Automaton{ int dfs_cnt,ch[M][26],fail[M],in[M],out[M]; queue<int> que; vector<int> sn[M]; struct Fenwick_Tree{ int sum[M],sz; inline int lowbit(int x) {return x & -x;} inline void modify(int w, int delta) { for (int i=w;i<=sz;i+=lowbit(i)) sum[i] += delta; } inline int query(int l, int r) { int ret = 0; l--; for (int i=l;i;i-=lowbit(i)) ret -= sum[i]; for (int i=r;i;i-=lowbit(i)) ret += sum[i]; return ret; } }BIT; public: inline void build() { for (int i=0;i<26;i++) ch[0][i]=1; que.push(1); fail[1] = 0; while (!que.empty()) { int w = que.front(); que.pop(); sn[fail[w]].push_back(w); for (int i=0;i<26;i++) { if (ch[w][i]) { que.push(ch[w][i]); fail[ch[w][i]] = ch[fail[w]][i]; } else ch[w][i] = ch[fail[w]][i]; } } DFS(1); BIT.sz = dfs_cnt; } inline int insert(char *s) { static int cnt = 1; int w = 1, len = strlen(s+1); for (int i=1,c;i<=len;i++) { if (!ch[w]-'a']) ch[w] = ++cnt; w = ch[w]; } return w; } inline int query(int p) { return BIT.query(in[p], out[p]); } inline void modify(int p, int delta) { BIT.modify(in[p], delta); } inline int move(int w, int c) { return ch[w]; } private: void DFS(int w) { in[w] = ++dfs_cnt; for (int i=sn[w].size()-1;~i;i--) if (!in[sn[w][i]]) DFS(sn[w][i]); out[w] = dfs_cnt; } }ac; inline int read() { char c=getchar(); int f=1,ret=0; while (c<'0'||c>'9') {if(c=='-')f=-1;c=getchar();} while (c<='9'&&c>='0') {ret=ret*10+c-'0';c=getchar();} return ret * f; } inline void AddEdge(int u, int v, int c) { static int E = 1; cost[E+1] = cost[E+2] = c; to[++E] = v; nxt[E] = head[u]; head[u] = E; to[++E] = u; nxt[E] = head[v]; head[v] = E; } inline int LCA(int a, int b) { if (dep[a] < dep[b]) swap(a, b); for (int j=19;~j;j--) if (dep[fa[a][j]] >= dep[b]) a = fa[a][j]; if (a == b) return a; for (int j=19;~j;j--) if (fa[a][j] != fa[b][j]) a = fa[a][j], b = fa[b][j]; return fa[a][0]; } void pre(int w, int f) { fa[w][0] = f; dep[w] = dep[f] + 1; for (int i=head[w];i;i=nxt[i]) if (to[i] != f) C[to[i]] = cost[i], pre(to[i], w); } void solve(int w, int p) { for (int i=qry[w].size()-1,f;~i;i--) { if (qry[w][i].first>0) f = 1; else f = -1, qry[w][i].first *= -1; ans[qry[w][i].first] += ac.query(pp[qry[w][i].first][qry[w][i].second]) * f; } for (int i=head[w],tmp;i;i=nxt[i]) { if (dep[to[i]] > dep[w]) { tmp = ac.move(p, cost[i]); ac.modify(tmp, 1); solve(to[i], tmp); ac.modify(tmp, -1); } } } inline int dif(int &u, int &v, int lca, char *s, int len) { static char ss[M]; static int NXT[M]; int tot = 0, TOT; int w = u, l = dep[u] - dep[lca] - len + 1, ret = 0; if (l > 0) {for (int j=0;l;l>>=1,++j) if (l&1) w = fa[w][j]; u = w;} while (w != lca) ss[++tot] = C[w] + 'a', w = fa[w][0]; w = v; l = dep[v] - dep[lca] - len + 1; if (l > 0) {for (int j=0;l;l>>=1,++j) if (l&1) w = fa[w][j]; v = w;} TOT = (tot += dep[w] - dep[lca]); while (w != lca) ss[tot--] = C[w] + 'a', w = fa[w][0]; for (int i=1,w;i<=len;i++) { for (w=NXT[i];w&&s[w+1]!=s[i+1];w=NXT[w]); NXT[i+1] = w + (s[w+1] == s[i+1]); } for (int i=1,w=0;i<=TOT;i++) { for (;w&&s[w+1]!=ss[i];w=NXT[w]); w += s[w+1] == ss[i]; ret += w == len; } return ret; } int main() { n = read(); m = read(); for (int i=1,u,v;i<n;i++) { u = read(); v = read(); char c[2]; scanf("%s",c); AddEdge(u, v, c[0] - 'a'); } pre(1, 1); for (int j=1;j<=19;j++) for (int i=1;i<=n;i++) fa[i][j] = fa[fa[i][j-1]][j-1]; char pat[300009]; for (int i=1,u,v,lca,ll,p1,p2;i<=m;i++) { U[i] = u = read(); v = read(); lca = LCA(u, v); scanf("%s",pat+1); pp[i][0] = ac.insert(pat); ll = strlen(pat+1); qry[u].push_back(make_pair(i,1)); qry[v].push_back(make_pair(i,0)); ans[i] += dif(u, v, lca, pat, ll); qry[u].push_back(make_pair(-i,1)); qry[v].push_back(make_pair(-i,0)); for (int l=1,r=ll;l<r;l++,r--) swap(pat[l], pat[r]); pp[i][1] = ac.insert(pat); } ac.build(); solve(1, 1); for (int i=1;i<=m;i++) printf("%d\n",ans[i]); return 0; }