线段树–区间最值操作例题
维护一个序列$a$,有以下操作:
0 l r t
:对每个区间$[l,r]$的数,$a_i=min(a_i,t)$1 l r
:输出$[L,R]$的最大值2 l r
:输出区间和
多组数据,$T \le 100,n \le 10^6,\sum m \le 10^6$
分析:
资料来源于oi-wiki
代码:
#include <iostream>
#include <cstring>
#include <algorithm>
using namespace std;
typedef long long ll;
const int N = 1000005;
struct Node
{
int l, r;
int ma, se;
ll sum;
int maxcnt;
int flag;
}tr[N * 4];
int n, m;
int w[N];
template<typename T>void in(T &x) {
char ch = getchar();bool flag = 0;x = 0;
while(ch < '0' || ch > '9') flag |= (ch == '-'), ch = getchar();
while(ch <= '9' && ch >= '0') x = (x << 1) + (x << 3) + ch - '0', ch = getchar();
if(flag) x = -x;return ;
}
void pushup(int u)
{
tr[u].sum = tr[u << 1].sum + tr[u << 1 | 1].sum;
if(tr[u << 1].ma == tr[u << 1 | 1].ma)
{
tr[u].maxcnt = tr[u << 1].maxcnt + tr[u << 1 | 1].maxcnt;
tr[u].ma = tr[u << 1].ma;
tr[u].se = max(tr[u << 1].se, tr[u << 1 | 1].se);
}
else if(tr[u << 1].ma > tr[u << 1 | 1].ma)
{
tr[u].maxcnt = tr[u << 1].maxcnt;
tr[u].ma = tr[u << 1].ma;
tr[u].se = max(tr[u << 1].se, tr[u << 1 | 1].ma);
}
else if(tr[u << 1].ma < tr[u << 1 | 1].ma)
{
tr[u].maxcnt = tr[u << 1 | 1].maxcnt;
tr[u].ma = tr[u << 1 | 1].ma;
tr[u].se = max(tr[u << 1].ma, tr[u << 1 | 1].se);
}
}
void pushtag(int u, int k)
{
if(tr[u].ma <= k) return;
tr[u].sum += (1ll * k - tr[u].ma) * (ll)tr[u].maxcnt;
tr[u].ma = tr[u].flag = k;
}
void pushdown(int u)
{
if(tr[u].flag == -1) return;
pushtag(u << 1, tr[u].flag), pushtag(u << 1 | 1, tr[u].flag);
tr[u].flag = -1;
}
void build(int u, int l, int r)
{
tr[u] = {l, r};
tr[u].flag = -1;
if(l == r)
{
int xx;
in(xx);
tr[u].ma = xx;
tr[u].sum = 1ll * xx;
tr[u].maxcnt = 1;
tr[u].se = -1;
tr[u].flag = -1;
return;
}
int mid = l + r >> 1;
build(u << 1, l, mid), build(u << 1 | 1, mid + 1, r);
pushup(u);
}
void modify(int u, int l, int r, int k)
{
if(k >= tr[u].ma) return;
if(l <= tr[u].l && r >= tr[u].r && k > tr[u].se)
{
pushtag(u, k);
return;
}
pushdown(u);
int mid = tr[u].l + tr[u].r >> 1;
if(l <= mid) modify(u << 1, l, r, k);
if(r > mid) modify(u << 1 | 1, l, r, k);
pushup(u);
}
int query_max(int u, int l, int r)
{
if(l <= tr[u].l && r >= tr[u].r) return tr[u].ma;
pushdown(u);
int mid = tr[u].l + tr[u].r >> 1;
int res = -1;
if(l <= mid) res = max(res, query_max(u << 1, l, r));
if(r > mid) res = max(res, query_max(u << 1 | 1, l, r));
return res;
}
ll query_sum(int u, int l, int r)
{
if(l <= tr[u].l && r >= tr[u].r) return tr[u].sum;
pushdown(u);
int mid = tr[u].l + tr[u].r >> 1;
ll res = 0;
if(l <= mid) res += query_sum(u << 1, l, r);
if(r > mid) res += query_sum(u << 1 | 1, l, r);
return res;
}
void solve()
{
in(n), in(m);
build(1, 1, n);
while (m -- )
{
int op, l, r;
int t;
in(op), in(l), in(r);
if(op == 0)
{
in(t);
modify(1, l, r, t);
}
else if(op == 1)
{
printf("%d\n", query_max(1, l, r));
}
else if(op == 2)
{
printf("%lld\n", query_sum(1, l, r));
}
}
}
int main()
{
int t;
in(t);
while(t -- ) solve();
return 0;
}