题目描述
给你一棵 n
个节点的无向树,节点编号为 1
到 n
。给你一个整数 n
和一个长度为 n - 1
的二维整数数组 edges
,其中 edges[i] = [u_i, v_i]
表示节点 u_i
和 v_i
在树中有一条边。
请你返回树中的 合法路径数目。
如果在节点 a
到节点 b
之间 恰好有一个 节点的编号是质数,那么我们称路径 (a, b)
是 合法的。
注意:
- 路径
(a, b)
指的是一条从节点a
开始到节点b
结束的一个节点序列,序列中的节点 互不相同,且相邻节点之间在树上有一条边。 - 路径
(a, b)
和路径(b, a)
视为 同一条 路径,且只计入答案 一次。
样例
输入:n = 5, edges = [[1,2],[1,3],[2,4],[2,5]]
输出:4
解释:恰好有一个质数编号的节点路径有:
- (1, 2) 因为路径 1 到 2 只包含一个质数 2。
- (1, 3) 因为路径 1 到 3 只包含一个质数 3。
- (1, 4) 因为路径 1 到 4 只包含一个质数 2。
- (2, 4) 因为路径 2 到 4 只包含一个质数 2。
只有 4 条合法路径。
输入:n = 6, edges = [[1,2],[1,3],[2,4],[3,5],[3,6]]
输出:6
解释:恰好有一个质数编号的节点路径有:
- (1, 2) 因为路径 1 到 2 只包含一个质数 2。
- (1, 3) 因为路径 1 到 3 只包含一个质数 3。
- (1, 4) 因为路径 1 到 4 只包含一个质数 2。
- (1, 6) 因为路径 1 到 6 只包含一个质数 3。
- (2, 4) 因为路径 2 到 4 只包含一个质数 2。
- (3, 6) 因为路径 3 到 6 只包含一个质数 3。
只有 6 条合法路径。
限制
1 <= n <= 10^5
edges.length == n - 1
edges[i].length == 2
1 <= ui, vi <= n
- 输入保证
edges
形成一棵合法的树。
算法
(筛质数,递归遍历) $O(n)$
- 通过线性筛标记 $1$ 到 $n$ 的质数。
- 从节点 $1$ 开始递归遍历,每层递归都分别返回以当前为根的子树中,恰好存在一个质数的路径个数和不存在质数的路径个数。每次递归累计跨过当前根节点的路径个数。
- 对于当前节点 $u$,维护两个变量 $f$ 和 $g$,分别记录当前已遍历过的子树中,恰好存在一个质数的路径个数和不存在质数的路径个数。
- 对于一个递归子节点的结果 $res$,如果当前节点 $u$ 是质数,则只能累计 $(g + 1)res.g$ 到答案中(需要找到两个不存在质数的路径拼在一起,且独立的根节点也是一种路径)。同理,如果当前节点不是质数,则累加 $f \cdot res.g + (g + 1)res.f$。
- 最后返回时,如果当前节点是质数,则返回 $(g + 1, 0)$,这是因为不存在质数的路径都变成了恰好存在一个质数的路径,且不再存在没有质数的路径了。否则,返回 $(f, g + 1)$,这里注意独立的根节点也是一种不存在质数的路径端点。
时间复杂度
- 筛质数的时间复杂度为 $O(n)$。
- 递归遍历每个节点一次,累加答案的时间为常数。
- 故总时间复杂度为 $O(n)$。
空间复杂度
- 需要 $O(n)$ 的额外空间存储线性筛的数据结构,和递归的系统栈空间。
C++ 代码
#define LL long long
class Solution {
private:
vector<bool> is_not_prime;
vector<int> prime;
vector<vector<int>> graph;
LL ans;
void init_prime(int n) {
is_not_prime.resize(n + 1, false);
is_not_prime[1] = true;
for (int i = 2; i <= n; i++) {
if (!is_not_prime[i])
prime.push_back(i);
for (int j = 0; i * prime[j] <= n; j++) {
is_not_prime[i * prime[j]] = true;
if (i % prime[j] == 0)
break;
}
}
}
pair<LL, LL> solve(int u, int fa) {
LL f = 0, g = 0;
for (int v : graph[u]) {
if (v == fa)
continue;
auto res = solve(v, u);
if (is_not_prime[u])
ans += f * res.second + (g + 1) * res.first;
else
ans += (g + 1) * res.second;
f += res.first;
g += res.second;
}
if (is_not_prime[u])
return make_pair(f, g + 1);
return make_pair(g + 1, 0);
}
public:
LL countPaths(int n, vector<vector<int>>& edges) {
init_prime(n);
graph.resize(n + 1);
for (const auto &e : edges) {
graph[e[0]].push_back(e[1]);
graph[e[1]].push_back(e[0]);
}
ans = 0;
solve(1, 0);
return ans;
}
};