【loj6261】一个人的高三楼 NTT+组合数学

花果子念报 2018-04-04

题目描述

给出一个长度为 $n$ 的数列,求它 $k$ 阶前缀和的每一项 $S_i^{(k)}$ 模 $998244353$ 的结果。

$n\le 10^5$ ,$k\le 2^{60}$ 。


题解

NTT+组合数学

设 $F_k(x)=\sum\limits_{i=0}^{n-1}S_{i+1}^{(k)}x^i$ ,

设 $G(x)=\sum\limits_{i=0}^{n-1}x^i$ 。

那么可以发现 $F_k(x)\times G(x)\equiv F_{k+1}(x)\ (\text{mod}\ x^n)$ 。

因此所求的 $S_i^{(k)}$ 就是 $(F_0(x)\times (G(x))^k)[i-1]$ ,其中 $F_0(x)$ 为 $0$ 阶前缀和即原数列的生成函数 $\sum\limits_{i=0}^{n-1}a_{i+1}x^i$ 。

使用多项式快速幂,时间复杂度为 $O(n\log n\log k)$ ,常数较大可能无法通过。

进一步考虑 $(G(x))^k[i]$ 的意义:根据卷积的定义,相当于有 $k$ 个数,每一个都在 $[0,k-1]$ 之间,每个数的和为 $i$ 的方案数。由于 $i\le k-1$ ,因此每个数范围的上界可以忽略,就相当于 $k$ 个自然数的和为 $i$ 的方案数。

这个方案数通过组合数学的插板法求出:$i+k-1$ 个位置选出 $k-1$ 个作为板,其余每一段作为每个数的大小。方案数为 $C_{i+k-1}^{k-1}$ 。

那么如何求这些组合数呢?根据 $C_n^m=\frac{n!}{m!(n-m)!}$ 有公式 $C_n^m=C_{n-1}^m\times\frac n{n-m}$ 。因此初始第 $0$ 项为 $1$ ,第 $i$ 项可直接由第 $i-1$ 项推出。

剩下的就好办了,直接使用NTT求 $F_0(x)$ 和 $(G(x))^k$ 的卷积即可。

时间复杂度 $O(n\log n)$ 。

#include <cstdio>
#include <algorithm>
#define N 262155
#define mod 998244353
using namespace std;
typedef long long ll;
ll a[N] , b[N];
ll pow(ll x , ll y)
{
	ll ans = 1;
	while(y)
	{
		if(y & 1) ans = ans * x % mod;
		x = x * x % mod , y >>= 1;
	}
	return ans;
}
void ntt(ll *a , int n , ll flag)
{
	int i , j , k;
	for(i = k = 0 ; i < n ; i ++ )
	{
		if(i > k) swap(a[i] , a[k]);
		for(j = n >> 1 ; (k ^= j) < j ; j >>= 1);
	}
	for(k = 2 ; k <= n ; k <<= 1)
	{
		ll wn = pow(3 , (mod - 1) / k * flag);
		for(i = 0 ; i < n ; i += k)
		{
			ll w = 1 , t;
			for(j = i ; j < i + (k >> 1) ; j ++ , w = w * wn % mod)
				t = w * a[j + (k >> 1)] % mod , a[j + (k >> 1)] = (a[j] - t + mod) % mod , a[j] = (a[j] + t) % mod;
		}
	}
}
int main()
{
	int n , len = 1 , i;
	ll k;
	scanf("%d%lld" , &n , &k) , k %= mod;
	for(i = 0 ; i < n ; i ++ ) scanf("%lld" , &a[i]);
	b[0] = 1;
	for(i = 1 ; i < n ; i ++ ) b[i] = b[i - 1] * (k + i - 1) % mod * pow(i , mod - 2) % mod;
	while(len < 2 * n) len <<= 1;
	ntt(a , len , 1) , ntt(b , len , 1);
	for(i = 0 ; i < len ; i ++ ) a[i] = a[i] * b[i] % mod;
	ntt(a , len , mod - 2);
	for(i = 0 ; i < n ; i ++ ) printf("%lld\n" , a[i] * pow(len , mod - 2) % mod);
	return 0;
}