AcWing
  • 首页
  • 课程
  • 题库
  • 更多
    • 竞赛
    • 题解
    • 分享
    • 问答
    • 应用
    • 校园
  • 关闭
    历史记录
    清除记录
    猜你想搜
    AcWing热点
  • App
  • 登录/注册

线段树套Splay

作者: 作者的头像   Ayanami_Rei ,  2025-02-02 23:54:03 ,  所有人可见 ,  阅读 11


0


平衡树维护节点寻找前驱和后继的证明

IMG_1785.jpg

注意事项

线段树维护的每个区间由Splay构成
因为可能存在重复元素, 因此Splay要维护cnt, size

将val所在的节点转到根节点

保证节点一定存在, 并且是和目标val差距最小的节点

void upper(int &root, int val) {
    int _node = root;

//  如果当前节点不存在儿子或者val和当前节点的值相等退出
    while (tr[_node].son[val > tr[_node].val] && tr[_node].val != val) {
        _node = tr[_node].son[val > tr[_node].val];
    }

    splay(root, _node, 0);
}

计算前驱和后继

int get_pre(int &root, int val) {
    upper(root, val);
//  找前驱, 如果当前点的值小于val那么当前点就是前驱
    if (tr[root].val < val) return root;
    int _node = tr[root].son[0];
    while (tr[_node].son[1]) _node = tr[_node].son[1];
    return _node;
}

int get_suff(int &root, int val) {
    upper(root, val);
    if (tr[root].val > val) return root;
    int _node = tr[root].son[1];
    while (tr[_node].son[0]) _node = tr[_node].son[0];
    return _node;
}

完整代码

#include <iostream>
#include <algorithm>

using namespace std;

const int NUMBER = 2000010, INF = 2e31 - 1;

int number, op_number;
int arr[NUMBER];
struct Node {
    int son[2], pre, val;
    int cnt, size;

    void init(int _pre, int _val) {
        pre = _pre;
        val = _val;
        cnt = size = 1;
    }
} tr[NUMBER];
int idx;
int lb[NUMBER], rb[NUMBER], id[NUMBER];

void push_up(int _index) {
    tr[_index].size = tr[tr[_index].son[0]].size + tr[tr[_index].son[1]].size + tr[_index].cnt;
}

void rotate(int _index) {
    int pre = tr[_index].pre;
    int grand = tr[pre].pre;
    int k = tr[pre].son[1] == _index;

    tr[grand].son[tr[grand].son[1] == pre] = _index, tr[_index].pre = grand;
    tr[pre].son[k] = tr[_index].son[k ^ 1], tr[tr[_index].son[k ^ 1]].pre = pre;
    tr[_index].son[k ^ 1] = pre, tr[pre].pre = _index;

    push_up(pre), push_up(_index);
}

void splay(int &root, int _index, int k) {
    while (tr[_index].pre != k) {
        int pre = tr[_index].pre;
        int grand = tr[pre].pre;

        if (grand != k) {
            (tr[grand].son[1] == pre) ^ (tr[pre].son[1] == _index) ? rotate(_index) : rotate(pre);
        }
        rotate(_index);
    }

    if (!k) root = _index;
}

void insert(int &root, int val) {
    int _node = root, pre = 0;
    while (_node && tr[_node].val != val) {
        pre = _node;
        _node = tr[_node].son[val > tr[_node].val];
    }

    if (_node && tr[_node].val == val) {
        tr[_node].cnt++;
        splay(root, _node, 0);
        return;
    }

    _node = ++idx;
    if (pre) tr[pre].son[val > tr[pre].val] = _node;
    tr[_node].init(pre, val);
    splay(root, _node, 0);
}

void find_node(int &root, int val) {
    int _node = root;
    while (tr[_node].son[val > tr[_node].val] && tr[_node].val != val) {
        _node = tr[_node].son[val > tr[_node].val];
    }

    splay(root, _node, 0);
}

int get_pre(int &root, int val) {
    find_node(root, val);
    if (tr[root].val < val) return root;
    int _node = tr[root].son[0];
    while (tr[_node].son[1]) _node = tr[_node].son[1];
    return _node;
}

int get_suff(int &root, int val) {
    find_node(root, val);
    if (tr[root].val > val) return root;
    int _node = tr[root].son[1];
    while (tr[_node].son[0]) _node = tr[_node].son[0];
    return _node;
}

void remove(int &root, int val) {
    int pre = get_pre(root, val);
    int suff = get_suff(root, val);
    splay(root, pre, 0), splay(root, suff, pre);

    int _node = tr[suff].son[0];
    if (!_node) return;

    if (tr[_node].cnt > 1) {
        tr[_node].cnt--;
        splay(root, _node, 0);
        return;
    }

    tr[suff].son[0] = 0;
    push_up(suff), push_up(pre);
    splay(root, suff, 0);
}

void update(int &root, int val1, int val2) {
    remove(root, val1);
    insert(root, val2);
}

int get_rank(int root, int val) {
    int _node = root, res = 0;

    while (_node) {
        if (tr[_node].val < val) {
            res += tr[tr[_node].son[0]].size + tr[_node].cnt;
            _node = tr[_node].son[1];
        }
        else _node = tr[_node].son[0];
    }

    return res - 1;
}

void build(int _index, int l, int r) {
    lb[_index] = l, rb[_index] = r;
    insert(id[_index], -INF), insert(id[_index], INF);
    for (int i = l; i <= r; ++i) insert(id[_index], arr[i]);

    if (l == r) return;

    int mid = l + r >> 1;
    build(_index << 1, l, mid);
    build(_index << 1 | 1, mid + 1, r);
}

void modify(int _index, int pos, int val) {
    update(id[_index], arr[pos], val);
    if (lb[_index] == rb[_index]) return;

    int mid = lb[_index] + rb[_index] >> 1;
    if (pos <= mid) modify(_index << 1, pos, val);
    else modify(_index << 1 | 1, pos, val);
}

int get_rank(int _index, int l, int r, int val) {
    if (lb[_index] >= l && rb[_index] <= r) return get_rank(id[_index], val);

    int res = 0;
    int mid = lb[_index] + rb[_index] >> 1;
    if (l <= mid) res += get_rank(_index << 1, l, r, val);
    if (r > mid) res += get_rank(_index << 1 | 1, l, r, val);

    return res;
}

int get_pre(int _index, int l, int r, int val) {
    if (lb[_index] >= l && rb[_index] <= r) return tr[get_pre(id[_index], val)].val;

    int mid = lb[_index] + rb[_index] >> 1;
    int res = -INF;
    if (l <= mid) res = max(res, get_pre(_index << 1, l, r, val));
    if (r > mid) res = max(res, get_pre(_index << 1 | 1, l, r, val));
    return res;
}

int get_suff(int _index, int l, int r, int val) {
    if (lb[_index] >= l && rb[_index] <= r) return tr[get_suff(id[_index], val)].val;

    int mid = lb[_index] + rb[_index] >> 1;
    int res = INF;
    if (l <= mid) res = min(res, get_suff(_index << 1, l, r, val));
    if (r > mid) res = min(res, get_suff(_index << 1 | 1, l, r, val));
    return res;
}

int main() {
    ios::sync_with_stdio(false);
    cin.tie(nullptr), cout.tie(nullptr);

    cin >> number >> op_number;
    for (int i = 1; i <= number; ++i) cin >> arr[i];

    build(1, 1, number);

    while (op_number--) {
        int op;
        cin >> op;

        if (op == 1) {
            int l, r, val;
            cin >> l >> r >> val;
            cout << get_rank(1, l, r, val) + 1 << endl;
        }
        else if (op == 2) {
            int l, r, k;
            cin >> l >> r >> k;

            int left = 0, right = 1e8;
            while (left < right) {
                int mid = left + right + 1 >> 1;
                if (get_rank(1, l, r, mid) + 1 <= k) left = mid;
                else right = mid - 1;
            }

            cout << left << endl;
        }
        else if (op == 3) {
            int pos, val;
            cin >> pos >> val;
            modify(1, pos, val);
            arr[pos] = val;
        }
        else if (op == 4) {
            int l, r, val;
            cin >> l >> r >> val;
            cout << get_pre(1, l, r, val) << endl;
        }
        else {
            int l, r, val;
            cin >> l >> r >> val;
            cout << get_suff(1, l, r, val) << endl;
        }
    }

    return 0;
}

0 评论

App 内打开
你确定删除吗?
1024
x

© 2018-2025 AcWing 版权所有  |  京ICP备2021015969号-2
用户协议  |  隐私政策  |  常见问题  |  联系我们
AcWing
请输入登录信息
更多登录方式: 微信图标 qq图标 qq图标
请输入绑定的邮箱地址
请输入注册信息