在知乎看到了大佬严格鸽的换根$dp$写法和我之前通常的写法不太一样,对此总结出两种写法(我不太会画图所以仅看文字理解起来可能有点困难$qwq$)
以换根$dp$的经典题目为例:https://www.luogu.com.cn/problem/P3478
定义$f[u]$:以$u$为根的子树的节点深度之和 $\space \space $ $sz[u]$:以$u$为根的子树的节点总个数
通常进行换根$dp$之前都需要先$dfs$一次,求出一些信息,此题中不难得到如下转移:
$f[u] = \sum_v f[v] + sz[v]$, $sz[u] = 1 + \sum_v sz[v]$
写法$1$(我通常的写法):
定义$g[u]$:以$u$为根时除去$u$的子树中的点其余点的深度之和,即考虑$u$往上的所有节点对答案的贡献,则以$u$为根的所有节点的深度之和为$f[u] + g[u]$
在第二次$dfs$时,可以得到如下$g$数组的转移:$g[v] = g[u] + f[u] - (f[v] + sz[v]) + n - sz[v]$
#include <bits/stdc++.h>
using namespace std;
typedef long long LL;
void solve()
{
int n;
cin >> n;
vector<vector<int>> e(n + 1);
for (int i = 0; i < n - 1; i ++ )
{
int a, b;
cin >> a >> b;
e[a].push_back(b), e[b].push_back(a);
}
vector<LL> f(n + 1), g(n + 1);
vector<int> sz(n + 1);
auto dfs1 = [&] (auto self, int u, int fa) -> void
{
sz[u] = 1;
for (auto v : e[u])
if (v != fa)
{
self(self, v, u);
f[u] += f[v] + sz[v];
sz[u] += sz[v];
}
};
auto dfs2 = [&] (auto self, int u, int fa) -> void
{
for (auto v : e[u])
if (v != fa)
{
g[v] = g[u] + f[u] - (f[v] + sz[v]) + n - sz[v];
self(self, v, u);
}
};
dfs1(dfs1, 1, -1);
dfs2(dfs2, 1, -1);
LL mv = 0;
int res = 0;
for (int i = 1; i <= n; i ++ )
if (f[i] + g[i] > mv)
{
mv = f[i] + g[i];
res = i;
}
cout << res << "\n";
}
int main()
{
ios::sync_with_stdio(false);
cin.tie(nullptr);
int T = 1;
while (T -- ) solve();
return 0;
}
写法$2$(严格鸽的写法):
不再定义$g$数组,具体的,分为如下两个步骤:
$1$、将根从$u$换为$v$时,要去掉$v$的贡献,考虑第一次$dfs$时是如何转移的,这里就如何去掉对应的贡献,不难得到如下代码:
f[u] -= sz[v] + f[v];
sz[u] -= sz[v];
然后把$u$连到$v$的下面,即把$u$作为$v$的儿子,此时$v$就有了$u$对应的贡献,不难得到如下代码:
f[v] += f[u] + sz[u];
sz[v] += sz[u];
这个时候已经计算出了以$v$为根的答案,继续$dfs$下去即可,注意回溯的时候需要把$f[u]$和$sz[u]$还原
参考资料:https://zhuanlan.zhihu.com/p/580249398
#include <bits/stdc++.h>
using namespace std;
typedef long long LL;
void solve()
{
int n;
cin >> n;
vector<vector<int>> e(n + 1);
for (int i = 0; i < n - 1; i ++ )
{
int a, b;
cin >> a >> b;
e[a].push_back(b), e[b].push_back(a);
}
vector<LL> f(n + 1), g(n + 1);
vector<int> sz(n + 1);
auto dfs1 = [&] (auto self, int u, int fa) -> void
{
sz[u] = 1;
for (auto v : e[u])
if (v != fa)
{
self(self, v, u);
f[u] += f[v] + sz[v];
sz[u] += sz[v];
}
};
LL mv = 0;
int res = 0;
auto dfs2 = [&] (auto self, int u, int fa) -> void
{
if (f[u] > mv)
{
mv = f[u];
res = u;
}
for (auto v : e[u])
if (v != fa)
{
int t1, t2;
f[u] -= sz[v] + f[v];
t1 = sz[v] + f[v];
sz[u] -= sz[v];
t2 = sz[v];
f[v] += f[u] + sz[u];
sz[v] += sz[u];
self(self, v, u);
f[u] += t1;
sz[u] += t2;
}
};
dfs1(dfs1, 1, -1);
dfs2(dfs2, 1, -1);
cout << res << "\n";
}
int main()
{
ios::sync_with_stdio(false);
cin.tie(nullptr);
int T = 1;
while (T -- ) solve();
return 0;
}