建议先看懂这个视频之后再来食用本帖
题目描述
给定两个有序数组$nums1$和$nums2$,长度分别为$m,n$。请找出它们的中位数,要求时间复杂度为 $O(log(min \{ n, m \} ))$,空间复杂度为$O(1)$。
样例1
nums1 = [1, 3]
nums2 = [2]
中位数是 2.0
样例2
nums1 = [1, 2]
nums2 = [3, 4]
中位数是 (2 + 3) / 2 = 2.5
算法1
(二分查找) $O(log(min \{ n, m \} ))$
本文仅对Y总的题解(即参考文献1
)里的第二种解法作一点补充。
1. 虚拟元素
在Y总的题解里,通过向数组$A_1$加入虚拟元素$@$转换成数组$A’_1$, 即
$ A_1:[1,2,3,4,5] => A’_1:[@,1,@,2,@,3,@,4,@,5,@] $
$\quad$那么原来数组里的分割点
可能在两数之间,也可能在某个数字上,经过这种转化之后,我们
就可以将分割点
用数组$A’_1$的下标来表示了,这方便了枚举分割点
,这是添加虚拟元素这步操作的动机
。
2. 平均分割数组
$\quad$在讲完分割点
坐标化之后,那么问题就变成了:
$\quad$如何在两个原数组里找到两个分割点来平均地分割原数组,也就是使得两个原数组里的元素可以被平均地分配在两个分割点的左右两边?。
$\quad$在Y总的题解里,证明$C_1+C_2=N_1+N_2$时,他在证明过程里提到:”除了$C_1$和$C_2$以外,共有$2N_1+2N_2$个元素,要平均分配到左右两边,因此左边共有$N_1+N_2$个元素.”。这里的做法只是平均地分割了虚拟元素数组($A’_1$和$A’_2$),而不是上面说的平均地分割原数组
(即$A_1$和$A_2$)。
$\quad$可以证明:
$\qquad$若$C_1,C_2$平均地分割了虚拟元素数组$A’_1$和$A’_2$,那么$C_1,C_2$一定也同时平均分割了原数组$A_1$和$A_2$。
证明:
1. 首先,两个原数组长度都是奇数时,对应的$C_1$和$C_2$均为虚拟元素。我们采用整体法
,将在分割点前的元素两两打包为@,数字
,在分割点之后的元素两两打包为数字,@
,这两种打包元素
并没有本质上的区别,打包元素
的数量总共是$N_1+N_2$,为偶数。那么平均分配虚拟元素数组$A’_1$和$A’_2$就等价于平均地分割原数组
(即$A_1$和$A_2$),因为两个分割点的左侧均不存在半个打包元素
。证明两个原数组的长度都是偶数的情况与此类似,就不再赘述了。
2. 当一个原数组长度为奇数,另一个原数组长度为偶数时。长度为奇数的原数组的分割点
在数组元素上;长度为偶数的原数组的分割点
在数组元素之间。因此在$A’_1$和$A’_2$中,对应的两个分割点
$C_1$和$C_2$一个是数字,一个是虚拟元素。假定$C_1$为数字,$C_2$为虚拟元素,那么除了在分割点$C_1$左右相邻的两个$@$单独作为两个半个打包元素
之外,剩下的元素从用上面相同的打包方式(即分割点前的元素两两打包为@,数字
,在分割点之后的元素两两打包为数字,@
)。$C_1$和$C_2$两个分割点左边总共有整数个打包元素,然后我们又平均分割了虚拟数组元素$A’_1$和$A’_2$,很显然,我们也同时平均分割了两个原数组$A_1$和$A_2$。 证毕!
3. 计算中位数
$\quad$解决完如何分割
之后,接下来的问题就是:
$\quad$如何计算中位数?
$\quad$Y总的题解里给出了进一步的条件,也即除了需要满足$C_1+C_2=N_1+N_2$之外,还需要满足:
$$ L_1 \le R_1 \&\& L_1 \le R_2 \&\& L_2 \le R_1 \&\& L_2 \le R_2 $$
值得注意的是,在Y总的题解里,对$L_1,R_1,L_2,R_2$的定义中,都包含了分割点。在$C_1,C_2$均为
虚拟元素时,正确性很容易理解。而当其中一个分割点是数字,另一个是虚拟元素时,包含分割点的做法相当于一个有序数组,数组长度为奇数时,将中位数同时切分给左右两边一边一个。
如[2,4,6],切分后左边为[2,4], 右边为[4,6]。这样就方便统一使用 $median=(max(L_1,L_2) + min(R_1,R_2)) / 2.0$ 公式来计算中位数。需要注意的是:我们在前面讨论平均分割问题
时,当分割点$C_1$为数字,$C_2$为虚拟元素时,我们没有将分割点$C_1$上的数字一分为二左右各一份,因为平均分割问题
关注的重点是两个分割点是不是将原数组里的数字平均地分割在两个分割点的两侧,无论分割点是不是数字
。只要做到了平均分割原数组
,$C_1$、$C_2$就有可能是我们要找的分割点。这样就将分割
与计算
分离开来,也即先考虑如何正确分割,正确分割之后再考虑如何计算。
4.扩展问题讨论
-
注意这个算法也可以用来找出两个有序数组的
上中位数
,此时公式为$median=max(L_1,L_2)$。感兴趣的可以用牛客网的 在两个长度相等的排序数组中找到上中位数 这道题练练手。 -
注意
:代码里提供了更加通用的版本$Kth$版(即寻找到这两个有序数组的第$K$小的数),二分查找时的边界推导如下:
$\quad$类比前面的$C_1+C_2=N_1+N_2$,此时公式应为:$C_1+C_2=2K-1$。又因为$0 \le C_1 \le 2N_1$且$0 \le C_2 \le 2N_2$,容易推得:$max(0, 2K-1-2N_1) \le C_2 \le min(2N_2,2K-1)$。
很显然,算法复杂度依旧为$O(log(min \{ n, m \} ))$。搞懂之后可以用 多数组第 K 小数 练习一下。
参考文献
C++ 代码
1. 中位数版之原始版
class Solution {
public:
double findMedianSortedArrays(vector<int>& nums1, vector<int>& nums2) {
return getMedian(nums1, nums2);
}
double getMedian(vector<int>& a, vector<int>& b){
const int N1 = a.size();
const int N2 = b.size();
if (N1 < N2) return getMedian(b, a);
const int N0 = N1 + N2;
int lo = 0, hi = N2 * 2;
while (lo <= hi) {
int mid2 = (lo + hi) / 2;
int mid1 = N0 - mid2;
int L1 = (mid1 == 0) ? INT_MIN : a[(mid1-1)/2];
int L2 = (mid2 == 0) ? INT_MIN : b[(mid2-1)/2];
int R1 = (mid1 == N1 * 2) ? INT_MAX : a[(mid1)/2];
int R2 = (mid2 == N2 * 2) ? INT_MAX : b[(mid2)/2];
if (L1 > R2) lo = mid2 + 1;
else if (L2 > R1) hi = mid2 - 1;
else return (max(L1,L2) + min(R1, R2)) / 2.0;
}
return -1.0;
}
};
2. 中位数版之始末坐标版
class Solution {
public:
double findMedianSortedArrays(vector<int>& nums1, vector<int>& nums2) {
return getMedian(nums1, 0, nums1.size() - 1, nums2, 0, nums2.size() - 1);
}
double getMedian(vector<int>& a, int start1, int end1, vector<int>& b, int start2, int end2){
const int N1 = end1 - start1 + 1;
const int N2 = end2 - start2 + 1;
if(N1 < N2) return getMedian(b, start2, end2, a, start1, end1);
const int N0 = N1 + N2;
int low = 0, high = 2 * N2;
while(low <= high){
int mid2 = (low + high) >> 1;
int mid1 = N0 - mid2;
int L1 = (mid1 == 0) ? INT_MIN : a[start1 + (mid1 - 1) / 2];
int L2 = (mid2 == 0) ? INT_MIN : b[start2 + (mid2 - 1) / 2];
int R1 = (mid1 == 2 * N1) ? INT_MAX : a[start1 + mid1 / 2];
int R2 = (mid2 == 2 * N2) ? INT_MAX : b[start2 + mid2 / 2];
if(L1 > R2){
low = mid2 + 1;
}
else if(L2 > R1){
high = mid2 - 1;
}
else {
return (max(L1, L2) + min(R1, R2)) / 2.0;
}
}
return -1.0;
}
};
3. $Kth$版
3.1 $Kth$版之原始版
class Solution {
public:
double findMedianSortedArrays(vector<int>& nums1, vector<int>& nums2) {
const int n = nums1.size();
const int m = nums2.size();
if((n + m) & 1){
return getKthNum((n + m + 1) / 2, nums1, nums2);
}
return (getKthNum((n + m) / 2, nums1, nums2) + getKthNum(1 + (n + m) / 2, nums1, nums2)) / 2.0;
}
// find Kth small number, e.g. [1,3,5], 1st small number is 1, 2nd small number is 3.
int getKthNum(int kth, vector<int>& a, vector<int>& b){
const int N1 = a.size();
const int N2 = b.size();
if (N1 < N2) return getKthNum(kth, b, a);
const int N0 = 2 * kth - 1;
int lo = max(0, N0 - 2 * N1), hi = min(2 * N2, N0);
while (lo <= hi) {
int cut2 = (lo + hi) / 2;
int cut1 = N0 - cut2;
int L1 = (cut1 == 0) ? INT_MIN : a[(cut1-1)/2];
int L2 = (cut2 == 0) ? INT_MIN : b[(cut2-1)/2];
int R1 = (cut1 == N1 * 2) ? INT_MAX : a[(cut1)/2];
int R2 = (cut2 == N2 * 2) ? INT_MAX : b[(cut2)/2];
if (L1 > R2) lo = cut2 + 1;
else if (L2 > R1) hi = cut2 - 1;
// also could be `return max(L1,L2);`
else return (max(L1,L2) + min(R1, R2)) / 2;
}
return -1;
}
};
3.2 $Kth$版之始末坐标版
class Solution {
public:
double findMedianSortedArrays(vector<int>& nums1, vector<int>& nums2) {
const int n = nums1.size();
const int m = nums2.size();
if((n + m) & 1){
return getKthNum((n + m + 1) / 2, nums1, 0, n - 1, nums2, 0, m - 1);
}
return (getKthNum((n + m) / 2, nums1, 0, n - 1, nums2, 0, m - 1) + getKthNum(1 + (n + m) / 2, nums1, 0, n - 1, nums2, 0, m - 1)) / 2.0;
}
// find Kth small number, e.g. [1,3,5], 1st small number is 1, 2nd small number is 3.
int getKthNum(int kth, vector<int>& a, int start1, int end1, vector<int>& b, int start2, int end2){
const int N1 = end1 - start1 + 1;
const int N2 = end2 - start2 + 1;
if(N1 < N2) return getKthNum(kth, b, start2, end2, a, start1, end1);
const int N0 = 2 * kth - 1;
int low = max(0, N0 - 2 * N1), high = min(2 * N2, N0);
while(low <= high){
int cut2 = (low + high) >> 1;
int cut1 = N0 - cut2;
int L1 = (cut1 == 0) ? INT_MIN : a[start1 + (cut1 - 1) / 2];
int L2 = (cut2 == 0) ? INT_MIN : b[start2 + (cut2 - 1) / 2];
int R1 = (cut1 == 2 * N1) ? INT_MAX : a[start1 + cut1 / 2];
int R2 = (cut2 == 2 * N2) ? INT_MAX : b[start2 + cut2 / 2];
if(L1 > R2){
low = cut2 + 1;
}
else if(L2 > R1){
high = cut2 - 1;
}
else {
// also could be `return max(L1,L2);`
return (max(L1, L2) + min(R1, R2)) / 2;
}
}
return -1;
}
};