头像

Mintind


访客:4419

在线 


活动打卡代码 AcWing 256. 最大异或和

Mintind
1天前

又卡常…抄了一堆不知道啥东东
貌似不用fread的快读也能过

#include <iostream>
#include <cstdio>
#include <cstring>
#include <algorithm>
using namespace std;

namespace io
{
    const int SIZE = (1 << 21) + 1;
    char ibuf[SIZE], *is, *it, obuf[SIZE], *os = obuf, *ot = os + SIZE - 1, ch[35];

    char gc() 
    {
        return (is == it ? (it = (is = ibuf) + fread(ibuf, 1, SIZE, stdin), (is == it ? EOF : *is++)) : *is++);
    }

    void flush()
    {
        fwrite(obuf, 1, os-obuf, stdout);
        os = obuf;
    }

    void putc(char c)
    {
        *os++ = c;
        if (os == ot) flush();
    }

    int read()
    {
        char c;
        int x = 0, f = 1;
        do{
            c = gc();
            if (c == '-') f = -1;
        } while (c < '0' || c > '9');
        do{
            x = (x << 3) + (x << 1) + (c ^ 48);
            c = gc();
        } while (c >= '0' && c <= '9');
        return x * f;
    }

    void write(int x)
    {
        if (!x)
        {
            putc('0');
            return;
        }
        if (x < 0) x = -x, putc('-');
        int len = 0;
        while (x)
        {
            ch[++len] = x % 10 + '0';
            x /= 10;
        }
        while (len) putc(ch[len--]);
    }
}
using io::gc;
using io::putc;
using io::read;
using io:: write;
using io::flush;

const int N = 6e5 + 5;

int s[N], root[N];

int max(const int &x, const int &y)
{
    return x > y ? x : y;
}

struct Trie
{
    int version[N * 23], trie[N * 23][2];
    int tot;

    void insert(int i, int k, int p, int &q)
    {
        if (!q) q = ++tot;
        if (k < 0)
        {
            version[q] = i;
            return;
        }

        int ch = s[i] >> k & 1;
        if (p) trie[q][ch ^ 1] = trie[p][ch ^ 1];
        insert(i, k - 1, trie[p][ch], trie[q][ch]);
        version[q] = max(version[trie[q][0]], version[trie[q][1]]);
    }

    int query(int val, int p, int limit)
    {
        for (int k = 23; k >= 0; k--)
        {
            int ch = val >> k & 1;
            if (version[trie[p][ch ^ 1]] >= limit) p = trie[p][ch ^ 1];
            else p = trie[p][ch];
        }
        return val ^ s[version[p]];
    }

    void build()
    {
        tot = 0;
        version[0] = -1;
        insert(0, 23, 0, root[0]);
    }
}trie;

int main()
{
    int n = read(), m = read();

    trie.build();

    int t = 0;
    s[0] = 0;
    for (int i = 1; i <= n; i++)
    {
        t++;
        s[t] = read();
        s[t] = s[t - 1] ^ s[t];
        trie.insert(t, 23, root[t - 1], root[t]);
    }
    while (m--)
    {
        char op = gc(); 
        while (!isalpha(op)) op = gc();
        int l, r, x;

        if (op == 'A')
        {
            x = read();
            t++;
            s[t] = s[t - 1] ^ x;
            trie.insert(t, 23, root[t - 1], root[t]);
        }
        else
        {
            l = read(), r = read(), x = read();
            write(trie.query(x ^ s[t], root[r - 1], l - 1));
            putc('\n');
        }
    }

    flush();

    return 0;
}


活动打卡代码 AcWing 255. 第K个数

Mintind
1天前

整体分治

在值域上分治,树状数组维护区间(下标)中数的个数,一定要分清楚

#include <iostream>
#include <cstdio>
#include <cstring>
#include <algorithm>
using namespace std;

const int N = 11e4 + 5, INF = 1e9;

struct BIT
{
    int c[N];

    int lowbit(int x)
    {
        return x & -x;
    }

    void change(int x, int val)
    {
        for ( ; x < N; x += lowbit(x))
            c[x] += val;
    }

    int query(int x)
    {
        int res = 0;
        for ( ; x; x -= lowbit(x))
            res += c[x];
        return res;
    }
}bit;

struct Node
{
    int x, y, k, id;
}q[N], lq[N], rq[N];
int ans[N];

void solve(int lval, int rval, int st, int ed)
{
    if (st > ed) return;
    if (lval == rval)
    {
        for (int i = st; i <= ed; i++)
            ans[q[i].id] = lval;
        return;
    }

    int mid = lval + rval >> 1;
    int lt = 0, rt = 0;
    for (int i = st; i <= ed; i++)
    {
        if (q[i].id == 0)
        {
            if (q[i].y <= mid)
            {
                bit.change(q[i].x, 1);
                lq[++lt] = q[i];
            }
            else rq[++rt] = q[i];
        }
        else
        {
            int cnt = bit.query(q[i].y) - bit.query(q[i].x - 1);
            if (cnt >= q[i].k) lq[++lt] = q[i];
            else q[i].k -= cnt, rq[++rt] = q[i];
        }
    }

    for (int i = st; i <= ed; i++)
        if (q[i].id == 0 && q[i].y <= mid) bit.change(q[i].x, -1);

    for (int i = 1; i <= lt; i++)
        q[st + i - 1] = lq[i];
    for (int i = 1; i <= rt; i++)
        q[st + lt + i - 1] = rq[i];

    solve(lval, mid, st, st + lt - 1);
    solve(mid + 1, rval, st + lt, ed);
}

int main()
{
    int n, m;
    cin >> n >> m;

    for (int i = 1; i <= n; i++)
    {
        cin >> q[i].y;
        q[i].x = i;//在i位置上的数+1
        q[i].id = 0;
    }
    for (int i = n + 1; i <= n + m; i++)
    {
        cin >> q[i].x >> q[i].y >> q[i].k;
        q[i].id = i;
    }

    solve(-INF, INF, 1, n + m);

    for (int i = n + 1; i <= n + m; i++)
        cout << ans[i] << endl;

    return 0;
}

可持久线段树(主席树)

用root[i]区分范围(下标),线段树里的l,r是值域

#include <iostream>
#include <cstdio>
#include <cstring>
#include <algorithm>
using namespace std;

const int N = 1e5 + 5;

int a[N], nums[N], root[N];

struct Persistable_Segment_Tree
{
    int tot;
    int lc[N * 20], rc[N * 20], cnt[N * 20];

    void build(int &p, int l, int r)
    {
        p = ++tot;
        if (l == r)
        {
            cnt[p] = 0;
            return;
        }
        int mid = l + r >> 1;
        build(lc[p], l, mid);
        build(rc[p], mid + 1, r);
    }

    void insert(int &p, int pre, int l, int r, int x)
    {
        p = ++tot;
        lc[p] = lc[pre], rc[p] = rc[pre], cnt[p] = cnt[pre];
        if (l == r)
        {
            cnt[p]++;
            return;
        }
        int mid = l + r >> 1;
        if (x <= mid) insert(lc[p], lc[pre], l, mid, x);
        else insert(rc[p], rc[pre], mid + 1, r, x);
        cnt[p] = cnt[lc[p]] + cnt[rc[p]];
    }

    int query(int p, int q, int l, int r, int k)
    {
        if (l == r) return nums[l];
        int lcnt = cnt[lc[p]] - cnt[lc[q]];
        int mid = l + r >> 1;
        if (lcnt >= k) return query(lc[p], lc[q], l, mid, k);
        else return query(rc[p], rc[q], mid + 1, r, k - lcnt);
    }
}t;

int main()
{
    int n, m;
    scanf("%d%d", &n, &m);

    for (int i = 1; i <= n; i++)
    {
        scanf("%d", &a[i]);
        nums[i] = a[i];
    }
    sort(nums + 1, nums + n + 1);
    int num = unique(nums + 1, nums + n + 1) - nums - 1;
    for (int i = 1; i <= n; i++)
        a[i] = lower_bound(nums + 1, nums + num + 1, a[i]) - nums;

    t.tot = 0;
    t.build(root[0], 1, num);
    for (int i = 1; i <= n; i++)
        t.insert(root[i], root[i - 1], 1, num, a[i]);

    while (m--)
    {
        int l, r, k;
        scanf("%d%d%d", &l, &r, &k);
        printf("%d\n", t.query(root[r], root[l - 1], 1, num, k));
    }

    return 0;
}


活动打卡代码 AcWing 254. 天使玩偶

Mintind
1天前

我会的不会的优化都加上了 最多只过了17个点 不加O2过不了啊QAQ

#pragma GCC optimize(2)
#include <iostream>
#include <cstdio>
#include <cstring>
#include <algorithm>
using namespace std;

#define getchar() (p1 == p2 && (p2 = (p1 = buf) + fread(buf, 1, 1 << 21, stdin), p1 == p2) ? EOF : *p1++)
//#define putchar(x) (p3 - obuf < 1000000) ? (*p3++ = x) : (fwrite(obuf, p3-obuf, 1, stdout), p3 = obuf,*p3++ = x)
char buf[1 << 23], *p1 = buf, *p2 = buf, obuf[1 << 23], *p3 = obuf, ch[35];

int read()
{
    int x = 0, f = 1;
    char c;
    do{
        c = getchar();
        if (c == '-') f = -1;
    } while (c < '0' || c > '9');
    do{
        x = (x << 3) + (x << 1) + (c ^ 48);
        c = getchar();
    } while (c >= '0' && c <= '9');
    return x * f;
}

void write(int x)
{
    if (x == 0)
    {
        putchar('0');
        return;
    }
    if (x < 0) x = -x, putchar('-');
    int len = 0;
    while (x)
    {
        ch[len++] = x % 10 + '0';
        x /= 10;
    }
    while (len--) putchar(ch[len]);
}

const int N = 1e6 + 5, INF = 1e9;

int n, m;

inline int max(const int &x, const int &y)
{
    return x > y ? x : y;
}
inline int min(const int &x, const int &y)
{
    return x < y ? x : y;
}

struct Binary_Indexed_Tree
{
    int c[N];

    void init()
    {
        for (int i = 1; i < N; ++i)
            c[i] = -INF;
    }

    inline int lowbit(const int &x)
    {
        return x & -x;
    }

    void modify(const int &x, const int &val)
    {
        for (int i = x; i < N; i += lowbit(i))
            c[i] = max(c[i], val);
    }

    int query(const int &x)
    {
        int res = 0;
        for (int i = x; i; i -= lowbit(i))
            res = max(res, c[i]);
        return res;
    }

    void remove(const int &x)
    {
        for (int i = x; i < N; i += lowbit(i))
            if (c[i] != 0) c[i] = 0;
            else break;
    }
}bit;

struct Node
{
    int x, y, k, id, ans;
    bool operator < (const Node &t) const
    {
        return x == t.x ? (y == t.y ? k < t.k : y < t.y) : x < t.x;
    }
}a[N], q[N], t[N];

void solve(const int &l, const int &r)
{
    if (l == r) return;

    int mid = l + r >> 1;
    solve(l, mid);
    solve(mid + 1, r);

    int i = l, j = mid + 1, k = l;
    while (i <= mid && j <= r)
    {
        if (q[i] < q[j])
        {
            if (q[i].k == 1) bit.modify(q[i].y, q[i].x + q[i].y);
            t[k++] = q[i++];
        }
        else
        {
            if (q[j].k == 2)
            {
                int tmp = bit.query(q[j].y);
                if (!tmp) tmp = -INF;
                a[q[j].id].ans = min(a[q[j].id].ans, q[j].x + q[j].y - tmp);
            }
            t[k++] = q[j++];
        }
    }
    while (i <= mid)
    {
        t[k++] = q[i++];
    }
    while (j <= r)
    {
        if (q[j].k == 2) 
        {
            int tmp = bit.query(q[j].y);
            if (!tmp) tmp = -INF;
            a[q[j].id].ans = min(a[q[j].id].ans, q[j].x + q[j].y - tmp);
        }
        t[k++] = q[j++];
    }

    for (int i = l; i <= mid; ++i)
        if (q[i].k == 1) bit.remove(q[i].y);

    for (int i = l; i <= r; ++i)
        q[i] = t[i];
}

int mx, my, len;
void work()
{
    len = 0;
    for (int i = 1; i <= n + m; i++)
    {
        if (a[i].x <= mx && a[i].y <= my)
            q[++len] = a[i];
    }

    solve(1, len);
}

int main()
{
    //fwrite(obuf, p3-obuf, 1, stdout);

    n = read(), m = read();

    mx = my = 0;
    for (int i = 1; i <= n; ++i)
    {
        a[i].x = read() + 1, a[i].y = read() + 1;
        a[i].k = 1;
        a[i].id = i;
        a[i].ans = INF;
    }
    for (int i = n + 1; i <= n + m; ++i)
    {
        a[i].k = read(), a[i].x = read() + 1, a[i].y = read() + 1;
        a[i].id = i;
        a[i].ans = INF;
        if (a[i].k == 2) mx = max(mx, a[i].x), my = max(my, a[i].y);
    }
    work();

    my = 0;
    for (int i = 1; i <= n + m; ++i)
    {
        a[i].y = N - a[i].y;
        if (a[i].k == 2) my = max(my, a[i].y);
    }
    work();

    mx = 0;
    for (int i = 1; i <= n + m; ++i)
    {
        a[i].x = N - a[i].x;
        if (a[i].k == 2) mx = max(mx, a[i].x);
    }
    work();

    my = 0;
    for (int i = 1; i <= n + m; ++i)
    {
        a[i].y = N - a[i].y;
        if (a[i].k == 2) my = max(my, a[i].y);
    }
    work();

    for (int i = n + 1; i <= n + m; ++i)
        if (a[i].k == 2) write(a[i].ans), putchar('\n');

    return 0;
}


活动打卡代码 AcWing 253. 普通平衡树

Mintind
2天前
#include <iostream>
#include <cstdio>
#include <cstdlib>
#include <cstring>
#include <algorithm>
using namespace std;

const int N = 1e5 + 10, INF = 1e9;

struct Treap
{
    int l, r;
    int val, dat, cnt, size;
}t[N];
int tot, root;

int new_node(int val)
{
    t[++tot].val = val;
    t[tot].dat = rand();
    t[tot].cnt = t[tot].size = 1;
    return tot;
}

void update(int p)
{
    t[p].size = t[t[p].l].size + t[t[p].r].size + t[p].cnt;
}

void build()
{
    tot = 0;
    new_node(-INF), new_node(INF);
    root = 1;
    t[1].r = 2;
    update(1);
}

void zig(int &p)
{
    int q = t[p].l;
    t[p].l = t[q].r, t[q].r = p, p = q;
    update(t[p].r), update(p);
}

void zag(int &p)
{
    int q = t[p].r;
    t[p].r = t[q].l, t[q].l = p, p = q;
    update(t[p].l), update(p);
}

void insert(int &p, int val)
{
    if (p == 0)
    {
        p = new_node(val);
        return;
    }
    if (val == t[p].val)
    {
        t[p].cnt++;
        update(p);
        return;
    }
    if (val < t[p].val)
    {
        insert(t[p].l, val);
        if (t[t[p].l].dat > t[p].dat) zig(p);
    }
    else
    {
        insert(t[p].r, val);
        if (t[t[p].r].dat > t[p].dat) zag(p);
    }
    update(p);
}

void del(int &p, int val)
{
    if (p == 0) return;
    if (val == t[p].val)
    {
        if (t[p].cnt > 1)
        {
            t[p].cnt--;
            update(p);
            return;
        }
        if (t[p].l || t[p].r)
        {
            if (t[p].r == 0 || t[t[p].l].dat > t[t[p].r].dat)
                zig(p), del(t[p].r, val);
            else
                zag(p), del(t[p].l, val);
            update(p);
        }
        else p = 0;
        return;
    }
    val < t[p].val ? del(t[p].l, val) : del(t[p].r, val);
    update(p);
}

int get_pre(int val)
{
    int ans = 1, p = root;
    while (p)
    {
        if (val == t[p].val)
        {
            if (t[p].l)
            {
                p = t[p].l;
                while (t[p].r) p = t[p].r;
                ans = p;
            }
            break;
        }
        if (t[p].val < val && t[p].val > t[ans].val) ans = p;
        p = val < t[p].val ? t[p].l : t[p].r;
    }
    return t[ans].val;
}

int get_next(int val)
{
    int ans = 2, p = root;
    while (p)
    {
        if (val == t[p].val)
        {
            if (t[p].r)
            {
                p = t[p].r;
                while (t[p].l) p = t[p].l;
                ans = p;
            }
            break;
        }
        if (t[p].val > val && t[p].val < t[ans].val) ans = p;
        p = val < t[p].val ? t[p].l : t[p].r;
    }
    return t[ans].val;
}

int get_rank(int p, int val)
{
    if (p == 0) return 0;
    if (val == t[p].val) return t[t[p].l].size + 1;
    if (val < t[p].val) return get_rank(t[p].l, val);
    return get_rank(t[p].r, val) + t[t[p].l].size + t[p].cnt;
}

int get_val(int p, int rank)
{
    if (p == 0) return INF;
    if (t[t[p].l].size >= rank) return get_val(t[p].l, rank);
    if (t[t[p].l].size + t[p].cnt >= rank) return t[p].val;
    return get_val(t[p].r, rank - t[t[p].l].size - t[p].cnt);
}

int main()
{
    build();

    int n;
    cin >> n;
    while (n--)
    {
        int opt, x;
        cin >> opt >> x;

        switch (opt)
        {
            case 1: insert(root, x); break;
            case 2: del(root, x); break;
            case 3: cout << get_rank(root, x) - 1 << endl; break;
            case 4: cout << get_val(root, x + 1) << endl; break;
            case 5: cout << get_pre(x) << endl; break;
            case 6: cout << get_next(x) << endl; break;
        }
    }
    return 0;
}



Mintind
2天前

绕题解逛了一圈,没有发现完全按书上第二种方法写的,自己gong了一个放上来qwq
比直接存dis的难写一些,实测快一点 几十ms emmm

#include <iostream>
#include <cstdio>
#include <cstring>
#include <algorithm>
using namespace std;

const int N = 10005, M = 20005;

int n, k;

int tot;
int head[N], nxt[M], ver[M], edge[M];

void add(int x, int y, int z)
{
    nxt[++tot] = head[x];
    ver[tot] = y;
    edge[tot] = z;
    head[x] = tot;
}

int minx, root, all;
int sz[N], vis[N];

void get_root(int x, int fa)
{
    sz[x] = 1;
    int max_part = 0;
    for (int i = head[x]; i; i = nxt[i])
    {
        int y = ver[i];
        if (y == fa || vis[y]) continue;
        get_root(y, x);
        sz[x] += sz[y];
        max_part = max(max_part, sz[y]);
    }
    max_part = max(max_part, all - sz[x]);
    if (max_part < minx)
    {
        minx = max_part;
        root = x;
    }
}

int ans, len;
int cnt[N], d[N];
struct Node
{
    int dis, bel;
    bool operator < (const Node &x) const
    {
        return dis < x.dis;
    }
}seq[N];

void get_dis(int x, int fa, int bel)
{
    seq[++len].dis = d[x];
    if (x == root || fa == root) bel = x;
    seq[len].bel = bel;
    cnt[bel]++;
    for (int i = head[x]; i; i = nxt[i])
    {
        int y = ver[i];
        if (y == fa || vis[y]) continue;
        d[y] = d[x] + edge[i];
        get_dis(y, x, bel);
    }
}

int cal()
{
    sort(seq + 1, seq + len + 1);

    int res = 0;
    /*
    //cnt[s]存[l + 1, r]中属于s的个数
    int l = 1, r = len;
    cnt[seq[1].bel]--;//这里要先减
    while (l < r)
    {
        if (seq[l].dis + seq[r].dis <= k)
        {
            res += r - l - cnt[seq[l].bel];
            cnt[seq[++l].bel]--;
        }
        else cnt[seq[r--].bel]--;
    }
    */
    //cnt[s]存[l, r]中属于s的个数(因为我习惯[l, r]啦)
    int l = 1, r = len;
    while (l <= r)//l = r时显然答案不会更新了,但是为了清空cnt数组emm
    {
        if (seq[l].dis + seq[r].dis <= k)
        {
            res += r - l + 1 - cnt[seq[l].bel];
            cnt[seq[l++].bel]--;
        }
        else cnt[seq[r--].bel]--;
    }

    return res;
}

void solve(int x)
{
    len = 0;
    d[x] = 0;
    get_dis(x, 0, 0);
    ans += cal();

    vis[x] = 1;
    for (int i = head[x]; i; i = nxt[i])
    {
        int y = ver[i];
        if (vis[y]) continue;

        all = minx = sz[y];
        get_root(y, x);
        solve(root);
    }
}

int main()
{
    while (cin >> n >> k, n && k)
    {
        memset(vis, 0, sizeof vis);
        memset(head, 0, sizeof head);
        tot = 1;

        for (int i = 1; i < n; i++)
        {
            int x, y, z;
            cin >> x >> y >> z;
            x++, y++;
            add(x, y, z), add(y, x, z);
        }

        ans = 0;
        minx = all = n;
        get_root(1, 0);
        solve(root);
        cout << ans << endl;
    }

    return 0;
}


活动打卡代码 AcWing 252. 树

Mintind
2天前

书上做法
有点长 但是快一些 因为不用每次都在子树中get_dis

#include <iostream>
#include <cstdio>
#include <cstring>
#include <algorithm>
using namespace std;

const int N = 10005, M = 20005;

int n, k;

int tot;
int head[N], nxt[M], ver[M], edge[M];

void add(int x, int y, int z)
{
    nxt[++tot] = head[x];
    ver[tot] = y;
    edge[tot] = z;
    head[x] = tot;
}

int minx, root, all;
int sz[N], vis[N];

void get_root(int x, int fa)
{
    sz[x] = 1;
    int max_part = 0;
    for (int i = head[x]; i; i = nxt[i])
    {
        int y = ver[i];
        if (y == fa || vis[y]) continue;
        get_root(y, x);
        sz[x] += sz[y];
        max_part = max(max_part, sz[y]);
    }
    max_part = max(max_part, all - sz[x]);
    if (max_part < minx)
    {
        minx = max_part;
        root = x;
    }
}

int ans, len;
int cnt[N], d[N];
struct Node
{
    int dis, bel;
    bool operator < (const Node &x) const
    {
        return dis < x.dis;
    }
}seq[N];

void get_dis(int x, int fa, int bel)
{
    seq[++len].dis = d[x];
    if (x == root || fa == root) bel = x;
    seq[len].bel = bel;
    cnt[bel]++;
    for (int i = head[x]; i; i = nxt[i])
    {
        int y = ver[i];
        if (y == fa || vis[y]) continue;
        d[y] = d[x] + edge[i];
        get_dis(y, x, bel);
    }
}

int cal()
{
    sort(seq + 1, seq + len + 1);

    int res = 0;
    //cnt[s]存[l + 1, r]中属于s的个数
    /*
    int l = 1, r = len;
    cnt[seq[1].bel]--;//这里要先减
    while (l < r)
    {
        if (seq[l].dis + seq[r].dis <= k)
        {
            res += r - l - cnt[seq[l].bel];
            cnt[seq[++l].bel]--;
        }
        else cnt[seq[r--].bel]--;
    }
    */
    //cnt[s]存[l, r]中属于s的个数(因为我习惯[l, r]啦)
    int l = 1, r = len;
    while (l <= r)//l = r时显然答案不会更新了,但是为了清空cnt数组emm
    {
        if (seq[l].dis + seq[r].dis <= k)
        {
            res += r - l + 1 - cnt[seq[l].bel];
            cnt[seq[l++].bel]--;
        }
        else cnt[seq[r--].bel]--;
    }

    return res;
}

void solve(int x)
{
    len = 0;
    d[x] = 0;
    get_dis(x, 0, 0);
    ans += cal();

    vis[x] = 1;
    for (int i = head[x]; i; i = nxt[i])
    {
        int y = ver[i];
        if (vis[y]) continue;

        all = minx = sz[y];
        get_root(y, x);
        solve(root);
    }
}

int main()
{
    while (cin >> n >> k, n && k)
    {
        memset(vis, 0, sizeof vis);
        memset(head, 0, sizeof head);
        tot = 1;

        for (int i = 1; i < n; i++)
        {
            int x, y, z;
            cin >> x >> y >> z;
            x++, y++;
            add(x, y, z), add(y, x, z);
        }

        ans = 0;
        minx = all = n;
        get_root(1, 0);
        solve(root);
        cout << ans << endl;
    }

    return 0;
}

看题解写的
只存d 不合法的减去

#include <iostream>
#include <cstdio>
#include <cstring>
#include <algorithm>
using namespace std;

const int N = 10005, M = 20005;

int n, k;

int tot;
int head[N], nxt[M], ver[M], edge[M];

void add(int x, int y, int z)
{
    nxt[++tot] = head[x];
    ver[tot] = y;
    edge[tot] = z;
    head[x] = tot;
}

int ans, len;
int root, mx, all;
int sz[N], vis[N], d[N], seq[N];

void get_root(int x, int fa)
{
    sz[x] = 1;
    int max_part = 0;
    for (int i = head[x]; i; i = nxt[i])
    {
        int y = ver[i];
        if (y == fa || vis[y]) continue;
        get_root(y, x);
        sz[x] += sz[y];
        max_part = max(max_part, sz[y]);
    }
    max_part = max(max_part, all - sz[x]);
    if (max_part < mx)
    {
        root = x;
        mx = max_part;
    }
}

void get_dis(int x, int fa)
{
    seq[++len] = d[x];
    for (int i = head[x]; i; i = nxt[i])
    {
        int y = ver[i];
        if (y == fa || vis[y]) continue;
        d[y] = d[x] + edge[i];
        get_dis(y, x);
    }
}

int cal(int x, int dis)
{
    d[x] = dis;
    len = 0;
    get_dis(x, 0);
    sort(seq + 1, seq + len + 1);

    int l = 1, r = len;
    int res = 0;
    while (l < r)
    {
        if (seq[l] + seq[r] <= k)
        {
            res += r - l;
            l++;
        }
        else r--;
    }

    return res;
}

void solve(int x)
{
    vis[x] = 1;
    ans += cal(x, 0);

    for (int i = head[x]; i; i = nxt[i])
    {
        int y = ver[i];
        if (vis[y]) continue;

        ans -= cal(y, edge[i]);//全在子树中,不合法

        mx = n;
        all = sz[y];
        get_root(y, x);
        solve(root);
    }
}

int main()
{
    while (cin >> n >> k, n && k)
    {
        memset(head, 0, sizeof head);
        memset(vis, 0, sizeof vis);
        tot = 1;

        for (int i = 1; i < n; i++)
        {
            int x, y, z;
            cin >> x >> y >> z;
            x++, y++;
            add(x, y, z), add(y, x, z);
        }

        mx = n;
        all = n;
        get_root(1, 0);
        ans = 0;
        solve(root);
        cout << ans << endl;
    }

    return 0;
}


活动打卡代码 AcWing 251. 小Z的袜子

Mintind
4天前
#include <iostream>
#include <cstdio>
#include <cstring>
#include <cmath>
#include <algorithm>
using namespace std;

typedef long long ll;

const int N = 50005;

int color[N], pos[N];
ll ans[N], p[N], num[N];

struct Node
{
    int l, r, id;

    bool operator < (const Node &x) const
    {
        return pos[l] ^ pos[x.l] ? l < x.l : (pos[l] & 1 ? r < x.r : r > x.r);
    }
}query[N];

ll res;
void add(int x)
{
    res -= num[x] * (num[x] - 1);
    num[x]++;
    res += num[x] * (num[x] - 1);
}
void del(int x)
{
    res -= num[x] * (num[x] - 1);
    num[x]--;
    res += num[x] * (num[x] - 1);
}

ll gcd(ll a, ll b)
{
    return b ? gcd(b, a % b) : a;
}

int main()
{
    int n, m;
    scanf("%d%d", &n, &m);

    for (int i = 1; i <= n; i++)
        scanf("%d", &color[i]);

    int block = sqrt(m), t = ceil((double)m / block);

    for (int i = 1; i <= m; i++)
    {
        int l, r;
        scanf("%d%d", &l, &r);
        query[i] = {l, r, i};
        p[i] = (ll)(r - l + 1) * (r - l);
        pos[i] = (i - 1) / block + 1;
    }
    sort(query + 1, query + m + 1);

    int l = query[0].l, r = l - 1;
    res = 0;
    for (int i = 1; i <= m; i++)
    {
        while (l < query[i].l) del(color[l++]);
        while (l > query[i].l) add(color[--l]);
        while (r < query[i].r) add(color[++r]);
        while (r > query[i].r) del(color[r--]);
        ans[query[i].id] = res;
    }

    for (int i = 1; i <= m; i++)
    {
        if (!p[i]) printf("0/1\n");//l = r
        else
        {
            ll d = gcd(ans[i], p[i]);
            printf("%lld/%lld\n", ans[i] / d, p[i] / d);
        }
    }

    return 0;
}


活动打卡代码 AcWing 250. 磁力块

Mintind
5天前
#include <iostream>
#include <cstdio>
#include <cstring>
#include <queue>
#include <cmath>
#include <algorithm>
using namespace std;

typedef long long ll;

const int N = 250005, T = 505;

struct Node
{
    ll d, m, p, r;
}node[N];

bool cmp_d(Node a, Node b)
{
    return a.d < b.d;
}

bool cmp_m(Node a, Node b)
{
    return a.m < b.m;
}

int L[T], R[T];
ll D[T];
int vis[N];

int main()
{
    int x0, y0, n;
    cin >> x0 >> y0 >> node[0].p >> node[0].r >> n;
    node[0].r *= node[0].r;

    for (int i = 1; i <= n; i++)
    {
        ll x, y, m, p, r;
        cin >> x >> y >> m >> p >> r;
        ll d = (x - x0) * (x - x0) + (y - y0) * (y - y0);
        r *= r;
        node[i] = {d, m, p, r};
    }

    sort(node + 1, node + n + 1, cmp_d);
    int t = 0, block = sqrt(n);
    for (int i = 1; i <= n; i += block)
    {
        L[++t] = i;
        R[t] = min(i + block - 1, n);
        D[t] = node[R[t]].d;
        sort(node + L[t], node + R[t] + 1, cmp_m);
    }

    queue<int> q;
    q.push(0);
    ll ans = 0;
    while (q.size())
    {
        int x = q.front();
        q.pop();

        ll p = node[x].p, r = node[x].r;

        for (int i = 1; i <= t; i++)
        {
            if (D[i] > r)
            {
                for (int j = L[i]; j <= R[i]; j++)
                {
                    if (!vis[j] && node[j].m <= p && node[j].d <= r)
                    {
                        vis[j] = 1;
                        q.push(j);
                        ans++;
                    }
                }
                break;
            }

            while (L[i] <= R[i] && node[L[i]].m <= p)
            {
                if (!vis[L[i]])
                {
                    vis[L[i]] = 1;
                    q.push(L[i]);
                    ans++;
                }
                L[i]++;
            }
        }
    }
    cout << ans;

    return 0;
}


活动打卡代码 AcWing 249. 蒲公英

Mintind
7天前
#include <iostream>
#include <cstdio>
#include <cstring>
#include <cmath>
#include <vector>
#include <algorithm>
using namespace std;

const int N = 40005, T = 1005;

int f[T][T];
int a[N], nums[N], pos[N], vis[N];
vector<int> v[N];

int n, m, t;

void init()
{
    for (int i = 1; i < pos[n]; i++)
    {
        memset(vis, 0, sizeof vis);
        int x = 0, mx = 0;
        for (int j = (i - 1) * t + 1; j <= n; j++)
        {
            vis[a[j]]++;
            if (vis[a[j]] > mx || (vis[a[j]] == mx && a[j] < x))
            {
                mx = vis[a[j]];
                x = a[j];
            }
            f[i][pos[j]] = x;
        }
    }
}

int count(int l, int r, int x)
{
    return upper_bound(v[x].begin(), v[x].end(), r) - lower_bound(v[x].begin(), v[x].end(), l);
}

int querry(int l, int r)
{
    int p = pos[l], q = pos[r];
    int x = 0, mx = 0;
    if (q - p <= 1)
    {
        for (int i = l; i <= r; i++)
        {
            int cnt = count(l, r, a[i]);
            if (cnt > mx || (cnt == mx && a[i] < x))
            {
                mx = cnt;
                x = a[i];
            }
        }
    }
    else
    {
        x = f[p + 1][q - 1];
        mx = count(l, r, x);
        for (int i = l; i <= p * t; i++)
        {
            int cnt = count(l, r, a[i]);
            if (cnt > mx || (cnt == mx && a[i] < x))
            {
                mx = cnt;
                x = a[i];
            }
        }
        for (int i = (q - 1) * t + 1; i <= r; i++)
        {
            int cnt = count(l, r, a[i]);
            if (cnt > mx || (cnt == mx && a[i] < x))
            {
                mx = cnt;
                x = a[i];
            }
        }
    }
    return nums[x];
}

int main()
{
    scanf("%d%d", &n, &m);
    t = max(1, (int)(n / sqrt((double)n * log2(n))));
    for (int i = 1; i <= n; i++)
    {
        scanf("%d", &a[i]);
        nums[i] = a[i];
        pos[i] = (i - 1) / t + 1;
    }
    sort(nums + 1, nums + n + 1);
    int num = unique(nums + 1, nums + n + 1) - nums - 1;
    for (int i = 1; i <= n; i++)
    {
        a[i] = lower_bound(nums + 1, nums + num + 1, a[i]) - nums;
        v[a[i]].push_back(i);
    }
    init();

    int l, r, x = 0;
    while (m--)
    {
        scanf("%d%d", &l, &r);
        l = (l + x - 1) % n + 1;
        r = (r + x - 1) % n + 1;
        if (l > r) swap(l, r);
        x = querry(l, r);
        printf("%d\n", x);
    }

    return 0;
}



Mintind
9天前

最大流模板 洛谷P3376 【模板】网络最大流


Edmonds-Karp增广路算法

#include <iostream>
#include <cstdio>
#include <cstring>
#include <queue>
#include <algorithm>
using namespace std;

typedef long long ll;

const int N = 205, M = 10005;
const ll INF = 1e12;

int head[N], nxt[M], ver[M]; 
ll edge[M], incf[N];
int vis[N][N], pre[N], v[N];

int tot;
void add_edge(int x, int y, int z)
{
    nxt[++tot] = head[x], ver[tot] = y, edge[tot] = z, head[x] = tot;
    nxt[++tot] = head[y], ver[tot] = x, edge[tot] = 0, head[y] = tot;
}

int s, t;
ll maxflow;

bool bfs()
{
    memset(v, 0, sizeof v);
    queue<int> q;

    q.push(s), v[s] = 1;
    incf[s] = INF;
    while (q.size())
    {
        int x = q.front();
        q.pop();
        for (int i = head[x]; i; i = nxt[i])
        {
            int y = ver[i];
            ll z = edge[i];
            if (v[y] || z == 0) continue;
            incf[y] = min(incf[x], z);
            pre[y] = i;
            if (y == t) return true;
            q.push(y), v[y] = 1;
        }
    }

    return false;
}

void update()
{
    int x = t;
    while (x != s)
    {
        int i = pre[x];
        edge[i] -= incf[t];
        edge[i ^ 1] += incf[t];
        x = ver[i ^ 1];
    }
    maxflow += incf[t];
}

int main()
{
    int n, m;
    scanf("%d%d%d%d", &n, &m, &s, &t);

    tot = 1;
    while (m--)
    {
        int x, y, z;
        scanf("%d%d%d", &x, &y, &z);
        if (!vis[x][y]) add_edge(x, y, z), vis[x][y] = tot - 1;//记录x -> y的边的编号
        else edge[vis[x][y]] += z;//处理重边
    }

    maxflow = 0;
    while (bfs()) update();
    printf("%lld", maxflow);

    return 0;
}

Dinic算法
借助分层图,一次寻找多条增广路

#include <iostream>
#include <cstdio>
#include <cstring>
#include <queue>
#include <algorithm>
using namespace std;

typedef long long ll;

const int N = 205, M = 10005;
const ll INF = 1e12;

int head[N], now[N], nxt[M], ver[M];
int d[N], vis[N][N];
ll edge[M], w[N][N];

int tot;
void add_edge(int x, int y, int z)
{
    nxt[++tot] = head[x], ver[tot] = y, edge[tot] = z, head[x] = tot;
    nxt[++tot] = head[y], ver[tot] = x, edge[tot] = 0, head[y] = tot;
}

int s, t;

bool bfs()
{
    memset(d, 0, sizeof d);
    queue<int> q;

    q.push(s), d[s] = 1;
    now[s] = head[s];//注意每次now要重新赋值为head
    while (q.size())
    {
        int x = q.front();
        q.pop();
        for (int i = head[x]; i; i = nxt[i])
        {
            int y = ver[i];
            if (d[y] || edge[i] == 0) continue;
            now[y] = head[y];
            d[y] = d[x] + 1;
            if (y == t) return true;
            q.push(y);
        }
    }

    return false;
}

ll dinic(int x, ll flow)
{
    if (x == t) return flow;
    ll ans = 0;
    for (int i = now[x]; i && flow; i = nxt[i])
    {
        now[x] = i;//当前弧优化,已经用过的边都废了
        int y = ver[i];
        if (!edge[i] || d[y] != d[x] + 1) continue;
        ll t = dinic(y, min(flow, edge[i]));
        if (t == 0) d[y] = 0;//这个点不行了,优化掉
        edge[i] -= t;
        edge[i ^ 1] += t;
        flow -= t;
        ans += t;
    }
    return ans;
}

int main()
{
    int n, m;
    scanf("%d%d%d%d", &n, &m, &s, &t);

    tot = 1;
    while (m--)
    {
        int x, y, z;
        scanf("%d%d%d", &x, &y, &z);
        if (!vis[x][y]) add_edge(x, y, z), vis[x][y] = tot - 1;
        else edge[vis[x][y]] += z;
    }

    ll maxflow = 0, flow;
    while (bfs())
        while (flow = dinic(s, INF))
            maxflow += flow;
    printf("%lld", maxflow);

    return 0;
}