P5158 【模板】多项式快速插值 AI题解
多项式快速插值完整详解
一、问题与拉格朗日插值公式
给定 nnn 个点 (xi,yi)(x_i, y_i)(xi,yi),求 n−1n-1n−1 次多项式 f(x)f(x)f(x) 满足 f(xi)=yif(x_i) = y_if(xi)=yi。
拉格朗日插值公式:
f(x)=∑i=1nyi∏j≠ix−xjxi−xjf(x) = \sum_{i=1}^n y_i \prod_{j \neq i} \frac{x - x_j}{x_i - x_j}f(x)=i=1∑nyij=i∏xi−xjx−xj
直接计算是 O(n2)O(n^2)O(n2) 的,需要优化。
二、快速插值算法框架
2.1 基本思路
设:
- M(x)=∏i=1n(x−xi)M(x) = \prod_{i=1}^n (x - x_i)M(x)=∏i=1n(x−xi)
- Mi(x)=M(x)x−xi=∏j≠i(x−xj)M_i(x) = \frac{M(x)}{x - x_i} = \prod_{j \neq i} (x - x_j)Mi(x)=x−xiM(x)=∏j=i(x−xj)
- wi=yi∏j≠i(xi−xj)=yiM′(xi)w_i = \frac{y_i}{\prod_{j \neq i} (x_i - x_j)} = \frac{y_i}{M'(x_i)}wi=∏j=i(xi−xj)yi=M′(xi)yi
则:
f(x)=∑i=1nwiMi(x)f(x) = \sum_{i=1}^n w_i M_i(x)f(x)=i=1∑nwiMi(x)
算法步骤:
- 计算 M(x)=∏i=1n(x−xi)M(x) = \prod_{i=1}^n (x - x_i)M(x)=∏i=1n(x−xi)
- 计算 M′(x)M'(x)M′(x)
- 多点求值计算 M′(xi)M'(x_i)M′(xi)
- 计算 wi=yi⋅[M′(xi)]−1mod 998244353w_i = y_i \cdot [M'(x_i)]^{-1} \mod 998244353wi=yi⋅[M′(xi)]−1mod998244353
- 计算 f(x)=∑wiMi(x)f(x) = \sum w_i M_i(x)f(x)=∑wiMi(x)
步骤5用分治:
设 fl,r(x)=∑i=lrwi∏j=l,j≠ir(x−xj)f_{l,r}(x) = \sum_{i=l}^r w_i \prod_{j=l, j\neq i}^r (x - x_j)fl,r(x)=∑i=lrwi∏j=l,j=ir(x−xj)
则:
fl,r=fl,mid⋅Mmid+1,r+fmid+1,r⋅Ml,midf_{l,r} = f_{l,mid} \cdot M_{mid+1,r} + f_{mid+1,r} \cdot M_{l,mid}fl,r=fl,mid⋅Mmid+1,r+fmid+1,r⋅Ml,mid
其中 Ml,r=∏i=lr(x−xi)M_{l,r} = \prod_{i=l}^r (x - x_i)Ml,r=∏i=lr(x−xi)
三、基础数学知识
3.1 模运算与逆元
模数 p=998244353p = 998244353p=998244353 是质数,费马小定理:ap−1≡1(modp)a^{p-1} \equiv 1 \pmod{p}ap−1≡1(modp)
逆元:a−1≡ap−2(modp)a^{-1} \equiv a^{p-2} \pmod{p}a−1≡ap−2(modp)
代码中的快速幂:
inline long long pow2(long long a1, long long b1) {
long long c1 = 1;
while (b1 != 0) {
if (b1 % 2 == 1) c1 = c1 * a1 % mod;
a1 = a1 * a1 % mod;
b1 /= 2;
}
return c1;
}
3.2 Barrett Reduction
普通取模用 % 较慢,Barrett Reduction 用乘法和移位代替除法:
a mod p=a−⌊a/p⌋×pa \bmod p = a - \lfloor a/p \rfloor \times pamodp=a−⌊a/p⌋×p
近似计算 ⌊a/p⌋\lfloor a/p \rfloor⌊a/p⌋:
⌊a/p⌋≈⌊(a×m)/264⌋,m=⌊264/p⌋\lfloor a/p \rfloor \approx \lfloor (a \times m) / 2^{64} \rfloor, \quad m = \lfloor 2^{64}/p \rfloor⌊a/p⌋≈⌊(a×m)/264⌋,m=⌊264/p⌋
代码实现:
const long long rr = (__int128)((__int128)1 << 64) / mod;
inline int mo(long long a1) {
a1 -= ((__int128)a1 * rr >> 64) * mod;
return a1 >= mod ? a1 - mod : a1;
}
四、快速数论变换(NTT)
4.1 原根与单位根
p=998244353=119×223+1p = 998244353 = 119 \times 2^{23} + 1p=998244353=119×223+1,原根 g=3g = 3g=3。
nnn 次单位根:ωn=g(p−1)/n\omega_n = g^{(p-1)/n}ωn=g(p−1)/n
NTT 变换:
A(k)=∑j=0n−1ajωnjk(k=0,1,…,n−1)A(k) = \sum_{j=0}^{n-1} a_j \omega_n^{jk} \quad (k=0,1,\dots,n-1)A(k)=j=0∑n−1ajωnjk(k=0,1,…,n−1)
逆变换:
aj=1n∑k=0n−1A(k)ωn−jka_j = \frac{1}{n} \sum_{k=0}^{n-1} A(k) \omega_n^{-jk}aj=n1k=0∑n−1A(k)ωn−jk
4.2 递归NTT实现
代码使用递归实现,基于 Cooley-Tukey 算法:
inline void ntt(long long n1, long long *a1, long long b1) {
if (n1 <= 1) return;
// 奇偶分裂
long long ax[n1/2], ay[n1/2];
for (int i = 0; i <= n1-1; i += 2) {
ax[i/2] = a1[i];
ay[i/2] = a1[i+1];
}
ntt(n1/2, ax, b1);
ntt(n1/2, ay, b1);
// 合并
long long ww, w = 1;
if (b1 == 1) ww = pow2(3ll, (mod-1) / n1); // 正变换
else ww = pow2(pow2(3ll, (mod-1) / n1), mod-2); // 逆变换
for (int i = 0; i <= n1/2-1; i++, w = w * ww % mod) {
a1[i] = (ax[i] + w * ay[i]) % mod;
a1[i + n1/2] = (ax[i] - w * ay[i] + moc) % mod;
}
}
其中 moc = 1ll*mod*mod 用于防止负数。
4.3 优化:预处理旋转因子
uv[n1][0] 存储 ωn1\omega_{n1}ωn1,uv[n1][1] 存储 ωn1−1\omega_{n1}^{-1}ωn1−1,避免重复计算。
五、多项式基本操作
5.1 多项式乘法
NTT 将多项式乘法 O(n2)O(n^2)O(n2) 降为 O(nlogn)O(n\log n)O(nlogn):
- 将多项式系数补零到 2k2^k2k 长度
- 分别 NTT
- 点值相乘
- 逆 NTT
5.2 多项式求逆
给定 A(x)A(x)A(x),求 B(x)B(x)B(x) 使 A(x)B(x)≡1(modxn)A(x)B(x) \equiv 1 \pmod{x^n}A(x)B(x)≡1(modxn)
牛顿迭代法:已知 B0(x)B_0(x)B0(x) 满足 A(x)B0(x)≡1(modx⌈n/2⌉)A(x)B_0(x) \equiv 1 \pmod{x^{\lceil n/2 \rceil}}A(x)B0(x)≡1(modx⌈n/2⌉),
则:
B(x)≡2B0(x)−A(x)B02(x)(modxn)B(x) \equiv 2B_0(x) - A(x)B_0^2(x) \pmod{x^n}B(x)≡2B0(x)−A(x)B02(x)(modxn)
代码在 ntt2 中实现求逆:
a[0] = pow2(g[0], mod-2); // 初始逆元
while (nn <= n-m+1) {
nn *= 2;
nn2 = pow2(nn*2ll % mod, mod-2);
// 计算 A(x)B_0(x)
ntt(nn*2, a, 1);
ntt(nn*2, b, 1);
for (int i = 0; i <= nn*2-1; i++) {
a[i] = a[i] * (2ll - b[i] * a[i] % mod + mod) % mod;
}
ntt(nn*2, a, -1);
for (int i = 0; i <= nn*2-1; i++) {
a[i] = a[i] * nn2 % mod;
if (i >= nn) a[i] = 0;
}
}
5.3 多项式求导
若 A(x)=∑i=0naixiA(x) = \sum_{i=0}^n a_i x^iA(x)=∑i=0naixi,则 A′(x)=∑i=1niaixi−1A'(x) = \sum_{i=1}^n i a_i x^{i-1}A′(x)=∑i=1niaixi−1
代码中:
for (int i = 1; i <= nn; i++) {
ff[i-1] = v[1][i] * i % mod; // M'(x) 系数
}
5.4 多项式取模
给定 F(x)F(x)F(x) 和 G(x)G(x)G(x),求 Q(x)Q(x)Q(x) 和 R(x)R(x)R(x) 使:
F(x)=Q(x)G(x)+R(x),degR<degGF(x) = Q(x)G(x) + R(x), \quad \deg R < \deg GF(x)=Q(x)G(x)+R(x),degR<degG
算法:
- 设 n=degFn = \deg Fn=degF, m=degGm = \deg Gm=degG
- 反转:FR(x)=xnF(1/x)F_R(x) = x^n F(1/x)FR(x)=xnF(1/x), GR(x)=xmG(1/x)G_R(x) = x^m G(1/x)GR(x)=xmG(1/x)
- 求 GR(x)G_R(x)GR(x) 模 xn−m+1x^{n-m+1}xn−m+1 的逆 H(x)H(x)H(x)
- QR(x)=FR(x)H(x) mod xn−m+1Q_R(x) = F_R(x)H(x) \bmod x^{n-m+1}QR(x)=FR(x)H(x)modxn−m+1
- 反转得 Q(x)Q(x)Q(x)
- R(x)=F(x)−Q(x)G(x)R(x) = F(x) - Q(x)G(x)R(x)=F(x)−Q(x)G(x)
代码在 ntt2 中实现,用于多点求值。
六、分治构建多项式乘积
6.1 构建 Ml,r(x)=∏i=lr(x−xi)M_{l,r}(x) = \prod_{i=l}^r (x - x_i)Ml,r(x)=∏i=lr(x−xi)
ntt0 函数:
inline void ntt0(long long n1, long long *a1, long long b1, long long l, long long r) {
if (l == r) { // 叶子节点
a1[0] = (-xx[l] + mod) % mod; // -x_i
a1[1] = 1; // 系数1
v[b1].push_back((-xx[l] + mod) % mod);
v[b1].push_back(1);
return;
}
long long mid = (l+r)/2;
long long ax[n1*2+1], ay[n1*2+1];
ntt0(n1/2, ax, b1*2, l, mid);
ntt0(n1/2, ay, b1*2+1, mid+1, r);
// 合并:多项式乘法
ntt(n1*2, ax, 1);
ntt(n1*2, ay, 1);
for (int i = 0; i <= n1*2-1; i++) ax[i] = ax[i] * ay[i] % mod;
ntt(n1*2, ax, -1);
long long n12 = pow2(n1*2ll % mod, mod-2);
for (int i = 0; i <= n1*2-1; i++) {
ax[i] = ax[i] * n12 % mod;
v[b1].push_back(ax[i]); // 存储到vector
a1[i] = ax[i];
}
}
结果存在 v[b1] 中,b1 是节点编号。
七、多点求值
7.1 原理
计算 F(xi)F(x_i)F(xi) 对每个 xix_ixi:
F(x) mod (x−xi)=F(xi)F(x) \bmod (x-x_i) = F(x_i)F(x)mod(x−xi)=F(xi)
对区间 [l,r][l,r][l,r],设 Gl,r(x)=∏i=lr(x−xi)G_{l,r}(x) = \prod_{i=l}^r (x - x_i)Gl,r(x)=∏i=lr(x−xi)
计算 R(x)=F(x) mod Gl,r(x)R(x) = F(x) \bmod G_{l,r}(x)R(x)=F(x)modGl,r(x),则:
∀i∈[l,r],F(xi)=R(xi)\forall i \in [l,r], \quad F(x_i) = R(x_i)∀i∈[l,r],F(xi)=R(xi)
因为 F(x)=Q(x)Gl,r(x)+R(x)F(x) = Q(x)G_{l,r}(x) + R(x)F(x)=Q(x)Gl,r(x)+R(x),而 Gl,r(xi)=0G_{l,r}(x_i)=0Gl,r(xi)=0
7.2 实现
ntt2 函数:
inline void ntt2(long long n1, long long *a1, long long b1, long long l, long long r) {
if (l + 2000 >= r) { // 小范围暴力
for (int i = l; i <= r; i++) ui[i-l+1] = 1;
for (int i = 0; i <= n1/2; i++) {
for (int u = l; u <= r; u++) {
yy[u] = (yy[u] + ui[u-l+1] * a1[i]);
ui[u-l+1] = ui[u-l+1] * xx[u] % mod;
}
if (i % 6 == 0) { // 每6次取模一次,优化
for (int u = l; u <= r; u++) yy[u] %= mod;
}
}
for (int u = l; u <= r; u++) yy[u] %= mod;
return;
}
// 大范围:分治取模
long long mid = (l+r)/2;
// 取左子树多项式
for (int i = 0; i < v[b1*2].size(); i++) g[i] = v[b1*2][i];
// 计算 a1 mod g
// ... 多项式取模代码
ntt2(n1/2, cff, b1*2, l, mid); // 递归左子树
ntt2(n1/2, cff2, b1*2+1, mid+1, r); // 递归右子树
}
八、快速插值合并
8.1 原理
已知 wi=yi/M′(xi)w_i = y_i / M'(x_i)wi=yi/M′(xi),求:
fl,r(x)=∑i=lrwi∏j=l,j≠ir(x−xj)f_{l,r}(x) = \sum_{i=l}^r w_i \prod_{j=l, j\neq i}^r (x - x_j)fl,r(x)=i=l∑rwij=l,j=i∏r(x−xj)
递推关系:
fl,r=fl,mid⋅Mmid+1,r+fmid+1,r⋅Ml,midf_{l,r} = f_{l,mid} \cdot M_{mid+1,r} + f_{mid+1,r} \cdot M_{l,mid}fl,r=fl,mid⋅Mmid+1,r+fmid+1,r⋅Ml,mid
8.2 实现
ntt3 函数:
inline void ntt3(long long n1, long long *f1, long long *m1, long long l, long long r) {
if (l == r) { // 叶子
f1[0] = uq[l]; // w_l
m1[0] = (mod - xq[l]) % mod; // -x_l
m1[1] = 1; // 系数1
return;
}
long long mid = (l+r)/2;
long long fx[n1*2+1], fy[n1*2+1], mx[n1*2+1], my[n1*2+1];
ntt3(n1/2, fx, mx, l, mid);
ntt3(n1/2, fy, my, mid+1, r);
// f = fx*my + fy*mx
// m = mx*my
long long nn = n1*2;
ntt(nn, fx, 1); ntt(nn, fy, 1);
ntt(nn, mx, 1); ntt(nn, my, 1);
long long nn2 = pow2(nn, mod-2);
for (int i = 0; i <= nn; i++) f1[i] = (fx[i] * my[i]) % mod;
ntt(nn, f1, -1);
for (int i = 0; i <= nn; i++) f1[i] = f1[i] * nn2 % mod;
long long fg[nn+1];
for (int i = 0; i <= nn; i++) fg[i] = (fy[i] * mx[i]) % mod;
ntt(nn, fg, -1);
for (int i = 0; i <= nn; i++) fg[i] = fg[i] * nn2 % mod;
for (int i = 0; i <= nn; i++) f1[i] = (f1[i] + fg[i]) % mod;
for (int i = 0; i <= nn; i++) m1[i] = mx[i] * my[i] % mod;
ntt(nn, m1, -1);
for (int i = 0; i <= nn; i++) m1[i] = m1[i] * nn2 % mod;
}
九、主函数流程
int main() {
cin >> n;
for (int i = 1; i <= n; i++) {
scanf("%lld%lld", &xq[i], &yq[i]);
xx[i] = xq[i];
}
// 1. 构造 M(x) = ∏(x - x_i)
ntt0(nn/2, qwerty, 1, 1, m);
// 2. 求导得到 M'(x)
for (int i = 1; i <= nn; i++) {
ff[i-1] = v[1][i] * i % mod;
}
// 3. 多点求值计算 M'(x_i)
ntt2(nn, ff, 1, 1, m);
// 4. 计算权重 w_i = y_i / M'(x_i)
for (int i = 1; i <= m; i++) {
uq[i] = yq[i] * pow2(yy[i], mod-2) % mod;
}
// 5. 分治合并得到最终多项式
ntt3(nn/2, fe, me, 1, m);
// 输出结果
for (int i = 0; i <= n-1; i++) {
cout << fe[i] << " ";
}
return 0;
}
十、关键细节
10.1 数组大小
- 多项式乘法后次数为两多项式次数之和
- 递归深度为 logn\log nlogn
- 原代码开 600005600005600005 大小,因为 n≤100000n \leq 100000n≤100000,需要 2⌈log2n⌉×22^{\lceil \log_2 n \rceil} \times 22⌈log2n⌉×2 以上
10.2 优化技巧
- 小范围暴力:区间较小时直接霍纳法求值
- 记忆化旋转因子:
uv数组存储 ωn\omega_nωn 和 ωn−1\omega_n^{-1}ωn−1 - Barrett Reduction:加速取模
- 延迟取模:累加多次后再取模(
i%6==0)
10.3 注意事项
- 所有运算在模 998244353998244353998244353 下进行
- 逆元不存在时(M′(xi)=0M'(x_i)=0M′(xi)=0)不会发生,因为 xix_ixi 互不相同
- 多项式次数不足 n−1n-1n−1 时高位补零
- 输出系数从低次到高次
十一、复杂度分析
- 构建 M(x)M(x)M(x):O(nlog2n)O(n \log^2 n)O(nlog2n)
- 求导:O(n)O(n)O(n)
- 多点求值:O(nlog2n)O(n \log^2 n)O(nlog2n)
- 计算权重:O(nlogmod)O(n \log mod)O(nlogmod)
- 插值合并:O(nlog2n)O(n \log^2 n)O(nlog2n)
总复杂度:O(nlog2n)O(n \log^2 n)O(nlog2n)
十二、总结
快速插值算法结合了:
- 分治思想
- NTT 加速多项式乘法
- 多项式求逆
- 多项式取模
- 多点求值技巧
每个部分都需要精心实现,注意边界条件和优化细节。通过分治,将 O(n2)O(n^2)O(n2) 的拉格朗日插值优化到 O(nlog2n)O(n \log^2 n)O(nlog2n),是多项式高级算法的经典应用。
AtomGit 是由开放原子开源基金会联合 CSDN 等生态伙伴共同推出的新一代开源与人工智能协作平台。平台坚持“开放、中立、公益”的理念,把代码托管、模型共享、数据集托管、智能体开发体验和算力服务整合在一起,为开发者提供从开发、训练到部署的一站式体验。
更多推荐



所有评论(0)