分治FFT
就离谱.
前置知识
例题
主要思路
像 CDQ 分治一样, 先算左边, 算完了统计对右边的贡献, 再算右边.
具体而言就是比如对于 $g_{0,\dots,3}=\{0,3,1,2\}$, 计算 $f_{1,2,3}$ 满足
$$ f_0=1\\f_n=\sum_{i=1}^nf_{n-i}g_i $$
(实际就是例题的第一组数据). 最初有 $f_{0,\dots,3}=\{1,0,0,0\}$.
- 从中间分开, 算左边. 此时有 $f_{0,1}=\{1,0\}$, 长度为 $2$, 不用再分了.
将左边的 $f$ 区间和 $g$ 区间 ($f_{0,1}$ 和 $g_{0,1}$) 卷起来得到 $\dots,3$($\dots$ 表示卷积结果的左边, 不需要), 将右边加到当前 $f$ 区间的右边, 得到新的 $f_{0,1}=\{1,3\}$. - 计算左边对右边的贡献.
将整个 $f$ 区间和 $g$ 区间 ($_{0,\dots3}$) 卷起来得到 $\dots,10,5$, 同样将右边加到 $f$ 区间的右边, 得到新的 $f_{0,\dots,3}=\{1,3,10,5\}$. - 接下来算右边.
与左边类似地, 将右边 ($_{2,3}$) 卷起来得到 $\dots,25$, 同样将右边加到当前 $f$ 区间的右边, 得到新的 $f_{2,3}=\{10,35\}$. - 整理一下以上过程, 发现得到了 $f_{0,\dots,3}=\{1,3,10,35\}$.
可以写出伪代码
function 分治FFT(l,r,gn) /*gn表示$\lg当前区间长度$*/
if l>=n or gn=0 then
return
end if
iv←r-l的逆元
m←(l+r)/2
分治FFT(l,m,gn-1) /*算左边*/
memcpy(a,f+l,sizeof(int)*((r-l)>>1)); /*用f[]填充a[]的前半段*/
memset(a+((r-l)>>1),0x00,sizeof(int)*((r-l)>>1)); /*用0填充a[]的后半段*/
memcpy(b,g,sizeof(int)*(r-l)); /*用g[]填充b[]*/
times(a,b,1<<gn); /*把a[]和b[]卷起来*/
for i←m to r-1 by step 1 do
f[i]←f[i]+a[i-l] /*加到右边,该取模取模*/
end for
分治FFT(m,r,gn-1) /*算右边*/
end function
Code
例题的, 感觉 NTT 会比较快 (也许).
#include <algorithm>
#include <cstdio>
#include <cstring>
using namespace std;
const int jly=998244353;
inline int ksm(int a,int b) {
int r=1;
while (b) {
if (b&1) r=(long long)r*a%jly;
a=(long long)a*a%jly; b>>=1;
}
return r;
}
inline int mjly (int x) {
return x<jly?x:x-jly;
}
namespace NTTspace {
const int grt=3;
int qow[400005],inv[400005];
inline void NTT(int *a,int n,bool f,bool u=false) {
static int k,r[400005];
int g0,gx,b,c;
if (u) {
k=0;
for (int i=2;i<n;i<<=1) ++k;
for (int i=0;i<n;++i)
r[i]=(r[i>>1]>>1)|((i&1)<<k);
}
for (int i=0;i<n;++i) if (i<r[i])
swap(a[i],a[r[i]]);
for (int i=1;i<n;i<<=1) {
g0=f?inv[i<<1]:qow[i<<1];
for (int j=0;j<n;j+=(i<<1)) {
gx=1;
for (int k=0;k<i;++k) {
b=a[j+k]; c=(long long)gx*a[i+j+k]%jly;
a[j+k]=mjly(b+c); a[i+j+k]=mjly(b-c+jly);
gx=(long long)g0*gx%jly;
}
}
}
if (f) {
int iv=ksm(n,jly-2);
for (int i=0;i<n;++i)
a[i]=(long long)a[i]*iv%jly;
}
}
}
namespace bxt {
int n,m,k;
int f[400005],g[400005];
int a[400005],b[400005];
}
inline void init() {
using namespace NTTspace;
int iv=ksm(grt,jly-2);
for (int i=1;i<262145;i<<=1) {
qow[i]=ksm(grt,(jly-1)/i);
inv[i]=ksm(iv,(jly-1)/i);
}
}
inline void times(int *a,int *b,int m) {
using namespace NTTspace;
int n=1; for (;n<m;n<<=1);
NTT(a,n,false,true); NTT(b,n,false);
for (int i=0;i<n;++i) a[i]=(long long)a[i]*b[i]%jly;
NTT(a,n,true);
}
void getans(int l,int r,int gn) {
using namespace bxt;
if (l>=n||(!gn)) return;
int iv=ksm(r-l,jly-2),m=(l+r)>>1;
getans(l,m,gn-1);
memcpy(a,f+l,sizeof(int)*((r-l)>>1));
memset(a+((r-l)>>1),0x00,sizeof(int)*((r-l)>>1));
memcpy(b,g,sizeof(int)*(r-l));
times(a,b,1<<gn);
for (int i=m;i<r;++i) f[i]=mjly(f[i]+a[i-l]);
getans(m,r,gn-1);
}
int main() {
using namespace bxt; init();
scanf("%d",&m); f[0]=1;
for (int i=1;i<m;++i) scanf("%d",&g[i]);
for (n=1;n<m;n<<=1) ++k; getans(0,n,k);
for (int i=0;i<m;++i)
printf("%d ",f[i]<0?f[i]+jly:f[i]);
return 0;
}