平衡树维护节点寻找前驱和后继的证明
注意事项
线段树维护的每个区间由Splay
构成
因为可能存在重复元素, 因此Splay
要维护cnt
, size
将val
所在的节点转到根节点
保证节点一定存在, 并且是和目标val
差距最小的节点
void upper(int &root, int val) {
int _node = root;
// 如果当前节点不存在儿子或者val和当前节点的值相等退出
while (tr[_node].son[val > tr[_node].val] && tr[_node].val != val) {
_node = tr[_node].son[val > tr[_node].val];
}
splay(root, _node, 0);
}
计算前驱和后继
int get_pre(int &root, int val) {
upper(root, val);
// 找前驱, 如果当前点的值小于val那么当前点就是前驱
if (tr[root].val < val) return root;
int _node = tr[root].son[0];
while (tr[_node].son[1]) _node = tr[_node].son[1];
return _node;
}
int get_suff(int &root, int val) {
upper(root, val);
if (tr[root].val > val) return root;
int _node = tr[root].son[1];
while (tr[_node].son[0]) _node = tr[_node].son[0];
return _node;
}
完整代码
#include <iostream>
#include <algorithm>
using namespace std;
const int NUMBER = 2000010, INF = 2e31 - 1;
int number, op_number;
int arr[NUMBER];
struct Node {
int son[2], pre, val;
int cnt, size;
void init(int _pre, int _val) {
pre = _pre;
val = _val;
cnt = size = 1;
}
} tr[NUMBER];
int idx;
int lb[NUMBER], rb[NUMBER], id[NUMBER];
void push_up(int _index) {
tr[_index].size = tr[tr[_index].son[0]].size + tr[tr[_index].son[1]].size + tr[_index].cnt;
}
void rotate(int _index) {
int pre = tr[_index].pre;
int grand = tr[pre].pre;
int k = tr[pre].son[1] == _index;
tr[grand].son[tr[grand].son[1] == pre] = _index, tr[_index].pre = grand;
tr[pre].son[k] = tr[_index].son[k ^ 1], tr[tr[_index].son[k ^ 1]].pre = pre;
tr[_index].son[k ^ 1] = pre, tr[pre].pre = _index;
push_up(pre), push_up(_index);
}
void splay(int &root, int _index, int k) {
while (tr[_index].pre != k) {
int pre = tr[_index].pre;
int grand = tr[pre].pre;
if (grand != k) {
(tr[grand].son[1] == pre) ^ (tr[pre].son[1] == _index) ? rotate(_index) : rotate(pre);
}
rotate(_index);
}
if (!k) root = _index;
}
void insert(int &root, int val) {
int _node = root, pre = 0;
while (_node && tr[_node].val != val) {
pre = _node;
_node = tr[_node].son[val > tr[_node].val];
}
if (_node && tr[_node].val == val) {
tr[_node].cnt++;
splay(root, _node, 0);
return;
}
_node = ++idx;
if (pre) tr[pre].son[val > tr[pre].val] = _node;
tr[_node].init(pre, val);
splay(root, _node, 0);
}
void find_node(int &root, int val) {
int _node = root;
while (tr[_node].son[val > tr[_node].val] && tr[_node].val != val) {
_node = tr[_node].son[val > tr[_node].val];
}
splay(root, _node, 0);
}
int get_pre(int &root, int val) {
find_node(root, val);
if (tr[root].val < val) return root;
int _node = tr[root].son[0];
while (tr[_node].son[1]) _node = tr[_node].son[1];
return _node;
}
int get_suff(int &root, int val) {
find_node(root, val);
if (tr[root].val > val) return root;
int _node = tr[root].son[1];
while (tr[_node].son[0]) _node = tr[_node].son[0];
return _node;
}
void remove(int &root, int val) {
int pre = get_pre(root, val);
int suff = get_suff(root, val);
splay(root, pre, 0), splay(root, suff, pre);
int _node = tr[suff].son[0];
if (!_node) return;
if (tr[_node].cnt > 1) {
tr[_node].cnt--;
splay(root, _node, 0);
return;
}
tr[suff].son[0] = 0;
push_up(suff), push_up(pre);
splay(root, suff, 0);
}
void update(int &root, int val1, int val2) {
remove(root, val1);
insert(root, val2);
}
int get_rank(int root, int val) {
int _node = root, res = 0;
while (_node) {
if (tr[_node].val < val) {
res += tr[tr[_node].son[0]].size + tr[_node].cnt;
_node = tr[_node].son[1];
}
else _node = tr[_node].son[0];
}
return res - 1;
}
void build(int _index, int l, int r) {
lb[_index] = l, rb[_index] = r;
insert(id[_index], -INF), insert(id[_index], INF);
for (int i = l; i <= r; ++i) insert(id[_index], arr[i]);
if (l == r) return;
int mid = l + r >> 1;
build(_index << 1, l, mid);
build(_index << 1 | 1, mid + 1, r);
}
void modify(int _index, int pos, int val) {
update(id[_index], arr[pos], val);
if (lb[_index] == rb[_index]) return;
int mid = lb[_index] + rb[_index] >> 1;
if (pos <= mid) modify(_index << 1, pos, val);
else modify(_index << 1 | 1, pos, val);
}
int get_rank(int _index, int l, int r, int val) {
if (lb[_index] >= l && rb[_index] <= r) return get_rank(id[_index], val);
int res = 0;
int mid = lb[_index] + rb[_index] >> 1;
if (l <= mid) res += get_rank(_index << 1, l, r, val);
if (r > mid) res += get_rank(_index << 1 | 1, l, r, val);
return res;
}
int get_pre(int _index, int l, int r, int val) {
if (lb[_index] >= l && rb[_index] <= r) return tr[get_pre(id[_index], val)].val;
int mid = lb[_index] + rb[_index] >> 1;
int res = -INF;
if (l <= mid) res = max(res, get_pre(_index << 1, l, r, val));
if (r > mid) res = max(res, get_pre(_index << 1 | 1, l, r, val));
return res;
}
int get_suff(int _index, int l, int r, int val) {
if (lb[_index] >= l && rb[_index] <= r) return tr[get_suff(id[_index], val)].val;
int mid = lb[_index] + rb[_index] >> 1;
int res = INF;
if (l <= mid) res = min(res, get_suff(_index << 1, l, r, val));
if (r > mid) res = min(res, get_suff(_index << 1 | 1, l, r, val));
return res;
}
int main() {
ios::sync_with_stdio(false);
cin.tie(nullptr), cout.tie(nullptr);
cin >> number >> op_number;
for (int i = 1; i <= number; ++i) cin >> arr[i];
build(1, 1, number);
while (op_number--) {
int op;
cin >> op;
if (op == 1) {
int l, r, val;
cin >> l >> r >> val;
cout << get_rank(1, l, r, val) + 1 << endl;
}
else if (op == 2) {
int l, r, k;
cin >> l >> r >> k;
int left = 0, right = 1e8;
while (left < right) {
int mid = left + right + 1 >> 1;
if (get_rank(1, l, r, mid) + 1 <= k) left = mid;
else right = mid - 1;
}
cout << left << endl;
}
else if (op == 3) {
int pos, val;
cin >> pos >> val;
modify(1, pos, val);
arr[pos] = val;
}
else if (op == 4) {
int l, r, val;
cin >> l >> r >> val;
cout << get_pre(1, l, r, val) << endl;
}
else {
int l, r, val;
cin >> l >> r >> val;
cout << get_suff(1, l, r, val) << endl;
}
}
return 0;
}