多项式乘法
两个多项式
求
即求解
朴素做法:O(n*m)
即两式系数依次相乘,其中其A、B式的表示方法为系数表示法。
##表示多项式的方法
系数表示法 和 点值表示法
- 系数表示法:
- 点值表示法:, 其中
点值表示法原理:
n个点确定一条n - 1次直线
点值表示法的优点:只需要把x相同的点对应的y值进行运算,时间复杂度降为
系数转换成点值称为求值 点值转换成系数称为插值
但怎么把系数表示转换成点值表示?
暴力转换还是
欧拉公式与单位根
欧拉公式:
单位根:
单位根满足:
即第k个n次单位根为
消去引理:
折半引理:平方的集合为的集合,有图可知。
由于其折半性,我们可以考虑把当做
利用其折半性,用某些方法可以分治的做达到的转换
DFT
DFT(离散傅里叶变换)
作用在一个长度为n复数序列,将其变换为另一个复数序列
可以把前者视为系数序列,后者视为点值的y,即把DFT过程视为求值 则逆DFT过程视为插值
DFT的公式定义:
逆DFT:
两者不一定前负后正,只要正负号相反即可。
FFT
现在考虑多项式的求值,即DFT
普通DFT为
用单位根的某些性质可以优化到
假设n为2的某次幂
对于序列进行拆分
可以得到
带入
由上,只要算出了前n/2,后n/2也能直接得出。
至于插值,令, 所得结果都除以n即可。
运用上述式子,进行递归分治求解即可
(递归的代码,下面参考网址的各位大佬都有,就不写了)
递归代码很慢,可以改成迭代形式
迭代代码(UOJ#34多项式乘法):
#include <cstdio>
#include <cmath>
#include <iostream>
#define MAXN 300000
#define m_p std::make_pair
#define x first
#define y second
int n, m;
const double pi = M_PI;
inline void getint(int &x) {
x = 0; int y = 1; char ch = getchar();
while(ch < '0' || ch > '9') { if(ch == '-') y = -1; ch = getchar(); }
while(ch >= '0' && ch <= '9') { x = (x << 1) + (x << 3) + (ch ^ 48); ch = getchar(); }
x *= y;
}
namespace np {
typedef std::pair<double, double> cp;
int rev[MAXN];
int n, m;
cp operator + (const cp &a, const cp &b) { return m_p(a.x + b.x, a.y + b.y); }
cp operator - (const cp &a, const cp &b) { return m_p(a.x - b.x, a.y - b.y); }
cp operator * (const cp &a, const cp &b) { return m_p(a.x * b.x - a.y * b.y, a.x * b.y + a.y * b.x); }
inline void init(int x) {
for(n = m = 1; n < x; n <<= 1, m++);
n <<= 1;
for(int i = 1; i < n; i++) rev[i] = (rev[i >> 1] >> 1) | ((i & 1) << (m - 1));
}
inline void dft(cp *a, int f) {
for(int i = 0; i < n; i++) if(i < rev[i]) swap(a[i], a[rev[i]]);
for(int i = 2; i <= n; i <<= 1) {
cp wn = m_p(cos(2 * pi / i), sin(2 * f * pi / i));
for(int j = 0; j < n; j += i) {
cp w = m_p(1.0, 0);
for(int k = j; k < j + i / 2; k++) {
cp u = a[k], v = w * a[k + i / 2];
a[k] = u + v;
a[k + i / 2] = u - v;
w = w * wn;
}
}
}
if(f == -1) for(int i = 0; i < n; i++) a[i].x /= double(n);
}
inline void fft(cp *a, cp *b, cp *c) {
dft(a, 1);
dft(b, 1);
for(int i = 0; i < n; i++) c[i] = a[i] * b[i];
dft(c, -1);
}
};
np::cp a[MAXN], b[MAXN], ans[MAXN];
int main() {
getint(n); getint(m);
int tmp;
for(int i = 0; i <= n; i++) {
getint(tmp);
a[i].x = (double)tmp;
}
for(int i = 0; i <= m; i++) {
getint(tmp);
b[i].x = (double)tmp;
}
np::init(std::max(n, m) + 1);
np::fft(a, b, ans);
for(int i = 0; i <= n + m; i++) printf("%d ", (int)(ans[i].x + 0.5));
}
参考博客:
阮行止/blue:FFT入门
riteme:有关多项式的算法
xlightgod:【uoj34】多项式乘法