题目描述
给你一个大小为 m x n
的整数矩阵 grid
和一个大小为 k
的数组 queries
。
找出一个大小为 k
的数组 answer
,且满足对于每个整数 queres[i]
,你从矩阵 左上角 单元格开始,重复以下过程:
- 如果
queries[i]
严格 大于你当前所处位置单元格,如果该单元格是第一次访问,则获得 1 分,并且你可以移动到所有4
个方向(上、下、左、右)上任一 相邻 单元格。 - 否则,你不能获得任何分,并且结束这一过程。
在过程结束后,answer[i]
是你可以获得的最大分数。注意,对于每个查询,你可以访问同一个单元格 多次。
返回结果数组 answer
。
样例
输入:grid = [[1,2,3],[2,5,7],[3,5,1]], queries = [5,6,2]
输出:[5,8,1]
解释:上图展示了每个查询中访问并获得分数的单元格。
输入:grid = [[5,2,1],[1,1,2]], queries = [3]
输出:[0]
解释:无法获得分数,因为左上角单元格的值大于等于 3。
限制
m == grid.length
n == grid[i].length
2 <= m, n <= 1000
4 <= m * n <= 10^5
k == queries.length
1 <= k <= 10^4
1 <= grid[i][j], queries[i] <= 10^6
算法
(离线排序,并查集) $O(mn \log (mn) + k \log k)$
- 每个查询相当于找到与左上角的连通块的大小,满足连通块内所有元素的值都严格小于
queries[i]
。 - 将矩阵的遍历下标按元素从小到大排序,将询问的下标按询问的值从小到大排序。
- 初始化并查集并记录连通集的大小。
- 对于排序后每个询问 $q$,按顺序遍历矩阵下标,并将满足条件的位置进行四连通合并。操作完毕后与左上角位置连通的集合大小就是当前询问的答案。
时间复杂度
- 排序的时间复杂度为 $O(mn \log (mn) + k \log k)$。
- 统计答案的时间复杂度为 $O(mn + k)$。
- 故总时间复杂度为 $O(mn \log (mn) + k \log k)$。
空间复杂度
- 需要 $O(mn + k)$ 的额外空间存储并查集,矩阵下标,询问下标和答案。
C++ 代码
const int dx[] = {0, 1, 0, -1};
const int dy[] = {1, 0, -1, 0};
class Solution {
private:
vector<int> f, sz;
int find(int x) {
return x == f[x] ? x : f[x] = find(f[x]);
}
void merge(int x, int y) {
int fx = find(x), fy = find(y);
if (fx == fy)
return;
if (sz[fx] < sz[fy]) {
sz[fy] += sz[fx];
f[fx] = fy;
} else {
sz[fx] += sz[fy];
f[fy] = fx;
}
}
public:
vector<int> maxPoints(vector<vector<int>>& grid, vector<int>& queries) {
const int m = grid.size(), n = grid[0].size(), k = queries.size();
f.resize(m * n);
sz.resize(m * n, 1);
vector<int> idx(m * n);
for (int i = 0; i < m; i++)
for (int j = 0; j < n; j++) {
f[i * n + j] = i * n + j;
idx[i * n + j] = i * n + j;
}
sort(idx.begin(), idx.end(), [&](int x, int y) {
return grid[x / n][x % n] < grid[y / n][y % n];
});
vector<int> q(k);
for (int i = 0; i < k; i++)
q[i] = i;
sort(q.begin(), q.end(), [&](int x, int y) {
return queries[x] < queries[y];
});
vector<int> ans(k, 0);
for (int i = 0, j = 0; i < k; i++) {
while (j < m * n) {
int x = idx[j] / n, y = idx[j] % n;
if (grid[x][y] >= queries[q[i]])
break;
for (int t = 0; t < 4; t++) {
int tx = x + dx[t], ty = y + dy[t];
if (tx < 0 || tx >= m || ty < 0 || ty >= n)
continue;
if (grid[tx][ty] >= queries[q[i]])
continue;
merge(idx[j], tx * n + ty);
}
j++;
}
if (grid[0][0] < queries[q[i]])
ans[q[i]] = sz[find(0)];
}
return ans;
}
};