头像

心里没有一点AC数

兰州大学 - 网易浚源工作室 - 西山居 - $\href{https://www.fogsail.net}{Blog}$




离线:2天前


最近来访(424)
用户头像
活在梦里_3
用户头像
daybreak
用户头像
Susu
用户头像
秦淮岸灯火阑珊
用户头像
寻求debug锦囊
用户头像
也.
用户头像
itdef
用户头像
hansk
用户头像
1564269628
用户头像
好好吃饭好好睡觉
用户头像
mo18183
用户头像
Aleksee
用户头像
-账号已注销-
用户头像
涤生
用户头像
淮南之橘丶
用户头像
𓆡𓆝𓆟𓆜𓆞_18
用户头像
我舅是太阳
用户头像
Sora_skyline
用户头像
delicacy
用户头像
whlbest


拆点最短路

要求从 $s \to t$ 的最便宜路径,不难想到将油价作为边权,由于油箱容量 $c \in [0, C]$
油箱容量 $c$ 也应该作为状态维度
$(u, c)$ 二元组表示当前汽车在点 $u$ 处,并且油箱有油 $c$ 单位

对于当前点 $v$,根据在 $v$ 处是否加油,存在 $2$ 种 $(u, v)$ 的状态转移

  • $(u, \ c+d(u, v)) \xrightarrow{0} (v, c)$,此时 $c + d(u, v) \leqslant C$
  • $(v, c-1) \xrightarrow{p_v} (v, c)$,在 $(v, c-1)$ 和 $(v, c)$ 间连一条权值为 $p_v$ 的边
    实际上这里将 $v$ 这个点拆成 $2$ 个状态,当然 $c-1 \geqslant 0 \ \textbf{and} \ c \leqslant C$

这样可以以 $(s, 0)$ 作为起点 $S$,执行 $\text{dijkstra}$,求出单源最短路
$\forall u, c: \quad d[(u, c)]$,最后答案就是 $\displaystyle\min_{c \in [0, C]} d(t, c)$

如果用哈希实现,将结构体作为 $key$,代码如下

class A {
public:
    int u, c;
    A(int u = 0, int c = 0) : u(u), c(c) {}
    bool operator< (const A &rhs) const {
        return u < rhs.u || (u == rhs.u && c < rhs.c);
    }
};

struct hashfun {
    std::size_t operator() (const A &a) const {
        return ( (hash<int>()(a.u))
            ^ (hash<int>()(a.c) << 1) );
    }
};

struct eq {
    bool operator() (const A &lhs, const A &rhs) const {
        return lhs.u == rhs.u && lhs.c == rhs.c;
    }
};

unordered_map<A, int, hashfun, eq> mp;

int get(const A &a) {
    return mp[a] ? mp[a] : (mp[a] = ++n);
}

如果本例先建出隐式图,再进行 $\text{dijkstra}$ 的话,每次都 $\text{build}$ 会造成算法开销过大
实际上,我们在 $\text{dijkstra}$ 的时候没有必要把整张图都建出来,因为有一些节点根本用不到
$\text{dijkstra}$ 的时候优先队列维护二元组 $(d[(u, c)], (u, c))$,对于当前状态 $(u, c)$
检查 $\forall \ v: \ (u, v) \in E$,对于边 $\forall v : \ (u, v)$,存在转移

  • $(u, c) \xrightarrow{p_u} (u, c+1)$,当然 $c+1 \leqslant C$,此时转移的权值代价为 $p_u$
  • $(u, c) \xrightarrow{0} (v, c-e(u, v))$,必须 $c-e(u, v) \geqslant 0$,转移权值为 $0$

起点为 $S = (s, 0)$,最后的答案为
$\forall c: \ \displaystyle\min_{c \in [0, C]}d[(t, c)]$
另外,由于 $\text{dijkstra}$ 不会有负权边,所以取出二叉堆堆顶元素 $(u, c)$,如果发现 $u = t$
那么应该立即结束 $\text{dijkstra}$,不需要继续扩展了,后续点的代价 $d$ 一定比 $d[(u, c)]$ 大
后续节点也用不到

const int maxn = 2e3 + 10, inf = 0x3f3f3f3f;
const int maxm = 2e4 + 10;

int n, m, C, s, t, p[maxn], res = inf;

namespace Graph {
    int idx = 1;
    int head[maxn], ver[maxm], e[maxm], ne[maxm];

    void add(int x, int y, int z) {
        ver[++idx] = y, e[idx] = z, ne[idx] = head[x], head[x] = idx;
    }

    void init() {
        idx = 1;
        memset(head, 0, sizeof head);
    }
};

class A {
public:
    int u, c;
    A(int u, int c) : u(u), c(c) {}
    bool operator< (const A &rhs) const {
        return u < rhs.u || (u == rhs.u && c < rhs.c);
    }
};

struct hashfun {
    int operator() (const A &a) const {
        return a.u * 1003 + a.c;
    }
};

struct eqfun {
    bool operator() (const A &lhs, const A &rhs) const {
        return lhs.u == rhs.u && lhs.c == rhs.c;
    }
};

unordered_map<A, int, hashfun, eqfun> mp;
int tot = 0;

inline int get(const A &a) {
    return mp[a] ? mp[a] : (mp[a] = ++tot);
}

const int N = 1e5 + 10;
int d[N], vis[N];
typedef pair<int, A> PII;
// (d[ states ], (u, c))

void dijkstra() {
    using namespace Graph;
    memset(d, inf, sizeof d);
    memset(vis, 0, sizeof vis);
    res = inf;

    priority_queue<PII, vector<PII>, greater<PII> > q;
    int ss = get(A(s, 0));
    d[ss] = 0;
    q.push(PII(d[ss], A(s, 0)));

    while (q.size()) {
        int D = q.top().first;
        auto x = q.top().second; q.pop();

        if (x.u == t) {
            res = min(res, D);
            return;
        }

        int sx = get(x);
        if (vis[sx]) continue;
        vis[sx] = true;

        if (x.c + 1 <= C) {
            A y(x.u, x.c+1);
            int sy = get(y);

            if (d[sy] > D + p[x.u]) {
                d[sy] = D + p[x.u];
                q.push(PII(d[sy], y));
            }
        }

        for (int i = head[x.u]; i; i = ne[i]) {
            int v = ver[i];
            if (x.c >= e[i]) {
                A y(v, x.c-e[i]);
                int sy = get(y);

                if (d[sy] > D + 0) {
                    d[sy] = D;
                    q.push(PII(d[sy], y));
                }
            }
        }
    }
    return;
}

int main() {
    freopen("input.txt", "r", stdin);
    cin >> n >> m;
    using namespace Graph;
    init();

    for (int i = 0; i < n; i++) scanf("%d", &p[i]);
    while (m--) {
        int x, y, z;
        scanf("%d%d%d", &x, &y, &z);
        add(x, y, z), add(y, x, z);
    }

    // query and dijkstra
    int q;
    cin >> q;
    while (q--) {
        scanf("%d%d%d", &C, &s, &t);
        dijkstra();

        if (res == inf) puts("impossible");
        else printf("%d\n", res);
    }
}



题意,有一个长度为 $n$ 的序列,每个点权值为在范围 $[1, m]$ 内,并且保证每个值都出现过
对每个 $i(1 \leqslant i \leqslant m)$ 询问包含权值 $[1, i]$ 的最小区间长度

具体来说,$f(i, x)$ 表示以 $[i \cdots)$ 作为左端点,并且包含 $[1\cdots x]$ 所有值(即编号为 $[1\cdots x]$ 的网络)
此时最近的右端点,记为 $f(i, x)$
另外用 $\textbf{A}_x[\cdots]$ 表示 $x$ 网络出现的位置

  • 很容易想到一种 $O(n^2)$ 的算法,状态转移方程是 $f(i, x) = \max (f(i, x-1), \min (\textbf{A}_x[\cdots]))$
    用 $\text{last}$ 记录 $\textbf{A}_x[\cdots]$ 中离 $i$ 最近的位置
    那么 $f(i, x) = \max (f(i, x-1), \text{last})$

  • 考虑优化,对于 $\textbf{for} \ \forall x \in [1\cdots m]$
    尝试维护 $\forall i \in [1\cdots n]$ 中的 $f_x(i)$,其中保证 $[i\cdots f_x(i)]$ 中 $[1\cdots x]$ 的每个数都出现
    这样可以用 $\min (f_x(i) - i + 1)$ 来更新全局的 $res$

  • $\textbf{A}[x] = \{p_1, \cdots p_{j-1}, p_j \cdots p_k \}$,当 $p_{j-1} \to p_j$
    转移的时候,考虑区间 $[p_{j-1}+1\cdots p_j]$
    对于 $\forall \ i \leqslant p_{j-1}$,$p_j$ 处出现的 $x$ 不影响 $f_x(i)$ 的值
    而 $\forall \ i \in [p_{j-1} + 1 \cdots p_j]$,要更新所有的 $f_x(i) \leftarrow \max (f_x(i), p_j)$
    但注意到对于 $\forall i \in [p_{j-1}+1 \cdots p_j]$,$p_j$ 只会影响到 $f_x(i) < p_j$ 的 $i$
    也就是说,我们只需要更新 $[p_{j-1}+1, p_j]$ 的某一段子区间即可,不妨设为 $[p_{j-1}+1, pos]$
    注意单调性,对于 $i_1 < i_2$,有 $f_x(i_1) \leqslant f_x(i_2)$
    $i \leqslant pos$, $f_x(i) \leqslant f_x(pos)$,只需要修改 $[p_{j-1}+1, pos]$ 即可

  • 要求出 $pos$,可以考虑在区间 $i \in [1\cdots n]$ 上建线段树,节点 $u$ 表示区间 $[t_u.l, t_u.r]$
    $t_u.f$ 维护 $\forall i \in [t_u.l, t_u.r]$ 的 $f_x(i)$ 信息
    在线段树上二分查找,找到满足 $t_u.f < p_j$ 的最大叶节点,$pos \leftarrow t_u.l$

  • 接下来在线段树找到区间 $u = [p_{j-1}+1, pos]$,并且执行区间赋值 $t_u.f \leftarrow p_j$
    此时 $i \in [t_u.l, t_u.r] = [p_{j-1}+1, pos]$ 所有的 $f_x(i)$ 都相等
    最短的包含 $[1\cdots x]$ 所有数的区间的长度 $\textbf{len}(x) = t_u.f - t_u.r + 1$
    自底向上 $\textbf{pull}$ 取 $\min$,最后 $t_1.len$ 就是 $\textbf{len}(x)$ 对应的答案
    特别地,$\textbf{A}[x] = \{p_1, \cdots, p_k \}$,对于 $i \in [p_{k}+1, \cdots n]$ 这一段,
    不存在满足 $[1\cdots x]$ 所有数都出现的区间段,这里只要将这个区间段的 $t_u.f \leftarrow +\infty$
    这样更新 $\min$ 的时候就用不到这一段了

  • 综上所述,线段树维护 $f$,表示 $f_x(i)$,另外需维护 $\min$,表示 $[t_u.l, t_u.r]$ 对应的最小 $\text{len}(x)$
    初始化 $\min$ 为 $+\infty$,$f$ 为 $0$,每次处理完 $\textbf{A}={p_1\cdots p_k}$ 时
    将 $[p_k+1, n]$ 的 $f$ 置为 $+\infty$

#include <iostream>
#include <cstdio>
#include <cstring>
#include <cstdlib>
#include <algorithm>
#include <queue>
#include <vector>
#include <stack>
#include <map>
#include <set>
#include <sstream>
#include <iomanip>
#include <cmath>
#include <bitset>
#include <assert.h>
#include <unordered_map>

using namespace std;
typedef long long ll;

#define Cmp(a, b) memcmp(a, b, sizeof(b))
#define Cpy(a, b) memcpy(a, b, sizeof(b))
#define Set(a, v) memset(a, v, sizeof(a))
#define debug(x) cout << #x << ": " << x << endl
#define _forS(i, l, r) for(set<int>::iterator i = (l); i != (r); i++)
#define _rep(i, l, r) for(int i = (l); i <= (r); i++)
#define _for(i, l, r) for(int i = (l); i < (r); i++)
#define _forDown(i, l, r) for(int i = (l); i >= r; i--)
#define debug_(ch, i) printf(#ch"[%d]: %d\n", i, ch[i])
#define debug_m(mp, p) printf(#mp"[%d]: %d\n", p->first, p->second)
#define debugS(str) cout << "dbg: " << str << endl;
#define debugArr(arr, x, y) _for(i, 0, x) { _for(j, 0, y) printf("%c", arr[i][j]); printf("\n"); }
#define _forPlus(i, l, d, r) for(int i = (l); i + d < (r); i++)
#define lowbit(i) (i & (-i))
#define MPR(a, b) make_pair(a, b)

pair<int, int> crack(int n) {
    int st = sqrt(n);
    int fac = n / st;
    while (n % st) {
        st += 1;
        fac = n / st;
    }

    return make_pair(st, fac);
}

inline ll qpow(ll a, int n) {
    ll ans = 1;
    for(; n; n >>= 1) {
        if(n & 1) ans *= 1ll * a;
        a *= a;
    }
    return ans;
}

template <class T>
inline bool chmax(T& a, T b) {
    if(a < b) {
        a = b;
        return true;
    }
    return false;
}

ll gcd(ll a, ll b) {
    return b == 0 ? a : gcd(b, a % b);
}

ll ksc(ll a, ll b, ll mod) {
    ll ans = 0;
    for(; b; b >>= 1) {
        if (b & 1) ans = (ans + a) % mod;
        a = (a * 2) % mod;
    }
    return ans;
}

ll ksm(ll a, ll b, ll mod) {
    ll ans = 1 % mod;
    a %= mod;

    for(; b; b >>= 1) {
        if (b & 1) ans = ksc(ans, a, mod);
        a = ksc(a, a, mod);
    }

    return ans;
}

template <class T>
inline bool chmin(T& a, T b) {
    if(a > b) {
        a = b;
        return true;
    }
    return false;
}

template<class T>
bool lexSmaller(vector<T> a, vector<T> b) {
    int n = a.size(), m = b.size();
    int i;
    for(i = 0; i < n && i < m; i++) {
        if (a[i] < b[i]) return true;
        else if (b[i] < a[i]) return false;
    }
    return (i == n && i < m);
}

// ============================================================== //

const int maxn = 2e5 + 10, inf = 0x3f3f3f3f;
int n, m;
vector<int> G[maxn];

class SegTree {
public:
    struct node {
        int l, r;
        int tag, len, f;
    };
    int tot = 0;
    vector<node> t;

    SegTree() = default;
    SegTree(int _tot) : tot(_tot) {
        assert(tot > 0);
        t.resize(tot << 2);
    }

    void build(int p, int l, int r) {
        t[p].l = l, t[p].r = r;
        if (l >= r) {
            t[p].f = 0, t[p].len = inf;
            return;
        }
        int mid = (l + r) >> 1;
        build(p<<1, l, mid);
        build(p<<1|1, mid+1, r);
    }

    void push(int p) {
        if (!t[p].tag) return;

        int fv = t[p].tag; t[p].tag = 0;

        t[p<<1].tag = fv;
        t[p<<1].f = fv;
        t[p<<1].len = fv - t[p<<1].r + 1;

        t[p<<1|1].tag = fv;
        t[p<<1|1].f = fv;
        t[p<<1|1].len = fv - t[p<<1|1].r + 1;
    }

    void pull(int p) {
        t[p].len = min(t[p<<1].len, t[p<<1|1].len);
        t[p].f = min(t[p<<1].f, t[p<<1|1].f);
    }

    void change(int p, const int l, const int r, int fv) {
        if (l <= t[p].l && t[p].r <= r) {
            t[p].tag = fv;
            // change
            t[p].f = fv;
            t[p].len = fv - t[p].r + 1;

            return;
        }
        push(p);
        int mid = (t[p].l + t[p].r) >> 1;
        if (l <= mid) change(p<<1, l, r, fv);
        if (r > mid) change(p<<1|1, l, r, fv);

        pull(p);
        return;
    }

    int find(int p, const int l, const int r, int fv) {
        if (t[p].l == t[p].r) return t[p].l;
        push(p);

        int mid = (t[p].l + t[p].r) >> 1;
        int res = -1;
        if (r > mid && t[p<<1|1].f < fv) res = find(p<<1|1, l, r, fv);
        if (res != -1) return res;
        if (l <= mid && t[p<<1].f < fv) res = find(p<<1, l, r, fv);

        return res;
    }
} seg(maxn);

void solve() {
    for (int x = 1; x <= m; x++) {
        for (int j = 1; j < G[x].size(); j++) {
            int l = G[x][j-1] + 1, r = G[x][j];
            // find p in [l, r], change [l, p]
            int pos = -1;
            pos = seg.find(1, l, r, r);
            if (pos == -1) continue;
            seg.change(1, l, pos, r);
        }
        seg.change(1, G[x].back()+1, n, inf);
        int res = seg.t[1].len;
        printf("%d ", res);
    }
}

int main() {
    //freopen("input.txt", "r", stdin);
    scanf("%d%d", &n, &m);
    for (int i = 0; i <= m; i++) G[i].push_back(0);
    for (int i = 1; i <= n; i++) {
        int x; scanf("%d", &x);
        G[x].push_back(i);
    }

    // build
    seg.build(1, 1, n);

    // solve
    solve();
}



  • 操作中涉及到将所有员工的工资都增加,容易想到用一个全局变量 $\Delta$ 来维护
    $\text{splay}$ 中维护的值为 $x$,那么员工真实的工资数值为 $x + \Delta$

  • 招员工和员工离职,就是 $\text{splay}$ 中常见的添加,删除操作

  • 比较难处理的是员工离职的情况,员工工资 $x + \Delta < \min$ 最低工资标准会离职

  • 在 splay 两端加入两个哨兵,$-\infty, \ +\infty$,splay 执行删除区间操作
    待删区间的左端点就是 $L = -\infty$,右端点是第一个满足 $t_R.v \geqslant \min-\Delta$ 的 $R$
    删除 $[L, R]$ 区间,将 $R$ splay 到根,并且将 $L$ splay 成 $R$ 的儿子,删除 $R$ 右子树即可

  • 找到 $\geqslant \min - \Delta$ 的最小数,实际上就是执行平衡二叉树的二分查找
    如果根节点 $u$ 满足, $u$ 是备选答案,接着去 $u$ 的左子树尝试找更小的

const int maxn = 1e5 + 10;
const int inf = 0x3f3f3f3f;

int n, m, L, R, delta = 0;

// get_k return dat

class Splay {
public:
    struct node {
        int son[2], pa, sz;
        int dat;

        void init(int _pa, int _dat) {
            pa = _pa, dat = _dat;
            sz = 1;
        }
    };

    int tot;
    int idx = 0, root = 0;
    vector<node> t;
    Splay() = default;
    Splay(int _tot) : tot(_tot) {
        t.resize(tot);
        idx = 0, root = 0;
    }

    inline void pull(int u) {
        t[u].sz = t[t[u].son[0]].sz + t[t[u].son[1]].sz + 1;
    }

    void rotate(int x) {
        int y = t[x].pa, z = t[y].pa;
        int k = t[y].son[1] == x;
        t[z].son[t[z].son[1] == y] = x, t[x].pa = z;
        t[y].son[k] = t[x].son[k^1], t[t[x].son[k^1]].pa = y;
        t[x].son[k^1] = y, t[y].pa = x;
        pull(y), pull(x);
    }

    void splay(int x, int k) {
        while (t[x].pa != k) {
            int y = t[x].pa, z = t[y].pa;
            if (z != k) {
                if ((t[y].son[1] == x) ^ (t[z].son[1] == y)) rotate(x);
                else rotate(y);
            }
            rotate(x);
        }

        if (k == 0) root = x;
    }

    int insert(int v) {
        int u = root, p = 0;
        while (u) p = u, u = t[u].son[v > t[u].dat];
        u = ++idx;
        if (p) t[p].son[v > t[p].dat] = u;
        t[u].init(p, v);
        splay(u, 0);
        return u;
    }

    int find(int x) {
        int u = root, res;
        while (u) {
            if (t[u].dat >= x) res = u, u = t[u].son[0];
            else u = t[u].son[1];
        }
        return res;
    }

    int get_k(int k) {
        int u = root;
        while (u) {
            if (t[t[u].son[0]].sz >= k) u = t[u].son[0];
            else if (t[t[u].son[0]].sz + 1 == k) return t[u].dat;
            else k -= t[t[u].son[0]].sz + 1, u = t[u].son[1];
        }
        return -1;
    } 


} spl(maxn);

int main() {
    freopen("input.txt", "r", stdin);
    scanf("%d%d", &n, &m);
    // build
    L = spl.insert(-inf), spl.insert(inf);

    int tot = 0;
    for (int i = 0; i < n; i++) {
        char op[2];
        int k;
        scanf("%s%d", op, &k);

        // get data
        if (op[0] == 'I') {
            if (k >= m) tot++, k -= delta, spl.insert(k);
        }
        else if (op[0] == 'A') {
            delta += k;
        }
        else if (op[0] == 'S') {
            delta -= k;
            R = spl.find(m - delta);
            spl.splay(R, 0), spl.splay(L, R);
            spl.t[L].son[1] = 0;
            spl.pull(L), spl.pull(R);
        }
        else {
            // F
            if (k > spl.t[spl.root].sz - 2) puts("-1");
            else {
                int res = spl.get_k(spl.t[spl.root].sz - k);
                printf("%d\n", res + delta);
            }
        }
    }
    printf("%d\n", tot - spl.t[spl.root].sz + 2);
}


活动打卡代码 AcWing 950. 郁闷的出纳员

  • 操作中涉及到将所有员工的工资都增加,容易想到用一个全局变量 $\Delta$ 来维护
    $\text{splay}$ 中维护的值为 $x$,那么员工真实的工资数值为 $x + \Delta$

  • 招员工和员工离职,就是 $\text{splay}$ 中常见的添加,删除操作

  • 比较难处理的是员工离职的情况,员工工资 $x + \Delta < \min$ 最低工资标准会离职

  • 在 splay 两端加入两个哨兵,$-\infty, \ +\infty$,splay 执行删除区间操作
    待删区间的左端点就是 $L = -\infty$,右端点是第一个满足 $t_R.v \geqslant \min-\Delta$ 的 $R$
    删除 $[L, R]$ 区间,将 $R$ splay 到根,并且将 $L$ splay 成 $R$ 的儿子,删除 $R$ 右子树即可

  • 找到 $\geqslant \min - \Delta$ 的最小数,实际上就是执行平衡二叉树的二分查找
    如果根节点 $u$ 满足, $u$ 是备选答案,接着去 $u$ 的左子树尝试找更小的

const int maxn = 1e5 + 10;
const int inf = 0x3f3f3f3f;

int n, m, L, R, delta = 0;

// get_k return dat

class Splay {
public:
    struct node {
        int son[2], pa, sz;
        int dat;

        void init(int _pa, int _dat) {
            pa = _pa, dat = _dat;
            sz = 1;
        }
    };

    int tot;
    int idx = 0, root = 0;
    vector<node> t;
    Splay() = default;
    Splay(int _tot) : tot(_tot) {
        t.resize(tot);
        idx = 0, root = 0;
    }

    inline void pull(int u) {
        t[u].sz = t[t[u].son[0]].sz + t[t[u].son[1]].sz + 1;
    }

    void rotate(int x) {
        int y = t[x].pa, z = t[y].pa;
        int k = t[y].son[1] == x;
        t[z].son[t[z].son[1] == y] = x, t[x].pa = z;
        t[y].son[k] = t[x].son[k^1], t[t[x].son[k^1]].pa = y;
        t[x].son[k^1] = y, t[y].pa = x;
        pull(y), pull(x);
    }

    void splay(int x, int k) {
        while (t[x].pa != k) {
            int y = t[x].pa, z = t[y].pa;
            if (z != k) {
                if ((t[y].son[1] == x) ^ (t[z].son[1] == y)) rotate(x);
                else rotate(y);
            }
            rotate(x);
        }

        if (k == 0) root = x;
    }

    int insert(int v) {
        int u = root, p = 0;
        while (u) p = u, u = t[u].son[v > t[u].dat];
        u = ++idx;
        if (p) t[p].son[v > t[p].dat] = u;
        t[u].init(p, v);
        splay(u, 0);
        return u;
    }

    int find(int x) {
        int u = root, res;
        while (u) {
            if (t[u].dat >= x) res = u, u = t[u].son[0];
            else u = t[u].son[1];
        }
        return res;
    }

    int get_k(int k) {
        int u = root;
        while (u) {
            if (t[t[u].son[0]].sz >= k) u = t[u].son[0];
            else if (t[t[u].son[0]].sz + 1 == k) return t[u].dat;
            else k -= t[t[u].son[0]].sz + 1, u = t[u].son[1];
        }
        return -1;
    } 


} spl(maxn);

int main() {
    freopen("input.txt", "r", stdin);
    scanf("%d%d", &n, &m);
    // build
    L = spl.insert(-inf), spl.insert(inf);

    int tot = 0;
    for (int i = 0; i < n; i++) {
        char op[2];
        int k;
        scanf("%s%d", op, &k);

        // get data
        if (op[0] == 'I') {
            if (k >= m) tot++, k -= delta, spl.insert(k);
        }
        else if (op[0] == 'A') {
            delta += k;
        }
        else if (op[0] == 'S') {
            delta -= k;
            R = spl.find(m - delta);
            spl.splay(R, 0), spl.splay(L, R);
            spl.t[L].son[1] = 0;
            spl.pull(L), spl.pull(R);
        }
        else {
            // F
            if (k > spl.t[spl.root].sz - 2) puts("-1");
            else {
                int res = spl.get_k(spl.t[spl.root].sz - k);
                printf("%d\n", res + delta);
            }
        }
    }
    printf("%d\n", tot - spl.t[spl.root].sz + 2);
}



splay.001.jpeg
splay.002.jpeg
splay.003.jpeg

#include <iostream>
#include <cstdio>
#include <cstring>
#include <cstdlib>
#include <algorithm>
#include <queue>
#include <vector>
#include <stack>
#include <map>
#include <set>
#include <sstream>
#include <iomanip>
#include <cmath>
#include <bitset>
#include <assert.h>
#include <unordered_map>

using namespace std;
typedef long long ll;

#define Cmp(a, b) memcmp(a, b, sizeof(b))
#define Cpy(a, b) memcpy(a, b, sizeof(b))
#define Set(a, v) memset(a, v, sizeof(a))
#define debug(x) cout << #x << ": " << x << endl
#define _forS(i, l, r) for(set<int>::iterator i = (l); i != (r); i++)
#define _rep(i, l, r) for(int i = (l); i <= (r); i++)
#define _for(i, l, r) for(int i = (l); i < (r); i++)
#define _forDown(i, l, r) for(int i = (l); i >= r; i--)
#define debug_(ch, i) printf(#ch"[%d]: %d\n", i, ch[i])
#define debug_m(mp, p) printf(#mp"[%d]: %d\n", p->first, p->second)
#define debugS(str) cout << "dbg: " << str << endl;
#define debugArr(arr, x, y) _for(i, 0, x) { _for(j, 0, y) printf("%c", arr[i][j]); printf("\n"); }
#define _forPlus(i, l, d, r) for(int i = (l); i + d < (r); i++)
#define lowbit(i) (i & (-i))
#define MPR(a, b) make_pair(a, b)

pair<int, int> crack(int n) {
    int st = sqrt(n);
    int fac = n / st;
    while (n % st) {
        st += 1;
        fac = n / st;
    }

    return make_pair(st, fac);
}

inline ll qpow(ll a, int n) {
    ll ans = 1;
    for(; n; n >>= 1) {
        if(n & 1) ans *= 1ll * a;
        a *= a;
    }
    return ans;
}

template <class T>
inline bool chmax(T& a, T b) {
    if(a < b) {
        a = b;
        return true;
    }
    return false;
}

ll gcd(ll a, ll b) {
    return b == 0 ? a : gcd(b, a % b);
}

ll ksc(ll a, ll b, ll mod) {
    ll ans = 0;
    for(; b; b >>= 1) {
        if (b & 1) ans = (ans + a) % mod;
        a = (a * 2) % mod;
    }
    return ans;
}

ll ksm(ll a, ll b, ll mod) {
    ll ans = 1 % mod;
    a %= mod;

    for(; b; b >>= 1) {
        if (b & 1) ans = ksc(ans, a, mod);
        a = ksc(a, a, mod);
    }

    return ans;
}

template <class T>
inline bool chmin(T& a, T b) {
    if(a > b) {
        a = b;
        return true;
    }
    return false;
}

template<class T>
bool lexSmaller(vector<T> a, vector<T> b) {
    int n = a.size(), m = b.size();
    int i;
    for(i = 0; i < n && i < m; i++) {
        if (a[i] < b[i]) return true;
        else if (b[i] < a[i]) return false;
    }
    return (i == n && i < m);
}

// ============================================================== //

const int maxn = 1e5 + 10;
int n, m;

class Splay {
public:
    struct node {
        int son[2], pa;
        int sz, tag, v;

        void init(int _v, int _pa) {
            v = _v, pa = _pa;
            sz = 1;
        }
    };

    int tot;
    int root = 0, idx = 0;

    vector<node> t;

    Splay() = default;
    Splay(int _n) : tot(_n) {
        t.resize(tot);
        root = 0, idx = 0;
    }

    inline void mark(int p) {
        t[p].tag = 1;
    }

    void push(int p) {
        if (!t[p].tag) return;
        swap(t[p].son[0], t[p].son[1]);
        t[t[p].son[0]].tag ^= 1, t[t[p].son[1]].tag ^= 1;
        t[p].tag = 0;
        return;
    }

    void pull(int p) {
        t[p].sz = t[t[p].son[0]].sz + t[t[p].son[1]].sz + 1;
    }

    void insert(int v) {
        int u = root, p = 0;
        while (u) p = u, u = t[u].son[v > t[u].v];

        u = ++idx;
        if (p > 0) t[p].son[v > t[p].v] = u;
        t[u].init(v, p);
        splay(u, 0);
    }

    void rotate(int x) {
        int y = t[x].pa, z = t[y].pa;
        int k = t[y].son[1] == x;

        // connect
        t[z].son[t[z].son[1] == y] = x, t[x].pa = z;
        t[y].son[k] = t[x].son[k^1], t[t[x].son[k^1]].pa = y;
        t[x].son[k^1] = y, t[y].pa = x;
        pull(y), pull(x);
    }

    void splay(int x, int k) {
        while (t[x].pa != k) {
            int y = t[x].pa, z = t[y].pa;
            if (z != k) {
                if ((t[z].son[1] == y) ^ (t[y].son[1] == x)) rotate(x);
                else rotate(y);
            }
            rotate(x);
        }
        if (k == 0) root = x;
    }

    int get_k(int v) {
        int u = root;
        while (true) {
            push(u);
            if (t[t[u].son[0]].sz >= v) u = t[u].son[0];
            else if (t[t[u].son[0]].sz + 1 == v) return u;
            else v -= t[t[u].son[0]].sz + 1, u = t[u].son[1];
        }
        return -1;
    }

    void out(int p) {
        push(p);
        if (t[p].son[0]) out(t[p].son[0]);
        if (t[p].v >= 1 && t[p].v <= n) printf("%d ", t[p].v);
        if (t[p].son[1]) out(t[p].son[1]);
    }

} spl(maxn);

int main() {
    freopen("input.txt", "r", stdin);
    scanf("%d%d", &n, &m);

    // build splay
    for (int i = 0; i <= n+1; i++) spl.insert(i);

    while (m--) {
        int li, ri;
        scanf("%d%d", &li, &ri);
        int l = spl.get_k(li), r = spl.get_k(ri+2);

        spl.splay(l, 0), spl.splay(r, l);
        spl.mark(spl.t[r].son[0]);
    }
    spl.out(spl.root);
}


活动打卡代码 AcWing 2437. Splay

splay.001.jpeg
splay.002.jpeg
splay.003.jpeg

#include <iostream>
#include <cstdio>
#include <cstring>
#include <cstdlib>
#include <algorithm>
#include <queue>
#include <vector>
#include <stack>
#include <map>
#include <set>
#include <sstream>
#include <iomanip>
#include <cmath>
#include <bitset>
#include <assert.h>
#include <unordered_map>

using namespace std;
typedef long long ll;

#define Cmp(a, b) memcmp(a, b, sizeof(b))
#define Cpy(a, b) memcpy(a, b, sizeof(b))
#define Set(a, v) memset(a, v, sizeof(a))
#define debug(x) cout << #x << ": " << x << endl
#define _forS(i, l, r) for(set<int>::iterator i = (l); i != (r); i++)
#define _rep(i, l, r) for(int i = (l); i <= (r); i++)
#define _for(i, l, r) for(int i = (l); i < (r); i++)
#define _forDown(i, l, r) for(int i = (l); i >= r; i--)
#define debug_(ch, i) printf(#ch"[%d]: %d\n", i, ch[i])
#define debug_m(mp, p) printf(#mp"[%d]: %d\n", p->first, p->second)
#define debugS(str) cout << "dbg: " << str << endl;
#define debugArr(arr, x, y) _for(i, 0, x) { _for(j, 0, y) printf("%c", arr[i][j]); printf("\n"); }
#define _forPlus(i, l, d, r) for(int i = (l); i + d < (r); i++)
#define lowbit(i) (i & (-i))
#define MPR(a, b) make_pair(a, b)

pair<int, int> crack(int n) {
    int st = sqrt(n);
    int fac = n / st;
    while (n % st) {
        st += 1;
        fac = n / st;
    }

    return make_pair(st, fac);
}

inline ll qpow(ll a, int n) {
    ll ans = 1;
    for(; n; n >>= 1) {
        if(n & 1) ans *= 1ll * a;
        a *= a;
    }
    return ans;
}

template <class T>
inline bool chmax(T& a, T b) {
    if(a < b) {
        a = b;
        return true;
    }
    return false;
}

ll gcd(ll a, ll b) {
    return b == 0 ? a : gcd(b, a % b);
}

ll ksc(ll a, ll b, ll mod) {
    ll ans = 0;
    for(; b; b >>= 1) {
        if (b & 1) ans = (ans + a) % mod;
        a = (a * 2) % mod;
    }
    return ans;
}

ll ksm(ll a, ll b, ll mod) {
    ll ans = 1 % mod;
    a %= mod;

    for(; b; b >>= 1) {
        if (b & 1) ans = ksc(ans, a, mod);
        a = ksc(a, a, mod);
    }

    return ans;
}

template <class T>
inline bool chmin(T& a, T b) {
    if(a > b) {
        a = b;
        return true;
    }
    return false;
}

template<class T>
bool lexSmaller(vector<T> a, vector<T> b) {
    int n = a.size(), m = b.size();
    int i;
    for(i = 0; i < n && i < m; i++) {
        if (a[i] < b[i]) return true;
        else if (b[i] < a[i]) return false;
    }
    return (i == n && i < m);
}

// ============================================================== //

const int maxn = 1e5 + 10;
int n, m;

class Splay {
public:
    struct node {
        int son[2], pa;
        int sz, tag, v;

        void init(int _v, int _pa) {
            v = _v, pa = _pa;
            sz = 1;
        }
    };

    int tot;
    int root = 0, idx = 0;

    vector<node> t;

    Splay() = default;
    Splay(int _n) : tot(_n) {
        t.resize(tot);
        root = 0, idx = 0;
    }

    inline void mark(int p) {
        t[p].tag = 1;
    }

    void push(int p) {
        if (!t[p].tag) return;
        swap(t[p].son[0], t[p].son[1]);
        t[t[p].son[0]].tag ^= 1, t[t[p].son[1]].tag ^= 1;
        t[p].tag = 0;
        return;
    }

    void pull(int p) {
        t[p].sz = t[t[p].son[0]].sz + t[t[p].son[1]].sz + 1;
    }

    void insert(int v) {
        int u = root, p = 0;
        while (u) p = u, u = t[u].son[v > t[u].v];

        u = ++idx;
        if (p > 0) t[p].son[v > t[p].v] = u;
        t[u].init(v, p);
        splay(u, 0);
    }

    void rotate(int x) {
        int y = t[x].pa, z = t[y].pa;
        int k = t[y].son[1] == x;

        // connect
        t[z].son[t[z].son[1] == y] = x, t[x].pa = z;
        t[y].son[k] = t[x].son[k^1], t[t[x].son[k^1]].pa = y;
        t[x].son[k^1] = y, t[y].pa = x;
        pull(y), pull(x);
    }

    void splay(int x, int k) {
        while (t[x].pa != k) {
            int y = t[x].pa, z = t[y].pa;
            if (z != k) {
                if ((t[z].son[1] == y) ^ (t[y].son[1] == x)) rotate(x);
                else rotate(y);
            }
            rotate(x);
        }
        if (k == 0) root = x;
    }

    int get_k(int v) {
        int u = root;
        while (true) {
            push(u);
            if (t[t[u].son[0]].sz >= v) u = t[u].son[0];
            else if (t[t[u].son[0]].sz + 1 == v) return u;
            else v -= t[t[u].son[0]].sz + 1, u = t[u].son[1];
        }
        return -1;
    }

    void out(int p) {
        push(p);
        if (t[p].son[0]) out(t[p].son[0]);
        if (t[p].v >= 1 && t[p].v <= n) printf("%d ", t[p].v);
        if (t[p].son[1]) out(t[p].son[1]);
    }

} spl(maxn);

int main() {
    freopen("input.txt", "r", stdin);
    scanf("%d%d", &n, &m);

    // build splay
    for (int i = 0; i <= n+1; i++) spl.insert(i);

    while (m--) {
        int li, ri;
        scanf("%d%d", &li, &ri);
        int l = spl.get_k(li), r = spl.get_k(ri+2);

        spl.splay(l, 0), spl.splay(r, l);
        spl.mark(spl.t[r].son[0]);
    }
    spl.out(spl.root);
}



考虑前缀和 $s[i]$,原问题等价于对于一个点 $p$,找到 $s[1\cdots p-1]$ 的一个点 $p’$
使得 $s[p’] \oplus s[p]$ 最大,可以考虑使用可持久化 $\text{trie}$

对于可持久化 $\text{trie}$,$\text{root}(p-1)$ 这个版本中只有区间 $[1\cdots p-1]$ 的信息
利用主席树的思想,可以解决可持久化 $\text{trie}$ 的 $\text{k-query}$ 问题

对于区间 $[1\cdots r]$,要找到一个 $l \in [1\cdots r-1]$,使得 $s_l \oplus s_r$ 为第 $k$ 大
令 $p \leftarrow \text{root}(r-1), \ val \leftarrow s[r]$,从高位到低位检查 $val$ 的第 $b$ 位 $c$
如果 $\textbf{size}(t(p, c \oplus 1)) \geqslant k$,那么 $res += (1 \ll b), \ p \leftarrow t(p, c\oplus 1)$
否则的话,$k’ \leftarrow k - \textbf{size}(t(p, c\oplus 1))$,$p \leftarrow t(p, c)$,递归在 $t(p, c)$ 子树查找第 $k’$ 大

需要注意的是边界,想要让 $r = 1$ 时有意义,必须提前在 $\text{trie}$ 树中插入 $\text{insert}(\text{root}(0), 0)$
表示在 $\text{root}(0)$ 初始化插入一个每个位都是 $0$ 的数

具体来说

  • 对于 $H$ 位的数 $val$,由于要统计 $\text{size}$ 信息,所以递归地插入
    $\textbf{insert}(pre, p, H, val)$,递归的边界是 $H < 0, \text{size}(p) = \text{size}(pre) + 1$
    ($H = 0$ 时插入最后一个字符 $c$,递归执行 $t(p, c)$ 之后,边界 $H = -1$)

  • $res \leftarrow (i, rk(i))$,表示在区间 $[1\cdots i-1]$ 中找到一个 $j$
    使得 $res = s_j \oplus x_i$ 为第 $rk(i)$ 大,很显然一开始 $rk(i) = 1$

  • 建立一个优先队列 $\text{que}$,对于 $\forall \ r \in [1, n]$
    将 $\textbf{ask}(\text{root}(r-1), rk(r), s[r])$ 的结果放入 $\text{que}$ 中

  • 取出堆顶元素,此时堆中最大元素假设为 $(res, p)$
    表示此时 $\exists l \in [1, p)$,使得 $s_l \oplus s_p$ 为第 $1$ 大,其值为 $res$
    将其累加到答案中,删掉 $s_l$,注意要接着找到 $[1, l-1] \cup [l+1, p)$ 中第 $1$ 大,将其放入堆中
    注意到 $[1, l-1] \cup [l+1, p)$ 中的第 $1$ 大,等价于 $[1, p)$ 中的第 $2$ 大
    由此在编程实现上可以更简单一些,一开始令 $rk(i) = 1$,取出堆顶元素 $(res, p)$ 之后
    执行查询 $res’ \leftarrow \textbf{ask}(\text{root}(p), ++rk(p), s[p])$,再继续将 $res’$ 放入堆中

#include <iostream>
#include <cstdio>
#include <cstring>
#include <cstdlib>
#include <algorithm>
#include <queue>
#include <vector>
#include <stack>
#include <map>
#include <set>
#include <sstream>
#include <iomanip>
#include <cmath>
#include <bitset>
#include <assert.h>
#include <unordered_map>
#pragma GCC optimize(2)

using namespace std;
typedef long long ll;

#define Cmp(a, b) memcmp(a, b, sizeof(b))
#define Cpy(a, b) memcpy(a, b, sizeof(b))
#define Set(a, v) memset(a, v, sizeof(a))
#define debug(x) cout << #x << ": " << x << endl
#define _forS(i, l, r) for(set<int>::iterator i = (l); i != (r); i++)
#define _rep(i, l, r) for(int i = (l); i <= (r); i++)
#define _for(i, l, r) for(int i = (l); i < (r); i++)
#define _forDown(i, l, r) for(int i = (l); i >= r; i--)
#define debug_(ch, i) printf(#ch"[%d]: %d\n", i, ch[i])
#define debug_m(mp, p) printf(#mp"[%d]: %d\n", p->first, p->second)
#define debugS(str) cout << "dbg: " << str << endl;
#define debugArr(arr, x, y) _for(i, 0, x) { _for(j, 0, y) printf("%c", arr[i][j]); printf("\n"); }
#define _forPlus(i, l, d, r) for(int i = (l); i + d < (r); i++)
#define lowbit(i) (i & (-i))
#define MPR(a, b) make_pair(a, b)

pair<int, int> crack(int n) {
    int st = sqrt(n);
    int fac = n / st;
    while (n % st) {
        st += 1;
        fac = n / st;
    }

    return make_pair(st, fac);
}

inline ll qpow(ll a, int n) {
    ll ans = 1;
    for(; n; n >>= 1) {
        if(n & 1) ans *= 1ll * a;
        a *= a;
    }
    return ans;
}

template <class T>
inline bool chmax(T& a, T b) {
    if(a < b) {
        a = b;
        return true;
    }
    return false;
}

ll gcd(ll a, ll b) {
    return b == 0 ? a : gcd(b, a % b);
}

ll ksc(ll a, ll b, ll mod) {
    ll ans = 0;
    for(; b; b >>= 1) {
        if (b & 1) ans = (ans + a) % mod;
        a = (a * 2) % mod;
    }
    return ans;
}

ll ksm(ll a, ll b, ll mod) {
    ll ans = 1 % mod;
    a %= mod;

    for(; b; b >>= 1) {
        if (b & 1) ans = ksc(ans, a, mod);
        a = ksc(a, a, mod);
    }

    return ans;
}

template <class T>
inline bool chmin(T& a, T b) {
    if(a > b) {
        a = b;
        return true;
    }
    return false;
}

template<class T>
bool lexSmaller(vector<T> a, vector<T> b) {
    int n = a.size(), m = b.size();
    int i;
    for(i = 0; i < n && i < m; i++) {
        if (a[i] < b[i]) return true;
        else if (b[i] < a[i]) return false;
    }
    return (i == n && i < m);
}

// ============================================================== //

typedef pair<ll, int> PII;
const int maxn = 500000 + 10, N = maxn * 35;
const int H = 33;
int n, k, rk[maxn], root[maxn];
ll s[maxn];
priority_queue<PII> heap;

// insert(pre, p, H, val)
// ask(root(p), rk, &ans)

class Trie {
public:
    int tot;
    int t[N][2], sz[N];

    Trie() {
        tot = 0;
        memset(t, 0, sizeof t);
        memset(sz, 0, sizeof sz);
    }

    void insert(int pre, int p, int H, ll val) {
        if (H < 0) {
            sz[p] = sz[pre] + 1;
            return;
        }
        int c = val >> H & 1;
        if (pre) t[p][c^1] = t[pre][c^1];
        t[p][c] = ++tot;
        insert(t[pre][c], t[p][c], H-1, val);
        sz[p] = sz[t[p][c]] + sz[t[p][c^1]];
    }

    void ask(int p, int rk, int H, ll val, ll &res) {
        if (H < 0) return;
        int c = val >> H & 1;
        if (sz[ t[p][c^1] ] >= rk) {
            res = (res << 1 | 1);
            ask(t[p][c^1], rk, H-1, val, res);
        }
        else {
            res <<= 1;
            ask(t[p][c], rk - sz[t[p][c^1]], H-1, val, res);
        }
    }
} trie;

void solve() {
    for (int i = 1; i <= n; i++) {
        ll res = 0;
        trie.ask(root[i-1], rk[i], H, s[i], res);
        heap.push({res, i});
    }
    ll ans = 0;
    while (k--) {
        auto x = heap.top(); heap.pop();
        ans += x.first;
        int r = x.second;
        ll res = 0;
        trie.ask(root[r-1], ++rk[r], H, s[r], res);
        heap.push({res, r});
    }
    printf("%lld\n", ans);
}

int main() {
    freopen("input.txt", "r", stdin);
    // init
    memset(root, 0, sizeof root);

    scanf("%d%d", &n, &k);
    for (int i = 1; i <= n; i++) {
        ll x;
        scanf("%lld", &x);
        s[i] = s[i-1] ^ x;
        rk[i] = 1;
    }

    // per trie
    root[0] = ++trie.tot;
    trie.insert(0, root[0], H, 0);
    for (int i = 1; i <= n; i++) {
        root[i] = ++trie.tot;
        trie.insert(root[i-1], root[i], H, s[i]);
    }

    // solve
    solve();
}



活动打卡代码 AcWing 256. 最大异或和

trie.001.jpeg

  • 用 $\text{root}[\cdots]$ 数组来定位每个根节点,不妨设 $\text{trie}$ 中字符集合为 $C$,全体字符集为 $A$
    当前插入第 $i$ 个字符串,即执行第 $i$ 个版本,$p \leftarrow \text{root}(i-1)$

  • 新建一个节点 $q$,即 $q = root(i) = ++tot$,假设当前插入字符 $S_j$

  • 如果 $p \neq 0$,对于 $\forall c \in {C}, \ c \neq S_j$,令 $t(q, c) \leftarrow t(p, c)$
    (这一步可以检查 $\forall c \in {A}$ 全体字符集,令 $t(q, c) \leftarrow t(p, c)$,因为下一步 $t(q, S_j)$ 会重新定位到新开的节点上)

  • 新建一个节点,令 $t(q, S_j) = ++tot$,即除了 $S_j$ 指针不同外,$p, q$ 的其余信息完全相同

  • $p \leftarrow t(p, S_j), q \leftarrow t(q, S_j), j \leftarrow j+1$ 直到字符串结尾

  • 可以类似引入一个异或前缀和的概念

$$
\bigoplus_{i = p}^{n} a_p = S_n \oplus S_{p-1}
$$

  • 对于添加操作,很简单 $S_{n+1} = S_n \oplus x, \ n = n+1$

  • 如果令 $p’ = p-1, \ l-1 \leqslant p’ \leqslant r-1$,询问操作实际上就是令 $val = x \oplus S_n$
    求一个位置 $p$,满足 $l-1 \leqslant p \leqslant r-1$,使得 $S_p \oplus val$ 最大
    这个问题如果没有 $p \in [l-1, r-1]$ 的限制,就是最大异或和问题

  • 对于 $p \leqslant r-1$,可以借鉴主席树思想,对 $\text{trie}$ 进行可持久化,在第 $r-1$ 个版本
    即 $\text{root}(r-1)$ 中查询最大异或和路径
    ($p = \text{root}(r-1)$,从高位到低位尽可能沿着和 $val$ 相反的位走)

  • 对于 $p \geqslant l-1$,只要保证异或和路径上所经过点的时间戳 $\geqslant l-1$ 即可
    对于插入操作 $\text{insert}(pre, p, i)$ 表示插入第 $i$ 个字符串
    $k$ 从高位到低位遍历,此时第 $k$ 位的字符为 $c = S_i >> k \& 1$

  • 如果 $pre \neq 0$,令 $t(p, c\oplus 1) \leftarrow t(pre, c\oplus 1)$

  • $t(p, c) = ++tot$,于此同时标记节点时间戳 $dfn(p) = i$,然后和主席树一样同步往下走
    $p \leftarrow t(p, c), \ pre \leftarrow t(pre, c), \textbf{then} \ dfn(p) = i$

const int N = 600000 + 5;
const int maxn = N * 25;
int n, m, s[N], root[N];

class Trie {
public:
    int t[maxn][2], dfn[maxn];
    int tot;

    Trie() {
        tot = 0;
        memset(t, 0, sizeof 0);
        memset(dfn, 0, sizeof 0);
        dfn[0] = -1;
    }

    void insert(int pre, int p, int ver) {
        dfn[p] = ver;
        for (int k = 25; k >= 0; k--) {
            int c = s[ver] >> k & 1;
            if (pre) t[p][c^1] = t[pre][c^1];
            t[p][c] = ++tot;
            p = t[p][c], pre = t[pre][c];
            dfn[p] = ver;
        }
    }

    int ask(int p, int val, int lim) {
        for (int k = 25; k >= 0; k--) {
            int c = val >> k & 1;
            if (dfn[ t[p][c^1] ] >= lim) p = t[p][c^1];
            else p = t[p][c];
        }
        return s[dfn[p]] ^ val;
    }
} trie;

int main() {
    freopen("input.txt", "r", stdin);
    cin >> n >> m;

    // init
    for (int i = 1; i <= n; i++) {
        int x;
        scanf("%d", &x);
        s[i] = s[i-1] ^ x;
        root[i] = ++trie.tot;
        trie.insert(root[i-1], root[i], i);
    }
    while (m--) {
        char cmd[2];
        scanf("%s", cmd);
        if (cmd[0] == 'A') {
            int x;
            scanf("%d", &x);
            root[++n] = ++trie.tot;
            s[n] = s[n-1] ^ x;
            trie.insert(root[n-1], root[n], n);
        }
        else {
            int l, r, x;
            scanf("%d%d%d", &l, &r, &x);
            int res = trie.ask(root[r-1], s[n]^x, l-1);
            printf("%d\n", res);
        }
    }
}



trie.001.jpeg

  • 用 $\text{root}[\cdots]$ 数组来定位每个根节点,不妨设 $\text{trie}$ 中字符集合为 $C$,全体字符集为 $A$
    当前插入第 $i$ 个字符串,即执行第 $i$ 个版本,$p \leftarrow \text{root}(i-1)$

  • 新建一个节点 $q$,即 $q = root(i) = ++tot$,假设当前插入字符 $S_j$

  • 如果 $p \neq 0$,对于 $\forall c \in {C}, \ c \neq S_j$,令 $t(q, c) \leftarrow t(p, c)$
    (这一步可以检查 $\forall c \in {A}$ 全体字符集,令 $t(q, c) \leftarrow t(p, c)$,因为下一步 $t(q, S_j)$ 会重新定位到新开的节点上)

  • 新建一个节点,令 $t(q, S_j) = ++tot$,即除了 $S_j$ 指针不同外,$p, q$ 的其余信息完全相同

  • $p \leftarrow t(p, S_j), q \leftarrow t(q, S_j), j \leftarrow j+1$ 直到字符串结尾

  • 可以类似引入一个异或前缀和的概念

$$
\bigoplus_{i = p}^{n} a_p = S_n \oplus S_{p-1}
$$

  • 对于添加操作,很简单 $S_{n+1} = S_n \oplus x, \ n = n+1$

  • 如果令 $p’ = p-1, \ l-1 \leqslant p’ \leqslant r-1$,询问操作实际上就是令 $val = x \oplus S_n$
    求一个位置 $p$,满足 $l-1 \leqslant p \leqslant r-1$,使得 $S_p \oplus val$ 最大
    这个问题如果没有 $p \in [l-1, r-1]$ 的限制,就是最大异或和问题

  • 对于 $p \leqslant r-1$,可以借鉴主席树思想,对 $\text{trie}$ 进行可持久化,在第 $r-1$ 个版本
    即 $\text{root}(r-1)$ 中查询最大异或和路径
    ($p = \text{root}(r-1)$,从高位到低位尽可能沿着和 $val$ 相反的位走)

  • 对于 $p \geqslant l-1$,只要保证异或和路径上所经过点的时间戳 $\geqslant l-1$ 即可
    对于插入操作 $\text{insert}(pre, p, i)$ 表示插入第 $i$ 个字符串
    $k$ 从高位到低位遍历,此时第 $k$ 位的字符为 $c = S_i >> k \& 1$

  • 如果 $pre \neq 0$,令 $t(p, c\oplus 1) \leftarrow t(pre, c\oplus 1)$

  • $t(p, c) = ++tot$,于此同时标记节点时间戳 $dfn(p) = i$,然后和主席树一样同步往下走
    $p \leftarrow t(p, c), \ pre \leftarrow t(pre, c), \textbf{then} \ dfn(p) = i$

const int N = 600000 + 5;
const int maxn = N * 25;
int n, m, s[N], root[N];

class Trie {
public:
    int t[maxn][2], dfn[maxn];
    int tot;

    Trie() {
        tot = 0;
        memset(t, 0, sizeof 0);
        memset(dfn, 0, sizeof 0);
        dfn[0] = -1;
    }

    void insert(int pre, int p, int ver) {
        dfn[p] = ver;
        for (int k = 25; k >= 0; k--) {
            int c = s[ver] >> k & 1;
            if (pre) t[p][c^1] = t[pre][c^1];
            t[p][c] = ++tot;
            p = t[p][c], pre = t[pre][c];
            dfn[p] = ver;
        }
    }

    int ask(int p, int val, int lim) {
        for (int k = 25; k >= 0; k--) {
            int c = val >> k & 1;
            if (dfn[ t[p][c^1] ] >= lim) p = t[p][c^1];
            else p = t[p][c];
        }
        return s[dfn[p]] ^ val;
    }
} trie;

int main() {
    freopen("input.txt", "r", stdin);
    cin >> n >> m;

    // init
    for (int i = 1; i <= n; i++) {
        int x;
        scanf("%d", &x);
        s[i] = s[i-1] ^ x;
        root[i] = ++trie.tot;
        trie.insert(root[i-1], root[i], i);
    }
    while (m--) {
        char cmd[2];
        scanf("%s", cmd);
        if (cmd[0] == 'A') {
            int x;
            scanf("%d", &x);
            root[++n] = ++trie.tot;
            s[n] = s[n-1] ^ x;
            trie.insert(root[n-1], root[n], n);
        }
        else {
            int l, r, x;
            scanf("%d%d%d", &l, &r, &x);
            int res = trie.ask(root[r-1], s[n]^x, l-1);
            printf("%d\n", res);
        }
    }
}


活动打卡代码 AcWing 161. 电话列表

将所有字符串插入 trie 中
接着对每个字符串执行查询,遍历 $str[1\cdots n-1]$,如果 trie 中表示某个字符的节点 $p$,$vis(p) \neq 0$
那么就互为前缀

#include <iostream>
#include <cstdio>
#include <cstring>
#include <cstdlib>
#include <algorithm>
#include <queue>
#include <vector>
#include <stack>
#include <map>
#include <set>
#include <sstream>
#include <iomanip>
#include <cmath>
#include <bitset>
#include <assert.h>
#include <unordered_map>

using namespace std;
typedef long long ll;

#define Cmp(a, b) memcmp(a, b, sizeof(b))
#define Cpy(a, b) memcpy(a, b, sizeof(b))
#define Set(a, v) memset(a, v, sizeof(a))
#define debug(x) cout << #x << ": " << x << endl
#define _forS(i, l, r) for(set<int>::iterator i = (l); i != (r); i++)
#define _rep(i, l, r) for(int i = (l); i <= (r); i++)
#define _for(i, l, r) for(int i = (l); i < (r); i++)
#define _forDown(i, l, r) for(int i = (l); i >= r; i--)
#define debug_(ch, i) printf(#ch"[%d]: %d\n", i, ch[i])
#define debug_m(mp, p) printf(#mp"[%d]: %d\n", p->first, p->second)
#define debugS(str) cout << "dbg: " << str << endl;
#define debugArr(arr, x, y) _for(i, 0, x) { _for(j, 0, y) printf("%c", arr[i][j]); printf("\n"); }
#define _forPlus(i, l, d, r) for(int i = (l); i + d < (r); i++)
#define lowbit(i) (i & (-i))
#define MPR(a, b) make_pair(a, b)

pair<int, int> crack(int n) {
    int st = sqrt(n);
    int fac = n / st;
    while (n % st) {
        st += 1;
        fac = n / st;
    }

    return make_pair(st, fac);
}

inline ll qpow(ll a, int n) {
    ll ans = 1;
    for(; n; n >>= 1) {
        if(n & 1) ans *= 1ll * a;
        a *= a;
    }
    return ans;
}

template <class T>
inline bool chmax(T& a, T b) {
    if(a < b) {
        a = b;
        return true;
    }
    return false;
}

ll gcd(ll a, ll b) {
    return b == 0 ? a : gcd(b, a % b);
}

ll ksc(ll a, ll b, ll mod) {
    ll ans = 0;
    for(; b; b >>= 1) {
        if (b & 1) ans = (ans + a) % mod;
        a = (a * 2) % mod;
    }
    return ans;
}

ll ksm(ll a, ll b, ll mod) {
    ll ans = 1 % mod;
    a %= mod;

    for(; b; b >>= 1) {
        if (b & 1) ans = ksc(ans, a, mod);
        a = ksc(a, a, mod);
    }

    return ans;
}

template <class T>
inline bool chmin(T& a, T b) {
    if(a > b) {
        a = b;
        return true;
    }
    return false;
}

template<class T>
bool lexSmaller(vector<T> a, vector<T> b) {
    int n = a.size(), m = b.size();
    int i;
    for(i = 0; i < n && i < m; i++) {
        if (a[i] < b[i]) return true;
        else if (b[i] < a[i]) return false;
    }
    return (i == n && i < m);
}

// ============================================================== //

const int maxn = 1500000 + 10;
const int N = 1e5 + 10;
int n;
char str[N][15];

class Trie {
public:
    int t[maxn][10];
    int vis[maxn];
    int tot;
    Trie() {
        tot = 1;
        memset(t, 0, sizeof t);
        memset(vis, 0, sizeof vis);
    }
    void clear() {
        tot = 1;
        memset(t, 0, sizeof t);
        memset(vis, 0, sizeof vis);
    }

    void insert(const char *str) {
        int p = 1, len = strlen(str);
        for (int i = 0; i < len; i++) {
            int c = str[i]-'0';
            if (!t[p][c]) t[p][c] = ++tot;
            p = t[p][c];
        }
        vis[p]++;
    }

    bool query(const char *str) {
        int p = 1, len = strlen(str);
        for (int i = 0; i < len-1; i++) {
            int c = str[i]-'0';
            p = t[p][c];
            if (vis[p]) return false;
        }
        return true;
    }

} trie;

int main() {
    freopen("input.txt", "r", stdin);
    int T;
    cin >> T;
    while (T--) {
        trie.clear();
        scanf("%d", &n);

        for (int i = 0; i < n; i++) {
            scanf("%s", str[i]);
            trie.insert(str[i]);
        }

        bool ok = true;
        for (int i = 0; i < n; i++) {
            if (trie.query(str[i]) == false) {
                ok = false;
                break;
            }
        }
        ok == true ? puts("YES") : puts("NO");
    }
}