题面
一根长为 n 的无色纸条,每个位置依次编号为 1,2,3,…,n ,m 次操作,第 i 次操作把纸条的一段区间 [l,r] (l <= r , l,r ∈ {1,2,3,…,n})涂成颜色 i ,最后一定要把纸条涂满颜色,问最终的纸条有多少种可能的模样。
输入为两个数 n,m ,输出为你的答案
m <= n <= 1e6
题解
不考虑先前染的颜色被覆盖这件事情。如果某种颜色在最终的序列中出现了 x 次,那么我们就直接认为在染这种颜色的时候,我们只染了 x 个格子。
但这样一来每次染色的格子就不再是连续的一段了。不过如果我们把给一段格子染色认为是在已被染色的颜色序列中插入一段,那么一切都显得简单而明晰了!
首先我们可以想到一个 DP, d p [ i ] [ j ] dp[i][j] dp[i][j] 表示纸条长度为 j,强制最后出现 i 种颜色时,这 i 种颜色的方案数(也就是说先不乘 C ( m , i ) C(m,i) C(m,i) 之类的),那么有如下转移:
d p [ i ] [ j ] ⋅ ( j + 1 ) → d p [ i + 1 ] [ k ] ( k > j ) dp[i][j]\cdot (j+1)\;\rightarrow\;dp[i+1][k](k>j) dp[i][j]⋅(j+1)→dp[i+1][k](k>j)
(可以看作是在格子间隙中插入了一段颜色为 i+1 的)
我们最后要求的是 ∑ i = 1 m d p [ i ] [ n ] ∗ C ( m − 1 , i − 1 ) \sum_{i=1}^{m}dp[i][n]*C(m-1,i-1) ∑i=1mdp[i][n]∗C(m−1,i−1) 因为最后一种颜色必须出现所以是 C ( m − 1 , i − 1 ) C(m-1,i-1) C(m−1,i−1)。
用前缀和优化可以做到 n 方,接下来我们想想怎么优化。
化式子无比艰难,我们不如感性分析一下。看上方的转移,从 d p [ 0 ] [ 0 ] dp[0][0] dp[0][0] 转移过来,我们首先选了一个数 k > 0 转移到了 d p [ 1 ] [ k ] dp[1][k] dp[1][k],然后还可以再选一个数 k’ > k 来转移到 d p [ 2 ] [ k ′ ] dp[2][k'] dp[2][k′],此时的贡献为 ( k + 1 ) (k+1) (k+1),于是再选个数 k’’ > k’,对 d p [ 3 ] [ k ′ ′ ] dp[3][k''] dp[3][k′′] 产生 ( k + 1 ) ( k ′ + 1 ) (k+1)(k'+1) (k+1)(k′+1) 的贡献…… 对于 d p [ i ] [ n ] dp[i][n] dp[i][n] 来说,相当于我们在 [ 1 , n − 1 ] [1,n-1] [1,n−1] 中选了 i − 1 i-1 i−1 个数,把它们都+1,然后乘起来,这样的所有方案的乘积和。
所以, d p [ i ] [ n ] dp[i][n] dp[i][n] 就等于 F n − 1 = ∏ i = 1 n − 1 ( x + ( i + 1 ) ) F_{n-1}=\prod_{i=1}^{n-1}(x+(i+1)) Fn−1=∏i=1n−1(x+(i+1)) 这个多项式的 n − i n - i n−i 次项,这里稍微转化理解一下,不难明白,相当于不选产生 1 的贡献,选产生 i+1 的贡献,k 次项系数表示 k 个数不选。
于是可以分治 NTT 做,求出这个多项式, O ( n log 2 n ) O(n\log^2n) O(nlog2n),很可惜还是过不了。
实际上我们乘的这 n-1 个多项式是有规律的,我们可以推一推。
假设我们已经求出了 F t F_t Ft ,我们要求 F 2 t F_{2t} F2t ,有这个式子:
F 2 t = ∏ i = 1 2 t ( x + i + 1 ) = ∏ i = 1 t ( x + i + 1 ) ∏ i = 1 t ( x + ( i + t ) + 1 ) = F t ∏ i = 1 t ( x + ( i + t ) + 1 ) F_{2t}=\prod_{i=1}^{2t}(x+i+1)=\prod_{i=1}^{t}(x+i+1)\prod_{i=1}^{t}(x+(i+t)+1)\\ =F_t\prod_{i=1}^{t}(x+(i+t)+1) F2t=∏i=12t(x+i+1)=∏i=1t(x+i+1)∏i=1t(x+(i+t)+1)=Ft∏i=1t(x+(i+t)+1)
不妨就设右边那坨为 F t ′ F_t' Ft′,那么 F 2 t = F t ⋅ F t ′ F_{2t}=F_t\cdot F_{t}' F2t=Ft⋅Ft′ ,我们知道 F t ′ F_t' Ft′ 的话就可以 NTT 了,现在来推 F t ′ F_t' Ft′:
我们先令 X = x + t X = x+t X=x+t 来换个元,
F t ′ = ∏ i = 1 t ( x + i + t + 1 ) = ∏ i = 1 t ( ( x + t ) + i + 1 ) = ∏ i = 1 t ( X + i + 1 ) F_t'=\prod_{i=1}^{t}(x+i+t+1)=\prod_{i=1}^{t}((x+t)+i+1)=\prod_{i=1}^{t}(X+i+1) Ft′=∏i=1t(x+i+t+1)=∏i=1t((x+t)+i+1)=∏i=1t(X+i+1)
然后会发现它是跟 F t F_t Ft 一样的形式,我们把 F t F_t Ft 的每项系数带入:
F t ′ = ∑ i = 0 t F t [ i ] X i = ∑ i = 0 t A [ i ] ( x + t ) i F_t'=\sum_{i=0}^{t}F_t[i]X^i=\sum_{i=0}^{t}A[i](x+t)^i Ft′=∑i=0tFt[i]Xi=∑i=0tA[i](x+t)i
利用二项式定理变成这样:
∑ i = 0 t F t [ i ] ∑ j = 0 i x j t i − j C ( i , j ) \sum_{i=0}^{t}F_t[i]\sum_{j=0}^{i}x^jt^{i-j}C(i,j) ∑i=0tFt[i]∑j=0ixjti−jC(i,j)
换个枚举顺序:
∑ j = 0 t x j ∑ i = j t F t [ i ] t i − j C ( i , j ) \sum_{j=0}^{t}x^j\sum_{i=j}^{t}F_t[i]t^{i-j}C(i,j) ∑j=0txj∑i=jtFt[i]ti−jC(i,j)
我们令 B [ i ] = F t [ t − i ] B[i]=F_t[t-i] B[i]=Ft[t−i] 翻转一下:
∑ j = 0 t x j ∑ i = j t B [ t − i ] t i − j C ( i , j ) \sum_{j=0}^{t}x^j\sum_{i=j}^{t}B[t-i]t^{i-j}C(i,j) ∑j=0txj∑i=jtB[t−i]ti−jC(i,j)
→ ∑ j = 0 t x j ∑ i = j t B [ t − i ] t i − j i ! j ! ( i − j ! ) \rightarrow\sum_{j=0}^{t}x^j\sum_{i=j}^{t}B[t-i]t^{i-j}\frac{i!}{j!(i-j!)} →∑j=0txj∑i=jtB[t−i]ti−jj!(i−j!)i!
右边就变成了一个卷积的形式,最终这样应该更好理解:
F t ′ = ∑ j = 0 t ( 1 j ! ⋅ ( ∑ i = j t B [ t − i ] t i − j i ! ( i − j ) ! ) ) ⋅ x j F_t'=\sum_{j=0}^{t}\left(\frac{1}{j!}\cdot(\sum_{i=j}^{t}B[t-i]t^{i-j}\frac{i!}{(i-j)!})\right)\cdot x^{j} Ft′=∑j=0t(j!1⋅(∑i=jtB[t−i]ti−j(i−j)!i!))⋅xj
求出它之后我们就可以求出 F 2 t F_{2t} F2t 了,这样递归,复杂度 T ( n ) = T ( n 2 ) + O ( n log n ) = O ( n log n ) T(n)=T(\frac{n}{2})+O(n\log n)=O(n\log n) T(n)=T(2n)+O(nlogn)=O(nlogn) ,可以过。
CODE
#include<cstdio>
#include<cstring>
#include<iostream>
#include<algorithm>
using namespace std;
#define MAXN 1000005
#define DB double
#define LL long long
#define ENDL putchar('\n')
LL read() {
LL f=1,x=0;char s = getchar();
while(s < '0' || s > '9') { if(s=='-')f = -f;s = getchar();}
while(s >= '0' && s <= '9') { x=x*10+(s-'0');s = getchar();}
return x*f;
}
const int MOD = 998244353;
const int proot = 3;
int n,m,i,j,s,o,k;
int dp[MAXN];
int fac[MAXN],inv[MAXN],invf[MAXN];
int C(int n,int m) {
if(m > n || n < 0) return 0;
return fac[n] *1ll* invf[n-m] % MOD *1ll* invf[m] % MOD;
}
int xm[MAXN<<2],rev[MAXN<<2],om;
int qkpow(int a,int b) {
int res = 1;
while(b > 0) {
if(b & 1) res = res *1ll* a % MOD;
a = a *1ll* a % MOD; b >>= 1;
}return res;
}
void NTT(int *s,int n,int op) {
for(int i = 1;i < n;i ++) {
rev[i] = ((rev[i>>1]>>1) | ((i & 1) ? (n>>1):0));
if(rev[i] < i) swap(s[rev[i]],s[i]);
}
om = qkpow(proot,(MOD-1)/n); xm[0] = 1;
if(op<0) om = qkpow(om,MOD-2);
for(int i = 1;i <= n;i ++) xm[i] = xm[i-1] *1ll* om % MOD;
for(int k = 2,t = (n>>1);k <= n;k <<= 1,t >>= 1) {
for(int j = 0;j < n;j += k) {
for(int i = j,l=0;i < j+(k>>1);i ++,l += t) {
int A = s[i],B = s[i+(k>>1)];
s[i] = (A + xm[l] *1ll* B % MOD) % MOD;
s[i+(k>>1)] = (A +MOD- xm[l] *1ll* B % MOD) % MOD;
}
}
}
int invn = qkpow(n,MOD-2);
if(op < 0) for(int i = 0;i < n;i ++) s[i] = s[i] *1ll* invn % MOD;
return ;
}
int A[MAXN<<2],B[MAXN<<2],cc[MAXN<<2];
void solve(int n) {
if(n == 1) {
A[0] = 2;A[1] = 1;return ;
}
int md = n>>1;
solve(md);
int po = 1,le = 1;
while(le <= md*2) le <<= 1;
for(int i = 0;i <= md;i ++) {
B[i] = A[md-i] *1ll* fac[md-i] % MOD;
cc[i] = po *1ll* invf[i] % MOD;
po = po *1ll* md % MOD;
}
NTT(B,le,1);NTT(cc,le,1);
for(int i = 0;i <= le;i ++) B[i] = B[i] *1ll* cc[i] % MOD,cc[i] = 0;
NTT(B,le,-1);
for(int i = md+1;i <= le;i ++) B[i] = 0;
for(int i = 0;i <= md;i ++) B[i] = B[i] *1ll* invf[md-i] % MOD;
for(int i = 0;i+i <= md;i ++) swap(B[i],B[md-i]);
while(le <= n) le <<= 1;
NTT(B,le,1); NTT(A,le,1);
for(int i = 0;i <= le;i ++) A[i] = A[i] *1ll* B[i] % MOD,B[i] = 0;
NTT(A,le,-1); for(int i = n+1;i <= le;i ++) A[i] = 0;
if(n & 1) {
for(int i = n;i > 0;i --) {
A[i] = (A[i] *1ll* (n+1) % MOD + A[i-1]) % MOD;
}
A[0] = A[0] *1ll* (n+1) % MOD;
}
return ;
}
int main() {
freopen("color.in","r",stdin);
freopen("color.out","w",stdout);
n = read();m = read();
fac[0] = fac[1] = inv[0] = inv[1] = invf[0] = invf[1] = 1;
for(int i = 2;i <= max(n,m);i ++) {
fac[i] = fac[i-1] *1ll* i % MOD;
inv[i] = (MOD-inv[MOD % i]) *1ll* (MOD/i) % MOD;
invf[i] = invf[i-1] *1ll* inv[i] % MOD;
}
solve(n-1);
int ans = 0;
for(int i = 1;i <= min(n,m);i ++) {
(ans += A[n-i] *1ll* C(m-1,i-1) % MOD) %= MOD;
}
printf("%d",ans);
return 0;
}