头像

kuzi




离线:1天前


最近来访(48)
用户头像
Reinhart
用户头像
SnowMan
用户头像
寻觅梦
用户头像
Aigrl
用户头像
moreexcellent
用户头像
填海难....填心更难
用户头像
scboy
用户头像
那你呢
用户头像
rech
用户头像
美少女

活动打卡代码 AcWing 2521. 数颜色

kuzi
11天前
//2021/11/24
#include <iostream>
#include <cstring>
#include <cstdio>
#include <algorithm>
#include <cmath>

using namespace std;

const int N = 10010, S = 1000010;

int n, m, mq, mc, len;
int w[N], cnt[S], ans[N];
struct Query
{
    int id, l, r, t;
}q[N];
struct Modify
{
    int p, c;
}c[N];

int get(int x)
{
    return x / len;
}

bool cmp(const Query& a, const Query& b)
{
    int al = get(a.l), ar = get(a.r);
    int bl = get(b.l), br = get(b.r);
    if (al != bl) return al < bl;
    if (ar != br) return ar < br;
    return a.t < b.t;
}

void add(int x, int& res)
{
    if (!cnt[x]) res ++ ;
    cnt[x] ++ ;
}

void del(int x, int& res)
{
    cnt[x] -- ;
    if (!cnt[x]) res -- ;
}

int main()
{
    scanf("%d%d", &n, &m);
    for (int i = 1; i <= n; i ++ ) scanf("%d", &w[i]);
    for (int i = 0; i < m; i ++ )
    {
        char op[2];
        int a, b;
        scanf("%s%d%d", op, &a, &b);
        if (*op == 'Q') mq ++, q[mq] = {mq, a, b, mc};
        else c[ ++ mc] = {a, b};
    }

    len = cbrt((double)n * mc) + 1;
    sort(q + 1, q + mq + 1, cmp);

    for (int i = 0, j = 1, t = 0, k = 1, res = 0; k <= mq; k ++ )
    {
        int id = q[k].id, l = q[k].l, r = q[k].r, tm = q[k].t;
        while (i < r) add(w[ ++ i], res);
        while (i > r) del(w[i -- ], res);
        while (j < l) del(w[j ++ ], res);
        while (j > l) add(w[ -- j], res);
        while (t < tm)
        {
            t ++ ;
            if (c[t].p >= j && c[t].p <= i)
            {
                del(w[c[t].p], res);//修改一个数可以当作删掉原数再加回去
                add(c[t].c, res);
            }
            swap(w[c[t].p], c[t].c);
        }
        while (t > tm)
        {
            if (c[t].p >= j && c[t].p <= i)
            {
                del(w[c[t].p], res);
                add(c[t].c, res);
            }
            swap(w[c[t].p], c[t].c);
            t -- ;
        }
        ans[id] = res;
    }

    for (int i = 1; i <= mq; i ++ ) printf("%d\n", ans[i]);
    return 0;
}



kuzi
11天前

线段树套线段树,求二维区间最大值,最小值,和。注意每次query都要初始化maxV,minV,sum。

向下是x轴(n),向右是y轴(m)

#include <cstdio>
#include <cmath>
#include <cstring>
#include <iostream>
#include <algorithm>
#define ll long long
using namespace std;
const int INF = 0x3f3f3f3f;
const int N = 1024 + 5;
ll MAX[N << 2][N << 2], minV, maxV,MIN[N<<2][N<<2];//维护最值
ll a[N<<2][N<<2];//初始矩阵
ll SUM[N<<2][N<<2],sumV;//维护求和
int n,m;
void pushupX(int deep, int rt)
{
    MAX[deep][rt] = max(MAX[deep << 1][rt], MAX[deep << 1 | 1][rt]);
    MIN[deep][rt] = min(MIN[deep << 1][rt], MIN[deep << 1 | 1][rt]);
    SUM[deep][rt] = SUM[deep<<1][rt]+SUM[deep<<1|1][rt];
}
void pushupY(int deep, int rt)
{
    MAX[deep][rt] = max(MAX[deep][rt << 1], MAX[deep][rt << 1 | 1]);
    MIN[deep][rt] = min(MIN[deep][rt << 1], MIN[deep][rt << 1 | 1]);
    SUM[deep][rt]=SUM[deep][rt<<1]+SUM[deep][rt<<1|1];
}
void buildY(int ly, int ry, int deep, int rt, int flag)
{
    //y轴范围ly,ry;deep,rt;标记flag
    if (ly == ry){
        if (flag!=-1)
            MAX[deep][rt] = MIN[deep][rt] = SUM[deep][rt] = a[flag][ly];
        else
            pushupX(deep, rt);
        return;
    }
    int mid = (ly + ry) >> 1;
    buildY(ly, mid, deep, rt << 1, flag);
    buildY(mid + 1, ry, deep, rt << 1 | 1, flag);
    pushupY(deep, rt);
}
void buildX(int lx, int rx, int deep)
{
    //建树x轴范围lx,rx;deep
    if (lx == rx){
        buildY(1, m, deep, 1, lx);
        return;
    }
    int mid = (lx + rx) >> 1;
    buildX(lx, mid, deep << 1);
    buildX(mid + 1, rx, deep << 1 | 1);
    buildY(1, m, deep, 1, -1);
}
void updateY(int Y, int val, int ly, int ry, int deep, int rt, int flag)
{
    //单点更新y坐标;更新值val;当前操作y的范围ly,ry;deep,rt;标记flag
    if (ly == ry){
        if (flag) //注意读清楚题意,看是单点修改值还是单点加值
            MAX[deep][rt] = MIN[deep][rt] = SUM[deep][rt] = val;
        else pushupX(deep, rt);
        return;
    }
    int mid = (ly + ry) >> 1;
    if (Y <= mid)
        updateY(Y, val, ly, mid, deep, rt << 1, flag);
    else
        updateY(Y, val, mid + 1, ry, deep, rt << 1 | 1, flag);
    pushupY(deep, rt);
}
void updateX(int X, int Y, int val, int lx, int rx, int deep)
{
    //单点更新范围x,y;更新值val;当前操作x的范围lx,rx;deep
    if (lx == rx){
        updateY(Y, val, 1, m, deep, 1, 1);
        return;
    }
    int mid = (lx + rx) >> 1;
    if (X <= mid)
        updateX(X, Y, val, lx, mid, deep << 1);
    else
        updateX(X, Y, val, mid + 1, rx, deep << 1 | 1);
    updateY(Y, val, 1, m, deep, 1, 0);
}
void queryY(int Yl, int Yr, int ly, int ry, int deep, int rt)
{
    //询问区间y轴范围y1,y2;当前操作y的范围ly,ry;deep,rt
    if (Yl <= ly && ry <= Yr)
    {
        minV = min(MIN[deep][rt], minV);
        maxV = max(MAX[deep][rt], maxV);
        sumV += SUM[deep][rt];
        return;
    }
    int mid = (ly + ry) >> 1;
    if (Yl <= mid)
        queryY(Yl, Yr, ly, mid, deep, rt << 1);
    if (mid < Yr)
        queryY(Yl, Yr, mid + 1, ry, deep, rt << 1 | 1);
}
void queryX(int Xl, int Xr, int Yl, int Yr, int lx, int rx, int rt)
{
    //询问区间范围x1,x2,y1,y2;当前操作x的范围lx,rx;rt
    if (Xl <= lx && rx <= Xr){
        queryY(Yl, Yr, 1, m, rt, 1);
        return;
    }
    int mid = (lx + rx) >> 1;
    if (Xl <= mid)
        queryX(Xl, Xr, Yl, Yr, lx, mid, rt << 1);
    if (mid < Xr)
        queryX(Xl, Xr, Yl, Yr, mid + 1, rx, rt << 1 | 1);
}
ll w[N][N],sum[N][N];
int main()
{
    scanf("%d%d",&n,&m);
    int h1,w1,h2,w2;
    cin>>h1>>w1>>h2>>w2;
    for(int i=1;i<=n;i++)
        for(int j=1;j<=m;j++)
            {
                scanf("%lld",&w[i][j]);
                sum[i][j]=sum[i-1][j]+sum[i][j-1]-sum[i-1][j-1]+w[i][j];
            }
    for(int i=min(h2,h1);i<=n;i++)
        for(int j=min(w2,w1);j<=m;j++)
        {
            a[i][j]=sum[i][j]-sum[i-min(h2,h1)][j]-sum[i][j-min(w2,w1)]+sum[i-min(h2,h1)][j-min(w1,w2)];
        }
    buildX(1,n,1);
    ll ans=0;
    for(int i=h1;i<=n;i++)
        for(int j=w1;j<=m;j++)
        {
            maxV=0;
            int x=i-h1+min(h1,h2),y=j-w1+min(w1,w2);
            queryX(x,i,y,j,1,n,1);//填入x1,x2,y1,y2注意x1<=x2 y1<=y2
            ll now=sum[i][j]-sum[i-h1][j]-sum[i][j-w1]+sum[i-h1][j-w1];
            ans=max(ans,now-maxV);
        }
    cout<<ans;
    return 0;
}



kuzi
23天前

求解n个。 x=ai%mi 其中mi不一定两两互质

//2021/11/12
#include <cstdio>
#include <iostream>
using namespace std;
typedef long long ll;
int n;
ll exgcd(ll a, ll b, ll &x, ll &y){
    if(b == 0){
        x = 1, y = 0;
        return a;
    }

    ll d = exgcd(b, a % b, y, x);
    y -= a / b * x;
    return d;
}
ll inline mod(ll a, ll b){
    return ((a % b) + b) % b;
}
int main(){
    while(~scanf("%lld",&n))
    {
        ll a1, m1;
        bool f=1;
        scanf("%lld%lld", &a1, &m1);
        for(int i = 1; i < n; i++){
            ll a2, m2, k1, k2;
            scanf("%lld%lld", &a2, &m2);
            ll d = exgcd(a1, -a2, k1, k2);
            if((m2 - m1) % d)
            {
                f=0;
            }
            k1 = mod(k1 * (m2 - m1) / d, abs(a2 / d));
            m1 = k1 * a1 + m1;
            a1 = abs(a1 / d * a2);
        }
        if(f)printf("%lld\n", m1);
        else puts("-1");
    }
    return 0;
}


活动打卡代码 AcWing 1298. 曹冲养猪

kuzi
23天前

中国剩余定理 保证Mi两两互质,(模数两两互质)

截屏2021-11-12 下午3.11.51.png

求最小正整数解,(res%M+M)%M即可

//2021/11/12
#include <iostream>

using namespace std;

typedef long long ll;

const int N = 15;

int n;
int a[N], m[N];

void exgcd(ll a, ll b, ll &x, ll &y) {
    if (!b) x = 1, y = 0;
    else {
        exgcd(b, a % b, y, x);
        y -= a / b * x;
    }
}
int main() {
    cin >> n;
    ll M = 1;
    for (int i = 0; i < n; ++ i) {
        cin >> m[i] >> a[i];
        M *= m[i];  //读入mi的同时计算M
    }
    ll res = 0;
    for (int i = 0; i < n; ++ i) {
        LL Mi = M / m[i];   //计算Mi = M/mi
        LL ti, y;
        //这一步是求逆元,根据逆元公式的衍生公式可以得到 ti * Mi + y * mi = 1
        exgcd(Mi, m[i], ti, y);
        res += a[i] * Mi * ti;  //计算的同时累加到res中(上述公式里有个sum需要累加)
    }
    cout << (res % M + M) % M << endl;  //对于任意x+kM都是满足要求的解,但目标是输出最小的正整数x,因此取模即可
    return 0;
}
#include <iostream>

using namespace std;

typedef long long ll;

const int N = 15;

int n;
int a[N], m[N];

void exgcd(ll a, ll b, ll &x, ll &y) {
    if (!b) x = 1, y = 0;
    else {
        exgcd(b, a % b, y, x);
        y -= a / b * x;
    }
}
int main() {
    cin >> n;
    ll M = 1;
    for (int i = 0; i < n; ++ i) {
        cin >> m[i] >> a[i];
        M *= m[i];  
    }
    ll res = 0;
    for (int i = 0; i < n; ++ i) {
        ll Mi = M / m[i]; 
        ll ti, y;
        exgcd(Mi, m[i], ti, y);
        res += a[i] * Mi * ti;
    }
    cout << (res % M + M) % M << endl; 
    return 0;
}

模数不一定互质版本,可能会有无解的情况

#include <cstdio>
#include <iostream>
using namespace std;
typedef long long ll;
int n;
ll exgcd(ll a, ll b, ll &x, ll &y){
    if(b == 0){
        x = 1, y = 0;
        return a;
    }

    ll d = exgcd(b, a % b, y, x);
    y -= a / b * x;
    return d;
}
ll inline mod(ll a, ll b){
    return ((a % b) + b) % b;
}
int main(){
    while(~scanf("%lld",&n))
    {
        ll a1, m1;
        bool f=1;
        scanf("%lld%lld", &a1, &m1);
        for(int i = 1; i < n; i++){
            ll a2, m2, k1, k2;
            scanf("%lld%lld", &a2, &m2);
            ll d = exgcd(a1, -a2, k1, k2);
            if((m2 - m1) % d)
            {
                f=0;
            }
            k1 = mod(k1 * (m2 - m1) / d, abs(a2 / d));
            m1 = k1 * a1 + m1;
            a1 = abs(a1 / d * a2);
        }
        if(f)printf("%lld\n", m1);
        else puts("-1");
    }
    return 0;
}


活动打卡代码 AcWing 3133. 串珠子

kuzi
1个月前

一个置换:旋转,对称等

polya:统计每个置换有多少个循环。

//10/26

//这里填你的代码^^
//注意代码要放在两组三个点之间,才可以正确显示代码高亮哦~

截屏2021-10-26 下午1.21.42.png



活动打卡代码 AcWing 3028. 最小圆覆盖

kuzi
1个月前
//2021/10/22
#include <iostream>
#include <cstring>
#include <algorithm>
#include <cmath>

#define x first
#define y second

using namespace std;

typedef pair<double, double> PDD;
const int N = 100010;
const double eps = 1e-12;
const double PI = acos(-1);

int n;
PDD q[N];
struct Circle
{
    PDD p;
    double r;
};

int sign(double x)
{
    if (fabs(x) < eps) return 0;
    if (x < 0) return -1;
    return 1;
}

int dcmp(double x, double y)
{
    if (fabs(x - y) < eps) return 0;
    if (x < y) return -1;
    return 1;
}

PDD operator- (PDD a, PDD b)
{
    return {a.x - b.x, a.y - b.y};
}

PDD operator+ (PDD a, PDD b)
{
    return {a.x + b.x, a.y + b.y};
}

PDD operator* (PDD a, double t)
{
    return {a.x * t, a.y * t};
}

PDD operator/ (PDD a, double t)
{
    return {a.x / t, a.y / t};
}

double operator* (PDD a, PDD b)
{
    return a.x * b.y - a.y * b.x;
}

PDD rotate(PDD a, double b)
{
    return {a.x * cos(b) + a.y * sin(b), -a.x * sin(b) + a.y * cos(b)};
}

double get_dist(PDD a, PDD b)
{
    double dx = a.x - b.x;
    double dy = a.y - b.y;
    return sqrt(dx * dx + dy * dy);
}

PDD get_line_intersection(PDD p, PDD v, PDD q, PDD w)
{
    auto u = p - q;
    double t = w * u / (v * w);
    return p + v * t;
}

pair<PDD, PDD> get_line(PDD a, PDD b)
{
    return {(a + b) / 2, rotate(b - a, PI / 2)};
}

Circle get_circle(PDD a, PDD b, PDD c)
{
    auto u = get_line(a, b), v = get_line(a, c);
    auto p = get_line_intersection(u.x, u.y, v.x, v.y);
    return {p, get_dist(p, a)};
}

int main()
{
    scanf("%d", &n);
    for (int i = 0; i < n; i ++ ) scanf("%lf%lf", &q[i].x, &q[i].y);
    random_shuffle(q, q + n);

    Circle c({q[0], 0});
    for (int i = 1; i < n; i ++ )
        if (dcmp(c.r, get_dist(c.p, q[i])) < 0)
        {
            c = {q[i], 0};
            for (int j = 0; j < i; j ++ )
                if (dcmp(c.r, get_dist(c.p, q[j])) < 0)
                {
                    c = {(q[i] + q[j]) / 2, get_dist(q[i], q[j]) / 2};
                    for (int k = 0; k < j; k ++ )
                        if (dcmp(c.r, get_dist(c.p, q[k])) < 0)
                            c = get_circle(q[i], q[j], q[k]);
                }
        }

    printf("%.10lf\n", c.r);
    printf("%.10lf %.10lf\n", c.p.x, c.p.y);
    return 0;
}


活动打卡代码 AcWing 3132. 食物

kuzi
1个月前

截屏2021-10-15 下午6.14.47.png
截屏2021-10-17 下午7.29.43.png
截屏2021-10-18 下午7.13.14.png

指数型生成函数(多重集,即考虑排列)的方案是x^n/n!前的系数,分母必须是对应的阶乘。系数就是答案

x当成-1~1之间的数,所以比如1 +x+x^2+…x^n=(1-x^(n+1))/(1-x) = 1/(1-x) 因为次幂趋近于0

//2021/10/15
#include<bits/stdc++.h>
using namespace std;
const int N=510,P=10007;
char s[N];
typedef long long ll;
int main(){
    cin>>s;
    ll n=0;
    for(int i=0;s[i];i++)
        n=(n*10+s[i]-'0')%P;//最高次项开始,往外展开,原数除以P的余数:秦九韶算法
    cout<<n*(n+1)*(n+2)/6%P<<endl;
}

截屏2021-10-15 下午5.30.39.png



活动打卡代码 AcWing 3123. 高精度乘法II

kuzi
1个月前
//2021/10/14
#include <iostream>
#include <cstring>
#include <algorithm>
#include <cmath>

using namespace std;

const int N = 300010;
const double PI = acos(-1);

int n, m;
struct Complex
{
    double x, y;
    Complex operator+ (const Complex& t) const
    {
        return {x + t.x, y + t.y};
    }
    Complex operator- (const Complex& t) const
    {
        return {x - t.x, y - t.y};
    }
    Complex operator* (const Complex& t) const
    {
        return {x * t.x - y * t.y, x * t.y + y * t.x};
    }
}a[N], b[N];
int rev[N], bit, tot;

void fft(Complex a[], int inv)
{
    for (int i = 0; i < tot; i ++ )
        if (i < rev[i])
            swap(a[i], a[rev[i]]);
    for (int mid = 1; mid < tot; mid <<= 1)
    {
        auto w1 = Complex({cos(PI / mid), inv * sin(PI / mid)});
        for (int i = 0; i < tot; i += mid * 2)
        {
            auto wk = Complex({1, 0});
            for (int j = 0; j < mid; j ++, wk = wk * w1)
            {
                auto x = a[i + j], y = wk * a[i + j + mid];
                a[i + j] = x + y, a[i + j + mid] = x - y;
            }
        }
    }
}
char s1[N],s2[N];
int res[N];
int main()
{
    scanf("%s%s",s1,s2);
    n=strlen(s1)-1,m=strlen(s2)-1;
    for (int i = 0; i <= n; i ++ ) a[i].x=s1[n-i]-'0';
    for (int i = 0; i <= m; i ++ ) b[i].x=s2[m-i]-'0';
    while ((1 << bit) < n + m + 1) bit ++;
    tot = 1 << bit;
    for (int i = 0; i < tot; i ++ )
        rev[i] = (rev[i >> 1] >> 1) | ((i & 1) << (bit - 1));
    fft(a, 1), fft(b, 1);
    for (int i = 0; i < tot; i ++ ) a[i] = a[i] * b[i];
    fft(a, -1);
    int t=0,k=0;
    for (int i = 0; i <=n+m||t; i ++ )
    {
        t+=(a[i].x / tot + 0.5);
        res[k++]=t%10;
        t/=10;
    }
    while(k>1&&!res[k-1])k--;
    for(int i=k-1;i>=0;i--)printf("%d",res[i]);
    return 0;
}


活动打卡代码 AcWing 3122. 多项式乘法

kuzi
1个月前

FFT

nlogn 求两个多项式乘(卷积)

数组开4倍

使用分治时,将其他无关的位置置为0,记得循环到limit

p 一般取998244353,1004535809,469762049

//2021/10/14
#include <iostream>
#include <cstring>
#include <algorithm>
#include <cmath>

using namespace std;

const int N = 300010;
const double PI = acos(-1);

int n, m;
struct Complex
{
    double x, y;
    Complex operator+ (const Complex& t) const
    {
        return {x + t.x, y + t.y};
    }
    Complex operator- (const Complex& t) const
    {
        return {x - t.x, y - t.y};
    }
    Complex operator* (const Complex& t) const
    {
        return {x * t.x - y * t.y, x * t.y + y * t.x};
    }
}a[N], b[N];
int rev[N], bit, tot;

void fft(Complex a[], int inv)
{
    for (int i = 0; i < tot; i ++ )
        if (i < rev[i])
            swap(a[i], a[rev[i]]);
    for (int mid = 1; mid < tot; mid <<= 1)
    {
        auto w1 = Complex({cos(PI / mid), inv * sin(PI / mid)});
        for (int i = 0; i < tot; i += mid * 2)
        {
            auto wk = Complex({1, 0});
            for (int j = 0; j < mid; j ++, wk = wk * w1)
            {
                auto x = a[i + j], y = wk * a[i + j + mid];
                a[i + j] = x + y, a[i + j + mid] = x - y;
            }
        }
    }
}

int main()
{
    scanf("%d%d", &n, &m);
    for (int i = 0; i <= n; i ++ ) scanf("%lf", &a[i].x);
    for (int i = 0; i <= m; i ++ ) scanf("%lf", &b[i].x);
    while ((1 << bit) < n + m + 1) bit ++;
    tot = 1 << bit;
    for (int i = 0; i < tot; i ++ )
        rev[i] = (rev[i >> 1] >> 1) | ((i & 1) << (bit - 1));
    fft(a, 1), fft(b, 1);
    for (int i = 0; i < tot; i ++ ) a[i] = a[i] * b[i];
    fft(a, -1);
    for (int i = 0; i <= n + m; i ++ )
        printf("%d ", (int)(a[i].x / tot + 0.5));//+0.5因为是浮点,怕误差。

    return 0;
}

NTT FFT的升级版

#include<bits/stdc++.h>
using namespace std;
#define getchar() (p1 == p2 && (p2 = (p1 = buf) + fread(buf, 1, 1<<21, stdin), p1 == p2) ? EOF : *p1++)
#define swap(x, y) x ^= y, y ^= x, x ^= y
#define LL long long
const int MAXN = 3 * 1e6 + 10, P = 998244353, G = 3, Gi = 332748118;
char buf[1 << 21], *p1 = buf, *p2 = buf;

inline int read()
{
    char c = getchar();
    int x = 0, f = 1;
    while (c < '0' || c > '9')
    {
        if (c == '-') f = -1;
        c = getchar();
    }
    while (c >= '0' && c <= '9') x = x * 10 + c - '0', c = getchar();
    return x * f;
}

int N, M, limit = 1, L, r[MAXN];
LL a[MAXN], b[MAXN];

inline LL fastpow(LL a, LL k)
{
    LL base = 1;
    while (k)
    {
        if (k & 1) base = (base * a) % P;
        a = (a * a) % P;
        k >>= 1;
    }
    return base % P;
}

inline void NTT(LL *A, int type)
{
    for (int i = 0; i < limit; i++)
        if (i < r[i]) swap(A[i], A[r[i]]);
    for (int mid = 1; mid < limit; mid <<= 1)
    {
        LL Wn = fastpow(type == 1 ? G : Gi, (P - 1) / (mid << 1));
        for (int j = 0; j < limit; j += (mid << 1))
        {
            LL w = 1;
            for (int k = 0; k < mid; k++, w = (w * Wn) % P)
            {
                int x = A[j + k], y = w * A[j + k + mid] % P;
                A[j + k] = (x + y) % P,
                        A[j + k + mid] = (x - y + P) % P;
            }
        }
    }
}

int main()
{
    N = read();
    M = read();
    for (int i = 0; i <= N; i++) a[i] = (read() + P) % P;
    for (int i = 0; i <= M; i++) b[i] = (read() + P) % P;
    while (limit <= N + M) limit <<= 1, L++;
    for (int i = 0; i < limit; i++) r[i] = (r[i >> 1] >> 1) | ((i & 1) << (L - 1));
    NTT(a, 1);
    NTT(b, 1);
    for (int i = 0; i < limit; i++) a[i] = (a[i] * b[i]) % P;
    NTT(a, -1);
    LL inv = fastpow(limit, P - 2);
    for (int i = 0; i <= N + M; i++)
        printf("%d ", (a[i] * inv) % P);
    return 0;
}

分治ntt,求解这种当前f需要用前面的f来计算的问题。

分治求出l-mid的f,就可以算出左区间对右区间的贡献。
截屏2021-10-15 下午3.20.58.png
截屏2021-10-15 下午3.22.21.png

#include<bits/stdc++.h>
using namespace std;
#define getchar() (p1 == p2 && (p2 = (p1 = buf) + fread(buf, 1, 1<<21, stdin), p1 == p2) ? EOF : *p1++)
#define swap(x, y) x ^= y, y ^= x, x ^= y
#define ll long long
const int MAXN = 3 * 1e6 + 10, P = 998244353, G = 3, Gi = 332748118;
char buf[1 << 21], *p1 = buf, *p2 = buf;

inline int read() 
{
    char c = getchar();
    int x = 0, f = 1;
    while (c < '0' || c > '9') 
    {
        if (c == '-') f = -1;
        c = getchar();
    }
    while (c >= '0' && c <= '9') x = x * 10 + c - '0', c = getchar();
    return x * f;
}

int N, M, limit = 1, L, r[MAXN];
ll a[MAXN], b[MAXN];

inline ll fastpow(ll a, ll k) 
{
    ll base = 1;
    while (k) 
    {
        if (k & 1) base = (base * a) % P;
        a = (a * a) % P;
        k >>= 1;
    }
    return base % P;
}

void ini(int len) 
{
    limit = 1, L = 0;
    while (limit <= len) limit <<= 1, L++;
    for (int i = 0; i < limit; i++) r[i] = (r[i >> 1] >> 1) | ((i & 1) << (L - 1));
}

inline void NTT(ll *A, int type) 
{
    for (int i = 0; i < limit; i++)
        if (i < r[i]) swap(A[i], A[r[i]]);
    for (int mid = 1; mid < limit; mid <<= 1) 
    {
        ll Wn = fastpow(type == 1 ? G : Gi, (P - 1) / (mid << 1));
        for (int j = 0; j < limit; j += (mid << 1)) 
        {
            ll w = 1;
            for (int k = 0; k < mid; k++, w = (w * Wn) % P) 
            {
                int x = A[j + k], y = w * A[j + k + mid] % P;
                A[j + k] = (x + y) % P,
                        A[j + k + mid] = (x - y + P) % P;
            }
        }
    }
}

ll g[MAXN], f[MAXN], ans[MAXN];

void solve(int l, int r) 
{
    if (l >= r)return;
    int mid = l + r >> 1;
    solve(l, mid);
    int len = r - l;
    for (int i = 0; i <= len; i++)g[i] = a[i];
    for (int i = l; i <= mid; i++)f[i - l] = ans[i];
    for (int i = mid + 1; i <= r; i++)f[i - l] = 0;
    ini(len);//len是最高次幂
    NTT(f, 1);
    NTT(g, 1);
    for (int i = 0; i < limit; i++) f[i] = (f[i] * g[i]) % P;
    NTT(f, -1);
    ll inv = fastpow(limit, P - 2);
    for (int i = mid + 1; i <= r; i++)ans[i] = (ans[i] + f[i - l] * inv % P) % P;
    solve(mid + 1, r);
}

int main() 
{
    N = read();
    for (int i = 1; i <= N - 1; i++) a[i] = (read() + P) % P;
    while (limit <= N) limit <<= 1, L++;
    ans[0] = 1;
    solve(0, limit);
    for (int i = 0; i < N; i++)printf("%lld ", ans[i]);
    return 0;
}

N个多项式相乘,分治fft。

#include<bits/stdc++.h>
using namespace std;
#define getchar() (p1 == p2 && (p2 = (p1 = buf) + fread(buf, 1, 1<<21, stdin), p1 == p2) ? EOF : *p1++)
#define swap(x, y) x ^= y, y ^= x, x ^= y
#define ll long long
const int MAXN = 3 * 1e6 + 10, P = 998244353, G = 3, Gi = 332748118;
char buf[1 << 21], *p1 = buf, *p2 = buf;

inline int read() 
{
    char c = getchar();
    int x = 0, f = 1;
    while (c < '0' || c > '9') 
    {
        if (c == '-') f = -1;
        c = getchar();
    }
    while (c >= '0' && c <= '9') x = x * 10 + c - '0', c = getchar();
    return x * f;
}

int N, M, limit = 1, L, r[MAXN];
ll a[MAXN], b[MAXN];

inline ll fastpow(ll a, ll k) 
{
    ll base = 1;
    while (k) 
    {
        if (k & 1) base = (base * a) % P;
        a = (a * a) % P;
        k >>= 1;
    }
    return base % P;
}

void ini(int len) 
{
    limit = 1, L = 0;
    while (limit <= len) limit <<= 1, L++;
    for (int i = 0; i < limit; i++) r[i] = (r[i >> 1] >> 1) | ((i & 1) << (L - 1));
}

inline void NTT(ll *A, int type) 
{
    for (int i = 0; i < limit; i++)
        if (i < r[i]) swap(A[i], A[r[i]]);
    for (int mid = 1; mid < limit; mid <<= 1) 
    {
        ll Wn = fastpow(type == 1 ? G : Gi, (P - 1) / (mid << 1));
        for (int j = 0; j < limit; j += (mid << 1)) 
        {
            ll w = 1;
            for (int k = 0; k < mid; k++, w = (w * Wn) % P) 
            {
                int x = A[j + k], y = w * A[j + k + mid] % P;
                A[j + k] = (x + y) % P,
                        A[j + k + mid] = (x - y + P) % P;
            }
        }
    }
}

ll w[MAXN],g[MAXN];

void solve(int l, int r, vector<ll> &f) 
{
    if (l == r)
    {
        f[0]=1;
        f[1]=a[l];
        return;
    }
    int mid = l + r >> 1;
    int len1=mid-l+1,len2=r-mid;
    vector<ll>f1(len1+5),f2(len2+5);
    solve(l,mid,f1);
    solve(mid+1,r,f2);
    for(int i=0;i<=len1;i++)w[i]=f1[i];
    for(int i=0;i<=len2;i++)g[i]=f2[i];
    int len=len1+len2;
    ini(len);
    for(int i=len1+1;i<limit;i++)w[i]=0;
    for(int i=len2+1;i<limit;i++)g[i]=0;
    NTT(w, 1);
    NTT(g, 1);
    for (int i = 0; i < limit; i++) w[i] = (w[i] * g[i]) % P;
    NTT(w, -1);
    ll inv = fastpow(limit, P - 2);
    for (int i=0;i<=len;i++)f[i] = w[i] * inv % P;
}
ll in[MAXN];
int main() 
{
    int t;
    t = read();
    in[0]=1;
    for(ll i=1;i<MAXN;i++)in[i]=in[i-1]*i%P;
    while(t--)
    {
        N = read();
        for (int i = 1; i <= N ; i++) a[i] = (read() + P) % P;
        vector<ll>ans(N+5);
        solve(1,N,ans);
        ll res=0;
        for(int i=1;i<=N;i++)
        {
            //cout<<ans[i]<<endl;
            res=(res+ans[i]*in[i]%P*in[N-i]%P)%P;
        }
        printf("%lld\n",res*fastpow(in[N],P-2)%P);
    }
    return 0;
}



活动打卡代码 AcWing 3125. 扩展BSGS

kuzi
1个月前

$a^x≡b(modp)$
a和p不互质也可以求。

//2021/10/13
#include <iostream>
#include <cstring>
#include <algorithm>
#include <cmath>
#include <unordered_map>

using namespace std;

typedef long long LL;
const int INF = 1e8;

int exgcd(int a, int b, int& x, int& y)
{
    if (!b)
    {
        x = 1, y = 0;
        return a;
    }
    int d = exgcd(b, a % b, y, x);
    y -= a / b * x;
    return d;
}

int bsgs(int a, int b, int p)
{
    if (1 % p == b % p) return 0;
    int k = sqrt(p) + 1;
    unordered_map<int, int> hash;
    for (int i = 0, j = b % p; i < k; i ++ )
    {
        hash[j] = i;
        j = (LL)j * a % p;
    }
    int ak = 1;
    for (int i = 0; i < k; i ++ ) ak = (LL)ak * a % p;
    for (int i = 1, j = ak; i <= k; i ++ )
    {
        if (hash.count(j)) return i * k - hash[j];
        j = (LL)j * ak % p;
    }
    return -INF;
}

int exbsgs(int a, int b, int p)
{
    b = (b % p + p) % p;
    if (1 % p == b % p) return 0;
    int x, y;
    int d = exgcd(a, p, x, y);
    if (d > 1)
    {
        if (b % d) return -INF;
        exgcd(a / d, p / d, x, y);
        return exbsgs(a, (LL)b / d * x % (p / d), p / d) + 1;
    }
    return bsgs(a, b, p);
}

int main()
{
    int a, p, b;
    while (cin >> a >> p >> b, a || p || b)
    {
        int res = exbsgs(a, b, p);
        if (res < 0) puts("No Solution");
        else cout << res << endl;
    }
    return 0;
}
ll exgcd(ll a,ll b,ll &x,ll &y)
{
    if(!b)
    {
        x=1,y=0;
        return a;
    }
    ll d=exgcd(b,a%b,y,x);
    y-=a/b*x;
    return d;
}
ll bsgs(ll a, ll b, ll p)
{
    if (1 % p == b % p) return 0;
    ll k = sqrt(p) + 1;
    unordered_map<ll, ll> hash;
    for (ll i = 0, j = b % p; i < k; i ++ )
    {
        hash[j] = i;
        j = (ll)j * a % p;
    }
    ll ak = 1;
    for (ll i = 0; i < k; i ++ ) ak = (ll)ak * a % p;
    for (ll i = 1, j = ak; i <= k; i ++ )
    {
        if (hash.count(j)) return i * k - hash[j];
        j = (ll)j * ak % p;
    }
    return -inf;
}

ll exbsgs(ll a, ll b, ll p)
{
    b = (b % p + p) % p;
    if (1 % p == b % p) return 0;
    ll x, y;
    ll d = exgcd(a, p, x, y);
    if (d > 1)
    {
        if (b % d) return -inf;
        exgcd(a / d, p / d, x, y);
        return exbsgs(a, (ll)b / d * x % (p / d), p / d) + 1;
    }
    return bsgs(a, b, p);
}