网络管理
前置内容:
第K小数(静态区间第k大)
Count on a tree(静态树链第k大)
动态排名(动态区间第k大)
dfs序+树状数组套主席树+树上差分
#include <iostream>
#include <algorithm>
#include <cstring>
using namespace std;
const int N = 80010, M = 160001;
int n, m;
struct Node {
int l, r, cnt;
}tr[N * 300];
struct Query {
int k, a, b;
}q[N];
int nums[M], tot;
int w[N];
int root[N], node_idx;
int tr1[N], tr2[N], tr3[N], tr4[N], n1, n2, n3, n4;
int e[M], ne[M], h[N], idx;
int fa[N][17], dep[N], sz[N];
int dfn[N], ts;
int find(int x) {
int l = 0, r = tot - 1;
while (l < r) {
int mid = l + r >> 1;
if (nums[mid] >= x) r = mid;
else l = mid + 1;
}
return l;
}
void insert(int u, int v) {
e[idx] = v, ne[idx] = h[u], h[u] = idx++;
}
inline int lowbit(int x) {
return x & -x;
}
int lca(int a, int b) {
if (dep[a] < dep[b]) swap(a, b);
for (int k = 16; ~k; k--)
if (dep[fa[a][k]] >= dep[b])
a = fa[a][k];
if (a == b) return a;
for (int k = 16; ~k; k--)
if (fa[a][k] != fa[b][k])
a = fa[a][k], b = fa[b][k];
return fa[a][0];
}
void dfs(int u, int father) {
dfn[u] = ++ts, sz[u] = 1;
fa[u][0] = father, dep[u] = dep[father] + 1;
for (int i = 1; i <= 16; i++)
fa[u][i] = fa[fa[u][i - 1]][i - 1];
for (int i = h[u]; ~i; i = ne[i]) {
int j = e[i];
if (j == father) continue;
dfs(j, u);
sz[u] += sz[j];
}
}
void insert(int& p, int l, int r, int x, int v) {
if (!p) p = ++idx;
tr[p].cnt += v;
if (l == r) return;
int mid = l + r >> 1;
if (x <= mid) insert(tr[p].l, l, mid, x, v);
else insert(tr[p].r, mid + 1, r, x, v);
}
void add(int pos, int x, int v) {
for (; pos <= n; pos += lowbit(pos))
insert(root[pos], 0, tot - 1, x, v);
}
int query(int l, int r, int k) {
if (l == r) return nums[r];
int mid = l + r >> 1, cnt = 0;
for (int i = 1; i <= n1; i++) cnt += tr[tr[tr1[i]].l].cnt;
for (int i = 1; i <= n2; i++) cnt += tr[tr[tr2[i]].l].cnt;
for (int i = 1; i <= n3; i++) cnt -= tr[tr[tr3[i]].l].cnt;
for (int i = 1; i <= n4; i++) cnt -= tr[tr[tr4[i]].l].cnt;
if (cnt >= k) {
for (int i = 1; i <= n1; i++) tr1[i] = tr[tr1[i]].l;
for (int i = 1; i <= n2; i++) tr2[i] = tr[tr2[i]].l;
for (int i = 1; i <= n3; i++) tr3[i] = tr[tr3[i]].l;
for (int i = 1; i <= n4; i++) tr4[i] = tr[tr4[i]].l;
return query(l, mid, k);
}
else {
for (int i = 1; i <= n1; i++) tr1[i] = tr[tr1[i]].r;
for (int i = 1; i <= n2; i++) tr2[i] = tr[tr2[i]].r;
for (int i = 1; i <= n3; i++) tr3[i] = tr[tr3[i]].r;
for (int i = 1; i <= n4; i++) tr4[i] = tr[tr4[i]].r;
return query(mid + 1, r, k - cnt);
}
}
int query_main(int u, int v, int p, int q, int k) {
n1 = n2 = n3 = n4 = 0;
for (; u; u -= lowbit(u)) tr1[++n1] = root[u];
for (; v; v -= lowbit(v)) tr2[++n2] = root[v];
for (; p; p -= lowbit(p)) tr3[++n3] = root[p];
for (; q; q -= lowbit(q)) tr4[++n4] = root[q];
return query(0, tot - 1, k);
}
int main() {
cin.tie(0), cout.tie(0)->sync_with_stdio(false);
cin >> n >> m;
for (int i = 1; i <= n; i++) {
cin >> w[i];
nums[tot++] = w[i];
}
memset(h, -1, sizeof h);
for (int i = 1; i < n; i++) {
int u, v;
cin >> u >> v;
insert(u, v), insert(v, u);
}
for (int i = 1; i <= m; i++) {
auto& [k, a, b] = q[i];
cin >> k >> a >> b;
if (!k) nums[tot++] = b;
}
sort(nums, nums + tot);
tot = unique(nums, nums + tot) - nums;
dfs(1, 0);
for (int i = 1; i <= n; i++) {
w[i] = find(w[i]);
add(dfn[i], w[i], 1);
add(dfn[i] + sz[i], w[i], -1);
}
for (int i = 1; i <= m; i++) {
auto& [k, a, b] = q[i];
if (k) {
int p = lca(a, b), q = fa[p][0];
k = dep[a] + dep[b] - dep[p] - dep[q] - k + 1;
if (k <= 0) cout << "invalid request!\n";
else cout << query_main(dfn[a], dfn[b], dfn[p], dfn[q], k) << '\n';
}
else {
add(dfn[a], w[a], -1);
add(dfn[a] + sz[a], w[a], 1);
w[a] = find(b);
add(dfn[a], w[a], 1);
add(dfn[a] + sz[a], w[a], -1);
}
}
return 0;
}