头像

fujang

大连海事大学




在线 


活动打卡代码 AcWing 1277. 维护序列

fujang
23小时前
#include <bits/stdc++.h>
using namespace std;
typedef long long ll;
const int N = 100010;

int n, p, m;
int w[N];
struct Node {
    int l, r;
    int sum, add, mul;
} tr[N * 4];

void pushup(int u) {
    tr[u].sum = (tr[u << 1].sum + tr[u << 1 | 1].sum) % p;
}

void eval(Node &t, int add, int mul) {
    t.sum = ((ll)t.sum * mul + (ll)(t.r - t.l + 1) * add) % p;
    t.mul = (ll)t.mul * mul % p;
    t.add = ((ll)t.add * mul + add) % p;
}

void pushdown(int u) {
    eval(tr[u << 1], tr[u].add, tr[u].mul);
    eval(tr[u << 1 | 1], tr[u].add, tr[u].mul);
    tr[u].add = 0, tr[u].mul = 1;
}

void build(int u, int l, int r) {
    if (l == r)
        tr[u] = {l, r, w[r], 0, 1};
    else {
        tr[u] = {l, r, 0, 0, 1};
        int mid = l + r >> 1;
        build(u << 1, l, mid), build(u << 1 | 1, mid + 1, r);
        pushup(u);
    }
}

void modify(int u, int l, int r, int add, int mul) {
    if (tr[u].l >= l && tr[u].r <= r)
        eval(tr[u], add, mul);
    else {
        pushdown(u);
        int mid = tr[u].l + tr[u].r >> 1;
        if (l <= mid) modify(u << 1, l, r, add, mul);
        if (r > mid) modify(u << 1 | 1, l, r, add, mul);
        pushup(u);
    }
}

int query(int u, int l, int r) {
    if (tr[u].l >= l && tr[u].r <= r) return tr[u].sum;

    pushdown(u);
    int mid = tr[u].l + tr[u].r >> 1;
    int sum = 0;
    if (l <= mid) sum = query(u << 1, l, r);
    if (r > mid) sum = (sum + query(u << 1 | 1, l, r)) % p;
    return sum;
}

int main() {
    scanf("%d%d", &n, &p);
    for (int i = 1; i <= n; i++) scanf("%d", &w[i]);
    build(1, 1, n);
    scanf("%d", &m);
    while (m--) {
        int t, l, r, d;
        scanf("%d%d%d", &t, &l, &r);
        if (t == 1) {
            scanf("%d", &d);
            modify(1, l, r, 0, d);
        } else if (t == 2) {
            scanf("%d", &d);
            modify(1, l, r, d, 1);
        } else
            printf("%d\n", query(1, l, r));
    }
    return 0;
}




fujang
23小时前
#include <bits/stdc++.h>
using namespace std;
typedef long long ll;
const int N = 100010;

int n, m;
int w[N];
struct Node {
    int l, r;
    ll sum, add;
}tr[N * 4];

void pushup(int u) {
    tr[u].sum = tr[u << 1].sum + tr[u << 1 | 1].sum;
}

void pushdown(int u) {
    auto &root = tr[u], &left = tr[u << 1], &right = tr[u << 1 | 1];
    if (root.add) {
        left.add += root.add, left.sum += (ll)(left.r - left.l + 1) * root.add;
        right.add += root.add, right.sum += (ll)(right.r - right.l + 1) * root.add;
        root.add = 0;
    }
}

void build(int u, int l, int r) {
    if (l == r) tr[u] = {l, r, w[r], 0};
    else {
        tr[u] = {l, r};
        int mid = l + r >> 1;
        build(u << 1, l, mid), build(u << 1 | 1, mid + 1, r);
        pushup(u);
    }
}

void modify(int u, int l, int r, int d) {
    if (tr[u]. l >= l && tr[u].r <= r) {
        tr[u].sum += (ll)(tr[u].r - tr[u].l + 1) * d;
        tr[u].add += d;
    } else {
        pushdown(u);
        int mid = tr[u].l + tr[u].r >> 1;
        if (l <= mid) modify(u << 1, l, r, d);
        if (r > mid) modify(u << 1 | 1, l, r, d);
        pushup(u);
    }
}

ll query(int u, int l, int r) {
    if (tr[u].l >= l && tr[u].r <= r) return tr[u].sum;
    pushdown(u);
    int mid = tr[u].l + tr[u].r >> 1;
    ll sum = 0;
    if (l <= mid) sum = query(u << 1, l, r);
    if (r > mid) sum += query(u << 1 | 1, l, r);
    return sum;
}

int main() {
    ios::sync_with_stdio(false);
    cin.tie(0);
    cout.tie(0);
    cin >> n >> m;
    for (int i = 1; i <= n; i++) cin >> w[i];
    build(1, 1, n);
    string op;
    int l, r, d;
    while (m--) {
        cin >> op >> l >> r;
        if (op[0] == 'C') {
            cin >> d;
            modify(1, l, r, d);
        } else cout << query(1, l, r) << endl;
    }
    return 0;
}



fujang
1天前
#include <bits/stdc++.h>
using namespace std;
constexpr int maxn = 5e5 + 10;
typedef long long ll;
ll w[maxn];

ll gcd(ll x, ll y) {
    return y ? gcd(y, x % y) : x;
}

struct Node {
    int l, r;
    ll sum, d;
} tr[maxn << 2];

void pushup(Node &u, Node &l, Node &r) {
    u.sum = l.sum + r.sum;
    u.d = gcd(l.d, r.d);
}

void pushup(int u) {
    pushup(tr[u], tr[u << 1], tr[u << 1 | 1]);
}

void build(int u, int l, int r) {
    if (l == r) {
        tr[u].l = tr[u].r = l;
        tr[u].sum = w[l] - w[l - 1];
        tr[u].d = w[l] - w[l - 1];
        return;
    }
    tr[u].l = l, tr[u].r = r;
    int mid = l + r >> 1;
    build(u << 1, l, mid), build(u << 1 | 1, mid + 1, r);
    pushup(u);
}

void modify(int u, int x, ll v) {
    if (tr[u].l == x && tr[u].r == x) {
        tr[u] = {x, x, tr[u].sum + v, tr[u].sum + v};
        return;
    }
    int mid = tr[u].l + tr[u].r >> 1;
    if (x <= mid) modify(u << 1, x, v);
    else modify(u << 1 | 1, x, v);
    pushup(u);
}

Node query(int u, int l, int r) {
    if (tr[u].l >= l && tr[u].r <= r) return tr[u];
    else {
        int mid = tr[u].l + tr[u].r >> 1;
        if (r <= mid) return query(u << 1, l, r);
        else if (l > mid) return query(u << 1 | 1, l, r);
        else {
            Node res;
            auto left = query(u << 1, l, r);
            auto right = query(u << 1 | 1, l, r);
            pushup(res, left, right);
            return res;
        }
    }
}

int main() {
    ios::sync_with_stdio(false);
    cin.tie(0);
    cout.tie(0);
    int n, m;
    cin >> n >> m;

    for (int i = 1; i <= n; i++) cin >> w[i];
    build (1, 1, n);
    while (m--) {
        string s;
        ll x, y, z;
        cin >> s;
        if (s[0] == 'Q') {
            cin >> x >> y;
            Node res = query(1, 1, x);
            Node tmp({0, 0, 0, 0});
            if (x + 1 <= y) tmp = query(1, x + 1, y);
            cout << abs(gcd(res.sum, tmp.d)) << endl;
        } else {
            cin >> x >> y >> z;
            modify(1, x, z);
            if (y + 1 <= n) modify(1, y + 1, -z);
        }
    }
}



fujang
1天前
#include <bits/stdc++.h>
typedef long long ll;
using namespace std;
const int maxn = 5e5 + 10;
int a[maxn];
struct Node {
    int l, r;
    ll tmax, lmax, rmax, sum;
} tree[maxn << 2];

void pushup(Node &u, Node &l, Node &r) {
    u.sum = l.sum + r.sum;
    u.lmax = max(l.lmax, l.sum + r.lmax);
    u.rmax = max(r.rmax, r.sum + l.rmax);
    u.tmax = max({l.tmax, r.tmax, l.rmax + r.lmax});
}

void pushup(int u) {
    pushup(tree[u], tree[u << 1], tree[u << 1 | 1]);
}

void build(int u, int l, int r) {
    if (l == r) {
        tree[u] = {l, r, a[l], a[l], a[l], a[l]};
        return;
    } 
    tree[u].l = l, tree[u].r = r;
    int mid = l + r >> 1;
    build(u << 1, l, mid), build(u << 1 | 1, mid + 1, r);
    pushup(u);
}

void modify(int u, int x, int v) {
    if (tree[u].l == x && tree[u].r == x) {
        tree[u] = {x, x, v, v, v, v};
        return;
    }
    int mid = tree[u].l + tree[u].r >> 1;
    if (x <= mid) modify(u << 1, x, v);
    if (x > mid) modify(u << 1 | 1, x, v);
    pushup(u);
}

Node query(int u, int l, int r) {
    if (tree[u].l >= l && tree[u].r <= r) return tree[u];
    else {
        int mid = tree[u].l + tree[u].r >> 1;
        Node res, left, right;
        if (r <= mid) return query(u << 1, l, r);
        else if (l > mid) return query(u << 1 | 1, l, r);
        else {
            left = query(u << 1, l, r);
            right = query(u << 1 | 1, l, r);
            pushup(res, left, right);
            return res;
        }
    }
}

int main() {
    ios::sync_with_stdio(false);
    cin.tie(0);
    cout.tie(0);
    int n, m;
    cin >> n >> m;
    for (int i = 1; i <= n; i++) cin >> a[i];
    build(1, 1, n);
    int x, y, z;
    while (m--) {
        cin >> x >> y >> z;
        if (x == 1) {
            if (y > z) swap(y, z);
            cout << query(1, y, z).tmax << endl;
        } else {
            modify(1, y, z);
        }
    }
    return 0;
}



fujang
3天前

数论——容斥原理、莫比乌斯函数

1、容斥原理:

  • 时间复杂度为$O(2^N)$,下面会有证明。
  • 举一个简单的例子:用韦恩图来思考,求$S1$、$S2$、$S3$三个集合的原有元素的并集,那么结果为:$S1+S2+S3-S1 \cap S2-S1 \cap S3-S2 \cap S3+S1 \cap S2 \cap S3$。
  • 以此类推到$N$个圆的交集:用容斥原理的方法答案为所有单个集合的元素个数-所有两个集合互相交集的元素个数+所有三个集合互相交集的元素个数…
  • 我们知道容斥原理公式一共涉及到的元素个数为:$C_N^1+C_N^2+C_N^3+…+C_N^N$。因为$C_N^0+C_N^1+C_N^2+C_N^3+…+C_N^N=2^n$,因此$C_N^1+C_N^2+C_N^3+…+C_N^N=2^n-1$,因此容斥原理公式一共涉及到的元素个数为$2^n-1$。关于此公式($C_N^0+C_N^1+C_N^2+C_N^3+…+C_N^N=2^n$)的证明,我们可以假设等号左边为对于$N$个物品所有选法的总个数,等号右边考虑每个物品选与不选两种情况,因此等式成立。
  • 因此容斥原理的时间复杂度为$O(2^N)$。
  • 容斥原理的证明:对于容斥原理$|S1 \cup S2 \cup … \cup SN|=\sum_{i=1}^N{Si}-\sum_{i, j}^N{Si \cap Sj}+\sum_{i, j, k}^N{Si \cap Sj \cap Sk}+…$
    对于一个元素$x$,它在$k$个集合中,$1 \leq k \leq N$,它本身被选择的次数为$C_k^1-C_k^2+C_k^3-…+(-1)^{k-1}C_k^k$。我们知道一个结论:$C_k^1-C_k^2+C_k^3-…+(-1)^{k-1}C_k^k=1$,因此对于每一个元素$x$,它只被计算了$1$次,证毕。

例题:AcWing 890. 能被整除的数

给定一个整数$n$和$m$个不同的质数$p1,p2,…,pm$。请你求出$1$到$n$中能被$p1,p2,…,pm$中的至少一个数整除的整数有多少个。

首先我们知道,在$N$个数中能被$x$整除的数的个数为$\lfloor{N/x}\rfloor$。

因此我们只需要根据容斥原理,求出可以被单个元素整除的个数之和-可以被两个元素整除的个数之和+可以被三个元素整除的个数之和…我们用位运算来求得答案,时间复杂度为$O(2^N)$。

#include <bits/stdc++.h>
#define int long long
using namespace std;
int p[20];

void work() {
    int n, m;
    cin >> n >> m;
    for (int i = 0; i < m; i++) cin >> p[i];

    int res = 0;
    for (int i = 1; i < 1 << m; i++) {
        int t = 1, s = 0;
        for (int j = 0; j < m; j++)
            if (i >> j & 1) {
                if (t * p[j] > n) {
                    t = -1;
                    break;
                } 
                t *= p[j];
                s++;
            }
        if (t != -1) {
            if (s % 2) res += n / t;
            else res -= n / t;
        }
    }
    cout << res << endl;
}

int32_t main() {
    ios::sync_with_stdio(false);
    cin.tie(0);
    cout.tie(0);
    int cas;
    cas = 1;
    while (cas--) work();
}

例题:AcWing 214. Devu和鲜花

有$N$个盒子,第$i$个盒子中有$Ai$枝花。同一个盒子内的花颜色相同,不同盒子内的花颜色不同。要从这些盒子中选出$M$枝花组成一束,求共有多少种方案。若两束花每种颜色的花的数量都相同,则认为这两束花是相同的方案。结果需对$10^9+7$取模之后方可输出。

我们先考虑:从$N$个盒子选$M$枝花,每个盒子的花的个数为无限个,问一共能选多少枝花?

此问题等价于从$N$个盒子选$M+N枝$花,那么每个盒子至少选$1$枝。那么此问题由等价于把$N+M$个点分成$N$份,我们可以用隔板法来做,一共有$N+M-1$个空隙,有$N-1$个板子,因此答案为$C_{N+M-1}^{N-1}$。

拓展到此问题,第$i$个盒子中有$Ai$枝花。那么我们可以反过来考虑,用总共的答案$C_{N+M-1}^{N-1}$减去其中第$i$个盒子被拿走了大于$Ai$枝花的方案。第$i$个盒子被拿走了大于$Ai$枝花的方案数为:假设此盒子已经被拿走了$Ai+1$枝花,那么等价于前面的问题,从$N$个盒子中共拿走$M-Ai-1$枝花的方案数,等价于从N个盒子拿走$M-Ai-1+N$的方案数,每个盒子至少被拿$1$枝,因此答案为$C_{M-Ai-1+N-1}^{N-1}$。

根据容斥原理来做,可知答案为总共的$C_{N+M-1}^{N-1}$减去所有$1$个盒子不满足的加上所有$2$个盒子不满足的减去所有$3$个盒子不满足的…

#include <bits/stdc++.h>
#define int long long
using namespace std;
int A[20];
constexpr int mod = 1e9 + 7;
int down = 1;

int qmi(int a, int b, int p) {
    int res = 1;
    while (b) {
        if (b & 1) res = res * a % p;
        a = a * a % p;
        b >>= 1;
    }
    return res;
}

int C(int a, int b) {
    if (a < b) return 0;
    int up = 1;
    for (int i = a; i > a - b; i--) up = i % mod * up % mod;
    return up * down % mod;
}

void work() {
    int n, m;
    cin >> n >> m;
    for (int i = 0; i < n; i++) cin >> A[i];
    for (int j = 1; j <= n - 1; j++) down = j * down % mod;
    down = qmi(down, mod - 2, mod);

    int res = 0;
    for (int i = 0; i < 1 << n; i++) {
        int a = n + m - 1, b = n - 1;
        int sign = 1;
        for (int j = 0; j < n; j++)
            if (i >> j & 1) {
                sign *= -1;
                a -= A[j] + 1;
            }
        res = (res + C(a, b) * sign) % mod;
    }
    cout << (res % mod + mod) % mod << endl;
}

int32_t main() {
    ios::sync_with_stdio(false);
    cin.tie(0);
    cout.tie(0);
    int cas;
    cas = 1;
    while (cas--) work();
}

2、莫比乌斯函数:


我们举例一道经典的应用题,求$1$到$N$中与$a$互质的数的个数:那么根据容斥原理,设$S_i$为$1$到$N$中和$a$有公因子i的数的个数,答案为$N-S_2-S_3-S_5-S_7..+S_{2,3}+S_{3,5}+S_{2,5}+…$,我们可以惊奇的发现,其答案为$\sum_{i=0}^{N}{u(i)*S_i}$。

我们可以根据线性筛质数在$O(N)$的时间内算出前$N$个数的莫比乌斯数。

int primes[N], cnt;
bool st[N];
int mobius[N];

void init(int n) {
    mobius[1] = 1;
    for (int i = 2; i <= n; i++) {
        if (!st[i]) {
            primes[cnt++] = i;
            mobius[i] = -1;
        }
        for (int j = 0; primes[j] * i <= n; j++) {
            int t = primes[j] * i;
            st[t] = true;
            if (i % primes[j] == 0) {
                mobius[t] = 0;
                break;
            }
            mobius[t] = mobius[i] * -1;
        }
    }
}

AcWing 215. 破译密码

对于给定的整数$a$,$b$和$d$,有多少正整数对$x$,$y$,满足$x \leq a$,$y\leq b$,并且$gcd(x, y)=d$。$5e4$组询问,$a、b$的数据范围为$5e4$。

根据数据的范围,我们可以推断出时间复杂度为$O(nlogn)$。

每次询问的问题等价于:有多少正整数对$x$,$y$,满足$x \leq a/d$,$y\leq b/d$,并且$gcd(x, y)=1$。那么根据容斥原理:值为$min(a/d, b/d)-S_2-S_3-S_5+S_{2,3}+S_{2,5}+S_{3,5}-S{2,3,5}…$,其答案为$\sum_{i=0}^{min(a/d, b/d)}{u(i)*S_i}$,也就是$\sum^{min(a,b)}_i=a/i∗b/i∗u[i]$。我们可以推断出时间复杂度为$O(n)$。

考虑把上述过程优化,发现,这个式子中虽然i要枚举$N$次,但是实际上因为整除的原因$ai$的值很少,只有$2\sqrt a$个。

因为$a/1 、a/2、a/3、…$是单调递减的,并且有的值相同,所以整个序列一共有有$2\sqrt a$个值。(证明:在分母为$1$到$\sqrt a$之间,值的个数为$a / \sqrt a$个值,在$\sqrt a +1$到$n$之间,值的个数为$a / \sqrt a$个值)。

设$g(x)$表示$a/x$的取值不变的最大的$x$值,那么$a/x=a/g(x)$,并且$a/x>a/(g(x)+1)$,其中$g(x)=a/(a/x)$。

证明$a/x=a/g(x)$:

证明:$g(x)=a/(a/x)$:

综上:将原来的序列分成$2\sqrt a$段,而且每次都会跳一段,所以总共会跳$2\sqrt a$次,时间复杂度就是$O(\sqrt a)$。

加上询问后总的时间复杂度为$O(nlogn)$。

#include <bits/stdc++.h>
typedef long long ll;
using namespace std;
const int N = 50010;

int primes[N], cnt;
bool st[N];
int mobius[N], sum[N];

void init(int n) {
    mobius[1] = 1;
    for (int i = 2; i <= n; i++) {
        if (!st[i]) {
            primes[cnt++] = i;
            mobius[i] = -1;
        }
        for (int j = 0; primes[j] * i <= n; j++) {
            int t = primes[j] * i;
            st[t] = true;
            if (i % primes[j] == 0) {
                mobius[t] = 0;
                break;
            }
            mobius[t] = mobius[i] * -1;
        }
    }
    for (int i = 1; i <= n; i++) sum[i] = sum[i - 1] + mobius[i];
}

void work() {
    int a, b, d;
    cin >> a >> b >> d;
    a /= d, b /= d;
    int n = min(a, b);
    ll res = 0;
    for (int l = 1, r; l <= n; l = r + 1) {
        r = min(n, min(a / (a / l), b / (b / l)));
        res += (sum[r] - sum[l - 1]) * (ll)(a / l) * (b / l);
    }
    cout << res << endl;
}

int32_t main() {
    ios::sync_with_stdio(false);
    cin.tie(0);
    cout.tie(0);
    init(N - 1);
    int cas;
    cin >> cas;
    while (cas--) work();
}


活动打卡代码 AcWing 215. 破译密码

fujang
3天前
#include <bits/stdc++.h>
typedef long long ll;
#define int long long
using namespace std;
const int N = 50010;

int primes[N], cnt;
bool st[N];
int mobius[N], sum[N];

void init(int n) {
    mobius[1] = 1;
    for (int i = 2; i <= n; i++) {
        if (!st[i]) {
            primes[cnt++] = i;
            mobius[i] = -1;
        }
        for (int j = 0; primes[j] * i <= n; j++) {
            int t = primes[j] * i;
            st[t] = true;
            if (i % primes[j] == 0) {
                mobius[t] = 0;
                break;
            }
            mobius[t] = mobius[i] * -1;
        }
    }
    for (int i = 1; i <= n; i++) sum[i] = sum[i - 1] + mobius[i];
}

void work() {
    int a, b, d;
    cin >> a >> b >> d;
    a /= d, b /= d;
    int n = min(a, b);
    ll res = 0;
    for (int l = 1, r; l <= n; l = r + 1) {
        r = min(n, min(a / (a / l), b / (b / l)));
        res += (sum[r] - sum[l - 1]) * (ll)(a / l) * (b / l);
    }
    cout << res << endl;
}

int32_t main() {
    ios::sync_with_stdio(false);
    cin.tie(0);
    cout.tie(0);
    init(N - 1);
    int cas;
    cin >> cas;
    while (cas--) work();
}


活动打卡代码 AcWing 214. Devu和鲜花

fujang
3天前
#include <bits/stdc++.h>
#define int long long
using namespace std;
int A[20];
constexpr int mod = 1e9 + 7;
int down = 1;

int qmi(int a, int b, int p) {
    int res = 1;
    while (b) {
        if (b & 1) res = res * a % p;
        a = a * a % p;
        b >>= 1;
    }
    return res;
}

int C(int a, int b) {
    if (a < b) return 0;
    int up = 1;
    for (int i = a; i > a - b; i--) up = i % mod * up % mod;
    return up * down % mod;
}

void work() {
    int n, m;
    cin >> n >> m;
    for (int i = 0; i < n; i++) cin >> A[i];
    for (int j = 1; j <= n - 1; j++) down = j * down % mod;
    down = qmi(down, mod - 2, mod);

    int res = 0;
    for (int i = 0; i < 1 << n; i++) {
        int a = n + m - 1, b = n - 1;
        int sign = 1;
        for (int j = 0; j < n; j++)
            if (i >> j & 1) {
                sign *= -1;
                a -= A[j] + 1;
            }
        res = (res + C(a, b) * sign) % mod;
    }
    cout << (res % mod + mod) % mod << endl;
}

int32_t main() {
    ios::sync_with_stdio(false);
    cin.tie(0);
    cout.tie(0);
    int cas;
    cas = 1;
    while (cas--) work();
}


活动打卡代码 AcWing 890. 能被整除的数

fujang
3天前
#include <bits/stdc++.h>
#define int long long
using namespace std;
int p[20];

void work() {
    int n, m;
    cin >> n >> m;
    for (int i = 0; i < m; i++) cin >> p[i];

    int res = 0;
    for (int i = 1; i < 1 << m; i++) {
        int t = 1, s = 0;
        for (int j = 0; j < m; j++)
            if (i >> j & 1) {
                if (t * p[j] > n) {
                    t = -1;
                    break;
                } 
                t *= p[j];
                s++;
            }
        if (t != -1) {
            if (s % 2) res += n / t;
            else res -= n / t;
        }
    }
    cout << res << endl;
}

int32_t main() {
    ios::sync_with_stdio(false);
    cin.tie(0);
    cout.tie(0);
    int cas;
    cas = 1;
    while (cas--) work();
}


活动打卡代码 AcWing 756. 蛇形矩阵

fujang
5天前
#include <iostream>

using namespace std;
const int N = 105;

int a[N][N];
int n, m;

int main() {
    cin >> n >> m;

    int left = 0, right = m - 1, top = 0, bottom = n - 1;
    int k = 1;
    while (left <= right && top <= bottom) {
        for (int i = left ; i <= right; i ++) {
            a[top][i] = k ++;
        }
        for (int i = top + 1; i <= bottom; i ++) {
            a[i][right] = k ++;
        }
        for (int i = right - 1; i >= left && top < bottom; i --) {
            a[bottom][i] = k ++;
        }
        for (int i = bottom - 1; i > top && left < right; i --) {
            a[i][left] = k ++;
        }
        left ++, right --, top ++, bottom --;
    }
    for (int i = 0; i < n; i ++) {
        for (int j = 0; j < m; j ++) {
            cout << a[i][j] << " ";
        }
        cout << endl;
    }
    return 0;
}


活动打卡代码 AcWing 352. 闇の連鎖

fujang
6天前
#include <bits/stdc++.h>
using namespace std;
const int maxn = 100010, M = maxn * 2;
int n, m;
int h[maxn], e[M], ne[M], idx;
int depth[maxn], fa[maxn][17];
int d[maxn];
int ans;

void add(int a, int b) {
    e[idx] = b, ne[idx] = h[a], h[a] = idx++;
}

void bfs() {
    memset(depth, 0x3f, sizeof depth);
    depth[0] = 0, depth[1] = 1;
    queue<int> q;
    q.push(1);
    while (q.size()) {
        int t = q.front();
        q.pop();
        for (int i = h[t]; ~i; i = ne[i]) {
            int j = e[i];
            if (depth[j] > depth[t] + 1) {
                depth[j] = depth[t] + 1;
                q.push(j);
                fa[j][0] = t;
                for (int k = 1; k <= 16; k++)
                    fa[j][k] = fa[fa[j][k - 1]][k - 1];
            }
        }
    }
}

int lca(int a, int b) {
    if (depth[a] < depth[b]) swap(a, b);
    for (int k = 16; k >= 0; k--)
        if (depth[fa[a][k]] >= depth[b])
            a = fa[a][k];

    if (a == b) return a;
    for (int k = 16; k >= 0; k--)
        if (fa[a][k] != fa[b][k]) {
            a = fa[a][k];
            b = fa[b][k];
        }
    return fa[a][0];
}

int dfs(int u, int father) {
    int res = d[u];
    for (int i = h[u]; i != -1; i = ne[i]) {
        int j = e[i];
        if (j != father) {
            int s = dfs(j, u);
            if (s == 0) ans += m;
            else if (s == 1) ans++;
            res += s;
        }
    }
    return res;
}

int32_t main() {
    ios::sync_with_stdio(false);
    cin.tie(nullptr);
    cout.tie(nullptr);
    cin >> n >> m;
    memset(h, -1, sizeof h);
    for (int i = 0; i < n - 1; i++) {
        int a, b;
        cin >> a >> b;
        add(a, b), add(b, a);
    }

    bfs();

    for (int i = 0; i < m; i++) {
        int a, b;
        cin >> a >> b;
        int p = lca(a, b);
        d[a]++, d[b]++, d[p] -= 2;
    }
    dfs(1, -1);
    cout << ans << endl;
    return 0;
}