头像

线段树天下第一


访客:928

离线:16小时前



线段树果真是强,区间修改,区间查询
注意线段树刚刚开始的build的初始化
sum:区间的和,但是忽略祖宗节点的add懒标记
add:懒标记, 表示的是子节点的懒标记,自身的标记已经是加上了(这个是属于个人习惯hh)
除了build, 其他搜索的子区间都是直接l, r, 没必要吧mid 更新进去, 因为他不一定是到mid啊, right 可能比mid小
有些地方的sum千万别忘了 += (r - l)*d; modify ,pushdown;

#include <bits/stdc++.h>
using namespace std;

typedef long long LL;
const int N = 101010;
struct Node{
    int l, r;
    LL sum, add;    ///qu jian he ,yi ji lan biao ji
}tr[4*N];
int n, m;
int w[N];

void pushup(int u){
    tr[u].sum = tr[u<<1].sum + tr[u<<1|1].sum;
}

void pushdown(int u){
    Node &root = tr[u], &left = tr[u<<1], &right = tr[u<<1|1];  ///yin yong
    if(root.add == 0)   return;
    left.add += root.add;   left.sum += (left.r - left.l + 1) * root.add;
    right.add += root.add;  right.sum += (right.r - right.l + 1) * root.add;
    root.add = 0;
}

void build(int u, int l, int r){
    if(l == r){
        tr[u] = {l, r, w[l], 0};            ///千万看准了,不是w[u],是w[r]或者w[l]
    }else{
        int mid = l + r >> 1;
        tr[u] = {l, r, 0, 0};
        build(u << 1, l, mid);  build(u << 1 | 1, mid + 1, r);
        pushup(u);
    }
}

void modify(int u, int l, int r, int d){
    if(tr[u].l >= l && tr[u].r <= r){
        tr[u].add += d; tr[u].sum += (LL)(tr[u].r - tr[u].l + 1) * d;
    }else{
        pushdown(u);            ///lazy
        int mid = tr[u].l + tr[u].r >> 1;
        if(l <= mid)    modify(u << 1, l, r, d);
        if(r > mid)     modify(u << 1 | 1, l, r, d);
        pushup(u);              ///adjust
    }
}

LL query(int u, int l, int r){
    if(tr[u].l >= l && tr[u].r <= r)    return tr[u].sum;   ///没必要在进行pushdown 操作了
    pushdown(u);
    LL sum = 0;
    int mid  = tr[u].l + tr[u].r >> 1;

    if(l <= mid)
        sum += query(u << 1, l, r);
    if(r > mid) 
        sum += query(u << 1 | 1, l, r);
    /*
    if(l <= mid)
        sum += query(u << 1, l, mid);   因为他不一定是到mid啊, right 可能比mid小
    if(r > mid) 
        sum += query(u << 1 | 1, mid + 1, r);
    */
    return sum;
}

int main(){
    cin>>n>>m;
    for(int i = 1; i <= n; i++) scanf("%d", &w[i]);

    build(1, 1, n);

    char op[5];
    int l, r, d;
    while(m--){
        scanf("%s%d%d", op, &l, &r);
        if(*op == 'C'){
            scanf("%d", &d);
            modify(1, l, r, d);
        }else{
            printf("%lld\n", query(1, l, r));
        }
    }
    return 0;
}



活动打卡代码 AcWing 1275. 最大数

注意事项

1.build 注意递归终点

2.modify注意pushup

3.modify递归终点跳出,直接return,免得被pushup

#include <bits/stdc++.h>
using namespace std;

typedef long long LL;
const int N = 202020;
struct Node{
    int l, r;
    int v;  ///max
}node[N*4];

void pushup(int u){
    node[u].v = max(node[u<<1].v, node[u<<1|1].v);
}

void build(int u, int l, int r){
    node[u] = {l, r, int(-2e9)};
    if(l == r)  return;     ///the end!!!
    int mid = l + r >> 1;
    build(u << 1, l, mid), build(u << 1 | 1, mid + 1, r);
}

void modify(int u, int x, int v){
    int mid = node[u].l + node[u].r >> 1;
    if(node[u].l == node[u].r){
        node[u].v = v;
        return;     ///加上一个return 免得被pushup
    }
    else if(x <= mid)   modify(u << 1, x, v);
    else    modify(u << 1 | 1, x, v);
    pushup(u);
}

int query(int u, int l, int r){
    if(node[u].l >= l && node[u].r <= r)    return node[u].v;
    int mid = node[u].l + node[u].r >> 1;
    int res = -2e9;
    if(l <= mid)    res = query(u << 1, l, r);
    if(r > mid)     res = max(res, query(u << 1 | 1, l, r));
    return res;
}

int main(){
    LL m, p, a, t, n;
    char op[5];
    cin>>m>>p;
    n = a = 0;

    build(1, 1, N - 20);
    for(int i = 0; i < m; i++){
        scanf("%s%d", op, &t);
        if(*op == 'Q'){
            a = query(1, n - t + 1, n);
            //printf("%lld %lld %lld\n", n - t + 1, n, a);
            printf("%lld\n", a);
        }else{///add
            //cout<<((a + t) % p + p) % p<<endl;
            modify(1, ++n, ((a + t) % p + p) % p);
        }
    }
    /*cout<<endl<<n<<endl;
    a = query(1, 1, 1);
    printf("%lld\n", a);
    a = query(1, 2, 2);
    printf("%lld\n", a);
    a = query(1, 3, 3);
    printf("%lld\n", a);
    a = query(1, 4, 4);
    printf("%lld\n", a);*/
    return 0;
}




活动打卡代码 AcWing 1237. 螺旋折线

自己慢慢

分情况讨论

就可以了

#include <bits/stdc++.h>
using namespace std;

typedef long long LL;
int main(){
    LL x, y;
    LL k, res;
    LL tmp;
    while(cin>>x>>y){
        if(y >= 0 && abs(x) <= y){///up
            k = 2 * y;
            tmp = x + y;
            res = tmp + k * (k - 1);
        }else if(x >= 0 && abs(y) <= x){///right
            k = 2 * x;
            tmp = x - y;
            res = tmp + k * k;
        }else if(y <= 0 && y <= x + 1){///down
            k = -2 * y + 1;
            tmp = abs(x + y);
            res = tmp + k * (k - 1);
        }else{///left
            k = -2 * x - 1;
            tmp = abs(y - x - 1);
            res = tmp + k * k;
        }
        cout<<res<<endl;
    }
}


活动打卡代码 AcWing 1232. 三体攻击

二分 加 三维数组的差分

#include <cstdio>
#include <cstring>

const int N = 1001010;
typedef long long LL;

LL a[2*N], b[2*N], bp[2*N];
int op[N][7];
int A, B, C;
int move[8][4] = {///注意是减一,不是加一
    { 0,  0, -1, -1}, 
    { 0, -1,  0, -1}, 
    { 0, -1, -1,  1}, 
    {-1,  0,  0, -1},
    {-1,  0, -1,  1}, 
    {-1, -1,  0,  1}, 
    {-1, -1, -1, -1}
};
int get(int i, int j, int k){
    return i * B * C + j * C + k;
}

bool Check(int mid){
    memcpy(b, bp, sizeof bp);
    for(int i = 1; i <= mid; i++){
        ///这一步处理很是重要
        int x1 = op[i][0], x2 = op[i][1], y1 = op[i][2], y2 = op[i][3], z1 = op[i][4], z2 = op[i][5], h = op[i][6];
        b[get(x1    ,y1    ,z1    )] -= h;
        b[get(x1    ,y1    ,z2 + 1)] += h;
        b[get(x1    ,y2 + 1,z1    )] += h;
        b[get(x1    ,y2 + 1,z2 + 1)] -= h;
        b[get(x2 + 1,y1    ,z1    )] += h;
        b[get(x2 + 1,y1    ,z2 + 1)] -= h;
        b[get(x2 + 1,y2 + 1,z1    )] -= h;
        b[get(x2 + 1,y2 + 1,z2 + 1)] += h;
    }

    for(int i = 1; i <= A; i++){
        for(int j = 1; j <= B; j++){
            for(int k = 1; k <= C; k++){
                for(int t = 0; t < 7; t++){
                    b[get(i, j, k)] -= b[get(i+move[t][0], j+move[t][1], k+move[t][2])] * move[t][3];
                }
                if(b[get(i, j, k)] < 0) return true;
            }
        }   
    }
    return false;
}

int main(){
    int m;
    scanf("%d%d%d%d", &A, &B, &C, &m);

    for(int i = 1; i <= A; i++){
        for(int j = 1; j <= B; j++){
            for(int k = 1; k <= C; k++){
                scanf("%lld", &a[get(i, j, k)]);
            }
        }
    }

    for(int i = 1; i <= m; i++){
        scanf("%d %d %d %d %d %d %d", op[i] + 0, op[i] + 1, op[i] + 2, op[i] + 3, op[i] + 4, op[i] + 5, op[i] + 6);
    }

    ///三维差分
    for(int i = 1; i <= A; i++){
        for(int j = 1; j <= B; j++){
            for(int k = 1; k <= C; k++){
                bp[get(i, j, k)] = a[get(i, j, k)];
                for(int t = 0; t < 7; t++){
                    bp[get(i, j, k)] += a[get(i+move[t][0], j+move[t][1], k+move[t][2])] * move[t][3];
                }
            }
        }
    }
    ///二分
    int l = 1, r = m, mid;
    while(l < r){
        mid = l + r >> 1;
        if(Check(mid))  r = mid;
        else    l = mid + 1;
    }
    printf("%d\n", l);
    return 0;
}


活动打卡代码 AcWing 798. 差分矩阵

#include <bits/stdc++.h>
using namespace std;

const int N = 1010;
int a[N][N], b[N][N];
int n, m, q;

int main(){
    cin>>n>>m>>q;

    memset(a, 0, sizeof a);
    for(int i = 1; i <= n; i++){
        for(int j = 1; j <= m; j++){
            scanf("%d", &a[i][j]);
        }
    }

    for(int i = 1; i <= n; i++){
        for(int j = 1; j <= m; j++){
            b[i][j] = a[i][j] - a[i-1][j] - a[i][j-1] + a[i-1][j-1];
        }
    }

    int x1, x2, y1, y2, c;
    while(q--){
        scanf("%d%d%d%d%d", &x1, &y1, &x2, &y2, &c);
        b[x1][y1] += c;
        b[x1][y2+1] -= c;
        b[x2+1][y1] -= c;
        b[x2+1][y2+1] += c;
    }

    for(int i = 1; i <= n; i++){
        for(int j = 1; j <= m; j++){
            b[i][j] += b[i-1][j] + b[i][j-1] - b[i-1][j-1];
            printf("%d ", b[i][j]);
        }
        cout<<endl;
    }
    return 0;
}


活动打卡代码 AcWing 797. 差分

#include <bits/stdc++.h>
using namespace std;

const int N = 101010;
int a[N], b[N];
int n, m;

int main(){
    cin>>n>>m;
    for(int i = 1; i <= n; i++) scanf("%d", a + i);

    for(int i = 1; i <= n; i++) b[i] = a[i] - a[i-1];

    int l, r, c;
    for(int i = 0; i < m; i++){
        scanf("%d%d%d", &l, &r, &c);
        b[l] += c;
        b[r+1] -= c;
    }

    for(int i = 1; i <= n; i++) b[i] += b[i-1];
    for(int i = 1; i <= n; i++) printf("%d ", b[i]);
    cout<<endl;
    return 0;
}


活动打卡代码 AcWing 1228. 油漆面积

线段树的传参

只有在build需要mid(建树的时候分配负责区域),
在query, modify时候使用mid是错误的,主要是因为他们传递的参数。和我们建表的mid不一定同步
有可能造成一定区间的丢失,倘若是同步的,那么mid也是可以的

本题需要注意点

1.首先是y2–, 将原本图上的点映射为正方格
2.而且本题需要化简,即把完全覆盖的modify的子区间递归省去,否则会超时,主要是因为线段树的常数比较大

#include <bits/stdc++.h>
using namespace std;

const int N = 10101;

int n;
struct Segment{
    int x, y1, y2, k;   ///k表示的是他是矩形的开头,还是结尾
    bool operator < (const Segment &S){
        return x < S.x;
    }
}seg[2*N];

struct Node{
    int l, r;
    int len;    ///[l, r]上面的长度;
    int cnt;    ///完整出现的次数;
}tr[4*N];

void pushup(int u){///一定要注意这个len的讨论
    if(tr[u].cnt)   tr[u].len = tr[u].r - tr[u].l + 1;
    else if(tr[u].l == tr[u].r) tr[u].len = 0;
    else tr[u].len = tr[u << 1].len + tr[u << 1 | 1].len;
}

void build(int u, int l, int r){
    tr[u].l = l, tr[u].r = r;
    if(l == r)  return;
    int mid = l + r >> 1;
    build(u << 1, l, mid), build(u << 1 | 1, mid + 1, r);
}

void modify(int u, int l, int r, int k){//区间修改

    if(tr[u].l == tr[u].r){ ///递归 终点
        tr[u].cnt += k; pushup(u);
        return;
    }

    int mid = tr[u].l + tr[u].r >> 1;
    if(tr[u].l >= l && tr[u].r <= r){//完全覆盖,没有必要递归了
        tr[u].cnt += k;    
        //modify(u << 1, l, r, k), modify(u << 1 | 1, l, r, k);   ///这一个步骤是可以省去的
        ///这一步不省去的话,会超时。。。。。。我笑了md
        pushup(u);
    }else{//部分覆盖
        if(l <= mid)    modify(u << 1, l, r, k);    ///一定要注意递归到哪里去了没有mid
        if(r > mid)     modify(u << 1 | 1, l, r, k);    ///因为这个的修改的传参是有问题的,所以说l + r不一定对应我们的区间
                                                        ///因此不可以使用mid
        pushup(u);
    }
}

int query(int u, int l, int r){
    int sum = 0;
    if(tr[u].l >= l && tr[u].r <= r)    return tr[u].len;
    int mid = tr[u].l + tr[u].r >> 1;
    if(mid >= l)    sum += query(u << 1, l, mid);               ///这个的传参可以使用mid,因为我的传参和建表是一样的
    if(mid < r)     sum += query(u << 1 | 1, mid + 1, r);       ///但是最好也别这样用,hxd
    return sum;
}

int main(){
    cin>>n;
    int m = 0, x1, x2, y1, y2;
    for(int i = 1; i <= n; i++){
        scanf("%d%d%d%d", &x1, &y1, &x2, &y2);
        seg[m++] = {x1, y1, y2, 1};
        seg[m++] = {x2, y1, y2, -1};
    }

    sort(seg, seg + m);

    build(1, 0, 10000);

    int res = 0;
    //cout<<query(1, 0, 10000)<<endl;
    for(int i = 0; i < m; i++){
        if(i){
            res += (seg[i].x - seg[i-1].x) * query(1, 0, 10000);
            //cout<<seg[i].x - seg[i-1].x<<' '<<query(1, 0, 10000)<<"|";
        }
        ///为了使用线段树,这里y2应当减去一,是的这个点目前对应的是一个矩形,
        ///而不是真正图形上的点
        modify(1, seg[i].y1, seg[i].y2 - 1, seg[i].k);
    }

    cout<<res<<endl;
    return 0;
}


活动打卡代码 AcWing 1215. 小朋友排队

身高是需要偏移量的
证明比较关键,其次就是树状数组是如何使用的了
树状数组存储的是C[i]:这个数出现的次数的部分区间维护,所以说需要开大一些
找每一个小孩子左面比他高的,右面比他矮的,
这个左右是如何实现的呢,是使用树状数组的动态维护实现的
LL
O(N * log M)

#include <bits/stdc++.h>
using namespace std;

typedef long long LL;
const int N = 100010, M = 1000010;
int w[N], cnt[N];
int tr[M];
int n;

int lowbit(int x){
    return x & -x;
}

void add(int x, int y = 1){
    for(int i = x; i < M; i += lowbit(i))   tr[i] += y;
}

int query(int x){
    int sum = 0;
    for(int i = x; i; i -= lowbit(i))   sum += tr[i];
    return sum;
}

int main(){
    cin>>n;
    for(int i = 1; i <= n; i++) scanf("%d", &w[i]), w[i]++;
    memset(cnt, 0, sizeof cnt);

    /// w[i] left > num, 注意这个遍历顺序
    for(int i = 1; i <= n; i++){        
        cnt[i] += query(M - 1) - query(w[i]);
        add(w[i]);
    }

    memset(tr, 0, sizeof tr);
    /// w[i] right < num
    for(int i = n; i; i--){
        cnt[i] += query(w[i] - 1);
        add(w[i]);
    }

    LL res = 0;
    for(int i = 1; i <= n; i++) res += LL(cnt[i]) * (cnt[i] + 1) / 2;
    cout<<res<<endl;
    return 0;
}



求大神解答




对于这样的题目一定要先看看

区间

是从哪里开始的
注意modify那里

不是if , else

#include <bits/stdc++.h>
using namespace std;

const int N = 101010;

int w[N];
struct Node{
    int l, r, minv;
}tr[4*N];
int m, n;

void pushup(int u){
    tr[u].minv = max(tr[u << 1].minv, tr[u << 1 | 1].minv);
}

void build(int u, int l, int r){
    tr[u].l = l, tr[u].r = r;
    if(l == r){
        tr[u].minv = w[l];    return;
    }
    int mid = l + r >> 1;
    build(u << 1, l, mid), build(u << 1 | 1, mid + 1, r);
    pushup(u);
}

int query(int u, int l, int r){
    if(l <= tr[u].l && r >= tr[u].r)  return tr[u].minv;
    int mid = tr[u].l + tr[u].r >> 1;
    int sum = -0x3f3f3f3f;
    ///注意这里不是if else
    if(l <= mid)    sum = query(u << 1, l, r);
    if(r > mid)    sum = max(sum, query(u << 1 | 1, l, r));
    return sum;
}

void modify(int u, int x, int y){
    if(tr[u].l == tr[u].r){
        tr[u].minv = y; return;
    }  
    int mid = tr[u].l + tr[u].r >> 1;
    if(x <= mid)    modify(u << 1, x, y);
    else    modify(u << 1 | 1, x, y);
    pushup(u);
}
int main(){
    cin>>n>>m;
    for(int i = 1; i <= n; i++) scanf("%d", w + i);
    build(1, 1, n);

    int x, y;
    while(m--){
        scanf("%d%d", &x, &y);
        printf("%d\n", query(1, x, y));
    }

    return 0;
}