题解:关于树的路径统计问题
一、题目背景
给定一棵有 (n) 个节点的树,树中节点编号从 (1) 到 (n),有 (m) 条路径。对于每条路径,给出路径的两个端点 (a) 和 (b)。要求计算经过节点的路径数量,对于一个节点,如果有路径经过它,且该路径的两个端点都在以该节点为根的子树之外,那么该节点的贡献为 (m);如果有路径经过它,且该路径的一个端点在以该节点为根的子树内,另一个端点在子树外,那么该节点的贡献为 (1)。最终输出所有节点的贡献之和。
二、代码整体结构
代码主要由以下几个部分组成:
1. 头文件和全局变量的定义,用于存储树的结构信息、节点深度、祖先关系、路径标记以及结果变量等。
2. add
函数,用于向树中添加边。
3. bfs
函数,通过广度优先搜索计算每个节点的深度以及各级祖先。
4. lca
函数,用于求解两个节点的最近公共祖先。
5. dfs
函数,通过深度优先搜索计算每个节点的贡献,并累加得到最终结果。
6. main
函数,负责读取输入数据,调用上述函数进行处理,并输出结果。
三、代码逐段分析
(一)头文件和全局变量
#include <cstdio>
#include <cstring>
#include <iostream>
#include <algorithm>
using namespace std;
const int N = 100010, M = N * 2;
int n, m;
int h[N], e[M], ne[M], idx;
int depth[N], fa[N][17];
int d[N];
int q[N];
int ans;
- 头文件:
#include <cstdio>
:提供标准输入输出函数,如scanf
和printf
。#include <cstring>
:用于字符串操作,如memset
函数。#include <iostream>
:提供 C++ 风格的输入输出流,如cin
和cout
。#include <algorithm>
:包含一些常用的算法函数,如swap
和max
等。using namespace std;
:使用标准命名空间,以便直接使用标准库中的函数和类型。
- 常量定义:
const int N = 100010
:定义节点的最大数量。const int M = N * 2
:由于是无向树(边是无向的),边的最大数量是节点数量的两倍。
- 变量定义:
int n, m;
:n
表示树的节点数,m
表示路径的数量。int h[N], e[M], ne[M], idx;
:用于构建邻接表存储树的结构。h[i]
表示节点i
的第一条边在e
和ne
数组中的下标;e[j]
表示第j
条边的另一端点;ne[j]
表示与第j
条边同起点的下一条边的下标;idx
是边的计数器,用于记录当前已添加边的数量。int depth[N], fa[N][17];
:depth[i]
记录节点i
在树中的深度;fa[i][k]
表示节点i
的2^k
级祖先(k
从0
到16
),通过这种方式可以快速向上跳跃祖先节点。int d[N];
:用于标记每个节点被路径经过的情况,通过对路径端点和最近公共祖先的操作来统计。int q[N];
:作为队列,用于广度优先搜索(BFS)过程中存储待处理的节点。int ans;
:用于存储最终的结果,即所有节点的贡献之和。
(二)add
函数
void add(int a, int b)
{
e[idx] = b, ne[idx] = h[a], h[a] = idx ++;
}
该函数用于向邻接表中添加一条从节点 a
到节点 b
的边。具体操作是将节点 b
存储到 e[idx]
,将原来节点 a
的第一条边的下标存储到 ne[idx]
,然后更新 h[a]
为当前边的下标 idx
,最后 idx
自增 1
,以便添加下一条边。
(三)bfs
函数
void bfs()
{
memset(depth, 0x3f, sizeof depth);
depth[0] = 0, depth[1] = 1;
int hh = 0, tt = 0;
q[0] = 1;
while (hh <= tt)
{
int t = q[hh ++];
for (int i = h[t]; ~i; i = ne[i])
{
int j = e[i];
if (depth[j] > depth[t] + 1)
{
depth[j] = depth[t] + 1;
q[ ++ tt] = j;
fa[j][0] = t;
for (int k = 1; k <= 16; k ++ )
fa[j][k] = fa[fa[j][k - 1]][k - 1];
}
}
}
}
- 初始化:
memset(depth, 0x3f, sizeof depth);
:将所有节点的深度初始化为一个较大的值(0x3f
表示十六进制的较大数,这里用于表示未访问过的节点深度)。depth[0] = 0, depth[1] = 1;
:将虚拟节点0
的深度设为0
(这里的0
节点是为了方便处理,实际无意义),根节点1
的深度设为1
。int hh = 0, tt = 0;
:初始化队列的头指针hh
和尾指针tt
为0
。q[0] = 1;
:将根节点1
放入队列中。
- BFS 过程:
while (hh <= tt)
:当队列不为空时,进行循环。int t = q[hh ++];
:取出队头节点t
,并将头指针hh
后移一位。- 对于节点
t
的每一条邻接边(通过for (int i = h[t]; ~i; i = ne[i])
遍历):- 若邻接节点
j
(j = e[i]
)的深度大于节点t
的深度加1
,说明找到了一条更短的路径到达节点j
。 - 更新节点
j
的深度为depth[t] + 1
,将节点j
加入队列(q[ ++ tt] = j;
),并设置节点j
的0
级祖先为节点t
(fa[j][0] = t;
)。 - 通过内层循环
for (int k = 1; k <= 16; k ++ )
计算节点j
的2^k
级祖先,利用递推关系fa[j][k] = fa[fa[j][k - 1]][k - 1];
,即节点j
的2^k
级祖先是节点j
的2^(k - 1)
级祖先的2^(k - 1)
级祖先。
- 若邻接节点
(四)lca
函数
int lca(int a, int b)
{
if (depth[a] < depth[b]) swap(a, b);
for (int k = 16; k >= 0; k -- )
if (depth[fa[a][k]] >= depth[b])
a = fa[a][k];
if (a == b) return a;
for (int k = 16; k >= 0; k -- )
if (fa[a][k] != fa[b][k])
{
a = fa[a][k];
b = fa[b][k];
}
return fa[a][0];
}
- 深度对齐:
if (depth[a] < depth[b]) swap(a, b);
:确保节点a
的深度大于等于节点b
的深度,如果不是则交换a
和b
。- 通过循环
for (int k = 16; k >= 0; k -- )
,从2^16
级祖先开始,若节点a
的2^k
级祖先的深度大于等于节点b
的深度,则将节点a
向上提升到其2^k
级祖先(a = fa[a][k];
),这样可以使节点a
和节点b
处于同一深度。
- 寻找最近公共祖先:
if (a == b) return a;
:如果此时a
和b
相等,说明已经找到了最近公共祖先,直接返回a
。- 再次通过循环
for (int k = 16; k >= 0; k -- )
,从2^16
级祖先开始,若节点a
和节点b
的2^k
级祖先不相等,则将节点a
和节点b
同时向上提升到它们的2^k
级祖先(a = fa[a][k]; b = fa[b][k];
)。 - 循环结束后,
a
和b
的父节点(即fa[a][0]
)就是它们的最近公共祖先,返回fa[a][0]
。
(五)dfs
函数
int dfs(int u, int father)
{
int res = d[u];
for (int i = h[u]; ~i; i = ne[i])
{
int j = e[i];
if (j != father)
{
int s = dfs(j, u);
if (s == 0) ans += m;
else if (s == 1) ans ++;
res += s;
}
}
return res;
}
- 函数接受当前节点
u
和其父节点father
作为参数。 - 初始化
res
为节点u
的路径标记值d[u]
。 - 遍历节点
u
的所有邻接节点j
(通过for (int i = h[u]; ~i; i = ne[i])
遍历邻接表):- 如果邻接节点
j
不是节点u
的父节点(j != father
),则递归调用dfs(j, u)
计算子树j
的路径标记值s
。 - 根据
s
的值更新结果ans
:如果s
为0
,说明子树j
中没有路径端点,且有路径经过节点u
且两个端点都在子树j
之外,此时ans
增加m
;如果s
为1
,说明子树j
中有一个路径端点,此时ans
增加1
。 - 将
s
的值累加到res
中。
- 如果邻接节点
- 最后返回
res
,表示以节点u
为根的子树中路径端点的数量。
(六)main
函数
int main()
{
scanf("%d%d", &n, &m);
memset(h, -1, sizeof h);
for (int i = 0; i < n - 1; i ++ )
{
int a, b;
scanf("%d%d", &a, &b);
add(a, b), add(b, a);
}
bfs();
for (int i = 0; i < m; i ++ )
{
int a, b;
scanf("%d%d", &a, &b);
int p = lca(a, b);
d[a] ++, d[b] ++, d[p] -= 2;
}
dfs(1, -1);
printf("%d\n", ans);
return 0;
}
- 输入树的信息:
scanf("%d%d", &n, &m);
:读取树的节点数n
和路径数m
。memset(h, -1, sizeof h);
:将邻接表头数组h
初始化为-1
,表示每个节点的第一条边还未添加。- 通过循环
for (int i = 0; i < n - 1; i ++ )
读取树的n - 1
条边的两个端点a
和b
,并通过add(a, b), add(b, a);
在邻接表中添加双向边。
- 计算节点深度和祖先关系:
bfs();
:以节点1
为起点,调用bfs
函数进行广度优先搜索,计算每个节点的深度和各级祖先。
- 处理路径信息:
- 通过循环
for (int i = 0; i < m; i ++ )
处理每一条路径:- 读取路径的两个端点
a
和b
(scanf("%d%d", &a, &b);
)。 - 调用
lca(a, b)
函数计算它们的最近公共祖先p
。 - 对路径端点
a
和b
以及最近公共祖先p
的路径标记数组d
进行操作:d[a] ++, d[b] ++, d[p] -= 2;
,这样做的目的是为了后续通过dfs
函数统计每个节点的贡献。
- 读取路径的两个端点
- 通过循环
- 深度优先搜索计算结果:
dfs(1, -1);
:从根节点1
开始进行深度优先搜索,计算所有节点的贡献,并累加到ans
中。
- 输出结果:
printf("%d\n", ans);
:输出最终的结果,即所有节点的贡献之和。
四、时间复杂度和空间复杂度分析
(一)时间复杂度
- 构建树的邻接表:添加
n - 1
条边,每次添加边的操作时间复杂度为 (O(1)),所以构建树的邻接表的时间复杂度为 (O(n))。 - BFS 过程:每个节点最多入队和出队一次,每条边最多被访问两次,时间复杂度为 (O(n + m’))(这里的
m'
是边数,因为是树,边数为n - 1
,所以时间复杂度为 (O(n)))。 - 计算最近公共祖先:每次查询时,通过两次循环,每次循环最多执行
17
次(因为k
从16
到0
),有m
次查询,所以计算最近公共祖先的时间复杂度为 (O(m \times 2 \times 17) = O(m))(这里忽略常数项)。 - 深度优先搜索计算贡献:每个节点最多被访问一次,时间复杂度为 (O(n))。
- 总体时间复杂度为 (O(n + m))。
(二)空间复杂度
- 邻接表存储树结构:边的数量最多为
M = N * 2
,节点的数量为N
,所以邻接表的空间复杂度为 (O(N + M) = O(N))(因为M = 2N
)。 - 存储深度和祖先信息:
depth[N]
数组占用 (O(N)) 的空间,fa[N][17]
数组占用 (O(N \times 17) = O(N)) 的空间。 - 路径标记数组
d[N]
:占用 (O(N)) 的空间。 - 队列
q[N]
:占用 (O(N)) 的空间。 - 总体空间复杂度为 (O(N))。