【日常小测】生日礼物

相关链接

题目传送门:https://oi.qizy.tech/wp-content/uploads/2017/03/4237864728368.png
斯特林数相关:https://en.wikipedia.org/wiki/Stirling_number

解题报告

答案显然就是两个序列卷积起来
而其中一个序列就是一个组合数一样的东西
另一个序列则需要处理出第二类斯特林数
众所周知,预处理特殊的第二类斯特林数可以用$NTT$优化
于是就搞两次$NTT$就可以了,时间复杂度$O(n \log n)$

Code

#include<bits/stdc++.h>
#define LL long long
using namespace std;
 
const int N = 600009;
const int MOD = 786433;
const int RT = 13;
 
int g[N],f[N],POW[N],REV[N],pos[N];
int n1,n2,m,q,vis[N],a[N],b[N];
 
inline int read() {
    char c=getchar(); int ret=0,f=1;
    while (c<'0'||c>'9') {if(c=='-')f=-1;c=getchar();}
    while (c<='9'&&c>='0') {ret=ret*10+c-'0';c=getchar();}
    return ret*f;
}
 
inline int Pow(int w, int t) {
    int ret = 1;
    for (;t;t>>=1,w=(LL)w*w%MOD)
        if(t&1)ret=(LL)ret*w%MOD;
    return ret;
}   
 
inline int C(int a, int b) {
    if (a > b) return 0;
    return ((LL)POW[b] * REV[a] % MOD) * REV[b-a] % MOD;
}
 
inline void prework() {
    for (int i=2;i<N;i++) {
        for (int j=i*2;j<N;j+=i) {
            vis[j] = 1;
        }
    } vis[1] = 1;
    POW[0] = REV[0] = 1;
    for (int i=1;i<N;i++) POW[i] = (LL)POW[i-1] * i % MOD;
    REV[N-1] = Pow(POW[N-1], MOD-2);
    for (int i=N-2;i;i--) REV[i] = REV[i+1] * (i+1ll) % MOD;
}

inline void ntt(int *a, int len, int rev = 0) {
	for (int i=0;i<len;i++) if (pos[i]<i) swap(a[i], a[pos[i]]);
	for (int l=2;l<=len;l<<=1) {
		int wn = Pow(RT, MOD / l);
		if (rev) wn = Pow(wn, MOD-2);
		for (int i=0,w=1,tmp;i<len;i+=l,w=1) {
			for (int j=0;j<(l>>1);j++,w=(LL)w*wn%MOD) {
				tmp = (LL)w * a[i+j+(l>>1)] % MOD;
				a[i+j+(l>>1)] = (a[i+j] - tmp) % MOD;
				a[i+j] = (a[i+j] + tmp) % MOD;
			}
		}
	}
	if (rev) for (int i=1,Rev=Pow(len,MOD-2);i<=len;i++) 
		a[i] = ((LL)a[i] * Rev % MOD + MOD) % MOD;
}

inline void solve(int l, int *a, int *b) {
	int t = -1, len = 1;
	while (len < (l+1)) len <<= 1, t++;
	for (int i=0;i<len;i++) {
		pos[i] = pos[i>>1]>>1;
		if (i&1) pos[i] |= 1<<t;
	} 
	ntt(a, len); ntt(b, len);
	for (int i=0;i<len;i++) a[i] = (LL)a[i] * b[i] % MOD;
	ntt(a, len, 1);
}
 
inline void update() {
	memset(f,0,sizeof(f)); memset(g,0,sizeof(g));
	memset(a,0,sizeof(a)); memset(b,0,sizeof(b));
    for (int i=1;i<=n2;i++) 
        g[i] = (LL)C(i-1, n2-1) * C(i, m) % MOD;
    for (int i=0;i<=n1;i++) {
		a[i] = (i&1)? MOD-REV[i]: REV[i];
		b[i] = (LL)Pow(i, n1) * REV[i] % MOD;
	}
	solve(n1 << 1, a, b);
    for (int i=1;i<=n1;i++) {
        if (vis[i]) f[i] = i==n1? 1: C(i-1, n1);
        else f[i] = a[i]; 
    }
}
 
int main() {
    prework();
    for (int T=read();T;T--) {
        n1 = read(); n2 = read(); 
        m = read(); update();
        solve(n1 + n2, f, g);
        for (int q=read();q;q--) 
            printf("%d\n",(f[read()]+MOD)%MOD);
    }
    return 0;
}