题目描述
给定三个整数数组
A=[A1,A2,…AN],
B=[B1,B2,…BN],
C=[C1,C2,…CN],
请你统计有多少个三元组 (i,j,k)
满足:
1≤i,j,k≤N
Ai<Bj<Ck
输入格式
第一行包含一个整数 N。
第二行包含 N 个整数 A1,A2,…AN。
第三行包含 N 个整数 B1,B2,…BN。
第四行包含 N 个整数 C1,C2,…CN。
输出格式
一个整数表示答案。
数据范围
1 ≤ N ≤ 105,
0 ≤ Ai,Bi,Ci ≤ 105
样例
输入样例:
3
1 1 1
2 2 2
3 3 3
输出样例:
27
算法1
(暴力枚举) $O(n^3)$
超时,要在此基础上做优化。
C++ 代码
#include <iostream>
#include <cstring>
#include <algorithm>
using namespace std;
const int N = 100010;
int n;
int a[N], b[N], c[N];
int main()
{
cin >> n;
for (int i = 0; i < n; i ++ ) cin >> a[i];
for (int i = 0; i < n; i ++ ) cin >> b[i];
for (int i = 0; i < n; i ++ ) cin >> c[i];
int res = 0;
for (int i = 0; i < n; i ++ )
for (int j = 0; j < n; j ++ )
for (int k = 0; k < n; k ++ )
if (a[i] < b[j] && b[j] < c[k])
res ++ ;
cout << res;
return 0;
}
算法2
(双指针) $O(n)$
C++ 代码
#include <iostream>
#include <cstring>
#include <algorithm>
using namespace std;
typedef long long LL;
const int N = 100010;
int n;
int num[3][N];
int main()
{
cin >> n;
for (int i = 0; i < 3; ++ i)
for (int j = 1; j <= n; ++ j)
scanf("%d", &num[i][j]);
for (int i = 0; i < 3; i ++ )
sort(num[i] + 1, num[i] + n + 1);
LL res = 0;
int a = 1, c = 1;
for (int i = 1; i <= n; ++ i)
{
int key = num[1][i];
while (a <= n && num[0][a] < key) a ++ ;
while (c <= n && num[2][c] <= key) c ++ ;
res += (LL)(a - 1) * (n - c + 1);
}
cout << res;
return 0;
}
算法2
(前缀和) $O(n)$
C++ 代码
#include <iostream>
#include <cstdio>
#include <cstring>
using namespace std;
typedef long long LL;
const int N = 1e5+10;
int a[N], b[N], c[N];
int cnt[N], s[N], sa[N], sc[N];
int main()
{
int n;
cin >> n;
for (int i = 0; i < n; i ++ ) scanf("%d", &a[i]), a[i] ++ ;
for (int i = 0; i < n; i ++ ) cin >> b[i], b[i] ++ ;
for (int i = 0; i < n; i ++ ) cin >> c[i], c[i] ++ ;
for (int i = 0; i < n; i ++ ) cnt[a[i]] ++ ;
for (int i = 1; i < N; i ++ ) s[i] = s[i - 1] + cnt[i];
for (int i = 0; i < n; i ++ ) sa[i] = s[b[i] - 1];
memset(cnt, 0, sizeof cnt);
memset(s, 0, sizeof s);
for (int i = 0; i < n; i ++ ) cnt[c[i]] ++ ;
for (int i = 1; i < N; i ++ ) s[i] = s[i - 1] + cnt[i];
for (int i = 0; i < n; i ++ ) sc[i] = s[N - 1] - s[b[i]];
LL res = 0;
for (int i = 0; i < n; i ++ ) res += (LL)sa[i] * sc[i];
cout << res << endl;
return 0;
}