参考例题:http://www.lydsy.com/JudgeOnline/problem.php?id=4530
#include<bits/stdc++.h> #define LL long long using namespace std; const int N = 200009; const int M = N * 21; int n, m, vis[N], head[N], nxt[N], to[N]; int dep[N], beg[N], out[N], sz[N], fa[N]; struct Data{ int t, x, y; inline Data() { } inline Data(bool a, int b, int c):t(a), x(b), y(c) { } }opt[N]; inline int read() { char c = getchar(); int ret = 0, f = 1; for (; c < '0' || c > '9'; f = c == '-'? -1: 1, c = getchar()); for (; '0' <= c && c <= '9'; ret = ret * 10 + c - '0', c = getchar()); return ret * f; } inline void AddEdge(int u, int v) { static int E = 1; to[++E] = v; nxt[E] = head[u]; head[u] = E; to[++E] = u; nxt[E] = head[v]; head[v] = E; } inline void DFS(int w, int f) { static int D = 0; vis[w] = 1; beg[w] = ++D; dep[w] = dep[f] + 1; for (int i = head[w]; i; i = nxt[i]) { if (to[i] != f) { DFS(to[i], w); } } out[w] = D; } inline int find(int x) { return fa[x] == x? x: fa[x] = find(fa[x]); } class SegmentTree{ int cnt, ch[M][2], sum[M], root[N]; public: inline void insert(int p, int v) { insert(root[p], 1, n, v); } inline void merge(int a, int b) { root[a] = Merge(root[a], root[b]); } inline int query(int p, int l, int r) { return query(root[p], 1, n, l, r); } private: inline int Merge(int a, int b) { if (!a || !b) { return a + b; } else { sum[a] += sum[b]; ch[a][0] = Merge(ch[a][0], ch[b][0]); ch[a][1] = Merge(ch[a][1], ch[b][1]); return a; } } inline void insert(int &w, int l, int r, int p) { sum[w = ++cnt] = 1; if (l < r) { int mid = l + r + 1 >> 1; if (p < mid) { insert(ch[w][0], l, mid - 1, p); } else { insert(ch[w][1], mid, r, p); } } } inline int query(int w, int l, int r, int L, int R) { if (!w) { return 0; } else if (L <= l && r <= R) { return sum[w]; } else { int mid = l + r + 1 >> 1, ret = 0; ret += L < mid? query(ch[w][0], l, mid - 1, L, R): 0; ret += mid <= R? query(ch[w][1], mid, r, L, R): 0; return ret; } } }SGT; int main() { n = read(); m = read(); for (int i = 1; i <= m; i++) { char cmd[3]; scanf("%s", cmd); int u = read(), v = read(); if (cmd[0] == 'A') { AddEdge(u, v); } opt[i] = Data(cmd[0] == 'A', u, v); } for (int i = 1; i <= n; i++) { if (!vis[i]) { DFS(i, i); } } for (int i = 1; i <= n; i++) { sz[i] = 1; fa[i] = i; SGT.insert(i, beg[i]); } for (int i = 1; i <= m; i++) { int u = opt[i].x, v = opt[i].y; if (opt[i].t == 1) { SGT.merge(find(u), find(v)); sz[find(u)] += sz[find(v)]; fa[find(v)] = find(u); } else { if (dep[u] < dep[v]) { swap(u, v); } int p1 = SGT.query(find(u), beg[u], out[u]); int p2 = sz[find(u)] - p1; printf("%lld\n", (LL)p1 * p2); } } return 0; }