头像

心里没有一点AC数

兰州大学 - 网易浚源工作室 - 西山居 - $\href{https://www.fogsail.net}{Blog}$




离线:1天前


最近来访(1130)
用户头像
卡卡罗特ss
用户头像
Gzm1317
用户头像
吃饱喝足不学习
用户头像
LCY_67
用户头像
wmh123
用户头像
蓬蒿人
用户头像
LiuGuangZhou
用户头像
老老老帅比
用户头像
D_K_D
用户头像
yjn1187
用户头像
pvenlambda
用户头像
acwing_1000
用户头像
智障也有春天
用户头像
月色
用户头像
Dumby_cat
用户头像
一万小时定律
用户头像
cdm
用户头像
222222
用户头像
人生如戏ba
用户头像
2021263933


很迷,赛中用 vector 被卡了,一度以为算法出了问题,整半天都用dfs序 + 线段树硬是没搞出来
赛后用了普通的暴力 dfs 就过了,只不过换了一个 unordered_set,很迷

算法分析

  • 首先把图建出来,然后枚举删除哪一条边,不妨设删除了边 $(x, y)$
    那么这个时候 $x, y$ 分别形成独立的树,可以 $dfs(x), dfs(y)$ 分别求出以 $x, y$ 为根的子树
    所有点权值的异或和,不妨设为 $(sx, sy)$

  • 然后呢,去遍历 $x$ 的子树,再 $dfs$ 一遍,这样对于 $x$ 的子树,有如下情况
    $x$ 的某个子树 $u$,$u$ 的整棵子树的异或和,可以一边 $dfs$ 的时候一边返回
    不妨设为 $t$,这样原来的图就被分为三个部分 $(sy, t, sx \oplus t)$,按题意更新全局的最小值即可
    $y$ 的子树同理

#pragma GCC optimize(2)
const int maxn = 1000 + 10;
class Solution {
public:
    int minimumScore(vector<int>& nums, vector<vector<int>>& edges) {
        int m = edges.size(), n = nums.size();
        unordered_set<int> G[maxn];

        function<int(int, int)> dfs = [&](int x, int fa) {
            int res = nums[x];
            for (auto y : G[x]) {
                if (y == fa) continue;
                res ^= dfs(y, x);
            }
            return res;
        };

        int ans = 1e9;

        function<int(int, int, const int, const int, const int)> sub = [&](int x, int fa, const int root, const int sx, const int sy) {
            int t = nums[x];
            for (auto y : G[x]) {
                if (y == fa) continue;
                t ^= sub(y, x, root, sx, sy);
            }

            if (x != root) {
                int mx = max( {sy, sx ^ t, t} ), mn = min( {sy, sx ^ t, t} );
                ans = min(ans, mx - mn);
            } 

            return t;
        };

        for (int i = 0; i < m; i++) {
            int x = edges[i][0], y = edges[i][1];
            G[x].insert(y), G[y].insert(x);
        }


        for (int i = 0; i < m; i++) {
            // delete edges[i]
            int x = edges[i][0], y = edges[i][1];
            G[x].erase(y), G[y].erase(x);

            int sx = dfs(x, -1), sy = dfs(y, -1);
            int tot = sx ^ sy;

            // (sy, sub(sx), sx^sub(sx)), (sx, sub(sy), sy^sub(sy))
            // array<int, 3> help;
            sub(x, -1, x, sx, sy), sub(y, -1, y, sy, sx);

            G[x].insert(y), G[y].insert(x);
        }
        return ans;
    }
};



算法分析

本例中需要支持如下操作,在 $n \times m$ 的方阵中,如果让 $(x, y)$ 出方阵,那么

  • 找到标号第 $x$ 大的行,并且在该行中找到第 $y$ 大的列,将对应的元素 $(x, y)$ 取出
    然后把该行中 $[y+1\cdots m]$ 整体往前挪一位

  • 然后操作第 $m$ 列,将第 $m$ 列 $[x+1\cdots n]$ 整体往上挪一位

  • 最后把取出来的元素插入 $(n, m)$ 处

注意到我们需要对 $n$ 行,以及第 $m$ 列求前缀第 $k$ 大
可以考虑一开始建 $n+1$ 棵线段树,其中每一棵线段树初始只有一个根节点
这样第 $[1\cdots n]$ 棵线段树维护 $[1 \to n]$ 行,第 $n+1$ 棵线段树维护第 $m$ 列

将线段树写成动态开点,这样就可以尝试维护操作了,维护区间内被删除了多少个数 $cnt$

  • 对于删除 $(x, y)$,就是在第 $x$ 棵线段树 $tr(x)$ 中找到第 $y$ 大的数 $p$,然后返回 $p$
    如果要对应到方阵中的编号,那么就是 $id = (x-1)\cdot m + p$
    然后删除这个点,采用懒惰删除的方法,对线段树这个点标记为删除,即 $u.\text{cnt} + 1$

  • 值得注意的是,懒惰删除的时候,和一般线段树不一样,我们 $\text{pull}(u)$ 的时候
    因为是动态开点的,$u$ 的左子树或者右子树可能不存在,要加判一下子树存在
    以左子树为例,如果 $u.l \neq \text{null}$,那么直接 $u.\text{cnt} += u.l.\text{cnt}$
    否则的话左子树不存在,也就是说我们并没有修改过左子树对应的区间,$u.\text{cnt} += 0$

  • 修改的时候,比如要出列第 $y$ 个人,只需要把 $root \to y$ 路径上的点都开出来就可以
    唯一需要注意的是,递归子树的时候,如果子树是 $\textbf{null}$,那么要先 $\textbf{new}$ 出来节点然后递归

  • 接着考虑查询第 $k$ 大,同样需注意,递归子树遇到 $\textbf{null}$ 要先将其 $\textbf{new}$ 出来
    假设当前递归到 $(u, l, r)$,那么左子树中的元素个数是 $tot = mid - l + 1$
    如果左子树非空,那么 $tot$ 还要扣除掉已经删除的元素,即 $tot -= u.l.\text{cnt}$
    否则,说明左子树对应的区间我们没有修改过,自然对应的 $\text{cnt} = 0$

  • 然后执行很常见的二分查找逻辑,如果 $k \leqslant tot$,那么在左子树查找第 $k$ 大
    否则去右子树找第 $k - tot$ 大

综上所述

  • 通过线段树维护,可以得到要出列的点编号,假设是 $id = (x-1) \cdot m + y$
    可以用一个 $\text{vector}$ 存储被取出的点,$\text{vector}[x] \leftarrow {y}$ 表示
    第 $x$ 棵线段树被删除点的情况就存储到 $\text{vector}[x]$ 中

  • 接着是查询第 $x$ 大的下标 $p$,如果是叶子节点,就直接返回 $p = l$,否则二分查询左右子树

  • 接下来有一些坑点,注意到对 $x$ 行修改,得到的前 $y$ 大 $\in [1, m-1]$
    最后一列我们是单独拿出来处理的,如果 $p \geqslant m$,那么查询到的这个人是
    “向左看齐” 之后由最后一列第 $x$ 大填补空缺补充上来的

  • 由此对 $(x, y)$ 出列操作,还需要额外维护最后一列填补空缺的人的编号
    即在第 $n+1$ 棵线段树中查找第 $x$ 大,此时这个编号的人不需要出列,而是填补空缺

由此程序实现需要注意的

  • 对最后一列操作,即第 $n+1$ 棵线段树,维护两种情况

  • 第一种是 $solve(1)$
    这一列的第 $x$ 个人出列,此时 $\text{vector}[n+1] \leftarrow {x}$,这个人要重新插入到末尾

  • 另外一种是 $solve(0)$
    前 $m-1$ 列的人出队,假设说是 $(x, y)$,在第 $n+1$ 棵线段树找到第 $x$ 大之后
    这个人是要填补空缺的,不需要真正出列,不需要将其放在 $\text{vector}[n+1]$ 的末尾
    而是要放到 $\text{vector}[x]$ 的末尾

  • 对其他列操作,加入是 $(x, y)$,在第 $x$ 棵线段树中找到第 $y$ 大的编号 $id$,做两件事情
    第一是在将 $id$ 加入到第 $n+1$ 棵线段树对应的 $\text{vector}$ 末尾,即 $\text{vector}[n+1] \leftarrow {id}$
    第二就是填补空缺,在第 $n+1$ 棵线段树中找到第 $x$ 大的编号 $id2$,$\text{vector}[x] \leftarrow {id2}$

这样所有信息都维护好了

  • 如果对第 $m$ 列操作,查询第 $x$ 大的下标是 $p$,$p \leqslant n$ 直接返回 $p$
    否则是出列过之后又进来的,返回 $\text{vector}[n+1][p-n-1]$

  • 如果对让 $(x, y)$ 出列,$x \leqslant m-1$,那么同样查询第 $y$ 大的下标 $p$
    $p \leqslant m-1$ 直接返回,否则返回 $\text{vector}[x][p-m]$\

#include <iostream>
#include <cstdio>
#include <cstring>
#include <cstdlib>
#include <algorithm>
#include <queue>
#include <vector>
#include <stack>
#include <map>
#include <set>
#include <sstream>
#include <iomanip>
#include <cmath>
#include <bitset>
#include <assert.h>
#include <unordered_map>
#include <functional>
#include <bit>
#include <random>
#include <numeric>

using namespace std;
typedef long long ll;

#define debug(x) cout << #x << ": " << x << endl
#define _rep(i, l, r) for(int i = (l); i <= (r); i++)
#define _for(i, l, r) for(int i = (l); i < (r); i++)
#define _forDown(i, l, r) for(int i = (l); i >= r; i--)
#define debug_(ch, i) printf(#ch"[%d]: %d\n", i, ch[i])
#define debug_m(mp, p) printf(#mp"[%d]: %d\n", p->first, p->second)
#define debugS(str) cout << "dbg: " << str << endl;
#define debugArr(arr, x, y) _for(i, 0, x) { _for(j, 0, y) printf("%c", arr[i][j]); printf("\n"); }
#define lowbit(i) (i & (-i))
#define fill_v(f, v) fill(f.begin(), f.end(), v)

template<class T> 
inline void debug_v(const vector<T> &vec) {
    printf("vec: ");
    for (auto u : vec) cout << u << " ";
    cout << endl;
}

template<class T>
inline int cntOne(const T x) {
    bitset<64> res(x);
    return res.count();
}

inline ll qpow(ll a, int n) {
    ll ans = 1;
    for(; n; n >>= 1) {
        if(n & 1) ans *= 1ll * a;
        a *= a;
    }
    return ans;
}

template <class T>
inline bool chmax(T& a, T b) {
    if(a < b) {
        a = b;
        return true;
    }
    return false;
}

ll gcd(ll a, ll b) {
    return b == 0 ? a : gcd(b, a % b);
}

ll ksc(ll a, ll b, ll mod) {
    ll ans = 0;
    for(; b; b >>= 1) {
        if (b & 1) ans = (ans + a) % mod;
        a = (a * 2) % mod;
    }
    return ans;
}

ll ksm(ll a, ll b, ll mod) {
    ll ans = 1 % mod;
    a %= mod;

    for(; b; b >>= 1) {
        if (b & 1) ans = ksc(ans, a, mod);
        a = ksc(a, a, mod);
    }

    return ans;
}

template <class T>
inline bool chmin(T& a, T b) {
    if(a > b) {
        a = b;
        return true;
    }
    return false;
}

template<class T>
bool lexSmaller(vector<T> a, vector<T> b) {
    int n = a.size(), m = b.size();
    int i;
    for(i = 0; i < n && i < m; i++) {
        if (a[i] < b[i]) return true;
        else if (b[i] < a[i]) return false;
    }
    return (i == n && i < m);
}

constexpr int P = 1e9 + 7;

int norm(int x) {
    if (x < 0) x += P;
    if (x >= P) x -= P;
    return x;
}

template<class T>
T power(T a, int b) {
    T res = 1;
    for (; b; b >>= 1) {
        if (b & 1) res *= a;
        a *= a;
    }
    return res;
}

struct mint {
    int x;
    mint(int x = 0) : x(norm(x)) {}

    int val() const {
        return x;
    }
    mint operator-() const {
        return mint(norm(P-x));
    }
    mint &operator *= (const mint &rhs) {
        x = (ll)(x) * rhs.x % P;
        return *this;
    }
    mint &operator += (const mint &rhs) {
        x = norm(x + rhs.x);
        return *this;
    }
    mint &operator -= (const mint &rhs) {
        x = norm(x - rhs.x);
        return *this;
    }
    mint &operator /= (const mint &rhs) {
        return *this *= rhs.inv();
    }
    mint inv() const {
        assert(x != 0);
        return power(*this, P-2);
    }
    friend mint operator* (const mint &lhs, const mint &rhs) {
        mint res = lhs;
        res *= rhs;
        return res;
    }
    friend mint operator+ (const mint &lhs, const mint &rhs) {
        mint res = lhs;
        res += rhs;
        return res;
    }
    friend mint operator- (const mint &lhs, const mint &rhs) {
        mint res = lhs;
        res -= rhs;
        return res;
    }
    friend mint operator/ (const mint &lhs, const mint &rhs) {
        mint res = lhs;
        res /= rhs;
        return res;
    }
};

struct Int {
    static constexpr int B = 10;
    vector<int> X;
    int size() const {
        return (int)X.size();
    }

    Int(int x = 0) {
        while (x) {
            X.push_back(x % B), x /= B;
        }
    }

    Int(string str) {
        reverse(str.begin(), str.end());
        for (auto u : str) X.push_back(u-'0');
    }

    friend Int operator+ (const Int &lhs, const Int &rhs) {
        if (lhs.size() < rhs.size()) return rhs + lhs;
        Int res;

        int t = 0;
        for (int i = 0; i < lhs.size(); i++) {
            t += lhs.X[i];
            if (i < rhs.size()) t += rhs.X[i];
            res.X.push_back(t % B), t /= B;
        }
        if (t) res.X.push_back(t);
        while (res.X.size() > 1 && res.X.back() == 0) res.X.pop_back();
        return res;
    }

    friend Int operator- (const Int &lhs, const Int &rhs) {
        Int res;
        int t = 0;
        for (int i = 0; i < lhs.size(); i++) {
            t = lhs.X[i] - t;
            if (i < rhs.size()) t -= rhs.X[i];
            res.X.push_back((t + B) % B);

            if (t < 0) t = 1;
            else t = 0;
        }
        while (res.X.size() > 1 && res.X.back() == 0) res.X.pop_back();
        return res;
    }

    friend Int operator* (const Int &lhs, int b) {
        Int res;
        int t = 0;
        for (int i = 0; i < lhs.X.size() || t; i++) {
            if (i < lhs.X.size()) t += lhs.X[i] * b;
            res.X.push_back(t % B), t /= B;
        }
        return res;
    }

    friend Int operator/ (const Int &lhs, int b) {
        Int res;
        int r = 0;
        for (int i = lhs.X.size()-1; i >= 0; i--) {
            r = r * B + lhs.X[i];
            res.X.push_back(r / b), r %= b;
        }
        reverse(res.X.begin(), res.X.end());
        while (res.X.size() > 1 && res.X.back() == 0) res.X.pop_back();
        return res;
    }

    friend Int operator* (const Int &lhs, const Int &rhs) {
        Int res;
        res.X.resize(lhs.size() + rhs.size() + B);
        fill(res.X.begin(), res.X.end(), 0);

        for (int i = 0; i < lhs.size(); i++) {
            for (int j = 0; j < rhs.size(); j++) {
                res.X[i+j] += lhs.X[i] * rhs.X[j];
            }
        }
        for (int i = 0; i < res.X.size(); i++) {
            if (res.X[i] >= B) res.X[i+1] += res.X[i] / B, res.X[i] %= B;
        }

        while (res.X.size() > 1 && res.X.back() == 0) res.X.pop_back();
        return res;
    }

    friend Int operator/ (const Int& lhs, const Int &rhs) {
        int dv = lhs.size() - rhs.size();
        assert(dv >= 0);

        Int res;
        res.X.resize(dv+1);
        fill(res.X.begin(), res.X.end(), 0);

        // append suffix zero
        Int a = lhs, b = rhs;
        reverse(b.X.begin(), b.X.end());
        for (int i = 0; i < dv; i++) b.X.push_back(0);
        reverse(b.X.begin(), b.X.end());

        for (int i = 0; i <= dv; i++) {
            while (b < a) {
                a = a-b;
                res.X[dv-i]++;
            }
            b.X.erase(b.X.begin());
        }
        while (res.X.size() > 1 && res.X.back() == 0) res.X.pop_back();
        return res;
    }

    friend bool operator< (const Int &lhs, const Int &rhs) {
        if (lhs.size() < rhs.size()) return true;
        if (lhs.size() > rhs.size()) return false;

        if (vector<int>(lhs.X.rbegin(), lhs.X.rend()) <= vector<int>(rhs.X.rbegin(), rhs.X.rend())) return true;
        return false;
    }
    void out() {
        if (X.size() == 0) {
            puts("0");
            return;
        }
        reverse(X.begin(), X.end());
        for (auto x : X) printf("%d", x);
        printf("\n");
    }
};

Int max_int(const Int &A, const Int &B) {
    if (A < B) return B;
    else return A;
}

int get_root(int P) {
    function<vector<int>(int x)> divide = [&](int x) {
        vector<int> primes;
        for (int i = 2; i <= sqrt(x); i++) {
            if (x % i) continue;

            primes.push_back(i);
            while (x % i == 0) x /= i;
        }
        if (x > 1) primes.push_back(x);

        return primes;
    };

    vector<int> pr = divide(P-1);

    for (ll g = 2; g <= P-1; g++) {
        bool ok = true;

        for (auto p : pr) {
            if (ksm(g, (1LL * P-1)/p, P) == 1) {
                ok = false; break;
            }
        }

        if (ok) return g;
    }
    return -1;
}

namespace NTT {
    const int G = 3;

    vector<int> rev;
    void ntt(vector<mint> &a, int op) {
        int n = a.size();

        if ((int)rev.size() != n) {
            rev.resize(n);
            // int k = __builtin_ctz(n);
            //int k = countr_zero((unsigned int)n);
           int k = 0; 

            for (int i = 0; i < n; i++) {
                rev[i] = rev[i>>1] >> 1 | (i&1) << (k-1);
            }
        }

        for (int i = 0; i < n; i++) {
            if (rev[i] < i) swap(a[i], a[rev[i]]);
        }

        // swap
        for (int mid = 1; mid < n; mid <<= 1) {
            mint gn = power(mint(G), (P - 1) / (mid << 1));
            if (op == -1) gn = gn.inv();

            for (int i = 0; i < n; i += mid * 2) {
                mint gnk = 1;

                for (int j = 0; j < mid; j++) {
                    mint u = a[i+j], v = gnk * a[i+mid+j];
                    a[i+j] = u + v, a[i+mid+j] = u - v;
                    gnk = gnk * gn;
                }
            }
        }

        if (op == -1) {
            mint inv = mint((int)a.size()).inv();
            for (int i = 0; i < n; i++) {
                a[i] *= inv;
            }
        }
    }

    void dft(vector<mint> &a) {
        ntt(a, 1);
    }

    void idft(vector<mint> &a) {
        ntt(a, -1);
    }
};

struct Poly {
    vector<mint> a;
    Poly() {}
    Poly(const vector<mint> &a) : a(a) {}
    Poly(const initializer_list<mint> &a) : a(a) {}
    int size() const {
        return a.size();
    }
    void resize(int n) {
        a.resize(n);
    }

    mint operator[] (int idx) const {
        if (idx < 0 || idx >= size()) {
            return 0;
        }
        return a[idx];
    }
    mint& operator[] (int idx) {
        return a[idx];
    }

    Poly mulxk(int k) const {
        auto b = a;
        b.insert(b.begin(), k, 0);
        return Poly(b);
    }
    Poly modxk(int k) const {
        k = min(k, size());
        return Poly(vector<mint>(a.begin(), a.begin() + k));
    }
    Poly divxk(int k) const {
        if (size() <= k) {
            return Poly();
        }
        return Poly(vector<mint>(a.begin() + k, a.end()));
    }

    friend Poly operator+ (const Poly &a, const Poly &b) {
        vector<mint> res(max(a.size(), b.size()));
        for (int i = 0; i < (int)res.size(); i++) {
            res[i] = a[i] + b[i];
        }
        return Poly(res);
    }

    friend Poly operator- (const Poly &a, const Poly &b) {
        vector<mint> res(max(a.size(), b.size()));
        for (int i = 0; i < (int)res.size(); i++) {
            res[i] = a[i] - b[i];
        }
        return Poly(res);
    }

    friend Poly operator* (Poly a, Poly b) {
        using namespace NTT;

        if (a.size() == 0 || b.size() == 0) {
            return Poly();
        }

        int sz = 1, tot = a.size() + b.size() - 1;
        while (sz < tot) sz *= 2;

        a.a.resize(sz), b.a.resize(sz);
        dft(a.a), dft(b.a);

        for (int i = 0; i < sz; i++) {
            a.a[i] = a[i] * b[i];
        }

        idft(a.a);
        a.resize(tot);
        return a;
    }

    friend Poly operator* (Poly a, mint b) {
        for (int i = 0; i < (int)a.size(); i++) {
            a[i] *= b;
        }
        return a;
    }

    friend Poly operator* (mint a, Poly b) {
        for (int i = 0; i < (int)b.size(); i++) {
            b[i] *= a;
        }
        return b;
    }

    Poly &operator+= (Poly b) {
        return (*this) = (*this) + b;
    }
    Poly &operator-= (Poly b) {
        return (*this) = (*this) - b;
    }
    Poly &operator*= (Poly b) {
        return (*this) = (*this) * b;
    }

    Poly deriv() const {
        if (a.empty()) return Poly();

        vector<mint> res(size() - 1);
        for (int i = 0; i < size()-1; i++) {
            res[i] = (i + 1) * a[i + 1];
        }

        return Poly(res);
    }

    Poly integr() const {
        vector<mint> res(size() + 1);
        for (int i = 0; i < size(); i++) {
            res[i + 1] = a[i] / (i + 1);
        }
        return Poly(res);
    }
};

mt19937_64 mrand(random_device{}());
int rnd(int x) {
    return mrand() % x;
}

// ============================================================== //

struct info {
    info *l = nullptr, *r = nullptr;
    ll cnt;

    bool leaf() const {
        return !l && !r;
    }

    // delete child
    info& operator--() {
        if (l) delete l;
        if (r) delete r;
        l = r = nullptr;

        return *this;
    }

    // update info
    info& operator++() {
        int now = 0;
        if (l != nullptr) now += l->cnt;
        if (r != nullptr) now += r->cnt;
        cnt = now;
        return *this;
    }

    void setval(int val) {
        cnt += val;
    }

    ~info() {
        --(*this);
    }
};

// 动态开点线段树,两点注意
// 对一个区间 setval 的时候,它的所有值都相同,这个时候删除子节点,调用 p--
// 自顶向下递归的时候,如果没有遇到目标区间 [l, r],那么要 push 开点
// push 和普通线段树比,动态开点的 push 一般是要新建节点的

// 另外如果区间所有值都相等,那么这个区间用 leaf 标记
struct SegTree {
    info* root;

    SegTree() {
        root = new info();
    }
    ~SegTree() {
        if (root != nullptr) {
            --(*root);
            delete root;
        }
    }

    void pull(info &p) {
        ++p;
    }

    // 节点 p 对应区间是 [l, r],将 a[pos] -> val

    void change(info &p, int l, int r, int x, int val) {
        if (l >= r) {
            p.setval(val);
            return;
        }

        int mid = (l + r) >> 1;
        if (x <= mid) {
            if (!p.l) p.l = new info();
            change(*(p.l), l, mid, x, val);
        }
        else {
            if (!p.r) p.r = new info();
            change(*(p.r), mid+1, r, x, val);
        }

        pull(p);
    }

    int query(info &p, int l, int r, int x) {
        if (l >= r) return l;
        int mid = (l + r) >> 1;

        int tot = mid - l + 1;
        if (p.l) tot -= p.l->cnt;

        // 递归下去的时候动态开点
        if (x <= tot) {
            if (!p.l) p.l = new info();
            return query(*p.l, l, mid, x);
        }
        else {
            if (!p.r) p.r = new info();
            return query(*p.r, mid+1, r, x-tot);
        }
    }
};

// usage: SegTree<4 * maxn> tr(a, n)

int n, m, q, lim;
const int maxn = 3e5 + 5;
vector<vector<ll> > extra(maxn);

vector<SegTree> tr(maxn);

ll solve0(int x, bool del) {
    // tr[n+1] xth element
    // idx = 0 only operate col m
    // idx = other ? (x, y) insert to (n, m)

    int p = tr[n+1].query(*tr[n+1].root, 1, lim, x);
    tr[n+1].change(*tr[n+1].root, 1, lim, p, 1);

    ll ans = 0;
    if (p <= n) ans = 1LL * p * m;
    else ans = extra[n+1][p-n-1];

    if (del) extra[n+1].push_back(ans);

    return ans;
}

ll solve1(int x, int y) {
    int p = tr[x].query(*tr[x].root, 1, lim, y);
    tr[x].change(*tr[x].root, 1, lim, p, 1);

    ll ans = 0;
    if (p < m) ans = 1LL * (x-1) * m + p;
    else ans = extra[x][p-m];
    extra[n+1].push_back(ans);

    ll now = solve0(x, 0);
    extra[x].push_back(now);

    return ans;
}

// usage: SegTree<4 * maxn> tr(a, n)

int main() {
    //freopen("input.txt", "r", stdin);

    cin >> n >> m >> q;
    lim = max(n, m) + q;

    while (q--) {
        int x, y;
        scanf("%d%d", &x, &y);

        ll ans = 0;
        if (y == m) ans = solve0(x, 1);
        else ans = solve1(x, y);

        printf("%lld\n", ans);
    }
}



算法设计
有 $n$ 个位置需要染色,可以考虑如下分组,先从 $m$ 种颜色中任选 $k$ 种,每一种元素是一组,然后剩下的颜色分一组
对于 $[1\cdots k]$ 组,每一组放 $s$ 个元素,表示 $k$ 种颜色,每一种都染了 $s$ 个位置
那么剩下的呢?剩下 $(n - ks)$ 个位置呢?怎么染色?
实际上,剩下的位置可以染剩下的 $(m - k)$ 种颜色中的任意一种,这样我们可以得到一个初步的表达式

$\displaystyle \binom{m}{k} \cdot (m-k)^{n-ks} \cdot \frac{n!}{(s!)^{k} \cdot (n - ks)!}$

其中 $\displaystyle \frac{n!}{(s!)^k \cdot (n - ks)!}$ 表示分组排列的方案数

但是,这并不是我们需要的,题目中要求 $k$ 种颜色恰好出现了 $s$ 次
上面的表达式可以保证一定有 $k$ 种颜色出现了 $s$ 次,但是对于剩下的 $(m - k)$ 种颜色呢?
对于未染色的 $(n - ks)$ 个位置,我们是从剩下的 $(m - k)$ 种颜色中任意选取,所以
还可能存在 $k$ 种颜色之外的其他颜色,也染了 $s$ 个位置
上面的表达式计算,求出的出现 $s$ 次的颜色种类 $\geqslant k$ 种,不妨记为 $f(k)$

$\displaystyle f(k) = \binom{m}{k} \cdot (m-k)^{n-ks} \cdot \frac{n!}{(s!)^{k} \cdot (n - ks)!}$

$f(k)$ 表示出现 $s$ 次的颜色至少有 $k$ 种的方案数

我们需要的是出现 $s$ 次的颜色恰好有 $k$ 种的方案,不妨记为 $g(k)$,那么出现 $s$ 次的颜色最多有多少种呢?
不难想到最多有 $\displaystyle N = \min \left(m, \left\lfloor \frac{n}{s} \right\rfloor \right)$

$\displaystyle f(k) = \sum_{i = k}^{N} g(i)$,只可惜这样还是不对

已知出现 $s$ 次的颜色恰好有 $i$ 种,对应的方案数为 $g(i)$,但 $f(k)$ 对应的是
有 $k$ 种颜色一定要出现 $s$ 次,剩下的没有限制,仅仅把 $g(i)$ 加起来会漏解

实际上,对于恰好有 $i$ 种颜色出现 $s$ 次,方案数为 $g(i)$,那么推 $f(k)$ 时应该这样考虑
$i$ 种颜色中 $\displaystyle \binom{i}{k}$ 任意选择 $k$ 种颜色一定出现 $s$ 次,剩下 $(i-k)$ 种颜色出现次数没有限制

$\displaystyle f(k) = \sum_{i = k}^{N} g(i) \cdot \binom{i}{k}$

接下来考虑如何计算,我们需要知道的是 $g(k)$,最后的答案是 $\displaystyle \sum_{k = 0}^{N} (g(k) \cdot w_k)$

$\displaystyle f(k) = \sum_{i = k}^{N} \binom{i}{k} \cdot g(i)$,二项式反演可以得到

$\displaystyle g(k) = \sum_{i = k}^{N} (-1)^{i - k} \binom{i}{k} \cdot f(i)$

化简可得
$\displaystyle k! g(k) = \sum_{i = k}^{N} \left(\frac{(-1)^{i - k}}{(i-k)!} \right) \left(i! f(i) \right)$

令 $A(i) = i! f(i)$,$B(i) = \displaystyle \frac{(-1)^i}{i!}$

那么 $\displaystyle g(k) = \frac{1}{k!} \sum_{i = k}^{N} A(i)B(i-k)$

下面考虑化简该式
$g(k) = \displaystyle \frac{1}{k!} \sum_{k \leqslant i \leqslant N} A(i)B(i-k)$ 并不是标准的卷积形式

做变量替换,令 $i’ = i - k$,那么变量范围为 $0 \leqslant i’ \leqslant N-k$

$\displaystyle g(k) = \frac{1}{k!} \sum_{0 \leqslant i \leqslant N-k} A(i+k)B(i)$,但可惜这也不是标准卷积形式

令 $A(i+k) = A’(N-k-i)$,这样 $\displaystyle g(k) = \frac{1}{k!} \sum_{0 \leqslant i \leqslant N-k} A’(N-k-i)B(i)$

这就是标准的卷积形式,只要做变换 $A’(x) = A(N-x)$,这样就可以得到

$\displaystyle g(k) = \frac{1}{k!} \sum_{0 \leqslant i \leqslant N-k} A’(N-k-i)B(i)$,用 $\text{NTT}$ 求卷积

int N, n, m, s;
int w[maxm];

void solve() {
    N = min(m, n/s);
    vector<mint> f(N+1, 0);
    for (int i = 0; i <= N; i++) f[i] = binom(m, i) * power(mint(m-i), n-s*i) * fac[n] * power(infac[s], i) * (infac[n-s*i]);
    vector<mint> _a(N+1, 0), _b(N+1, 0);
    for (int i = 0; i <= N; i++) _a[i] = fac[i] * f[i];
    for (int i = 0; i <= N; i++) _b[i] = ( (i & 1) ? (-1) : 1 ) * infac[i];
    for (int i = 0; i <= N; i++) if (i < N - i) swap(_a[i], _a[N-i]);

    Poly A(_a), B(_b);
    Poly C = A * B;

    mint ans = 0;
    for (int k = 0; k <= N; k++) ans = ans + mint(w[k]) * infac[k] * C[N-k];
    printf("%d\n", ans.x);
}

二项式反演的推导(1)

$\displaystyle f_n = \sum_{i = 0}^n \binom{n}{i} g_i \Leftrightarrow g_n = \sum_{i = 0}^{n} (-1)^{n-i} \binom{n}{i} f_i$
这个可以用生成函数来推

$\displaystyle \frac{f_n}{n!} = \sum_{i = 0}^{n} \frac{g_i}{i!} \frac{1}{(n - i)!}$,可以知道

$\displaystyle F = \left<\frac{f_n}{n!}\right>$ 的 $\text{EGF}$ 为 $\displaystyle F_i = \sum_{i = 0}^{\infty} f_ix^i$

$\displaystyle G = \left<\frac{g_n}{n!} \right>$ 的 $\text{EGF}$ 为 $\displaystyle G_i = \sum_{i = 0}^{\infty} g_ix^i$

根据 $\displaystyle e^x = \sum_{i = 0}^{\infty} \frac{x^i}{i!}, \quad e^{-x} = \sum_{i = 0}^{\infty} (-1)^i \frac{x^i}{i!}$

我们有 $F = G \cdot e^{x}$,由此 $G = F \cdot e^{-x}$,根据生成函数的系数展开

$\displaystyle \frac{g_n}{n!} = \sum_{i = 0}^{n} \frac{f_i}{i!} \cdot (-1)^{n-i} \frac{1}{(n-i)!}$,证毕

二项式反演的推导(2)

$\displaystyle f_n = \sum_{i = 0}^{n} (-1)^i \binom{n}{i}g_i \Longleftrightarrow g_n = \sum_{i = 0}^{n} (-1)^i\binom{n}{i}f_i$

这里使用待定系数的方法来构造

$\displaystyle g_n = \sum_{i = 0}^{n} t(n, i)\cdot f_i, \quad t(n, i)$ 为待定系数

$\displaystyle f_n = \sum_{i = 0}^{n} (-1)^i \binom{n}{i} \sum_{j = 0}^{i} t(i, j)\cdot f_j = \sum_{j = 0}^{n} f_j \sum_{i = j}^{n} (-1)^i \binom{n}{i}t(i, j)$

令 $i’ = i - j, \quad i = i’+j$ 代换

$\displaystyle f_n = \sum_{j = 0}^{n} f_j \sum_{i = 0}^{n - j} (-1)^{i+j} \binom{n}{i+j} t(i+j, j)$

注意到 $j = n$ 的时候,$f_n = f_n$,于是我们必须构造出

$\displaystyle [j = n] = \sum_{i = 0}^{n - j} (-1)^{i+j} \binom{n}{i+j} t(i+j, j) \quad \textbf{(1)}$

注意到 $(1-1)^n$ 的展开,有 $\displaystyle [n = 0] = \sum_{i = 0}^{n}(-1)^i \binom{n}{i}$

$\displaystyle [j = n] = \sum_{i = 0}^{n - j} (-1)^i \binom{n - j}{i} \quad \textbf{(2)}$

$\textbf{(1), (2)}$ 对比,可以知道 $t(i+j, j)$ 中有组合数形式,我们配凑

$\displaystyle [j = n] = \sum_{i = 0}^{n - j} (-1)^i \binom{n-j}{i} \binom{n}{j} = \sum_{i = 0}^{n - j} (-1)^i \binom{n}{n - j} \binom{n-j}{i}$

$\displaystyle \quad \quad = \sum_{i = 0}^{n - j} (-1)^i \binom{n}{i} \binom{n-i}{n-j-i} = \sum_{i = 0}^{n - j} (-1)^i \binom{n}{i} \binom{n-i}{j}$

其中,我们根据恒等式 $\displaystyle \binom{n}{i+j} \binom{i+j}{i} = \binom{n}{i} \binom{n-i}{j}$

所以有 $\displaystyle [j = n] = \sum_{i = 0}^{n-j} (-1)^i \binom{n}{i+j} \binom{i+j}{i}$,从而有

$\displaystyle t(i+j, j) = (-1)^j \cdot \binom{i+j}{i} = (-1)^j \binom{i+j}{j}$,从而有

$\displaystyle t(i, j) = (-1)^j \binom{i}{j}$

完整代码

#include <iostream>
#include <cstdio>
#include <cstring>
#include <cstdlib>
#include <algorithm>
#include <queue>
#include <vector>
#include <stack>
#include <map>
#include <set>
#include <sstream>
#include <iomanip>
#include <cmath>
#include <bitset>
#include <assert.h>
#include <unordered_map>
#include <functional>
#pragma GCC optimize(2)


using namespace std;
typedef long long ll;

#define debug(x) cout << #x << ": " << x << endl
#define _rep(i, l, r) for(int i = (l); i <= (r); i++)
#define _for(i, l, r) for(int i = (l); i < (r); i++)
#define _forDown(i, l, r) for(int i = (l); i >= r; i--)
#define debug_(ch, i) printf(#ch"[%d]: %d\n", i, ch[i])
#define debug_m(mp, p) printf(#mp"[%d]: %d\n", p->first, p->second)
#define debugS(str) cout << "dbg: " << str << endl;
#define debugArr(arr, x, y) _for(i, 0, x) { _for(j, 0, y) printf("%c", arr[i][j]); printf("\n"); }
#define lowbit(i) (i & (-i))
#define fill_v(f, v) fill(f.begin(), f.end(), v)

template<class T> 
inline void debug_v(const vector<T> &vec) {
    printf("vec: ");
    for (auto u : vec) cout << u << " ";
    cout << endl;
}

template<class T>
inline int cntOne(const T x) {
    bitset<64> res(x);
    return res.count();
}

inline ll qpow(ll a, int n) {
    ll ans = 1;
    for(; n; n >>= 1) {
        if(n & 1) ans *= 1ll * a;
        a *= a;
    }
    return ans;
}

template <class T>
inline bool chmax(T& a, T b) {
    if(a < b) {
        a = b;
        return true;
    }
    return false;
}

ll gcd(ll a, ll b) {
    return b == 0 ? a : gcd(b, a % b);
}

ll ksc(ll a, ll b, ll mod) {
    ll ans = 0;
    for(; b; b >>= 1) {
        if (b & 1) ans = (ans + a) % mod;
        a = (a * 2) % mod;
    }
    return ans;
}

ll ksm(ll a, ll b, ll mod) {
    ll ans = 1 % mod;
    a %= mod;

    for(; b; b >>= 1) {
        if (b & 1) ans = ksc(ans, a, mod);
        a = ksc(a, a, mod);
    }

    return ans;
}

template <class T>
inline bool chmin(T& a, T b) {
    if(a > b) {
        a = b;
        return true;
    }
    return false;
}

template<class T>
bool lexSmaller(vector<T> a, vector<T> b) {
    int n = a.size(), m = b.size();
    int i;
    for(i = 0; i < n && i < m; i++) {
        if (a[i] < b[i]) return true;
        else if (b[i] < a[i]) return false;
    }
    return (i == n && i < m);
}

constexpr int P = 1004535809;

int norm(int x) {
    if (x < 0) x += P;
    if (x >= P) x -= P;
    return x;
}

template<class T>
T power(T a, int b) {
    T res = 1;
    for (; b; b >>= 1) {
        if (b & 1) res *= a;
        a *= a;
    }
    return res;
}

struct mint {
    int x;
    mint(int x = 0) : x(norm(x)) {}

    int val() const {
        return x;
    }
    mint operator-() const {
        return mint(norm(P-x));
    }
    mint &operator *= (const mint &rhs) {
        x = (ll)(x) * rhs.x % P;
        return *this;
    }
    mint &operator += (const mint &rhs) {
        x = norm(x + rhs.x);
        return *this;
    }
    mint &operator -= (const mint &rhs) {
        x = norm(x - rhs.x);
        return *this;
    }
    mint &operator /= (const mint &rhs) {
        return *this *= rhs.inv();
    }
    mint inv() const {
        assert(x != 0);
        return power(*this, P-2);
    }
    friend mint operator* (const mint &lhs, const mint &rhs) {
        mint res = lhs;
        res *= rhs;
        return res;
    }
    friend mint operator+ (const mint &lhs, const mint &rhs) {
        mint res = lhs;
        res += rhs;
        return res;
    }
    friend mint operator- (const mint &lhs, const mint &rhs) {
        mint res = lhs;
        res -= rhs;
        return res;
    }
    friend mint operator/ (const mint &lhs, const mint &rhs) {
        mint res = lhs;
        res /= rhs;
        return res;
    }
};

struct Int {
    static constexpr int B = 10;
    vector<int> X;
    int size() const {
        return (int)X.size();
    }

    Int(int x = 0) {
        while (x) {
            X.push_back(x % B), x /= B;
        }
    }

    Int(string str) {
        reverse(str.begin(), str.end());
        for (auto u : str) X.push_back(u-'0');
    }

    friend Int operator+ (const Int &lhs, const Int &rhs) {
        if (lhs.size() < rhs.size()) return rhs + lhs;
        Int res;

        int t = 0;
        for (int i = 0; i < lhs.size(); i++) {
            t += lhs.X[i];
            if (i < rhs.size()) t += rhs.X[i];
            res.X.push_back(t % B), t /= B;
        }
        if (t) res.X.push_back(t);
        while (res.X.size() > 1 && res.X.back() == 0) res.X.pop_back();
        return res;
    }

    friend Int operator- (const Int &lhs, const Int &rhs) {
        Int res;
        int t = 0;
        for (int i = 0; i < lhs.size(); i++) {
            t = lhs.X[i] - t;
            if (i < rhs.size()) t -= rhs.X[i];
            res.X.push_back((t + B) % B);

            if (t < 0) t = 1;
            else t = 0;
        }
        while (res.X.size() > 1 && res.X.back() == 0) res.X.pop_back();
        return res;
    }

    friend Int operator* (const Int &lhs, int b) {
        Int res;
        int t = 0;
        for (int i = 0; i < lhs.X.size() || t; i++) {
            if (i < lhs.X.size()) t += lhs.X[i] * b;
            res.X.push_back(t % B), t /= B;
        }
        return res;
    }

    friend Int operator/ (const Int &lhs, int b) {
        Int res;
        int r = 0;
        for (int i = lhs.X.size()-1; i >= 0; i--) {
            r = r * B + lhs.X[i];
            res.X.push_back(r / b), r %= b;
        }
        reverse(res.X.begin(), res.X.end());
        while (res.X.size() > 1 && res.X.back() == 0) res.X.pop_back();
        return res;
    }

    friend Int operator* (const Int &lhs, const Int &rhs) {
        Int res;
        res.X.resize(lhs.size() + rhs.size() + B);
        fill(res.X.begin(), res.X.end(), 0);

        for (int i = 0; i < lhs.size(); i++) {
            for (int j = 0; j < rhs.size(); j++) {
                res.X[i+j] += lhs.X[i] * rhs.X[j];
            }
        }
        for (int i = 0; i < res.X.size(); i++) {
            if (res.X[i] >= B) res.X[i+1] += res.X[i] / B, res.X[i] %= B;
        }

        while (res.X.size() > 1 && res.X.back() == 0) res.X.pop_back();
        return res;
    }

    friend Int operator/ (const Int& lhs, const Int &rhs) {
        int dv = lhs.size() - rhs.size();
        assert(dv >= 0);

        Int res;
        res.X.resize(dv+1);
        fill(res.X.begin(), res.X.end(), 0);

        // append suffix zero
        Int a = lhs, b = rhs;
        reverse(b.X.begin(), b.X.end());
        for (int i = 0; i < dv; i++) b.X.push_back(0);
        reverse(b.X.begin(), b.X.end());

        for (int i = 0; i <= dv; i++) {
            while (b < a) {
                a = a-b;
                res.X[dv-i]++;
            }
            b.X.erase(b.X.begin());
        }
        while (res.X.size() > 1 && res.X.back() == 0) res.X.pop_back();
        return res;
    }

    friend bool operator< (const Int &lhs, const Int &rhs) {
        if (lhs.size() < rhs.size()) return true;
        if (lhs.size() > rhs.size()) return false;

        if (vector<int>(lhs.X.rbegin(), lhs.X.rend()) <= vector<int>(rhs.X.rbegin(), rhs.X.rend())) return true;
        return false;
    }
    void out() {
        if (X.size() == 0) {
            puts("0");
            return;
        }
        reverse(X.begin(), X.end());
        for (auto x : X) printf("%d", x);
        printf("\n");
    }
};

Int max_int(const Int &A, const Int &B) {
    if (A < B) return B;
    else return A;
}

// ============================================================== //


int get_root(int P) {
    function<vector<int>(int x)> divide = [&](int x) {
        vector<int> primes;
        for (int i = 2; i <= sqrt(x); i++) {
            if (x % i) continue;

            primes.push_back(i);
            while (x % i == 0) x /= i;
        }
        if (x > 1) primes.push_back(x);

        return primes;
    };

    vector<int> pr = divide(P-1);

    for (ll g = 2; g <= P-1; g++) {
        bool ok = true;

        for (auto p : pr) {
            if (ksm(g, (1LL * P-1)/p, P) == 1) {
                ok = false; break;
            }
        }

        if (ok) return g;
    }
    return -1;
}

namespace NTT {
    const int G = 3;

    vector<int> rev;
    void ntt(vector<mint> &a, int op) {
        int n = a.size();

        if ((int)rev.size() != n) {
            rev.resize(n);
            int k = __builtin_ctz(n);

            for (int i = 0; i < n; i++) {
                rev[i] = rev[i>>1] >> 1 | (i&1) << (k-1);
            }
        }

        for (int i = 0; i < n; i++) {
            if (rev[i] < i) swap(a[i], a[rev[i]]);
        }

        // swap
        for (int mid = 1; mid < n; mid <<= 1) {
            mint gn = power(mint(G), (P - 1) / (mid << 1));
            if (op == -1) gn = gn.inv();

            for (int i = 0; i < n; i += mid * 2) {
                mint gnk = 1;

                for (int j = 0; j < mid; j++) {
                    mint u = a[i+j], v = gnk * a[i+mid+j];
                    a[i+j] = u + v, a[i+mid+j] = u - v;
                    gnk = gnk * gn;
                }
            }
        }

        if (op == -1) {
            mint inv = mint((int)a.size()).inv();
            for (int i = 0; i < n; i++) {
                a[i] *= inv;
            }
        }
    }

    void dft(vector<mint> &a) {
        ntt(a, 1);
    }

    void idft(vector<mint> &a) {
        ntt(a, -1);
    }
};

struct Poly {
    vector<mint> a;
    Poly() {}
    Poly(const vector<mint> &a) : a(a) {}
    Poly(const initializer_list<mint> &a) : a(a) {}
    int size() const {
        return a.size();
    }
    void resize(int n) {
        a.resize(n);
    }

    mint operator[] (int idx) const {
        if (idx < 0 || idx >= size()) {
            return 0;
        }
        return a[idx];
    }
    mint& operator[] (int idx) {
        return a[idx];
    }

    Poly mulxk(int k) const {
        auto b = a;
        b.insert(b.begin(), k, 0);
        return Poly(b);
    }
    Poly modxk(int k) const {
        k = min(k, size());
        return Poly(vector<mint>(a.begin(), a.begin() + k));
    }
    Poly divxk(int k) const {
        if (size() <= k) {
            return Poly();
        }
        return Poly(vector<mint>(a.begin() + k, a.end()));
    }

    friend Poly operator+ (const Poly &a, const Poly &b) {
        vector<mint> res(max(a.size(), b.size()));
        for (int i = 0; i < (int)res.size(); i++) {
            res[i] = a[i] + b[i];
        }
        return Poly(res);
    }

    friend Poly operator- (const Poly &a, const Poly &b) {
        vector<mint> res(max(a.size(), b.size()));
        for (int i = 0; i < (int)res.size(); i++) {
            res[i] = a[i] - b[i];
        }
        return Poly(res);
    }

    friend Poly operator* (Poly a, Poly b) {
        using namespace NTT;

        if (a.size() == 0 || b.size() == 0) {
            return Poly();
        }

        int sz = 1, tot = a.size() + b.size() - 1;
        while (sz < tot) sz *= 2;

        a.a.resize(sz), b.a.resize(sz);
        dft(a.a), dft(b.a);

        for (int i = 0; i < sz; i++) {
            a.a[i] = a[i] * b[i];
        }

        idft(a.a);
        a.resize(tot);
        return a;
    }

    friend Poly operator* (Poly a, mint b) {
        for (int i = 0; i < (int)a.size(); i++) {
            a[i] *= b;
        }
        return a;
    }

    friend Poly operator* (mint a, Poly b) {
        for (int i = 0; i < (int)b.size(); i++) {
            b[i] *= a;
        }
        return b;
    }

    Poly &operator+= (Poly b) {
        return (*this) = (*this) + b;
    }
    Poly &operator-= (Poly b) {
        return (*this) = (*this) - b;
    }
    Poly &operator*= (Poly b) {
        return (*this) = (*this) * b;
    }

    Poly deriv() const {
        if (a.empty()) return Poly();

        vector<mint> res(size() - 1);
        for (int i = 0; i < size()-1; i++) {
            res[i] = (i + 1) * a[i + 1];
        }

        return Poly(res);
    }

    Poly integr() const {
        vector<mint> res(size() + 1);
        for (int i = 0; i < size(); i++) {
            res[i + 1] = a[i] / (i + 1);
        }
        return Poly(res);
    }
};

const int maxn = 1e7 + 5;
const int maxm = 1e5 + 5;

mint inv[maxn], fac[maxn], infac[maxn];

void pre() {
    inv[0] = inv[1] = 1;
    for (int i = 2; i < maxn; i++) inv[i] = mint(P - P/i) * inv[P%i];

    fac[0] = infac[0] = 1;
    for (int i = 1; i < maxn; i++) fac[i] = fac[i-1] * mint(i);
    for (int i = 1; i < maxn; i++) infac[i] = infac[i-1] * inv[i];
}

inline mint binom(int x, int y) {
    if (x < y) return mint(0);
    return fac[x] * infac[x-y] * infac[y];
}

int N, n, m, s;
int w[maxm];

void solve() {
    N = min(m, n/s);
    vector<mint> f(N+1, 0);
    for (int i = 0; i <= N; i++) f[i] = binom(m, i) * power(mint(m-i), n-s*i) * fac[n] * power(infac[s], i) * (infac[n-s*i]);
    vector<mint> _a(N+1, 0), _b(N+1, 0);
    for (int i = 0; i <= N; i++) _a[i] = fac[i] * f[i];
    for (int i = 0; i <= N; i++) _b[i] = ( (i & 1) ? (-1) : 1 ) * infac[i];
    for (int i = 0; i <= N; i++) if (i < N - i) swap(_a[i], _a[N-i]);

    Poly A(_a), B(_b);
    Poly C = A * B;

    mint ans = 0;
    for (int k = 0; k <= N; k++) ans = ans + mint(w[k]) * infac[k] * C[N-k];
    printf("%d\n", ans.x);
}

int main() {
    //freopen("input.txt", "r", stdin);
    cin >> n >> m >> s;

    for (int i = 0; i <= m; i++) scanf("%d", &w[i]);

    pre();
    solve();
}



这个问题一看没有什么思路,但区间修改不难想到线段树,相邻的两次操作如下
$(U_i \cdot U_j) \bmod P$,考虑打表,每出现一个新的数就新开一个状态编码,打表之后
$idx = 32$,也就是说,在模 $P$ 意义下
一个数 $x$ 乘以 ${U_0, U_1, \cdots U_4}$ 中的任意一个乘以若干次,最终只会有 $32$ 种不同的结果,记为 $Y[32]$
若干次 $\times U_i$ 的操作,等价于 $x \times Y[j], j \in [1, 32]$

基于此,可以用类似有限状态自动机模型,来构造线段树区间修改,修改等价于状态转移
csp201912-5.jpg

  • 维护转移矩阵,$g(i, j) = k$ 表示 $Y_i \times U_j \to Y_k$,即编码为 $i$ 的数乘以 $U_j$ 之后
    转移到编码为 $k$ 的数,其中 $i \in [1, 32], j \in [0, 5)$,为了转移方便和统一,可以都用编码 $\textbf{mp}$ 表示
    $Y_i \times U_j = Y_k$,需要维护 $g(i, \textbf{mp}[U_j]) = k$

  • 线段树的叶子结点,一开始的时候的 $f$ 值是 $f(A_l) = (l \bmod P) \bmod 2019$
    题中需要问询执行若干次 $A_l \times U_j = A_l’$ 之后的 $\displaystyle\sum_{l}f(A_l’)$,根据前面的分析
    对于每个 $l$,$A_l’$ 只有 $32$ 种不同的取值,预处理为 $Y[1\cdots 32]$,所以考虑拆点
    将每个叶子结点拆成 $32$ 个不同的子节点,对于每个 $A_l$,可以预处理出 $f(l, j), j = [1\cdots 32]$
    $f(l, j) = ((l \cdot Y_j) \bmod P) \bmod 2019$,这里可以优化,不必要每个都执行乘法
    因为 $l \in [1, n]$,可以递推,对每一个 $Y_j, j \in [1, 32]$,$(l \cdot Y_j) \bmod P = ((l-1) \cdot Y_j + Y_j) \bmod P$
    用一个 $res$ 记录一下 $(l \cdot Y_j) \bmod P$,递推求解就可以

  • 特别地,$f(l, 0)$ 表示对叶子结点不做任何修改,$f(l, 0) = A_l$

有了叶子结点之后,考虑 $\textbf{pull}$ 操作,线段树结点维护 $f$ 值,不妨设当前结点为 $[l, r]$
相应的值被修改为 $[A_l’, A_{l+1}’, \cdots, A_r’]$,结点 $f$ 值即维护 $\displaystyle\sum_{i = 1}^{r} f(A_i’)$
$\textbf{pull}$ 的时候只要把子区间对应的 $f[j], j \in [0, 32]$ 相加就可以

csp201912-5-02.jpg

对于每一次操作,先考虑区间修改 $\textbf{upd}(l, r, k)$,即 $\times U_k$ 的操作
难点在于,区间 $[l, r]$ 的子区间可能处于不同的状态 ${sta1, sta2, sta3, \cdots }$,在 $\textbf{pull}$ 的时候怎么合并?

当一个区间 $\times U_k$ 发生状态转移的时候,很可能子区间对应 ${g(sta1, k), g(sta2, k), \cdots}$ 多个状态同步转移
所以转移的时候,$32$ 个状态要同时转移
$\forall p \in [1, 32], g(i_p, \textbf{mp}[U_k]) = j_p$,有 $\textbf{f}(i_1, i_2, \cdots, i_{32}) \longrightarrow \textbf{f}’(j_1, j_2, \cdots, j_{32})$
$\textbf{pull}$ 的时候子区间对应的 $32$ 位,挨个对应合并即可,这样假设区间乘以若干个 $U_k$ 之后,等价于 $\times Y_s$
此时如果子区间之前已经被修改过了,因为 $\textbf{pull}$ 操作的存在,子区间和当前区间的 $Y_s$ 已经被更新成修改后的值
我们问询 $\textbf{f}’(s) \leftarrow \textbf{f}(g(s, k))$,它对应转移后的结果

另外,如何回答 $i \in [l, r]$ 区间的 $s = \sum f(A_i)$,那要注意到另外一个很重要的性质
因为本例对应的状态是有限的,仅有 $32$ 种,所以一定存在这样的转移
$A_l \times (U_{p1} \cdot U_{p2} \cdots U_{pk}) \to A_l$,即 $Y_s = U_{p1} \cdot U_{p2} \cdots U_{pk} = 1$,$s$ 为不动点
(否则的话状态应该是无限的)
也就是说 $g(i, s) = i$,初始化时候一定有 $A_l \cdot Y_s = A_l$,$f(s)$ 这个状态表示区间映射到自身
这样只要打表找到 $Y_s = 1$ 的下标 $s$, $f(s)$ 就表示区间和,经过若干次修改之后、
$\textbf{f}(s)$ 也会被更新为修改为 $A’$ 之后的区间 $\sum \textbf{f}(A_i’)$

那么这个问题大致就解决了,具体来说,执行 $\textbf{upd}(l, r, k)$ 操作的时候,找到对应的区间结点 $p$
将 $p$ 的状态向量 $f$ 更新为 $f’$,具体来说,$\forall i \in [1, 32], \textbf{f}’(i) \longleftarrow \textbf{f}(g(i, \textbf{mp}[U_k]))$
同时打标记,如果 $tag = 0$,那么 $tag = \textbf{mp}[U_k]$,否则,$tag \leftarrow g(tag, \textbf{mp}[U_k])$

延迟标记维护区间发生的一系列 $\times (U_{p1} \cdot U_{p2} \cdots U_{pk})$ 的累计状态转移,$\textbf{push}$ 的时候
记录当前区间的 $tag$,如果子区间的 $tag’ = 0$,那么 $tag’ = tag$,否则 $tag’ = g(tag’, tag)$,然后清空当前区间的 $tag$
问询区间 $sum$ 的时候,直接找到不动点 $s$ 并返回 $\textbf{f}(s)$

特别地,有些情形下,这种转移没有不动点,需要手动构造
构造的方法也很简单,预处理以及初始化线段树的时候,令 $\forall l \in [1, n], \ f(l, 0) = A_l$
线段树维护 $f$ 值,$f(0)$ 表示区间没有任何转移的时候的 $sum$ 值
构造转移矩阵,$\forall i \in [1, 32], \ g(0, i) = i$,这样 $\textbf{upd}(l, r, k)$ 的时候,就可以转移 $f$ 了
$\textbf{f}’(0) \longleftarrow \textbf{f}(g(0, \textbf{mp}[U_k]))$,查询区间和的时候,返回修改后的 $\textbf{f}’(0)$ 即可
但是本例不能这样做,是因为 $s = 28$ 是不动点,即 $28$ 这个状态,对应 $\textbf{f}(28)$ 将区间映射为它自己
本例中的修改,可能存在其他状态转移到 $s = 28$,我们的 $\textbf{f}(28)$ 要加上这部分贡献才能得到正确的结果
如果手动构造 $f(0)$ 为不动点,原问题不存在从其他状态向 $0$ 状态的转移,会漏掉一部分解

#include <iostream>
#include <cstdio>
#include <cstring>
#include <cstdlib>
#include <algorithm>
#include <queue>
#include <vector>
#include <stack>
#include <map>
#include <set>
#include <sstream>
#include <iomanip>
#include <cmath>
#include <bitset>
#include <assert.h>
#include <unordered_map>
#pragma GCC optimize(2)


using namespace std;
typedef long long ll;

#define debug(x) cout << #x << ": " << x << endl
#define _rep(i, l, r) for(int i = (l); i <= (r); i++)
#define _for(i, l, r) for(int i = (l); i < (r); i++)
#define _forDown(i, l, r) for(int i = (l); i >= r; i--)
#define debug_(ch, i) printf(#ch"[%d]: %d\n", i, ch[i])
#define debug_m(mp, p) printf(#mp"[%d]: %d\n", p->first, p->second)
#define debugS(str) cout << "dbg: " << str << endl;
#define debugArr(arr, x, y) _for(i, 0, x) { _for(j, 0, y) printf("%c", arr[i][j]); printf("\n"); }
#define lowbit(i) (i & (-i))
#define fill_v(f, v) fill(f.begin(), f.end(), v)

template<class T> 
inline void debug_v(const vector<T> &vec) {
    printf("vec: ");
    for (auto u : vec) cout << u << " ";
    cout << endl;
}

inline int cntOne(unsigned int x) {
    return __builtin_popcount(x);
}
inline int cntOnell(unsigned long long x) {
    return __builtin_popcountll(x); 
}


inline ll qpow(ll a, int n) {
    ll ans = 1;
    for(; n; n >>= 1) {
        if(n & 1) ans *= 1ll * a;
        a *= a;
    }
    return ans;
}

template <class T>
inline bool chmax(T& a, T b) {
    if(a < b) {
        a = b;
        return true;
    }
    return false;
}

ll gcd(ll a, ll b) {
    return b == 0 ? a : gcd(b, a % b);
}

ll ksc(ll a, ll b, ll mod) {
    ll ans = 0;
    for(; b; b >>= 1) {
        if (b & 1) ans = (ans + a) % mod;
        a = (a + a) % mod;
    }
    return ans;
}

ll ksm(ll a, ll b, ll mod) {
    ll ans = 1 % mod;
    a %= mod;

    for(; b; b >>= 1) {
        if (b & 1) ans = ksc(ans, a, mod);
        a = ksc(a, a, mod);
    }

    return ans;
}

template <class T>
inline bool chmin(T& a, T b) {
    if(a > b) {
        a = b;
        return true;
    }
    return false;
}

template<class T>
bool lexSmaller(vector<T> a, vector<T> b) {
    int n = a.size(), m = b.size();
    int i;
    for(i = 0; i < n && i < m; i++) {
        if (a[i] < b[i]) return true;
        else if (b[i] < a[i]) return false;
    }
    return (i == n && i < m);
}

constexpr int P = 998244353;

int norm(int x) {
    if (x < 0) x += P;
    if (x >= P) x -= P;
    return x;
}

template<class T>
T power(T a, int b) {
    T res = 1;
    for (; b; b >>= 1) {
        if (b & 1) res *= a;
        a *= a;
    }
    return res;
}

struct Z {
    int x;
    Z(int x = 0) : x(norm(x)) {}

    int val() const {
        return x;
    }
    Z operator-() const {
        return Z(norm(P-x));
    }
    Z &operator *= (const Z &rhs) {
        x = (ll)(x) * rhs.x % P;
        return *this;
    }
    Z &operator += (const Z &rhs) {
        x = norm(x + rhs.x);
        return *this;
    }
    Z &operator -= (const Z &rhs) {
        x = norm(x - rhs.x);
        return *this;
    }
    Z &operator /= (const Z &rhs) {
        return *this *= rhs.inv();
    }
    Z inv() const {
        assert(x != 0);
        return power(*this, P-2);
    }
    friend Z operator* (const Z &lhs, const Z &rhs) {
        Z res = lhs;
        res *= rhs;
        return res;
    }
    friend Z operator+ (const Z &lhs, const Z &rhs) {
        Z res = lhs;
        res += rhs;
        return res;
    }
    friend Z operator- (const Z &lhs, const Z &rhs) {
        Z res = lhs;
        res -= rhs;
        return res;
    }
    friend Z operator/ (const Z &lhs, const Z &rhs) {
        Z res = lhs;
        res /= rhs;
        return res;
    }
};

struct Int {
    static constexpr int B = 10;
    vector<int> X;
    int size() const {
        return (int)X.size();
    }

    Int(int x = 0) {
        while (x) {
            X.push_back(x % B), x /= B;
        }
    }

    Int(string str) {
        reverse(str.begin(), str.end());
        for (auto u : str) X.push_back(u-'0');
    }

    friend Int operator+ (const Int &lhs, const Int &rhs) {
        if (lhs.size() < rhs.size()) return rhs + lhs;
        Int res;

        int t = 0;
        for (int i = 0; i < lhs.size(); i++) {
            t += lhs.X[i];
            if (i < rhs.size()) t += rhs.X[i];
            res.X.push_back(t % B), t /= B;
        }
        if (t) res.X.push_back(t);
        while (res.X.size() > 1 && res.X.back() == 0) res.X.pop_back();
        return res;
    }

    friend Int operator- (const Int &lhs, const Int &rhs) {
        Int res;
        int t = 0;
        for (int i = 0; i < lhs.size(); i++) {
            t = lhs.X[i] - t;
            if (i < rhs.size()) t -= rhs.X[i];
            res.X.push_back((t + B) % B);

            if (t < 0) t = 1;
            else t = 0;
        }
        while (res.X.size() > 1 && res.X.back() == 0) res.X.pop_back();
        return res;
    }

    friend Int operator* (const Int &lhs, int b) {
        Int res;
        int t = 0;
        for (int i = 0; i < lhs.X.size() || t; i++) {
            if (i < lhs.X.size()) t += lhs.X[i] * b;
            res.X.push_back(t % B), t /= B;
        }
        return res;
    }

    friend Int operator/ (const Int &lhs, int b) {
        Int res;
        int r = 0;
        for (int i = lhs.X.size()-1; i >= 0; i--) {
            r = r * B + lhs.X[i];
            res.X.push_back(r / b), r %= b;
        }
        reverse(res.X.begin(), res.X.end());
        while (res.X.size() > 1 && res.X.back() == 0) res.X.pop_back();
        return res;
    }

    friend Int operator* (const Int &lhs, const Int &rhs) {
        Int res;
        res.X.resize(lhs.size() + rhs.size() + B);
        fill(res.X.begin(), res.X.end(), 0);

        for (int i = 0; i < lhs.size(); i++) {
            for (int j = 0; j < rhs.size(); j++) {
                res.X[i+j] += lhs.X[i] * rhs.X[j];
            }
        }
        for (int i = 0; i < res.X.size(); i++) {
            if (res.X[i] >= B) res.X[i+1] += res.X[i] / B, res.X[i] %= B;
        }

        while (res.X.size() > 1 && res.X.back() == 0) res.X.pop_back();
        return res;
    }

    friend Int operator/ (const Int& lhs, const Int &rhs) {
        int dv = lhs.size() - rhs.size();
        assert(dv >= 0);

        Int res;
        res.X.resize(dv+1);
        fill(res.X.begin(), res.X.end(), 0);

        // append suffix zero
        Int a = lhs, b = rhs;
        reverse(b.X.begin(), b.X.end());
        for (int i = 0; i < dv; i++) b.X.push_back(0);
        reverse(b.X.begin(), b.X.end());

        for (int i = 0; i <= dv; i++) {
            while (b < a) {
                a = a-b;
                res.X[dv-i]++;
            }
            b.X.erase(b.X.begin());
        }
        while (res.X.size() > 1 && res.X.back() == 0) res.X.pop_back();
        return res;
    }

    friend bool operator< (const Int &lhs, const Int &rhs) {
        if (lhs.size() < rhs.size()) return true;
        if (lhs.size() > rhs.size()) return false;

        if (vector<int>(lhs.X.rbegin(), lhs.X.rend()) <= vector<int>(rhs.X.rbegin(), rhs.X.rend())) return true;
        return false;
    }
    void out() {
        reverse(X.begin(), X.end());
        for (auto x : X) printf("%d", x);
        printf("\n");
    }
};

Int max_int(const Int &A, const Int &B) {
    if (A < B) return B;
    else return A;
}

typedef __int128_t int128;

// vector<int> v
// vector<int>(v).swap(v)

template <class T>
void read(T &x) {
    x = 0;
    bool fl = 0;
    char c = getchar();
    while (c < '0' || c > '9') {
        if (c == '-') fl = 1;
        c = getchar();
    }
    while (c >= '0' && c <= '9') {
        x = (x<<1) + (x<<3) + (c^48);
        c = getchar();
    }
    if (fl) x = -x;
}

template <class T>
inline void write(T x) {
    if (x < 0) putchar('-'), x = -x;
    if (x > 9) write(x / 10);
    putchar(x % 10 + '0');
}

// ============================================================== //


int n, q;
const int maxn = 1e6 + 1;
const int maxm = 33;
typedef unsigned long long ull;

const ull U[5] = {
    314882150829468584,
    427197303358170108,
    1022292690726729920,
    1698479428772363217,
    2006101093849356424
};

const ull mod = 2009731336725594113;

unordered_map<ull, int> mp;
int idx = 0;
int trans[maxm][maxm];
ull rm[maxm];

void pre() {
    queue<ull> que;
    for (int i = 0; i < 5; i++) {
        mp[U[i]] = ++idx, rm[idx] = U[i];
        que.push(U[i]);
    }
    while (que.size()) {
        int128 x = (int128) que.front(); que.pop();
        for (int i = 0; i < 5; i++) {
            ull y = x * U[i] % mod;
            if (mp[y]) continue;
            mp[y] = ++idx, rm[idx] = y;
            que.push(y);
        }
    }

    for (int i = 1; i <= 32; i++) for (int j = 1; j <= 32; j++) {
        trans[i][j] = trans[j][i] = mp[(int128) rm[i] * rm[j] % mod];
    }
}

int a[maxn][maxm];
void pre2() {
    for (int j = 1; j <= 32; j++) {
        ull res = 0;
        for (int i = 1; i <= n; i++) {
            res = (res + rm[j]) % mod;
            a[i][j] = res % 2019;
        }
    }
}

void dbg() {
    for (int i = 0; i <= n; i++) {
        for (int j = 0; j <= 32; j++) printf("%d ", a[i][j]);
        printf("\n");
    }

    for (auto x : mp) printf("#val:%llu           #id:%d\n", x.first, x.second);
}

struct Node {
    int l, r, tag;
    int f[maxm];
} t[maxn << 2];

inline void pull(int p) {
    for (int i = 1; i <= 32; i++) {
        t[p].f[i] = t[p<<1].f[i] + t[p<<1|1].f[i];
    }
}

void build(int p, int l, int r) {
    t[p].l = l, t[p].r = r;
    if (l >= r) {
        for (int i = 1; i <= 32; i++) t[p].f[i] = a[l][i];
        //assert(a[l][28] == l);
        return;
    }
    int mid = (l + r) >> 1;
    build(p<<1, l, mid), build(p<<1|1, mid+1, r);
    pull(p);
}

inline void push(int p) {
    if (t[p].tag == 0) return;
    int tg = t[p].tag;

    #define ls (p<<1)
    #define rs (p<<1|1)

    t[ls].tag == 0 ? t[ls].tag = tg : t[ls].tag = trans[t[ls].tag][tg];
    t[rs].tag == 0 ? t[rs].tag = tg : t[rs].tag = trans[t[rs].tag][tg];

    static int f2[maxm];
    //memset(f2, 0, sizeof f2);
    for (int i = 1; i <= 32; i++) f2[i] = t[ls].f[trans[i][tg]];
    memcpy(t[ls].f, f2, sizeof f2);

    // memset(f2, 0, sizeof f2);
    for (int i = 1; i <= 32; i++) f2[i] = t[rs].f[trans[i][tg]];
    memcpy(t[rs].f, f2, sizeof f2);

    t[p].tag = 0;
}

void upd(int p, const int l, const int r, int k) {
    if (l <= t[p].l && t[p].r <= r) {
        static int f2[maxm];
        // memset(f2, 0, sizeof f2);
        for (int i = 1; i <= 32; i++) f2[i] = t[p].f[ trans[i][ k+1 ] ];
        memcpy(t[p].f, f2, sizeof f2);

        t[p].tag == 0 ? t[p].tag = k+1 : t[p].tag = trans[ t[p].tag ][ k+1 ];
        return;
    }

    push(p);
    int mid = (t[p].l + t[p].r) >> 1;
    if (l <= mid) upd(p<<1, l, r, k);
    if (r > mid) upd(p<<1|1, l, r, k);

    pull(p);
}

int query(int p, const int l, const int r) {
    if (l <= t[p].l && t[p].r <= r) {
        return t[p].f[28];
    }
    push(p);
    int ans = 0;
    int mid = (t[p].l + t[p].r) >> 1;
    if (l <= mid) ans += query(p<<1, l, r);
    if (r > mid) ans += query(p<<1|1, l, r);
    return ans;
}

int main() {
    //freopen("input.txt", "r", stdin);
    read(n), read(q);

    pre();
    pre2();
    // debug(idx), dbg();
    build(1, 1, n);

    while (q--) {
        int l, r;
        read(l), read(r);
        //write(l), printf(" "), write(r), printf("\n");

        int s = query(1, l, r);
        printf("%d\n", s);
        int k = s % 5;
        upd(1, l, r, k);
    }
}



根据 $1 \leqslant N \leqslant 16$,很容易想到这一定是一个状态压缩
$0$ 表示宝石还未被处理,$1$ 表示已经被处理,那么起始状态为 $(00\cdots 0)$,终态为 $(11\cdots 1)$
令 $S = (1 \ll n) - 1$,从 $dp(0) \to dp(S)$,$dp(S)$ 就是答案

涉及枚举子集的操作,比如枚举 $S$ 的子集 $S_0$,代码如下

for (int S = 1; S < (1<<n); S++) {
    dp[S] = 初始化
    for (int S0 = S; S0; S0 = (S0-1) & S) {
        dp[S] = max(dp[S], dp[S^S0] + f(S0))
        // 其中 f(S0) 是将 S0 并入集合中的代价
    }
}

回到该问题,对于某一状态 $S$,它之前的状态假设为 $S’$,那么要考虑 $S’ \to S$ 新加入了哪些元素?
刚刚加入的元素一定是 $S$ 的一个子集,可以考虑枚举子集 $S_0$,状态转移方程如下
$dp(S) = \min (dp(S), dp(S \oplus S_0) + f(S_0))$,其中 $f(S_0)$ 是将 $S_0$ 并入集合的最小代价

接下来需要考虑 $f(S_0)$ 怎么求,首先 $S_0$ 必须合法,也就是说新加入的元素 $i$,$\sum_{i} A_i = 0$
这很简单,对 $S_0$ 中为 $1$ 的每一位 $i$ 求 $sum = \displaystyle\sum_{i} A_i$
如果 $sum = 0$ 才合法,否则的话不合法

将 $S_0$ 为 $1$ 的位取出来,假设有 $C$ 个为 $1$ 的位,求出这些位能量转移的最小代价,就是 $f(S_0)$
而这又是一个部分点的最小生成树问题,只需要将 $m$ 条边按权值排序,然后依次检查每一条边
对于边 $(x, y)$,只有 $S_0$ 中表示 $x, y$ 的位同时为 $1$ 的时候,将边权值累加到结果中
另外注意,累加的时候记得统计生成树边的个数 $cnt$,只有 $cnt = C-1$ 的时候才合法

至此,先预处理 $f(S)$,然后执行 $dp$,$dp((1 \ll n)-1)$ 就是答案

#include <iostream>
#include <cstring>
#include <algorithm>
#include <vector>
#include <cmath>
#include <cassert>

using namespace std;
typedef long long ll;

const int maxn = 1000;
const int N = (1<<16) + 5;
const int inf = 0x3f3f3f3f;
int n, m;

struct Edge {
    int x, y, z;
};

vector<Edge> edges;
vector<int> a(20, 0);
vector<ll> f(N, 0);

int kruskal(const int S) {
    int sum = 0, C = 0;
    for (int i = 0; i < n; i++) if (S >> i & 1) sum += a[i], C++;
    if (sum != 0) return inf;

    vector<int> pa(20);
    for (int i = 0; i <= n; i++) pa[i] = i;
    function<int(int)> fp = [&](int x) {
        return pa[x] == x ? x : pa[x] = fp(pa[x]);
    };

    sort(edges.begin(), edges.end(), [](const Edge a, const Edge b) {
        return a.z < b.z;
    });

    int cnt = 0, res = 0;
    for (auto e : edges) {
        if ( !(S >> e.x & 1) || !(S >> e.y & 1)) continue;
        int x = fp(e.x), y = fp(e.y);
        if (x == y) continue;

        pa[x] = y, cnt++;
        res += e.z;
    }

    return cnt == C-1 ? res : inf;
}

void solve() {
    vector<ll> dp(N, inf);
    dp[0] = 0;

    for (int S = 1; S < (1<<n); S++) {
        for (int S0 = S; S0; S0 = (S0-1) & S) {
            dp[S] = min(dp[S], dp[S^S0] + f[S0]);
        }
    }

    ll res = dp[(1<<n)-1];
    if (res >= inf) puts("Impossible");
    else printf("%lld\n", res);
}

void init() {
    edges.clear();
    a.clear(), f.clear();
}

int main() {
    //freopen("input.txt", "r", stdin);
    cin >> n >> m;
    init();

    for (int i = 0; i < n; i++) scanf("%d", &a[i]);
    for (int i = 0; i < m; i++) {
        int x, y, z;
        scanf("%d%d%d", &x, &y, &z);
        edges.push_back(Edge{x, y, z});
    }
    assert(edges.size() == m);

    // then prework
    for (int S = 1; S < (1<<n); S++) f[S] = kruskal(S);

    // then solve
    solve();
}


活动打卡代码 AcWing 388. 四叶草魔杖

根据 $1 \leqslant N \leqslant 16$,很容易想到这一定是一个状态压缩
$0$ 表示宝石还未被处理,$1$ 表示已经被处理,那么起始状态为 $(00\cdots 0)$,终态为 $(11\cdots 1)$
令 $S = (1 \ll n) - 1$,从 $dp(0) \to dp(S)$,$dp(S)$ 就是答案

涉及枚举子集的操作,比如枚举 $S$ 的子集 $S_0$,代码如下

for (int S = 1; S < (1<<n); S++) {
    dp[S] = 初始化
    for (int S0 = S; S0; S0 = (S0-1) & S) {
        dp[S] = max(dp[S], dp[S^S0] + f(S0))
        // 其中 f(S0) 是将 S0 并入集合中的代价
    }
}

回到该问题,对于某一状态 $S$,它之前的状态假设为 $S’$,那么要考虑 $S’ \to S$ 新加入了哪些元素?
刚刚加入的元素一定是 $S$ 的一个子集,可以考虑枚举子集 $S_0$,状态转移方程如下
$dp(S) = \min (dp(S), dp(S \oplus S_0) + f(S_0))$,其中 $f(S_0)$ 是将 $S_0$ 并入集合的最小代价

接下来需要考虑 $f(S_0)$ 怎么求,首先 $S_0$ 必须合法,也就是说新加入的元素 $i$,$\sum_{i} A_i = 0$
这很简单,对 $S_0$ 中为 $1$ 的每一位 $i$ 求 $sum = \displaystyle\sum_{i} A_i$
如果 $sum = 0$ 才合法,否则的话不合法

将 $S_0$ 为 $1$ 的位取出来,假设有 $C$ 个为 $1$ 的位,求出这些位能量转移的最小代价,就是 $f(S_0)$
而这又是一个部分点的最小生成树问题,只需要将 $m$ 条边按权值排序,然后依次检查每一条边
对于边 $(x, y)$,只有 $S_0$ 中表示 $x, y$ 的位同时为 $1$ 的时候,将边权值累加到结果中
另外注意,累加的时候记得统计生成树边的个数 $cnt$,只有 $cnt = C-1$ 的时候才合法

至此,先预处理 $f(S)$,然后执行 $dp$,$dp((1 \ll n)-1)$ 就是答案

#include <iostream>
#include <cstring>
#include <algorithm>
#include <vector>
#include <cmath>
#include <cassert>

using namespace std;
typedef long long ll;

const int maxn = 1000;
const int N = (1<<16) + 5;
const int inf = 0x3f3f3f3f;
int n, m;

struct Edge {
    int x, y, z;
};

vector<Edge> edges;
vector<int> a(20, 0);
vector<ll> f(N, 0);

int kruskal(const int S) {
    int sum = 0, C = 0;
    for (int i = 0; i < n; i++) if (S >> i & 1) sum += a[i], C++;
    if (sum != 0) return inf;

    vector<int> pa(20);
    for (int i = 0; i <= n; i++) pa[i] = i;
    function<int(int)> fp = [&](int x) {
        return pa[x] == x ? x : pa[x] = fp(pa[x]);
    };

    sort(edges.begin(), edges.end(), [](const Edge a, const Edge b) {
        return a.z < b.z;
    });

    int cnt = 0, res = 0;
    for (auto e : edges) {
        if ( !(S >> e.x & 1) || !(S >> e.y & 1)) continue;
        int x = fp(e.x), y = fp(e.y);
        if (x == y) continue;

        pa[x] = y, cnt++;
        res += e.z;
    }

    return cnt == C-1 ? res : inf;
}

void solve() {
    vector<ll> dp(N, inf);
    dp[0] = 0;

    for (int S = 1; S < (1<<n); S++) {
        for (int S0 = S; S0; S0 = (S0-1) & S) {
            dp[S] = min(dp[S], dp[S^S0] + f[S0]);
        }
    }

    ll res = dp[(1<<n)-1];
    if (res >= inf) puts("Impossible");
    else printf("%lld\n", res);
}

void init() {
    edges.clear();
    a.clear(), f.clear();
}

int main() {
    //freopen("input.txt", "r", stdin);
    cin >> n >> m;
    init();

    for (int i = 0; i < n; i++) scanf("%d", &a[i]);
    for (int i = 0; i < m; i++) {
        int x, y, z;
        scanf("%d%d%d", &x, &y, &z);
        edges.push_back(Edge{x, y, z});
    }
    assert(edges.size() == m);

    // then prework
    for (int S = 1; S < (1<<n); S++) f[S] = kruskal(S);

    // then solve
    solve();
}



如果没有通信卫星的话,那么很显然 $D$ 是原图 $G = (N, M)$ 最小生成树中权值最大的边

现在有 $S$ 个通信卫星,实际上是求 $G$ 的一个子图 $G’ = (N-S)$,即子图中有 $N-S$ 个点
同时还必须满足子图 $G’$ 的生成树 $T’$ 尽量小,生成树中权值最大的边就是所求的 $D$

最小生成树证明中,有一个结论
图 $G(N, M)$ 的最小生成树,一定包括权值最小的 $N-1$ 条边
那么子图 $G’$ 的最小生成树一定包括最小的 $N-S-1$ 条边

可以这么设计算法,执行 $\text{kruskal}$ 的时候,对边从小到大排序,并初始化 $cnt = 0$,当并查集 $x \neq y$ 时
令 $pa[x] = y, \ cnt++$,当 $cnt = N-S$ 的时候,此时边 $e$ 的权值就是答案

说明,此时维护的子图有 $N-S$ 个节点,有 $N-S-1$ 条边,另外还需要一条边连到“带有卫星的节点”
即子图外的节点,所以总共是最小的 $N-S$ 条边

#include <iostream>
#include <cstring>
#include <algorithm>
#include <vector>
#include <cmath>


using namespace std;
typedef long long ll;

const int maxn = 500 + 10;
int n, m, s;

struct Edge {
    int x, y;
    double z; 
};
vector<Edge> edges;

void init() {
    edges.clear();
}
typedef pair<int, int> PII;
PII a[maxn];

inline double dist(int i, int j) {
    return sqrt( (a[i].first-a[j].first)*(a[i].first-a[j].first) + (a[i].second-a[j].second)*(a[i].second-a[j].second) );
}

void kruskal() {
    vector<int> pa(maxn, 0);
    for (int i = 0; i <= n; i++) pa[i] = i;
    function<int(int)> fp = [&](int x) {
        return pa[x] == x ? x : pa[x] = fp(pa[x]);
    };
    // debug(edges.size());

    sort(edges.begin(), edges.end(), [](const Edge a, const Edge b) {
        return a.z < b.z;
    });

    double res = 0.0;
    int cnt = 0;
    for (auto e : edges) {
        int x = fp(e.x), y = fp(e.y);
        if (x == y) continue;

        pa[x] = y, res = max(res, e.z);
        if (++cnt == n-s) break;
    }
    printf("%.2lf\n", res);
}

int main() {
    //freopen("input.txt", "r", stdin);
    int cas;
    cin >> cas;

    while (cas--) {
        scanf("%d%d", &s, &n);
        init();

        for (int i = 1; i <= n; i++) {
            int x, y;
            scanf("%d%d", &x, &y);
            a[i] = PII(x, y);
        }

        // build
        for (int i = 1; i <= n; i++) {
            for (int j = i+1; j <= n; j++) {
                double z = dist(i, j);
                // debug(z);
                edges.push_back(Edge{i, j, z});
            }
        }

        // kruskal
        if (n-s <= 0) {
            puts("0");
            continue;
        }
        kruskal();
    }
}


活动打卡代码 AcWing 387. 北极网络

如果没有通信卫星的话,那么很显然 $D$ 是原图 $G = (N, M)$ 最小生成树中权值最大的边

现在有 $S$ 个通信卫星,实际上是求 $G$ 的一个子图 $G’ = (N-S)$,即子图中有 $N-S$ 个点
同时还必须满足子图 $G’$ 的生成树 $T’$ 尽量小,生成树中权值最大的边就是所求的 $D$

最小生成树证明中,有一个结论
图 $G(N, M)$ 的最小生成树,一定包括权值最小的 $N-1$ 条边
那么子图 $G’$ 的最小生成树一定包括最小的 $N-S-1$ 条边

可以这么设计算法,执行 $\text{kruskal}$ 的时候,对边从小到大排序,并初始化 $cnt = 0$,当并查集 $x \neq y$ 时
令 $pa[x] = y, \ cnt++$,当 $cnt = N-S$ 的时候,此时边 $e$ 的权值就是答案

说明,此时维护的子图有 $N-S$ 个节点,有 $N-S-1$ 条边,另外还需要一条边连到“带有卫星的节点”
即子图外的节点,所以总共是最小的 $N-S$ 条边

#include <iostream>
#include <cstring>
#include <algorithm>
#include <vector>
#include <cmath>


using namespace std;
typedef long long ll;

const int maxn = 500 + 10;
int n, m, s;

struct Edge {
    int x, y;
    double z; 
};
vector<Edge> edges;

void init() {
    edges.clear();
}
typedef pair<int, int> PII;
PII a[maxn];

inline double dist(int i, int j) {
    return sqrt( (a[i].first-a[j].first)*(a[i].first-a[j].first) + (a[i].second-a[j].second)*(a[i].second-a[j].second) );
}

void kruskal() {
    vector<int> pa(maxn, 0);
    for (int i = 0; i <= n; i++) pa[i] = i;
    function<int(int)> fp = [&](int x) {
        return pa[x] == x ? x : pa[x] = fp(pa[x]);
    };
    // debug(edges.size());

    sort(edges.begin(), edges.end(), [](const Edge a, const Edge b) {
        return a.z < b.z;
    });

    double res = 0.0;
    int cnt = 0;
    for (auto e : edges) {
        int x = fp(e.x), y = fp(e.y);
        if (x == y) continue;

        pa[x] = y, res = max(res, e.z);
        if (++cnt == n-s) break;
    }
    printf("%.2lf\n", res);
}

int main() {
    //freopen("input.txt", "r", stdin);
    int cas;
    cin >> cas;

    while (cas--) {
        scanf("%d%d", &s, &n);
        init();

        for (int i = 1; i <= n; i++) {
            int x, y;
            scanf("%d%d", &x, &y);
            a[i] = PII(x, y);
        }

        // build
        for (int i = 1; i <= n; i++) {
            for (int j = i+1; j <= n; j++) {
                double z = dist(i, j);
                // debug(z);
                edges.push_back(Edge{i, j, z});
            }
        }

        // kruskal
        if (n-s <= 0) {
            puts("0");
            continue;
        }
        kruskal();
    }
}



不难想到这是一个最短路树问题,从源点 $1$ 开始执行 $\text{dijkstra}$ 并且得到距离向量 $f[\cdots]$
$(x, y)$ 是最短路树上的边当且仅当 $f(y) = f(x) + e(x, y)$

可以根据 $\text{prim}$ 算法的思想,设计出如下算法

  • 根据点 $i \in [1\cdots n]$ 的 $f(i)$ 排序,一开始令集合 $T$ 中只有一个元素 $1$,即 $\text{inq}(1) = 1$

  • 枚举所有的节点 $\forall x \in [1\cdots n]$,找到第一个不在 $T$ 中的节点 $x$,即 $\text{inq}(x) = 0$ 的点,并初始化 $cnt = 0$

  • 接下来需要统计出 $T$ 中有多少个点 $p$ 满足 $d(x) = d(p) + (p, x)$,满足的话就令 $cnt++$
    根据乘法原理,令 $res = res \cdot cnt$,然后把 $x$ 加入 $T$ 中,即 $\text{inq}(x) = 1$
    执行完之后,$res$ 就是答案

#include <iostream>
#include <cstring>
#include <algorithm>
#include <vector>
#include <queue>
#include <cmath>

using namespace std;

#define fill_v(f, v) fill(f.begin(), f.end(), v)
typedef long long ll;

const int maxn = 1e3 + 10;
const int maxm = 1e6 + 10;
const int inf = 0x3f3f3f3f;
const ll mod = pow(2, 31)-1;

namespace Graph {
    int idx = 1;
    int h[maxn], ver[maxm], e[maxm], ne[maxm];

    void initG() {
        idx = 1;
        memset(h, 0, sizeof h);
    }

    void add(int x, int y, int z) {
        ver[++idx] = y, e[idx] = z, ne[idx] = h[x], h[x] = idx;
    }
}

using namespace Graph;
int g[maxn][maxn], n, m;
typedef pair<int, int> PII;
vector<int> f(maxn, inf), vis(maxn, 0);
void dijkstra(int s, vector<int> &f) {
    fill_v(f, inf), fill_v(vis, 0);

    f[s] = 0;
    priority_queue<PII, vector<PII>, greater<PII> > q;
    q.push(PII(f[s], s));

    while (q.size()) {
        auto x = q.top().second; q.pop();
        if (vis[x]) continue;
        vis[x] = 1;

        for (int i = h[x]; i; i = ne[i]) {
            int y = ver[i];
            if (f[y] > f[x] + e[i]) {
                f[y] = f[x] + e[i];
                q.push(PII(f[y], y));
            }
        }
    }
}

void solve() {
    vector<PII> a(maxn, PII());
    for (int i = 1; i <= n; i++) a[i] = PII(f[i], i);
    sort(a.begin()+1, a.begin()+1+n);

    vector<int> inq(maxn, 0);
    vector<int> nodes;
    inq[a[1].second] = 1, nodes.push_back(a[1].second);

    ll ans = 1LL;
    for (int i = 1; i <= n; i++) {
        if (inq[a[i].second]) continue;

        int x = a[i].second, cnt = 0;
        for (auto u : nodes) {
            if (f[u] + g[u][x] == f[x]) cnt++;
        }
        ans = ans * cnt % mod;
        inq[x] = 1, nodes.push_back(x);
    }
    printf("%lld\n", ans);
}

void init() {
    memset(g, inf, sizeof g);
}

int main() {
    //freopen("input.txt", "r", stdin);
    initG(), init();

    scanf("%d%d", &n, &m);
    while (m--) {
        int x, y, z;
        scanf("%d%d%d", &x, &y, &z);
        add(x, y, z), add(y, x, z);
        g[x][y] = g[y][x] = min(g[y][x], z);
    }

    // dijkstra and solve
    dijkstra(1, f);
    solve();
}


活动打卡代码 AcWing 349. 黑暗城堡

不难想到这是一个最短路树问题,从源点 $1$ 开始执行 $\text{dijkstra}$ 并且得到距离向量 $f[\cdots]$
$(x, y)$ 是最短路树上的边当且仅当 $f(y) = f(x) + e(x, y)$

可以根据 $\text{prim}$ 算法的思想,设计出如下算法

  • 根据点 $i \in [1\cdots n]$ 的 $f(i)$ 排序,一开始令集合 $T$ 中只有一个元素 $1$,即 $\text{inq}(1) = 1$

  • 枚举所有的节点 $\forall x \in [1\cdots n]$,找到第一个不在 $T$ 中的节点 $x$,即 $\text{inq}(x) = 0$ 的点,并初始化 $cnt = 0$

  • 接下来需要统计出 $T$ 中有多少个点 $p$ 满足 $d(x) = d(p) + (p, x)$,满足的话就令 $cnt++$
    根据乘法原理,令 $res = res \cdot cnt$,然后把 $x$ 加入 $T$ 中,即 $\text{inq}(x) = 1$
    执行完之后,$res$ 就是答案

#include <iostream>
#include <cstring>
#include <algorithm>
#include <vector>
#include <queue>
#include <cmath>

using namespace std;

#define fill_v(f, v) fill(f.begin(), f.end(), v)
typedef long long ll;

const int maxn = 1e3 + 10;
const int maxm = 1e6 + 10;
const int inf = 0x3f3f3f3f;
const ll mod = pow(2, 31)-1;

namespace Graph {
    int idx = 1;
    int h[maxn], ver[maxm], e[maxm], ne[maxm];

    void initG() {
        idx = 1;
        memset(h, 0, sizeof h);
    }

    void add(int x, int y, int z) {
        ver[++idx] = y, e[idx] = z, ne[idx] = h[x], h[x] = idx;
    }
}

using namespace Graph;
int g[maxn][maxn], n, m;
typedef pair<int, int> PII;
vector<int> f(maxn, inf), vis(maxn, 0);
void dijkstra(int s, vector<int> &f) {
    fill_v(f, inf), fill_v(vis, 0);

    f[s] = 0;
    priority_queue<PII, vector<PII>, greater<PII> > q;
    q.push(PII(f[s], s));

    while (q.size()) {
        auto x = q.top().second; q.pop();
        if (vis[x]) continue;
        vis[x] = 1;

        for (int i = h[x]; i; i = ne[i]) {
            int y = ver[i];
            if (f[y] > f[x] + e[i]) {
                f[y] = f[x] + e[i];
                q.push(PII(f[y], y));
            }
        }
    }
}

void solve() {
    vector<PII> a(maxn, PII());
    for (int i = 1; i <= n; i++) a[i] = PII(f[i], i);
    sort(a.begin()+1, a.begin()+1+n);

    vector<int> inq(maxn, 0);
    vector<int> nodes;
    inq[a[1].second] = 1, nodes.push_back(a[1].second);

    ll ans = 1LL;
    for (int i = 1; i <= n; i++) {
        if (inq[a[i].second]) continue;

        int x = a[i].second, cnt = 0;
        for (auto u : nodes) {
            if (f[u] + g[u][x] == f[x]) cnt++;
        }
        ans = ans * cnt % mod;
        inq[x] = 1, nodes.push_back(x);
    }
    printf("%lld\n", ans);
}

void init() {
    memset(g, inf, sizeof g);
}

int main() {
    //freopen("input.txt", "r", stdin);
    initG(), init();

    scanf("%d%d", &n, &m);
    while (m--) {
        int x, y, z;
        scanf("%d%d%d", &x, &y, &z);
        add(x, y, z), add(y, x, z);
        g[x][y] = g[y][x] = min(g[y][x], z);
    }

    // dijkstra and solve
    dijkstra(1, f);
    solve();
}