分治FFT

Jul 13th, 2020
  • 在其它设备中阅读本文章

就离谱.

前置知识

FFTNTT,CDQ 分治 (能口胡就行).

例题

主要思路

像 CDQ 分治一样, 先算左边, 算完了统计对右边的贡献, 再算右边.

具体而言就是比如对于 $g_{0,\dots,3}=\{0,3,1,2\}$, 计算 $f_{1,2,3}$ 满足

$$ f_0=1\\f_n=\sum_{i=1}^nf_{n-i}g_i $$

(实际就是例题的第一组数据). 最初有 $f_{0,\dots,3}=\{1,0,0,0\}$.

  • 从中间分开, 算左边. 此时有 $f_{0,1}=\{1,0\}$, 长度为 $2$, 不用再分了.
    将左边的 $f$ 区间和 $g$ 区间 ($f_{0,1}$ 和 $g_{0,1}$) 卷起来得到 $\dots,3$($\dots$ 表示卷积结果的左边, 不需要), 将右边加到当前 $f$ 区间的右边, 得到新的 $f_{0,1}=\{1,3\}$.
  • 计算左边对右边的贡献.
    将整个 $f$ 区间和 $g$ 区间 ($_{0,\dots3}$) 卷起来得到 $\dots,10,5$, 同样将右边加到 $f$ 区间的右边, 得到新的 $f_{0,\dots,3}=\{1,3,10,5\}$.
  • 接下来算右边.
    与左边类似地, 将右边 ($_{2,3}$) 卷起来得到 $\dots,25$, 同样将右边加到当前 $f$ 区间的右边, 得到新的 $f_{2,3}=\{10,35\}$.
  • 整理一下以上过程, 发现得到了 $f_{0,\dots,3}=\{1,3,10,35\}$.

可以写出伪代码

function 分治FFT(l,r,gn) /*gn表示$\lg当前区间长度$*/
    if l>=n or gn=0 then
        return
    end if
    iv←r-l的逆元
    m←(l+r)/2
    分治FFT(l,m,gn-1) /*算左边*/
    memcpy(a,f+l,sizeof(int)*((r-l)>>1)); /*用f[]填充a[]的前半段*/
    memset(a+((r-l)>>1),0x00,sizeof(int)*((r-l)>>1)); /*用0填充a[]的后半段*/
    memcpy(b,g,sizeof(int)*(r-l)); /*用g[]填充b[]*/
    times(a,b,1<<gn); /*把a[]和b[]卷起来*/
    for i←m to r-1 by step 1 do
        f[i]←f[i]+a[i-l] /*加到右边,该取模取模*/
    end for
    分治FFT(m,r,gn-1) /*算右边*/
end function

Code

例题的, 感觉 NTT 会比较快 (也许).

#include <algorithm>
#include <cstdio>
#include <cstring>
using namespace std;

const int jly=998244353;

inline int ksm(int a,int b) {
    int r=1;
    while (b) {
        if (b&1) r=(long long)r*a%jly;
        a=(long long)a*a%jly; b>>=1;
    }
    return r;
}

inline int mjly (int x) {
    return x<jly?x:x-jly;
}

namespace NTTspace {
    const int grt=3;
    int qow[400005],inv[400005];

    inline void NTT(int *a,int n,bool f,bool u=false) {
        static int k,r[400005];
        int g0,gx,b,c;
        if (u) {
            k=0;
            for (int i=2;i<n;i<<=1) ++k;
            for (int i=0;i<n;++i)
                r[i]=(r[i>>1]>>1)|((i&1)<<k);
        }
        for (int i=0;i<n;++i) if (i<r[i])
            swap(a[i],a[r[i]]);
        for (int i=1;i<n;i<<=1) {
            g0=f?inv[i<<1]:qow[i<<1];
            for (int j=0;j<n;j+=(i<<1)) {
                gx=1;
                for (int k=0;k<i;++k) {
                    b=a[j+k]; c=(long long)gx*a[i+j+k]%jly;
                    a[j+k]=mjly(b+c); a[i+j+k]=mjly(b-c+jly);
                    gx=(long long)g0*gx%jly;
                }
            }
        }
        if (f) {
            int iv=ksm(n,jly-2);
            for (int i=0;i<n;++i)
                a[i]=(long long)a[i]*iv%jly;
        }
    }
}

namespace bxt {
    int n,m,k;
    int f[400005],g[400005];
    int a[400005],b[400005];
}

inline void init() {
    using namespace NTTspace;
    int iv=ksm(grt,jly-2);
    for (int i=1;i<262145;i<<=1) {
        qow[i]=ksm(grt,(jly-1)/i);
        inv[i]=ksm(iv,(jly-1)/i);
    }
}

inline void times(int *a,int *b,int m) {
    using namespace NTTspace;
    int n=1; for (;n<m;n<<=1);
    NTT(a,n,false,true); NTT(b,n,false);
    for (int i=0;i<n;++i) a[i]=(long long)a[i]*b[i]%jly;
    NTT(a,n,true);
}

void getans(int l,int r,int gn) {
    using namespace bxt;
    if (l>=n||(!gn)) return;
    int iv=ksm(r-l,jly-2),m=(l+r)>>1;
    getans(l,m,gn-1);
    memcpy(a,f+l,sizeof(int)*((r-l)>>1));
    memset(a+((r-l)>>1),0x00,sizeof(int)*((r-l)>>1));
    memcpy(b,g,sizeof(int)*(r-l));
    times(a,b,1<<gn);
    for (int i=m;i<r;++i) f[i]=mjly(f[i]+a[i-l]);
    getans(m,r,gn-1);
}

int main() {
    using namespace bxt; init();
    scanf("%d",&m); f[0]=1;
    for (int i=1;i<m;++i) scanf("%d",&g[i]);
    for (n=1;n<m;n<<=1) ++k; getans(0,n,k);
    for (int i=0;i<m;++i)
        printf("%d ",f[i]<0?f[i]+jly:f[i]);
    return 0;
}

owo

mo-ha