留着复习
j = heavyson[i]
gi,1 lightson[i] 可选可不选
gi,0 lightson[i] 都不可选
fi,0 = gi,1 + max(fj,0, fj,1)
fi,1 = g1,0 + a[i] + fj,0
[a b] * [e f] = [max(a+e, b+g) max(a+f, b+h)]
[c d] [g h] [max(c+e, d+g) max(c+f, d+h)]
[gi,1 gi,1] * [fj,0] = [fi,0]
[gi,0+a[i] -INF] [fj,1] [fi,1]
链尾的转移矩阵等价于f矩阵
修改步骤
- 树上单点转移矩阵修改
- 定位链尾
- 查询重链转移矩阵 求出链顶f矩阵
- 用链顶的f矩阵修改链顶父节点的转移矩阵 (链顶父节点一定是不是另一条链的链尾)
一路修改至根节点
查询时直接查询根节点所在重链的矩阵连乘积
#include <iostream>
#include <cstring>
using namespace std;
using i64 = long long;
const int N = 100110, M = 200210;
const int INF = 0x3f3f3f3f;
int a[N];
int e[M], ne[M], h[N], idx;
int seq[N], id[N], ts;
int dep[N], son[N], top[N], tail[N], fa[N], sz[N];
i64 f[N][2];
int n, m;
/*
j = heavyson[i]
gi,1 lightson[i] 可选可不选
gi,0 lightson[i] 都不可选
fi,0 = gi,1 + max(fj,0, fj,1)
fi,1 = g1,0 + a[i] + fj,0
[a b] * [e f] = [max(a+e, b+g) max(a+f, b+h)]
[c d] [g h] [max(c+e, d+g) max(c+f, d+h)]
[gi,1 gi,1] * [fj,0] = [fi,0]
[gi,0+a[i] -INF] [fj,1] [fi,1]
*/
struct Matrix {
int val[2][2];
Matrix() {
memset(val, -0x3f, sizeof val);
}
Matrix operator* (Matrix b) {
Matrix res;
res.val[0][0] = max(val[0][0] + b.val[0][0], val[0][1] + b.val[1][0]);
res.val[0][1] = max(val[0][0] + b.val[0][1], val[0][1] + b.val[1][1]);
res.val[1][0] = max(val[1][0] + b.val[0][0], val[1][1] + b.val[1][0]);
res.val[1][1] = max(val[1][0] + b.val[0][1], val[1][1] + b.val[1][1]);
return res;
}
}v[N];
struct Node {
int l, r;
Matrix mat;
}tr[N << 2];
void insert(int u, int v) {
e[idx] = v, ne[idx] = h[u], h[u] = idx++;
}
void pushup(Node& u, Node& l, Node& r) {
u.mat = l.mat * r.mat;
}
void pushup(int u) {
pushup(tr[u], tr[u << 1], tr[u << 1 | 1]);
}
void build(int u, int l, int r) {
tr[u] = {l, r};
if (l == r) return tr[u].mat = v[seq[l]], void();
int mid = l + r >> 1;
build(u << 1, l, mid);
build(u << 1 | 1, mid + 1, r);
pushup(u);
}
void modify(int u, int x) {
if (tr[u].l == tr[u].r) return tr[u].mat = v[seq[x]], void();
int mid = tr[u].l + tr[u].r >> 1;
if (x <= mid) modify(u << 1, x);
else modify(u << 1 | 1, x);
pushup(u);
}
Node query(int u, int l, int r) {
if (tr[u].l >= l && tr[u].r <= r) return tr[u];
int mid = tr[u].l + tr[u].r >> 1;
if (r <= mid) return query(u << 1, l, r);
if (l > mid) return query(u << 1 | 1, l, r);
Node res, left = query(u << 1, l, r), right = query(u << 1 | 1, l, r);
pushup(res, left, right);
return res;
}
void dfs1(int u, int father) {
dep[u] = dep[father] + 1, fa[u] = father, sz[u] = 1;
for (int i = h[u]; ~i; i = ne[i]) {
int j = e[i];
if (j == father) continue;
dfs1(j, u);
sz[u] += sz[j];
if (sz[son[u]] < sz[j]) son[u] = j;
}
}
void dfs2(int u, int t) {
id[u] = ++ts, seq[ts] = u, top[u] = t;
tail[t] = max(tail[t], ts);
f[u][0] = 0, f[u][1] = a[u];
v[u].val[1][0] = a[u];
v[u].val[0][0] = v[u].val[0][1] = 0;
if (!son[u]) return;
dfs2(son[u], t);
f[u][0] += max(f[son[u]][0], f[son[u]][1]);
f[u][1] += f[son[u]][0];
for (int i = h[u]; ~i; i = ne[i]) {
int j = e[i];
if (j == fa[u] || j == son[u]) continue;
dfs2(j, j);
f[u][0] += max(f[j][0], f[j][1]);
f[u][1] += f[j][0];
v[u].val[0][0] += max(f[j][0], f[j][1]);
v[u].val[1][0] += f[j][0];
}
v[u].val[0][1] = v[u].val[0][0];
}
/*
j = heavyson[i]
gi,1 lightson[i] 可选可不选
gi,0 lightson[i] 都不可选
fi,0 = gi,1 + max(fj,0, fj,1)
fi,1 = g1,0 + a[i] + fj,0
[a b] * [e f] = [max(a+e, b+g) max(a+f, b+h)]
[c d] [g h] [max(c+e, d+g) max(c+f, d+h)]
[gi,1 gi,1] * [fj,0] = [fi,0]
[gi,0+a[i] -INF] [fj,1] [fi,1]
*/
/*
注意 链中维护的是转移矩阵 链尾维护的是f矩阵
*/
void modify_path(int u, int k) {
v[u].val[1][0] += k - a[u];
a[u] = k;
Matrix before, after;
while (u) {
before = query(1, id[top[u]], tail[top[u]]).mat;
modify(1, id[u]);
after = query(1, id[top[u]], tail[top[u]]).mat;
u = fa[top[u]]; // 跳上去一定不是链尾
v[u].val[0][0] += max(after.val[0][0], after.val[1][0]) - max(before.val[0][0], before.val[1][0]); ;
v[u].val[0][1] = v[u].val[0][0];
v[u].val[1][0] += after.val[0][0] - before.val[0][0];
}
}
signed main() {
cin.tie(0)->sync_with_stdio(false);
cin >> n >> m;
for (int i = 1; i <= n; i++) cin >> a[i];
memset(h, -1, sizeof h);
for (int i = 1; i < n; i++) {
int u, v;
cin >> u >> v;
insert(u, v);
insert(v, u);
}
dfs1(1, 0), dfs2(1, 1), build(1, 1, n);
while (m--) {
int u, v;
cin >> u >> v;
modify_path(u, v);
Node res = query(1, id[1], tail[1]);
cout << max(res.mat.val[0][0], res.mat.val[1][0]) << '\n';
}
return 0;
}
%%%