相关链接
题目传送门:http://codeforces.com/problemset/problem/715/C
官方题解:http://codeforces.com/blog/entry/47169
解题报告
要求统计树上路径,那么基本上确定是DP或者树分治了
想一想,好像DP的状态不好表示的样子,于是就直接点分治啦!
考虑对于每一个中心,统计经过该点符合要求的路径数
很明显需要将路径剖成两半,一半扔到map里,另一半直接查
但这题还有需要注意的就是如何去掉两段都在同一子树的非法情况
似乎直接像之前一样在子树的根部调用cal()
直接剪掉的方法不管用了
于是可以先DFS一遍,统计所有信息
然后再处理每一个子树的时候,先DFS一遍,把该子树的信息给去掉
查询完成之后,再DFS一遍把信息给加回去
Code
#include<bits/stdc++.h>
#define LL long long
using namespace std;
const int N = 200000 + 9;
int head[N],nxt[N],cost[N],to[N],REV[N];
int n,MOD; LL vout;
inline int read() {
char c=getchar(); int f=1,ret=0;
while (c<'0'||c>'9') {if(c=='-')f=-1;c=getchar();}
while (c<='9'&&c>='0') {ret=ret*10+c-'0';c=getchar();}
return ret * f;
}
inline void Add_Edge(int u, int v, int w) {
static int T = 0;
to[++T] = v; nxt[T] = head[u]; head[u] = T; cost[T] = w;
to[++T] = u; nxt[T] = head[v]; head[v] = T; cost[T] = w;
}
void gcd(int a, LL &x, int b, LL &y) {
if (!b) {x = 1, y = 0;}
else {gcd(b,y,a%b,x);y-=a/b*x;}
}
inline int gcd(int a, int b) {
static LL x,y;
gcd(a,x,b,y);
return (x % MOD + MOD) % MOD;
}
namespace Node_Decomposition{
#define ND Node_Decomposition
const int INF = 1e9;
int tot,node_sz,root,cur;
int sum[N],dep[N],vis[N];
map<int,int> cnt;
void Get_Root(int w, int f) {
sum[w] = 1; int mx = 0;
for (int i=head[w];i;i=nxt[i]) {
if (to[i] != f && !vis[to[i]]) {
Get_Root(to[i], w);
sum[w] += sum[to[i]];
mx = max(mx, sum[to[i]]);
}
}
mx = max(mx, node_sz - sum[w]);
if (mx < cur) cur = mx, root = w;
}
void DFS(int w, int f, int delta, LL p, int val) {
cnt[val] += delta;
for (int i=head[w];i;i=nxt[i]) {
if (!vis[to[i]] && to[i] != f) {
DFS(to[i], w, delta, p * 10 % MOD, (val + cost[i] * p) % MOD);
}
}
}
void cal(int w, int f, int t, LL val) {
vout += cnt[(-val*REV[t]%MOD+MOD)%MOD];
for (int i=head[w];i;i=nxt[i]) {
if (!vis[to[i]] && to[i] != f) {
cal(to[i], w, t+1, (val * 10 + cost[i]) % MOD);
}
}
}
void solve(int w, int sz) {
vis[w] = 1; cnt.clear();
for (int i=head[w];i;i=nxt[i]) {
if (!vis[to[i]]) {
DFS(to[i], w, 1, 10 % MOD, cost[i] % MOD);
}
}
vout += cnt[0]; cnt[0]++;
for (int i=head[w];i;i=nxt[i]) {
if (!vis[to[i]]) {
DFS(to[i], w, -1, 10 % MOD, cost[i] % MOD);
cal(to[i], w, 1, cost[i] % MOD);
DFS(to[i], w, 1, 10 % MOD, cost[i] % MOD);
}
}
for (int i=head[w];i;i=nxt[i]) {
if (!vis[to[i]]) {
node_sz = sum[to[i]] > sum[w] ? sz - sum[w] : sum[to[i]];
cur = INF; Get_Root(to[i], w);
solve(root, node_sz);
}
}
}
inline void solve() {
cur = INF; node_sz = n;
Get_Root(1,1);
solve(root,n);
}
};
int main() {
n = read(); MOD = read();
for (int i=1,u,v,w;i<n;i++) {
u = read(); v = read(); w = read();
Add_Edge(u + 1, v + 1, w);
}
REV[0] = 1; REV[1] = gcd(10, MOD);
for (int i=2;i<=n;i++)
REV[i] = (LL)REV[i-1] * REV[1] % MOD;
ND::solve();
printf("%lld\n",vout);
return 0;
}