Codeforces Round #589 Div. 2 E. Another Filling the Grid
問題: https://codeforces.com/contest/1228/problem/E
公式解説: https://codeforces.com/blog/entry/70162
問題概要
$n \times n$ のグリッドの各マスに,1 以上 $k ( \leq {10}^9 )$以下の整数値を書いていく.すべての行の最小値が 1 かつすべての列の最小値が 1 であるような整数値の書き方の通り数を ${10}^9 + 7$ で割ったものを求めよ.
解法概要
問題の条件は結局「各行各列に 1 つ以上 1 が含まれる」と等価なので,以下はこの条件で考える.
次のような dp を考える.
$ dp_{i, j} = i$ 行目まですべての列を埋めたとき, 1が 1 つでも入っているような列が $ j $ 個ぴったりあるような埋め方の通り数
$i-1$ 行目までで $n$ 列中 $j_1$ に 1 が含まれていて, $i$ 行目までだと $n$ 列中 $j_2$ (ただし $j_1 \leq j_2$) に 1 が含まれているような埋め方の通り数は,次のように考えられる.
- $i-1$ 行目までで 1 が含まれず,$i$ 行目で 1 が含まれる列の選び方は${}_{n-j_1} C_{j_2 - j_1}$ 通り
- $i-1$ 行目までで 1 が含まれているような列は$i$ 行目で何を選んでもよい.つまり $k$ 種類の値が入りうる.
- $i-1$ 行目までで 1 が含まれず,$i$ 行目で 1 が含まれない列は1 以外の値を選ばなければならない.つまり $k-1$ 種類の値が入りうる.
- $j_1 = n$ のときはすでにすべての列に 1 が含まれるが,$i$ 行目も少なくとも 1 つ 1 を含まなければならない.
これらのことから,
$$dp_{i, j} = \sum_{j^\prime = 1}^{j} \left( dp_{i - 1, j^\prime} \times {}_{n - j} C_{j^\prime - j} \times k^{j^\prime} \times (k - 1)^{n - j} \right) - dp_{i - 1, j} \times (k - 1)^n$$
が成り立つ.
あとは式をそのまま実装すれば良いのだが,$k$ や $k - 1$ の累乗を事前計算しておかないと出てくるたびに繰り返し二乗法を行う羽目になり最悪時間計算量が $O(n^3 \log n)$ となり TLE してしまうので注意.
ソースコード
#include <iostream> #include <vector> using namespace std; using ll = long long; constexpr ll MOD = 1000000007; constexpr int N_MAX = 250; ll n, k; ll fact[N_MAX+1], rfact[N_MAX+1]; vector<ll> kpow(N_MAX+1, 1), k1pow(N_MAX+1, 1); inline ll perm(ll n, ll r){ return (fact[n] * rfact[r]) % MOD; } inline ll comb(ll n, ll r){ return (perm(n, r) * rfact[n-r]) % MOD; } inline void init(ll n){ for(int i=1;i<=n;++i) { kpow[i] = (kpow[i-1] * k) % MOD; k1pow[i] = (k1pow[i-1] * (k-1)) % MOD; } fact[0] = fact[1] = 1; rfact[0] = rfact[1] = 1; for(int i=2;i<=n;++i) { fact[i] = (fact[i-1] * (ll)i) % MOD; rfact[i] = 1; ll k = MOD-2; ll a = fact[i]; while(k > 0){ if(k & 1){ rfact[i] *= a; rfact[i] %= MOD; } a *= a; a %= MOD; k >>= 1; } } } inline ll modpow(ll a, ll t) { if(a == k) { return kpow[t]; } else if(a == k-1) { return k1pow[t]; } ll ret = 1LL; while(t){ if(t & 1LL){ ret *= a; ret %= MOD; } a *= a; a %= MOD; t >>= 1; } return ret; } int main() { cin >> n >> k; init(n); vector<vector<ll>> dp(n, vector<ll>(n+1, 0)); for(int j=1;j<=n;++j) { dp[0][j] = (comb(n, j) * modpow(k-1, n-j)) % MOD; } for(int i=1;i<n;++i) { for(int j=1;j<=n;++j) { for(int j_=1;j_<=j;++j_) { dp[i][j] += (((dp[i-1][j_] * comb(n-j_, j-j_)) % MOD) * ((modpow(k, j_) * modpow(k-1, n-j)) % MOD)) % MOD; dp[i][j] %= MOD; } // j_ = j の場合,足し過ぎてしまうので引く dp[i][j] += MOD - ((dp[i-1][j] * modpow(k-1, n)) % MOD); dp[i][j] %= MOD; } } cout << dp[n-1][n] << endl; }
所感
log は定数じゃない,それはそう.