算法1:前缀和
时间复杂度:O(n)。
用空间换时间。时间最短。约166 ms。
cnt[i]
表示在A中i这个值出现多少次,前缀和数组s[i] == cnt[0] + cnt[1] + ... + cnt[i]
,表示在
A中,0~i出现了多少次。
根据题目,A中i最大为10^5。
C++版本:
#include <iostream>
#include <cstdio>
#include <algorithm>
#include <cstring>
using namespace std;
typedef long long LL;
const int N = 100010;
int a[N],b[N],c[N];
int cnt[N],s[N];
int as[N],cs[N];// 分别表示在A、C中有多少个数>、<b[i]
int main(){
int n;
scanf("%d",&n);
for (int i = 0;i < n;i++) scanf("%d",&a[i]),a[i]++;//a[i]可以为0,+1避免前缀和问题
for (int i = 0;i < n;i++) scanf("%d",&b[i]),b[i]++;//a、b、c只需考虑相对大小
for (int i = 0;i < n;i++) scanf("%d",&c[i]),c[i]++;
//注意循环细节
// 求as[N]
for (int i = 0;i < n;i++) cnt[a[i]]++;
for (int i = 1;i < N;i++) s[i] = s[i-1] + cnt[i];// 求cnt[]的前缀和,从1开始避免越界
for (int i = 0;i < n;i++) as[i] = s[b[i]-1];
// 求cs[N]
memset(cnt,0,sizeof cnt);
memset(s,0,sizeof s);
for (int i = 0;i < n;i++) cnt[c[i]]++;
for (int i = 1;i < N;i++) s[i] = s[i-1] + cnt[i];
for (int i = 0;i < n;i++) cs[i] = s[N-1] - s[b[i]];
LL res = 0;
for (int i = 0;i < n;i++) res += (LL)as[i]*cs[i];//一个数转LL就行
// 注意:LL(as[i]*cs[i])是错的,有精度损失
printf("%lld",res);
return 0;
}
Java版本:
参考题解:https://www.acwing.com/solution/content/7392/
import java.io.BufferedReader;
import java.io.IOException;
import java.io.InputStreamReader;
import java.util.Arrays;
public class Main{
static int N = 100010;
static int[] a = new int[N];
static int[] b = new int[N];
static int[] c = new int[N];
static int[] acnt = new int[N];//acnt和ccnt开两个数组,因为开一个不能完全覆盖
static int[] ccnt = new int[N];
static int[] as = new int[N];
static int[] cs = new int[N];
static int[] s = new int[N];//s只需开一个,因为能完全覆盖
public static void main(String[] args) throws NumberFormatException,IOException{
BufferedReader reader = new BufferedReader(new InputStreamReader(System.in));
int n = Integer.parseInt(reader.readLine().trim());//去掉空格、回车
String[] s1 = reader.readLine().split(" ");
String[] s2 = reader.readLine().split(" ");
String[] s3 = reader.readLine().split(" ");
for (int i = 1;i <= n;i++) a[i] = Integer.parseInt(s1[i-1]) + 1;
//String转int
for (int i = 1;i <= n;i++) b[i] = Integer.parseInt(s2[i-1]) + 1;
for (int i = 1;i <= n;i++) c[i] = Integer.parseInt(s3[i-1]) + 1;
for (int i = 1;i <= n;i ++) acnt[a[i]]++;
for (int i = 1;i <= N-1;i ++) s[i]= s[i-1] + acnt[i];
for (int i = 1;i <= n;i ++) as[i] = s[b[i]-1];
for (int i = 1;i <= n;i ++) ccnt[c[i]]++;
for (int i = 1;i <= N-1;i ++) s[i] = s[i-1] + ccnt[i];
for (int i = 1;i <= n;i ++) cs[i] = s[N-1] - s[b[i]];
long res = 0;
for (int i = 1;i <= n;i ++) res += (long)as[i]*cs[i];
System.out.println(res);
}
}
算法2:二分
时间复杂度:O(n*logn)。约703 ms。时间最长。
C++STL版本:
#include <iostream>
#include <cstdio>
#include <algorithm>
using namespace std;
typedef long long LL;
const int N = 100010;
int a[N],b[N],c[N];
int main(){
int n;
cin >> n;
for (int i = 0;i < n;i++) scanf("%d",&a[i]);
for (int i = 0;i < n;i++) scanf("%d",&b[i]);
for (int i = 0;i < n;i++) scanf("%d",&c[i]);
sort(a,a+n);
sort(b,b+n);
sort(c,c+n);
LL res = 0;
for (int i = 0;i < n;i++){
int la = lower_bound(a,a+n,b[i]) - a;//在数组a中找比b[i]小的数
int rc = upper_bound(c,c+n,b[i]) - c;//在数组c中找比b[i]大的数
if (la == 0 || rc == n) continue;//这句可以不加,计算会得到0,加上能优化几毫秒
res += LL(la)*(n-rc);
}
printf("%lld",res);
return 0;
}
Java手写版本:
import java.io.BufferedReader;
import java.io.InputStreamReader;
import java.io.IOException;
import java.util.Arrays;
public class Main{
static int N = 100010;
static int[] a = new int[N];
static int[] b = new int[N];
static int[] c = new int[N];
public static void main(String[] args) throws NumberFormatException,IOException{
BufferedReader reader = new BufferedReader(new InputStreamReader(System.in));
int n = Integer.parseInt(reader.readLine().trim());
String[] s1 = reader.readLine().split(" ");
String[] s2 = reader.readLine().split(" ");
String[] s3 = reader.readLine().split(" ");
for (int i = 1;i <= n;i++) a[i] = Integer.parseInt(s1[i-1]);
for (int i = 1;i <= n;i++) b[i] = Integer.parseInt(s2[i-1]);
for (int i = 1;i <= n;i++) c[i] = Integer.parseInt(s3[i-1]);
Arrays.sort(a,1,n + 1);// 对a[1]到a[n]从小到大排序
Arrays.sort(b,1,n + 1);
Arrays.sort(c,1,n + 1);
long res = 0;
// 求满足最小的<=b[i]的下标
for (int i = 1;i <= n;i++){
int la = 0,ra = n+1;
while (la < ra){
int mid = (la + ra) >> 1;
if (a[mid] < b[i]) la = mid + 1;
else ra = mid;
}
// 求满足最小的>=b[i]的下标
int lc = 0,rc = n + 1;
while(lc < rc)
{
int mid = (lc + rc + 1) >> 1;
if(c[mid] <= b[i]) lc = mid;
else rc = mid - 1;
}
if (la == 0 || lc == n+1) continue;
res += (long)(la-1)*(n-lc);
}
System.out.println(res);
}
}
算法3:双指针
时间复杂度:O(n)。约444 ms。
参考题解:https://www.acwing.com/solution/content/19218/
只需要将二分部分修改为双指针就行。
分别对数组a和c进行指针扫描。
#include <iostream>
#include <cstdio>
#include <algorithm>
using namespace std;
typedef long long LL;
const int N = 100010;
int a[N],b[N],c[N];
int main(){
int n;
cin >> n;
for (int i = 0;i < n;i++) scanf("%d",&a[i]);
for (int i = 0;i < n;i++) scanf("%d",&b[i]);
for (int i = 0;i < n;i++) scanf("%d",&c[i]);
sort(a,a+n);
sort(b,b+n);
sort(c,c+n);
LL res = 0;
int la = 0,rc = 0;
for (int i = 0;i < n;i++){
int key = b[i];
while (la <= n-1 && a[la] < b[i]) la++;
while (rc <= n-1 && c[rc] <= b[i]) rc++;
res += LL(la)*(n-rc);
}
printf("%lld",res);
return 0;
}