头像

滑稽_ωノ


访客:26828

离线:8小时前


新鲜事 原文

试一下AcWing的新功能
图片



原题链接: 模板】可持久化数组(可持久化线段树/平衡树)
AgOH大佬的视频课

题目描述

维护这样的一个长度为 $n$ 的数组,支持如下几种操作
1. 在某个历史版本上修改某一个位置上的值
2. 访问某个历史版本上的某一位置的值
3.
此外,每进行一次操作(对于操作2,即为生成一个完全一样的版本,不作任何改动),就会生成一个新的版本。版本编号即为当前操作的编号(从1开始编号,版本0表示初始状态数组)

输入样例

5 10
59 46 14 87 41
0 2 1
0 1 1 14
0 1 1 57
0 1 1 88
4 2 4
0 2 5
0 2 4
4 2 1
2 2 2
1 1 5 91

输出样例

59
87
41
87
88
46

可持久化线段树

线段树里什么信息都不需要维护,只需要在叶节点存对应下标的值即可

时间复杂度 $O((n + m)logn$

C++ 代码

#include<cstdio>

inline int read(){
    int x=0,f=1;
    char ch=getchar();
    while(ch<'0'||ch>'9'){
        if(ch=='-')
            f=-1;
        ch=getchar();
    }
    while(ch>='0'&&ch<='9'){
        x=(x<<1)+(x<<3)+(ch^48);
        ch=getchar();
    }
    return x*f;
}

inline void write(int x){
    static char F[200];
    int tmp=x>0?x:-x;
    if(x<0)putchar('-');
    int cnt=0;
    while(tmp>0){
        F[cnt++]=tmp%10+'0';
        tmp/=10;
    }
    while(cnt>0) putchar(F[--cnt]);
}

const int N = 1000010;

int n, m;
int a[N];

struct Node{
    int l, r;
    int val;
}tr[N * 40];

int root[N], tot;

void build(int &u, int l, int r)
{
    u = ++ tot;
    if(l == r)
    {
        tr[u].val = a[l];
        return;
    }
    int mid = l + r >> 1;
    build(tr[u].l, l, mid),  build(tr[u].r, mid + 1, r);
}

void modify(int ver, int &u, int l, int r, int x, int val)
{
    tr[u = ++ tot] = tr[ver];
    if(l == r)
    {
        tr[u].val = val;
        return;
    }
    int mid = l + r >> 1;
    if(x <= mid)  modify(tr[ver].l, tr[u].l, l, mid, x, val);
    else  modify(tr[ver].r, tr[u].r, mid + 1, r, x, val);
}

int query(int u, int l, int r, int x)
{
    if(l == r)  return tr[u].val;

    int mid = l + r >> 1;
    if(x <= mid)  return query(tr[u].l, l, mid, x);
    return query(tr[u].r, mid + 1, r, x);
}

int main()
{
    n = read();
    m = read();
    for(int i = 1; i <= n; i ++)  a[i] = read();

    build(root[0], 1, n);

    for(int i = 1; i <= m; i ++)
    {
        int v, op, x, val;
        v = read();
        op = read();
        x = read();
        if(op == 1)
        {
            val = read();
            modify(root[v], root[i], 1, n, x, val);
        }
        else
        {
            write(query(root[v], 1, n, x));
            putchar('\n');
            root[i] = root[v];
        }
    }
    return 0;
}



原题链接: 二逼平衡树
肖然大佬的博客

题目描述

著名的树套树板子题


下标线段树套值域平衡树

时间复杂度 $O(nlog^3n)$

C++ 代码

#include<cstdio>
#include<cstring>
#include<cstdlib>
#include<algorithm>
#include<vector>

using namespace std;

const int N = 100010, INF = 2147483647;

struct Node{
    int l, r;
    int val, key;
    int sz;
}tr[N * 18];

int n, m;
int a[N];
int tot;

int get_node(int val)
{
    int u = ++ tot;
    tr[u].val = val;
    tr[u].key = rand();
    tr[u].sz = 1;
    return u;
}

void pushup(int u)
{
    tr[u].sz = tr[tr[u].l].sz + tr[tr[u].r].sz + 1;
}

void split(int u, int val, int &x, int &y)
{
    if(!u)  x = y = 0;
    else
    {
        if(tr[u].val <= val)
        {
            x = u;
            split(tr[u].r, val, tr[u].r, y);
        }
        else
        {
            y = u;
            split(tr[u].l, val, x, tr[u].l);
        }
        pushup(u);
    }
}

int merge(int x, int y)
{
    if(!x or !y)  return x | y;
    if(tr[x].key > tr[y].key)
    {
        tr[x].r = merge(tr[x].r, y);
        pushup(x);
        return x;
    }
    else
    {
        tr[y].l = merge(x, tr[y].l);
        pushup(y);
        return y;
    }
}

int x, y, z;
void insert(int &root, int val)
{
    split(root, val, x, y);
    root = merge(merge(x, get_node(val)), y);
}

void dele(int &root, int val)
{
    split(root, val, x, y);
    split(x, val - 1, x, z);
    z = merge(tr[z].l, tr[z].r);
    root = merge(merge(x, z), y);
}

int get_rank(int &root, int val)
{
   split(root, val - 1, x, y);
   int res = tr[x].sz;
   root = merge(x, y);
   return res;
}

int get_val(int &root, int k)
{
    int u = root;
    while(u)
    {
        if(tr[tr[u].l].sz + 1 == k)  break;
        if(tr[tr[u].l].sz >= k)  u = tr[u].l;
        else
        {
            k -= tr[tr[u].l].sz + 1;
            u = tr[u].r;
        }
    }
    return tr[u].val;
}

int get_prev(int &root, int val)
{
    split(root, val - 1, x, y);
    int u = x;
    while(tr[u].r)  u = tr[u].r;
    int res = tr[u].val;
    root = merge(x, y);
    return res;
}

int get_next(int &root, int val)
{
    split(root, val, x, y);
    int u = y;
    while(tr[u].l)  u = tr[u].l;
    int res = tr[u].val;
    root = merge(x, y);
    return res;
}

int root[N << 2];

void build(int u, int l, int r)
{
    for(int i = l; i <= r; i ++)  insert(root[u], a[i]);
    if(l == r)  return;
    int mid = l + r >> 1;
    build(u << 1, l, mid),  build(u << 1 | 1, mid + 1, r);
}

int query_rank(int u, int ul, int ur, int l, int r, int val)
{
    if(ul >= l and ur <= r)  return get_rank(root[u], val);

    int mid = ul + ur >> 1, res = 0;
    if(l <= mid)  res = query_rank(u << 1, ul, mid, l, r, val);
    if(r > mid)  res += query_rank(u << 1 | 1, mid + 1, ur, l, r, val);
    return res;
}

int query_val(int x, int y, int k)
{
    int l = 0, r = 1e8;
    while(l < r)
    {
        int mid = l + r + 1 >> 1;
        if(query_rank(1, 1, n, x, y, mid) < k)  l = mid;
        else  r = mid - 1;
    }
    return l;
}

void update(int u, int ul, int ur, int x, int val)
{
    dele(root[u], a[x]);
    insert(root[u], val);

    if(ul == ur)  return;

    int mid = ul + ur >> 1;
    if(x <= mid)  update(u << 1, ul, mid, x, val);
    else  update(u << 1 | 1, mid + 1, ur, x, val);
}

int query_prev(int u, int ul, int ur, int l, int r, int val)
{
    if(ul >= l and ur <= r)
    {
        return get_prev(root[u], val);
    }

    int mid = ul + ur >> 1, res = -1;
    if(l <= mid)  res = query_prev(u << 1, ul, mid, l, r, val);
    if(r > mid)  res = max(res, query_prev(u << 1 | 1, mid + 1, ur, l, r, val));
    return res;
}

int query_next(int u, int ul, int ur, int l, int r, int val)
{
    if(ul >= l and ur <= r)  return get_next(root[u], val);

    int mid = ul + ur >> 1, res = -1;
    if(l <= mid)  res = query_next(u << 1, ul, mid, l, r, val);
    if(r > mid)
    {
        int t = query_next(u << 1 | 1, mid + 1, ur, l, r, val);
        if(res == -1 or t != -1 and t < res)  res = t;
    }
    return res;
}

int main()
{
    srand(114514);
    tr[0].val = -1;

    scanf("%d%d", &n, &m);
    for(int i = 1; i <= n; i ++)  scanf("%d", &a[i]);

    build(1, 1, n);

    while(m --)
    {
        int op, l, r, k, val;
        scanf("%d", &op);
        if(op == 1)
        {
            scanf("%d%d%d", &l, &r, &val);
            printf("%d\n", query_rank(1, 1, n, l, r, val) + 1);
        }
        else if(op == 2)
        {
            scanf("%d%d%d", &l, &r, &k);
            printf("%d\n", query_val(l, r, k));
        }
        else if(op == 3)
        {
            scanf("%d%d", &k, &val);
            update(1, 1, n, k, val);
            a[k] = val;
        }
        else if(op == 4)
        {
            scanf("%d%d%d", &l, &r, &val);
            int t = query_prev(1, 1, n, l, r, val);
            if(t == -1)  t = -INF;
            printf("%d\n", t);
        }
        else if(op == 5)
        {
            scanf("%d%d%d", &l, &r, &val);
            int t = query_next(1, 1, n, l, r, val);
            if(t == -1)  t = INF;
            printf("%d\n", t);
        }
    }
    return 0;
}



原题链接: 洛谷P4197 Peaks
肖然大佬的视频教程

题目描述

给出一个 $n$ 个点 $m$ 条边的有点权有边权的无向图,有 $q$ 个询问,
每次询问求从点 $v$ 出发只经过边权小于等于 $x$ 的边,能抵达的第 $k$ 大的点的点权

$n \leq 10^5$
$m, q \leq 5 × 10^5$
$点权,边权 \leq 10^9$

输入样例

10 11 4
1 2 3 4 5 6 7 8 9 10
1 4 4
2 5 3
9 8 2
7 8 10
7 1 4
6 7 1
6 4 8
2 1 5
10 8 10
3 4 7
3 4 6
1 5 2
1 5 6
1 5 8
8 9 2

输出样例

6
1
-1
8

离线线段树合并

首先把点权离散化,在每个点上建一棵线段树,维护点权区间的 $size$

然后把边按照边权排序,把询问按照边权 $x$ 排序

由于询问是递增的,我们就可以把边权小于等于 $x$ 的每一条边所连接的点合并
并查集合并连通块,同时合并线段树

对于每个询问我们在合并之后的线段树内二分即可找到第 $k$ 大的点

时间复杂度 $O((n + m)logn)$

C++ 代码

#include<cstdio>
#include<cstring>
#include<iostream>
#include<algorithm>
#include<vector>

using namespace std;

const int N = 100010, M = 500010;

int n, m, k;
int h[N];

vector<int> hs;
int find(int x)
{
    return lower_bound(hs.begin(), hs.end(), x) - hs.begin();
}

struct Edge{
    int a, b, w;
    bool used;
    bool operator< (const struct Edge &W) const {
        return w < W.w;
    }
}e[M];

int p[N];
int fand(int x)
{
    if(x != p[x])  p[x] = fand(p[x]);
    return p[x];
}

struct Query{
    int root, x, k, num;
    bool operator< (const struct Query &W) const {
        return x < W.x;
    }
}q[M];

struct Node{
    int l, r;
    int sz;
}tr[N * 18];

int root[N], tot;

void insert(int &u, int l, int r, int x)
{
    if(!u)  u = ++ tot;
    tr[u].sz ++;
    if(l == r)  return;
    int mid = l + r >> 1;
    if(x <= mid)  insert(tr[u].l, l, mid, x);
    else  insert(tr[u].r, mid + 1, r, x);
}

int merge(int x, int y, int l, int r)
{
    if(!x or !y)  return x + y;

    int u = x;
    tr[u].sz = tr[x].sz + tr[y].sz;
    if(l == r)  return u;
    int mid = l + r >> 1;
    if(tr[x].l or tr[y].l)  tr[u].l = merge(tr[x].l, tr[y].l, l, mid);
    if(tr[x].r or tr[y].r)  tr[u].r = merge(tr[x].r, tr[y].r, mid + 1, r);
    return u;
}

int query(int u, int l, int r, int k)
{
    if(tr[u].sz < k)  return -1;
    if(l == r)  return hs[l];

    int mid = l + r >> 1;
    if(tr[tr[u].r].sz >= k)  return query(tr[u].r, mid + 1, r, k);
    return query(tr[u].l, l, mid, k - tr[tr[u].r].sz);
}

int res[M];

int main()
{
    scanf("%d%d%d", &n, &m, &k);
    for(int i = 1; i <= n; i ++)  scanf("%d", &h[i]),  hs.push_back(h[i]), p[i] = i;
    sort(hs.begin(), hs.end());
    hs.erase(unique(hs.begin(), hs.end()), hs.end());

    for(int i = 0; i < m; i ++)  scanf("%d%d%d", &e[i].a, &e[i].b, &e[i].w);
    sort(e, e + m);

    for(int i = 0; i < k; i ++)
    {
        scanf("%d%d%d", &q[i].root, &q[i].x, &q[i].k);
        q[i].num = i;
    }
    sort(q, q + k);

    for(int i = 1; i <= n; i ++)  insert(root[i], 0, hs.size() - 1, find(h[i]));

    for(int i = 0, j = 0; i < k; i ++)
    {
        while(j < m and e[j].w <= q[i].x)
        {
            int a = e[j].a, b = e[j].b;
            int pa = fand(a), pb = fand(b);
            if(pa != pb)
            {
                root[pb] = merge(root[pa], root[pb], 0, hs.size() - 1);
                p[pa] = pb;
            }
            j ++;
        }
        res[q[i].num] = query(root[fand(q[i].root)], 0, hs.size() - 1, q[i].k);
    }

    for(int i = 0; i < k; i ++)  printf("%d\n", res[i]);
    return 0;
}



class Solution {
public:
    int maxProduct(vector<int>& nums) {

        int d1 = 0, d2 = 0;
        for(int i = 0; i < nums.size(); i ++)
            if(nums[i] > d1)  d2 = d1, d1 = nums[i];
            else if(nums[i] > d2)  d2 = nums[i];

        return (d1 - 1) * (d2 - 1);
    }
};




const int N = 100010, mod = 1e9 + 7;

int a[N], b[N];

class Solution {
public:
    int maxArea(int h, int w, vector<int>& horizontalCuts, vector<int>& verticalCuts) {


        vector<int> &x = horizontalCuts, &y = verticalCuts;

        x.push_back(0),  x.push_back(h);
        y.push_back(0),  y.push_back(w);
        sort(x.begin(), x.end());
        sort(y.begin(), y.end());

        int d1 = 0, d2 = 0;
        for(int i = 1; i < x.size(); i ++)  a[i] = x[i] - x[i - 1],  d1 = max(d1, a[i]);
        for(int i = 1; i < y.size(); i ++)  b[i] = y[i] - y[i - 1],  d2 = max(d2, b[i]);


        return (long long)d1 * d2 % mod;
    }
};




const int N = 100010;

int h[N], e[N], ne[N], to[N], idx;
void add(int a, int b, int c)
{
    e[idx] = b, to[idx] = c, ne[idx] = h[a], h[a] = idx ++;
}

int res;
void dfs(int u, int fa)
{
   for(int i = h[u]; ~i; i = ne[i])
   {
       int j = e[i];
       if(j == fa)  continue;
       if(to[i] != j)  res ++;
       dfs(j, u);
   }
}

class Solution {
public:
    int minReorder(int n, vector<vector<int>>& connections) {

        memset(h, -1, sizeof h);
        idx = res = 0;

        for(int i = 0; i < connections.size(); i ++)
        {
            int a = connections[i][0], b = connections[i][1];
            add(a, b, b);
            add(b, a, b);
        }

        dfs(0, -1);

        return n - 1 - res;
    }
};




const int N = 10;

int n, sum;
int a[N];

double fact[50];

double get(int a[])      //  求有重复元素的不同排列数
{
    int sum = 0;
    for(int i = 0; i < n; i ++)  sum += a[i];

    double res = fact[sum];
    for(int i = 0; i < n; i ++)  res /= fact[a[i]];
    return res;
}

int l[N], r[N];
double up, down;

void dfs(int u, int cnt)
{
    if(u == n)
    {
        if(cnt * 2 == sum)
        {
            int left = 0, right = 0;
            for(int i = 0; i < n; i ++)  if(l[i])  left ++;
            for(int i = 0; i < n; i ++)  if(r[i])  right ++;
            if(left == right)  up += get(l) * get(r);
        }
        return;
    }

    for(int i = 0; i <= a[u]; i ++)             //  枚举第u种球给左边i个
    {
        l[u] = i,  r[u] = a[u] - i;
        dfs(u + 1, cnt + i);
    }
}

class Solution {
public:
    double getProbability(vector<int>& balls) {

        memset(l, 0, sizeof l);
        memset(r, 0, sizeof r);

        fact[0] = 1;
        for(int i = 1; i < 50; i ++)  fact[i] = fact[i - 1] * i;

        n = balls.size();
        sum = 0;
        for(int i = 0; i < n; i ++)
        {
            a[i] = balls[i];
            sum += a[i];
        }

        up = 0;
        dfs(0, 0);
        down = get(a);

        printf("up = %lf\n", up);
        printf("down = %lf\n", down);
        return up / down;
    }
};


活动打卡代码 AcWing 353. 雨天的尾巴

离散化 树上差分 线段树合并

#include<cstdio>
#include<cstring>
#include<iostream>
#include<algorithm>

using namespace std;

const int N = 100010, M = N * 2;

int h[N], e[M], ne[M], idx;
void add(int a, int b)
{
    e[idx] = b, ne[idx] = h[a], h[a] = idx ++;
}

int f[N][18], dep[N];
void dfs(int u, int fa)
{
    for(int i = h[u]; ~i; i = ne[i])
    {
        int j = e[i];
        if(j == fa)  continue;

        dep[j] = dep[u] + 1;
        f[j][0] = u;
        for(int k = 1; k < 18; k ++)
            f[j][k] = f[f[j][k - 1]][k - 1];

        dfs(j, u);
    }
}

int lca(int a, int b)
{
    if(dep[a] < dep[b])  swap(a, b);

    for(int i = 17; i >= 0; i --)
        if(dep[f[a][i]] >= dep[b])
            a = f[a][i];

    if(a == b)  return a;

    for(int i = 17; i >= 0; i --)
        if(f[a][i] != f[b][i])
        {
            a = f[a][i];
            b = f[b][i];
        }
    return f[a][0];
}

struct Node{
    int l, r;
    int id, d;      //  存放最多的种类,最多的种类的数量

    int ls, rs;
}tr[N * 50];

int n, m;
int root[N], tot;

int get_node(int l, int r)
{
    int u = ++ tot;
    tr[u] = {l, r};
    return u;
}

void pushup(int u)
{
    Node &left = tr[tr[u].ls], &right = tr[tr[u].rs];
    if(left.d >= right.d)  tr[u].id = left.id, tr[u].d = left.d;
    else  tr[u].id = right.id, tr[u].d = right.d;
}

void insert(int &u, int l, int r, int x, int c)
{
    if(!u)  u = get_node(l, r);

    if(l == r)
    {
        tr[u].id = l;
        tr[u].d += c;
        return;
    }

    int mid = l + r >> 1;
    if(x <= mid)  insert(tr[u].ls, l, mid, x, c);
    else  insert(tr[u].rs, mid + 1, r, x, c);
    pushup(u);
}

int merge(int x, int y)
{
    if(!x or !y)  return x + y;

    int u = x;
    if(tr[u].l == tr[u].r)
    {
        tr[u].d = tr[x].d + tr[y].d;
    }
    else
    {
        if(tr[x].ls or tr[y].ls)  tr[u].ls = merge(tr[x].ls, tr[y].ls);
        if(tr[x].rs or tr[y].rs)  tr[u].rs = merge(tr[x].rs, tr[y].rs);
        pushup(u);
    }
    return u;
}

int res[N];

void dfs2(int u, int fa)
{
    for(int i = h[u]; ~i; i = ne[i])
    {
        int j = e[i];
        if(j == fa)  continue;
        dfs2(j, u);
        merge(root[u], root[j]);
    }
}

int x[N], y[N], z[N];

vector<int> v;
int find(int x)
{
    return lower_bound(v.begin(), v.end(), x) - v.begin();
}

int main()
{
    memset(h, -1, sizeof h);

    scanf("%d%d", &n, &m);
    for(int i = 1; i < n; i ++)
    {
        int a, b;
        scanf("%d%d", &a, &b);
        add(a, b),  add(b, a);
    }

    dep[1] = 1;
    dfs(1, -1);

    for(int i = 0; i < m; i ++)
    {
        scanf("%d%d%d", &x[i], &y[i], &z[i]);
        v.push_back(z[i]);
    }

    v.push_back(0);
    sort(v.begin(), v.end());
    v.erase(unique(v.begin(), v.end()), v.end());
    int INF = v.back();

    for(int i = 1; i <= n; i ++)  root[i] = get_node(0, v.size() - 1);

    for(int i = 0; i < m; i ++)
    {
        int a = x[i], b = y[i], c = find(z[i]);

        if(a == b)
        {
            insert(root[a], 0, v.size() - 1, c, 1);
            if(f[a][0])  insert(root[f[a][0]], 0, v.size() - 1, c, -1);
        }
        else
        {
            insert(root[a], 0, v.size() - 1, c, 1);
            insert(root[b], 0, v.size() - 1, c, 1);

            int p = lca(a, b);
            insert(root[p], 0, v.size() - 1, c, -1);
            if(f[p][0])  insert(root[f[p][0]], 0, v.size() - 1, c, -1);
        }
    }

    dfs2(1, -1);

    for(int i = 1; i <= n; i ++)  printf("%d\n", v[tr[root[i]].id]);
    return 0;
}



题目描述

给定一棵有根树,每次给两个节点询问他们的祖孙关系。


dfs序预处理st表

$dfs$ 预处理出每个节点的深度 $dep[u]$
第一次遍历到这个节点在第几步 $first[u]$
第几步遍历到哪个节点(考虑回溯)$a[i]$

st表lca.png

$lca(a, b)$ 一定为 $first[a]$ 到 $first[b]$ 之间遍历过的深度最小的点

预处理出 $a[i]$ 之后即可预处理 $st$ 表存储第 $i$ 步及之后的 $2^j$ 步遍历到的深度最小的节点

对于每个 $lca$ 询问,我们都可以转化为 $st$ 表中区间最小数的查询,即可实现 $O(1)$ 问答

时间复杂度 $O(nlogn + m)$

C++ 代码

#include<cstdio>
#include<cstring>
#include<algorithm>
#include<cmath>

using namespace std;

const int N = 40010, M = N * 2;

int h[N], e[M], ne[M], idx;
void add(int a, int b)
{
    e[idx] = b, ne[idx] = h[a], h[a] = idx ++;
}

int n, m;
int root;

int dep[N], dfn[N], step;       //  dfn[u]表示首次遍历到u号节点是第几步
int a[M];                       //  a[i]表示第i步遍历到的节点编号

void dfs(int u, int fa)
{
    dfn[u] = ++ step;
    a[step] = u;

    for(int i = h[u]; ~i; i = ne[i])
    {
        int j = e[i];
        if(j == fa)  continue;
        dep[j] = dep[u] + 1;
        dfs(j, u);
        a[++ step] = u;
    }
}

int f[M][17];
void init()
{
    for(int j = 0; j < 17; j ++)
        for(int i = 1; i + (1 << j) - 1 <= n * 2; i ++)
            if(!j)  f[i][j] = a[i];
            else
            {
                if(dep[f[i][j - 1]] < dep[f[i + (1 << j - 1)][j - 1]])  f[i][j] = f[i][j - 1];
                else  f[i][j] = f[i + (1 << j - 1)][j - 1];
            }
}

int query(int l, int r)
{
    if(l > r)  swap(l, r);
    int k = log2(r - l + 1);
    int a = f[l][k], b = f[r - (1 << k) + 1][k];
    if(dep[a] < dep[b])  return a;
    return b;
}

int lca(int a, int b)
{
    return query(dfn[a], dfn[b]);
}

int main()
{
    memset(h, -1, sizeof h);
    memset(dfn, 0x3f, sizeof dfn);

    scanf("%d", &n);
    for(int i = 0; i < n; i ++)
    {
        int a, b;
        scanf("%d%d", &a, &b);
        if(b == -1)  root = a;
        else  add(a, b), add(b, a);
    }

    dep[root] = 1;
    dfs(root, -1);
    init();

    //for(int i = 1; i <= n; i ++)  printf("dfn[%d] = %d\n", i, dfn[i]);  puts("");
    //for(int i = 1; i <= step; i ++)  printf("%d ", a[i]);  puts("");

    int m;
    scanf("%d", &m);
    while(m --)
    {
        int a, b;
        scanf("%d%d", &a, &b);
        int t = lca(a, b);
        if(t == a)  puts("1");
        else if(t == b)  puts("2");
        else  puts("0");
    }
    return 0;
}

倍增求lca

时间复杂度 $O((n + m)logn)$

C++ 代码

#include<iostream>
#include<algorithm>
#include<cstring>
#include<queue>

using namespace std;

const int N = 40010, M = N * 2;

int h[N], e[M], ne[M], idx;
void add(int a, int b)
{
    e[idx] = b;
    ne[idx] = h[a];
    h[a] = idx ++;
}

int n, m;
int depth[N];
int fa[N][17];
void bfs(int root)
{
    depth[root] = 1;

    queue<int> q;
    q.push(root);
    while(q.size())
    {
        int u = q.front();  q.pop();

        for(int i = h[u]; ~i; i = ne[i])
        {
            int j = e[i];

            if(depth[j])  continue;

            depth[j] = depth[u] + 1;
            q.push(j);

            fa[j][0] = u;
            for(int k = 1; k < 17; k ++)
                fa[j][k] = fa[fa[j][k - 1]][k - 1];
        }
    }
}

int lca(int a, int b)
{
    if(depth[a] < depth[b])  swap(a, b);

    for(int i = 16; i >= 0; i --)
        if(depth[fa[a][i]] >= depth[b])
            a = fa[a][i];

    if(a == b)  return a;

    for(int i = 16; i >= 0; i --)
        if(fa[a][i] != fa[b][i])
        {
            a = fa[a][i];
            b = fa[b][i];
        }
    return fa[a][0];
}

int main()
{
    scanf("%d", &n);

    memset(h, -1, sizeof h);
    int root;
    for(int i = 0; i < n; i ++)
    {
        int a, b;
        scanf("%d%d", &a, &b);
        if(b == -1)  root = a;
        else{
            add(a, b);
            add(b, a);
        }
    }

    bfs(root);
    scanf("%d", &m);
    while(m --)
    {
        int a, b;
        scanf("%d%d", &a, &b);
        int p = lca(a, b);

        if(p == a)  puts("1");
        else if(p == b)  puts("2");
        else  puts("0");
    }
    return 0;
}