这种离散化很方便,而且速度也比视频讲解里的要快
红黑树TreeMap维护一个升序哈希表,key是原数据,value是原数据对应的下标,直接a[下标] = i++
即可
Java
/**
* 操作a或b或者a和b结果都一样
* 先进行离散化,缩小值域并且使得两个数组元素完全重叠
* 如果a严格升序,那么结果则为b的逆序对数量
* 但是a不一定升序,所以构造一个等价的升序数组,将a映射过去
* 令a[i] = i,保存映射结果即为c[a[i]] = i,意思是按照顺序把a[i]这个数映射为i,这样,只要我们有初始数字,就可以通过c数组求出对应的映射
* 对于b,要接收这个映射,求出每个b[j]的映射,即b[j] = c[b[j]]
*
*/
import java.io.*;
import java.util.*;
public class Main {
static final int N = (int) 1e5 + 10;
static final int M = 99999997;
static Integer[] a = new Integer[N];
static Integer[] b = new Integer[N];
static Integer[] p = new Integer[N];
static Integer[] c = new Integer[N];
static Map<Integer, Integer> map = new TreeMap<>();
static int n;
static void work(Integer[] a) {
// 相当易懂且速度比下面快的离散化方式
map.clear();
for (int i = 1; i <= n; i++)
map.put(a[i], i);
int i = 1;
for (var integer : map.entrySet()) {
a[integer.getValue()] = i++;
}
// 非常难懂的离散化方式
// for (int i = 1; i <= n; i++) p[i] = i;
// Arrays.sort(p, 1, n + 1, (x, y) -> a[x] - a[y]);
// for (int i = 1; i <= n; i++) a[p[i]] = i;
}
static int merge_sort(int l, int r) {
if (l >= r)
return 0;
int mid = (l + r) / 2;
int res = (merge_sort(l, mid) + merge_sort(mid + 1, r)) % M;
int i = l, j = mid + 1, k = 0;
while (i <= mid && j <= r)
if (b[i] < b[j]) p[k++] = b[i++];
else {
p[k++] = b[j++];
res = (res + mid - i + 1) % M;
}
while (i <= mid)
p[k++] = b[i++];
while (j <= r)
p[k++] = b[j++];
for (i = 0, j = l; i < k; i++, j++)
b[j] = p[i];
return res;
}
public static void main(String[] args) throws IOException {
BufferedReader br = new BufferedReader(new InputStreamReader(System.in));
PrintWriter pr = new PrintWriter(new BufferedWriter(new OutputStreamWriter(System.out)));
n = Integer.parseInt(br.readLine());
String[] s = br.readLine().split(" ");
for (int i = 1; i <= n; i++) {
a[i] = Integer.parseInt(s[i - 1]);
}
s = br.readLine().split(" ");
for (int i = 1; i <= n; i++) {
b[i] = Integer.parseInt(s[i - 1]);
}
// 离散化
work(a);
work(b);
// 映射,
for (int i = 1; i <= n; i++) c[a[i]] = i;
for (int i = 1; i <= n; i++) b[i] = c[b[i]];
pr.println(merge_sort(1, n));
pr.flush();
}
}
C++
#include <iostream>
#include <map>
using namespace std;
const int N = 1e5 + 10, M = 99999997;
int a[N], b[N], c[N], p[N];
map<int, int> mp;
int n;
void work(int a[])
{
mp.clear();
for (int i = 1; i <= n; i++)
mp.insert({a[i], i});
int i = 1;
for (auto m : mp)
a[m.second] = i++;
}
int merge_sort(int l, int r)
{
if (l >= r)
return 0;
int mid = l + r >> 1;
int res = (merge_sort(l, mid) + merge_sort(mid + 1, r)) % M;
int i = l, j = mid + 1, k = 0;
while (i <= mid && j <= r)
{
if (b[i] <= b[j])
p[k++] = b[i++];
else
{
res = (res + mid - i + 1) % M;
p[k++] = b[j++];
}
}
while (i <= mid)
p[k++] = b[i++];
while (j <= r)
p[k++] = b[j++];
for (i = 0, j = l; i < k; i++, j++)
b[j] = p[i];
return res;
}
int main()
{
ios::sync_with_stdio(false);
cin.tie(0);
cin >> n;
for (int i = 1; i <= n; i++)
cin >> a[i];
for (int i = 1; i <= n; i++)
cin >> b[i];
work(a), work(b);
for (int i = 1; i <= n; i++)
c[a[i]] = i;
for (int i = 1; i <= n; i++)
b[i] = c[b[i]];
cout << merge_sort(1, n);
return 0;
}