NoiminのNoise

競技プログラミング (多め) とWeb (たまに) ,自然言語処理 (ブログではまだ)。数式の書き方を一気に KaTeX に変えようとして記事を全削除してインポートし直すなどしたので,過去にブックマークされた記事は URL が変わってしまっている可能性があります…….

AtCoder Regular Contest 075 E - Meaningful Mean

E - Meaningful Mean

問題概要

長さNの整数列$ a=\{a_1, a_2, \cdots, a_N \} $の空でない連続する部分列について,算術平均がK以上のものの数を求める.

解法概要

「算術平均がK以上の部分列」が満たす条件を変形してみる.

$$ \begin{align} \frac{1}{r-l+1}\sum_{i=l}^{r} a_i &\leq K \\ \sum_{i=l}^{r} a_i &\leq (r-l+1) K \\ \sum_{i=1}^{r} a_i - \sum_{i=1}^{l - 1} a_i &\leq r K - ( l - 1 ) K \\ \sum_{i=1}^{r} a_i - r K &\leq \sum_{i=1}^{l - 1} a_i - ( l - 1 ) K \end{align} $$

つまり,$ \sum_{i=1}^{m} a_i - m K $を各要素について求めて,その大小関係を見ればよいということになる.ちなみにこの式変形のような「条件式の中でごっちゃになってる2つの要素を右辺と左辺に切り離すことで,考えるべき特徴量を1つだけにする」という考え方はCodeforces Round #478 (Div. 2) D. Ghosts - NoiminのNoiseでも使ったので典型みを感じた.

あとはBITで転倒数を数える要領で上の条件を満たすようなlとrの組の数を求めたいのだが,1つ気をつけなければならない点がある.それは整数列aの要素が $ 1 \leq a_i \leq 10^9$という制約を持っており,普通にBITで転倒数を数えようとするとMLEするほどの大きな配列が必要になってしまうという点である.

これは$ 1 \leq N \leq 2 \times 10^5 $,つまり「$ \sum_{i=1}^{m} a_i - m K $ (以下"特徴量"と呼ぶ) の値の種類も高々200000種類であること」と,「特徴量の大小関係にのみ関心があり,絶対値の大きさは考えなくてよい」ことを利用する.ソートの結果を使うことで,1から$ 10^9$までの値を大小関係を保ったまま1から$ 10^5$までの値に圧縮できる.

値の圧縮さえできれば,あとはBITで転倒数を数える要領で条件を満たすlとrの組を数えるだけ.

ソースコード

#include <iostream>
#include <vector>
#include <algorithm>

#define all(c) c.begin(), c.end()
#define rall(c) c.rbegin(), c.rend()

using namespace std;

typedef long long ll;

class BIT {
    public:
    int n;
    vector<int> bit;
    
    BIT(int n){
        this->n = n;
        bit.resize(n);
        fill(bit.begin(), bit.end(), 0);
    }

    void add(int idx, int x){
        for(int i=idx;i<=this->n;i+=i&-i) bit[i] += x;
    }

    //bit[1]からbit[end]までの和 (閉区間)
    int sum(int end){
        int ret = 0;
        for(int i=end;i>=1;i-=i&-i) ret += bit[i];
        return ret;
    }
};

int main(){
    int n; ll k;
    cin >> n >> k;
    vector<ll> a(n+1, 0);
    for(int i=1;i<=n;i++){
        cin >> a[i];
        a[i] += (i?a[i-1]:0LL) - k;
    }

    vector<ll> sorted_a(a);
    sort(all(sorted_a));

    ll ans = 0;
    BIT tree = BIT(n+2);
    for(int i=n;i>=0;--i){
        int idx = int(lower_bound(all(sorted_a), a[i]) - sorted_a.begin()) + 1;
        if(i != n) ans += ll(n-i - tree.sum(idx-1));
        tree.add(idx, 1);
    }
    cout << ans << endl;
}

感想

最近の私BIT好きすぎでは (それともそれだけBITが (私が解くような難易度帯では) 頻出のデータ構造ということ?)