AcWing
  • 首页
  • 课程
  • 题库
  • 更多
    • 竞赛
    • 题解
    • 分享
    • 问答
    • 应用
    • 校园
  • 关闭
    历史记录
    清除记录
    猜你想搜
    AcWing热点
  • App
  • 登录/注册

AcWing 2506. 拦截导弹    原题链接    简单

作者: 作者的头像   Union_Find ,  2025-06-10 16:28:10 · 福建 ,  所有人可见 ,  阅读 5


1


题意

求一个序列的二维最长不上升子序列的长度,以及每个点称为该子序列的长度。

保证方案数在 double 范围内。

$1 \le n \le 5 \times 10^4$。

分析

我们容易想到一个 dp。定义 $f_i$ 表示以 $i$ 结尾的二维最长不上升子序列(后文称为 LDS)的长度。我们就可以设计出 $O(n^2)$ 的转移。

$$f_i = \max_{j<i,a_j \ge a_i,b_j \ge b_i} f_j + 1$$

然后求每个点的概率。容易想到要求经过每个点的 LDS 个数。于是我们定义 $g_i$ 表示以 $i$ 结尾的 LDS 个数。然后我们同样可以容易得到 $O(n^2)$ 的转移。

$$g_i = \sum_{j<i,a_j \ge a_i,b_j \ge b_i} [f_j + 1 = f_i]g_j$$

但是这只是以 $i$ 结尾的答案,我们要的是经过 $i$ 的答案,于是我们要再求一个以 $i$ 为开头 的 LDS 长度及个数。

为了区分,定义以 $i$ 为结尾的 LDS 是 $f1_i,g1_i$,以 $i$ 为开头的 LDS 是 $f2_i,g2_i$。二者转移是类似的。

首先可以求出 LDS 的长度,及 $f1_i$ 的最大值,设为 $maxn$。

现在求概率,我们要先求出总的方案数,设为 $sum$。

$$sum = \sum_{i=1}^n [f1_i = maxn]g1_i$$

然后我们判断一个点的方案数,就是 $g1_i \times g2_i$。但是前提是 $f1_i + f2_i - 1 = maxn$,及经过 $i$ 的 LDS 长度是 $maxn$。概率就是 $\frac{g1_i \times g2_i}{sum}$。

int n, maxn, f1[N], f2[N];
double sum, g1[N], g2[N];
struct Point{
    int h, w;   
}p[N];
signed main(){
    n = rd();
    for (int i = 1; i <= n; i++) p[i] = Point{rd(), rd()};
    for (int i = 1; i <= n; i++){
        f1[i] = 1;
        for (int j = 1; j < i; j++) if (p[j].h >= p[i].h && p[j].w >= p[i].w) f1[i] = max(f1[i], f1[j] + 1);
        for (int j = 1; j < i; j++) if (p[j].h >= p[i].h && p[j].w >= p[i].w && f1[j] + 1 == f1[i]) g1[i] += g1[j];
        if (!g1[i]) g1[i]++;
    }
    for (int i = n; i >= 1; i--){
        f2[i] = 1;
        for (int j = i + 1; j <= n; j++) if (p[j].h <= p[i].h && p[j].w <= p[i].w) f2[i] = max(f2[i], f2[j] + 1);
        for (int j = i + 1; j <= n; j++) if (p[j].h <= p[i].h && p[j].w <= p[i].w && f2[j] + 1 == f2[i]) g2[i] += g2[j];
        if (!g2[i]) g2[i]++;
    }
    for (int i = 1; i <= n; i++) maxn = max(maxn, f1[i]);
    printf ("%lld\n", maxn);
    for (int i = 1; i <= n; i++) if (f1[i] == maxn) sum += g1[i];
    for (int i = 1; i <= n; i++) if (f1[i] + f2[i] - 1 == maxn) printf ("%.5lf ", g1[i] * g2[i] / sum); else printf ("0.00000 ");
    return 0;
}

然后 T 飞了。这个时间复杂度是 $O(n^2)$ 的。我们考虑优化。

考虑一维的时候我们可以用二分或线段树维护来做到单点加,区间最大值。这样子就是 $O(n\log n)$ 的。但是现在是二维的,难道我们用二维线段树?

确实可以,但是你想写吗?

我们用 cdq 分治,可以做到时间 $O(n\log^2n)$,空间 $O(n)$。实际上,这是一道 cdq 优化 dp 的板子。

套路的,我们要分治解决这个问题。以 $f1_i$ 和 $g1_i$ 为例。设当前区间是 $[l,r]$,中点是 $mid$。

  • 求区间 $[l,mid]$ 的 dp 值。
  • 将 $[l,mid]$ 的 dp 值转移给 $[mid+1,r]$ 的 dp 值。
  • 求 $[mid+1,r]$ 的 dp 值。

第 $1$ 步和第 $3$ 步都是递归求解,主要是第 $2$ 步。

考虑转移的限制有三个,为 $j<i,a_j \ge a_i,b_j \ge b_i$。其中我们用分治处理了第一个限制,然后我们暴力按照 $a$ 排序,双指针处理解决掉第 $2$ 个限制。现在只剩第 $3$ 个限制。我们可以直接用线段树单点修改区间查询解决。

首先考虑线段树中维护最大的 $f$ 以及这些位置对应的 $g$ 的和。用一个结构体封装起来,如下代码。

struct Tree{
    int mx;double ct;
}tr[N << 2];
Tree operator +(const Tree &a, const Tree &b){
    if (a.mx > b.mx) return a;
    if (a.mx < b.mx) return b;
    return Tree{a.mx, a.ct + b.ct};
}
void clr(int p, int l, int r){
    if (!tr[p].mx) return ;// 这一句很关键,否则时间复杂度会退化到 O(n^2log n)
    tr[p] = Tree{0, 0};
    if (l == r) return ;
    int mid = (l + r) >> 1;
    clr(p << 1, l, mid), clr(p << 1 | 1, mid + 1, r);
}
void upd(int p, int l, int r, int x, Tree k){
    if (l == r) return tr[p] = tr[p] + k, void(0);
    int mid = (l + r) >> 1;
    if (x <= mid) upd(p << 1, l, mid, x, k);
    else upd(p << 1 | 1, mid + 1, r, x, k);
    tr[p] = tr[p << 1] + tr[p << 1 | 1];
}
Tree qry(int p, int l, int r, int nl, int nr){
    if (nl <= l && r <= nr) return tr[p];
    int mid = (l + r) >> 1;
    if (nl <= mid && mid < nr) return qry(p << 1, l, mid, nl, nr) + qry(p << 1 | 1, mid + 1, r, nl, nr);
    if (nl <= mid) return qry(p << 1, l, mid, nl, nr);
    return qry(p << 1 | 1, mid + 1, r, nl, nr);
}

这里 $mx$ 就是最大的 $f$,$ct$ 就是个数。

看一下主要的 cdq 分治代码。

struct Point{
    int id, h, w;
}p[N];
bool operator <(const Point &a, const Point &b){return a.id < b.id;}
il bool cmp(Point a, Point b){return (a.h != b.h) ? (a.h > b.h) : (a.id < b.id);}
void cdq1(int l, int r){
    if (l == r) return ;
    sort (p + l, p + r + 1);// 按照编号排序
    int mid = (l + r) >> 1;
    cdq1(l, mid);
    sort (p + l, p + mid + 1, cmp), sort(p + mid + 1, p + r + 1, cmp);// 按照第一维排序
    clr(1, 1, n);// 清空线段树
    for (int i = l, j = mid + 1; j <= r; j++){// 双指针
        for (; i <= mid && p[i].h >= p[j].h; i++) upd(1, 1, n, p[i].w, f[p[i].id]);
        Tree t = qry(1, 1, n, p[j].w, n);
        t.mx++;
        f[p[j].id] = f[p[j].id] + t;// 转移
    }
    cdq1(mid + 1, r);
}

这个过程很简单,只要理解了上面所说的应该都不难。

完整代码。

#include <bits/stdc++.h>
using namespace std;
#define int long long
#define il inline
#define N 100005
il int rd(){
    int s = 0, w = 1;
    char ch = getchar();
    for (;ch < '0' || ch > '9'; ch = getchar()) if (ch == '-') w = -1;
    for (;ch >= '0' && ch <= '9'; ch = getchar()) s = ((s << 1) + (s << 3) + ch - '0');
    return s * w;
}
int n, nls[N], cnt, maxn;
double sum;
struct Point{
    int id, h, w;
}p[N];
bool operator <(const Point &a, const Point &b){return a.id < b.id;}
il bool cmp(Point a, Point b){return (a.h != b.h) ? (a.h > b.h) : (a.id < b.id);}
struct Tree{
    int mx;double ct;
}tr[N << 2], f[N], g[N];
Tree operator +(const Tree &a, const Tree &b){
    if (a.mx > b.mx) return a;
    if (a.mx < b.mx) return b;
    return Tree{a.mx, a.ct + b.ct};
}
void clr(int p, int l, int r){
    if (!tr[p].mx) return ;
    tr[p] = Tree{0, 0};
    if (l == r) return ;
    int mid = (l + r) >> 1;
    clr(p << 1, l, mid), clr(p << 1 | 1, mid + 1, r);
}
void upd(int p, int l, int r, int x, Tree k){
    if (l == r) return tr[p] = tr[p] + k, void(0);
    int mid = (l + r) >> 1;
    if (x <= mid) upd(p << 1, l, mid, x, k);
    else upd(p << 1 | 1, mid + 1, r, x, k);
    tr[p] = tr[p << 1] + tr[p << 1 | 1];
}
Tree qry(int p, int l, int r, int nl, int nr){
    if (nl <= l && r <= nr) return tr[p];
    int mid = (l + r) >> 1;
    if (nl <= mid && mid < nr) return qry(p << 1, l, mid, nl, nr) + qry(p << 1 | 1, mid + 1, r, nl, nr);
    if (nl <= mid) return qry(p << 1, l, mid, nl, nr);
    return qry(p << 1 | 1, mid + 1, r, nl, nr);
}
void cdq1(int l, int r){
    if (l == r) return ;
    sort (p + l, p + r + 1);
    int mid = (l + r) >> 1;
    cdq1(l, mid);
    sort (p + l, p + mid + 1, cmp), sort(p + mid + 1, p + r + 1, cmp);
    clr(1, 1, n);
    for (int i = l, j = mid + 1; j <= r; j++){
        for (; i <= mid && p[i].h >= p[j].h; i++) upd(1, 1, n, p[i].w, f[p[i].id]);
        Tree t = qry(1, 1, n, p[j].w, n);
        t.mx++;
        f[p[j].id] = f[p[j].id] + t;
    }
    cdq1(mid + 1, r);
}
void cdq2(int l, int r){
    if (l == r) return ;
    sort (p + l, p + r + 1);
    int mid = (l + r) >> 1;
    cdq2(mid + 1, r);
    sort (p + l, p + mid + 1, cmp), sort(p + mid + 1, p + r + 1, cmp);
    clr(1, 1, n);
    for (int i = mid, j = r; i >= l; i--){
        for (; j > mid && p[j].h <= p[i].h; j--) upd(1, 1, n, p[j].w, g[p[j].id]);
        Tree t = qry(1, 1, n, 1, p[i].w);
        t.mx++;
        g[p[i].id] = g[p[i].id] + t;
    }
    cdq2(l, mid);
}
signed main(){
    n = rd();
    for (int i = 1; i <= n; i++) p[i] = Point{i, rd(), rd()}, nls[i] = p[i].w, f[i] = g[i] = Tree{1, 1};
    sort (nls + 1, nls + n + 1);
    cnt = unique(nls + 1, nls + n + 1) - nls - 1;
    for (int i = 1; i <= n; i++) p[i].w = lower_bound(nls + 1, nls + cnt + 1, p[i].w) - nls;
    cdq1(1, n), cdq2(1, n);
    for (int i = 1; i <= n; i++) maxn = max(maxn, f[i].mx);
    printf ("%lld\n", maxn);
    for (int i = 1; i <= n; i++) if (f[i].mx == maxn) sum += f[i].ct;
    for (int i = 1; i <= n; i++) if (f[i].mx + g[i].mx - 1 == maxn) printf ("%.5lf ", f[i].ct * g[i].ct / sum); else printf ("0.00000 ");
    return 0;
}

0 评论

App 内打开
你确定删除吗?
1024
x

© 2018-2025 AcWing 版权所有  |  京ICP备2021015969号-2
用户协议  |  隐私政策  |  常见问题  |  联系我们
AcWing
请输入登录信息
更多登录方式: 微信图标 qq图标 qq图标
请输入绑定的邮箱地址
请输入注册信息