理解题意
询问一颗树中两点之间的最短距离: 树中任意两点路径唯一, 实际就是求两点路径的距离.
一种做法是首先初始化$dist$数组 — $dist[u]: u$到根的距离.
设$u, v$的最近公共祖先为$p$, $u — v$的距离为: $dist[u] + dist[v] - 2\times dist[p]$.
在初始化$dist$数组后, 问题转变为求树上两点的最近公共祖先.
$Tarjan—$离线求$LCA$
离线
在线和离线这两个概念出现在包含询问的问题中:
-
在线: 每次询问都可以得到询问结果, 询问和对数据的操作可以同时进行.
-
离线: 先把所有询问读入, 在一次性输出询问结果.
算法过程
基于$dfs$算法, 在$dfs$过程中, 将所有节点分为三类(也可以按函数是否在栈中分类):
-
节点已被遍历且已被回溯(所有子节点均被遍历). 该节点对应的$dfs$函数栈帧已入栈且已出栈.
-
节点已被遍历但未被回溯(存在子节点未被遍历). 该节点对应的$dfs$函数栈帧已入栈且仍在栈中.
-
节点还未被遍历. 该节点对应的$dfs$函数栈帧还未入栈.
设$dfs$按从左向右的顺序遍历子树, 分别用灰色、蓝色和绿色表示上述节点:
观察仍在栈中的节点(蓝色)与已经出栈节点(灰色节点)的公共祖先:
对当前遍历的节点$u$(位于栈帧顶部, 箭头指向的节点), 与已经出栈的节点$v$, 其公共祖先为$v$
的仍在栈中的父节点. 可以将仍在栈中的节点作为其灰色子节点的代表(可利用并查集实现), 每次
询问当前遍历节点$u$与某灰色节点时, 它们的公共祖先即灰色节点的代表节点.
时间复杂度: 对于每个节点只会遍历一次, 合并一次, 查询一次. 而合并和查询可通过并查集实现, 可
认为时间是线性的, 所以时间复杂度为$O(n + m)$, 其中$n$为节点个数, $m$为查询次数.
算法理解
算法对节点的标记的过程与$dfs$算法过程同步, 考虑一颗具有普遍性的树:
若以蓝色节点为根, $dfs$的遍历过程可认为是“根$\rightarrow$左$\rightarrow$根$\rightarrow$右$…$”, 所以当查询
蓝色节点与灰色节点的公共祖先时, 按$dfs$的顺序自然就是仍在函数栈中的根. 我们
不用对这颗树的具体结构做任何假设, 所以可以认为算法具有普遍性.
代码实现
- 注意: 我们每次关注的是即将出栈(子节点均为灰色节点)即栈顶节点的查询, 所有节点的颜色是动态变化的.
灰色节点合并过程在其刚出栈后, 在它的父节点函数中执行(仍在栈中)将其合并至父节点.
#include <vector>
#include <cstring>
#include <iostream>
#include <algorithm>
using namespace std;
typedef pair<int, int> pii;
const int N = 1e4 + 10, M = 2 * N;
int n, m;
int h[N], e[M], w[M], ne[M], idx;
int p[N];
int st[N]; //0: 未入栈; 1:在栈中; 2:已出栈
int res[M]; //res[i]: 第i次询问的结果
int dist[N]; //dist[u]: 根到u的距离
vector<pii> query[N]; //query[u]: first:询问的另一个顶点; second:询问的编号(第几次)
void add(int u, int v, int c)
{
e[idx] = v, w[idx] = c, ne[idx] = h[u], h[u] = idx ++;
}
int find(int x)
{
return x == p[x] ? x : p[x] = find(p[x]);
}
void dfs(int u, int fa)
{//fa: 上次dfs遍历节点 防止向根循环遍历
for( int i = h[u]; ~i; i = ne[i] )
{
int v = e[i];
if( v != fa )
{
dist[v] = dist[u] + w[i];
dfs(v, u);
}
}
}
void tarjan(int u)
{
st[u] = 1; //u入栈
for( int i = h[u]; ~i; i = ne[i] )
{
int v = e[i];
if( !st[v] )
{//v还未入栈
tarjan(v);
p[v] = u; //v已经出栈 其父节点u仍然在栈中
}
}
for( auto item : query[u] )
{
int v = item.first, id = item.second;
if( st[v] == 2 )
{//v是灰色节点
int anc = find(v);
res[id] = dist[u] + dist[v] - 2 * dist[anc];
}
}
st[u] = 2; //u出栈
}
int main()
{
cin >> n >> m;
memset(h, -1, sizeof h);
for( int i = 1; i < n; i ++ )
{
int u, v, c;
cin >> u >> v >> c;
add(u, v, c); add(v, u, c);
}
for( int i = 1; i <= m; i ++ )
{
int u, v;
cin >> u >> v;
if( u != v)// 如果u == v, 结果为0, 因为res初始值为0, 可以忽略此次询问
{
query[u].push_back({v, i});
query[v].push_back({u, i}); //不知道谁先被遍历 所以两种情况均要考虑
}
}
dfs(1, -1); //预处理dist root = 1
for( int i = 1; i <= n; i ++ ) p[i] = i; //预处理并查集
tarjan(1);
for( int i = 1; i <= m; i ++ ) cout << res[i] << endl;
return 0;
}