题意翻译成人话:
在a到b的有向路径上选一个点买入 在买入点后面选一个点卖出
使得收益最大 查询完之后执行链加
考虑维护链上从左往右走的最大收益和从右往左走的最大收益
以及链上最大最小值
用一边max减去另一边min再和左右子节点收益取max进行pushup
树剖和lct均可实现
树剖链查合并时要注意带方向合并
(1log的lct比2log的树剖慢三倍 怎么回事呢)
树剖 $O(nlog^2n)$ $2000ms$
#include <iostream>
#include <cstring>
using namespace std;
const int N = 50010, M = 100010;
const int INF = 1e9;
int e[M], ne[M], h[N], idx;
int w[N], nw[N];
int son[N], fa[N], sz[N], dep[N], top[N], ts;
int id[N];
int n, m;
struct Node {
int l, r;
int ls, rs;
int mx, mn;
int add;
}tr[N << 2];
void insert(int u, int v) {
e[idx] = v, ne[idx] = h[u], h[u] = idx++;
}
void dfs1(int u, int father, int depth) {
dep[u] = depth, fa[u] = father, sz[u] = 1;
for (int i = h[u]; ~i; i = ne[i]) {
int j = e[i];
if (j == father) continue;
dfs1(j, u, depth + 1);
sz[u] += sz[j];
if (sz[j] > sz[son[u]]) son[u] = j;
}
}
void dfs2(int u, int t) {
id[u] = ++ts, nw[ts] = w[u], top[u] = t;
if (!son[u]) return;
dfs2(son[u], t);
for (int i = h[u]; ~i; i = ne[i]) {
int j = e[i];
if (j == fa[u] || j == son[u]) continue;
dfs2(j, j);
}
}
void pushup(Node& u, Node& l, Node& r) {
u.rs = max(max(l.rs, r.rs), r.mx - l.mn);
u.ls = max(max(l.ls, r.ls), l.mx - r.mn);
u.mx = max(l.mx, r.mx);
u.mn = min(l.mn, r.mn);
}
void pushup(int u) {
pushup(tr[u], tr[u << 1], tr[u << 1 | 1]);
}
void build(int u, int l, int r) {
tr[u] = {l, r, 0, 0, nw[l], nw[l], 0};
if (l == r) return;
int mid = l + r >> 1;
build(u << 1, l, mid);
build(u << 1 | 1, mid + 1, r);
pushup(u);
}
void pushadd(Node& u, int add) {
u.add += add;
u.mx += add;
u.mn += add;
}
void pushdown(int u) {
auto &rt = tr[u], &l = tr[u << 1], &r = tr[u << 1 | 1];
if (rt.add) {
pushadd(l, rt.add);
pushadd(r, rt.add);
rt.add = 0;
}
}
void modify(int u, int l, int r, int v) {
if (tr[u].l >= l && tr[u].r <= r) {
pushadd(tr[u], v);
return;
}
pushdown(u);
int mid = tr[u].l + tr[u].r >> 1;
if (l <= mid) modify(u << 1, l, r, v);
if (r > mid) modify(u << 1 | 1, l, r, v);
pushup(u);
}
Node query(int u, int l, int r) {
if (tr[u].l >= l && tr[u].r <= r) return tr[u];
int mid = tr[u].l + tr[u].r >> 1;
pushdown(u);
if (r <= mid) return query(u << 1, l, r);
if (l > mid) return query(u << 1 | 1, l, r);
Node res;
Node left = query(u << 1, l, r);
Node right = query(u << 1 | 1, l, r);
pushup(res, left, right);
return res;
}
void modify_path(int u, int v, int c) {
while (top[u] != top[v]) {
if (dep[top[u]] < dep[top[v]]) swap(u, v);
modify(1, id[top[u]], id[u], c);
u = fa[top[u]];
}
if (dep[u] < dep[v]) swap(u, v);
modify(1, id[v], id[u], c);
}
Node query_path(int u, int v) {
bool rev = 0;
Node left = {0, 0, 0, 0, 0, INF, 0}, right = {0, 0, 0, 0, 0, INF, 0};
while (top[u] != top[v]) {
if (dep[top[u]] < dep[top[v]]) swap(left, right), swap(u, v), rev ^= 1;
Node tmp = query(1, id[top[u]], id[u]);
pushup(left, tmp, left);
u = fa[top[u]];
}
if (dep[u] < dep[v]) swap(left, right), swap(u, v), rev ^= 1;
Node tmp = query(1, id[v], id[u]);
pushup(left, tmp, left);
Node res = {0, 0, 0, 0, 0, 0, 0};
if (rev) swap(left, right);
swap(left.ls, left.rs);
pushup(res, left, right);
return res;
}
int main() {
scanf("%d", &n);
for (int i = 1; i <= n; i++) scanf("%d", w + i);
memset(h, -1, sizeof h);
for (int i = 1; i < n; i++) {
int u, v;
scanf("%d%d", &u, &v);
insert(u, v);
insert(v, u);
}
dfs1(1, -1, 0);
dfs2(1, 1);
build(1, 1, n);
scanf("%d", &m);
while (m--) {
int a, b, c;
scanf("%d%d%d", &a, &b, &c);
printf("%d\n", query_path(a, b).rs);
modify_path(a, b, c);
}
return 0;
}
$LCT$ $O(nlogn)$ $6000ms$
#include <iostream>
using namespace std;
const int N = 100010;
int n, m;
struct Node {
int s[2], p;
int mx, mn;
int val;
int ls, rs;
int add;
bool rev;
void cleartag() {
add = rev = 0;
}
}tr[N];
bool isroot(int u) {
return tr[tr[u].p].s[0] != u && tr[tr[u].p].s[1] != u;
}
void pushup(int x) {
tr[x].mx = max(max(tr[tr[x].s[0]].mx, tr[tr[x].s[1]].mx), tr[x].val);
tr[x].mn = min(min(tr[tr[x].s[0]].mn, tr[tr[x].s[1]].mn), tr[x].val);
tr[x].ls = max(max(tr[tr[x].s[0]].ls, tr[tr[x].s[1]].ls),
max(tr[tr[x].s[0]].mx, tr[x].val) - min(tr[tr[x].s[1]].mn, tr[x].val));
tr[x].rs = max(max(tr[tr[x].s[0]].rs, tr[tr[x].s[1]].rs),
max(tr[tr[x].s[1]].mx, tr[x].val) - min(tr[tr[x].s[0]].mn, tr[x].val));
}
void pushadd(int x, int v) {
tr[x].val += v;
tr[x].add += v;
tr[x].mx += v;
tr[x].mn += v;
}
void pushrev(int x) {
swap(tr[x].ls, tr[x].rs);
swap(tr[x].s[0], tr[x].s[1]);
tr[x].rev ^= 1;
}
void pushdown(int x) {
if (tr[x].add) {
if (tr[x].s[0]) pushadd(tr[x].s[0], tr[x].add);
if (tr[x].s[1]) pushadd(tr[x].s[1], tr[x].add);
}
if (tr[x].rev) {
if (tr[x].s[0]) pushrev(tr[x].s[0]);
if (tr[x].s[1]) pushrev(tr[x].s[1]);
}
tr[x].cleartag();
}
void rotate(int x) {
int y = tr[x].p, z = tr[y].p;
int k = tr[y].s[1] == x;
if (!isroot(y)) tr[z].s[tr[z].s[1] == y] = x;
tr[x].p = z;
tr[y].s[k] = tr[x].s[k ^ 1], tr[tr[x].s[k ^ 1]].p = y;
tr[x].s[k ^ 1] = y, tr[y].p = x;
pushup(y), pushup(x);
}
void splay(int x) {
static int stk[N];
int tt = 0, r = x;
stk[++tt] = r;
while (!isroot(r)) stk[++tt] = r = tr[r].p;
while (tt) pushdown(stk[tt--]);
while (!isroot(x)) {
int y = tr[x].p, z = tr[y].p;
if (!isroot(y)) {
if ((tr[z].s[1] == y) ^ (tr[y].s[1] == x)) rotate(x);
else rotate(y);
}
rotate(x);
}
}
void access(int x) {
int z = x;
for (int y = 0; x; y = x, x = tr[x].p) {
splay(x);
tr[x].s[1] = y, pushup(x);
}
splay(z);
}
void makeroot(int x) {
access(x);
pushrev(x);
}
int findroot(int x) {
access(x);
while (tr[x].s[0]) pushdown(x), x = tr[x].s[0];
splay(x);
return x;
}
void split(int x, int y) {
makeroot(y);
access(x);
}
void link(int x, int y) {
makeroot(x);
if (findroot(y) != x) tr[x].p = y;
}
int main() {
scanf("%d", &n);
tr[0].mx = -2e9, tr[0].mn = 2e9;
for (int i = 1; i <= n; i++) {
scanf("%d", &tr[i].val);
tr[i].mx = tr[i].mn = tr[i].val;
}
for (int i = 1; i < n; i++) {
int u, v;
scanf("%d%d", &u, &v);
link(u, v);
}
scanf("%d", &m);
while (m--) {
int a, b, v;
scanf("%d%d%d", &a, &b, &v);
split(a, b);
printf("%d\n", tr[a].ls);
pushadd(a, v);
}
return 0;
}