题目描述
给你一个 n
个节点的无向无根树,节点编号从 0
到 n - 1
。给你整数 n
和一个长度为 n - 1
的二维整数数组 edges
,其中 edges[i] = [a_i, b_i]
表示树中节点 a_i
和 b_i
之间有一条边。再给你一个长度为 n
的数组 coins
,其中 coins[i]
可能为 0
也可能为 1 ,1 表示节点 i 处有一个金币。
一开始,你需要选择树中任意一个节点出发。你可以执行下述操作任意次:
- 收集距离当前节点距离为
2
以内的所有金币,或者 - 移动到树中一个相邻节点。
你需要收集树中所有的金币,并且回到出发节点,请你返回最少经过的边数。
如果你多次经过一条边,每一次经过都会给答案加一。
样例
输入:coins = [1,0,0,0,0,1], edges = [[0,1],[1,2],[2,3],[3,4],[4,5]]
输出:2
解释:从节点 2 出发,收集节点 0 处的金币,移动到节点 3,收集节点 5 处的金币,然后移动回节点 2。
输入:coins = [0,0,0,1,1,0,0,1], edges = [[0,1],[0,2],[1,3],[1,4],[2,5],[5,6],[5,7]]
输出:2
解释:从节点 0 出发,收集节点 4 和 3 处的金币,移动到节点 2 处,收集节点 7 处的金币,移动回节点 0。
限制
n == coins.length
1 <= n <= 3 * 10^4
0 <= coins[i] <= 1
edges.length == n - 1
edges[i].length == 2
0 <= a_i, b_i < n
a_i != b_i
edges
表示一棵合法的树。
算法
(两次递归遍历) $O(n)$
- 将无根树转为有根树,并令 $0$ 为根节点。
- 设 $f(u)$ 表示以 $u$ 为根的子树,从 $u$ 出发,收集子树的所有金币经过的最少边数。$f_1(u)$ 表示以 $u$ 为根的子树,距离 $u$ 恰好为 $1$ 的节点中金币的数量。$f_2(u)$ 表示以 $u$ 为根的子树,距离 $u$ 为 $2$ 或以上的节点中金币的数量。
- 第一次递归遍历,从 $0$ 开始针对每个点求出 $f$、$f_1$ 和 $f_2$。假设当前点为 $u$,对于子节点 $v$,如果 $f_2(v) > 0$,则 $f(u) = f(u) + f(v) + 2$,表示需要走到节点 $v$ 获取 $v$ 及其子树的所有金币后,再返回。$f_1(u) = f_1(u) + coins(v)$。$f_2(u) = f_2(u) + f_1(v) + f_2(v)$。
- 第二次递归遍历,求解以每个点作为起始点时的最终答案。这里需要引入三个值 $\text{fa_}f$,$\text{fa_}f_1$ 和 $\text{fa_}f_2$,含义是来自于父节点的 $f$、$f_1$ 和 $f_2$,表示起始点从 $u$ 换到 $v$ 过程中,$u$ 原本作为 $v$ 的父节点变到子节点时的值。
- 第二次递归时,对于当前点 $u$,更新答案,如果 $\text{fa_}f_2 > 0$,则用 $f_u + \text{fa_}f + 2$ 更新答案。否则,用 $f(u)$ 更新答案。这是因为 $fa$ 虽然是名义的父节点,在已经是「子节点」了,需要按照之前的过程把这个子节点的值统计到答案中。
- 接着枚举 $u$ 的子节点 $v$,转移时,$\text{new_fa_}f_1 = f_1(u) - coins(v) + coins(fa)$,这是因为,距离 $u$ 恰好为 $1$ 的金币个数来自于原来 $f_1(u)$ 中除去 $v$ 的,还有来自于 $u$ 父节点的。同理,计算 $\text{new_fa_}f$ 和 $\text{new_fa_}f_2$ 也是采用同样的思路,但需要考虑 $f_2(v)$ 的情况,具体可以参考代码。
时间复杂度
- 两次递归遍历,每次遍历每个节点仅访问一次,故总时间复杂度为 $O(n)$。
空间复杂度
- 需要 $O(n)$ 的额外空间存储图,$f$ 数组和递归的系统栈。
C++ 代码
class Solution {
private:
vector<vector<int>> graph;
vector<int> f, f1, f2;
int ans;
void solve1(int u, int fa, vector<int> &coins) {
for (int v : graph[u]) {
if (v == fa)
continue;
solve1(v, u, coins);
if (f2[v] > 0)
f[u] += f[v] + 2;
f1[u] += coins[v];
f2[u] += f1[v] + f2[v];
}
}
void solve2(int u, int fa, int fa_f, int fa_f1, int fa_f2, vector<int> &coins) {
if (fa_f2 > 0) ans = min(ans, f[u] + fa_f + 2);
else ans = min(ans, f[u]);
for (int v : graph[u]) {
if (v == fa)
continue;
int new_fa_f = fa_f + (fa_f2 > 0 ? 2 : 0) + f[u];
int new_fa_f1 = f1[u] - coins[v] + (fa != -1 ? coins[fa] : 0);
int new_fa_f2 = fa_f1 + fa_f2 + f2[u];
if (f2[v] > 0) {
new_fa_f -= f[v] + 2;
new_fa_f2 -= f1[v] + f2[v];
}
solve2(v, u, new_fa_f, new_fa_f1, new_fa_f2, coins);
}
}
public:
int collectTheCoins(vector<int>& coins, vector<vector<int>>& edges) {
const int n = coins.size();
graph.resize(n);
for (const auto &e : edges) {
graph[e[0]].push_back(e[1]);
graph[e[1]].push_back(e[0]);
}
f.resize(n, 0);
f1.resize(n, 0);
f2.resize(n, 0);
solve1(0, -1, coins);
ans = INT_MAX;
solve2(0, -1, 0, 0, 0, coins);
return ans;
}
};