题目描述
给定一棵包含 $n$ 个节点的有根无向树,节点编号互不相同,但不一定是 $1 \sim n$。
有 $m$ 个询问,每个询问给出了一对节点的编号 $x$ 和 $y$,询问 $x$ 与 $y$ 的祖孙关系。
对于每一个询问,若$x$ 是 $y$的祖先则输出 $1$,若 $y$ 是 $x$ 的祖先则输出 $2$,否则输出 $0$。
解题思路
$\qquad$显然这题可以转化为:(记$p$ 为$x$ 和 $y$的最近公共祖先),如果$x=p$输出$1$,如果$y=p$输出$2$,否则输出$0$,然后就可以用倍增求$LCA$了
倍增
#include <iostream>
#include <cstring>
#include <queue>
using namespace std;
const int N = 1e5 + 10;
int h[N], e[N], ne[N], idx;
int depth[N], f[N][20], n, m, root;
void add(int a, int b)
{
e[idx] = b, ne[idx] = h[a], h[a] = idx ++ ;
}
void prework(int root)
{
memset(depth, 0x3f, sizeof depth);
depth[0] = 0, depth[root] = 1;
queue<int> q; q.push(root);
while (q.size())
{
auto t = q.front(); q.pop();
for (int i = h[t]; ~i; i = ne[i])
{
int j = e[i];
if (depth[j] > depth[t] + 1)
{
depth[j] = depth[t] + 1;
f[j][0] = t, q.push(j);
for (int k = 1; k <= 15; k ++ )
f[j][k] = f[f[j][k - 1]][k - 1];
}
}
}
}
int lca(int x, int y)
{
if (depth[x] < depth[y]) swap(x, y);
for (int i = 15; ~i; i -- )
{
int tmp = f[x][i];
if (depth[tmp] >= depth[y]) x = tmp;
}
if (x == y) return x;
for (int i = 15; ~i; i -- )
if (f[x][i] != f[y][i]) x = f[x][i], y = f[y][i];
return f[x][0];
}
int main()
{
scanf("%d", &n);
memset(h, -1, sizeof h);
for (int i = 1; i <= n; i ++ )
{
int a, b;
scanf("%d%d", &a, &b);
if (b == -1) root = a;
add(a, b), add(b, a);
}
prework(root);
scanf("%d", &m);
while (m -- )
{
int x, y;
scanf("%d%d", &x, &y);
int p = lca(x, y);
if (x == p) puts("1");
else if (y == p) puts("2");
else puts("0");
}
return 0;
}
Tarjan
#include <iostream>
#include <vector>
#include <cstring>
#include <queue>
using namespace std;
using PII = pair<int, int>;
const int N = 1e5 + 10;
int h[N], e[N], ne[N], w[N], idx;
int st[N], dist[N], n, m, p[N], res[N];
vector<PII> query[N];
int x[N], y[N], root;
void add(int a, int b)
{
e[idx] = b, ne[idx] = h[a], h[a] = idx ++ ;
}
int find(int x)
{
if (x == p[x]) return x;
return p[x] = find(p[x]);
}
void tarjan(int cur)
{
st[cur] = 1;
for (int i = h[cur]; ~i; i = ne[i])
{
int j = e[i];
if (!st[j]) tarjan(j), p[j] = cur;
}
for (auto [q, id] : query[cur])
if (st[q] == 2) res[id] = find(q);
st[cur] = 2;
}
int main()
{
scanf("%d", &n);
memset(h, -1, sizeof h);
for (int i = 1; i <= N; i ++ ) p[i] = i;
for (int i = 1; i <= n; i ++ )
{
int u, v;
scanf("%d%d", &u, &v);
if (v == -1) root = u;
else add(u, v), add(v, u);
}
scanf("%d", &m);
for (int i = 1; i <= m; i ++ )
{
scanf("%d%d", &x[i], &y[i]);
query[x[i]].push_back({y[i], i});
query[y[i]].push_back({x[i], i});
}
tarjan(root);
for (int i = 1; i <= m; i ++ )
{
if (res[i] == x[i]) puts("1");
else if (res[i] == y[i]) puts("2");
else puts("0");
}
return 0;
}