以下讲解出自我的数据结构笔记:我的数据结构笔记
IV. k-D Tree(KDT , k-Dimension Tree)
KDT 是一种特殊的二叉搜索树,可以高效处理 $k$ 维空间信息。
其相对于 CDQ 分治的优点是可以高度模板化,而不需要额外写函数;另外还可以带修。
在算法竞赛中,一般 $k = 2$。下文中如果无特殊说明,$k = 2$。
首先假设平面上有这么一堆点:
我要维护这一堆点。
首先有一个简单的想法:找到一个点,以这个点为分界点,在它左边(横坐标比他小)的分到他的左子树,剩下的分到右子树。左右递归处理。
但是这样做是不对的。比如在上图中我选点,依次选 $F, E, B, D, C, A$,树高直接干到 $O(n)$。这是非常不好的,因为这意味着我查询的时候可能要遍历所有点。
所以想到了另外的一种方法:每次按照横坐标排序,找到横坐标中位数所对应的点。把这个点当做划分点。这样左右两边的点数就相等了(?)。
乍一看,这样每次都能减少一半的点,树高就是 $\log n + O(1)$ 了。但是其实不然。比如你想这样一种情况:假设所有点的横坐标都相等,那么每次有 $n - 1$ 个点被划分到左边。所以还是会被卡成一条链。
因此需要更牛逼的优化:交替建树。
交替建树的思想是这样的:首先按照 $x$ 坐标排序,选择中位数作为划分点。与刚才做法不一样的是,他的左右儿子应该选择按照 $y$ 坐标排序,选择中位数作为划分点。接下来再按照 $x$,再按照 $y$,以此类推。由于排序键值交替变化,所以叫做交替建树。
可以发现,这样做,构建出的 K-D Tree 高度就是 $\log n + O(1)$ 了。
比如刚才那个图中,建树过程是这样的:
建出的树形态是这样的:
于是,KDT 就以划分的方式,维护了平面上的 $n$ 个点。
有人说 KDT 像线段树,但我觉得更像平衡树。因为线段树的非叶节点是不存储信息的,而 KDT 存储信息。另外,KDT 也具有可二分性。对于一个点 $k$,假设其划分依据为 $z(z = 1, 2)$ 维,那么其左边的 $z$ 维小于它,右面的大于它。这也与 BST 更为相似。
KDT 建完了,那么该如何进行操作呢?
说起来非常玄学,KDT 进行操作的方法就像是暴力剪枝。对于一个点来说,对于每个维度 $z(z \in \mathbb{Z}, z \le k)$,需要维护这个维度意义下最靠左的点和最靠右的点。这样相当于把整个平面划分成了若干个矩形。
如果当前子树对应的矩形与所求矩形没有交点,则不继续搜索其子树;如果当前子树对应的矩形完全包含在所求矩形内,返回当前子树内所有点的权值和;否则,判断当前点是否在所求矩形内,更新答案并递归在左右子树中查找答案(这段话来自 OI-wiki,因为写的太好了就直接摘过来了)。
时间复杂度我不会证明啊,有兴趣看 OI-wiki 吧。最后结论是 $T(n) = O(n ^ {1 - \frac{1}{k}})$。
于是这道题就很好写了。直接上个 KDT 矩形求和就可以了。
注意,写这道题的时候,千万不要用类似二维前缀和的形式求四遍矩形和。这样会导致常数爆炸。
需要稍微卡常。
#pragma GCC optimize("Ofast")
#include <algorithm>
#include <iostream>
#include <cstring>
#include <cstdio>
#include <vector>
#define x first
#define y second
#define max(a, b) (a > b ? a : b)
#define min(a, b) (a < b ? a : b)
#define all(p) p.begin(), p.end()
#define rep(i, a, b) for (register int i = (a); i <= (b); i ++ )
using namespace std;
typedef long long LL;
const int N = 100010;
int n, m, K, rt;
struct node {
int ls, rs, d[2], mn[2], mx[2], v; LL s;
bool operator < (const node &t)const { return d[K] < t.d[K]; }
}tr[N];
#define ls tr[u].ls
#define rs tr[u].rs
inline void chkmin(int &a, int b) { a = min(a, b); }
inline void chkmax(int &a, int b) { a = max(a, b); }
void pushup(int u) {
rep(i, 0, 1) { tr[u].mn[i] = tr[u].mx[i] = tr[u].d[i], tr[u].s = tr[u].v;
if (ls) chkmin(tr[u].mn[i], tr[ls].mn[i]), chkmax(tr[u].mx[i], tr[ls].mx[i]),
tr[u].s += tr[ls].s;
if (rs) chkmin(tr[u].mn[i], tr[rs].mn[i]), chkmax(tr[u].mx[i], tr[rs].mx[i]),
tr[u].s += tr[rs].s;
}
}
inline int build(int l, int r, int k) {
if (l > r) return 0; int u = l + r >> 1;
K = k; nth_element(tr + l, tr + u, tr + r + 1);
ls = build(l, u - 1, k ^ 1); rs = build(u + 1, r, k ^ 1);
pushup(u); return u;
}
bool in(int x, int l, int r) { return x >= l and x <= r; }
inline LL ask(int u, int x1, int y1, int x2, int y2) {
if (tr[u].mn[0] > x2 or tr[u].mn[1] > y2) return 0;
if (tr[u].mx[0] < x1 or tr[u].mx[1] < y1) return 0;
if (in(tr[u].mn[0], x1, x2) and in(tr[u].mx[0], x1, x2)
and in(tr[u].mn[1], y1, y2) and in(tr[u].mx[1], y1, y2)) return tr[u].s;
LL s = ask(ls, x1, y1, x2, y2) + ask(rs, x1, y1, x2, y2);
if (in(tr[u].d[0], x1, x2) and in(tr[u].d[1], y1, y2)) s += tr[u].v; return s;
}
int main() {
scanf("%d%d", &n, &m);
rep(i, 1, n) scanf("%d%d%d", &tr[i].d[0], &tr[i].d[1], &tr[i].v);
int rt = build(1, n, 0); rep(i, 1, m) {
int x1, y1, x2, y2; scanf("%d%d%d%d", &x1, &y1, &x2, &y2);
printf("%lld\n", ask(rt, x1, y1, x2, y2));
} return 0;
}