题目描述
给你一棵 n
个节点的树(连通无向无环的图),节点编号从 0
到 n - 1
且恰好有 n - 1
条边。
给你一个长度为 n
下标从 0 开始的整数数组 vals
,分别表示每个节点的值。同时给你一个二维整数数组 edges
,其中 edges[i] = [a_i, b_i]
表示节点 a_i
和 b_i
之间有一条 无向 边。
一条 好路径 需要满足以下条件:
- 开始节点和结束节点的值 相同。
- 开始节点和结束节点中间的所有节点值都 小于等于 开始节点的值(也就是说开始节点的值应该是路径上所有节点的最大值)。
请你返回不同好路径的数目。
注意,一条路径和它反向的路径算作 同一 路径。比方说,0 -> 1
与 1 -> 0
视为同一条路径。单个节点也视为一条合法路径。
样例
输入:vals = [1,3,2,1,3], edges = [[0,1],[0,2],[2,3],[2,4]]
输出:6
解释:总共有 5 条单个节点的好路径。
还有 1 条好路径:1 -> 0 -> 2 -> 4。
(反方向的路径 4 -> 2 -> 0 -> 1 视为跟 1 -> 0 -> 2 -> 4 一样的路径)
注意 0 -> 2 -> 3 不是一条好路径,因为 vals[2] > vals[0]。
输入:vals = [1,1,2,2,3], edges = [[0,1],[1,2],[2,3],[2,4]]
输出:7
解释:总共有 5 条单个节点的好路径。
还有 2 条好路径:0 -> 1 和 2 -> 3。
输入:vals = [1], edges = []
输出:1
解释:这棵树只有一个节点,所以只有一条好路径。
限制
n == vals.length
1 <= n <= 3 * 10^4
0 <= vals[i] <= 10^5
edges.length == n - 1
edges[i].length == 2
0 <= a_i, b_i < n
a_i != b_i
edges
表示一棵合法的树。
算法
(并查集) $O(n \log n + m)$
- 将节点按照值从小到大排序,按顺序每次遍历值相同的一组点。
- 遍历时,如果当前点的值大于等于其出边的点的值,则将当前点与其出边的点合并。
- 加入结束后,再次遍历当前这组点,统计出每个集合内等于当前点的值的个数,然后累加答案。
时间复杂度
- 排序需要 $O(n \log n)$ 的时间。
- 建邻接表需要 $O(n + m)$ 的时间。
- 并查集单次操作时间复杂度近似常数,每个点仅访问常数次,时间复杂度为 $O(n)$。
- 故总时间复杂度为 $O(n \log n + m)$。
空间复杂度
- 需要 $O(n)$ 的额外空间存储邻接表和并查集。
C++ 代码
class Solution {
private:
vector<int> f, sz;
int find(int x) {
return x == f[x] ? x : f[x] = find(f[x]);
}
void merge(int x, int y) {
int fx = find(x), fy = find(y);
if (fx == fy)
return;
if (sz[fx] < sz[fy]) {
f[fx] = fy;
sz[fy] += sz[fx];
} else {
f[fy] = fx;
sz[fx] += sz[fy];
}
}
public:
int numberOfGoodPaths(vector<int>& vals, vector<vector<int>>& edges) {
const int n = vals.size();
vector<vector<int>> graph(n);
for (const auto &e : edges) {
graph[e[0]].push_back(e[1]);
graph[e[1]].push_back(e[0]);
}
f.resize(n);
sz.resize(n, 1);
for (int i = 0; i < n; i++)
f[i] = i;
vector<int> rk(n);
for (int i = 0; i < n; i++)
rk[i] = i;
sort(rk.begin(), rk.end(), [&](int x, int y) {
return vals[x] < vals[y];
});
int ans = 0;
vector<int> cnt(n, 0);
for (int i = 1, j = 0; i <= n; i++) {
if (i < n && vals[rk[i]] == vals[rk[j]])
continue;
for (int k = j; k < i; k++) {
int p = rk[k];
for (int u : graph[p])
if (vals[u] <= vals[p])
merge(u, p);
}
for (int k = j; k < i; k++) {
cnt[find(rk[k])]++;
ans += cnt[find(rk[k])];
}
for (int k = j; k < i; k++)
cnt[find(rk[k])]--;
j = i;
}
return ans;
}
};
排序需要算时间复杂度吗
需要
已修正