可以先了解下用分治法求多项式的值 链接

求 $A(x)$ 在 $w_j,j=0,1,2...,n-1$ 处的值:

设多项式

$$A(x)=a_0+a_1*x+a_2*{x^2}+a_3*{x^3}+a_4*{x^4}+a_5*{x^5}+ \dots+a_{n-2}*x^{n-2}+a_{n-1}*x^{n-1}$$

按下标的奇偶性分类,设

$$A_1(x)=a_0+a_2*{x}+a_4*{x^2}+\dots+a_{n-2}*x^{\frac{n}{2}-1}$$

$$A_2(x)=a_1+a_3*{x}+a_5*{x^2}+ \dots+a_{n-1}*x^{\frac{n}{2}-1}$$

那么不难得到 $A(x)=A_1(x^2)+xA_2(x^2)$.

我们将 $\omega_n^k (k<\frac{n}{2})$ 代入得

$$\begin{aligned}
A(\omega_n^k) &=A_1(\omega_n^{2k})+\omega_n^kA_2(\omega_n^{2k}) \\
&=A_1(\omega_{\frac{n}{2}}^{k})+\omega_n^kA_2(\omega_{\frac{n}{2}}^{k})
\end{aligned}$$

同理,将 $\omega_n^{k+\frac{n}{2}}$ 代入得

$$\begin{aligned}
A(\omega_n^{k+\frac{n}{2}}) &=A_1(\omega_n^{2k+n})+\omega_n^{k+\frac{n}{2}}(\omega_n^{2k+n}) \\
&=A_1(\omega_n^{2k}*\omega_n^n)-\omega_n^kA_2(\omega_n^{2k}*\omega_n^n) \\
&=A_1(\omega_n^{2k})-\omega_n^kA_2(\omega_n^{2k})
\end{aligned}$$

两个式子只有符号不同,也就是说,算出在前 $n$ 个点的值就能得到后 $n$ 个点的值,相当于问题规模减半了。

所以可以递归的实现,直到多项式仅剩一个常数项,这时候我们直接返回就好啦!

这样时间复杂度为 $O(nlogn)$.

FFT算法的伪代码:

1.求值 $A(w_j), B(w_j)$,$j=0,1,..2n-1$

2. 计算 $C(w_j)$,$j=0,1,...,2n-1$

3. 构造多项式

$D(x)=C(w_0) + C(w_1)x+...+C(w_{2n-1})x^{2n-1}$

4. 计算所有的 $D(w_j)$,$j=0,1,...2n-1$

5. 利用下式计算 $C(x)$ 的系数 $c_j$

$D(w_j) = 2nc_{2n-j},\ j=1,...,2n-1$

$D(w_0) = 2nc_0$

递归版

fft函数 $type=1$ 进去的时候 $a$ 数组存的是系数,返回时存的是计算出来的值,$a[i]=A(w_i)$;

$type=-1$ 时相反。

这里预处理了sin和cos值,大概能快2、3倍

// luogu-judger-enable-o2
#include<iostream>
#include<cstdio>
#include<cmath>
using namespace std;
const int MAXN = 4 * 60000 + 10;  、、开4倍空间
inline int read() {
    char c = getchar(); int x = 0, f = 1;
    while (c < '0' || c > '9') {if (c == '-')f = -1; c = getchar();}
    while (c >= '0' && c <= '9') {x = x * 10 + c - '0'; c = getchar();}
    return x * f;
}
const double Pi = acos(-1.0);
const double Eps = 1e-8;
double ccos[MAXN], ssin[MAXN];
struct complex {
    double x, y;
    complex (double xx = 0, double yy = 0) {x = xx, y = yy;}
} a[MAXN], b[MAXN];
complex operator + (complex a, complex b) { return complex(a.x + b.x , a.y + b.y);}
complex operator - (complex a, complex b) { return complex(a.x - b.x , a.y - b.y);}
complex operator * (complex a, complex b) { return complex(a.x * b.x - a.y * b.y , a.x * b.y + a.y * b.x);} //不懂的看复数的运算那部分
void fast_fast_tle(int limit, complex *a, int type) {
    if (limit == 1) return ; //只有一个常数项
    complex a1[limit >> 1], a2[limit >> 1];
    for (int i = 0; i < limit; i += 2) //根据下标的奇偶性分类
        a1[i >> 1] = a[i], a2[i >> 1] = a[i + 1];
    fast_fast_tle(limit >> 1, a1, type);
    fast_fast_tle(limit >> 1, a2, type);
    complex Wn = complex(ccos[limit] , type * ssin[limit]), w = complex(1, 0);
    //complex Wn = complex(cos(2.0 * Pi / limit) , type * sin(2.0 * Pi / limit)), w = complex(1, 0);
    //Wn为单位根,w表示幂
    for (int i = 0; i < (limit >> 1); i++, w = w * Wn) //这里的w相当于公式中的k
        a[i] = a1[i] + w * a2[i],
               a[i + (limit >> 1)] = a1[i] - w * a2[i]; //利用单位根的性质,O(1)得到另一部分
}

char s[MAXN];
int res[MAXN];

int main() {
    int N = read();
    scanf("%s", s);
    for (int i = 0; i < N; i++) a[i].x = s[N-1-i]-'0';
    scanf("%s", s);
    for (int i = 0; i < N; i++) b[i].x = s[N-1-i]-'0';

    //for(int i = 0;i < N;i++)  printf("%f ", a[i]);

    int limit = 1; while (limit <= 2*N) limit <<= 1;

    for(int i = 1;i <= limit;i++)
    {
        ccos[i] = cos(2.0 * Pi / i);
        ssin[i] = sin(2.0 * Pi / i);
    }

    fast_fast_tle(limit, a, 1);
    fast_fast_tle(limit, b, 1);
    //后面的1表示要进行的变换是什么类型
    //1表示从系数变为点值
    //-1表示从点值变为系数
    //至于为什么这样是对的,可以参考一下c向量的推导过程,
    for (int i = 0; i <= limit; i++)
        a[i] = a[i] * b[i];
    fast_fast_tle(limit, a, -1);

    for(int i = 0;i <= 2*N;i++)  res[i] = int(a[i].x/limit+0.5);

    int tmp = 0;    //进位
    for(int i = 0;i <= 2*N;i++)
    {
        res[i] += tmp;
        tmp = res[i] / 10;
        res[i] = res[i] % 10;
    }

    bool flag = false;
    for (int i = 2*N; i >= 0; i--)
    {
        //printf("%f ", a[i].x);
       if(res[i])  flag = true;
        if(flag)  printf("%d", res[i]);  //按照我们推倒的公式,这里还要除以n
    }
    return 0;
}

非递归版

这个很容易发现点什么吧?

  • 每个位置分治后的最终位置为其二进制翻转后得到的位置

这样的话我们可以先把原序列变换好,把每个数放在最终的位置上,再一步一步向上合并。

一句话就可以 $O(n)$ 预处理出位置 $i$ 最终的位置 $rev[i]$:

//原理也很简单,将高bit-1位(也就是i/2)反转,再将第一位补到最高位。

fo(i,0,n-1)rev[i]=(rev[i>>1]>>1)|((i&1)<<(bit-1));
// luogu-judger-enable-o2
#include<iostream>
#include<cstdio>
#include<cmath>
using namespace std;
const int MAXN = 4e6 + 10;
inline int read() {
    char c = getchar(); int x = 0, f = 1;
    while (c < '0' || c > '9') {if (c == '-')f = -1; c = getchar();}
    while (c >= '0' && c <= '9') {x = x * 10 + c - '0'; c = getchar();}
    return x * f;
}
const double Pi = acos(-1.0);
struct complex {
    double x, y;
    complex (double xx = 0, double yy = 0) {x = xx, y = yy;}
} a[MAXN], b[MAXN];
complex operator + (complex a, complex b) { return complex(a.x + b.x , a.y + b.y);}
complex operator - (complex a, complex b) { return complex(a.x - b.x , a.y - b.y);}
complex operator * (complex a, complex b) { return complex(a.x * b.x - a.y * b.y , a.x * b.y + a.y * b.x);} //不懂的看复数的运算那部分
int N, M;
int bit, r[MAXN];
int limit = 1;
void fast_fast_tle(complex *A, int type) {
    for (int i = 0; i < limit; i++)
        if (i < r[i]) swap(A[i], A[r[i]]); //求出要迭代的序列
    for (int mid = 1; mid < limit; mid <<= 1) { //待合并区间的长度的一半
        complex Wn( cos(Pi / mid) , type * sin(Pi / mid) ); //单位根
        for (int R = mid << 1, j = 0; j < limit; j += R) { //R是区间的长度,j表示前已经到哪个位置了
            complex w(1, 0); //
            for (int k = 0; k < mid; k++, w = w * Wn) { //枚举左半部分
                complex x = A[j + k], y = w * A[j + mid + k]; //蝴蝶效应
                A[j + k] = x + y;
                A[j + mid + k] = x - y;
            }
        }
    }
}
int main() {
    int N = read(), M = read();
    for (int i = 0; i <= N; i++) a[i].x = read();
    for (int i = 0; i <= M; i++) b[i].x = read();
    while (limit <= N + M) limit <<= 1, bit++;
    for (int i = 0; i < limit; i++)
        r[i] = ( r[i >> 1] >> 1 ) | ( (i & 1) << (bit - 1) ) ;
    fast_fast_tle(a, 1);
    fast_fast_tle(b, 1);
    for (int i = 0; i <= limit; i++) a[i] = a[i] * b[i];
    fast_fast_tle(a, -1);
    for (int i = 0; i <= N + M; i++)
        printf("%d ", (int)(a[i].x / limit + 0.5));
    return 0;
}

参考链接:

(建议直接看大佬的,我都是copy过来的,整理一下思路而已)

https://www.cnblogs.com/zwfymqz/p/8244902.html

2. https://blog.csdn.net/enjoy_pascal/article/details/81478582

02-13 01:30