题目描述
难度分:$2400$
输入$T(\leq 10^4)$表示$T$组数据。所有数据的$n$之和$\leq 3 \times 10^5$。
每组数据输入$n$,$k(1 \leq k \leq n \leq 3 \times 10^5)$和长为$n$的数组$a(1 \leq a[i] \leq 10^9)$。
- 从$a$中找到一个长度恰好为$k$的子序列$b$。
- 选定一个长度$L$,将$b$划分成长为$L$的前缀$p$,长为$k-L$的后缀$q$,其中$0 \leq L \leq k$。
- 最小化$max(sum(p), sum(q))$。
输出这个最小值。
注:子序列不一定连续。
输入样例
6
5 4
1 10 1 1 1
5 3
1 20 5 15 3
5 3
1 20 3 15 5
10 6
10 8 20 14 3 8 6 4 16 11
10 5
9 9 2 13 15 19 4 9 13 12
1 1
1
输出样例
2
6
5
21
18
1
算法
二分答案+前后缀分解
要求的是最大值的最小值,很容易就往二分答案上想,再把它“合理化”一点就基本确定了做法。因为$a$数组中的所有元素都是正数,肯定是越加越多,$p$或$q$为空直接就使$max(sum(p), sum(q))$最大化了,所以要最大是很容易达成的。
然后我们考虑一下在二分时怎么去$check$一个给定的最小$limit=max(sum(p), sum(q))$能够达成,也很容易想到前后缀分解,定义两个状态:
- $pre[i]$表示从前缀$[1,i]$中选择若干元素,使得其累加和不超过$limit$的最长子序列长度。
- $suf[i]$表示从后缀$[i,n]$中选择若干元素,使得其累加和不超过$limit$的最长子序列长度。
这只需要正序+倒序遍历数组$a$,做两遍反悔贪心就可以预处理出来,用一个大根堆存当前算入累加和$sum$中的元素,只要发现$sum \gt limit$了,就从堆中删除那个最大的元素,把$sum$降至不超过$limit$。而只要存在一个$i \in [1,n]$,使得$pre[i]+suf[i+1] \geq k$成立。就肯定能凑出一个长度为$k$的子序列,使得$max(sum(p), sum(q)) \leq limit$。
复杂度分析
时间复杂度
二分的下界为$0$,上界为$A=\Sigma_{i=1}^{n}a[i]$,因此二分的时间复杂度为$O(log_2A)$。对于一个给定的最小值$mid$,需要进行前后缀分解,但在遍历的过程中还需要利用大根堆来进行反悔贪心,因此$check$的时间复杂度是$O(nlog_2n)$。
综上,算法整体的时间复杂度为$O(nlog_2nlog_2A)$。
空间复杂度
空间消耗主要在于$check$函数中的前后缀信息数组,以及反悔贪心过程中的大根堆。它们的空间都是$O(n)$的,所以算法整体的额外空间复杂度就是$O(n)$。
C++ 代码
#include <iostream>
#include <cstring>
#include <algorithm>
#include <queue>
using namespace std;
typedef long long LL;
const int N = 300010;
int t, n, k, a[N], pre[N], suf[N];
bool check(LL limit) {
priority_queue<int> heap;
LL presum = 0;
for(int i = 1; i <= n; i++) {
presum += a[i];
heap.push(a[i]);
if(presum > limit) {
presum -= heap.top();
heap.pop();
}
pre[i] = heap.size();
}
while(!heap.empty()) heap.pop();
LL sufsum = 0;
suf[n + 1] = 0;
for(int i = n; i >= 1; i--) {
sufsum += a[i];
heap.push(a[i]);
if(sufsum > limit) {
sufsum -= heap.top();
heap.pop();
}
suf[i] = heap.size();
}
for(int i = 1; i <= n; i++) {
if(pre[i] + suf[i + 1] >= k) return true;
}
return false;
}
int main() {
scanf("%d", &t);
while(t--) {
scanf("%d%d", &n, &k);
LL tot = 0;
for(int i = 1; i <= n; i++) {
scanf("%d", &a[i]);
tot += a[i];
}
LL l = 0, r = tot;
while(l < r) {
LL mid = l + ((r - l)>>1);
if(check(mid)) {
r = mid;
}else {
l = mid + 1;
}
}
printf("%lld\n", r);
}
return 0;
}