【日常小测】tree

题目大意

给定一棵大小为$n(n \le 3000)$的树,带边权
请你给出一个长度为$k(k \le n)$的序列$\{a_i\}$
要求序列中两两元素不同,请你最小化$\sum\limits_{i=1}^{k-1}{dis(a_i,a_{i+1})}$

解题报告

我们看一看这个式子可以发现实际上是要我们给出一个大小为$k$的连通块
问你这个连通块内的边权乘上$2$再减掉直径的最小值是多少

于是我们定义$h_{i,j}$为$i$的子树中选$j$个点,还没有考虑直径的最小花费
$f_{i,j}$表示在$i$的子树里,选$j$个点,直径包含$i$的最小花费
$g_{i,j}$表示在$i$的子树里,选$j$个点,已经考虑过直径且直径不包含$i$的最小花费
然后我们暴力$DP$就可以了

我们的时间复杂度是$T(n)=\sum\limits_{i}{\sum\limits_{fa_i=fa_j}{size_i \cdot size_j}}$
仔细想一想的话,每一堆点只会在$LCA$处被计算,时间复杂度是$O(n^2)$的

Code

#include<bits/stdc++.h>
#define LL long long
using namespace std;

const int N = 3009;
const int M = N << 1;
const LL INF = 1e18;

int n,k,head[N],nxt[M],to[M],cost[M],sz[N];
LL vout=INF,f[N][N],g[N][N],h[N][N];

inline int read() {
	char c=getchar(); int ret=0,f=1;
	while (c<'0'||c>'9') {if(c=='-')f=-1;c=getchar();}
	while (c<='9'&&c>='0') {ret=ret*10+c-'0';c=getchar();}
	return ret*f;
}

inline void AddEdge(int u, int v, int c) {
	static int E = 1; cost[E+1] = cost[E+2] = c;
	to[++E] = v; nxt[E] = head[u]; head[u] = E;
	to[++E] = u; nxt[E] = head[v]; head[v] = E;
}

void update(int w, int fa) { 
	fill(f[w], f[w]+1+N, INF);
    fill(g[w], g[w]+1+N, INF);
	fill(h[w], h[w]+1+N, INF);
	f[w][1] = g[w][1] = h[w][1] = 0; sz[w] = 1;
	for (int i=head[w],t;i;i=nxt[i]) {
		if ((t=to[i]) != fa) {
			update(to[i], w); 
			for (int x=sz[w],tmp=cost[i]<<1;x;x--) {
				for (int j=1;j<=sz[t];j++) {
					g[w][x+j] = min(g[w][x+j], g[w][x] + h[t][j] + tmp);
					g[w][x+j] = min(g[w][x+j], f[w][x] + f[t][j] + cost[i]);
					g[w][x+j] = min(g[w][x+j], h[w][x] + g[t][j] + tmp);
				}
			}
			for (int x=sz[w],tmp=cost[i]<<1;x;x--) {
				for (int j=1;j<=sz[t];j++) {
					f[w][x+j] = min(f[w][x+j], h[w][x] + f[t][j] + cost[i]);
					f[w][x+j] = min(f[w][x+j], f[w][x] + h[t][j] + tmp);
				}
			} 
			for (int x=sz[w],tmp=cost[i]<<1;x;x--) {
				for (int j=1;j<=sz[t];j++) 
					h[w][x+j] = min(h[w][x+j], h[w][x] + h[t][j] + tmp);
			} 
			sz[w] += sz[t];
		}	
	}
	vout = min(vout, f[w][k]);
	vout = min(vout, g[w][k]); 
}

int main() {
	n = read(); k = read();
	for (int i=2,u,v;i<=n;i++) {
		u = read(); v = read();
		AddEdge(u, v, read());
	}
	update(1, 1);
	printf("%lld\n",vout);
	return 0;
}