题目描述
难度分:$2000$
输入$n(3 \leq n \leq 3 \times 10^5)$和长为$n$的数组$x(-10^8 \leq x[i] \leq 10^8)$,然后输入 $k(2 \leq k \leq n-1)$。
一维数轴上有$n$个点,第$i$个点位于$x[i]$。从中选择$k$个点,最小化这$k$个点的两两距离之和。
输出这$k$个点的下标(按照输入顺序,从$1$开始)。多解输出任意解。
输入样例
3
1 100 101
2
输出样例
2 3
算法
前缀和+滑动窗口
比较容易发现的一点就是需要将所有的$x$先排序,最优解一定可以从有序$x$的某一个子数组中产生。接下来滑动长度为$k$的窗口,维护窗口内所有点两两距离和的最小值就可以了。
问题就在于如何高效求取子数组内所有点的两两距离之和?可以先预处理一个数组$s$,其中$s[i]=\Sigma_{j=1}^{i}x[j]$。再预处理出一个数组$ss$,其中$ss[i]$表示$i$与之前所有点的距离之和,即
$ss[i]=(i-1) \times x[i] - \Sigma_{j=1}^{i-1}x[j]=(i-1) \times x[i] - s[i-1]$
然后原地在$ss$上求前缀和,这样它子数组$[l,r]$(满足$r-l+1=k$)的累加和就表示对于所有$i \in [l,r]$,$i$与它前面所有点的距离之和。但是这并不是我们要求的$[l,r]$内部所有点的两两距离之和,我们还需要对每个点$i$去掉它们与$[1,l-1]$所有点的距离之和,即$(l - 1) \times x[i] - \Sigma_{j=1}^{l-1}x[j]=(l - 1) \times x[i] - s[l-1]$。$[l,r]$内部所有点的两两之间距离和就是
$mn=ss[r]-ss[l-1]-\Sigma_{i=l}^{r}((l - 1) \times x[i] - s[l-1])$
$=ss[r]-ss[l-1]-((l-1) \times (s[r]-s[l-1])-(r-l+1) \times s[l-1])$
维护$mn$的最小值,然后输出其对应的$k$个点编号就行了。
复杂度分析
时间复杂度
求前缀和,以及滑动窗口维护最小距离和的操作都是线性的,时间复杂度为$O(n)$。因此,算法的瓶颈就在于最初的排序,时间复杂度是$O(nlog_2n)$。
空间复杂度
开辟了两个前缀和数组$s$和$ss$,因此额外空间复杂度为$O(n)$。
C++ 代码
#include <bits/stdc++.h>
using namespace std;
typedef long long LL;
const int N = 300010;
int n, k;
LL s[N], ss[N];
struct Node {
int index, x;
bool operator<(const Node other) const {
return x < other.x;
}
} node[N];
int main() {
scanf("%d", &n);
for(int i = 1; i <= n; i++) {
scanf("%d", &node[i].x);
node[i].index = i;
}
scanf("%d", &k);
sort(node + 1, node + n + 1);
for(int i = 1; i <= n; i++) {
s[i] = s[i - 1] + node[i].x;
}
for(int i = 1; i <= n; i++) {
ss[i] = ss[i - 1] + (i - 1LL)*node[i].x - s[i - 1];
}
LL mn = ss[k];
for(int r = k + 1; r <= n; r++) {
int l = r - k + 1;
LL temp = ss[r] - ss[l - 1] - ((l - 1LL)*(s[r] - s[l - 1]) - (r - l + 1)*s[l - 1]);
mn = min(mn, temp);
}
for(int r = k + 1; r <= n; r++) {
int l = r - k + 1;
if(mn == ss[r] - ss[l - 1] - ((l - 1LL)*(s[r] - s[l - 1]) - (r - l + 1)*s[l - 1])) {
vector<int> pos;
for(int i = l; i <= r; i++) {
pos.push_back(node[i].index);
}
sort(pos.begin(), pos.end());
for(int x: pos) printf("%d ", x);
exit(0);
}
}
vector<int> pos;
for(int i = 1; i <= k; i++) {
pos.push_back(node[i].index);
}
sort(pos.begin(), pos.end());
for(int x: pos) printf("%d ", x);
return 0;
}