题目描述
给你一个下标从 0 开始的整数数组 nums
。如果一对整数 x
和 y
满足以下条件,则称其为 强数对:
|x - y| <= min(x, y)
你需要从 nums
中选出两个整数,且满足:这两个整数可以形成一个强数对,并且它们的按位异或(XOR
)值是在该数组所有强数对中的 最大值。
返回数组 nums
所有可能的强数对中的 最大 异或值。
注意,你可以选择同一个整数两次来形成一个强数对。
样例
输入:nums = [1,2,3,4,5]
输出:7
解释:数组 nums 中有 11 个强数对:
(1, 1), (1, 2), (2, 2), (2, 3), (2, 4), (3, 3),
(3, 4), (3, 5), (4, 4), (4, 5) 和 (5, 5)。
这些强数对中的最大异或值是 3 XOR 4 = 7。
输入:nums = [10,100]
输出:0
解释:数组 nums 中有 2 个强数对:(10, 10) 和 (100, 100)。
这些强数对中的最大异或值是 10 XOR 10 = 0,
数对 (100, 100) 的异或值也是 100 XOR 100 = 0。
输入:nums = [500,520,2500,3000]
输出:1020
解释:数组 nums 中有 6 个强数对:(500, 500), (500, 520),
(520, 520), (2500, 2500), (2500, 3000) 和 (3000, 3000)。
这些强数对中的最大异或值是 500 XOR 520 = 1020;
另一个异或值非零的数对是 (5, 6),其异或值是 2500 XOR 3000 = 636。
限制
1 <= nums.length <= 5 * 10^4
1 <= nums[i] <= 2^20 - 1
算法
(字典树) $O(n \log n + n \log L)$
- 考虑表达式 $|x - y| \le \min(x, y)$,不妨假设 $x >= y$,则可以推导出 $y \le x \le 2y$。
- 将数组从小到大排序,按顺序遍历每个数字作为 $x$,在已遍历过的数字中找到合法的 $y$,使得 $x \text{ XOR } y$ 最大。
- 遍历时,将当前数字从高位到低位插入到 01 字典树中(深度为 20),然后考虑已遍历过的数字,如果发现存在数字 $x > 2y$,则将 $y$ 从字典树中删除。
- 求解最大的异或值,可以通过在每一层选择尽可能与 $x$ 相反的路径(如果存在)。
时间复杂度
- 排序的时间复杂度为 $O(n \log n)$。
- 遍历数组,每个数字都需要 $O(\log L)$ 的时间插入、删除以及查询字典树。其中 $L$ 为最大的数字。
- 故总时间复杂度为 $O(n \log n + n \log L)$。
空间复杂度
- 需要 $O(\log n + n \log L)$ 的额外空间存储排序的系统栈和字典树。
C++ 代码
struct Node {
Node *nxt[2];
int cnt;
Node() {
nxt[0] = nxt[1] = NULL;
cnt = 0;
}
};
class Solution {
private:
Node *root;
void insert(int x) {
Node *p = root;
for (int i = 19; i >= 0; i--) {
int t = (x >> i) & 1;
if (p->nxt[t] == NULL)
p->nxt[t] = new Node();
p = p->nxt[t];
++p->cnt;
}
}
void remove(int x) {
Node *p = root;
for (int i = 19; i >= 0; i--) {
int t = (x >> i) & 1;
p = p->nxt[t];
--p->cnt;
}
}
int find(int x) {
Node *p = root;
int res = 0;
for (int i = 19; i >= 0; i--) {
int t = (x >> i) & 1;
if (p->nxt[t ^ 1] && p->nxt[t ^ 1]->cnt > 0) {
res |= 1 << i;
p = p->nxt[t ^ 1];
} else {
p = p->nxt[t];
}
}
return res;
}
public:
int maximumStrongPairXor(vector<int>& nums) {
sort(nums.begin(), nums.end());
const int n = nums.size();
root = new Node();
int ans = 0;
for (int i = 0, j = 0; i < n; i++) {
insert(nums[i]);
while (j <= i && nums[i] > 2 * nums[j]) {
remove(nums[j]);
++j;
}
ans = max(ans, find(nums[i]));
}
return ans;
}
};