(吉老师线段树)
题目背景
本题是线段树维护区间最值操作与区间历史最值的模板。
题目描述
给出一个长度为 $n$ 的数列 $A$,同时定义一个辅助数组 $B$,$B$开始与 $A$ 完全相同。接下来进行了 $m$次操作,操作有五种类型,按以下格式给出:
1 l r k
:对于所有的 $i\in[l,r]$,将 $A_i$ 加上 $k$($k$ 可以为负数)。2 l r v
:对于所有的 $i\in[l,r]$,将 $A_i$ 变成 $\min(A_i,v)$。3 l r
:求 $\sum_{i=l}^{r}A_i$。4 l r
:对于所有的 $i\in[l,r]$,求 $A_i$ 的最大值。5 l r
:对于所有的 $i\in[l,r]$,求 $B_i$ 的最大值。
在每一次操作后,我们都进行一次更新,让 $B_i\gets\max(B_i,A_i)$。
输入格式
第一行包含两个正整数 $n,m$,分别表示数列 $A$ 的长度和操作次数。
第二行包含 $n$ 个整数 $A_1,A_2,\cdots,A_n$,表示数列 $A$。
接下来 $m$ 行,每行行首有一个整数 $op$,表示操作类型;接下来两个或三个整数表示操作参数,格式见【题目描述】。
输出格式
对于 $op\in\{3,4,5\}$ 的操作,输出一行包含一个整数,表示这个询问的答案。
输入输出样例
输入 #1
5 6
1 2 3 4 5
3 2 5
1 1 3 3
4 2 4
2 3 4 1
5 1 5
3 1 4
输出 #1
14
6
6
11
数据范围:
$1\le n,m \le 5e5,-5e8\le a_i \le 5e8$
分析:
具体分析见洛谷题解
#include <iostream>
#include <cstring>
#include <algorithm>
using namespace std;
typedef long long ll;
const int N = 500005;
struct Node
{
int l, r;
ll sum;
int maxa, se, cnt, maxb;
//区间最大值、区间次大值、区间最大值的个数、历史区间最大值
int add_a, add_a1, add_b, add_b1;
//懒标记:区间最大值加法、区间非最大值加法、区间历史最大值加法、区间非最大的历史最大值的加法
}tr[N * 4];
int n, m;
int a[N];
template<typename T>void in(T &x)
{
char ch = getchar();bool flag = 0;x = 0;
while(ch < '0' || ch > '9') flag |= (ch == '-'), ch = getchar();
while(ch <= '9' && ch >= '0') x = (x << 1) + (x << 3) + ch - '0', ch = getchar();
if(flag) x = -x;return ;
}
void pushup(Node& u, Node& l, Node& r)
{
u.maxa = max(l.maxa, r.maxa);
u.maxb = max(l.maxb, r.maxb);
u.sum = l.sum + r.sum;
if(l.maxa == r.maxa)
{
u.se = max(l.se, r.se);
u.cnt = l.cnt + r.cnt;
}
else if(l.maxa > r.maxa)
{
u.se = max(l.se, r.maxa);
u.cnt = l.cnt;
}
else if(l.maxa < r.maxa)
{
u.se = max(l.maxa, r.se);
u.cnt = r.cnt;
}
}
void pushup(int u)
{
pushup(tr[u], tr[u << 1], tr[u << 1 | 1]);
}
void update(int u, int k1, int k2, int k3, int k4)
{
//k1最大值要加的数 k2历史最大值要加的数 k3 非最大值要加的数 k4历史非最大值要加的数
tr[u].sum += 1ll * tr[u].cnt * k1 + 1ll * (tr[u].r - tr[u].l + 1 - tr[u].cnt) *k3;
tr[u].maxb = max(tr[u].maxb, tr[u].maxa + k2); // 当前最大值+k2超过了历史最大值,历史最大值更新
tr[u].add_b = max(tr[u].add_b, tr[u].add_a + k2);
tr[u].add_b1 = max(tr[u].add_b1, tr[u].add_a1 + k4);
tr[u].maxa += k1, tr[u].add_a += k1;
tr[u].add_a1 += k3;
if(tr[u].se != -1e9) tr[u].se += k3;
}
void pushdown(int u)
{
auto& root = tr[u], &left = tr[u << 1], &right = tr[u << 1 | 1];
int maxn = max(left.maxa, right.maxa);
if(left.maxa == maxn)
update(u << 1, root.add_a, root.add_b, root.add_a1, root.add_b1);
else update(u << 1, root.add_a1, root.add_b1, root.add_a1, root.add_b1);
if(right.maxa == maxn)
update(u << 1 | 1, root.add_a, root.add_b, root.add_a1, root.add_b1);
else update(u << 1 | 1, root.add_a1, root.add_b1, root.add_a1, root.add_b1);
tr[u].add_a = tr[u].add_b = tr[u].add_a1 = tr[u].add_b1 = 0;
}
void build(int u, int l, int r)
{
tr[u] = {l, r};
if(l == r)
{
tr[u].sum = tr[u].maxa = tr[u].maxb = a[l];
tr[u].se = -1e9;
tr[u].cnt = 1;
return;
}
int mid = l + r >> 1;
build(u << 1, l, mid), build(u << 1 | 1, mid + 1, r);
pushup(u);
}
void modify_add(int u, int l, int r, int k)
{
if(l <= tr[u].l && r >= tr[u].r)
{
update(u, k, k, k, k);
return;
}
int mid = tr[u].l + tr[u].r >> 1;
pushdown(u);
if(l <= mid) modify_add(u << 1, l, r, k);
if(r > mid) modify_add(u << 1 | 1, l, r, k);
pushup(u);
}
void modify_min(int u, int l, int r, int k)
{
if(k > tr[u].maxa) return;
if(l <= tr[u].l && r >= tr[u].r && k > tr[u].se)
{
update(u, k - tr[u].maxa, k - tr[u].maxa, 0, 0);
return;
}
pushdown(u);
int mid = tr[u].l + tr[u].r >> 1;
if(l <= mid) modify_min(u << 1, l, r, k);
if(r > mid) modify_min(u << 1 | 1, l, r, k);
pushup(u);
}
ll query_sum(int u, int l, int r)
{
if(l <= tr[u].l && r >= tr[u].r) return tr[u].sum;
pushdown(u);
int mid = tr[u].l + tr[u].r >> 1;
ll res = 0;
if(l <= mid) res += query_sum(u << 1, l, r);
if(r > mid) res += query_sum(u << 1 | 1, l, r);
return res;
}
int query_max(int u, int l, int r, int op)
{
if(l <= tr[u].l && r >= tr[u].r)
{
if(op == 0) return tr[u].maxa;
return tr[u].maxb;
}
pushdown(u);
int mid = tr[u].l + tr[u].r >> 1;
int res = -1e9;
if(l <= mid) res = max(res, query_max(u << 1, l, r, op));
if(r > mid) res = max(res, query_max(u << 1 | 1, l, r, op));
return res;
}
int main()
{
in(n), in(m);
for(int i = 1; i <= n; i ++ ) in(a[i]);
build(1, 1, n);
while (m -- )
{
int op, l, r, x;
in(op), in(l), in(r);
if(op == 1)
{
in(x);
modify_add(1, l, r, x);
}
else if(op == 2)
{
in(x);
modify_min(1, l, r, x);
}
else if(op == 3) printf("%lld\n", query_sum(1, l, r));
else if(op == 4) printf("%d\n", query_max(1, l, r, 0));
else if(op == 5) printf("%d\n", query_max(1, l, r, 1));
}
return 0;
}