相关链接
题目传送门:http://oi.cyo.ng/wp-content/uploads/2017/03/4237864728368.png
斯特林数相关:https://en.wikipedia.org/wiki/Stirling_number
解题报告
答案显然就是两个序列卷积起来
而其中一个序列就是一个组合数一样的东西
另一个序列则需要处理出第二类斯特林数
众所周知,预处理特殊的第二类斯特林数可以用$NTT$优化
于是就搞两次$NTT$就可以了,时间复杂度$O(n \log n)$
Code
#include<bits/stdc++.h> #define LL long long using namespace std; const int N = 600009; const int MOD = 786433; const int RT = 13; int g[N],f[N],POW[N],REV[N],pos[N]; int n1,n2,m,q,vis[N],a[N],b[N]; inline int read() { char c=getchar(); int ret=0,f=1; while (c<'0'||c>'9') {if(c=='-')f=-1;c=getchar();} while (c<='9'&&c>='0') {ret=ret*10+c-'0';c=getchar();} return ret*f; } inline int Pow(int w, int t) { int ret = 1; for (;t;t>>=1,w=(LL)w*w%MOD) if(t&1)ret=(LL)ret*w%MOD; return ret; } inline int C(int a, int b) { if (a > b) return 0; return ((LL)POW[b] * REV[a] % MOD) * REV[b-a] % MOD; } inline void prework() { for (int i=2;i<N;i++) { for (int j=i*2;j<N;j+=i) { vis[j] = 1; } } vis[1] = 1; POW[0] = REV[0] = 1; for (int i=1;i<N;i++) POW[i] = (LL)POW[i-1] * i % MOD; REV[N-1] = Pow(POW[N-1], MOD-2); for (int i=N-2;i;i--) REV[i] = REV[i+1] * (i+1ll) % MOD; } inline void ntt(int *a, int len, int rev = 0) { for (int i=0;i<len;i++) if (pos[i]<i) swap(a[i], a[pos[i]]); for (int l=2;l<=len;l<<=1) { int wn = Pow(RT, MOD / l); if (rev) wn = Pow(wn, MOD-2); for (int i=0,w=1,tmp;i<len;i+=l,w=1) { for (int j=0;j<(l>>1);j++,w=(LL)w*wn%MOD) { tmp = (LL)w * a[i+j+(l>>1)] % MOD; a[i+j+(l>>1)] = (a[i+j] - tmp) % MOD; a[i+j] = (a[i+j] + tmp) % MOD; } } } if (rev) for (int i=1,Rev=Pow(len,MOD-2);i<=len;i++) a[i] = ((LL)a[i] * Rev % MOD + MOD) % MOD; } inline void solve(int l, int *a, int *b) { int t = -1, len = 1; while (len < (l+1)) len <<= 1, t++; for (int i=0;i<len;i++) { pos[i] = pos[i>>1]>>1; if (i&1) pos[i] |= 1<<t; } ntt(a, len); ntt(b, len); for (int i=0;i<len;i++) a[i] = (LL)a[i] * b[i] % MOD; ntt(a, len, 1); } inline void update() { memset(f,0,sizeof(f)); memset(g,0,sizeof(g)); memset(a,0,sizeof(a)); memset(b,0,sizeof(b)); for (int i=1;i<=n2;i++) g[i] = (LL)C(i-1, n2-1) * C(i, m) % MOD; for (int i=0;i<=n1;i++) { a[i] = (i&1)? MOD-REV[i]: REV[i]; b[i] = (LL)Pow(i, n1) * REV[i] % MOD; } solve(n1 << 1, a, b); for (int i=1;i<=n1;i++) { if (vis[i]) f[i] = i==n1? 1: C(i-1, n1); else f[i] = a[i]; } } int main() { prework(); for (int T=read();T;T--) { n1 = read(); n2 = read(); m = read(); update(); solve(n1 + n2, f, g); for (int q=read();q;q--) printf("%d\n",(f[read()]+MOD)%MOD); } return 0; }