题目传送门:http://uoj.ac/problem/176
位运算生成树系列之&最大
详细情况,见之后的算法笔记吧
#include<iostream>
#include<cstdio>
#define LL long long
using namespace std;
const int N = 1<<19;
int n,m,num[N],fa[N];
LL vout;
inline int read(){
char c=getchar(); int ret=0;
while (c<'0'||c>'9') c=getchar();
while (c<='9'&&c>='0') ret=ret*10+c-'0',c=getchar();
return ret;
}
inline int find(int w){
int f=fa[w],tmp;
while (f != fa[f]) f = fa[f];
while (w != f) tmp=fa[w],fa[w]=f,w=tmp;
return f;
}
int main(){
n = read(); m = read();
for (int i=1,v;i<=n;i++)
if (num[v=read()]) vout += v;
else num[v] = v;
for (int q=(1<<m)-1;q;q--) { fa[q] = q;
for (int i=1;i<=m && !num[q];i++) num[q] = num[q|1<<i-1];
for (int i=1,f1,f2,u=num[q],v;i<=m;i++)
if ((v=num[q|1<<i-1]) && find(u) != find(v)) vout += q, fa[find(u)] = find(v);
}
printf("%lld\n",vout);
return 0;
}
来补一个Boruvka + trie的版本:
#include<iostream>
#include<cstdio>
#include<cstring>
#define LL long long
#define R(x) ((x)>0)
using namespace std;
const int N = 100000+9;
const int M = N*32*4;
int n,arr[N],v[N],sur[N],fa[N],m;
LL vout = 0;
namespace Trie{
int root,ch[M][2],MX[M],MN[M],cnt,NUM,COL,ans_tmp;
inline void init(){
memset(ch,0,sizeof(ch));
memset(MX,0,sizeof(MX));
memset(MN,0,sizeof(MN));
cnt = root = 1;
}
void Insert(int &w, int t){
if (!w) w = ++cnt;
if (!MX[w] || MX[w] < COL) MX[w] = COL;
if (!MN[w] || MN[w] > COL) MN[w] = COL;
if (t) Insert(ch[w][R(NUM&t)],t>>1);
}
inline void insert(int num, int col){
NUM = num; COL = col;
Insert(ch[root][R(num&(1<<m))],1<<m-1);
}
void Merge(int &w, int p){
if (!w) w = ++cnt;
if (!MX[w] || MX[w] < MX[p]) MX[w] = MX[p];
if (!MN[w] || MN[w] > MN[p]) MN[w] = MN[p];
if (ch[p][0]) Merge(ch[w][0], ch[p][0]);
if (ch[p][1]) Merge(ch[w][1], ch[p][1]);
}
void merge(int w){
if (ch[w][1]) Merge(ch[w][0],ch[w][1]), merge(ch[w][1]);
if (ch[w][0]) merge(ch[w][0]);
}
void Query(int w, int t){
if (t) {
if (R(NUM&t) && (MN[ch[w][1]] != COL || MX[ch[w][1]] != COL)) ans_tmp|=t, Query(ch[w][1], t>>1);
else Query(ch[w][0],t>>1);
} else {
if (MN[w] != COL) COL = MN[w];
else COL = MX[w];
}
}
inline pair<int,int> query(int num, int col){
ans_tmp = 0; NUM = num, COL = col;
Query(root,1<<m);
return make_pair(ans_tmp,COL);
}
};
inline int find(int w){
int f=fa[w], tmp;
while (fa[f] != f) f = fa[f];
while (w != f) tmp=fa[w],fa[w]=f,w=tmp;
return f;
}
inline int read(){
char c=getchar(); int ret=0;
while (c<'0'||c>'9') c=getchar();
while (c<='9'&&c>='0') ret=ret*10+c-'0',c=getchar();
return ret;
}
int main(){
n = read(); m = read(); int cnt = n-1;
for (int i=1;i<=n;i++) fa[i] = i, arr[i] = read();
while (cnt) {
Trie::init(); memset(v,-1,sizeof(v));
for (int i=1;i<=n;i++) Trie::insert(arr[i],find(i));
Trie::merge(Trie::root);
for (int i=1;i<=n;i++) {
pair<int,int> tmp = Trie::query(arr[i],fa[i]);
if (tmp.first > v[fa[i]]) v[fa[i]] = tmp.first, sur[fa[i]] = tmp.second;
}
for (int i=1;i<=n;i++) if (~v[i] && find(i) != find(sur[i]))
vout += v[i], fa[find(i)] = find(sur[i]), cnt--;
}
printf("%lld\n",vout);
return 0;
}