NoiminのNoise

競技プログラミング (多め) とWeb (たまに) ,自然言語処理 (ブログではまだ)。数式の書き方を一気に KaTeX に変えようとして記事を全削除してインポートし直すなどしたので,過去にブックマークされた記事は URL が変わってしまっている可能性があります…….

yukicoder No.802 だいたい等差数列

問題文: No.802 だいたい等差数列 - yukicoder

Writer 解説: https://yukicoder.me/problems/no/802/editorial

問題概要

長さ$ N (\leq 3 \times 10^5)$の整数列$ A_1, A_2, \cdots, A_N$で,次の 2 つの条件の両方を満たすものの数をを$ 10^9+7 $で割ったものを求めよ.

  • $ 1 \leq A_1 \leq A_2 \leq \cdots \leq A_N \leq M \leq 10^6$
  • $ 1 \leq i \leq N-1$ なるすべての i について$ D_1 \leq A_{i+1} - A_i \leq D_2 $

解法概要

しばらくサンプルで遊んでみたが行き詰まったので,テスターである heno_code さんのソースコードをチラ見. どうやら包除原理を利用するらしいというヒントを得たのでそこから考えてみた.

各隣接要素間の差 $ A_{i+1} - A_i $ について,条件を満たす / 満たさないものがいくつあるかについて包除原理を適用して計算したい.

このとき,隣接要素間の差の上限と下限の両方に条件付けがされていると扱いづらそうなので,どちらか片方だけにしたい.上限をどうにかするのは難しそうなので,いかなる入力においても隣接要素間の差の最小$ D_1 = 0$となるような問題にあらかじめ変換する.

元の問題を $ (N, M, D_1, D_2) $と表現すると,返還後の問題は $ (N, M - D_1(N-1), 0, D_2 - D_1) $ のように表現できる.これならどうにかなりそう.

あとは,「条件を満たす (または満たさない) "隣接要素間の差" が i 個あるような整数列の数」を求めることができれば,包除原理が適用できる.

これは (N-1) 個ある "隣接要素間の差" のうち i 個を $ D_2 - D_1 + 1$以上であると考えて,条件を満たさない整数列について考えるとやりやすい.

「条件を満たす (または満たさない) "隣接要素間の差" が i 個あるような整数列の数」は,

  • (N-1) 個の "隣接要素間の差" のうち,条件を満たさないもの i 個の選び方が $ {}_{N-1} C_{i}$ 通り
  • 隣接要素間の差のうち i 個が条件を満たさない ($ D_2 - D_1 + 1$以上) であるような整数列は, 条件を満たさないものがどれかが決まっていれば $ {}_{N+1} H_{M_{rest}} = {}_{N+M_{rest}} C_{M_{rest}} $個.ただし$ M_{rest} = M - 1 - (D_2 - D_1 + 1)*i$.あらかじめ i 個の要素に $ D_2 - D_1 + 1$ ずつ値を割り振っておいて,残りを好きに割り振ると考える.ただし,整数列の最初は最低 1 なので 少なくとも 1 は1番最初の要素に割り当てられる (いちいち-1しているのはそのため).

これで「条件を満たす (または満たさない) "隣接要素間の差" が i 個あるような整数列の数」がわかったので,あとは

(条件を満たさないものが 0 個以上の整数列の数) - (条件を満たさないものが 1 個以上の整数列の数) + (条件を満たさないものが 2 個以上の整数列の数) - (条件を満たさないものが 3 個以上の整数列の数) ...

のように包除原理を使って計算すれば良い.

ソースコード

#include <iostream>

using namespace std;

typedef long long ll;

const ll MOD = 1000000007;
const int FACT_MAX = 1300001;

ll fact[FACT_MAX], rfact[FACT_MAX];

ll perm(ll n, ll r){
    return (fact[n] * rfact[r]) % MOD;
}

ll comb(ll n, ll r){
    return (perm(n, r) * rfact[n-r]) % MOD;
}

void init(ll n){
    fact[0] = fact[1] = 1;
    rfact[0] = rfact[1] = 1;
    for(int i=2;i<=n;++i) {
        fact[i] = (fact[i-1] * (ll)i) % MOD;
        rfact[i] = 1;
        ll k = MOD-2;
        ll a = fact[i];
        while(k > 0){
            if(k & 1){
                rfact[i] *= a;
                rfact[i] %= MOD;
            }
            a *= a;
            a %= MOD;
            k  >>= 1;
        }
    }
}

int main() {
    ll N,M,D1,D2;
    cin >> N >> M >> D1 >> D2;

    init(FACT_MAX);

    ll Mrest = M - 1 - D1*(N-1);
    ll ans = 0LL;
    for(int i=0;i<N&&Mrest>=0;++i) {
        ans += MOD + (i%2?-1:1) * (comb(N-1, i) * comb(N+Mrest, Mrest)) % MOD;
        ans %= MOD;
        Mrest -= (D2-D1+1);
    }
    cout << ans << endl;
}

所感

こんな綺麗に解けるのか〜〜. 包除原理の教育的良問では.