题目描述
难度分:$2000$
输入$T(\leq 10^4)$表示$T$组数据。所有数据的$n$之和$\leq 2 \times 10^5$。
每组数据输入$n(1 \leq n \leq 2 \times 10^5)$和一个$0$~$n-1$的排列$p$。
输出$p$有多少个非空连续子数组$b$,满足$mex(b) > median(b)$。
注:$mex(b)$为不在$b$中的最小非负整数,例如$mex([1,0,3])=2$。
注:$median(b)$为$b$的中位数,例如$median([0,1,2,3])=1$。如果有两个中位数,取小的那个。
输入样例
8
1
0
2
1 0
3
1 0 2
4
0 2 1 3
5
3 1 0 2 4
6
2 0 4 1 3 5
8
3 7 2 6 0 1 5 4
4
2 0 1 3
输出样例
1
2
4
4
8
8
15
6
算法
双指针
很有意思的一道题,要想$mex(b) \gt median(b)$成立,那么$[0,med]$都应该在数组$b$当中。而如果$med$是$b$的中位数,那此时已经有$med$个比$med$小的数在数组中($[0,med)$),还需要$med$或者$med+1$个比$med$大的数在数组中才行,这样的话数组$b$的长度就有两种可能:$2 \times med+1$和$2 \times med + 2$。
因此可以从小到大枚举$med \in [0,n)$,计算在中位数为$med$的情况下,有多少个数组满足$mex(b) \gt med$,而要知道有多少个数组满足这个条件,只需要找到最小的那个数组就好了,这个最小的数组左右边界往外扩张到底的所有数组都符合。
对于一个给定的中位数$med$,要求数组$b$要包括这个$med$(为了快速得到$med$的位置,可以先预处理出一个$pos$数组,$pos[i]$表示$i$在排列$p$中的位置),在$med \in [0,n)$的过程中逐渐扩张左右边界。如果对于中位数$med$得到的最小数组是$[l,r]$,那么此时数组的左边界$left \in [1,l]$,数组的右边界$right \in [r,n]$,而$right$根据$b$的长度有两种可能:$right=left+(2 \times med+1)-1$,$right=left+(2 \times med+2)-1$。
因此$1 \leq left \leq l$,$r \leq left + len - 1 \leq n$,得$max(1,r-len+1) \leq left \leq min(l, n-len+1)$,其中$len=2 \times med+1$或$2 \times med+2$。$med$对答案的贡献为$min(l, n-len+1)-max(1,r-len+1)+1$,$med$依次取完$0$到$n-1$的所有值之后把所有贡献加起来就是最终答案。
复杂度分析
时间复杂度
遍历$med \in [0,n)$就统计出了答案,而对于每个$med$,贡献的计算都是$O(1)$,因此整个算法的时间复杂度就是$O(n)$。
空间复杂度
额外空间主要是$pos$数组,它的空间消耗是线性的,因此算法的额外空间复杂度为$O(n)$。
C++ 代码
#include <iostream>
#include <cstring>
#include <algorithm>
using namespace std;
typedef long long LL;
const int N = 200010;
int t, n, p[N], pos[N];
void solve() {
int left = 1e9, right = 0;
LL ans = 0;
for(int med = 0; med < n; med++) {
// 要包含med这个数必须要获得区间[left,right]
left = min(left, pos[med]), right = max(right, pos[med]);
// 第一种可能的区间长度
int len = 2*med + 2;
if(right - left + 1 <= len) {
int l = max(1, right - len + 1), r = min(left, n - len + 1);
ans += max(0, r - l + 1);
}
// 第二种可能的区间长度
len = 2*med + 1;
if(right - left + 1 <= len) {
int l = max(1, right - len + 1), r = min(left, n - len + 1);
ans += max(0, r - l + 1);
}
}
printf("%lld\n", ans);
}
int main() {
scanf("%d", &t);
while(t--) {
scanf("%d", &n);
for(int i = 1; i <= n; i++) {
scanf("%d", &p[i]);
pos[p[i]] = i;
}
solve();
}
return 0;
}