相关链接
题目传送门:http://codeforces.com/problemset/problem/715/C
官方题解:http://codeforces.com/blog/entry/47169
解题报告
要求统计树上路径,那么基本上确定是DP或者树分治了
想一想,好像DP的状态不好表示的样子,于是就直接点分治啦!
考虑对于每一个中心,统计经过该点符合要求的路径数
很明显需要将路径剖成两半,一半扔到map里,另一半直接查
但这题还有需要注意的就是如何去掉两段都在同一子树的非法情况
似乎直接像之前一样在子树的根部调用cal()
直接剪掉的方法不管用了
于是可以先DFS一遍,统计所有信息
然后再处理每一个子树的时候,先DFS一遍,把该子树的信息给去掉
查询完成之后,再DFS一遍把信息给加回去
Code
#include<bits/stdc++.h> #define LL long long using namespace std; const int N = 200000 + 9; int head[N],nxt[N],cost[N],to[N],REV[N]; int n,MOD; LL vout; inline int read() { char c=getchar(); int f=1,ret=0; while (c<'0'||c>'9') {if(c=='-')f=-1;c=getchar();} while (c<='9'&&c>='0') {ret=ret*10+c-'0';c=getchar();} return ret * f; } inline void Add_Edge(int u, int v, int w) { static int T = 0; to[++T] = v; nxt[T] = head[u]; head[u] = T; cost[T] = w; to[++T] = u; nxt[T] = head[v]; head[v] = T; cost[T] = w; } void gcd(int a, LL &x, int b, LL &y) { if (!b) {x = 1, y = 0;} else {gcd(b,y,a%b,x);y-=a/b*x;} } inline int gcd(int a, int b) { static LL x,y; gcd(a,x,b,y); return (x % MOD + MOD) % MOD; } namespace Node_Decomposition{ #define ND Node_Decomposition const int INF = 1e9; int tot,node_sz,root,cur; int sum[N],dep[N],vis[N]; map<int,int> cnt; void Get_Root(int w, int f) { sum[w] = 1; int mx = 0; for (int i=head[w];i;i=nxt[i]) { if (to[i] != f && !vis[to[i]]) { Get_Root(to[i], w); sum[w] += sum[to[i]]; mx = max(mx, sum[to[i]]); } } mx = max(mx, node_sz - sum[w]); if (mx < cur) cur = mx, root = w; } void DFS(int w, int f, int delta, LL p, int val) { cnt[val] += delta; for (int i=head[w];i;i=nxt[i]) { if (!vis[to[i]] && to[i] != f) { DFS(to[i], w, delta, p * 10 % MOD, (val + cost[i] * p) % MOD); } } } void cal(int w, int f, int t, LL val) { vout += cnt[(-val*REV[t]%MOD+MOD)%MOD]; for (int i=head[w];i;i=nxt[i]) { if (!vis[to[i]] && to[i] != f) { cal(to[i], w, t+1, (val * 10 + cost[i]) % MOD); } } } void solve(int w, int sz) { vis[w] = 1; cnt.clear(); for (int i=head[w];i;i=nxt[i]) { if (!vis[to[i]]) { DFS(to[i], w, 1, 10 % MOD, cost[i] % MOD); } } vout += cnt[0]; cnt[0]++; for (int i=head[w];i;i=nxt[i]) { if (!vis[to[i]]) { DFS(to[i], w, -1, 10 % MOD, cost[i] % MOD); cal(to[i], w, 1, cost[i] % MOD); DFS(to[i], w, 1, 10 % MOD, cost[i] % MOD); } } for (int i=head[w];i;i=nxt[i]) { if (!vis[to[i]]) { node_sz = sum[to[i]] > sum[w] ? sz - sum[w] : sum[to[i]]; cur = INF; Get_Root(to[i], w); solve(root, node_sz); } } } inline void solve() { cur = INF; node_sz = n; Get_Root(1,1); solve(root,n); } }; int main() { n = read(); MOD = read(); for (int i=1,u,v,w;i<n;i++) { u = read(); v = read(); w = read(); Add_Edge(u + 1, v + 1, w); } REV[0] = 1; REV[1] = gcd(10, MOD); for (int i=2;i<=n;i++) REV[i] = (LL)REV[i-1] * REV[1] % MOD; ND::solve(); printf("%lld\n",vout); return 0; }