换根DP
此题与 树的中心 基本一致
树的高度 等价于 树根往下递归的最长链的长度
树的最小高度 等价于 以树中任何一个节点为根时的往下递归的最长链的长度的最小值
因此考虑使用换根DP
以每一个节点为根的树的高度,就等价于从该节点出发的最长链的长度
从任意节点出发有两条路径:
1. 从该点往下递归
2. 从该点到其父节点往上递归
因此我们考虑维护两个数组:d1[]
和up[]
,分别表示从某个节点往下递归的最长链的长度,和往上递归的最长链的长度
其中,d1[]
非常好维护:假设存在一条x -> y
的边,因此d1[x] = d1[y] + w[x -> y]
,因此考虑自底向上维护d1[]
数组
而up[]
数组的维护,还需要额外维护两个数组:d2[]
和p1[]
,分别表示某个节点往下递归的次长链和最长链是从哪个节点上来的
同理,由up[y] = max(up[x], d1[u] OR d2[u]) + w[x -> y]
,可知我们需要自项向下维护up[]
数组,同时:
1. 如果x
的最长链是从y
上来的,就用次大值d2[u]
更新 ( 因为d1[u]
包含从x -> y
的路径 )
2. 否则,就用最大值d1[u]
更新
C++ Code
class Solution {
public:
vector<int> findMinHeightTrees(int n, vector<vector<int>>& edges) {
vector<int> d1(n), d2(n), up(n), p1(n);
// 建图
vector<vector<int>> g(n);
for (auto& e : edges) {
int x = e[0], y = e[1];
g[x].push_back(y), g[y].push_back(x); // 无向边
}
function<void(int, int)> dfs_d = [&](int u, int fa) {
d1[u] = d2[u] = 0;
for (int x : g[u]) {
if (x == fa) continue;
dfs_d(x, u); // 自底向上维护信息
int d = d1[x] + 1;
if (d >= d1[u]) {
d2[u] = d1[u], d1[u] = d;
p1[u] = x;
} else if (d > d2[u]) {
d2[u] = d;
}
}
};
dfs_d(0, -1);
function<void(int, int)> dfs_u = [&](int u, int fa) {
for (int x : g[u]) {
if (x == fa) continue;
if (p1[u] == x) up[x] = max(up[u], d2[u]) + 1; // u 的最长链是从 x 上来的 : 用次大值更新
else up[x] = max(up[u], d1[u]) + 1; // 否则用最大值更新
dfs_u(x, u); // 自项向下维护信息
}
};
dfs_u(0, -1);
int minv = 1e9;
for (int i = 0; i < n; i ++ ) minv = min(minv, max(up[i], d1[i]));
vector<int> res;
for (int i = 0; i < n; i ++ )
if (minv == max(up[i], d1[i])) {
res.push_back(i);
}
return res;
}
};
Python3 Code
class Solution:
def findMinHeightTrees(self, n: int, edges: List[List[int]]) -> List[int]:
# 建图
g = [[] for _ in range(n + 1)]
for x, y in edges:
g[x].append(y)
g[y].append(x)
d1 = [0] * n
d2 = [0] * n
p1 = [0] * n
up = [0] * n
def dfs_d(u: int, fa: int) -> None:
for x in g[u]:
if (x != fa):
dfs_d(x, u) # 自底向上信息
d = d1[x] + 1
if d >= d1[u]:
d2[u] = d1[u]; d1[u] = d
p1[u] = x
elif d > d2[u]:
d2[u] = d
dfs_d(0, -1)
def dfs_u(u: int, fa: int) -> None:
for x in g[u]:
if (x != fa):
if (p1[u] == x):
up[x] = max(up[u], d2[u]) + 1
else:
up[x] = max(up[u], d1[u]) + 1
dfs_u(x, u) # 自项向下信息
dfs_u(0, -1)
minv = 1e9
for i in range(0, n):
minv = min(minv, max(d1[i], up[i]))
res = []
for i in range(n):
if max(up[i], d1[i]) == minv:
res.append(i)
return res