给定 $n$ 个正整数 $a_1, a_2, \dots, a_n$,求这些数异或表示出的数的集合中,第 $k$ 小的是多少(不能选空集)。
事实上和模板几乎一样。
还是考虑从高位到低位贪心确定,首先如果基内有 $c$ 个元素,那么它们可以组合出 $2^c$ 种数。
那么我们只要考虑 $k$ 和 $2^{c-1}$ 的关系就知道这一位是 $0$ 还是 $1$ 了。
具体地,设到当前确定的数为 $x$,到了第 $u$ 位,接下来还有 $c$ 个基要处理(因为有一些位的基可能不存在所以要记录 $c$)。
如果 $x$ 的第 $u$ 位已经是 $1$,那么选上这一位的基底会让 $x$ 更小,不选会更大,根据 $k$ 属于哪边递归求解。
#include <bits/stdc++.h>
using namespace std;
const int N = 70;
int n, m, cnt;
long long x, sum, k, base[N];
bool flag;
bool chk(long long x, int i) { return (x >> i) & 1; }
void insert(long long x) {
for (int i = 60; i >= 0; i--) {
if (!chk(x, i)) continue;
if (!base[i]) return base[i] = x, cnt++, void();
x ^= base[i];
}
flag = 1;
}
long long query(int u, long long x, int k, int c) {
if (u == -1) return x;
if (!base[u]) return query(u - 1, x, k, c);
long long mid = (1ll << c - 1);
if (chk(x, u)) { // 选了更小
if (k <= mid) return query(u - 1, x ^ base[u], k, c - 1);
else return query(u - 1, x, k - mid, c - 1);
} else { // 不选更小
if (k <= mid) return query(u - 1, x, k, c - 1);
else return query(u - 1, x ^ base[u], k - mid, c - 1);
}
}
void solve() {
flag = 0, cnt = 0; for (int i = 0; i <= 60; i++) base[i] = 0ll;
scanf("%d", &n);
for (int i = 1; i <= n; i++) scanf("%lld", &x), insert(x);
scanf("%d", &m);
while (m--) {
scanf("%lld", &k);
if (!flag) k++; // 如果不能异或出 0 就要 +1
if (k > (1ll << cnt)) puts("-1");
else printf("%lld\n", query(60, 0, k, cnt) );
}
}
int main() {
int T; scanf("%d", &T);
for (int Tid = 1; Tid <= T; Tid++) {
printf("Case #%d:\n", Tid);
solve();
}
return 0;
}
这样写代码真的很丑……但有一种更优雅的做法可以让代码变得较为简单。
考虑按照之前的方法求出的基如何再简化:相当于拿高位的基再和低位的基异或,把高位基消掉尽可能多低位的 $1$。
说着挺绕的,但事实上按照高斯消元的思维模式很好理解:
10011
01110
00110
00001
这样的一组基,我们把上面的不断异或掉下面的,就会得到:
10010
01000
00110
00001
这样转化有什么用?
这样每一组基的最高位,在矩阵上的位置设为 $(x,y)$,则 $(x,y)$ 上方一定都是 $0$,因为它之前的基一定可以通过异或它把这一位消掉。
相当于少一层分讨,因为不需要特判 $x$ 的第 $u$ 位是否是 $1$ 了,选了一定更大,不选一定更小。
因此将询问 $k$ 的二进制表示为 $1$ 的那些基底异或起来就是答案。
因为之前 $mid = 2^{c-1}$,所以要把全 $0$ 行删掉再给 $k$ 做二进制分解,也就是删掉 base 中为 $0$ 的元素。
#include <bits/stdc++.h>
using namespace std;
const int N = 70;
int n, m, cnt, tot;
long long x, sum, k, base[N];
bool flag;
bool chk(long long x, int i) { return (x >> i) & 1; }
void insert(long long x) {
for (int i = 60; i >= 0; i--) {
if (!chk(x, i)) continue;
if (!base[i]) return base[i] = x, cnt++, void();
x ^= base[i];
}
flag = 1;
}
void init() {
for (int i = 60; i >= 0; i--)
for (int j = i - 1; j >= 0; j--)
if (chk(base[i], j)) base[i] ^= base[j];
tot = 0;
for (int i = 0; i <= 60; i++)
if (base[i]) base[tot++] = base[i]; // 一定要记得加这个,因为之前 mid = 2^{c-1},所以要把全 0 行删掉
}
long long query(long long k) {
if (k >> cnt) return -1;
long long res = 0;
for (int i = 0; i < tot; i++) res ^= chk(k, i) * base[i];
return res;
}
void solve() {
flag = 0, cnt = 0; for (int i = 0; i <= 60; i++) base[i] = 0ll;
scanf("%d", &n);
for (int i = 1; i <= n; i++) scanf("%lld", &x), insert(x);
init();
scanf("%d", &m);
while (m--) scanf("%lld", &k), printf("%lld\n", query(k - flag) );
}
int main() {
int T; scanf("%d", &T);
for (int Tid = 1; Tid <= T; Tid++) {
printf("Case #%d:\n", Tid);
solve();
}
return 0;
}