本机与luogu用时差距相反 ?
查看原帖
本机与luogu用时差距相反 ?
486056
xuzishuai楼主2024/12/29 10:33

我的迭代写法本机将luogu大样例下下来跑了 2s , luogu 上 TLE 。 递归写法同一组数据本机跑了 3s , luogu上 ac

且luogu上迭代写法用时是递归的两倍

迭代写法 code :

#include<iostream>
#include<algorithm>
#include<vector>
#include<time.h>
#include<assert.h>
#define fq(i,d,u) for(int i(d); i<=u; ++i)
#define fr(i,u,d) for(int i(u); i>=d; --i)
using namespace std;

int n,m;

class poly {
    protected:
    typedef const int ci;
    ci P = 998244353,g = 3;
    int qpow(int x,int k) {
        int ans(1);
        while(k) {
            if(k & 1) ans = 1ll * ans * x % P;
            x = 1ll * x * x % P; k >>= 1;
        } return ans;
    }
    std::vector<int> a;

    public :
    poly() {a = {1};} poly(ci x) { a = {x}; } poly(const std::vector<int> v) {a = v;}
    poly(const poly &b) { a = b.a; }
    poly getl(int n) { poly f(*this); f.resize(n); return f; }
    int &operator[](const int i) {return a[i];}
    size_t size() { return a.size(); }
    void resize(int n) { a.resize(n,0); }

    void print() {for(int p : a) {cout << p << " ";} cout << "\n"; }

    void ntt(int lim,int *a,bool mode,int *r) { //mode == 1 ? intt else ntt
        for(int i(0); i < lim; ++i) if(i < r[i]) std::swap(a[i],a[r[i]]);
        for(int len(2); len <= lim; len <<= 1) {
            int w1 = (mode ? qpow(g,P - 1 - (P - 1) / len) : qpow(g,(P - 1) / len));
            for(int i(0),w(1); i < (len >> 1); ++i , w = 1ll * w * w1 % P) 
                for(int bg(0); bg < lim; bg += len) {
                    int x = a[i | bg] , y = a[bg | i | (len >> 1)];
                    a[i | bg] = (x + 1ll * w * y % P) % P;
                    a[bg | i | len / 2] = (x - 1ll * w * y % P) % P;
                }
        }
        for(int i(0); i < lim; ++i) if(a[i] < 0) a[i] += P;
    }
    poly operator*(poly &b) {
        poly c;
        int lim = (int)a.size() + (int)b.size() - 1,lg(0);
        for(lg = 0; (1 << lg) < lim; ++lg); 
        lim = 1 << lg;

        int r[lim]; r[0] = 0; 
        for(int i(1); i < lim; ++i) r[i] = r[i >> 1] >> 1 | (i & 1) << (lg - 1);

        int f[lim],g[lim];
        for(int i(0); i < lim; ++i) {
            if(i < (int)a.size()) f[i] = a[i]; else f[i] = 0;
            if(i < (int)b.size()) g[i] = b[i]; else g[i] = 0;
        }
        ntt(lim,f,0,r); ntt(lim,g,0,r);
        for(int i(0); i < lim; ++i) f[i] = 1ll * f[i] * g[i] % P;
        ntt(lim,f,1,r); int inv = qpow(lim,P - 2);
        for(int i(0); i < lim; ++i) f[i] = 1ll * f[i] * inv % P;

        int n = a.size() + b.size() - 1; c.resize(n);
        for(int i(0); i < n; ++i) c[i] = f[i];
        return c;
    }
    
    poly operator+(poly &b) {
        poly c; c.resize(std::max(a.size(),b.size()));
        int n = std::min(a.size(),b.size());
        for(int i(0); i < n; ++i) c[i] = a[i] + b[i];
        return c;
    }
};

int main(){

    // freopen("input.txt","r",stdin); freopen("output.txt","w",stdout);
    auto bg = clock();
    ios::sync_with_stdio(0); cin.tie(0); cout.tie(0);

    poly a; poly b;
    cin >> n >> m; n += 1; m += 1;
    a.resize(n); b.resize(m);
    for(int i(0); i < n; ++i) cin >> a[i];
    for(int i(0); i < m; ++i) cin >> b[i];
    (a * b).print();
    
    cerr << clock() - bg << "ms\n";

    return 0;
}

递归写法 code :

#include<iostream>
#include<algorithm>
#include<vector>
#include<assert.h>
#include<time.h>
#define fq(i,d,u) for(int i(d); i<=u; ++i)
#define fr(i,u,d) for(int i(u); i>=d; --i)
using namespace std;
typedef long long ll;
const ll P = 998244353;
const int N = 2e6;

int n,m,a[N << 2],b[N << 2];

ll qpow(ll x,ll k) {
    ll ans(1);
    while(k) {
        if(k & 1) ans = ans * x % P;
        x = x * x % P; k >>= 1;
    } return ans;
}

void ntt(int lim,int a[],bool fg) {
    if(lim == 1) return ;
    int a0[lim >> 1],a1[lim >> 1];
    for(int i(0); i < lim; ++i) if(i & 1) a1[i >> 1] = a[i]; else a0[i >> 1] = a[i];
    ntt(lim >> 1,a0,fg); ntt(lim >> 1,a1,fg);

    ll w(1),w1 = qpow(3,(P - 1) / lim); if(fg) w1 = qpow(w1,lim - 1);
    for(int i(0); i < (lim >> 1); ++i,w = w * w1 % P) {
        a[i] = (a0[i] + 1ll * w * a1[i]) % P;
        a[i + (lim >> 1)] = (a0[i] - 1ll * w * a1[i]) % P;
    }
}

int main(){

    auto bg = clock();
    cin >> n >> m; 
    for(int i(0); i <= n; ++i) cin >> a[i];
    for(int i(0); i <= m; ++i) cin >> b[i];
    n = n + m; int lim(1); for(; lim <= n; lim <<= 1);
    cerr << lim << "\n";
    ntt(lim,a,0); ntt(lim,b,0); 
    // for(int i(0); i < lim; ++i) a[i] = (a[i] + P) % P,b[i] = (b[i] + P) % P;
    for(int i(0); i < lim; ++i) a[i] = (1ll * a[i] * b[i]) % P;
    ntt(lim,a,1); ll inv = qpow(lim,P - 2); for(int i(0); i < lim; ++i) a[i] = (1ll * a[i] * inv % P + P) % P;
    for(int i(0); i <= n; ++i) cout << a[i] << " ";
    cerr << clock() - bg << "ms\n";

    return 0;
}
2024/12/29 10:33
加载中...